Back to all papers
Deep Dive #145 min read

Waymax Deep Dive

Core simulator architecture, data-driven simulation, metrics system, and evaluation framework for autonomous driving.

Waymax Deep Dive: A Comprehensive Learning Guide

Paper: Waymax: An Accelerated, Data-Driven Simulator for Large-Scale Autonomous Driving Research Authors: Cole Gulino, Justin Fu, Wenjie Luo, et al. (Waymo Research) Published: NeurIPS 2023


Table of Contents

  1. Executive Summary
  2. Architecture Deep Dive
  3. Key Concepts Explained
  4. Interactive Code Examples
  5. Mental Models
  6. Key Equations
  7. Hands-On Exercises
  8. Common Pitfalls
  9. Interview Questions
  10. Further Reading

Executive Summary

The Problem: Simulation Speed vs. Realism Trade-off

Autonomous driving research faces a fundamental tension:

Traditional Simulators:
+------------------+     +------------------+
|   CARLA/SUMO     |     |  Data-Driven     |
|   - Realistic    |     |  Approaches      |
|   - Slow (10Hz)  | vs. |  - Fast          |
|   - Synthetic    |     |  - Less realistic|
|   - CPU-bound    |     |  - Limited scope |
+------------------+     +------------------+

Real-world costs: Testing an autonomous vehicle policy requires millions of simulation steps. At 10Hz on CPU, evaluating a single policy across 44,000 scenarios takes hours. This bottleneck cripples research iteration speed.

What Waymax Solves

Waymax breaks this trade-off by combining:

  1. Hardware Acceleration: Built entirely in JAX, runs on GPUs/TPUs
  2. Real-World Data: Uses Waymo Open Motion Dataset (100K+ real driving scenarios)
  3. Multi-Agent Simulation: All vehicles can be controlled, not just ego
  4. Differentiability: Enables gradient-based optimization through the simulator

The Result:

  • 1000+ Hz simulation speed (100x faster than CARLA)
  • 44,000 scenarios in < 2 minutes on 8 V100 GPUs
  • Real driving behaviors, not synthetic

Why This Matters

Think of it like this: If training a self-driving AI is like teaching someone to drive, traditional simulators are like using a driving textbook with static images. Waymax is like having access to 100,000 dashcam recordings that you can rewind, modify, and replay at superhuman speed.

Impact Areas:

  • Reinforcement Learning: Train policies with millions of environment interactions
  • Imitation Learning: Learn from real expert trajectories
  • Motion Prediction: Benchmark prediction models in closed-loop settings
  • Safety Testing: Evaluate edge cases at scale

Architecture Deep Dive

High-Level System Design

+============================================================================+
|                           WAYMAX ARCHITECTURE                               |
+============================================================================+
|                                                                             |
|  +-------------------------+    +-----------------------------------+       |
|  |   WAYMO OPEN MOTION     |    |        JAX COMPUTATION GRAPH      |       |
|  |      DATASET            |    |   +---------------------------+   |       |
|  |  +------------------+   |    |   |                           |   |       |
|  |  | 100K+ Scenarios  |   |--->|   |   GPU/TPU Accelerated     |   |       |
|  |  | 7.64M Objects    |   |    |   |   Vectorized Operations   |   |       |
|  |  | 10Hz, 9 seconds  |   |    |   |                           |   |       |
|  |  +------------------+   |    |   +---------------------------+   |       |
|  +-------------------------+    +-----------------------------------+       |
|                                              |                              |
|                                              v                              |
|  +----------------------------------------------------------------------+  |
|  |                      SIMULATION CORE                                  |  |
|  |  +----------------+  +----------------+  +----------------------+     |  |
|  |  |   DYNAMICS     |  |    STATE       |  |      METRICS         |     |  |
|  |  |   MODELS       |  |  MANAGEMENT    |  |                      |     |  |
|  |  |  - Delta       |  |  - Position    |  |  - Collision         |     |  |
|  |  |  - Bicycle     |  |  - Velocity    |  |  - Off-road          |     |  |
|  |  |                |  |  - Rotation    |  |  - Log divergence    |     |  |
|  |  +----------------+  +----------------+  +----------------------+     |  |
|  +----------------------------------------------------------------------+  |
|                                              |                              |
|                                              v                              |
|  +----------------------------------------------------------------------+  |
|  |                         AGENT LAYER                                   |  |
|  |  +--------------+ +--------------+ +--------------+ +--------------+ |  |
|  |  |   EXPERT     | |   BEHAVIOR   | |   WAYFORMER  | |     RL       | |  |
|  |  |   (Log)      | |   CLONING    | |   PREDICTION | |    (DQN)     | |  |
|  |  +--------------+ +--------------+ +--------------+ +--------------+ |  |
|  +----------------------------------------------------------------------+  |
|                                              |                              |
|                                              v                              |
|  +----------------------------------------------------------------------+  |
|  |                    RL FRAMEWORK ADAPTERS                              |  |
|  |         +------------------+    +------------------+                  |  |
|  |         |     dm-env       |    |      Brax        |                  |  |
|  |         +------------------+    +------------------+                  |  |
|  +----------------------------------------------------------------------+  |
+============================================================================+

Core Data Flow

SCENARIO LOADING                    SIMULATION LOOP
================                    ================

  TFRecord Files                    For each timestep:
       |
       v                            +---> Observe State
  +-----------+                     |          |
  | Dataloader|                     |          v
  +-----------+                     |    Select Action
       |                            |          |
       v                            |          v
  +-----------+                     |    Apply Dynamics
  |  Scenario |                     |          |
  |   State   |--------------------->          v
  +-----------+                     |    Compute Metrics
       |                            |          |
       |  Contains:                 |          v
       |  - Agent trajectories      +----Update State
       |  - Road geometry
       |  - Traffic signals
       |  - Route information

State Representation

+===========================================================================+
|                           SIMULATOR STATE                                  |
+===========================================================================+
|                                                                            |
|  DYNAMIC STATE (per agent, per timestep)                                   |
|  +---------------------------------------------------------------------+  |
|  |  Position (x, y)  |  Heading (theta)  |  Velocity (vx, vy)         |  |
|  |  Bounding Box     |  Valid Mask       |  Agent Type                |  |
|  +---------------------------------------------------------------------+  |
|                                                                            |
|  STATIC STATE (per scenario)                                               |
|  +---------------------------------------------------------------------+  |
|  |  Road Graph                                                         |  |
|  |  +---------------------------------------------------------------+ |  |
|  |  |  Lane Centers    |  Lane Boundaries  |  Road Edges            | |  |
|  |  |  Stop Signs      |  Crosswalks       |  Speed Bumps           | |  |
|  |  +---------------------------------------------------------------+ |  |
|  |                                                                     |  |
|  |  Traffic Lights                                                     |  |
|  |  +---------------------------------------------------------------+ |  |
|  |  |  State per timestep (red/yellow/green)                        | |  |
|  |  |  Associated lane IDs                                          | |  |
|  |  +---------------------------------------------------------------+ |  |
|  +---------------------------------------------------------------------+  |
|                                                                            |
|  ROUTE INFORMATION                                                         |
|  +---------------------------------------------------------------------+  |
|  |  Logged trajectory union    |    Driveable futures                 |  |
|  +---------------------------------------------------------------------+  |
|                                                                            |
+===========================================================================+

