V-Max Deep Dive: Reinforcement Learning for Autonomous Driving on Waymax
Paper: V-Max: Making Reinforcement Learning Work for Autonomous Driving Authors: Valeo AI Research Team ArXiv: 2503.08388 Code: github.com/valeoai/v-max
Table of Contents
- Executive Summary
- Architecture Deep Dive
- ScenarioMax: Multi-Dataset Support
- Observation Functions
- Reward Design
- Encoder Architectures
- RL Training Pipeline
- Interactive Code Examples
- Benchmarks Analysis
- Hands-On Exercises
- Interview Questions
1. Executive Summary
The Problem: RL for Autonomous Driving is Hard
Reinforcement learning has shown remarkable success in games and robotics, but autonomous driving presents unique challenges:
- Distribution shift: Simulation differs from real-world driving
- Sparse rewards: Crashes are rare but catastrophic
- High-dimensional observations: Processing LiDAR, cameras, and maps
- Safety constraints: Cannot explore freely like in games
Why V-Max Extends Waymax
Waymax (Google DeepMind) provides a JAX-accelerated simulation environment but lacks:
+------------------+ V-Max Adds: +----------------------+
| Waymax | -----------------> | Complete RL Stack |
+------------------+ +----------------------+
| - Simulation | | - Observation Funcs |
| - WOMD Dataset | | - Reward Hierarchy |
| - Basic Metrics | | - Transformer Enc. |
| - JAX Accel. | | - SAC/PPO Pipelines |
+------------------+ | - Multi-Dataset |
| - ScenarioMax |
+----------------------+
Key Contributions
| Component | What It Solves |
|---|---|
| ScenarioMax | Converts nuPlan, Argoverse 2, WOMD to unified format |
| Observation Functions | Configurable feature extraction (YAML-based) |
| Hierarchical Rewards | Safety > Navigation > Behavior priority |
| Transformer Encoders | 4 architectures: LQ, LQH, MTR, Wayformer |
| RL Pipelines | Production-ready SAC and PPO in JAX |
The Bottom Line
V-Max achieves 97.86% accuracy with SAC, matching expert-level driving in simulation while maintaining 0.89% collision rate. This represents a significant leap from baseline approaches.
2. Architecture Deep Dive
System Overview
V-Max models autonomous driving as a Partially Observable Markov Decision Process (POMDP):
POMDP Formulation
+------------------------------------------------+
| |
| State Space S: Full world state (hidden) |
| Observation O: BEV features (partial view) |
| Action Space A: [acceleration, steering] |
| Transition T: Waymax physics simulation |
| Reward R: Hierarchical (Safety/Nav/Behavior) |
| |
+------------------------------------------------+
Data Flow Architecture
+-------------+ +---------------+ +------------------+
| Waymax | | Observation | | Encoder |
| Simulator | -> | Functions | -> | (Transformer) |
+-------------+ +---------------+ +------------------+
| | |
v v v
+-------------+ +---------------+ +------------------+
| 9s Scenario | | - Trajectory | | Latent Query |
| @ 10Hz | | - Roadgraph | | or MTR/Wayformer |
| (90 steps) | | - Traffic | | Latent Vector |
+-------------+ | - Path Target | +------------------+
+---------------+ |
v
+-------------+ +---------------+ +------------------+
| Reward | <- | Policy | <- | Actor-Critic |
| Hierarchy | | Network | | Networks |
+-------------+ +---------------+ +------------------+
| | |
+-------------------+----------------------+
|
v
+------------------+
| RL Algorithm |
| (SAC or PPO) |
+------------------+
JAX Integration Architecture
The key insight is unified computation graphs:
Traditional RL Pipeline:
+----------+ CPU/GPU +----------+ CPU/GPU +----------+
|Simulation| <---------> | Python | <---------> | Training |
+----------+ Transfer +----------+ Transfer +----------+
Slow!
V-Max Pipeline:
+----------------------------------------------------------+
| Single JAX Graph |
| +----------+ +-------------+ +----------+ |
| |Simulation| -> | Observation | -> | Training | |
| | (Waymax) | | Functions | | (Brax) | |
| +----------+ +-------------+ +----------+ |
| |
| All operations compiled together - 4,609 steps/sec |
+----------------------------------------------------------+
Action Space Design
V-Max uses continuous control with bounded actions:
# Action Space Configuration
action_space = {
'acceleration': {
'range': [-4.0, 2.0], # m/s^2 (harder braking allowed)
'dtype': float32
},
'steering': {
'range': [-0.3, 0.3], # radians (approx +/- 17 degrees)
'dtype': float32
}
}
# Applied at 10Hz for 9 seconds = 90 timesteps per episode
3. ScenarioMax: Multi-Dataset Support
The Dataset Fragmentation Problem
Before ScenarioMax:
+--------+
| Waymax |
+--------+
|
v
+--------+
| WOMD | <-- Only supported dataset
+--------+
After ScenarioMax:
+--------+
| Waymax |
+--------+
|
v
+-------------+
| ScenarioMax |
+-------------+
/ | \
v v v
+------+ +------+ +------+
| WOMD | |nuPlan| | AV2 |
+------+ +------+ +------+
Dataset Comparison
| Feature | WOMD | nuPlan | Argoverse 2 |
|---|---|---|---|
| Scenarios | 487K | 1.3M | 250K |
| Location | US (6 cities) | US (4 cities) | US (6 cities) |
| Duration | 9.1s | 15s | 11s |
| Frequency | 10Hz | 10Hz | 10Hz |
| Map Format | Roadgraph | HD Vector | HD Vector |
SDC Path Reconstruction
A critical contribution: generating self-driving car (SDC) paths for datasets that lack them:
Original Data (nuPlan/AV2):
+-------+ +-------+ +-------+
| Start | ---> | ??? | ---> | End |
+-------+ +-------+ +-------+
No path defined
ScenarioMax Path Generation:
Step 1: Extract lane topology from HD map
Step 2: Find lanes near ego start position
Step 3: BFS traversal to generate 10 candidate routes
Step 4: Score routes by: distance, lane changes, turns
Step 5: Select optimal SDC path
Result:
+-------+ +-------+ +-------+ +-------+
| Start | -> | Lane1 | -> | Lane2 | -> | End |
+-------+ +-------+ +-------+ +-------+
^ ^
| SDC Path (10 routes) |
+-------------------------+
TfRecord Schema
# ScenarioMax unified schema
scenario_proto = {
# Core Waymax fields
'scenario_id': string,
'timestamps': float32[91], # 9.1s at 10Hz
# Trajectory data
'state/x': float32[128, 91], # 128 objects, 91 timesteps
'state/y': float32[128, 91],
'state/heading': float32[128, 91],
'state/velocity_x': float32[128, 91],
'state/velocity_y': float32[128, 91],
'state/valid': bool[128, 91],
# Roadgraph (sampled every 2m)
'roadgraph/xyz': float32[20000, 3],
'roadgraph/type': int32[20000],
'roadgraph/dir': float32[20000, 3], # V-Max addition
# SDC Paths (V-Max addition)
'sdc_paths/xy': float32[10, 500, 2], # 10 paths, 500 points each
'sdc_paths/valid': bool[10, 500],
# Traffic lights
'traffic_light/state': int32[16, 91],
'traffic_light/lane_ids': int64[16]
}
Cross-Dataset Training Results
Training on combined datasets improves generalization:
| Training Data | Test on WOMD | Test on nuPlan |
|---|---|---|
| WOMD only | 95.97% | 89.23% |
| nuPlan only | 91.45% | 96.12% |
| Combined | 95.38% | 95.54% |
Key Insight: Combined training sacrifices 0.6% on source domain but gains 6% on transfer domain.
4. Observation Functions
Feature Extraction Pipeline
+------------------+
| Waymax State |
+------------------+
|
v
+------------------+ +------------------+
| Feature Selector | --> | Ego-Centric |
| (YAML Config) | | Transformation |
+------------------+ +------------------+
| |
v v
+------------------+ +------------------+
| Trajectory Feat. | | Normalize to |
| Roadgraph Feat. | | Ego Frame |
| Traffic Feat. | +------------------+
| Path Target Feat.| |
+------------------+ |
| |
+------------------------+
|
v
+------------------+
| Observation Dict |
| (Ready for Enc.) |
+------------------+
Feature Categories Detailed
1. Trajectory Features
trajectory_features = {
# Per-object, per-timestep
'xy': float32[N_obj, T, 2], # Position relative to ego
'heading': float32[N_obj, T, 1], # Yaw angle
'velocity': float32[N_obj, T, 2], # Velocity vector
'size': float32[N_obj, 3], # Length, width, height
'type': int32[N_obj], # Vehicle, pedestrian, cyclist
'valid': bool[N_obj, T], # Visibility mask
}
# Default configuration
N_obj = 16 # Closest 16 objects
T = 5 # 5 historical timesteps (0.5 seconds)
2. Roadgraph Features
roadgraph_features = {
'xyz': float32[N_road, 3], # Waypoint positions
'direction': float32[N_road, 3], # Lane direction vectors
'type': int32[N_road], # Lane, road edge, crosswalk
'speed_limit': float32[N_road], # Speed limit if available
}
# Sampling strategy
# - Sample every 2 meters along lanes
# - Include lane centers, road edges, crosswalks
# - Filter by distance to ego (default: 100m radius)
3. Traffic Light Features
traffic_light_features = {
'state': int32[N_lights], # UNKNOWN, GREEN, YELLOW, RED
'xy': float32[N_lights, 2], # Stop line position
'relevant': bool[N_lights], # On ego's path?
}
N_lights = 5 # 5 closest traffic lights
4. Path Target Features
path_target_features = {
'xy': float32[N_targets, 2], # Waypoint positions
'heading': float32[N_targets], # Expected heading
'speed': float32[N_targets], # Suggested speed
}
N_targets = 10 # 10 waypoints, 5m apart = 50m lookahead
YAML Configuration System
# config/observation/default.yaml
observation:
trajectory:
num_objects: 16
num_timesteps: 5
features:
- xy
- heading
- velocity
- size
- type
ego_centric: true
roadgraph:
max_points: 2000
sample_interval: 2.0 # meters
types:
- lane_center
- road_edge
- crosswalk
radius: 100.0 # meters
traffic_light:
num_lights: 5
include_relevance: true
path_target:
num_waypoints: 10
waypoint_spacing: 5.0 # meters
include_speed: true
Ablation Study Results
| Config | Objects | Timesteps | Road Edge | V-Max Score |
|---|---|---|---|---|
| Minimal | 8 | 1 | No | 0.76 |
| Standard | 16 | 5 | No | 0.84 |
| Optimal | 16 | 5 | Yes | 0.88 |
| Maximal | 32 | 10 | Yes | 0.86 |
Key Finding: Adding road edge features significantly improves off-road prevention. Beyond 16 objects shows diminishing returns.
5. Reward Design
The Hierarchical Philosophy
V-Max implements a priority-based reward hierarchy:
+--------------------+
| SAFETY | <-- Highest Priority (Hard Constraints)
| No collisions |
| No off-road |
| No red lights |
+--------------------+
|
v
+--------------------+
| NAVIGATION | <-- Medium Priority (Task Completion)
| Stay on route |
| Make progress |
| Reach goal |
+--------------------+
|
v
+--------------------+
| BEHAVIOR | <-- Lowest Priority (Quality)
| Comfort (jerk) |
| Speed compliance |
| Smooth steering |
+--------------------+
Mathematical Formulation
Safety Reward
def safety_reward(state):
"""
Binary penalties for safety violations.
These should NEVER be traded off for navigation/behavior.
"""
collision = float(check_collision(state)) # 1 if collision
offroad = float(check_offroad(state)) # 1 if off drivable area
red_light = float(check_red_light_violation(state)) # 1 if ran red
r_safety = -collision - offroad - red_light
return r_safety # Range: [-3, 0]
Navigation Reward
def navigation_reward(state, prev_state):
"""
Task completion rewards built on top of safety.
"""
r_safety = safety_reward(state)
# Route following
off_route = float(distance_to_route(state) > threshold)
off_route_penalty = -0.2 * off_route
# Progress along route
progress = get_route_progress(state)
prev_progress = get_route_progress(prev_state)
progress_reward = 0.2 * float(progress > prev_progress)
r_navigation = r_safety + off_route_penalty + progress_reward
return r_navigation # Range: [-3.2, 0.2]
Behavior Reward
def behavior_reward(state, prev_state):
"""
Quality of driving built on top of navigation.
"""
r_navigation = navigation_reward(state, prev_state)
# Comfort metrics
jerk = compute_jerk(state, prev_state)
lateral_accel = compute_lateral_acceleration(state)
comfort_score = 1.0 - min(1.0, (jerk + lateral_accel) / max_allowed)
comfort_reward = 0.2 * comfort_score
# Speed compliance
speed = get_speed(state)
speed_limit = get_speed_limit(state)
speeding_penalty = -0.1 * float(speed > speed_limit)
r_behavior = r_navigation + comfort_reward + speeding_penalty
return r_behavior # Range: [-3.3, 0.4]
Reward Shaping Strategy
Episode Timeline:
t=0 t=90
|-----------------------------------------------------|
| Phase 1: Safety | Phase 2: Navigation | Phase 3: Behavior
| Focus on survival | Focus on progress | Focus on quality
| | |
Early Training (0-5M steps):
- Agent learns basic control
- Safety rewards dominate
- Most episodes end in collision
Mid Training (5-15M steps):
- Safety violations decrease
- Navigation rewards become meaningful
- Agent learns route following
Late Training (15-25M steps):
- Safety mostly solved
- Navigation mostly solved
- Behavior refinement
- Comfort scores improve
Reward Impact Analysis (Table 3)
| Reward Config | Collision Rate | Progress | V-Max Score |
|---|---|---|---|
| Safety Only | 2.1% | 78.66m | 0.71 |
| Safety + Nav | 1.8% | 155.38m | 0.82 |
| Full (S+N+B) | 0.89% | 148.21m | 0.89 |
Key Insight: Navigation rewards doubled progress distance. Adding behavior rewards slightly reduced progress (less aggressive) but significantly improved collision rate.
Dense vs Sparse Rewards
Sparse Rewards (Bad for RL):
+--------+ +--------+
| t=0 | No feedback... | t=90 |
| Start |----------------------------->| Crash! |
+--------+ | r=-1 |
+--------+
Dense Rewards (V-Max Approach):
+--------+ +--------+ +--------+ +--------+
| t=0 | | t=30 | | t=60 | | t=90 |
| r=0 |-->| r=+0.1 |-->| r=+0.2 |-->| r=+0.3 |
| Start | |Progress| |Progress| | Done! |
+--------+ +--------+ +--------+ +--------+
6. Encoder Architectures
Architecture Comparison Overview
+------------------+ +------------------+ +------------------+
| Latent Query | | MTR | | Wayformer |
| (Best: 97.45%) | | (95.94%) | | (96.08%) |
+------------------+ +------------------+ +------------------+
| | | | | |
| Cross-Attention | | Self-Attention | | Factorized |
| to Learnable | | with Agent-Centric| | Attention |
| Latent Vectors | | Scene Encoding | | (Time x Space) |
| | | | | |
+------------------+ +------------------+ +------------------+
1. Latent Query (LQ) Encoder - Best Performer
Input Features
|
v
+------------------+
| Feature Embedding | Linear projection to d_model
+------------------+
|
v
+------------------+
| Positional | Sinusoidal embeddings
| Encoding | for temporal/spatial info
+------------------+
|
+-----------------------------+
| |
v v
+------------------+ +------------------+
| Cross-Attention |<----| Learnable |
| | | Latent Vectors |
| Q: Latents | | [16 x d_model] |
| K,V: Features | +------------------+
+------------------+
|
v (repeat 4x)
+------------------+
| Self-Attention |
| |
| Q,K,V: Latents |
+------------------+
|
v
+------------------+
| Mean Pool | Aggregate 16 latents
+------------------+
|
v
[Scene Embedding: d_model]
Why Latent Query Works Best:
- Bottleneck compression: Forces learning of essential features
- Learnable queries: Adapts to what information is useful
- Efficient: Fixed computation regardless of scene size
class LatentQueryEncoder(nn.Module):
"""JAX/Flax implementation of Latent Query encoder."""
d_model: int = 256
n_heads: int = 8
n_layers: int = 4
n_latents: int = 16
@nn.compact
def __call__(self, features, mask=None):
# Learnable latent vectors
latents = self.param(
'latents',
nn.initializers.normal(0.02),
(self.n_latents, self.d_model)
)
# Project input features
x = nn.Dense(self.d_model)(features)
# Cross-attention: latents query the features
for _ in range(self.n_layers):
# Cross-attention
latents = nn.MultiHeadDotProductAttention(
num_heads=self.n_heads
)(latents, x, mask=mask)
# Self-attention among latents
latents = nn.MultiHeadDotProductAttention(
num_heads=self.n_heads
)(latents, latents)
# FFN
latents = latents + nn.Dense(self.d_model)(
nn.gelu(nn.Dense(self.d_model * 4)(latents))
)
# Mean pool latents
return jnp.mean(latents, axis=0)
2. Latent Query Hierarchical (LQH) Encoder
+------------------+ +------------------+ +------------------+
| Trajectory | | Roadgraph | | Traffic Light |
| Features | | Features | | Features |
+------------------+ +------------------+ +------------------+
| | |
v v v
+------------------+ +------------------+ +------------------+
| LQ Encoder | | LQ Encoder | | LQ Encoder |
| (Local) | | (Local) | | (Local) |
+------------------+ +------------------+ +------------------+
| | |
+----------+------------+-----------+-----------+
| | |
v v v
+----------------------------------+
| Concatenate Embeddings |
+----------------------------------+
|
v
+----------------------------------+
| Global LQ Encoder |
| (Cross-modal fusion) |
+----------------------------------+
|
v
[Scene Embedding]
Advantage: Modular processing allows specialized encoding per feature type.
3. Motion Transformer (MTR) Encoder
Per-Agent Processing:
+------------------+
| Agent Trajectory |
| [T x Features] |
+------------------+
|
v
+------------------+
| Temporal |
| Self-Attention |
+------------------+
|
v
[Agent Embedding]
Scene-Level Fusion:
+--------+ +--------+ +--------+ +--------+
| Agent1 | | Agent2 | | Map | | Traffic|
+--------+ +--------+ +--------+ +--------+
| | | |
+----------+----------+----------+
|
v
+------------------+
| Cross-Agent |
| Attention |
+------------------+
|
v
+------------------+
| Agent-Map |
| Attention |
+------------------+
|
v
[Scene Embedding]
4. Wayformer Encoder
Factorized Attention:
Input: [N_agents x T x Features]
Step 1: Temporal Attention (per agent)
+------------------+
| Self-Attn over |
| time dimension |
| [T x T] |
+------------------+
Step 2: Spatial Attention (per timestep)
+------------------+
| Self-Attn over |
| agent dimension |
| [N x N] |
+------------------+
Why Factorized?
- Full attention: O(N*T)^2
- Factorized: O(N^2) + O(T^2) << Much cheaper!
Encoder Performance Comparison
| Encoder | Accuracy | Collision | V-Max Score | Params | Speed |
|---|---|---|---|---|---|
| MLP (baseline) | 68.12% | 15.2% | 0.58 | 0.5M | 10K/s |
| Latent Query | 97.45% | 0.89% | 0.89 | 2.1M | 4.6K/s |
| LQ Hierarchical | 96.28% | 1.2% | 0.87 | 3.2M | 3.8K/s |
| MTR | 95.94% | 1.4% | 0.85 | 2.8M | 4.1K/s |
| Wayformer | 96.08% | 1.3% | 0.86 | 2.5M | 4.3K/s |
Key Insight: Latent Query achieves best performance with moderate parameter count. The learnable bottleneck appears crucial for policy learning.
7. RL Training Pipeline
Algorithm Comparison
+------------------+ +------------------+
| SAC | | PPO |
| (Off-Policy) | | (On-Policy) |
+------------------+ +------------------+
| | | |
| + Sample | | + Stable |
| efficient | | + Simple to tune |
| + Exploration | | + Parallelizes |
| via entropy | | well |
| | | |
| - Sensitive to | | - Sample |
| hyperparams | | inefficient |
| - Replay buffer | | - Needs more |
| memory | | environment |
| | | steps |
+------------------+ +------------------+
| | | |
| 25M steps | | 200M steps |
| 12-24h training | | 24-48h training |
| Best results | | Good results |
+------------------+ +------------------+
SAC Implementation Details
# V-Max SAC Configuration
sac_config = {
# Environment
'num_envs': 16, # Parallel environments
'episode_length': 90, # 9 seconds at 10Hz
# Replay Buffer
'buffer_size': 1_000_000, # 1M transitions
'min_buffer_size': 10_000, # Start training after
# Training
'total_steps': 25_000_000,
'batch_size': 256,
'updates_per_step': 4, # Gradient updates per env step
# SAC Hyperparameters
'actor_lr': 3e-4,
'critic_lr': 3e-4,
'alpha_lr': 3e-4, # Entropy coefficient
'gamma': 0.99, # Discount factor
'tau': 0.005, # Target network update rate
# Entropy
'init_alpha': 0.1,
'target_entropy': -2.0, # dim(action) = 2
'auto_alpha': True, # Learn alpha
# Network Architecture
'encoder': 'latent_query',
'hidden_dims': [256, 256],
'activation': 'relu',
}
SAC Training Loop
+------------------+
| Initialize |
| - Actor network |
| - Critic network |
| - Target critic |
| - Replay buffer |
| - Alpha (entropy)|
+------------------+
|
v
+------------------+ +------------------+
| Collect Data |<----| Train Loop |
| (16 parallel) | | (25M steps) |
+------------------+ +------------------+
| ^
v |
+------------------+ |
| Store in Buffer | |
+------------------+ |
| |
v |
+------------------+ |
| Sample Batch | |
| (256 transitions)| |
+------------------+ |
| |
v |
+------------------+ |
| Update Critic | |
| (TD error) | |
+------------------+ |
| |
v |
+------------------+ |
| Update Actor | |
| (Policy gradient)| |
+------------------+ |
| |
v |
+------------------+ |
| Update Alpha | |
| (Entropy target) | |
+------------------+ |
| |
v |
+------------------+ |
| Update Target |-------------+
| (Soft update) |
+------------------+
PPO Implementation Details
# V-Max PPO Configuration
ppo_config = {
# Environment
'num_envs': 256, # More parallel envs for on-policy
'episode_length': 90,
# Training
'total_steps': 200_000_000, # 8x more than SAC
'batch_size': 512,
'num_minibatches': 8,
'num_epochs': 4, # Epochs per update
# PPO Hyperparameters
'lr': 3e-4,
'gamma': 0.99,
'gae_lambda': 0.95, # GAE parameter
'clip_epsilon': 0.2, # PPO clip range
'vf_coef': 0.5, # Value loss coefficient
'ent_coef': 0.01, # Entropy bonus
'max_grad_norm': 0.5, # Gradient clipping
# Network Architecture
'encoder': 'latent_query',
'shared_encoder': True, # Actor-critic share encoder
}
BC+SAC Hybrid Approach
Phase 1: Behavior Cloning (BC)
+------------------+ +------------------+
| Expert Demos |---->| Supervised |
| (Logged Data) | | Pre-training |
+------------------+ +------------------+
|
v
+------------------+
| Warm-start |
| Policy |
+------------------+
|
Phase 2: SAC Fine-tuning |
+------------------+ |
| Online RL |<-----------+
| (SAC) |
+------------------+
|
v
+------------------+
| Final Policy |
| (Best of both) |
+------------------+
Benefits:
- BC provides good initialization
- SAC explores and improves
- Reduces sample complexity
Training Infrastructure
Hardware Configuration:
+------------------+
| Single NVIDIA L4 |
| - 24GB VRAM |
| - 12-48h runtime |
+------------------+
Memory Usage:
+------------------+
| Replay Buffer | ~4GB (1M transitions)
| Networks | ~500MB
| JAX Compilation | ~2GB
| Batch Processing | ~1GB
+------------------+
Total: ~8GB (fits in L4)
Throughput:
+------------------+
| Training | ~1,500 steps/sec
| Evaluation | ~4,609 steps/sec (batched)
+------------------+
Learning Curves
SAC Training Progress:
Accuracy (%)
100| ****
90| ******
80| ******
70| ******
60| ******
50| ******
40| *****
+----------------------------------------
0 5M 10M 15M 20M 25M steps
Collision Rate (%)
20|*
15| *
10| **
5| ***
1| ********** *
0| ******************
+----------------------------------------
0 5M 10M 15M 20M 25M steps
8. Interactive Code Examples
Example 1: Setting Up V-Max Environment
"""
V-Max Training Setup
====================
Complete example of setting up a V-Max training environment.
"""
import jax
import jax.numpy as jnp
from v_max import make_env, make_observation_fn, make_reward_fn
from v_max.encoders import LatentQueryEncoder
from v_max.algorithms import SAC
# 1. Create Waymax environment with V-Max wrapper
env_config = {
'dataset': 'womd', # or 'nuplan', 'av2'
'data_path': '/path/to/tfrecords',
'batch_size': 16, # Parallel scenarios
'max_steps': 90, # 9 seconds at 10Hz
}
env = make_env(**env_config)
# 2. Configure observation function
obs_config = {
'trajectory': {
'num_objects': 16,
'num_timesteps': 5,
'features': ['xy', 'heading', 'velocity', 'size', 'type']
},
'roadgraph': {
'max_points': 2000,
'sample_interval': 2.0,
'types': ['lane_center', 'road_edge']
},
'traffic_light': {
'num_lights': 5
},
'path_target': {
'num_waypoints': 10,
'waypoint_spacing': 5.0
}
}
obs_fn = make_observation_fn(**obs_config)
# 3. Configure reward function
reward_config = {
'safety': {
'collision_weight': -1.0,
'offroad_weight': -1.0,
'red_light_weight': -1.0
},
'navigation': {
'off_route_weight': -0.2,
'progress_weight': 0.2
},
'behavior': {
'comfort_weight': 0.2,
'speeding_weight': -0.1
}
}
reward_fn = make_reward_fn(**reward_config)
# 4. Create encoder
encoder = LatentQueryEncoder(
d_model=256,
n_heads=8,
n_layers=4,
n_latents=16
)
# 5. Initialize SAC agent
agent = SAC(
encoder=encoder,
action_dim=2, # [acceleration, steering]
hidden_dims=[256, 256],
actor_lr=3e-4,
critic_lr=3e-4,
alpha_lr=3e-4,
gamma=0.99,
tau=0.005
)
# 6. Training loop
print("Starting V-Max training...")
for step in range(25_000_000):
# Collect experience
obs = obs_fn(env.state)
action = agent.sample_action(obs)
next_state, reward_info = env.step(action)
reward = reward_fn(env.state, next_state)
# Store transition
agent.buffer.add(obs, action, reward, obs_fn(next_state), done)
# Update agent
if step > 10_000: # After warmup
metrics = agent.update(batch_size=256, num_updates=4)
# Logging
if step % 10_000 == 0:
print(f"Step {step}: {metrics}")
Example 2: Custom Observation Function
"""
Custom Observation Function
===========================
How to create a specialized observation function for your use case.
"""
import jax.numpy as jnp
from v_max.observations import BaseObservation
class CustomObservation(BaseObservation):
"""
Custom observation that adds lane curvature features.
"""
def __init__(self, base_config, curvature_lookahead=50):
super().__init__(base_config)
self.curvature_lookahead = curvature_lookahead
def extract_trajectory_features(self, state):
"""Extract trajectory features with additional info."""
# Get base features
traj_features = super().extract_trajectory_features(state)
# Add relative velocity to ego
ego_velocity = state.ego_velocity
relative_velocities = traj_features['velocity'] - ego_velocity
traj_features['relative_velocity'] = relative_velocities
return traj_features
def extract_roadgraph_features(self, state):
"""Extract roadgraph with lane curvature."""
road_features = super().extract_roadgraph_features(state)
# Compute lane curvature
# Curvature = rate of change of heading
xy = road_features['xyz'][:, :2]
directions = jnp.diff(xy, axis=0)
headings = jnp.arctan2(directions[:, 1], directions[:, 0])
curvature = jnp.diff(headings)
# Pad to match original length
curvature = jnp.pad(curvature, (0, 2), mode='edge')
road_features['curvature'] = curvature
return road_features
def extract_path_target_features(self, state):
"""Extract path targets with speed profile."""
path_features = super().extract_path_target_features(state)
# Add recommended speed based on curvature
curvature = self.compute_path_curvature(path_features['xy'])
max_speed = 30.0 # m/s
min_speed = 5.0 # m/s
# Speed inversely proportional to curvature
recommended_speed = max_speed - (max_speed - min_speed) * jnp.abs(curvature)
path_features['recommended_speed'] = recommended_speed
return path_features
# Usage
custom_obs_fn = CustomObservation(
base_config=obs_config,
curvature_lookahead=50
)
Example 3: Custom Reward Function
"""
Custom Reward Function
======================
Implementing domain-specific reward shaping.
"""
import jax.numpy as jnp
from v_max.rewards import BaseReward
class AggressiveDriverReward(BaseReward):
"""
Reward function that encourages faster, more assertive driving.
Useful for testing policy robustness.
"""
def __init__(self, base_config, speed_bonus=0.3, gap_acceptance=0.5):
super().__init__(base_config)
self.speed_bonus = speed_bonus
self.gap_acceptance = gap_acceptance
def compute_safety_reward(self, state, prev_state):
"""Safety with adjusted thresholds."""
# Still penalize hard constraints
collision = self.check_collision(state)
offroad = self.check_offroad(state)
red_light = self.check_red_light(state)
# But allow closer following distance
ttc = self.compute_ttc(state)
ttc_penalty = -0.5 * jnp.clip(1.0 - ttc, 0, 1) # Softer TTC
return -collision - offroad - red_light + ttc_penalty
def compute_navigation_reward(self, state, prev_state):
"""Navigation with speed bonus."""
base_nav = super().compute_navigation_reward(state, prev_state)
# Bonus for maintaining higher speeds
speed = jnp.linalg.norm(state.ego_velocity)
speed_limit = self.get_speed_limit(state)
# Reward for being close to speed limit (not over!)
speed_ratio = speed / (speed_limit + 1e-6)
speed_reward = self.speed_bonus * jnp.clip(speed_ratio, 0, 1)
return base_nav + speed_reward
def compute_behavior_reward(self, state, prev_state):
"""Behavior that rewards assertive maneuvers."""
base_behavior = super().compute_behavior_reward(state, prev_state)
# Bonus for successful lane changes / merges
lane_change = self.detect_lane_change(state, prev_state)
merge_success = self.detect_merge(state, prev_state)
maneuver_bonus = 0.2 * (lane_change + merge_success)
return base_behavior + maneuver_bonus
# Usage
aggressive_reward = AggressiveDriverReward(
base_config=reward_config,
speed_bonus=0.3,
gap_acceptance=0.5
)
Example 4: Training with Curriculum
"""
Curriculum Learning for V-Max
=============================
Progressive difficulty increase for more stable training.
"""
from v_max.curriculum import CurriculumScheduler
# Define curriculum stages
curriculum_config = {
'stages': [
{
'name': 'basic',
'steps': 5_000_000,
'scenario_filter': {
'max_objects': 8,
'scenario_types': ['straight', 'lane_follow']
},
'reward_scale': 1.0
},
{
'name': 'intermediate',
'steps': 10_000_000,
'scenario_filter': {
'max_objects': 16,
'scenario_types': ['straight', 'lane_follow', 'turn']
},
'reward_scale': 1.0
},
{
'name': 'advanced',
'steps': 10_000_000,
'scenario_filter': {
'max_objects': 32,
'scenario_types': ['all']
},
'reward_scale': 1.0
}
],
'auto_advance': True, # Move to next stage when success_rate > 0.9
'success_metric': 'v_max_score'
}
scheduler = CurriculumScheduler(**curriculum_config)
# Training with curriculum
for step in range(25_000_000):
# Get current stage
stage = scheduler.get_stage(step)
# Sample scenarios from current difficulty
scenarios = dataset.sample(
batch_size=16,
filter=stage['scenario_filter']
)
# Load into environment
env.reset(scenarios)
# Training step (same as before)
...
# Update curriculum based on performance
scheduler.update(metrics['v_max_score'])
Example 5: Evaluation Script
"""
V-Max Evaluation
================
Complete evaluation pipeline with all metrics.
"""
from v_max.evaluation import Evaluator, MetricAggregator
# Initialize evaluator
evaluator = Evaluator(
env=env,
obs_fn=obs_fn,
checkpoint_path='checkpoints/sac_final.pkl'
)
# Evaluation modes
evaluation_configs = {
'non_reactive': {
'description': 'Other agents replay logged trajectories',
'controllable_agents': ['ego'],
'num_scenarios': 1000
},
'reactive': {
'description': 'Other agents use IDM controller',
'controllable_agents': ['ego', 'idm_agents'],
'num_scenarios': 1000
},
'adversarial': {
'description': 'ReGentS adversarial agents',
'controllable_agents': ['ego', 'adversarial_agents'],
'num_scenarios': 500
}
}
# Run evaluations
results = {}
for mode, config in evaluation_configs.items():
print(f"\nRunning {mode} evaluation...")
metrics = evaluator.evaluate(
num_scenarios=config['num_scenarios'],
mode=mode,
metrics=[
'accuracy',
'collision_rate',
'at_fault_collision_rate',
'off_road_rate',
'red_light_violation_rate',
'progress',
'ttc',
'comfort_score',
'v_max_score'
]
)
results[mode] = metrics
# Print summary
print(f" Accuracy: {metrics['accuracy']:.2%}")
print(f" Collision Rate: {metrics['collision_rate']:.2%}")
print(f" V-Max Score: {metrics['v_max_score']:.3f}")
# Aggregate and save results
aggregator = MetricAggregator()
summary = aggregator.create_report(results)
aggregator.save_latex_table(summary, 'results/evaluation_table.tex')
9. Benchmarks Analysis
Main Results: Non-Reactive Evaluation (Table 5)
Performance Comparison on WOMD Validation Set
=============================================
+--------+--------+--------+--------+--------+
| SAC | PPO | BC | BC+SAC | PDM |
+--------+--------+--------+--------+--------+
Accuracy | 97.86% | 90.75% | 79.42% | 93.21% | 93.60% |
+--------+--------+--------+--------+--------+
Collision | 0.89% | 7.81% | 13.14% | 4.32% | 0.96% |
+--------+--------+--------+--------+--------+
At-Fault Coll. | 0.45% | 3.12% | 5.89% | 1.87% | 0.52% |
+--------+--------+--------+--------+--------+
Off-Road | 0.34% | 1.56% | 4.23% | 0.98% | 0.41% |
+--------+--------+--------+--------+--------+
Red-Light Viol. | 0.12% | 0.89% | 2.31% | 0.43% | 0.18% |
+--------+--------+--------+--------+--------+
Progress (m) | 148.21 | 132.45 | 89.67 | 141.32 | 145.67 |
+--------+--------+--------+--------+--------+
TTC Score | 0.97 | 0.91 | 0.78 | 0.94 | 0.96 |
+--------+--------+--------+--------+--------+
Comfort Score | 0.89 | 0.82 | 0.71 | 0.85 | 0.91 |
+--------+--------+--------+--------+--------+
V-Max Score | 0.892 | 0.784 | 0.721 | 0.843 | 0.887 |
+--------+--------+--------+--------+--------+
Key Insights:
- SAC dominates: Achieves best accuracy (97.86%) and V-Max score (0.892)
- BC alone insufficient: Pure imitation learning achieves only 79.42% accuracy
- Hybrid helps: BC+SAC (93.21%) outperforms BC (79.42%) significantly
- PDM competitive: Rule-based PDM achieves 93.60% accuracy, close to RL methods
Reactive Evaluation (Table 6)
Performance with IDM-Controlled Traffic
=======================================
When other vehicles react to ego:
| Non-Reactive | Reactive | Delta |
+--------------+----------+----------+
SAC Accuracy | 97.86% | 94.12% | -3.74% |
SAC Collision | 0.89% | 2.34% | +1.45% |
+--------------+----------+----------+
PDM Accuracy | 93.60% | 91.23% | -2.37% |
PDM Collision | 0.96% | 1.89% | +0.93% |
+--------------+----------+----------+
Analysis:
- All methods degrade with reactive agents
- SAC more sensitive than PDM (-3.74% vs -2.37%)
- Gap narrows in reactive setting
Adversarial Evaluation (Table 7)
Performance Under ReGentS Adversarial Attacks
=============================================
Adversarial cut-in scenarios:
| Standard | Adversarial | Robustness |
+----------+-------------+------------+
SAC Accuracy | 97.86% | 76.23% | 77.9% |
SAC Collision | 0.89% | 18.45% | -17.6% |
+----------+-------------+------------+
PDM Accuracy | 93.60% | 53.12% | 56.7% |
PDM Collision | 0.96% | 32.67% | -31.7% |
+----------+-------------+------------+
Key Finding:
- SAC significantly more robust than PDM under adversarial attacks
- SAC robustness: 77.9% vs PDM: 56.7%
- RL learns more generalizable policies
Encoder Comparison (Table 2)
Impact of Encoder Architecture
==============================
Encoder | Accuracy | V-Max Score | Collision |
-----------------+----------+-------------+-----------+
MLP (baseline) | 68.12% | 0.58 | 15.2% |
-----------------+----------+-------------+-----------+
Wayformer | 96.08% | 0.86 | 1.3% |
MTR | 95.94% | 0.85 | 1.4% |
LQ Hierarchical | 96.28% | 0.87 | 1.2% |
Latent Query | 97.45% | 0.89 | 0.9% |
-----------------+----------+-------------+-----------+
Insight: Transformer encoders dramatically outperform MLP
Latent Query best overall
Cross-Dataset Generalization (Table 4)
Training on WOMD+nuPlan, Testing on Each
========================================
| WOMD-only | nuPlan-only | Combined |
-----------------+-----------+-------------+----------|
Test on WOMD | 95.97% | 91.45% | 95.38% |
Test on nuPlan | 89.23% | 96.12% | 95.54% |
-----------------+-----------+-------------+----------|
Average | 92.60% | 93.78% | 95.46% |
Key Finding:
- Combined training achieves best average performance
- Only 0.6% drop on source domain
- 6% gain on transfer domain
- Combined = more robust, generalizable policies
Computational Efficiency (Table 8)
Steps Per Second (SPS) Benchmarks
=================================
Configuration | SPS | Relative |
----------------------------+--------+----------|
Waymax baseline (no metrics)| 8,234 | 1.0x |
V-Max (full metrics) | 4,609 | 0.56x |
V-Max (minimal metrics) | 6,123 | 0.74x |
V-Max (32 batch) | 4,609 | 0.56x |
V-Max (128 batch) | 3,891 | 0.47x |
Insight:
- Full metrics computation adds ~44% overhead
- Still achieves >4K SPS - sufficient for RL
- Batching helps but has diminishing returns
10. Hands-On Exercises
Exercise 1: Basic Environment Setup (Beginner)
Objective: Set up V-Max environment and run a random policy.
"""
Exercise 1: Basic Setup
-----------------------
TODO: Complete the following tasks
1. Load the WOMD dataset
2. Create a V-Max environment
3. Run 100 episodes with random actions
4. Compute average episode length and collision rate
"""
import jax
import jax.numpy as jnp
from jax import random
# TODO: Import V-Max modules
# from v_max import ...
def random_policy(key, obs):
"""
TODO: Implement a random policy that returns:
- acceleration in [-4, 2]
- steering in [-0.3, 0.3]
"""
key, subkey = random.split(key)
# Your code here
pass
def run_episode(env, policy_fn, key):
"""
TODO: Run a single episode and return:
- episode_length: int
- collision: bool
- total_progress: float
"""
# Your code here
pass
def main():
# Setup
key = random.PRNGKey(42)
# TODO: Create environment
# env = ...
# Run episodes
results = []
for i in range(100):
key, subkey = random.split(key)
result = run_episode(env, random_policy, subkey)
results.append(result)
# Compute statistics
# TODO: Print average episode length, collision rate, progress
pass
if __name__ == "__main__":
main()
Expected Output:
Random Policy Results:
Average Episode Length: 23.4 steps (target: ~25)
Collision Rate: 87.3% (target: 80-95%)
Average Progress: 12.3m (target: 10-20m)
Exercise 2: Observation Function Analysis (Beginner)
Objective: Understand and visualize observation features.
"""
Exercise 2: Observation Analysis
--------------------------------
TODO: Analyze the observation function outputs
1. Extract observations from a scenario
2. Visualize trajectory features
3. Compute feature statistics
4. Identify the most important features
"""
import matplotlib.pyplot as plt
import numpy as np
def visualize_scene(obs):
"""
TODO: Create a bird's-eye view visualization showing:
- Ego vehicle (red)
- Other vehicles (blue)
- Roadgraph (gray lines)
- Traffic lights (colored circles)
- SDC path (green dashed line)
"""
fig, ax = plt.subplots(figsize=(12, 12))
# Plot roadgraph
# road_xy = obs['roadgraph']['xyz'][:, :2]
# TODO: Your code here
# Plot vehicles
# traj_xy = obs['trajectory']['xy']
# TODO: Your code here
# Plot ego path
# path_xy = obs['path_target']['xy']
# TODO: Your code here
ax.set_aspect('equal')
ax.set_title('Bird\'s Eye View')
plt.show()
def compute_feature_statistics(obs_batch):
"""
TODO: Compute statistics for each feature:
- Mean, std, min, max
- Correlation with collision outcome
"""
stats = {}
# For each feature category
for category in ['trajectory', 'roadgraph', 'traffic_light', 'path_target']:
# TODO: Compute statistics
pass
return stats
# Run analysis
# TODO: Load sample observations and run analysis
Exercise 3: Reward Shaping Experiment (Intermediate)
Objective: Compare different reward configurations.
"""
Exercise 3: Reward Shaping
--------------------------
TODO: Experiment with different reward configurations
1. Implement three reward variants:
a) Safety-only rewards
b) Safety + Navigation rewards
c) Full hierarchical rewards (Safety + Nav + Behavior)
2. Train each for 1M steps
3. Compare learning curves and final performance
"""
reward_configs = {
'safety_only': {
# TODO: Define safety-only rewards
},
'safety_nav': {
# TODO: Define safety + navigation rewards
},
'full': {
# TODO: Define full hierarchical rewards
}
}
def train_with_reward(reward_config, num_steps=1_000_000):
"""
TODO: Implement training loop with specified reward config
Return learning curves (collision_rate, progress, v_max_score)
"""
pass
def plot_comparison(results):
"""
TODO: Plot learning curves for all three configurations
"""
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
metrics = ['collision_rate', 'progress', 'v_max_score']
for ax, metric in zip(axes, metrics):
for name, data in results.items():
ax.plot(data['steps'], data[metric], label=name)
ax.set_title(metric)
ax.legend()
plt.tight_layout()
plt.savefig('reward_comparison.png')
Expected Results:
After 1M steps:
| Safety Only | Safety+Nav | Full |
-------------+-------------+------------+---------|
Collision | 5.2% | 3.1% | 2.3% |
Progress | 45.2m | 89.3m | 82.1m |
V-Max Score | 0.65 | 0.78 | 0.82 |
Exercise 4: Encoder Architecture Comparison (Intermediate)
Objective: Implement and compare encoder architectures.
"""
Exercise 4: Encoder Comparison
------------------------------
TODO: Implement simplified versions of each encoder and compare
1. Implement MLP baseline
2. Implement simplified Latent Query
3. Train both for 5M steps
4. Analyze performance differences
"""
import flax.linen as nn
class MLPEncoder(nn.Module):
"""
TODO: Implement MLP baseline encoder
Architecture:
- Flatten all inputs
- 3 hidden layers [256, 256, 256]
- ReLU activations
- Output: 256-dim embedding
"""
@nn.compact
def __call__(self, obs):
# Your code here
pass
class SimpleLatentQuery(nn.Module):
"""
TODO: Implement simplified Latent Query encoder
Architecture:
- Linear projection of inputs
- 2 cross-attention layers
- 8 learnable latent vectors
- Mean pooling
"""
n_latents: int = 8
d_model: int = 256
@nn.compact
def __call__(self, obs):
# Your code here
pass
def compare_encoders(encoders, num_steps=5_000_000):
"""
TODO: Train each encoder and collect metrics
"""
results = {}
for name, encoder_cls in encoders.items():
print(f"Training {name}...")
# Training code here
pass
return results
Exercise 5: SAC Hyperparameter Tuning (Advanced)
Objective: Understand SAC hyperparameters through systematic tuning.
"""
Exercise 5: SAC Hyperparameter Tuning
-------------------------------------
TODO: Systematically tune SAC hyperparameters
Key hyperparameters to explore:
1. Learning rates (actor, critic, alpha)
2. Batch size
3. Updates per environment step
4. Target entropy
5. Replay buffer size
"""
import optuna
def objective(trial):
"""
TODO: Implement Optuna objective for hyperparameter tuning
"""
# Sample hyperparameters
config = {
'actor_lr': trial.suggest_float('actor_lr', 1e-5, 1e-3, log=True),
'critic_lr': trial.suggest_float('critic_lr', 1e-5, 1e-3, log=True),
'batch_size': trial.suggest_categorical('batch_size', [64, 128, 256, 512]),
'updates_per_step': trial.suggest_int('updates_per_step', 1, 8),
'target_entropy': trial.suggest_float('target_entropy', -4.0, -1.0),
'buffer_size': trial.suggest_categorical('buffer_size', [100000, 500000, 1000000]),
}
# Train for limited steps
# TODO: Implement training
# Return validation metric
# return v_max_score
pass
# Run hyperparameter search
study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=50)
print("Best hyperparameters:", study.best_params)
print("Best V-Max score:", study.best_value)
Expected Results:
Best hyperparameters:
actor_lr: 3.2e-4
critic_lr: 2.8e-4
batch_size: 256
updates_per_step: 4
target_entropy: -2.0
buffer_size: 1000000
Best V-Max score: 0.89
Exercise 6: Multi-Dataset Training (Advanced)
Objective: Train on combined datasets and analyze generalization.
"""
Exercise 6: Multi-Dataset Training
----------------------------------
TODO: Implement cross-dataset training and evaluate transfer
1. Load WOMD and nuPlan datasets
2. Create combined dataloader
3. Train policy on combined data
4. Evaluate on each dataset separately
5. Compare to single-dataset training
"""
class MultiDatasetLoader:
"""
TODO: Implement a dataloader that samples from multiple datasets
Features:
- Configurable sampling ratios
- Dataset balancing
- Domain labels for analysis
"""
def __init__(self, datasets, ratios=None):
self.datasets = datasets
self.ratios = ratios or {name: 1.0 for name in datasets}
# Your code here
def sample(self, batch_size):
"""
TODO: Sample a batch from combined datasets
Return batch and domain labels
"""
pass
def evaluate_transfer(policy, datasets):
"""
TODO: Evaluate policy on each dataset and return metrics
"""
results = {}
for name, dataset in datasets.items():
# Evaluate
metrics = evaluate(policy, dataset)
results[name] = metrics
return results
def create_transfer_matrix(results):
"""
TODO: Create a transfer matrix showing:
- Rows: Training dataset
- Columns: Evaluation dataset
- Values: V-Max score
"""
pass
11. Interview Questions
Conceptual Questions
Q1: Why does V-Max use a POMDP formulation instead of MDP?
<details> <summary>Click to reveal answer</summary>Answer: Autonomous driving is inherently partially observable because:
-
Sensor limitations: The ego vehicle cannot observe the full state of the world - occluded vehicles, pedestrians behind buildings, etc.
-
Intentions are hidden: Other agents' goals and planned trajectories are not directly observable.
-
Map uncertainty: The exact lane boundaries and traffic signal states may be ambiguous.
In the POMDP formulation:
- State S: Full world state (positions, velocities, intentions of all agents)
- Observation O: What ego can perceive (BEV features within sensor range)
- Belief: The policy must maintain implicit beliefs about hidden state
This is why transformer encoders with attention over historical observations work well - they can learn to aggregate temporal information to estimate hidden state.
</details>Q2: Explain why SAC outperforms PPO in V-Max despite PPO often being preferred for robotics.
<details> <summary>Click to reveal answer</summary>Answer: Several factors favor SAC in the V-Max setting:
-
Sample Efficiency: SAC's off-policy learning with replay buffer achieves good performance in 25M steps vs PPO's 200M steps. This is crucial when simulation is expensive.
-
Exploration via Entropy: SAC's entropy maximization encourages diverse driving behaviors, helping discover safe trajectories in complex scenarios.
-
Continuous Action Space: SAC was designed for continuous control (acceleration, steering), while PPO's clipping can be sensitive in high-dimensional continuous spaces.
-
Deterministic Evaluation: SAC can use a deterministic policy at test time, reducing variance in safety-critical scenarios.
-
Dense Rewards: V-Max's hierarchical rewards provide dense feedback, which off-policy methods like SAC utilize efficiently.
PPO might be preferred when:
- Sample efficiency matters less than stability
- Distributed training across many workers is needed
- The reward signal is very sparse
Q3: Why is the hierarchical reward structure (Safety > Navigation > Behavior) important?
<details> <summary>Click to reveal answer</summary>Answer: The hierarchy enforces a priority ordering that mirrors human driving values:
-
Safety as Hard Constraint: Collisions and traffic violations are never acceptable, regardless of navigation goals. By making safety rewards dominant, the policy learns to prioritize survival.
-
Navigation Enables Learning: Without navigation rewards, a policy might learn to "play it safe" by never moving. Navigation rewards ensure the agent attempts the driving task.
-
Behavior Refinement: Comfort and compliance only matter once safety and navigation are achieved. Adding behavior rewards too early can cause reward hacking.
Mathematical Insight: The hierarchy can be seen as lexicographic optimization:
minimize collision_rate
subject to: progress > threshold
comfort_score > minimum
This is approximated by the weighted sum with large safety penalties.
Practical Impact (from Table 3):
- Safety-only: Low collision but no progress
- +Navigation: Progress improves significantly
- +Behavior: Collision rate decreases further while maintaining progress
Technical Questions
Q4: How does the Latent Query encoder differ from standard transformers, and why does it perform best?
<details> <summary>Click to reveal answer</summary>Answer: Key differences:
Standard Transformer:
Input -> Self-Attention -> ... -> Output
[N x D] -> [N x D] -> ... -> [N x D]
Output size depends on input size.
Latent Query:
Input -> Cross-Attention(Q=Latents, K,V=Input) -> Self-Attention(Latents) -> Mean Pool
[N x D] -> [K x D] -> [K x D] -> [D]
Output is fixed size regardless of input.
Why it works best:
-
Information Bottleneck: The K learnable latents (K=16) force compression of the N input features, learning to extract what's relevant for driving.
-
Adaptive Queries: Unlike fixed pooling, the learnable queries adapt to find task-relevant information.
-
Computational Efficiency: O(N*K) cross-attention instead of O(N^2) self-attention when N is large.
-
Permutation Invariance: Order of input features doesn't matter (desirable for sets of objects).
-
Stable Gradients: The bottleneck prevents vanishing gradients in deep networks.
Q5: Describe how ScenarioMax reconstructs SDC paths for datasets that don't provide them.
<details> <summary>Click to reveal answer</summary>Answer: SDC path reconstruction involves:
Step 1: Lane Graph Extraction
# Extract lane connectivity from HD map
lane_graph = {
lane_id: {
'centerline': [(x, y), ...],
'successors': [lane_id, ...],
'predecessors': [lane_id, ...]
}
}
Step 2: Find Starting Lanes
# Find lanes near ego's initial position
ego_start = scenario.ego_position[t=0]
candidate_lanes = []
for lane_id, lane in lane_graph.items():
dist = distance_to_centerline(ego_start, lane['centerline'])
if dist < threshold:
candidate_lanes.append(lane_id)
Step 3: BFS Route Generation
# Generate 10 candidate routes via BFS
routes = []
for start_lane in candidate_lanes:
queue = [(start_lane, [start_lane])]
while queue and len(routes) < 10:
current, path = queue.pop(0)
if path_length(path) > min_length:
routes.append(path)
for successor in lane_graph[current]['successors']:
queue.append((successor, path + [successor]))
Step 4: Route Scoring
# Score routes by driving quality metrics
def score_route(route):
return (
-num_lane_changes(route) * 0.3
-num_turns(route) * 0.2
+total_length(route) * 0.1
-curvature(route) * 0.1
)
best_routes = sorted(routes, key=score_route)[:10]
Step 5: Interpolation
# Interpolate to 500 points at regular intervals
sdc_path = interpolate(best_route, num_points=500, spacing=0.5)
</details>
Q6: How does V-Max handle the distribution shift between simulation and real-world driving?
<details> <summary>Click to reveal answer</summary>Answer: V-Max addresses distribution shift through several mechanisms:
1. Diverse Datasets (ScenarioMax)
Distribution Coverage:
WOMD: US highways, suburban
nuPlan: US urban, complex intersections
Combined: Broader coverage → better generalization
2. Reactive Evaluation
Non-Reactive: Other agents replay logs (unrealistic)
Reactive: Other agents use IDM (closer to real)
Adversarial: Stress testing with aggressive agents
3. Domain Randomization (implicit)
# Observation noise
obs += normal(0, 0.1) # Position uncertainty
# Scenario selection
scenarios = sample_diverse(dataset) # Weather, lighting, traffic
4. Robust Reward Design
# Safety margins ensure buffer for real-world uncertainty
collision_threshold = 0.5 # meters (conservative)
ttc_threshold = 0.95 # seconds (conservative)
5. Cross-Dataset Training Results
Training | Test WOMD | Test nuPlan | Avg
----------|-----------|-------------|-----
WOMD only | 95.97% | 89.23% | 92.60%
Combined | 95.38% | 95.54% | 95.46%
→ Combined training improves robustness by ~3%
Remaining Gap: Simulation still differs from real-world in:
- Sensor noise characteristics
- Long-tail events
- Human driver unpredictability
- Weather conditions
These require sim-to-real transfer techniques beyond V-Max's scope.
</details>System Design Questions
Q7: You need to deploy V-Max-trained policies on a real vehicle. What are the key considerations?
<details> <summary>Click to reveal answer</summary>Answer: Key considerations for deployment:
1. Latency Requirements
Simulation: 100ms inference OK (10Hz)
Real-time: 20-50ms required for safety
Solutions:
- Model distillation (large → small)
- TensorRT/XLA compilation
- Batching disabled (single scenario)
2. Observation Pipeline
Simulation: Perfect BEV features from Waymax
Real-world: Must process raw sensor data
Pipeline:
LiDAR/Camera → Perception → Tracking → Feature Extraction → Policy
↑ ↑
Need ML models here Match V-Max format
3. Action Space Mapping
V-Max output: [acceleration, steering] (continuous)
Vehicle input: Throttle, brake, steering angle
Conversion:
if acceleration > 0:
throttle = map_to_throttle(acceleration)
brake = 0
else:
throttle = 0
brake = map_to_brake(-acceleration)
steering_angle = steering * max_steering_angle
4. Safety Wrapper
class SafetyWrapper:
def __call__(self, action, state):
# Hard limits
action = clip(action, safety_bounds)
# Emergency override
if ttc < 0.5:
action = emergency_brake()
# Sanity check
if not is_physically_plausible(action, state):
action = fallback_action()
return action
5. Monitoring and Fallback
Primary: V-Max policy
Fallback 1: Rule-based PDM
Fallback 2: Emergency stop
Fallback 3: Human takeover alert
Trigger conditions:
- Policy uncertainty > threshold
- Perception confidence < threshold
- Out-of-distribution detection
6. Sim-to-Real Validation
Stage 1: Closed track testing
Stage 2: Shadow mode (policy suggests, human drives)
Stage 3: Limited ODD (Operational Design Domain)
Stage 4: Expanded ODD
</details>
Q8: Design an A/B testing framework to compare V-Max policies in simulation.
<details> <summary>Click to reveal answer</summary>Answer: A/B testing framework design:
1. Experiment Configuration
@dataclass
class ABTestConfig:
name: str
policy_a: PolicyConfig # Control
policy_b: PolicyConfig # Treatment
# Scenario allocation
scenario_split: str = "random" # or "stratified"
split_ratio: float = 0.5
# Metrics
primary_metric: str = "v_max_score"
secondary_metrics: list = field(default_factory=lambda: [
"collision_rate", "progress", "comfort"
])
# Statistical settings
min_scenarios: int = 1000
confidence_level: float = 0.95
power: float = 0.80
2. Scenario Stratification
def stratified_split(scenarios, config):
"""
Ensure balanced distribution of:
- Scenario difficulty (easy/medium/hard)
- Scenario type (straight/turn/intersection)
- Traffic density
"""
strata = {}
for scenario in scenarios:
key = (
get_difficulty(scenario),
get_type(scenario),
get_density_bucket(scenario)
)
strata.setdefault(key, []).append(scenario)
split_a, split_b = [], []
for key, group in strata.items():
random.shuffle(group)
mid = len(group) // 2
split_a.extend(group[:mid])
split_b.extend(group[mid:])
return split_a, split_b
3. Evaluation Pipeline
class ABTestRunner:
def run(self, config):
# Parallel evaluation
results_a = self.evaluate(config.policy_a, scenarios_a)
results_b = self.evaluate(config.policy_b, scenarios_b)
# Statistical analysis
analysis = self.analyze(results_a, results_b, config)
return ABTestReport(
config=config,
results_a=results_a,
results_b=results_b,
analysis=analysis
)
4. Statistical Analysis
def analyze(results_a, results_b, config):
# Primary metric comparison
metric_a = results_a[config.primary_metric]
metric_b = results_b[config.primary_metric]
# Bootstrap confidence intervals
ci_a = bootstrap_ci(metric_a, config.confidence_level)
ci_b = bootstrap_ci(metric_b, config.confidence_level)
# Effect size
effect_size = cohens_d(metric_a, metric_b)
# Statistical significance
p_value = permutation_test(metric_a, metric_b)
# Decision
if p_value < (1 - config.confidence_level):
if mean(metric_b) > mean(metric_a):
decision = "B wins"
else:
decision = "A wins"
else:
decision = "No significant difference"
return {
'ci_a': ci_a,
'ci_b': ci_b,
'effect_size': effect_size,
'p_value': p_value,
'decision': decision
}
5. Report Generation
A/B Test Report: SAC_v2 vs SAC_v1
=================================
Primary Metric: V-Max Score
Control (A): 0.887 [0.882, 0.892]
Treatment (B): 0.901 [0.896, 0.906]
Effect Size: +1.6% (Cohen's d = 0.34)
P-value: 0.0023
Decision: B wins (statistically significant)
Secondary Metrics:
Collision Rate: A=0.89%, B=0.72% (p=0.01)
Progress: A=148m, B=151m (p=0.12)
Comfort: A=0.89, B=0.91 (p=0.03)
Recommendation: Deploy policy B
</details>
Q9: How would you implement continuous training for V-Max as new scenarios become available?
<details> <summary>Click to reveal answer</summary>Answer: Continuous training architecture:
1. Data Pipeline
+------------------+ +------------------+ +------------------+
| New Scenarios | --> | Scenario Curator | --> | Training Queue |
| (Daily upload) | | (Filter/Label) | | (Priority Queue) |
+------------------+ +------------------+ +------------------+
|
+------------------+ +------------------+ |
| Model Registry | <-- | Training Job | <-----------+
| (Versioned) | | (GPU Cluster) |
+------------------+ +------------------+
|
v
+------------------+ +------------------+
| Eval Pipeline | --> | Deployment |
| (A/B Test) | | (If improved) |
+------------------+ +------------------+
2. Scenario Prioritization
class ScenarioCurator:
def score_scenario(self, scenario):
"""Prioritize scenarios that are:
1. Novel (different from existing data)
2. Challenging (current policy struggles)
3. Safety-critical (near-misses)
"""
novelty = self.compute_novelty(scenario)
difficulty = self.evaluate_difficulty(scenario)
safety_relevance = self.assess_safety(scenario)
return (
0.3 * novelty +
0.4 * difficulty +
0.3 * safety_relevance
)
3. Continual Learning Strategy
class ContinualTrainer:
def __init__(self, base_policy, replay_buffer):
self.policy = base_policy
self.buffer = replay_buffer # Experience replay
self.old_scenarios = OldScenarioBuffer() # Prevent forgetting
def train_step(self, new_batch):
# Mix new and old data to prevent catastrophic forgetting
old_batch = self.old_scenarios.sample(len(new_batch) // 2)
mixed_batch = concatenate(new_batch, old_batch)
# Standard SAC update
self.policy.update(mixed_batch)
# Elastic Weight Consolidation (optional)
ewc_loss = self.compute_ewc_loss()
self.policy.apply_ewc_regularization(ewc_loss)
4. Automatic Evaluation
class ContinuousEval:
def evaluate_checkpoint(self, checkpoint):
# Evaluate on held-out test set
test_metrics = self.eval_on_test_set(checkpoint)
# Evaluate on new scenarios
new_metrics = self.eval_on_new_scenarios(checkpoint)
# Compare to production policy
prod_metrics = self.get_production_metrics()
return {
'test': test_metrics,
'new': new_metrics,
'vs_prod': compare(test_metrics, prod_metrics)
}
5. Deployment Decision
def should_deploy(eval_results, config):
"""Deploy if:
1. No regression on test set (>= -0.5% V-Max)
2. Improvement on new scenarios (>= +1% V-Max)
3. No safety regression (collision rate <= prod)
"""
return (
eval_results['vs_prod']['v_max_delta'] >= -0.005 and
eval_results['new']['v_max_score'] >= config.new_threshold and
eval_results['test']['collision_rate'] <= config.max_collision_rate
)
6. Monitoring and Rollback
class DeploymentMonitor:
def monitor(self, deployed_policy):
metrics = collect_production_metrics()
if metrics['collision_rate'] > self.alert_threshold:
self.alert("Collision rate elevated!")
if metrics['v_max_score'] < self.rollback_threshold:
self.rollback_to_previous_version()
self.alert("Rolled back due to degraded performance")
</details>
Debugging Questions
Q10: Your V-Max training run shows high collision rate (>10%) after 10M steps. How do you debug?
<details> <summary>Click to reveal answer</summary>Answer: Systematic debugging approach:
Step 1: Check Reward Signal
# Log reward components
def debug_rewards(transitions):
safety_rewards = [t.reward_info['safety'] for t in transitions]
nav_rewards = [t.reward_info['navigation'] for t in transitions]
print(f"Safety reward mean: {np.mean(safety_rewards):.3f}")
print(f"Navigation reward mean: {np.mean(nav_rewards):.3f}")
# Check if safety penalties are being applied
collision_penalties = [r for r in safety_rewards if r < -0.5]
print(f"Collision penalties: {len(collision_penalties)}/{len(safety_rewards)}")
# Verify reward scale
print(f"Reward range: [{min(safety_rewards)}, {max(safety_rewards)}]")
Step 2: Analyze Collision Types
def analyze_collisions(episodes):
collision_types = {
'rear_end': 0,
'side_swipe': 0,
'head_on': 0,
'pedestrian': 0,
'stopped_vehicle': 0
}
for ep in episodes:
if ep.had_collision:
collision_types[classify_collision(ep)] += 1
print("Collision breakdown:", collision_types)
# Most common collision scenarios
# → If rear_end dominant: braking policy issue
# → If side_swipe dominant: lane keeping issue
Step 3: Check Observation Quality
def validate_observations(obs_batch):
# Check for NaN/Inf
for key, value in flatten(obs_batch).items():
if np.any(np.isnan(value)):
print(f"NaN found in {key}")
if np.any(np.isinf(value)):
print(f"Inf found in {key}")
# Check feature scales
for key, value in flatten(obs_batch).items():
print(f"{key}: mean={np.mean(value):.3f}, std={np.std(value):.3f}")
Step 4: Examine Policy Behavior
def visualize_policy_failures(policy, scenarios):
for scenario in sample_collision_scenarios(scenarios, n=10):
# Render scenario
render_scenario(scenario)
# Plot policy actions
actions = []
for t in range(90):
obs = get_observation(scenario, t)
action = policy.get_action(obs)
actions.append(action)
plot_action_sequence(actions)
# Check for erratic behavior
action_changes = np.diff(actions, axis=0)
print(f"Action jerk: {np.std(action_changes):.3f}")
Step 5: Validate Training Pipeline
def validate_training():
# Check gradient norms
grad_norms = training_logs['grad_norms']
if np.mean(grad_norms[-1000:]) > 100:
print("Gradient explosion detected")
if np.mean(grad_norms[-1000:]) < 0.001:
print("Gradient vanishing detected")
# Check Q-value estimates
q_values = training_logs['q_values']
if np.mean(q_values[-1000:]) > 1000:
print("Q-value explosion - reduce learning rate")
# Check replay buffer
buffer_rewards = replay_buffer.get_all_rewards()
print(f"Buffer reward distribution: {np.percentile(buffer_rewards, [25, 50, 75])}")
Common Fixes:
| Issue | Symptom | Fix |
|---|---|---|
| Reward scale | Q-values explode | Normalize rewards to [-1, 1] |
| Observation NaN | Policy outputs NaN | Add input validation |
| Sparse collisions | Policy doesn't learn | Increase collision penalty |
| Gradient explosion | Training unstable | Reduce learning rate, clip gradients |
| Exploration | Same collision repeatedly | Increase entropy, reset buffer |
Summary
V-Max represents a significant step forward in making reinforcement learning practical for autonomous driving. Key takeaways:
- Infrastructure Matters: JAX-based unified computation enables efficient training
- Multi-Dataset Support: ScenarioMax solves the data fragmentation problem
- Hierarchical Rewards: Safety > Navigation > Behavior priority is crucial
- Encoder Architecture: Latent Query's bottleneck compression works best
- Off-Policy RL: SAC achieves expert-level performance in 25M steps
For ML infrastructure engineers, V-Max demonstrates how to build a complete RL pipeline that's:
- Efficient (4,609 steps/sec)
- Modular (YAML-configurable observations)
- Extensible (pluggable encoders and rewards)
- Reproducible (comprehensive benchmarks)
Further Reading
- Waymax Paper - The simulation foundation
- SAC Paper - The RL algorithm
- Motion Transformer - Encoder inspiration
- nuPlan Benchmark - Evaluation metrics
Last updated: January 2026