Waymax Deep Dive: A Comprehensive Learning Guide
Paper: Waymax: An Accelerated, Data-Driven Simulator for Large-Scale Autonomous Driving Research Authors: Cole Gulino, Justin Fu, Wenjie Luo, et al. (Waymo Research) Published: NeurIPS 2023
Table of Contents
- Executive Summary
- Architecture Deep Dive
- Key Concepts Explained
- Interactive Code Examples
- Mental Models
- Key Equations
- Hands-On Exercises
- Common Pitfalls
- Interview Questions
- Further Reading
Executive Summary
The Problem: Simulation Speed vs. Realism Trade-off
Autonomous driving research faces a fundamental tension:
Traditional Simulators:
+------------------+ +------------------+
| CARLA/SUMO | | Data-Driven |
| - Realistic | | Approaches |
| - Slow (10Hz) | vs. | - Fast |
| - Synthetic | | - Less realistic|
| - CPU-bound | | - Limited scope |
+------------------+ +------------------+
Real-world costs: Testing an autonomous vehicle policy requires millions of simulation steps. At 10Hz on CPU, evaluating a single policy across 44,000 scenarios takes hours. This bottleneck cripples research iteration speed.
What Waymax Solves
Waymax breaks this trade-off by combining:
- Hardware Acceleration: Built entirely in JAX, runs on GPUs/TPUs
- Real-World Data: Uses Waymo Open Motion Dataset (100K+ real driving scenarios)
- Multi-Agent Simulation: All vehicles can be controlled, not just ego
- Differentiability: Enables gradient-based optimization through the simulator
The Result:
- 1000+ Hz simulation speed (100x faster than CARLA)
- 44,000 scenarios in < 2 minutes on 8 V100 GPUs
- Real driving behaviors, not synthetic
Why This Matters
Think of it like this: If training a self-driving AI is like teaching someone to drive, traditional simulators are like using a driving textbook with static images. Waymax is like having access to 100,000 dashcam recordings that you can rewind, modify, and replay at superhuman speed.
Impact Areas:
- Reinforcement Learning: Train policies with millions of environment interactions
- Imitation Learning: Learn from real expert trajectories
- Motion Prediction: Benchmark prediction models in closed-loop settings
- Safety Testing: Evaluate edge cases at scale
Architecture Deep Dive
High-Level System Design
+============================================================================+
| WAYMAX ARCHITECTURE |
+============================================================================+
| |
| +-------------------------+ +-----------------------------------+ |
| | WAYMO OPEN MOTION | | JAX COMPUTATION GRAPH | |
| | DATASET | | +---------------------------+ | |
| | +------------------+ | | | | | |
| | | 100K+ Scenarios | |--->| | GPU/TPU Accelerated | | |
| | | 7.64M Objects | | | | Vectorized Operations | | |
| | | 10Hz, 9 seconds | | | | | | |
| | +------------------+ | | +---------------------------+ | |
| +-------------------------+ +-----------------------------------+ |
| | |
| v |
| +----------------------------------------------------------------------+ |
| | SIMULATION CORE | |
| | +----------------+ +----------------+ +----------------------+ | |
| | | DYNAMICS | | STATE | | METRICS | | |
| | | MODELS | | MANAGEMENT | | | | |
| | | - Delta | | - Position | | - Collision | | |
| | | - Bicycle | | - Velocity | | - Off-road | | |
| | | | | - Rotation | | - Log divergence | | |
| | +----------------+ +----------------+ +----------------------+ | |
| +----------------------------------------------------------------------+ |
| | |
| v |
| +----------------------------------------------------------------------+ |
| | AGENT LAYER | |
| | +--------------+ +--------------+ +--------------+ +--------------+ | |
| | | EXPERT | | BEHAVIOR | | WAYFORMER | | RL | | |
| | | (Log) | | CLONING | | PREDICTION | | (DQN) | | |
| | +--------------+ +--------------+ +--------------+ +--------------+ | |
| +----------------------------------------------------------------------+ |
| | |
| v |
| +----------------------------------------------------------------------+ |
| | RL FRAMEWORK ADAPTERS | |
| | +------------------+ +------------------+ | |
| | | dm-env | | Brax | | |
| | +------------------+ +------------------+ | |
| +----------------------------------------------------------------------+ |
+============================================================================+
Core Data Flow
SCENARIO LOADING SIMULATION LOOP
================ ================
TFRecord Files For each timestep:
|
v +---> Observe State
+-----------+ | |
| Dataloader| | v
+-----------+ | Select Action
| | |
v | v
+-----------+ | Apply Dynamics
| Scenario | | |
| State |---------------------> v
+-----------+ | Compute Metrics
| | |
| Contains: | v
| - Agent trajectories +----Update State
| - Road geometry
| - Traffic signals
| - Route information
State Representation
+===========================================================================+
| SIMULATOR STATE |
+===========================================================================+
| |
| DYNAMIC STATE (per agent, per timestep) |
| +---------------------------------------------------------------------+ |
| | Position (x, y) | Heading (theta) | Velocity (vx, vy) | |
| | Bounding Box | Valid Mask | Agent Type | |
| +---------------------------------------------------------------------+ |
| |
| STATIC STATE (per scenario) |
| +---------------------------------------------------------------------+ |
| | Road Graph | |
| | +---------------------------------------------------------------+ | |
| | | Lane Centers | Lane Boundaries | Road Edges | | |
| | | Stop Signs | Crosswalks | Speed Bumps | | |
| | +---------------------------------------------------------------+ | |
| | | |
| | Traffic Lights | |
| | +---------------------------------------------------------------+ | |
| | | State per timestep (red/yellow/green) | | |
| | | Associated lane IDs | | |
| | +---------------------------------------------------------------+ | |
| +---------------------------------------------------------------------+ |
| |
| ROUTE INFORMATION |
| +---------------------------------------------------------------------+ |
| | Logged trajectory union | Driveable futures | |
| +---------------------------------------------------------------------+ |
| |
+===========================================================================+
Multi-Agent Control Architecture
CONTROL HIERARCHY
=================
+-----------------------------------------+
| SCENARIO (N agents) |
+-----------------------------------------+
|
+----------------+----------------+
| | |
v v v
+-----------+ +-----------+ +-----------+
| Agent 1 | | Agent 2 | | Agent N |
| CONTROLLED| | SIMULATED | | SIMULATED |
+-----------+ +-----------+ +-----------+
| | |
v v v
+-----------+ +-----------+ +-----------+
| Policy | | IDM | | Log |
| Network | | Model | | Playback |
+-----------+ +-----------+ +-----------+
| | |
+----------------+----------------+
|
v
+------------------------+
| DYNAMICS MODEL |
| (Bicycle / Delta) |
+------------------------+
|
v
+------------------------+
| NEXT STATE |
+------------------------+
Key Concepts Explained
1. Data-Driven Simulation
The Traditional Approach (Synthetic): Imagine building a driving game from scratch. You design roads, program car AI to follow rules, add pedestrians with random behavior patterns. It's like creating a fictional city - it works, but it's not real.
The Waymax Approach (Data-Driven): Instead, Waymax starts with actual recorded driving data. It's like having security camera footage from thousands of intersections. Every scenario you simulate actually happened in the real world.
SYNTHETIC SIMULATION DATA-DRIVEN SIMULATION
================== ====================
Designer creates: Real world provides:
+---------------+ +---------------+
| Road Layout | | Actual Roads |
| (Imagined) | | (SF, PHX...) |
+---------------+ +---------------+
| |
v v
+---------------+ +---------------+
| NPC Behavior | | Real Drivers |
| (Rule-based) | | (Human data) |
+---------------+ +---------------+
| |
v v
+---------------+ +---------------+
| Edge Cases | | Natural Edge |
| (Must guess) | | Cases |
+---------------+ +---------------+
Why This Matters: When your policy succeeds in Waymax, it handled situations that actually occurred. When it fails, you can analyze real failure modes.
2. Closed-Loop vs. Open-Loop Evaluation
Analogy: Think of training a chess AI.
-
Open-Loop: You show the AI 1 million chess positions and ask "What's the best next move?" The AI never sees the consequences of its choices.
-
Closed-Loop: You let the AI play full games. Each move affects the board, and it must handle the situations its own moves create.
OPEN-LOOP EVALUATION CLOSED-LOOP EVALUATION
==================== ======================
Time t: [Observe] --> [Predict] Time t: [Observe] --> [Act]
| |
v v
Compare to Environment
ground truth Updates
| |
(No feedback) v
Time t+1: [Observe] --> [Act]
|
(Compounding effects)
The Problem Waymax Solves: Most motion prediction benchmarks are open-loop. But autonomous vehicles operate closed-loop! A prediction model might look great on benchmarks but fail in deployment because it never trained on its own error propagation.
3. The Bicycle Model
Intuition: Why is it called a "bicycle" model for cars?
Imagine looking at a car from above:
REAL CAR (4 wheels) BICYCLE MODEL (2 wheels)
=================== ========================
+--[F]----[F]--+ +--[F]--+
| | | |
| ^ | ==> | ^ |
| | | | | |
+--[R]----[R]--+ +--[R]--+
Front and rear Front and rear
axles with 2 as single points
wheels each
For motion planning, we can simplify a 4-wheel car to 2 wheels (front and rear) because:
- Left and right wheels on an axle move together
- The key constraint is the turning radius
The Key Insight: A bicycle can't move sideways (no-slip constraint). Neither can a car. This constraint makes the math tractable while preserving realistic motion.
4. Route Conditioning
The Problem: Without knowing where an agent wants to go, predicting behavior is ambiguous.
+----------+
| |
+---------->| TURN |
| | LEFT |
| +----------+
|
+-------+-------+
| |
| CURRENT |--------> STRAIGHT
| POSITION |
| |
+-------+-------+
|
| +----------+
| | |
+---------->| TURN |
| RIGHT |
+----------+
Without route info: All three are plausible!
With route info: Only one makes sense.
What Waymax Does: It extracts "routes" from logged trajectories - the path each vehicle actually took. During training, the model learns to follow these routes, dramatically reducing ambiguity.
Results:
- Without routes: 2.31% off-route rate
- With routes: ~1% off-route rate
5. Sim Agents: The Multi-Agent Challenge
The Dilemma: When training your ego vehicle policy, what should other cars do?
OPTION 1: Log Playback OPTION 2: Reactive Agents
===================== ======================
Other cars follow Other cars respond
their recorded paths to ego's actions
[Log Car] [IDM Car]
| |
| (ignores ego) | (reacts to ego)
v v
+--------+ +--------+
| Drives | | Slows |
| through| | down |
| ego! | | |
+--------+ +--------+
Problem: Unrealistic Problem: May be "too easy"
collisions, ego learns to exploit, domain gap
to avoid ghosts at evaluation
The Key Finding: The paper discovered a critical issue - RL agents trained against IDM (rule-based) sim agents learn to exploit their predictable behavior. When evaluated against log playback, collision rates 4x higher!
This is a form of overfitting to the simulator.
Interactive Code Examples
Example 1: Basic Environment Setup
"""
Waymax Quick Start: Loading and Stepping Through Scenarios
==========================================================
This example shows the fundamental workflow for using Waymax.
"""
import jax
import jax.numpy as jnp
from waymax import config as waymax_config
from waymax import dataloader
from waymax import dynamics
from waymax import env as waymax_env
from waymax import datatypes
# ============================================================
# STEP 1: Load Data from Waymo Open Motion Dataset
# ============================================================
# Configure data loading (uses WOMD v1.1.0 training split)
data_config = waymax_config.WOD_1_1_0_TRAINING
# Create a scenario generator (lazy loading)
scenario_generator = dataloader.simulator_state_generator(data_config)
# Get one scenario
scenario = next(scenario_generator)
print(f"Scenario loaded!")
print(f" - Number of objects: {scenario.num_objects}")
print(f" - Number of timesteps: {scenario.remaining_timesteps}")
print(f" - Current timestep: {scenario.timestep}")
# ============================================================
# STEP 2: Configure the Environment
# ============================================================
# Choose a dynamics model
# Option 1: Bicycle model (realistic vehicle kinematics)
dynamics_model = dynamics.InvertibleBicycleModel()
# Option 2: Delta model (direct state control, less realistic)
# dynamics_model = dynamics.DeltaLocal()
# Create environment configuration
env_config = waymax_config.EnvironmentConfig(
# Which agents are controlled (vs. log playback)
controlled_object=waymax_config.ObjectType.SDC, # Self-driving car only
)
# Initialize the environment
environment = waymax_env.PlanningAgentEnvironment(
dynamics_model=dynamics_model,
config=env_config,
)
# ============================================================
# STEP 3: Run a Simulation Loop
# ============================================================
# Reset environment with our scenario
state = environment.reset(scenario)
# Storage for trajectory
trajectory = []
# Simulation loop
for step in range(80): # 8 seconds at 10Hz
# Get observation (what the agent sees)
observation = environment.observe(state)
# Create an action (here: just follow the log for demonstration)
# In practice, this would come from your policy network
action = datatypes.Action(
data=jnp.zeros((1, 2)), # [acceleration, steering]
valid=jnp.ones((1,), dtype=bool),
)
# Step the environment
state = environment.step(state, action)
# Store position for visualization
ego_pos = state.current_sim_trajectory.xy[0, state.timestep]
trajectory.append(ego_pos)
# Check if done
if state.is_done:
break
print(f"Simulation complete! Ran for {len(trajectory)} steps")
Example 2: Computing Metrics
"""
Waymax Metrics: Evaluating Agent Performance
============================================
This example demonstrates how to compute and interpret
the six core metrics in Waymax.
"""
from waymax import metrics
from waymax.metrics import comfort
from waymax.metrics import overlap
from waymax.metrics import roadgraph as roadgraph_metrics
from waymax.metrics import route as route_metrics
def evaluate_trajectory(state, env):
"""
Compute all Waymax metrics for a simulation run.
Args:
state: Final simulator state after rollout
env: The Waymax environment
Returns:
Dictionary of metric values
"""
results = {}
# --------------------------------------------------------
# METRIC 1: Log Divergence (Imitation Quality)
# --------------------------------------------------------
# How far did the agent deviate from the logged trajectory?
# Lower is better (means closer to human expert)
log_divergence = metrics.log_divergence(
state.sim_trajectory,
state.log_trajectory,
)
results['log_divergence_l2'] = float(jnp.mean(log_divergence.value))
print(f"Log Divergence (L2): {results['log_divergence_l2']:.3f} meters")
print(f" -> Agent deviated {results['log_divergence_l2']:.1f}m on average from expert")
# --------------------------------------------------------
# METRIC 2: Collision Detection
# --------------------------------------------------------
# Did the agent collide with any other object?
# Binary: True = collision occurred
collision_result = overlap.detect_collision(state)
results['collision'] = bool(jnp.any(collision_result.value))
if results['collision']:
print(f"Collision: YES - Agent hit another object!")
else:
print(f"Collision: NO - Agent avoided all collisions")
# --------------------------------------------------------
# METRIC 3: Off-Road Detection
# --------------------------------------------------------
# Did the agent drive outside driveable area?
# Binary: True = went off-road
offroad_result = roadgraph_metrics.offroad(state)
results['offroad'] = bool(jnp.any(offroad_result.value))
if results['offroad']:
print(f"Off-Road: YES - Agent left the road!")
else:
print(f"Off-Road: NO - Agent stayed on road")
# --------------------------------------------------------
# METRIC 4: Route Progress
# --------------------------------------------------------
# How much of the intended route did the agent complete?
# Higher is better (1.0 = completed full route)
# Note: Requires route information in the scenario
route_progress = route_metrics.route_progress(state)
results['route_progress'] = float(route_progress.value[0])
print(f"Route Progress: {results['route_progress']*100:.1f}%")
# --------------------------------------------------------
# METRIC 5: Kinematic Feasibility
# --------------------------------------------------------
# Were the agent's actions physically plausible?
# Checks: acceleration < 6 m/s^2, curvature < 0.3 m^-1
comfort_result = comfort.kinematic_infeasibility(state)
results['kinematic_infeasible'] = bool(jnp.any(comfort_result.value))
if results['kinematic_infeasible']:
print(f"Kinematic Feasibility: VIOLATED - Actions too aggressive!")
else:
print(f"Kinematic Feasibility: PASSED - Actions are realistic")
# --------------------------------------------------------
# METRIC 6: Wrong Way Detection
# --------------------------------------------------------
# Is the agent driving against traffic flow?
wrong_way = roadgraph_metrics.wrong_way(state)
results['wrong_way'] = bool(jnp.any(wrong_way.value))
return results
# Example usage in evaluation loop
def run_evaluation(policy, scenarios, num_scenarios=100):
"""
Evaluate a policy across multiple scenarios.
"""
all_results = []
for i, scenario in enumerate(scenarios):
if i >= num_scenarios:
break
# Run rollout with policy
state = rollout_policy(policy, scenario)
# Compute metrics
results = evaluate_trajectory(state, env)
all_results.append(results)
# Aggregate results
collision_rate = sum(r['collision'] for r in all_results) / len(all_results)
offroad_rate = sum(r['offroad'] for r in all_results) / len(all_results)
avg_divergence = sum(r['log_divergence_l2'] for r in all_results) / len(all_results)
print(f"\n=== EVALUATION SUMMARY ({num_scenarios} scenarios) ===")
print(f"Collision Rate: {collision_rate*100:.2f}%")
print(f"Off-Road Rate: {offroad_rate*100:.2f}%")
print(f"Average Log Divergence: {avg_divergence:.3f}m")
Example 3: Implementing a Simple Policy
"""
Waymax Policy: Building Your First Planning Agent
=================================================
This example shows how to implement a simple neural network
policy that can control a vehicle in Waymax.
"""
import jax
import jax.numpy as jnp
import flax.linen as nn
from typing import Tuple
class SimpleWaymaxPolicy(nn.Module):
"""
A simple MLP policy for Waymax.
Takes in observations and outputs actions for the bicycle model:
- Acceleration (m/s^2)
- Steering curvature (1/m)
"""
hidden_dims: Tuple[int, ...] = (256, 256, 128)
@nn.compact
def __call__(self, observation):
"""
Forward pass through the policy.
Args:
observation: Dictionary containing:
- ego_state: Current vehicle state [x, y, theta, vx, vy]
- roadgraph: Nearby road geometry
- other_agents: States of other vehicles
- route: Target waypoints to follow
Returns:
action: [acceleration, steering_curvature]
"""
# --------------------------------------------------------
# Step 1: Extract and flatten features
# --------------------------------------------------------
# Ego vehicle state (position, heading, velocity)
ego_state = observation['ego_state'] # Shape: [5]
# Route waypoints (next N points to follow)
route = observation['route'] # Shape: [N, 2]
route_flat = route.reshape(-1)
# Nearby agents (positions and velocities)
agents = observation['nearby_agents'] # Shape: [M, 4]
agents_flat = agents.reshape(-1)
# Concatenate all features
features = jnp.concatenate([
ego_state,
route_flat,
agents_flat,
])
# --------------------------------------------------------
# Step 2: MLP backbone
# --------------------------------------------------------
x = features
for dim in self.hidden_dims:
x = nn.Dense(dim)(x)
x = nn.relu(x)
x = nn.LayerNorm()(x)
# --------------------------------------------------------
# Step 3: Output heads
# --------------------------------------------------------
# Acceleration head: output in range [-4, 4] m/s^2
accel_logits = nn.Dense(1)(x)
acceleration = 4.0 * jnp.tanh(accel_logits)
# Steering head: output in range [-0.3, 0.3] 1/m
steer_logits = nn.Dense(1)(x)
steering = 0.3 * jnp.tanh(steer_logits)
# Combine into action
action = jnp.concatenate([acceleration, steering])
return action
def create_policy_step_fn(policy, params):
"""
Create a JIT-compiled policy step function.
This function can be used efficiently in the simulation loop.
"""
@jax.jit
def policy_step(observation):
return policy.apply(params, observation)
return policy_step
# ============================================================
# Training Loop Skeleton
# ============================================================
def train_behavior_cloning(
policy,
scenarios,
num_epochs: int = 100,
batch_size: int = 32,
learning_rate: float = 1e-4,
):
"""
Train policy using behavior cloning (imitation learning).
The agent learns to mimic the logged expert trajectories.
"""
import optax
# Initialize parameters
rng = jax.random.PRNGKey(0)
dummy_obs = get_dummy_observation()
params = policy.init(rng, dummy_obs)
# Optimizer
optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(params)
# Loss function: predict expert actions
def loss_fn(params, observations, expert_actions):
predicted_actions = policy.apply(params, observations)
# MSE loss between predicted and expert actions
loss = jnp.mean((predicted_actions - expert_actions) ** 2)
return loss
# JIT-compiled training step
@jax.jit
def train_step(params, opt_state, batch):
observations, expert_actions = batch
loss, grads = jax.value_and_grad(loss_fn)(
params, observations, expert_actions
)
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return params, opt_state, loss
# Training loop
for epoch in range(num_epochs):
epoch_loss = 0.0
num_batches = 0
for batch in create_batches(scenarios, batch_size):
params, opt_state, loss = train_step(params, opt_state, batch)
epoch_loss += loss
num_batches += 1
avg_loss = epoch_loss / num_batches
print(f"Epoch {epoch+1}/{num_epochs} - Loss: {avg_loss:.4f}")
return params
Example 4: Vectorized Multi-Scenario Evaluation
"""
Waymax Vectorization: Evaluating Many Scenarios in Parallel
==========================================================
This is where Waymax really shines - batch processing scenarios
on GPU for massive speedup.
"""
import jax
import jax.numpy as jnp
from functools import partial
def create_vectorized_env(dynamics_model, env_config, batch_size):
"""
Create a vectorized environment for parallel scenario evaluation.
Instead of running scenarios one-by-one, we run `batch_size`
scenarios simultaneously on the GPU.
"""
# Create base environment
base_env = waymax_env.PlanningAgentEnvironment(
dynamics_model=dynamics_model,
config=env_config,
)
# Vectorize the step function across scenarios
# vmap = "vectorized map" - applies function to batched inputs
@jax.jit
def batched_step(states, actions):
"""
Step multiple environments in parallel.
Args:
states: Batched simulator states [batch_size, ...]
actions: Batched actions [batch_size, action_dim]
Returns:
next_states: Updated states [batch_size, ...]
"""
return jax.vmap(base_env.step)(states, actions)
@jax.jit
def batched_reset(scenarios):
"""Reset multiple scenarios in parallel."""
return jax.vmap(base_env.reset)(scenarios)
@jax.jit
def batched_observe(states):
"""Get observations for multiple states in parallel."""
return jax.vmap(base_env.observe)(states)
return batched_reset, batched_step, batched_observe
def benchmark_throughput(env, scenarios, num_steps=80):
"""
Benchmark simulation throughput.
This demonstrates the speed advantage of Waymax.
"""
import time
# Prepare batched scenarios
batched_reset, batched_step, batched_observe = env
# Stack scenarios into a batch
scenario_batch = stack_scenarios(scenarios)
# Warm-up (JIT compilation happens here)
states = batched_reset(scenario_batch)
dummy_actions = jnp.zeros((len(scenarios), 2))
_ = batched_step(states, dummy_actions)
# Timed run
start_time = time.time()
states = batched_reset(scenario_batch)
for _ in range(num_steps):
actions = jnp.zeros((len(scenarios), 2)) # Dummy actions
states = batched_step(states, actions)
# Wait for GPU to finish
jax.block_until_ready(states)
elapsed = time.time() - start_time
# Calculate throughput
total_steps = len(scenarios) * num_steps
steps_per_second = total_steps / elapsed
print(f"=== BENCHMARK RESULTS ===")
print(f"Batch size: {len(scenarios)}")
print(f"Steps per scenario: {num_steps}")
print(f"Total simulation steps: {total_steps}")
print(f"Total time: {elapsed:.3f} seconds")
print(f"Throughput: {steps_per_second:.0f} steps/second")
print(f"Equivalent Hz: {steps_per_second/len(scenarios):.0f} Hz per scenario")
# Expected output (on V100 GPU):
# === BENCHMARK RESULTS ===
# Batch size: 16
# Steps per scenario: 80
# Total simulation steps: 1280
# Total time: 0.256 seconds
# Throughput: 5000 steps/second
# Equivalent Hz: 312 Hz per scenario
Mental Models
Mental Model 1: The Replay-and-Branch Metaphor
Think of Waymax like a VCR with superpowers:
TRADITIONAL VIDEO REPLAY WAYMAX "BRANCHING" REPLAY
====================== =========================
[Play] ---> [Watch] [Play] ---> [Watch]
|
|---> [What if I
| brake here?]
|
|---> [What if I
| accelerate?]
|
+---> [What if I
change lanes?]
Can only watch what Can intervene and see
happened alternative futures
Mental Model 2: The Driving Instructor Analogy
OPEN-LOOP (Traditional) CLOSED-LOOP (Waymax)
======================= ====================
Instructor: "Watch this video Instructor: "Now YOU drive
of expert driving" in this recorded scenario"
Student: [Watches passively] Student: [Takes the wheel]
|
v
Student makes a mistake
|
v
Scenario continues with
the consequences!
|
v
Student learns from
compounding errors
Mental Model 3: The Simulation Fidelity Spectrum
LOW FIDELITY HIGH FIDELITY
(Fast, Simple) (Slow, Realistic)
| |
v v
+--------+ +--------+ +--------+ +--------+ +--------+
| Point | | Simple | | | | Physics| | Full |
| Mass |---->| Bicycle|---->| WAYMAX |---->| Sim |---->| Digital|
| Model | | Model | | | | (CARLA)| | Twin |
+--------+ +--------+ +--------+ +--------+ +--------+
^
|
SWEET SPOT
- Real data
- Fast (GPU)
- Good enough physics
- Multi-agent
Mental Model 4: The Action Space Trade-off
DELTA ACTION SPACE BICYCLE ACTION SPACE
================== ====================
Direct control: Indirect control:
"Move to (x+dx, y+dy)" "Accelerate by a, steer by k"
+---+ +---+
| * | <-- New position | * | <-- Current position
+---+ +---+
^ |
| | Physics
| Teleport! | simulation
| v
+---+ +---+
| * | <-- Current position | * | <-- New position
+---+ +---+
PROS: PROS:
- Can simulate anything - Physically realistic
- Works for pedestrians - Transferable to real cars
- Simple - Natural constraints
CONS: CONS:
- Unrealistic motion - Only works for vehicles
- Hard to transfer - More complex
Mental Model 5: The Sim Agent Problem
THE MULTI-AGENT DILEMMA
=======================
You're training an autonomous car. But what should the OTHER cars do?
Option A: "Ghost Replay"
========================
Other cars follow their recorded paths, ignoring your car.
Your car: [Learns to drive through "ghosts"]
Result: Unrealistic collisions, weird behaviors
Your Car Other Car
| |
v v
+---+ +---+
| > | crash! | < | (keeps going, ignores you)
+---+ +---+
Option B: "Reactive Agents (IDM)"
=================================
Other cars react to your car using simple rules.
Your car: [Learns to exploit predictable behavior]
Result: Overfits to IDM, fails against real humans
Your Car IDM Car
| |
v v
+---+ +---+
| > | | | (always yields)
+---+ +---+
^
|
[IDM always brakes]
[Your car learns to be aggressive]
THE WAYMAX INSIGHT: You must evaluate against BOTH
=================================================
Train vs IDM, but test vs Log Playback to catch overfitting!
Key Equations
Equation 1: Bicycle Model Dynamics
The bicycle model approximates vehicle motion with two key control inputs:
- a: acceleration (m/s^2)
- k: steering curvature (1/m, i.e., inverse turning radius)
STATE UPDATE EQUATIONS
======================
Current state: (x, y, theta, v) -- position, heading, speed
Action: (a, k) -- acceleration, curvature
Time step: dt = 0.1s -- 10 Hz simulation
New speed:
v' = v + a * dt
Intuition: Speed increases by acceleration times time.
If a = 2 m/s^2 and dt = 0.1s, speed increases by 0.2 m/s
New heading:
theta' = theta + v * k * dt
Intuition: Heading change = arc length / radius
Arc length = v * dt (distance traveled)
Curvature k = 1/radius
So: d(theta) = v * dt * k
New position:
x' = x + v * cos(theta) * dt
y' = y + v * sin(theta) * dt
Intuition: Project velocity onto x and y axes
Worked Example:
Starting state: x=0, y=0, theta=0 (facing right), v=10 m/s
Action: a=0 (constant speed), k=0.1 (turning left)
dt = 0.1s
Step 1: v' = 10 + 0 * 0.1 = 10 m/s
Step 2: theta' = 0 + 10 * 0.1 * 0.1 = 0.1 rad (5.7 degrees)
Step 3: x' = 0 + 10 * cos(0) * 0.1 = 1.0 m
y' = 0 + 10 * sin(0) * 0.1 = 0.0 m
After one step: (1.0, 0.0, 0.1 rad, 10 m/s)
Moved forward 1m, started turning left
Equation 2: Delta Action Space
For non-vehicle objects (pedestrians, cyclists) or when you want direct control:
DELTA UPDATES
=============
Action: (dx, dy, d_theta) -- change in position and heading
x' = x + dx
y' = y + dy
theta' = theta + d_theta
Intuition: Direct "teleportation" - specify exactly where to go.
No physics constraints, so any motion is possible.
Equation 3: Log Divergence Metric
Measures how far the simulated trajectory deviates from the logged expert:
LOG DIVERGENCE (L2)
===================
Given:
- Simulated trajectory: [(x_1, y_1), (x_2, y_2), ..., (x_T, y_T)]
- Logged trajectory: [(x'_1, y'_1), (x'_2, y'_2), ..., (x'_T, y'_T)]
Log Divergence = (1/T) * SUM_t sqrt((x_t - x'_t)^2 + (y_t - y'_t)^2)
Intuition: Average Euclidean distance between simulated and logged
positions across all timesteps.
Example:
Simulated: [(0,0), (1,0), (2,0)]
Logged: [(0,0), (1,1), (2,2)]
Distances: [0, 1, 2]
Log Divergence = (0 + 1 + 2) / 3 = 1.0 meter
Equation 4: Collision Detection
Bounding box overlap check for all pairs of objects:
COLLISION DETECTION
===================
Each object has a bounding box defined by:
- Center: (cx, cy)
- Dimensions: (length, width)
- Heading: theta
Collision occurs when bounding boxes overlap.
Simplified 2D box overlap (axis-aligned):
Box A: [x_min_A, x_max_A] x [y_min_A, y_max_A]
Box B: [x_min_B, x_max_B] x [y_min_B, y_max_B]
Overlap = (x_min_A < x_max_B) AND (x_max_A > x_min_B) AND
(y_min_A < y_max_B) AND (y_max_A > y_min_B)
For rotated boxes, use Separating Axis Theorem (SAT).
Equation 5: Kinematic Feasibility Constraints
Ensures actions are physically plausible:
KINEMATIC CONSTRAINTS
=====================
Acceleration limit: |a| <= 6 m/s^2
- Typical car: 0-60 mph in ~4.5s requires ~6 m/s^2
- Hard braking: typically 8-10 m/s^2 max
Curvature limit: |k| <= 0.3 m^-1
- Minimum turning radius: r = 1/k = 3.3 meters
- Typical car can turn tighter at low speeds
Infeasibility metric:
kinematic_infeasible = (|a| > 6) OR (|k| > 0.3)
If violated: Agent is executing physically impossible maneuvers
Equation 6: Route Progress
Measures completion along intended path:
ROUTE PROGRESS
==============
Given:
- Route: sequence of waypoints defining intended path
- Current position: (x, y)
Progress = (distance traveled along route) / (total route length)
Calculated via projection onto route polyline:
1. Find closest point on route to current position
2. Compute distance from route start to that point
3. Divide by total route length
Range: [0, 1] where 1.0 = completed full route
Hands-On Exercises
Exercise 1: Data Exploration (Beginner)
Goal: Understand the structure of Waymo Open Motion Dataset scenarios.
"""
Exercise 1: Exploring a Waymax Scenario
=======================================
Tasks:
1. Load a scenario and print basic statistics
2. Identify the ego vehicle (SDC)
3. Count different object types
4. Visualize the road graph
"""
from waymax import config, dataloader
import jax.numpy as jnp
# TODO: Load a scenario
scenarios = dataloader.simulator_state_generator(config.WOD_1_1_0_TRAINING)
scenario = next(scenarios)
# Task 1: Print basic statistics
# HINT: scenario.num_objects, scenario.remaining_timesteps
print("=== SCENARIO STATISTICS ===")
# YOUR CODE HERE
# Expected output: number of objects, timesteps, etc.
# Task 2: Find the ego vehicle (SDC = Self-Driving Car)
# HINT: scenario.object_metadata contains object types
# Object type 1 = vehicle, with SDC being object index 0
# YOUR CODE HERE
# Task 3: Count object types
# Types: 1=Vehicle, 2=Pedestrian, 3=Cyclist
# HINT: Use jnp.unique with return_counts=True
# YOUR CODE HERE
# Task 4: Examine road graph structure
# HINT: scenario.roadgraph_points contains road geometry
# YOUR CODE HERE
Solution:
# Task 1
print(f"Number of objects: {scenario.num_objects}")
print(f"Number of timesteps: {scenario.remaining_timesteps}")
print(f"Current timestep: {scenario.timestep}")
# Task 2
sdc_idx = 0 # SDC is always index 0 in WOMD
sdc_trajectory = scenario.log_trajectory.xy[sdc_idx]
print(f"SDC trajectory shape: {sdc_trajectory.shape}")
# Task 3
object_types = scenario.object_metadata.object_types
unique, counts = jnp.unique(object_types, return_counts=True)
type_names = {1: 'Vehicle', 2: 'Pedestrian', 3: 'Cyclist'}
for t, c in zip(unique.tolist(), counts.tolist()):
if t in type_names:
print(f"{type_names[t]}: {c}")
# Task 4
roadgraph = scenario.roadgraph_points
print(f"Road graph points: {roadgraph.xy.shape}")
Exercise 2: Implementing IDM (Intermediate)
Goal: Implement the Intelligent Driver Model for sim agents.
"""
Exercise 2: Intelligent Driver Model Implementation
==================================================
The IDM is a car-following model that determines acceleration
based on the current speed and distance to the leading vehicle.
IDM Equation:
a = a_max * [1 - (v/v_desired)^4 - (s*/s)^2]
Where:
- a_max: maximum acceleration (e.g., 2 m/s^2)
- v: current speed
- v_desired: desired speed (e.g., 15 m/s for urban)
- s: current gap to leading vehicle
- s*: desired gap = s_0 + v*T + (v*dv)/(2*sqrt(a_max*b))
- s_0: minimum gap (e.g., 2 m)
- T: safe time headway (e.g., 1.5 s)
- b: comfortable deceleration (e.g., 3 m/s^2)
- dv: velocity difference (ego - leading)
"""
import jax.numpy as jnp
def idm_acceleration(
ego_speed: float,
leading_speed: float,
gap: float,
v_desired: float = 15.0,
a_max: float = 2.0,
b: float = 3.0,
s_0: float = 2.0,
T: float = 1.5,
) -> float:
"""
Compute IDM acceleration.
Args:
ego_speed: Current speed of ego vehicle (m/s)
leading_speed: Speed of leading vehicle (m/s)
gap: Distance to leading vehicle (m)
v_desired: Desired cruising speed (m/s)
a_max: Maximum acceleration (m/s^2)
b: Comfortable deceleration (m/s^2)
s_0: Minimum gap (m)
T: Safe time headway (s)
Returns:
Acceleration command (m/s^2)
"""
# Task: Implement the IDM equation
# YOUR CODE HERE
# Step 1: Compute velocity difference
dv = None # YOUR CODE
# Step 2: Compute desired gap s*
s_star = None # YOUR CODE
# Step 3: Compute acceleration using IDM equation
acceleration = None # YOUR CODE
return acceleration
# Test your implementation
def test_idm():
# Test case 1: Free road (large gap)
a1 = idm_acceleration(ego_speed=10.0, leading_speed=15.0, gap=100.0)
print(f"Free road: a = {a1:.2f} m/s^2 (should be positive, accelerating)")
# Test case 2: Close following (small gap)
a2 = idm_acceleration(ego_speed=15.0, leading_speed=15.0, gap=10.0)
print(f"Close following: a = {a2:.2f} m/s^2 (should be near zero)")
# Test case 3: Emergency braking (very small gap, leading slower)
a3 = idm_acceleration(ego_speed=15.0, leading_speed=5.0, gap=5.0)
print(f"Emergency: a = {a3:.2f} m/s^2 (should be negative, braking)")
test_idm()
Solution:
def idm_acceleration(
ego_speed: float,
leading_speed: float,
gap: float,
v_desired: float = 15.0,
a_max: float = 2.0,
b: float = 3.0,
s_0: float = 2.0,
T: float = 1.5,
) -> float:
# Step 1: Velocity difference (positive if ego faster than leading)
dv = ego_speed - leading_speed
# Step 2: Desired gap s*
s_star = s_0 + ego_speed * T + (ego_speed * dv) / (2 * jnp.sqrt(a_max * b))
s_star = jnp.maximum(s_star, s_0) # Ensure minimum gap
# Step 3: IDM acceleration
# a = a_max * [1 - (v/v_desired)^4 - (s*/s)^2]
free_road_term = (ego_speed / v_desired) ** 4
interaction_term = (s_star / gap) ** 2
acceleration = a_max * (1 - free_road_term - interaction_term)
# Clamp to reasonable range
acceleration = jnp.clip(acceleration, -b, a_max)
return acceleration
Exercise 3: Behavior Cloning (Advanced)
Goal: Train a simple behavior cloning agent.
"""
Exercise 3: Training a Behavior Cloning Agent
============================================
Implement a training loop that learns to imitate expert trajectories.
"""
import jax
import jax.numpy as jnp
import flax.linen as nn
import optax
from typing import Dict, Tuple
class BCAgent(nn.Module):
"""Simple behavior cloning agent."""
@nn.compact
def __call__(self, observation: Dict) -> jnp.ndarray:
"""
Predict action from observation.
Args:
observation: Dict with keys:
- 'ego_state': [x, y, theta, vx, vy]
- 'goal': [goal_x, goal_y]
Returns:
action: [acceleration, steering_curvature]
"""
# Concatenate inputs
ego = observation['ego_state']
goal = observation['goal']
x = jnp.concatenate([ego, goal])
# MLP
x = nn.Dense(128)(x)
x = nn.relu(x)
x = nn.Dense(64)(x)
x = nn.relu(x)
# Output heads
acceleration = nn.Dense(1)(x)
steering = nn.Dense(1)(x)
return jnp.concatenate([acceleration, steering])
def extract_expert_actions(scenario, dynamics_model):
"""
Extract expert actions from logged trajectories.
This uses inverse dynamics: given the trajectory,
compute what actions would produce it.
Args:
scenario: Waymax scenario with logged trajectories
dynamics_model: Dynamics model (bicycle or delta)
Returns:
observations: List of observation dicts
actions: Array of expert actions [T, 2]
"""
# TODO: Implement expert action extraction
# HINT: Use dynamics_model.inverse() to get actions from trajectory
pass
def create_train_step(model, optimizer):
"""
Create a JIT-compiled training step.
Args:
model: BCAgent instance
optimizer: Optax optimizer
Returns:
train_step function
"""
def loss_fn(params, observations, expert_actions):
# TODO: Implement loss function
# Predict actions and compute MSE loss
pass
@jax.jit
def train_step(params, opt_state, batch):
observations, expert_actions = batch
loss, grads = jax.value_and_grad(loss_fn)(
params, observations, expert_actions
)
updates, new_opt_state = optimizer.update(grads, opt_state)
new_params = optax.apply_updates(params, updates)
return new_params, new_opt_state, loss
return train_step
# Main training loop
def train_bc_agent(scenarios, num_epochs=10, batch_size=32):
"""
Train a behavior cloning agent.
Tasks:
1. Initialize model and optimizer
2. Extract expert demonstrations
3. Run training loop
4. Evaluate on held-out scenarios
"""
# Initialize
model = BCAgent()
rng = jax.random.PRNGKey(0)
# TODO: Initialize parameters
# HINT: model.init(rng, dummy_observation)
# TODO: Create optimizer (try Adam with lr=1e-3)
# TODO: Training loop
pass
Exercise 4: Multi-Agent Rollout (Expert)
Goal: Implement coordinated multi-agent simulation with different agent types.
"""
Exercise 4: Multi-Agent Coordination Challenge
=============================================
Implement a simulation where:
- Agent 0 (SDC): Your learned policy
- Agent 1-5: IDM-controlled vehicles
- Agent 6+: Log playback
Evaluate collision rates and route progress.
"""
import jax
import jax.numpy as jnp
from waymax import env, config, dynamics, datatypes
def create_mixed_agent_controller(
policy_fn,
idm_params,
num_controlled: int = 6,
):
"""
Create a controller that uses different strategies for different agents.
Args:
policy_fn: Neural network policy for SDC
idm_params: Parameters for IDM agents
num_controlled: Number of agents (beyond SDC) to control with IDM
Returns:
Controller function that returns actions for all agents
"""
def controller(state, observations):
"""
Generate actions for all agents.
Args:
state: Current simulator state
observations: Observations for all agents
Returns:
actions: Actions for all controllable agents
"""
num_agents = state.num_objects
actions = []
for agent_idx in range(num_agents):
if agent_idx == 0:
# SDC: Use learned policy
action = policy_fn(observations[agent_idx])
elif agent_idx < num_controlled:
# IDM agents: Use car-following model
action = idm_action(state, agent_idx, idm_params)
else:
# Log playback: Extract action from logged trajectory
action = extract_log_action(state, agent_idx)
actions.append(action)
return jnp.stack(actions)
return controller
def evaluate_multi_agent(
controller,
scenarios,
num_scenarios: int = 100,
max_steps: int = 80,
):
"""
Evaluate multi-agent controller across scenarios.
Returns:
Dictionary of aggregated metrics
"""
results = {
'sdc_collision_rate': 0.0,
'sdc_offroad_rate': 0.0,
'sdc_route_progress': 0.0,
'idm_collision_rate': 0.0,
}
# TODO: Implement evaluation loop
# 1. Reset environment for each scenario
# 2. Run simulation with mixed controller
# 3. Compute metrics
# 4. Aggregate results
pass
Common Pitfalls
Pitfall 1: Ignoring the Sim Agent Gap
The Mistake: Training against IDM sim agents and assuming good performance will transfer to real driving.
What Goes Wrong:
Training: Evaluation:
========= ===========
Your Policy vs IDM Your Policy vs Log Playback
IDM: [Always yields] Log: [Human driver behavior]
| |
v v
Your Policy: [Learns to be Your Policy: [Crashes!]
aggressive, cut off IDM cars] Human didn't yield like IDM
The Fix:
- Always evaluate against log playback
- Train with mixture of IDM and log agents
- Report metrics for both evaluation settings
Pitfall 2: Forgetting to Account for Valid Masks
The Mistake: Processing all agents/timesteps without checking validity masks.
What Goes Wrong:
# WRONG: Processes invalid (padded) data
all_positions = scenario.log_trajectory.xy # Includes invalid entries!
mean_position = jnp.mean(all_positions) # Garbage result
# RIGHT: Use validity masks
valid_mask = scenario.log_trajectory.valid
positions = scenario.log_trajectory.xy
mean_position = jnp.sum(positions * valid_mask) / jnp.sum(valid_mask)
The Fix: Always check and use validity masks for:
- Agent trajectories
- Road graph points
- Traffic light states
Pitfall 3: Coordinate Frame Confusion
The Mistake: Mixing global and local coordinate frames.
What Goes Wrong:
Global Frame (Map): Local Frame (Ego):
================= ==================
N Forward
^ ^
| |
W <---+---> E Left <--+---> Right
| |
S Backward
Agent heading in global frame: 45 degrees (NE)
Agent sees obstacle at global (100, 100)
WRONG: Feed global coords to policy
Policy thinks obstacle is to the northeast
RIGHT: Transform to ego-centric coords
Obstacle is "ahead and to the right"
The Fix:
def global_to_local(global_pos, ego_pos, ego_heading):
"""Transform global coordinates to ego-centric frame."""
# Translate
relative = global_pos - ego_pos
# Rotate
cos_h = jnp.cos(-ego_heading)
sin_h = jnp.sin(-ego_heading)
local_x = relative[0] * cos_h - relative[1] * sin_h
local_y = relative[0] * sin_h + relative[1] * cos_h
return jnp.array([local_x, local_y])
Pitfall 4: Misunderstanding Timestep Indexing
The Mistake: Confusing which timestep you're at vs. which timestep you're predicting.
What Goes Wrong:
Timeline:
=========
t=0 t=1 t=2 t=3 ... t=10 (current) t=11 (next)
[log] [log] [log] [log] ... [sim state] [prediction]
WRONG: state.timestep gives you t=10
You compute action for t=11
But you read observation from t=9 by mistake
RIGHT:
observation = state at t=10
action = for transition t=10 -> t=11
next_state = state at t=11
Pitfall 5: Ignoring Scenario Diversity
The Mistake: Testing only on "easy" scenarios (highway driving, no interactions).
What Goes Wrong:
Easy Scenarios: Hard Scenarios:
=============== ================
Highway, light traffic Dense urban
No turns Complex intersections
No pedestrians Pedestrians crossing
No construction Construction zones
Policy trained on easy: 99% success
Same policy on hard: 45% success
The Fix:
- Use scenario tags to ensure diversity
- Track per-scenario-type metrics
- Weight training towards failure cases
Pitfall 6: JAX Compilation Overhead
The Mistake: Re-compiling JAX functions inside loops.
What Goes Wrong:
# WRONG: Recompiles every iteration
for scenario in scenarios:
@jax.jit
def step(state, action): # New function object each time!
return env.step(state, action)
for t in range(80):
state = step(state, action) # Compiles on first call each loop!
# RIGHT: Compile once, reuse
@jax.jit
def step(state, action):
return env.step(state, action)
for scenario in scenarios:
for t in range(80):
state = step(state, action) # Uses cached compilation
Interview Questions
Conceptual Questions
Q1: Why is closed-loop evaluation important for motion prediction models?
Expected Answer: Open-loop evaluation only measures single-step prediction accuracy. In deployment, prediction errors compound - a small error at t=1 affects the state at t=2, which affects predictions at t=2, creating a cascade. Closed-loop evaluation captures this error propagation by feeding predictions back as input, revealing how models behave under distribution shift from their own errors.
Q2: Explain the trade-off between log playback and reactive (IDM) sim agents.
Expected Answer:
- Log playback is realistic (real human behavior) but non-reactive (ignores ego actions, can cause "ghost" collisions)
- IDM agents react to ego vehicle but follow simple rules that can be exploited
- The paper shows RL agents trained against IDM have 4x higher collision rates when evaluated against log playback
- Best practice: train against reactive agents but evaluate against both to detect overfitting
Q3: Why does Waymax use JAX instead of PyTorch/TensorFlow?
Expected Answer:
- JAX's functional design enables efficient vectorization (vmap) across scenarios
- XLA compilation optimizes the entire simulation graph for GPU/TPU
- In-graph compilation means no Python interpreter overhead during rollouts
- Differentiability enables gradient-based optimization through the simulator
- Result: 1000+ Hz single-scenario speed vs. ~10 Hz for CPU-based simulators
Technical Questions
Q4: Derive the bicycle model update equations for heading angle.
Expected Answer:
For a bicycle model with:
- Speed v
- Steering curvature k = 1/r (inverse turning radius)
- Time step dt
The vehicle traces an arc with:
- Arc length s = v * dt
- Radius r = 1/k
Angular change = arc length / radius
= s / r
= v * dt * k
Therefore: theta_new = theta_old + v * k * dt
Q5: How would you detect if a trained agent is "cheating" by exploiting simulator artifacts?
Expected Answer:
- Compare performance on IDM vs. log playback sim agents (large gap = exploitation)
- Check for impossible maneuvers (kinematic infeasibility metric)
- Analyze behavior patterns: does agent always take the same action regardless of context?
- Test on held-out scenario types (new cities, new road layouts)
- Compare to expert (log divergence) - too good might mean exploitation
Q6: Design a reward function for training a lane-change policy.
Expected Answer:
def lane_change_reward(state, action, next_state):
reward = 0.0
# Positive: Progress toward destination
reward += 0.1 * route_progress_delta(state, next_state)
# Positive: Successful lane change when needed
if intended_lane_change and in_target_lane(next_state):
reward += 1.0
# Negative: Collisions (terminal)
if collision(next_state):
reward -= 10.0
# Negative: Off-road (terminal)
if offroad(next_state):
reward -= 5.0
# Negative: Uncomfortable driving
reward -= 0.01 * jerk_penalty(action)
reward -= 0.01 * lateral_acceleration_penalty(state, action)
# Negative: Unsafe distance to other vehicles
reward -= 0.1 * unsafe_distance_penalty(next_state)
return reward
System Design Questions
Q7: How would you scale Waymax evaluation to 1 million scenarios?
Expected Answer:
- Use data parallelism: distribute scenarios across multiple GPUs/TPUs
- Batch scenarios with similar lengths together (minimize padding)
- Use async data loading to overlap I/O with computation
- Implement checkpointing to handle failures gracefully
- Estimated time: 1M scenarios * 80 steps / (5000 steps/sec per GPU) / 8 GPUs ~ 30 minutes
Q8: Design a curriculum for training an RL agent in Waymax.
Expected Answer:
Stage 1: Lane Following (Easy)
- Highway scenarios, light traffic
- Reward: route progress
- Success criterion: >95% route completion
Stage 2: Car Following (Medium)
- Add leading vehicles
- Introduce IDM agents
- Reward: route progress + safe following distance
Stage 3: Dense Traffic (Hard)
- Urban scenarios, heavy traffic
- Reactive + log playback agents
- Reward: route progress + collision avoidance
Stage 4: Full Interaction (Expert)
- All scenario types
- Pedestrians, cyclists
- Complex intersections
Further Reading
Core Papers
-
Waymo Open Motion Dataset
- "Large Scale Interactive Motion Forecasting for Autonomous Driving" (Ettinger et al., 2021)
- The dataset that powers Waymax
- arXiv:2104.10133
-
Wayformer
- "Wayformer: Motion Forecasting via Simple & Efficient Attention Networks" (Nayakanti et al., 2022)
- Transformer architecture used in Waymax baselines
- arXiv:2207.05844
-
Symphony
- "Symphony: Learning Realistic and Diverse Agents for Autonomous Driving Simulation" (Igl et al., 2022)
- Multi-agent simulation with learned sim agents
- arXiv:2205.03195
Related Simulators
-
CARLA
- High-fidelity 3D simulator
- Slower but more visually realistic
- carla.org
-
nuPlan
- Motional's planning benchmark
- Similar data-driven approach
- nuplan.org
-
InterSim
- "InterSim: Interactive Traffic Simulation via Explicit Relation Modeling" (Sun et al., 2022)
- Focus on multi-agent interactions
- arXiv:2210.08235
Foundational Concepts
-
Intelligent Driver Model (IDM)
- "Congested Traffic States in Empirical Observations and Microscopic Simulations" (Treiber et al., 2000)
- The car-following model used for sim agents
- Physical Review E, 62(2), 1805
-
Behavior Cloning
- "A Reduction of Imitation Learning and Structured Prediction to No-Regret Online Learning" (Ross et al., 2011)
- DAgger algorithm for addressing distribution shift
- arXiv:1011.0686
-
Sim-to-Real Transfer
- "Domain Randomization for Transferring Deep Neural Networks from Simulation to the Real World" (Tobin et al., 2017)
- Techniques for bridging the simulation-reality gap
- arXiv:1703.06907
JAX Resources
-
JAX Documentation
- Official docs: jax.readthedocs.io
- Key concepts: vmap, jit, grad
-
Brax
- "Brax: A Differentiable Physics Engine for Large Scale Rigid Body Simulation" (Freeman et al., 2021)
- JAX-based physics simulator, Waymax uses similar design patterns
- arXiv:2106.13281
Summary
Waymax represents a paradigm shift in autonomous driving simulation:
| Aspect | Traditional Simulators | Waymax |
|---|---|---|
| Speed | 10-100 Hz | 1000+ Hz |
| Data | Synthetic | Real-world |
| Hardware | CPU | GPU/TPU |
| Multi-agent | Limited | Full support |
| Differentiable | No | Yes |
Key Takeaways:
-
Data-driven realism: Using real driving data grounds simulation in actual human behavior
-
Hardware acceleration: JAX + GPU enables research at scale (millions of interactions)
-
Closed-loop evaluation: Essential for understanding real-world performance
-
Sim agent challenge: Beware of overfitting to sim agents - always evaluate against log playback
-
Metrics matter: The six metrics capture different failure modes; optimize for all
Getting Started:
pip install git+https://github.com/waymo-research/waymax.git@main#egg=waymo-waymax
Resources:
- Paper: arxiv.org/abs/2310.08710
- Code: github.com/waymo-research/waymax
- Dataset: waymo.com/open
This learning guide was created to help researchers and practitioners deeply understand the Waymax simulator. For corrections or suggestions, please open an issue on the repository.