Multi-Agent Control Architecture

                    CONTROL HIERARCHY
                    =================

        +-----------------------------------------+
        |           SCENARIO (N agents)           |
        +-----------------------------------------+
                           |
          +----------------+----------------+
          |                |                |
          v                v                v
    +-----------+    +-----------+    +-----------+
    |  Agent 1  |    |  Agent 2  |    |  Agent N  |
    | CONTROLLED|    | SIMULATED |    | SIMULATED |
    +-----------+    +-----------+    +-----------+
          |                |                |
          v                v                v
    +-----------+    +-----------+    +-----------+
    |  Policy   |    |    IDM    |    |    Log    |
    |  Network  |    |  Model    |    |  Playback |
    +-----------+    +-----------+    +-----------+
          |                |                |
          +----------------+----------------+
                           |
                           v
              +------------------------+
              |   DYNAMICS MODEL       |
              |   (Bicycle / Delta)    |
              +------------------------+
                           |
                           v
              +------------------------+
              |   NEXT STATE           |
              +------------------------+

Key Concepts Explained

1. Data-Driven Simulation

The Traditional Approach (Synthetic): Imagine building a driving game from scratch. You design roads, program car AI to follow rules, add pedestrians with random behavior patterns. It's like creating a fictional city - it works, but it's not real.

The Waymax Approach (Data-Driven): Instead, Waymax starts with actual recorded driving data. It's like having security camera footage from thousands of intersections. Every scenario you simulate actually happened in the real world.

SYNTHETIC SIMULATION              DATA-DRIVEN SIMULATION
==================               ====================

Designer creates:                 Real world provides:
+---------------+                 +---------------+
| Road Layout   |                 | Actual Roads  |
| (Imagined)    |                 | (SF, PHX...)  |
+---------------+                 +---------------+
       |                                 |
       v                                 v
+---------------+                 +---------------+
| NPC Behavior  |                 | Real Drivers  |
| (Rule-based)  |                 | (Human data)  |
+---------------+                 +---------------+
       |                                 |
       v                                 v
+---------------+                 +---------------+
| Edge Cases    |                 | Natural Edge  |
| (Must guess)  |                 | Cases         |
+---------------+                 +---------------+

Why This Matters: When your policy succeeds in Waymax, it handled situations that actually occurred. When it fails, you can analyze real failure modes.

2. Closed-Loop vs. Open-Loop Evaluation

Analogy: Think of training a chess AI.

  • Open-Loop: You show the AI 1 million chess positions and ask "What's the best next move?" The AI never sees the consequences of its choices.

  • Closed-Loop: You let the AI play full games. Each move affects the board, and it must handle the situations its own moves create.

OPEN-LOOP EVALUATION                    CLOSED-LOOP EVALUATION
====================                    ======================

Time t:   [Observe] --> [Predict]       Time t:   [Observe] --> [Act]
               |                                       |
               v                                       v
          Compare to                              Environment
          ground truth                              Updates
               |                                       |
          (No feedback)                                v
                                        Time t+1: [Observe] --> [Act]
                                                       |
                                              (Compounding effects)

The Problem Waymax Solves: Most motion prediction benchmarks are open-loop. But autonomous vehicles operate closed-loop! A prediction model might look great on benchmarks but fail in deployment because it never trained on its own error propagation.

3. The Bicycle Model

Intuition: Why is it called a "bicycle" model for cars?

Imagine looking at a car from above:

REAL CAR (4 wheels)              BICYCLE MODEL (2 wheels)
===================              ========================

    +--[F]----[F]--+                    +--[F]--+
    |              |                    |       |
    |      ^       |        ==>         |   ^   |
    |      |       |                    |   |   |
    +--[R]----[R]--+                    +--[R]--+

    Front and rear                  Front and rear
    axles with 2                    as single points
    wheels each

For motion planning, we can simplify a 4-wheel car to 2 wheels (front and rear) because:

  • Left and right wheels on an axle move together
  • The key constraint is the turning radius

The Key Insight: A bicycle can't move sideways (no-slip constraint). Neither can a car. This constraint makes the math tractable while preserving realistic motion.

4. Route Conditioning

The Problem: Without knowing where an agent wants to go, predicting behavior is ambiguous.

                    +----------+
                    |          |
        +---------->| TURN     |
        |           | LEFT     |
        |           +----------+
        |
+-------+-------+
|               |
|   CURRENT     |--------> STRAIGHT
|   POSITION    |
|               |
+-------+-------+
        |
        |           +----------+
        |           |          |
        +---------->| TURN     |
                    | RIGHT    |
                    +----------+

Without route info: All three are plausible!
With route info: Only one makes sense.

What Waymax Does: It extracts "routes" from logged trajectories - the path each vehicle actually took. During training, the model learns to follow these routes, dramatically reducing ambiguity.

Results:

  • Without routes: 2.31% off-route rate
  • With routes: ~1% off-route rate

5. Sim Agents: The Multi-Agent Challenge

The Dilemma: When training your ego vehicle policy, what should other cars do?

OPTION 1: Log Playback              OPTION 2: Reactive Agents
=====================               ======================

Other cars follow                   Other cars respond
their recorded paths                to ego's actions

   [Log Car]                           [IDM Car]
      |                                   |
      |  (ignores ego)                    |  (reacts to ego)
      v                                   v
   +--------+                         +--------+
   | Drives |                         | Slows  |
   | through|                         | down   |
   | ego!   |                         |        |
   +--------+                         +--------+

Problem: Unrealistic                Problem: May be "too easy"
collisions, ego learns              to exploit, domain gap
to avoid ghosts                     at evaluation

The Key Finding: The paper discovered a critical issue - RL agents trained against IDM (rule-based) sim agents learn to exploit their predictable behavior. When evaluated against log playback, collision rates 4x higher!

This is a form of overfitting to the simulator.


Interactive Code Examples

Example 1: Basic Environment Setup

"""
Waymax Quick Start: Loading and Stepping Through Scenarios
==========================================================
This example shows the fundamental workflow for using Waymax.
"""

import jax
import jax.numpy as jnp
from waymax import config as waymax_config
from waymax import dataloader
from waymax import dynamics
from waymax import env as waymax_env
from waymax import datatypes

# ============================================================
# STEP 1: Load Data from Waymo Open Motion Dataset
# ============================================================

# Configure data loading (uses WOMD v1.1.0 training split)
data_config = waymax_config.WOD_1_1_0_TRAINING

# Create a scenario generator (lazy loading)
scenario_generator = dataloader.simulator_state_generator(data_config)

# Get one scenario
scenario = next(scenario_generator)

print(f"Scenario loaded!")
print(f"  - Number of objects: {scenario.num_objects}")
print(f"  - Number of timesteps: {scenario.remaining_timesteps}")
print(f"  - Current timestep: {scenario.timestep}")

# ============================================================
# STEP 2: Configure the Environment
# ============================================================

# Choose a dynamics model
# Option 1: Bicycle model (realistic vehicle kinematics)
dynamics_model = dynamics.InvertibleBicycleModel()

# Option 2: Delta model (direct state control, less realistic)
# dynamics_model = dynamics.DeltaLocal()

# Create environment configuration
env_config = waymax_config.EnvironmentConfig(
    # Which agents are controlled (vs. log playback)
    controlled_object=waymax_config.ObjectType.SDC,  # Self-driving car only
)

# Initialize the environment
environment = waymax_env.PlanningAgentEnvironment(
    dynamics_model=dynamics_model,
    config=env_config,
)

# ============================================================
# STEP 3: Run a Simulation Loop
# ============================================================

# Reset environment with our scenario
state = environment.reset(scenario)

# Storage for trajectory
trajectory = []

# Simulation loop
for step in range(80):  # 8 seconds at 10Hz

    # Get observation (what the agent sees)
    observation = environment.observe(state)

    # Create an action (here: just follow the log for demonstration)
    # In practice, this would come from your policy network
    action = datatypes.Action(
        data=jnp.zeros((1, 2)),  # [acceleration, steering]
        valid=jnp.ones((1,), dtype=bool),
    )

    # Step the environment
    state = environment.step(state, action)

    # Store position for visualization
    ego_pos = state.current_sim_trajectory.xy[0, state.timestep]
    trajectory.append(ego_pos)

    # Check if done
    if state.is_done:
        break

print(f"Simulation complete! Ran for {len(trajectory)} steps")

Example 2: Computing Metrics

"""
Waymax Metrics: Evaluating Agent Performance
============================================
This example demonstrates how to compute and interpret
the six core metrics in Waymax.
"""

from waymax import metrics
from waymax.metrics import comfort
from waymax.metrics import overlap
from waymax.metrics import roadgraph as roadgraph_metrics
from waymax.metrics import route as route_metrics

def evaluate_trajectory(state, env):
    """
    Compute all Waymax metrics for a simulation run.

    Args:
        state: Final simulator state after rollout
        env: The Waymax environment

    Returns:
        Dictionary of metric values
    """

    results = {}

    # --------------------------------------------------------
    # METRIC 1: Log Divergence (Imitation Quality)
    # --------------------------------------------------------
    # How far did the agent deviate from the logged trajectory?
    # Lower is better (means closer to human expert)

    log_divergence = metrics.log_divergence(
        state.sim_trajectory,
        state.log_trajectory,
    )
    results['log_divergence_l2'] = float(jnp.mean(log_divergence.value))

    print(f"Log Divergence (L2): {results['log_divergence_l2']:.3f} meters")
    print(f"  -> Agent deviated {results['log_divergence_l2']:.1f}m on average from expert")

    # --------------------------------------------------------
    # METRIC 2: Collision Detection
    # --------------------------------------------------------
    # Did the agent collide with any other object?
    # Binary: True = collision occurred

    collision_result = overlap.detect_collision(state)
    results['collision'] = bool(jnp.any(collision_result.value))

    if results['collision']:
        print(f"Collision: YES - Agent hit another object!")
    else:
        print(f"Collision: NO - Agent avoided all collisions")

    # --------------------------------------------------------
    # METRIC 3: Off-Road Detection
    # --------------------------------------------------------
    # Did the agent drive outside driveable area?
    # Binary: True = went off-road

    offroad_result = roadgraph_metrics.offroad(state)
    results['offroad'] = bool(jnp.any(offroad_result.value))

    if results['offroad']:
        print(f"Off-Road: YES - Agent left the road!")
    else:
        print(f"Off-Road: NO - Agent stayed on road")

    # --------------------------------------------------------
    # METRIC 4: Route Progress
    # --------------------------------------------------------
    # How much of the intended route did the agent complete?
    # Higher is better (1.0 = completed full route)

    # Note: Requires route information in the scenario
    route_progress = route_metrics.route_progress(state)
    results['route_progress'] = float(route_progress.value[0])

    print(f"Route Progress: {results['route_progress']*100:.1f}%")

    # --------------------------------------------------------
    # METRIC 5: Kinematic Feasibility
    # --------------------------------------------------------
    # Were the agent's actions physically plausible?
    # Checks: acceleration < 6 m/s^2, curvature < 0.3 m^-1

    comfort_result = comfort.kinematic_infeasibility(state)
    results['kinematic_infeasible'] = bool(jnp.any(comfort_result.value))

    if results['kinematic_infeasible']:
        print(f"Kinematic Feasibility: VIOLATED - Actions too aggressive!")
    else:
        print(f"Kinematic Feasibility: PASSED - Actions are realistic")

    # --------------------------------------------------------
    # METRIC 6: Wrong Way Detection
    # --------------------------------------------------------
    # Is the agent driving against traffic flow?

    wrong_way = roadgraph_metrics.wrong_way(state)
    results['wrong_way'] = bool(jnp.any(wrong_way.value))

    return results


# Example usage in evaluation loop
def run_evaluation(policy, scenarios, num_scenarios=100):
    """
    Evaluate a policy across multiple scenarios.
    """

    all_results = []

    for i, scenario in enumerate(scenarios):
        if i >= num_scenarios:
            break

        # Run rollout with policy
        state = rollout_policy(policy, scenario)

        # Compute metrics
        results = evaluate_trajectory(state, env)
        all_results.append(results)

    # Aggregate results
    collision_rate = sum(r['collision'] for r in all_results) / len(all_results)
    offroad_rate = sum(r['offroad'] for r in all_results) / len(all_results)
    avg_divergence = sum(r['log_divergence_l2'] for r in all_results) / len(all_results)

    print(f"\n=== EVALUATION SUMMARY ({num_scenarios} scenarios) ===")
    print(f"Collision Rate: {collision_rate*100:.2f}%")
    print(f"Off-Road Rate: {offroad_rate*100:.2f}%")
    print(f"Average Log Divergence: {avg_divergence:.3f}m")

Example 3: Implementing a Simple Policy

"""
Waymax Policy: Building Your First Planning Agent
=================================================
This example shows how to implement a simple neural network
policy that can control a vehicle in Waymax.
"""

import jax
import jax.numpy as jnp
import flax.linen as nn
from typing import Tuple

class SimpleWaymaxPolicy(nn.Module):
    """
    A simple MLP policy for Waymax.

    Takes in observations and outputs actions for the bicycle model:
    - Acceleration (m/s^2)
    - Steering curvature (1/m)
    """

    hidden_dims: Tuple[int, ...] = (256, 256, 128)

    @nn.compact
    def __call__(self, observation):
        """
        Forward pass through the policy.

        Args:
            observation: Dictionary containing:
                - ego_state: Current vehicle state [x, y, theta, vx, vy]
                - roadgraph: Nearby road geometry
                - other_agents: States of other vehicles
                - route: Target waypoints to follow

        Returns:
            action: [acceleration, steering_curvature]
        """

        # --------------------------------------------------------
        # Step 1: Extract and flatten features
        # --------------------------------------------------------

        # Ego vehicle state (position, heading, velocity)
        ego_state = observation['ego_state']  # Shape: [5]

        # Route waypoints (next N points to follow)
        route = observation['route']  # Shape: [N, 2]
        route_flat = route.reshape(-1)

        # Nearby agents (positions and velocities)
        agents = observation['nearby_agents']  # Shape: [M, 4]
        agents_flat = agents.reshape(-1)

        # Concatenate all features
        features = jnp.concatenate([
            ego_state,
            route_flat,
            agents_flat,
        ])

        # --------------------------------------------------------
        # Step 2: MLP backbone
        # --------------------------------------------------------

        x = features
        for dim in self.hidden_dims:
            x = nn.Dense(dim)(x)
            x = nn.relu(x)
            x = nn.LayerNorm()(x)

        # --------------------------------------------------------
        # Step 3: Output heads
        # --------------------------------------------------------

        # Acceleration head: output in range [-4, 4] m/s^2
        accel_logits = nn.Dense(1)(x)
        acceleration = 4.0 * jnp.tanh(accel_logits)

        # Steering head: output in range [-0.3, 0.3] 1/m
        steer_logits = nn.Dense(1)(x)
        steering = 0.3 * jnp.tanh(steer_logits)

        # Combine into action
        action = jnp.concatenate([acceleration, steering])

        return action


def create_policy_step_fn(policy, params):
    """
    Create a JIT-compiled policy step function.

    This function can be used efficiently in the simulation loop.
    """

    @jax.jit
    def policy_step(observation):
        return policy.apply(params, observation)

    return policy_step


# ============================================================
# Training Loop Skeleton
# ============================================================

def train_behavior_cloning(
    policy,
    scenarios,
    num_epochs: int = 100,
    batch_size: int = 32,
    learning_rate: float = 1e-4,
):
    """
    Train policy using behavior cloning (imitation learning).

    The agent learns to mimic the logged expert trajectories.
    """

    import optax

    # Initialize parameters
    rng = jax.random.PRNGKey(0)
    dummy_obs = get_dummy_observation()
    params = policy.init(rng, dummy_obs)

    # Optimizer
    optimizer = optax.adam(learning_rate)
    opt_state = optimizer.init(params)

    # Loss function: predict expert actions
    def loss_fn(params, observations, expert_actions):
        predicted_actions = policy.apply(params, observations)

        # MSE loss between predicted and expert actions
        loss = jnp.mean((predicted_actions - expert_actions) ** 2)

        return loss

    # JIT-compiled training step
    @jax.jit
    def train_step(params, opt_state, batch):
        observations, expert_actions = batch

        loss, grads = jax.value_and_grad(loss_fn)(
            params, observations, expert_actions
        )

        updates, opt_state = optimizer.update(grads, opt_state)
        params = optax.apply_updates(params, updates)

        return params, opt_state, loss

    # Training loop
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        num_batches = 0

        for batch in create_batches(scenarios, batch_size):
            params, opt_state, loss = train_step(params, opt_state, batch)
            epoch_loss += loss
            num_batches += 1

        avg_loss = epoch_loss / num_batches
        print(f"Epoch {epoch+1}/{num_epochs} - Loss: {avg_loss:.4f}")

    return params

Example 4: Vectorized Multi-Scenario Evaluation

"""
Waymax Vectorization: Evaluating Many Scenarios in Parallel
==========================================================
This is where Waymax really shines - batch processing scenarios
on GPU for massive speedup.
"""

import jax
import jax.numpy as jnp
from functools import partial

def create_vectorized_env(dynamics_model, env_config, batch_size):
    """
    Create a vectorized environment for parallel scenario evaluation.

    Instead of running scenarios one-by-one, we run `batch_size`
    scenarios simultaneously on the GPU.
    """

    # Create base environment
    base_env = waymax_env.PlanningAgentEnvironment(
        dynamics_model=dynamics_model,
        config=env_config,
    )

    # Vectorize the step function across scenarios
    # vmap = "vectorized map" - applies function to batched inputs

    @jax.jit
    def batched_step(states, actions):
        """
        Step multiple environments in parallel.

        Args:
            states: Batched simulator states [batch_size, ...]
            actions: Batched actions [batch_size, action_dim]

        Returns:
            next_states: Updated states [batch_size, ...]
        """
        return jax.vmap(base_env.step)(states, actions)

    @jax.jit
    def batched_reset(scenarios):
        """Reset multiple scenarios in parallel."""
        return jax.vmap(base_env.reset)(scenarios)

    @jax.jit
    def batched_observe(states):
        """Get observations for multiple states in parallel."""
        return jax.vmap(base_env.observe)(states)

    return batched_reset, batched_step, batched_observe


def benchmark_throughput(env, scenarios, num_steps=80):
    """
    Benchmark simulation throughput.

    This demonstrates the speed advantage of Waymax.
    """
    import time

    # Prepare batched scenarios
    batched_reset, batched_step, batched_observe = env

    # Stack scenarios into a batch
    scenario_batch = stack_scenarios(scenarios)

    # Warm-up (JIT compilation happens here)
    states = batched_reset(scenario_batch)
    dummy_actions = jnp.zeros((len(scenarios), 2))
    _ = batched_step(states, dummy_actions)

    # Timed run
    start_time = time.time()

    states = batched_reset(scenario_batch)
    for _ in range(num_steps):
        actions = jnp.zeros((len(scenarios), 2))  # Dummy actions
        states = batched_step(states, actions)

    # Wait for GPU to finish
    jax.block_until_ready(states)

    elapsed = time.time() - start_time

    # Calculate throughput
    total_steps = len(scenarios) * num_steps
    steps_per_second = total_steps / elapsed

    print(f"=== BENCHMARK RESULTS ===")
    print(f"Batch size: {len(scenarios)}")
    print(f"Steps per scenario: {num_steps}")
    print(f"Total simulation steps: {total_steps}")
    print(f"Total time: {elapsed:.3f} seconds")
    print(f"Throughput: {steps_per_second:.0f} steps/second")
    print(f"Equivalent Hz: {steps_per_second/len(scenarios):.0f} Hz per scenario")


# Expected output (on V100 GPU):
# === BENCHMARK RESULTS ===
# Batch size: 16
# Steps per scenario: 80
# Total simulation steps: 1280
# Total time: 0.256 seconds
# Throughput: 5000 steps/second
# Equivalent Hz: 312 Hz per scenario

Mental Models

Mental Model 1: The Replay-and-Branch Metaphor

Think of Waymax like a VCR with superpowers:

TRADITIONAL VIDEO REPLAY          WAYMAX "BRANCHING" REPLAY
======================           =========================

     [Play] ---> [Watch]             [Play] ---> [Watch]
                                           |
                                           |---> [What if I
                                           |      brake here?]
                                           |
                                           |---> [What if I
                                           |      accelerate?]
                                           |
                                           +---> [What if I
                                                  change lanes?]

Can only watch what               Can intervene and see
happened                          alternative futures

Mental Model 2: The Driving Instructor Analogy

OPEN-LOOP (Traditional)              CLOSED-LOOP (Waymax)
=======================              ====================

Instructor: "Watch this video        Instructor: "Now YOU drive
of expert driving"                   in this recorded scenario"

Student: [Watches passively]         Student: [Takes the wheel]
                                               |
                                               v
                                     Student makes a mistake
                                               |
                                               v
                                     Scenario continues with
                                     the consequences!
                                               |
                                               v
                                     Student learns from
                                     compounding errors

Mental Model 3: The Simulation Fidelity Spectrum

LOW FIDELITY                                           HIGH FIDELITY
(Fast, Simple)                                         (Slow, Realistic)
     |                                                        |
     v                                                        v
+--------+     +--------+     +--------+     +--------+     +--------+
| Point  |     | Simple |     |        |     | Physics|     | Full   |
| Mass   |---->| Bicycle|---->| WAYMAX |---->| Sim    |---->| Digital|
| Model  |     | Model  |     |        |     | (CARLA)|     | Twin   |
+--------+     +--------+     +--------+     +--------+     +--------+
                                  ^
                                  |
                              SWEET SPOT
                              - Real data
                              - Fast (GPU)
                              - Good enough physics
                              - Multi-agent

Mental Model 4: The Action Space Trade-off

DELTA ACTION SPACE                   BICYCLE ACTION SPACE
==================                   ====================

Direct control:                      Indirect control:
"Move to (x+dx, y+dy)"              "Accelerate by a, steer by k"

     +---+                                +---+
     | * | <-- New position               | * | <-- Current position
     +---+                                +---+
       ^                                    |
       |                                    | Physics
       | Teleport!                          | simulation
       |                                    v
     +---+                                +---+
     | * | <-- Current position           | * | <-- New position
     +---+                                +---+

PROS:                                PROS:
- Can simulate anything              - Physically realistic
- Works for pedestrians              - Transferable to real cars
- Simple                             - Natural constraints

CONS:                                CONS:
- Unrealistic motion                 - Only works for vehicles
- Hard to transfer                   - More complex

Mental Model 5: The Sim Agent Problem

THE MULTI-AGENT DILEMMA
=======================

You're training an autonomous car. But what should the OTHER cars do?

Option A: "Ghost Replay"
========================
Other cars follow their recorded paths, ignoring your car.

Your car: [Learns to drive through "ghosts"]
Result: Unrealistic collisions, weird behaviors

     Your Car          Other Car
        |                  |
        v                  v
      +---+              +---+
      | > |  crash!      | < |  (keeps going, ignores you)
      +---+              +---+


Option B: "Reactive Agents (IDM)"
=================================
Other cars react to your car using simple rules.

Your car: [Learns to exploit predictable behavior]
Result: Overfits to IDM, fails against real humans

     Your Car          IDM Car
        |                  |
        v                  v
      +---+              +---+
      | > |              |   |  (always yields)
      +---+              +---+
                           ^
                           |
                    [IDM always brakes]
                    [Your car learns to be aggressive]


THE WAYMAX INSIGHT: You must evaluate against BOTH
=================================================
Train vs IDM, but test vs Log Playback to catch overfitting!

Key Equations

Equation 1: Bicycle Model Dynamics

The bicycle model approximates vehicle motion with two key control inputs:

  • a: acceleration (m/s^2)
  • k: steering curvature (1/m, i.e., inverse turning radius)
STATE UPDATE EQUATIONS
======================

Current state: (x, y, theta, v)  -- position, heading, speed
Action: (a, k)                    -- acceleration, curvature
Time step: dt = 0.1s              -- 10 Hz simulation

New speed:
    v' = v + a * dt

    Intuition: Speed increases by acceleration times time.
    If a = 2 m/s^2 and dt = 0.1s, speed increases by 0.2 m/s

New heading:
    theta' = theta + v * k * dt

    Intuition: Heading change = arc length / radius
              Arc length = v * dt (distance traveled)
              Curvature k = 1/radius
              So: d(theta) = v * dt * k

New position:
    x' = x + v * cos(theta) * dt
    y' = y + v * sin(theta) * dt

    Intuition: Project velocity onto x and y axes

Worked Example:

Starting state: x=0, y=0, theta=0 (facing right), v=10 m/s
Action: a=0 (constant speed), k=0.1 (turning left)
dt = 0.1s

Step 1: v' = 10 + 0 * 0.1 = 10 m/s
Step 2: theta' = 0 + 10 * 0.1 * 0.1 = 0.1 rad (5.7 degrees)
Step 3: x' = 0 + 10 * cos(0) * 0.1 = 1.0 m
        y' = 0 + 10 * sin(0) * 0.1 = 0.0 m

After one step: (1.0, 0.0, 0.1 rad, 10 m/s)
               Moved forward 1m, started turning left

Equation 2: Delta Action Space

For non-vehicle objects (pedestrians, cyclists) or when you want direct control:

DELTA UPDATES
=============

Action: (dx, dy, d_theta) -- change in position and heading

x' = x + dx
y' = y + dy
theta' = theta + d_theta

Intuition: Direct "teleportation" - specify exactly where to go.
          No physics constraints, so any motion is possible.

Equation 3: Log Divergence Metric

Measures how far the simulated trajectory deviates from the logged expert:

LOG DIVERGENCE (L2)
===================

Given:
  - Simulated trajectory: [(x_1, y_1), (x_2, y_2), ..., (x_T, y_T)]
  - Logged trajectory:    [(x'_1, y'_1), (x'_2, y'_2), ..., (x'_T, y'_T)]

Log Divergence = (1/T) * SUM_t sqrt((x_t - x'_t)^2 + (y_t - y'_t)^2)

Intuition: Average Euclidean distance between simulated and logged
          positions across all timesteps.

Example:
  Simulated: [(0,0), (1,0), (2,0)]
  Logged:    [(0,0), (1,1), (2,2)]

  Distances: [0, 1, 2]
  Log Divergence = (0 + 1 + 2) / 3 = 1.0 meter

Equation 4: Collision Detection

Bounding box overlap check for all pairs of objects:

COLLISION DETECTION
===================

Each object has a bounding box defined by:
  - Center: (cx, cy)
  - Dimensions: (length, width)
  - Heading: theta

Collision occurs when bounding boxes overlap.

Simplified 2D box overlap (axis-aligned):
  Box A: [x_min_A, x_max_A] x [y_min_A, y_max_A]
  Box B: [x_min_B, x_max_B] x [y_min_B, y_max_B]

  Overlap = (x_min_A < x_max_B) AND (x_max_A > x_min_B) AND
            (y_min_A < y_max_B) AND (y_max_A > y_min_B)

For rotated boxes, use Separating Axis Theorem (SAT).

Equation 5: Kinematic Feasibility Constraints

Ensures actions are physically plausible:

KINEMATIC CONSTRAINTS
=====================

Acceleration limit: |a| <= 6 m/s^2
  - Typical car: 0-60 mph in ~4.5s requires ~6 m/s^2
  - Hard braking: typically 8-10 m/s^2 max

Curvature limit: |k| <= 0.3 m^-1
  - Minimum turning radius: r = 1/k = 3.3 meters
  - Typical car can turn tighter at low speeds

Infeasibility metric:
  kinematic_infeasible = (|a| > 6) OR (|k| > 0.3)

If violated: Agent is executing physically impossible maneuvers

Equation 6: Route Progress

Measures completion along intended path:

ROUTE PROGRESS
==============

Given:
  - Route: sequence of waypoints defining intended path
  - Current position: (x, y)

Progress = (distance traveled along route) / (total route length)

Calculated via projection onto route polyline:
  1. Find closest point on route to current position
  2. Compute distance from route start to that point
  3. Divide by total route length

Range: [0, 1] where 1.0 = completed full route

Hands-On Exercises

Exercise 1: Data Exploration (Beginner)

Goal: Understand the structure of Waymo Open Motion Dataset scenarios.

"""
Exercise 1: Exploring a Waymax Scenario
=======================================

Tasks:
1. Load a scenario and print basic statistics
2. Identify the ego vehicle (SDC)
3. Count different object types
4. Visualize the road graph
"""

from waymax import config, dataloader
import jax.numpy as jnp

# TODO: Load a scenario
scenarios = dataloader.simulator_state_generator(config.WOD_1_1_0_TRAINING)
scenario = next(scenarios)

# Task 1: Print basic statistics
# HINT: scenario.num_objects, scenario.remaining_timesteps
print("=== SCENARIO STATISTICS ===")
# YOUR CODE HERE
# Expected output: number of objects, timesteps, etc.

# Task 2: Find the ego vehicle (SDC = Self-Driving Car)
# HINT: scenario.object_metadata contains object types
# Object type 1 = vehicle, with SDC being object index 0
# YOUR CODE HERE

# Task 3: Count object types
# Types: 1=Vehicle, 2=Pedestrian, 3=Cyclist
# HINT: Use jnp.unique with return_counts=True
# YOUR CODE HERE

# Task 4: Examine road graph structure
# HINT: scenario.roadgraph_points contains road geometry
# YOUR CODE HERE

Solution:

# Task 1
print(f"Number of objects: {scenario.num_objects}")
print(f"Number of timesteps: {scenario.remaining_timesteps}")
print(f"Current timestep: {scenario.timestep}")

# Task 2
sdc_idx = 0  # SDC is always index 0 in WOMD
sdc_trajectory = scenario.log_trajectory.xy[sdc_idx]
print(f"SDC trajectory shape: {sdc_trajectory.shape}")

# Task 3
object_types = scenario.object_metadata.object_types
unique, counts = jnp.unique(object_types, return_counts=True)
type_names = {1: 'Vehicle', 2: 'Pedestrian', 3: 'Cyclist'}
for t, c in zip(unique.tolist(), counts.tolist()):
    if t in type_names:
        print(f"{type_names[t]}: {c}")

# Task 4
roadgraph = scenario.roadgraph_points
print(f"Road graph points: {roadgraph.xy.shape}")

Exercise 2: Implementing IDM (Intermediate)

Goal: Implement the Intelligent Driver Model for sim agents.

"""
Exercise 2: Intelligent Driver Model Implementation
==================================================

The IDM is a car-following model that determines acceleration
based on the current speed and distance to the leading vehicle.

IDM Equation:
a = a_max * [1 - (v/v_desired)^4 - (s*/s)^2]

Where:
- a_max: maximum acceleration (e.g., 2 m/s^2)
- v: current speed
- v_desired: desired speed (e.g., 15 m/s for urban)
- s: current gap to leading vehicle
- s*: desired gap = s_0 + v*T + (v*dv)/(2*sqrt(a_max*b))
- s_0: minimum gap (e.g., 2 m)
- T: safe time headway (e.g., 1.5 s)
- b: comfortable deceleration (e.g., 3 m/s^2)
- dv: velocity difference (ego - leading)
"""

import jax.numpy as jnp

def idm_acceleration(
    ego_speed: float,
    leading_speed: float,
    gap: float,
    v_desired: float = 15.0,
    a_max: float = 2.0,
    b: float = 3.0,
    s_0: float = 2.0,
    T: float = 1.5,
) -> float:
    """
    Compute IDM acceleration.

    Args:
        ego_speed: Current speed of ego vehicle (m/s)
        leading_speed: Speed of leading vehicle (m/s)
        gap: Distance to leading vehicle (m)
        v_desired: Desired cruising speed (m/s)
        a_max: Maximum acceleration (m/s^2)
        b: Comfortable deceleration (m/s^2)
        s_0: Minimum gap (m)
        T: Safe time headway (s)

    Returns:
        Acceleration command (m/s^2)
    """

    # Task: Implement the IDM equation
    # YOUR CODE HERE

    # Step 1: Compute velocity difference
    dv = None  # YOUR CODE

    # Step 2: Compute desired gap s*
    s_star = None  # YOUR CODE

    # Step 3: Compute acceleration using IDM equation
    acceleration = None  # YOUR CODE

    return acceleration


# Test your implementation
def test_idm():
    # Test case 1: Free road (large gap)
    a1 = idm_acceleration(ego_speed=10.0, leading_speed=15.0, gap=100.0)
    print(f"Free road: a = {a1:.2f} m/s^2 (should be positive, accelerating)")

    # Test case 2: Close following (small gap)
    a2 = idm_acceleration(ego_speed=15.0, leading_speed=15.0, gap=10.0)
    print(f"Close following: a = {a2:.2f} m/s^2 (should be near zero)")

    # Test case 3: Emergency braking (very small gap, leading slower)
    a3 = idm_acceleration(ego_speed=15.0, leading_speed=5.0, gap=5.0)
    print(f"Emergency: a = {a3:.2f} m/s^2 (should be negative, braking)")

test_idm()

Solution:

def idm_acceleration(
    ego_speed: float,
    leading_speed: float,
    gap: float,
    v_desired: float = 15.0,
    a_max: float = 2.0,
    b: float = 3.0,
    s_0: float = 2.0,
    T: float = 1.5,
) -> float:

    # Step 1: Velocity difference (positive if ego faster than leading)
    dv = ego_speed - leading_speed

    # Step 2: Desired gap s*
    s_star = s_0 + ego_speed * T + (ego_speed * dv) / (2 * jnp.sqrt(a_max * b))
    s_star = jnp.maximum(s_star, s_0)  # Ensure minimum gap

    # Step 3: IDM acceleration
    # a = a_max * [1 - (v/v_desired)^4 - (s*/s)^2]
    free_road_term = (ego_speed / v_desired) ** 4
    interaction_term = (s_star / gap) ** 2

    acceleration = a_max * (1 - free_road_term - interaction_term)

    # Clamp to reasonable range
    acceleration = jnp.clip(acceleration, -b, a_max)

    return acceleration

Exercise 3: Behavior Cloning (Advanced)

Goal: Train a simple behavior cloning agent.

"""
Exercise 3: Training a Behavior Cloning Agent
============================================

Implement a training loop that learns to imitate expert trajectories.
"""

import jax
import jax.numpy as jnp
import flax.linen as nn
import optax
from typing import Dict, Tuple

class BCAgent(nn.Module):
    """Simple behavior cloning agent."""

    @nn.compact
    def __call__(self, observation: Dict) -> jnp.ndarray:
        """
        Predict action from observation.

        Args:
            observation: Dict with keys:
                - 'ego_state': [x, y, theta, vx, vy]
                - 'goal': [goal_x, goal_y]

        Returns:
            action: [acceleration, steering_curvature]
        """
        # Concatenate inputs
        ego = observation['ego_state']
        goal = observation['goal']
        x = jnp.concatenate([ego, goal])

        # MLP
        x = nn.Dense(128)(x)
        x = nn.relu(x)
        x = nn.Dense(64)(x)
        x = nn.relu(x)

        # Output heads
        acceleration = nn.Dense(1)(x)
        steering = nn.Dense(1)(x)

        return jnp.concatenate([acceleration, steering])


def extract_expert_actions(scenario, dynamics_model):
    """
    Extract expert actions from logged trajectories.

    This uses inverse dynamics: given the trajectory,
    compute what actions would produce it.

    Args:
        scenario: Waymax scenario with logged trajectories
        dynamics_model: Dynamics model (bicycle or delta)

    Returns:
        observations: List of observation dicts
        actions: Array of expert actions [T, 2]
    """
    # TODO: Implement expert action extraction
    # HINT: Use dynamics_model.inverse() to get actions from trajectory
    pass


def create_train_step(model, optimizer):
    """
    Create a JIT-compiled training step.

    Args:
        model: BCAgent instance
        optimizer: Optax optimizer

    Returns:
        train_step function
    """

    def loss_fn(params, observations, expert_actions):
        # TODO: Implement loss function
        # Predict actions and compute MSE loss
        pass

    @jax.jit
    def train_step(params, opt_state, batch):
        observations, expert_actions = batch

        loss, grads = jax.value_and_grad(loss_fn)(
            params, observations, expert_actions
        )

        updates, new_opt_state = optimizer.update(grads, opt_state)
        new_params = optax.apply_updates(params, updates)

        return new_params, new_opt_state, loss

    return train_step


# Main training loop
def train_bc_agent(scenarios, num_epochs=10, batch_size=32):
    """
    Train a behavior cloning agent.

    Tasks:
    1. Initialize model and optimizer
    2. Extract expert demonstrations
    3. Run training loop
    4. Evaluate on held-out scenarios
    """

    # Initialize
    model = BCAgent()
    rng = jax.random.PRNGKey(0)

    # TODO: Initialize parameters
    # HINT: model.init(rng, dummy_observation)

    # TODO: Create optimizer (try Adam with lr=1e-3)

    # TODO: Training loop

    pass

Exercise 4: Multi-Agent Rollout (Expert)

Goal: Implement coordinated multi-agent simulation with different agent types.

"""
Exercise 4: Multi-Agent Coordination Challenge
=============================================

Implement a simulation where:
- Agent 0 (SDC): Your learned policy
- Agent 1-5: IDM-controlled vehicles
- Agent 6+: Log playback

Evaluate collision rates and route progress.
"""

import jax
import jax.numpy as jnp
from waymax import env, config, dynamics, datatypes

def create_mixed_agent_controller(
    policy_fn,
    idm_params,
    num_controlled: int = 6,
):
    """
    Create a controller that uses different strategies for different agents.

    Args:
        policy_fn: Neural network policy for SDC
        idm_params: Parameters for IDM agents
        num_controlled: Number of agents (beyond SDC) to control with IDM

    Returns:
        Controller function that returns actions for all agents
    """

    def controller(state, observations):
        """
        Generate actions for all agents.

        Args:
            state: Current simulator state
            observations: Observations for all agents

        Returns:
            actions: Actions for all controllable agents
        """
        num_agents = state.num_objects
        actions = []

        for agent_idx in range(num_agents):
            if agent_idx == 0:
                # SDC: Use learned policy
                action = policy_fn(observations[agent_idx])
            elif agent_idx < num_controlled:
                # IDM agents: Use car-following model
                action = idm_action(state, agent_idx, idm_params)
            else:
                # Log playback: Extract action from logged trajectory
                action = extract_log_action(state, agent_idx)

            actions.append(action)

        return jnp.stack(actions)

    return controller


def evaluate_multi_agent(
    controller,
    scenarios,
    num_scenarios: int = 100,
    max_steps: int = 80,
):
    """
    Evaluate multi-agent controller across scenarios.

    Returns:
        Dictionary of aggregated metrics
    """

    results = {
        'sdc_collision_rate': 0.0,
        'sdc_offroad_rate': 0.0,
        'sdc_route_progress': 0.0,
        'idm_collision_rate': 0.0,
    }

    # TODO: Implement evaluation loop
    # 1. Reset environment for each scenario
    # 2. Run simulation with mixed controller
    # 3. Compute metrics
    # 4. Aggregate results

    pass

Common Pitfalls

Pitfall 1: Ignoring the Sim Agent Gap

The Mistake: Training against IDM sim agents and assuming good performance will transfer to real driving.

What Goes Wrong:

Training:                           Evaluation:
=========                           ===========
Your Policy vs IDM                  Your Policy vs Log Playback

IDM: [Always yields]                Log: [Human driver behavior]
     |                                   |
     v                                   v
Your Policy: [Learns to be         Your Policy: [Crashes!]
aggressive, cut off IDM cars]      Human didn't yield like IDM

The Fix:

  • Always evaluate against log playback
  • Train with mixture of IDM and log agents
  • Report metrics for both evaluation settings

Pitfall 2: Forgetting to Account for Valid Masks

The Mistake: Processing all agents/timesteps without checking validity masks.

What Goes Wrong:

# WRONG: Processes invalid (padded) data
all_positions = scenario.log_trajectory.xy  # Includes invalid entries!
mean_position = jnp.mean(all_positions)     # Garbage result

# RIGHT: Use validity masks
valid_mask = scenario.log_trajectory.valid
positions = scenario.log_trajectory.xy
mean_position = jnp.sum(positions * valid_mask) / jnp.sum(valid_mask)

The Fix: Always check and use validity masks for:

  • Agent trajectories
  • Road graph points
  • Traffic light states

Pitfall 3: Coordinate Frame Confusion

The Mistake: Mixing global and local coordinate frames.

What Goes Wrong:

Global Frame (Map):              Local Frame (Ego):
=================               ==================

      N                              Forward
      ^                                 ^
      |                                 |
W <---+---> E                   Left <--+---> Right
      |                                 |
      S                              Backward

Agent heading in global frame: 45 degrees (NE)
Agent sees obstacle at global (100, 100)

WRONG: Feed global coords to policy
       Policy thinks obstacle is to the northeast

RIGHT: Transform to ego-centric coords
       Obstacle is "ahead and to the right"

The Fix:

def global_to_local(global_pos, ego_pos, ego_heading):
    """Transform global coordinates to ego-centric frame."""
    # Translate
    relative = global_pos - ego_pos

    # Rotate
    cos_h = jnp.cos(-ego_heading)
    sin_h = jnp.sin(-ego_heading)

    local_x = relative[0] * cos_h - relative[1] * sin_h
    local_y = relative[0] * sin_h + relative[1] * cos_h

    return jnp.array([local_x, local_y])

Pitfall 4: Misunderstanding Timestep Indexing

The Mistake: Confusing which timestep you're at vs. which timestep you're predicting.

What Goes Wrong:

Timeline:
=========
t=0    t=1    t=2    t=3    ...    t=10 (current)    t=11 (next)
[log]  [log]  [log]  [log]  ...    [sim state]       [prediction]

WRONG: state.timestep gives you t=10
       You compute action for t=11
       But you read observation from t=9 by mistake

RIGHT:
       observation = state at t=10
       action = for transition t=10 -> t=11
       next_state = state at t=11

Pitfall 5: Ignoring Scenario Diversity

The Mistake: Testing only on "easy" scenarios (highway driving, no interactions).

What Goes Wrong:

Easy Scenarios:                 Hard Scenarios:
===============                 ================
Highway, light traffic          Dense urban
No turns                        Complex intersections
No pedestrians                  Pedestrians crossing
No construction                 Construction zones

Policy trained on easy: 99% success
Same policy on hard: 45% success

The Fix:

  • Use scenario tags to ensure diversity
  • Track per-scenario-type metrics
  • Weight training towards failure cases

Pitfall 6: JAX Compilation Overhead

The Mistake: Re-compiling JAX functions inside loops.

What Goes Wrong:

# WRONG: Recompiles every iteration
for scenario in scenarios:
    @jax.jit
    def step(state, action):  # New function object each time!
        return env.step(state, action)

    for t in range(80):
        state = step(state, action)  # Compiles on first call each loop!

# RIGHT: Compile once, reuse
@jax.jit
def step(state, action):
    return env.step(state, action)

for scenario in scenarios:
    for t in range(80):
        state = step(state, action)  # Uses cached compilation

Interview Questions

Conceptual Questions

Q1: Why is closed-loop evaluation important for motion prediction models?

Expected Answer: Open-loop evaluation only measures single-step prediction accuracy. In deployment, prediction errors compound - a small error at t=1 affects the state at t=2, which affects predictions at t=2, creating a cascade. Closed-loop evaluation captures this error propagation by feeding predictions back as input, revealing how models behave under distribution shift from their own errors.


Q2: Explain the trade-off between log playback and reactive (IDM) sim agents.

Expected Answer:

  • Log playback is realistic (real human behavior) but non-reactive (ignores ego actions, can cause "ghost" collisions)
  • IDM agents react to ego vehicle but follow simple rules that can be exploited
  • The paper shows RL agents trained against IDM have 4x higher collision rates when evaluated against log playback
  • Best practice: train against reactive agents but evaluate against both to detect overfitting

Q3: Why does Waymax use JAX instead of PyTorch/TensorFlow?

Expected Answer:

  • JAX's functional design enables efficient vectorization (vmap) across scenarios
  • XLA compilation optimizes the entire simulation graph for GPU/TPU
  • In-graph compilation means no Python interpreter overhead during rollouts
  • Differentiability enables gradient-based optimization through the simulator
  • Result: 1000+ Hz single-scenario speed vs. ~10 Hz for CPU-based simulators

Technical Questions

Q4: Derive the bicycle model update equations for heading angle.

Expected Answer:

For a bicycle model with:
- Speed v
- Steering curvature k = 1/r (inverse turning radius)
- Time step dt

The vehicle traces an arc with:
- Arc length s = v * dt
- Radius r = 1/k

Angular change = arc length / radius
               = s / r
               = v * dt * k

Therefore: theta_new = theta_old + v * k * dt

Q5: How would you detect if a trained agent is "cheating" by exploiting simulator artifacts?

Expected Answer:

  • Compare performance on IDM vs. log playback sim agents (large gap = exploitation)
  • Check for impossible maneuvers (kinematic infeasibility metric)
  • Analyze behavior patterns: does agent always take the same action regardless of context?
  • Test on held-out scenario types (new cities, new road layouts)
  • Compare to expert (log divergence) - too good might mean exploitation

Q6: Design a reward function for training a lane-change policy.

Expected Answer:

def lane_change_reward(state, action, next_state):
    reward = 0.0

    # Positive: Progress toward destination
    reward += 0.1 * route_progress_delta(state, next_state)

    # Positive: Successful lane change when needed
    if intended_lane_change and in_target_lane(next_state):
        reward += 1.0

    # Negative: Collisions (terminal)
    if collision(next_state):
        reward -= 10.0

    # Negative: Off-road (terminal)
    if offroad(next_state):
        reward -= 5.0

    # Negative: Uncomfortable driving
    reward -= 0.01 * jerk_penalty(action)
    reward -= 0.01 * lateral_acceleration_penalty(state, action)

    # Negative: Unsafe distance to other vehicles
    reward -= 0.1 * unsafe_distance_penalty(next_state)

    return reward

System Design Questions

Q7: How would you scale Waymax evaluation to 1 million scenarios?

Expected Answer:

  • Use data parallelism: distribute scenarios across multiple GPUs/TPUs
  • Batch scenarios with similar lengths together (minimize padding)
  • Use async data loading to overlap I/O with computation
  • Implement checkpointing to handle failures gracefully
  • Estimated time: 1M scenarios * 80 steps / (5000 steps/sec per GPU) / 8 GPUs ~ 30 minutes

Q8: Design a curriculum for training an RL agent in Waymax.

Expected Answer:

Stage 1: Lane Following (Easy)
- Highway scenarios, light traffic
- Reward: route progress
- Success criterion: >95% route completion

Stage 2: Car Following (Medium)
- Add leading vehicles
- Introduce IDM agents
- Reward: route progress + safe following distance

Stage 3: Dense Traffic (Hard)
- Urban scenarios, heavy traffic
- Reactive + log playback agents
- Reward: route progress + collision avoidance

Stage 4: Full Interaction (Expert)
- All scenario types
- Pedestrians, cyclists
- Complex intersections

Further Reading

Core Papers

  1. Waymo Open Motion Dataset

    • "Large Scale Interactive Motion Forecasting for Autonomous Driving" (Ettinger et al., 2021)
    • The dataset that powers Waymax
    • arXiv:2104.10133
  2. Wayformer

    • "Wayformer: Motion Forecasting via Simple & Efficient Attention Networks" (Nayakanti et al., 2022)
    • Transformer architecture used in Waymax baselines
    • arXiv:2207.05844
  3. Symphony

    • "Symphony: Learning Realistic and Diverse Agents for Autonomous Driving Simulation" (Igl et al., 2022)
    • Multi-agent simulation with learned sim agents
    • arXiv:2205.03195
  1. CARLA

    • High-fidelity 3D simulator
    • Slower but more visually realistic
    • carla.org
  2. nuPlan

    • Motional's planning benchmark
    • Similar data-driven approach
    • nuplan.org
  3. InterSim

    • "InterSim: Interactive Traffic Simulation via Explicit Relation Modeling" (Sun et al., 2022)
    • Focus on multi-agent interactions
    • arXiv:2210.08235

Foundational Concepts

  1. Intelligent Driver Model (IDM)

    • "Congested Traffic States in Empirical Observations and Microscopic Simulations" (Treiber et al., 2000)
    • The car-following model used for sim agents
    • Physical Review E, 62(2), 1805
  2. Behavior Cloning

    • "A Reduction of Imitation Learning and Structured Prediction to No-Regret Online Learning" (Ross et al., 2011)
    • DAgger algorithm for addressing distribution shift
    • arXiv:1011.0686
  3. Sim-to-Real Transfer

    • "Domain Randomization for Transferring Deep Neural Networks from Simulation to the Real World" (Tobin et al., 2017)
    • Techniques for bridging the simulation-reality gap
    • arXiv:1703.06907

JAX Resources

  1. JAX Documentation

  2. Brax

    • "Brax: A Differentiable Physics Engine for Large Scale Rigid Body Simulation" (Freeman et al., 2021)
    • JAX-based physics simulator, Waymax uses similar design patterns
    • arXiv:2106.13281

Summary

Waymax represents a paradigm shift in autonomous driving simulation:

AspectTraditional SimulatorsWaymax
Speed10-100 Hz1000+ Hz
DataSyntheticReal-world
HardwareCPUGPU/TPU
Multi-agentLimitedFull support
DifferentiableNoYes

Key Takeaways:

  1. Data-driven realism: Using real driving data grounds simulation in actual human behavior

  2. Hardware acceleration: JAX + GPU enables research at scale (millions of interactions)

  3. Closed-loop evaluation: Essential for understanding real-world performance

  4. Sim agent challenge: Beware of overfitting to sim agents - always evaluate against log playback

  5. Metrics matter: The six metrics capture different failure modes; optimize for all

Getting Started:

pip install git+https://github.com/waymo-research/waymax.git@main#egg=waymo-waymax

Resources:


This learning guide was created to help researchers and practitioners deeply understand the Waymax simulator. For corrections or suggestions, please open an issue on the repository.