Back to all papers
Deep Dive #450 min read

V-Max Framework Deep Dive

Complete RL training pipeline on top of Waymax including ScenarioMax, observation design, and reward hierarchy.

V-Max Deep Dive: Reinforcement Learning for Autonomous Driving on Waymax

Paper: V-Max: Making Reinforcement Learning Work for Autonomous Driving Authors: Valeo AI Research Team ArXiv: 2503.08388 Code: github.com/valeoai/v-max


Table of Contents

  1. Executive Summary
  2. Architecture Deep Dive
  3. ScenarioMax: Multi-Dataset Support
  4. Observation Functions
  5. Reward Design
  6. Encoder Architectures
  7. RL Training Pipeline
  8. Interactive Code Examples
  9. Benchmarks Analysis
  10. Hands-On Exercises
  11. Interview Questions

1. Executive Summary

The Problem: RL for Autonomous Driving is Hard

Reinforcement learning has shown remarkable success in games and robotics, but autonomous driving presents unique challenges:

  • Distribution shift: Simulation differs from real-world driving
  • Sparse rewards: Crashes are rare but catastrophic
  • High-dimensional observations: Processing LiDAR, cameras, and maps
  • Safety constraints: Cannot explore freely like in games

Why V-Max Extends Waymax

Waymax (Google DeepMind) provides a JAX-accelerated simulation environment but lacks:

+------------------+     V-Max Adds:     +----------------------+
|    Waymax        | -----------------> |  Complete RL Stack   |
+------------------+                     +----------------------+
| - Simulation     |                     | - Observation Funcs  |
| - WOMD Dataset   |                     | - Reward Hierarchy   |
| - Basic Metrics  |                     | - Transformer Enc.   |
| - JAX Accel.     |                     | - SAC/PPO Pipelines  |
+------------------+                     | - Multi-Dataset      |
                                         | - ScenarioMax        |
                                         +----------------------+

Key Contributions

ComponentWhat It Solves
ScenarioMaxConverts nuPlan, Argoverse 2, WOMD to unified format
Observation FunctionsConfigurable feature extraction (YAML-based)
Hierarchical RewardsSafety > Navigation > Behavior priority
Transformer Encoders4 architectures: LQ, LQH, MTR, Wayformer
RL PipelinesProduction-ready SAC and PPO in JAX

The Bottom Line

V-Max achieves 97.86% accuracy with SAC, matching expert-level driving in simulation while maintaining 0.89% collision rate. This represents a significant leap from baseline approaches.


2. Architecture Deep Dive

System Overview

V-Max models autonomous driving as a Partially Observable Markov Decision Process (POMDP):

                    POMDP Formulation
    +------------------------------------------------+
    |                                                |
    |  State Space S: Full world state (hidden)      |
    |  Observation O: BEV features (partial view)    |
    |  Action Space A: [acceleration, steering]      |
    |  Transition T: Waymax physics simulation       |
    |  Reward R: Hierarchical (Safety/Nav/Behavior)  |
    |                                                |
    +------------------------------------------------+

Data Flow Architecture

+-------------+    +---------------+    +------------------+
|   Waymax    |    | Observation   |    |    Encoder       |
|  Simulator  | -> |   Functions   | -> |  (Transformer)   |
+-------------+    +---------------+    +------------------+
      |                   |                      |
      v                   v                      v
+-------------+    +---------------+    +------------------+
| 9s Scenario |    | - Trajectory  |    | Latent Query     |
| @ 10Hz      |    | - Roadgraph   |    | or MTR/Wayformer |
| (90 steps)  |    | - Traffic     |    | Latent Vector    |
+-------------+    | - Path Target |    +------------------+
                   +---------------+             |
                                                 v
+-------------+    +---------------+    +------------------+
|   Reward    | <- |    Policy     | <- |   Actor-Critic   |
|  Hierarchy  |    |   Network     |    |    Networks      |
+-------------+    +---------------+    +------------------+
      |                   |                      |
      +-------------------+----------------------+
                          |
                          v
                 +------------------+
                 |  RL Algorithm    |
                 |  (SAC or PPO)    |
                 +------------------+

JAX Integration Architecture

The key insight is unified computation graphs:

Traditional RL Pipeline:
+----------+   CPU/GPU    +----------+   CPU/GPU    +----------+
|Simulation| <---------> |  Python  | <---------> | Training |
+----------+   Transfer   +----------+   Transfer   +----------+
                              Slow!

V-Max Pipeline:
+----------------------------------------------------------+
|                    Single JAX Graph                       |
|  +----------+    +-------------+    +----------+         |
|  |Simulation| -> | Observation | -> | Training |         |
|  | (Waymax) |    |  Functions  |    | (Brax)   |         |
|  +----------+    +-------------+    +----------+         |
|                                                          |
|  All operations compiled together - 4,609 steps/sec      |
+----------------------------------------------------------+

Action Space Design

V-Max uses continuous control with bounded actions:

# Action Space Configuration
action_space = {
    'acceleration': {
        'range': [-4.0, 2.0],  # m/s^2 (harder braking allowed)
        'dtype': float32
    },
    'steering': {
        'range': [-0.3, 0.3],  # radians (approx +/- 17 degrees)
        'dtype': float32
    }
}

# Applied at 10Hz for 9 seconds = 90 timesteps per episode

3. ScenarioMax: Multi-Dataset Support

The Dataset Fragmentation Problem

Before ScenarioMax:
                    +--------+
                    | Waymax |
                    +--------+
                        |
                        v
                    +--------+
                    |  WOMD  |  <-- Only supported dataset
                    +--------+

After ScenarioMax:
                    +--------+
                    | Waymax |
                    +--------+
                        |
                        v
                 +-------------+
                 | ScenarioMax |
                 +-------------+
                   /    |    \
                  v     v     v
            +------+ +------+ +------+
            | WOMD | |nuPlan| | AV2  |
            +------+ +------+ +------+

Dataset Comparison

FeatureWOMDnuPlanArgoverse 2
Scenarios487K1.3M250K
LocationUS (6 cities)US (4 cities)US (6 cities)
Duration9.1s15s11s
Frequency10Hz10Hz10Hz
Map FormatRoadgraphHD VectorHD Vector

SDC Path Reconstruction

A critical contribution: generating self-driving car (SDC) paths for datasets that lack them:

Original Data (nuPlan/AV2):
+-------+      +-------+      +-------+
| Start | ---> | ??? | ---> |  End  |
+-------+      +-------+      +-------+
               No path defined

ScenarioMax Path Generation:
Step 1: Extract lane topology from HD map
Step 2: Find lanes near ego start position
Step 3: BFS traversal to generate 10 candidate routes
Step 4: Score routes by: distance, lane changes, turns
Step 5: Select optimal SDC path

Result:
+-------+    +-------+    +-------+    +-------+
| Start | -> | Lane1 | -> | Lane2 | -> |  End  |
+-------+    +-------+    +-------+    +-------+
             ^                         ^
             |   SDC Path (10 routes)  |
             +-------------------------+

TfRecord Schema

# ScenarioMax unified schema
scenario_proto = {
    # Core Waymax fields
    'scenario_id': string,
    'timestamps': float32[91],  # 9.1s at 10Hz

    # Trajectory data
    'state/x': float32[128, 91],      # 128 objects, 91 timesteps
    'state/y': float32[128, 91],
    'state/heading': float32[128, 91],
    'state/velocity_x': float32[128, 91],
    'state/velocity_y': float32[128, 91],
    'state/valid': bool[128, 91],

    # Roadgraph (sampled every 2m)
    'roadgraph/xyz': float32[20000, 3],
    'roadgraph/type': int32[20000],
    'roadgraph/dir': float32[20000, 3],  # V-Max addition

    # SDC Paths (V-Max addition)
    'sdc_paths/xy': float32[10, 500, 2],  # 10 paths, 500 points each
    'sdc_paths/valid': bool[10, 500],

    # Traffic lights
    'traffic_light/state': int32[16, 91],
    'traffic_light/lane_ids': int64[16]
}

Cross-Dataset Training Results

Training on combined datasets improves generalization:

Training DataTest on WOMDTest on nuPlan
WOMD only95.97%89.23%
nuPlan only91.45%96.12%
Combined95.38%95.54%

Key Insight: Combined training sacrifices 0.6% on source domain but gains 6% on transfer domain.


4. Observation Functions

Feature Extraction Pipeline

+------------------+
|   Waymax State   |
+------------------+
         |
         v
+------------------+     +------------------+
| Feature Selector | --> | Ego-Centric      |
| (YAML Config)    |     | Transformation   |
+------------------+     +------------------+
         |                        |
         v                        v
+------------------+     +------------------+
| Trajectory Feat. |     | Normalize to     |
| Roadgraph Feat.  |     | Ego Frame        |
| Traffic Feat.    |     +------------------+
| Path Target Feat.|              |
+------------------+              |
         |                        |
         +------------------------+
                   |
                   v
         +------------------+
         | Observation Dict |
         | (Ready for Enc.) |
         +------------------+

Feature Categories Detailed

1. Trajectory Features

trajectory_features = {
    # Per-object, per-timestep
    'xy': float32[N_obj, T, 2],      # Position relative to ego
    'heading': float32[N_obj, T, 1], # Yaw angle
    'velocity': float32[N_obj, T, 2], # Velocity vector
    'size': float32[N_obj, 3],       # Length, width, height
    'type': int32[N_obj],            # Vehicle, pedestrian, cyclist
    'valid': bool[N_obj, T],         # Visibility mask
}

# Default configuration
N_obj = 16  # Closest 16 objects
T = 5       # 5 historical timesteps (0.5 seconds)

2. Roadgraph Features

roadgraph_features = {
    'xyz': float32[N_road, 3],       # Waypoint positions
    'direction': float32[N_road, 3], # Lane direction vectors
    'type': int32[N_road],           # Lane, road edge, crosswalk
    'speed_limit': float32[N_road],  # Speed limit if available
}

# Sampling strategy
# - Sample every 2 meters along lanes
# - Include lane centers, road edges, crosswalks
# - Filter by distance to ego (default: 100m radius)

3. Traffic Light Features

traffic_light_features = {
    'state': int32[N_lights],      # UNKNOWN, GREEN, YELLOW, RED
    'xy': float32[N_lights, 2],    # Stop line position
    'relevant': bool[N_lights],    # On ego's path?
}

N_lights = 5  # 5 closest traffic lights

4. Path Target Features

path_target_features = {
    'xy': float32[N_targets, 2],   # Waypoint positions
    'heading': float32[N_targets], # Expected heading
    'speed': float32[N_targets],   # Suggested speed
}

N_targets = 10  # 10 waypoints, 5m apart = 50m lookahead

YAML Configuration System

# config/observation/default.yaml
observation:
  trajectory:
    num_objects: 16
    num_timesteps: 5
    features:
      - xy
      - heading
      - velocity
      - size
      - type
    ego_centric: true

  roadgraph:
    max_points: 2000
    sample_interval: 2.0  # meters
    types:
      - lane_center
      - road_edge
      - crosswalk
    radius: 100.0  # meters

  traffic_light:
    num_lights: 5
    include_relevance: true

  path_target:
    num_waypoints: 10
    waypoint_spacing: 5.0  # meters
    include_speed: true

Ablation Study Results

ConfigObjectsTimestepsRoad EdgeV-Max Score
Minimal81No0.76
Standard165No0.84
Optimal165Yes0.88
Maximal3210Yes0.86

Key Finding: Adding road edge features significantly improves off-road prevention. Beyond 16 objects shows diminishing returns.


5. Reward Design

The Hierarchical Philosophy

V-Max implements a priority-based reward hierarchy:

+--------------------+
|     SAFETY         |  <-- Highest Priority (Hard Constraints)
|   No collisions    |
|   No off-road      |
|   No red lights    |
+--------------------+
         |
         v
+--------------------+
|    NAVIGATION      |  <-- Medium Priority (Task Completion)
|   Stay on route    |
|   Make progress    |
|   Reach goal       |
+--------------------+
         |
         v
+--------------------+
|    BEHAVIOR        |  <-- Lowest Priority (Quality)
|   Comfort (jerk)   |
|   Speed compliance |
|   Smooth steering  |
+--------------------+

Mathematical Formulation

Safety Reward

def safety_reward(state):
    """
    Binary penalties for safety violations.
    These should NEVER be traded off for navigation/behavior.
    """
    collision = float(check_collision(state))      # 1 if collision
    offroad = float(check_offroad(state))          # 1 if off drivable area
    red_light = float(check_red_light_violation(state))  # 1 if ran red

    r_safety = -collision - offroad - red_light
    return r_safety  # Range: [-3, 0]
def navigation_reward(state, prev_state):
    """
    Task completion rewards built on top of safety.
    """
    r_safety = safety_reward(state)

    # Route following
    off_route = float(distance_to_route(state) > threshold)
    off_route_penalty = -0.2 * off_route

    # Progress along route
    progress = get_route_progress(state)
    prev_progress = get_route_progress(prev_state)
    progress_reward = 0.2 * float(progress > prev_progress)

    r_navigation = r_safety + off_route_penalty + progress_reward
    return r_navigation  # Range: [-3.2, 0.2]

Behavior Reward

def behavior_reward(state, prev_state):
    """
    Quality of driving built on top of navigation.
    """
    r_navigation = navigation_reward(state, prev_state)

    # Comfort metrics
    jerk = compute_jerk(state, prev_state)
    lateral_accel = compute_lateral_acceleration(state)
    comfort_score = 1.0 - min(1.0, (jerk + lateral_accel) / max_allowed)
    comfort_reward = 0.2 * comfort_score

    # Speed compliance
    speed = get_speed(state)
    speed_limit = get_speed_limit(state)
    speeding_penalty = -0.1 * float(speed > speed_limit)

    r_behavior = r_navigation + comfort_reward + speeding_penalty
    return r_behavior  # Range: [-3.3, 0.4]

Reward Shaping Strategy

Episode Timeline:
t=0                                              t=90
|-----------------------------------------------------|
|  Phase 1: Safety     |  Phase 2: Navigation  | Phase 3: Behavior
|  Focus on survival   |  Focus on progress    | Focus on quality
|                      |                       |

Early Training (0-5M steps):
- Agent learns basic control
- Safety rewards dominate
- Most episodes end in collision

Mid Training (5-15M steps):
- Safety violations decrease
- Navigation rewards become meaningful
- Agent learns route following

Late Training (15-25M steps):
- Safety mostly solved
- Navigation mostly solved
- Behavior refinement
- Comfort scores improve

Reward Impact Analysis (Table 3)

Reward ConfigCollision RateProgressV-Max Score
Safety Only2.1%78.66m0.71
Safety + Nav1.8%155.38m0.82
Full (S+N+B)0.89%148.21m0.89

Key Insight: Navigation rewards doubled progress distance. Adding behavior rewards slightly reduced progress (less aggressive) but significantly improved collision rate.

Dense vs Sparse Rewards

Sparse Rewards (Bad for RL):
+--------+                              +--------+
| t=0    |     No feedback...           | t=90   |
| Start  |----------------------------->| Crash! |
+--------+                              | r=-1   |
                                        +--------+

Dense Rewards (V-Max Approach):
+--------+   +--------+   +--------+   +--------+
| t=0    |   | t=30   |   | t=60   |   | t=90   |
| r=0    |-->| r=+0.1 |-->| r=+0.2 |-->| r=+0.3 |
| Start  |   |Progress|   |Progress|   | Done!  |
+--------+   +--------+   +--------+   +--------+

6. Encoder Architectures

Architecture Comparison Overview

+------------------+     +------------------+     +------------------+
| Latent Query     |     | MTR              |     | Wayformer        |
| (Best: 97.45%)   |     | (95.94%)         |     | (96.08%)         |
+------------------+     +------------------+     +------------------+
|                  |     |                  |     |                  |
| Cross-Attention  |     | Self-Attention   |     | Factorized       |
| to Learnable     |     | with Agent-Centric|    | Attention        |
| Latent Vectors   |     | Scene Encoding   |     | (Time x Space)   |
|                  |     |                  |     |                  |
+------------------+     +------------------+     +------------------+

1. Latent Query (LQ) Encoder - Best Performer

Input Features
     |
     v
+------------------+
| Feature Embedding |  Linear projection to d_model
+------------------+
     |
     v
+------------------+
| Positional        |  Sinusoidal embeddings
| Encoding          |  for temporal/spatial info
+------------------+
     |
     +-----------------------------+
     |                             |
     v                             v
+------------------+     +------------------+
| Cross-Attention  |<----|  Learnable       |
|                  |     |  Latent Vectors  |
| Q: Latents       |     |  [16 x d_model]  |
| K,V: Features    |     +------------------+
+------------------+
     |
     v (repeat 4x)
+------------------+
| Self-Attention   |
|                  |
| Q,K,V: Latents   |
+------------------+
     |
     v
+------------------+
| Mean Pool        |  Aggregate 16 latents
+------------------+
     |
     v
[Scene Embedding: d_model]

Why Latent Query Works Best:

  1. Bottleneck compression: Forces learning of essential features
  2. Learnable queries: Adapts to what information is useful
  3. Efficient: Fixed computation regardless of scene size
class LatentQueryEncoder(nn.Module):
    """JAX/Flax implementation of Latent Query encoder."""

    d_model: int = 256
    n_heads: int = 8
    n_layers: int = 4
    n_latents: int = 16

    @nn.compact
    def __call__(self, features, mask=None):
        # Learnable latent vectors
        latents = self.param(
            'latents',
            nn.initializers.normal(0.02),
            (self.n_latents, self.d_model)
        )

        # Project input features
        x = nn.Dense(self.d_model)(features)

        # Cross-attention: latents query the features
        for _ in range(self.n_layers):
            # Cross-attention
            latents = nn.MultiHeadDotProductAttention(
                num_heads=self.n_heads
            )(latents, x, mask=mask)

            # Self-attention among latents
            latents = nn.MultiHeadDotProductAttention(
                num_heads=self.n_heads
            )(latents, latents)

            # FFN
            latents = latents + nn.Dense(self.d_model)(
                nn.gelu(nn.Dense(self.d_model * 4)(latents))
            )

        # Mean pool latents
        return jnp.mean(latents, axis=0)

2. Latent Query Hierarchical (LQH) Encoder

+------------------+     +------------------+     +------------------+
| Trajectory       |     | Roadgraph        |     | Traffic Light    |
| Features         |     | Features         |     | Features         |
+------------------+     +------------------+     +------------------+
        |                       |                       |
        v                       v                       v
+------------------+     +------------------+     +------------------+
| LQ Encoder       |     | LQ Encoder       |     | LQ Encoder       |
| (Local)          |     | (Local)          |     | (Local)          |
+------------------+     +------------------+     +------------------+
        |                       |                       |
        +----------+------------+-----------+-----------+
                   |            |           |
                   v            v           v
            +----------------------------------+
            |       Concatenate Embeddings     |
            +----------------------------------+
                            |
                            v
            +----------------------------------+
            |       Global LQ Encoder          |
            |       (Cross-modal fusion)       |
            +----------------------------------+
                            |
                            v
                    [Scene Embedding]

Advantage: Modular processing allows specialized encoding per feature type.

3. Motion Transformer (MTR) Encoder

Per-Agent Processing:
+------------------+
| Agent Trajectory |
| [T x Features]   |
+------------------+
        |
        v
+------------------+
| Temporal         |
| Self-Attention   |
+------------------+
        |
        v
[Agent Embedding]

Scene-Level Fusion:
+--------+  +--------+  +--------+  +--------+
| Agent1 |  | Agent2 |  | Map    |  | Traffic|
+--------+  +--------+  +--------+  +--------+
     |          |          |          |
     +----------+----------+----------+
                |
                v
     +------------------+
     | Cross-Agent      |
     | Attention        |
     +------------------+
                |
                v
     +------------------+
     | Agent-Map        |
     | Attention        |
     +------------------+
                |
                v
        [Scene Embedding]

4. Wayformer Encoder

Factorized Attention:

Input: [N_agents x T x Features]

Step 1: Temporal Attention (per agent)
+------------------+
| Self-Attn over   |
| time dimension   |
| [T x T]          |
+------------------+

Step 2: Spatial Attention (per timestep)
+------------------+
| Self-Attn over   |
| agent dimension  |
| [N x N]          |
+------------------+

Why Factorized?
- Full attention: O(N*T)^2
- Factorized: O(N^2) + O(T^2)  << Much cheaper!

Encoder Performance Comparison

EncoderAccuracyCollisionV-Max ScoreParamsSpeed
MLP (baseline)68.12%15.2%0.580.5M10K/s
Latent Query97.45%0.89%0.892.1M4.6K/s
LQ Hierarchical96.28%1.2%0.873.2M3.8K/s
MTR95.94%1.4%0.852.8M4.1K/s
Wayformer96.08%1.3%0.862.5M4.3K/s

Key Insight: Latent Query achieves best performance with moderate parameter count. The learnable bottleneck appears crucial for policy learning.


7. RL Training Pipeline

Algorithm Comparison

+------------------+     +------------------+
|       SAC        |     |       PPO        |
| (Off-Policy)     |     | (On-Policy)      |
+------------------+     +------------------+
|                  |     |                  |
| + Sample         |     | + Stable         |
|   efficient      |     | + Simple to tune |
| + Exploration    |     | + Parallelizes   |
|   via entropy    |     |   well           |
|                  |     |                  |
| - Sensitive to   |     | - Sample         |
|   hyperparams    |     |   inefficient    |
| - Replay buffer  |     | - Needs more     |
|   memory         |     |   environment    |
|                  |     |   steps          |
+------------------+     +------------------+
|                  |     |                  |
| 25M steps        |     | 200M steps       |
| 12-24h training  |     | 24-48h training  |
| Best results     |     | Good results     |
+------------------+     +------------------+

SAC Implementation Details

# V-Max SAC Configuration
sac_config = {
    # Environment
    'num_envs': 16,              # Parallel environments
    'episode_length': 90,        # 9 seconds at 10Hz

    # Replay Buffer
    'buffer_size': 1_000_000,    # 1M transitions
    'min_buffer_size': 10_000,   # Start training after

    # Training
    'total_steps': 25_000_000,
    'batch_size': 256,
    'updates_per_step': 4,       # Gradient updates per env step

    # SAC Hyperparameters
    'actor_lr': 3e-4,
    'critic_lr': 3e-4,
    'alpha_lr': 3e-4,            # Entropy coefficient
    'gamma': 0.99,               # Discount factor
    'tau': 0.005,                # Target network update rate

    # Entropy
    'init_alpha': 0.1,
    'target_entropy': -2.0,      # dim(action) = 2
    'auto_alpha': True,          # Learn alpha

    # Network Architecture
    'encoder': 'latent_query',
    'hidden_dims': [256, 256],
    'activation': 'relu',
}

SAC Training Loop

+------------------+
| Initialize       |
| - Actor network  |
| - Critic network |
| - Target critic  |
| - Replay buffer  |
| - Alpha (entropy)|
+------------------+
        |
        v
+------------------+     +------------------+
| Collect Data     |<----|    Train Loop    |
| (16 parallel)    |     |    (25M steps)   |
+------------------+     +------------------+
        |                        ^
        v                        |
+------------------+             |
| Store in Buffer  |             |
+------------------+             |
        |                        |
        v                        |
+------------------+             |
| Sample Batch     |             |
| (256 transitions)|             |
+------------------+             |
        |                        |
        v                        |
+------------------+             |
| Update Critic    |             |
| (TD error)       |             |
+------------------+             |
        |                        |
        v                        |
+------------------+             |
| Update Actor     |             |
| (Policy gradient)|             |
+------------------+             |
        |                        |
        v                        |
+------------------+             |
| Update Alpha     |             |
| (Entropy target) |             |
+------------------+             |
        |                        |
        v                        |
+------------------+             |
| Update Target    |-------------+
| (Soft update)    |
+------------------+

PPO Implementation Details

# V-Max PPO Configuration
ppo_config = {
    # Environment
    'num_envs': 256,             # More parallel envs for on-policy
    'episode_length': 90,

    # Training
    'total_steps': 200_000_000,  # 8x more than SAC
    'batch_size': 512,
    'num_minibatches': 8,
    'num_epochs': 4,             # Epochs per update

    # PPO Hyperparameters
    'lr': 3e-4,
    'gamma': 0.99,
    'gae_lambda': 0.95,          # GAE parameter
    'clip_epsilon': 0.2,         # PPO clip range
    'vf_coef': 0.5,              # Value loss coefficient
    'ent_coef': 0.01,            # Entropy bonus
    'max_grad_norm': 0.5,        # Gradient clipping

    # Network Architecture
    'encoder': 'latent_query',
    'shared_encoder': True,      # Actor-critic share encoder
}

BC+SAC Hybrid Approach

Phase 1: Behavior Cloning (BC)
+------------------+     +------------------+
| Expert Demos     |---->| Supervised       |
| (Logged Data)    |     | Pre-training     |
+------------------+     +------------------+
                                |
                                v
                         +------------------+
                         | Warm-start       |
                         | Policy           |
                         +------------------+
                                |
Phase 2: SAC Fine-tuning        |
+------------------+            |
| Online RL        |<-----------+
| (SAC)            |
+------------------+
        |
        v
+------------------+
| Final Policy     |
| (Best of both)   |
+------------------+

Benefits:
- BC provides good initialization
- SAC explores and improves
- Reduces sample complexity

Training Infrastructure

Hardware Configuration:
+------------------+
| Single NVIDIA L4 |
| - 24GB VRAM      |
| - 12-48h runtime |
+------------------+

Memory Usage:
+------------------+
| Replay Buffer    |  ~4GB (1M transitions)
| Networks         |  ~500MB
| JAX Compilation  |  ~2GB
| Batch Processing |  ~1GB
+------------------+
Total: ~8GB (fits in L4)

Throughput:
+------------------+
| Training         |  ~1,500 steps/sec
| Evaluation       |  ~4,609 steps/sec (batched)
+------------------+

Learning Curves

SAC Training Progress:

Accuracy (%)
100|                                    ****
 90|                              ******
 80|                        ******
 70|                  ******
 60|            ******
 50|      ******
 40| *****
   +----------------------------------------
   0     5M    10M    15M    20M    25M steps

Collision Rate (%)
 20|*
 15|  *
 10|    **
  5|       ***
  1|           **********                  *
  0|                       ******************
   +----------------------------------------
   0     5M    10M    15M    20M    25M steps

8. Interactive Code Examples

Example 1: Setting Up V-Max Environment

"""
V-Max Training Setup
====================
Complete example of setting up a V-Max training environment.
"""

import jax
import jax.numpy as jnp
from v_max import make_env, make_observation_fn, make_reward_fn
from v_max.encoders import LatentQueryEncoder
from v_max.algorithms import SAC

# 1. Create Waymax environment with V-Max wrapper
env_config = {
    'dataset': 'womd',              # or 'nuplan', 'av2'
    'data_path': '/path/to/tfrecords',
    'batch_size': 16,               # Parallel scenarios
    'max_steps': 90,                # 9 seconds at 10Hz
}

env = make_env(**env_config)

# 2. Configure observation function
obs_config = {
    'trajectory': {
        'num_objects': 16,
        'num_timesteps': 5,
        'features': ['xy', 'heading', 'velocity', 'size', 'type']
    },
    'roadgraph': {
        'max_points': 2000,
        'sample_interval': 2.0,
        'types': ['lane_center', 'road_edge']
    },
    'traffic_light': {
        'num_lights': 5
    },
    'path_target': {
        'num_waypoints': 10,
        'waypoint_spacing': 5.0
    }
}

obs_fn = make_observation_fn(**obs_config)

# 3. Configure reward function
reward_config = {
    'safety': {
        'collision_weight': -1.0,
        'offroad_weight': -1.0,
        'red_light_weight': -1.0
    },
    'navigation': {
        'off_route_weight': -0.2,
        'progress_weight': 0.2
    },
    'behavior': {
        'comfort_weight': 0.2,
        'speeding_weight': -0.1
    }
}

reward_fn = make_reward_fn(**reward_config)

# 4. Create encoder
encoder = LatentQueryEncoder(
    d_model=256,
    n_heads=8,
    n_layers=4,
    n_latents=16
)

# 5. Initialize SAC agent
agent = SAC(
    encoder=encoder,
    action_dim=2,  # [acceleration, steering]
    hidden_dims=[256, 256],
    actor_lr=3e-4,
    critic_lr=3e-4,
    alpha_lr=3e-4,
    gamma=0.99,
    tau=0.005
)

# 6. Training loop
print("Starting V-Max training...")
for step in range(25_000_000):
    # Collect experience
    obs = obs_fn(env.state)
    action = agent.sample_action(obs)
    next_state, reward_info = env.step(action)
    reward = reward_fn(env.state, next_state)

    # Store transition
    agent.buffer.add(obs, action, reward, obs_fn(next_state), done)

    # Update agent
    if step > 10_000:  # After warmup
        metrics = agent.update(batch_size=256, num_updates=4)

    # Logging
    if step % 10_000 == 0:
        print(f"Step {step}: {metrics}")

Example 2: Custom Observation Function

"""
Custom Observation Function
===========================
How to create a specialized observation function for your use case.
"""

import jax.numpy as jnp
from v_max.observations import BaseObservation

class CustomObservation(BaseObservation):
    """
    Custom observation that adds lane curvature features.
    """

    def __init__(self, base_config, curvature_lookahead=50):
        super().__init__(base_config)
        self.curvature_lookahead = curvature_lookahead

    def extract_trajectory_features(self, state):
        """Extract trajectory features with additional info."""
        # Get base features
        traj_features = super().extract_trajectory_features(state)

        # Add relative velocity to ego
        ego_velocity = state.ego_velocity
        relative_velocities = traj_features['velocity'] - ego_velocity
        traj_features['relative_velocity'] = relative_velocities

        return traj_features

    def extract_roadgraph_features(self, state):
        """Extract roadgraph with lane curvature."""
        road_features = super().extract_roadgraph_features(state)

        # Compute lane curvature
        # Curvature = rate of change of heading
        xy = road_features['xyz'][:, :2]
        directions = jnp.diff(xy, axis=0)
        headings = jnp.arctan2(directions[:, 1], directions[:, 0])
        curvature = jnp.diff(headings)

        # Pad to match original length
        curvature = jnp.pad(curvature, (0, 2), mode='edge')
        road_features['curvature'] = curvature

        return road_features

    def extract_path_target_features(self, state):
        """Extract path targets with speed profile."""
        path_features = super().extract_path_target_features(state)

        # Add recommended speed based on curvature
        curvature = self.compute_path_curvature(path_features['xy'])
        max_speed = 30.0  # m/s
        min_speed = 5.0   # m/s

        # Speed inversely proportional to curvature
        recommended_speed = max_speed - (max_speed - min_speed) * jnp.abs(curvature)
        path_features['recommended_speed'] = recommended_speed

        return path_features

# Usage
custom_obs_fn = CustomObservation(
    base_config=obs_config,
    curvature_lookahead=50
)

Example 3: Custom Reward Function

"""
Custom Reward Function
======================
Implementing domain-specific reward shaping.
"""

import jax.numpy as jnp
from v_max.rewards import BaseReward

class AggressiveDriverReward(BaseReward):
    """
    Reward function that encourages faster, more assertive driving.
    Useful for testing policy robustness.
    """

    def __init__(self, base_config, speed_bonus=0.3, gap_acceptance=0.5):
        super().__init__(base_config)
        self.speed_bonus = speed_bonus
        self.gap_acceptance = gap_acceptance

    def compute_safety_reward(self, state, prev_state):
        """Safety with adjusted thresholds."""
        # Still penalize hard constraints
        collision = self.check_collision(state)
        offroad = self.check_offroad(state)
        red_light = self.check_red_light(state)

        # But allow closer following distance
        ttc = self.compute_ttc(state)
        ttc_penalty = -0.5 * jnp.clip(1.0 - ttc, 0, 1)  # Softer TTC

        return -collision - offroad - red_light + ttc_penalty

    def compute_navigation_reward(self, state, prev_state):
        """Navigation with speed bonus."""
        base_nav = super().compute_navigation_reward(state, prev_state)

        # Bonus for maintaining higher speeds
        speed = jnp.linalg.norm(state.ego_velocity)
        speed_limit = self.get_speed_limit(state)

        # Reward for being close to speed limit (not over!)
        speed_ratio = speed / (speed_limit + 1e-6)
        speed_reward = self.speed_bonus * jnp.clip(speed_ratio, 0, 1)

        return base_nav + speed_reward

    def compute_behavior_reward(self, state, prev_state):
        """Behavior that rewards assertive maneuvers."""
        base_behavior = super().compute_behavior_reward(state, prev_state)

        # Bonus for successful lane changes / merges
        lane_change = self.detect_lane_change(state, prev_state)
        merge_success = self.detect_merge(state, prev_state)

        maneuver_bonus = 0.2 * (lane_change + merge_success)

        return base_behavior + maneuver_bonus

# Usage
aggressive_reward = AggressiveDriverReward(
    base_config=reward_config,
    speed_bonus=0.3,
    gap_acceptance=0.5
)

Example 4: Training with Curriculum

"""
Curriculum Learning for V-Max
=============================
Progressive difficulty increase for more stable training.
"""

from v_max.curriculum import CurriculumScheduler

# Define curriculum stages
curriculum_config = {
    'stages': [
        {
            'name': 'basic',
            'steps': 5_000_000,
            'scenario_filter': {
                'max_objects': 8,
                'scenario_types': ['straight', 'lane_follow']
            },
            'reward_scale': 1.0
        },
        {
            'name': 'intermediate',
            'steps': 10_000_000,
            'scenario_filter': {
                'max_objects': 16,
                'scenario_types': ['straight', 'lane_follow', 'turn']
            },
            'reward_scale': 1.0
        },
        {
            'name': 'advanced',
            'steps': 10_000_000,
            'scenario_filter': {
                'max_objects': 32,
                'scenario_types': ['all']
            },
            'reward_scale': 1.0
        }
    ],
    'auto_advance': True,  # Move to next stage when success_rate > 0.9
    'success_metric': 'v_max_score'
}

scheduler = CurriculumScheduler(**curriculum_config)

# Training with curriculum
for step in range(25_000_000):
    # Get current stage
    stage = scheduler.get_stage(step)

    # Sample scenarios from current difficulty
    scenarios = dataset.sample(
        batch_size=16,
        filter=stage['scenario_filter']
    )

    # Load into environment
    env.reset(scenarios)

    # Training step (same as before)
    ...

    # Update curriculum based on performance
    scheduler.update(metrics['v_max_score'])

Example 5: Evaluation Script

"""
V-Max Evaluation
================
Complete evaluation pipeline with all metrics.
"""

from v_max.evaluation import Evaluator, MetricAggregator

# Initialize evaluator
evaluator = Evaluator(
    env=env,
    obs_fn=obs_fn,
    checkpoint_path='checkpoints/sac_final.pkl'
)

# Evaluation modes
evaluation_configs = {
    'non_reactive': {
        'description': 'Other agents replay logged trajectories',
        'controllable_agents': ['ego'],
        'num_scenarios': 1000
    },
    'reactive': {
        'description': 'Other agents use IDM controller',
        'controllable_agents': ['ego', 'idm_agents'],
        'num_scenarios': 1000
    },
    'adversarial': {
        'description': 'ReGentS adversarial agents',
        'controllable_agents': ['ego', 'adversarial_agents'],
        'num_scenarios': 500
    }
}

# Run evaluations
results = {}
for mode, config in evaluation_configs.items():
    print(f"\nRunning {mode} evaluation...")

    metrics = evaluator.evaluate(
        num_scenarios=config['num_scenarios'],
        mode=mode,
        metrics=[
            'accuracy',
            'collision_rate',
            'at_fault_collision_rate',
            'off_road_rate',
            'red_light_violation_rate',
            'progress',
            'ttc',
            'comfort_score',
            'v_max_score'
        ]
    )

    results[mode] = metrics

    # Print summary
    print(f"  Accuracy: {metrics['accuracy']:.2%}")
    print(f"  Collision Rate: {metrics['collision_rate']:.2%}")
    print(f"  V-Max Score: {metrics['v_max_score']:.3f}")

# Aggregate and save results
aggregator = MetricAggregator()
summary = aggregator.create_report(results)
aggregator.save_latex_table(summary, 'results/evaluation_table.tex')

9. Benchmarks Analysis

Main Results: Non-Reactive Evaluation (Table 5)

Performance Comparison on WOMD Validation Set
=============================================

                 +--------+--------+--------+--------+--------+
                 |  SAC   |  PPO   |  BC    | BC+SAC |  PDM   |
                 +--------+--------+--------+--------+--------+
Accuracy         | 97.86% | 90.75% | 79.42% | 93.21% | 93.60% |
                 +--------+--------+--------+--------+--------+
Collision        |  0.89% |  7.81% | 13.14% |  4.32% |  0.96% |
                 +--------+--------+--------+--------+--------+
At-Fault Coll.   |  0.45% |  3.12% |  5.89% |  1.87% |  0.52% |
                 +--------+--------+--------+--------+--------+
Off-Road         |  0.34% |  1.56% |  4.23% |  0.98% |  0.41% |
                 +--------+--------+--------+--------+--------+
Red-Light Viol.  |  0.12% |  0.89% |  2.31% |  0.43% |  0.18% |
                 +--------+--------+--------+--------+--------+
Progress (m)     | 148.21 | 132.45 |  89.67 | 141.32 | 145.67 |
                 +--------+--------+--------+--------+--------+
TTC Score        |  0.97  |  0.91  |  0.78  |  0.94  |  0.96  |
                 +--------+--------+--------+--------+--------+
Comfort Score    |  0.89  |  0.82  |  0.71  |  0.85  |  0.91  |
                 +--------+--------+--------+--------+--------+
V-Max Score      |  0.892 |  0.784 |  0.721 |  0.843 |  0.887 |
                 +--------+--------+--------+--------+--------+

Key Insights:

  1. SAC dominates: Achieves best accuracy (97.86%) and V-Max score (0.892)
  2. BC alone insufficient: Pure imitation learning achieves only 79.42% accuracy
  3. Hybrid helps: BC+SAC (93.21%) outperforms BC (79.42%) significantly
  4. PDM competitive: Rule-based PDM achieves 93.60% accuracy, close to RL methods

Reactive Evaluation (Table 6)

Performance with IDM-Controlled Traffic
=======================================

When other vehicles react to ego:

                 | Non-Reactive | Reactive | Delta    |
                 +--------------+----------+----------+
SAC Accuracy     |    97.86%    |  94.12%  |  -3.74%  |
SAC Collision    |     0.89%    |   2.34%  |  +1.45%  |
                 +--------------+----------+----------+
PDM Accuracy     |    93.60%    |  91.23%  |  -2.37%  |
PDM Collision    |     0.96%    |   1.89%  |  +0.93%  |
                 +--------------+----------+----------+

Analysis:
- All methods degrade with reactive agents
- SAC more sensitive than PDM (-3.74% vs -2.37%)
- Gap narrows in reactive setting

Adversarial Evaluation (Table 7)

Performance Under ReGentS Adversarial Attacks
=============================================

Adversarial cut-in scenarios:

                 | Standard | Adversarial | Robustness |
                 +----------+-------------+------------+
SAC Accuracy     |  97.86%  |    76.23%   |   77.9%    |
SAC Collision    |   0.89%  |    18.45%   |   -17.6%   |
                 +----------+-------------+------------+
PDM Accuracy     |  93.60%  |    53.12%   |   56.7%    |
PDM Collision    |   0.96%  |    32.67%   |   -31.7%   |
                 +----------+-------------+------------+

Key Finding:
- SAC significantly more robust than PDM under adversarial attacks
- SAC robustness: 77.9% vs PDM: 56.7%
- RL learns more generalizable policies

Encoder Comparison (Table 2)

Impact of Encoder Architecture
==============================

Encoder          | Accuracy | V-Max Score | Collision |
-----------------+----------+-------------+-----------+
MLP (baseline)   |  68.12%  |    0.58     |   15.2%   |
-----------------+----------+-------------+-----------+
Wayformer        |  96.08%  |    0.86     |    1.3%   |
MTR              |  95.94%  |    0.85     |    1.4%   |
LQ Hierarchical  |  96.28%  |    0.87     |    1.2%   |
Latent Query     |  97.45%  |    0.89     |    0.9%   |
-----------------+----------+-------------+-----------+

Insight: Transformer encoders dramatically outperform MLP
         Latent Query best overall

Cross-Dataset Generalization (Table 4)

Training on WOMD+nuPlan, Testing on Each
========================================

                 | WOMD-only | nuPlan-only | Combined |
-----------------+-----------+-------------+----------|
Test on WOMD     |   95.97%  |    91.45%   |  95.38%  |
Test on nuPlan   |   89.23%  |    96.12%   |  95.54%  |
-----------------+-----------+-------------+----------|
Average          |   92.60%  |    93.78%   |  95.46%  |

Key Finding:
- Combined training achieves best average performance
- Only 0.6% drop on source domain
- 6% gain on transfer domain
- Combined = more robust, generalizable policies

Computational Efficiency (Table 8)

Steps Per Second (SPS) Benchmarks
=================================

Configuration               | SPS    | Relative |
----------------------------+--------+----------|
Waymax baseline (no metrics)|  8,234 |   1.0x   |
V-Max (full metrics)        |  4,609 |   0.56x  |
V-Max (minimal metrics)     |  6,123 |   0.74x  |
V-Max (32 batch)            |  4,609 |   0.56x  |
V-Max (128 batch)           |  3,891 |   0.47x  |

Insight:
- Full metrics computation adds ~44% overhead
- Still achieves >4K SPS - sufficient for RL
- Batching helps but has diminishing returns

10. Hands-On Exercises

Exercise 1: Basic Environment Setup (Beginner)

Objective: Set up V-Max environment and run a random policy.

"""
Exercise 1: Basic Setup
-----------------------
TODO: Complete the following tasks

1. Load the WOMD dataset
2. Create a V-Max environment
3. Run 100 episodes with random actions
4. Compute average episode length and collision rate
"""

import jax
import jax.numpy as jnp
from jax import random

# TODO: Import V-Max modules
# from v_max import ...

def random_policy(key, obs):
    """
    TODO: Implement a random policy that returns:
    - acceleration in [-4, 2]
    - steering in [-0.3, 0.3]
    """
    key, subkey = random.split(key)
    # Your code here
    pass

def run_episode(env, policy_fn, key):
    """
    TODO: Run a single episode and return:
    - episode_length: int
    - collision: bool
    - total_progress: float
    """
    # Your code here
    pass

def main():
    # Setup
    key = random.PRNGKey(42)

    # TODO: Create environment
    # env = ...

    # Run episodes
    results = []
    for i in range(100):
        key, subkey = random.split(key)
        result = run_episode(env, random_policy, subkey)
        results.append(result)

    # Compute statistics
    # TODO: Print average episode length, collision rate, progress
    pass

if __name__ == "__main__":
    main()

Expected Output:

Random Policy Results:
  Average Episode Length: 23.4 steps (target: ~25)
  Collision Rate: 87.3% (target: 80-95%)
  Average Progress: 12.3m (target: 10-20m)

Exercise 2: Observation Function Analysis (Beginner)

Objective: Understand and visualize observation features.

"""
Exercise 2: Observation Analysis
--------------------------------
TODO: Analyze the observation function outputs

1. Extract observations from a scenario
2. Visualize trajectory features
3. Compute feature statistics
4. Identify the most important features
"""

import matplotlib.pyplot as plt
import numpy as np

def visualize_scene(obs):
    """
    TODO: Create a bird's-eye view visualization showing:
    - Ego vehicle (red)
    - Other vehicles (blue)
    - Roadgraph (gray lines)
    - Traffic lights (colored circles)
    - SDC path (green dashed line)
    """
    fig, ax = plt.subplots(figsize=(12, 12))

    # Plot roadgraph
    # road_xy = obs['roadgraph']['xyz'][:, :2]
    # TODO: Your code here

    # Plot vehicles
    # traj_xy = obs['trajectory']['xy']
    # TODO: Your code here

    # Plot ego path
    # path_xy = obs['path_target']['xy']
    # TODO: Your code here

    ax.set_aspect('equal')
    ax.set_title('Bird\'s Eye View')
    plt.show()

def compute_feature_statistics(obs_batch):
    """
    TODO: Compute statistics for each feature:
    - Mean, std, min, max
    - Correlation with collision outcome
    """
    stats = {}

    # For each feature category
    for category in ['trajectory', 'roadgraph', 'traffic_light', 'path_target']:
        # TODO: Compute statistics
        pass

    return stats

# Run analysis
# TODO: Load sample observations and run analysis

Exercise 3: Reward Shaping Experiment (Intermediate)

Objective: Compare different reward configurations.

"""
Exercise 3: Reward Shaping
--------------------------
TODO: Experiment with different reward configurations

1. Implement three reward variants:
   a) Safety-only rewards
   b) Safety + Navigation rewards
   c) Full hierarchical rewards (Safety + Nav + Behavior)

2. Train each for 1M steps
3. Compare learning curves and final performance
"""

reward_configs = {
    'safety_only': {
        # TODO: Define safety-only rewards
    },
    'safety_nav': {
        # TODO: Define safety + navigation rewards
    },
    'full': {
        # TODO: Define full hierarchical rewards
    }
}

def train_with_reward(reward_config, num_steps=1_000_000):
    """
    TODO: Implement training loop with specified reward config
    Return learning curves (collision_rate, progress, v_max_score)
    """
    pass

def plot_comparison(results):
    """
    TODO: Plot learning curves for all three configurations
    """
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))

    metrics = ['collision_rate', 'progress', 'v_max_score']
    for ax, metric in zip(axes, metrics):
        for name, data in results.items():
            ax.plot(data['steps'], data[metric], label=name)
        ax.set_title(metric)
        ax.legend()

    plt.tight_layout()
    plt.savefig('reward_comparison.png')

Expected Results:

After 1M steps:

             | Safety Only | Safety+Nav | Full    |
-------------+-------------+------------+---------|
Collision    |    5.2%     |    3.1%    |   2.3%  |
Progress     |   45.2m     |   89.3m    |  82.1m  |
V-Max Score  |    0.65     |    0.78    |   0.82  |

Exercise 4: Encoder Architecture Comparison (Intermediate)

Objective: Implement and compare encoder architectures.

"""
Exercise 4: Encoder Comparison
------------------------------
TODO: Implement simplified versions of each encoder and compare

1. Implement MLP baseline
2. Implement simplified Latent Query
3. Train both for 5M steps
4. Analyze performance differences
"""

import flax.linen as nn

class MLPEncoder(nn.Module):
    """
    TODO: Implement MLP baseline encoder

    Architecture:
    - Flatten all inputs
    - 3 hidden layers [256, 256, 256]
    - ReLU activations
    - Output: 256-dim embedding
    """
    @nn.compact
    def __call__(self, obs):
        # Your code here
        pass

class SimpleLatentQuery(nn.Module):
    """
    TODO: Implement simplified Latent Query encoder

    Architecture:
    - Linear projection of inputs
    - 2 cross-attention layers
    - 8 learnable latent vectors
    - Mean pooling
    """
    n_latents: int = 8
    d_model: int = 256

    @nn.compact
    def __call__(self, obs):
        # Your code here
        pass

def compare_encoders(encoders, num_steps=5_000_000):
    """
    TODO: Train each encoder and collect metrics
    """
    results = {}
    for name, encoder_cls in encoders.items():
        print(f"Training {name}...")
        # Training code here
        pass
    return results

Exercise 5: SAC Hyperparameter Tuning (Advanced)

Objective: Understand SAC hyperparameters through systematic tuning.

"""
Exercise 5: SAC Hyperparameter Tuning
-------------------------------------
TODO: Systematically tune SAC hyperparameters

Key hyperparameters to explore:
1. Learning rates (actor, critic, alpha)
2. Batch size
3. Updates per environment step
4. Target entropy
5. Replay buffer size
"""

import optuna

def objective(trial):
    """
    TODO: Implement Optuna objective for hyperparameter tuning
    """
    # Sample hyperparameters
    config = {
        'actor_lr': trial.suggest_float('actor_lr', 1e-5, 1e-3, log=True),
        'critic_lr': trial.suggest_float('critic_lr', 1e-5, 1e-3, log=True),
        'batch_size': trial.suggest_categorical('batch_size', [64, 128, 256, 512]),
        'updates_per_step': trial.suggest_int('updates_per_step', 1, 8),
        'target_entropy': trial.suggest_float('target_entropy', -4.0, -1.0),
        'buffer_size': trial.suggest_categorical('buffer_size', [100000, 500000, 1000000]),
    }

    # Train for limited steps
    # TODO: Implement training

    # Return validation metric
    # return v_max_score
    pass

# Run hyperparameter search
study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=50)

print("Best hyperparameters:", study.best_params)
print("Best V-Max score:", study.best_value)

Expected Results:

Best hyperparameters:
  actor_lr: 3.2e-4
  critic_lr: 2.8e-4
  batch_size: 256
  updates_per_step: 4
  target_entropy: -2.0
  buffer_size: 1000000

Best V-Max score: 0.89

Exercise 6: Multi-Dataset Training (Advanced)

Objective: Train on combined datasets and analyze generalization.

"""
Exercise 6: Multi-Dataset Training
----------------------------------
TODO: Implement cross-dataset training and evaluate transfer

1. Load WOMD and nuPlan datasets
2. Create combined dataloader
3. Train policy on combined data
4. Evaluate on each dataset separately
5. Compare to single-dataset training
"""

class MultiDatasetLoader:
    """
    TODO: Implement a dataloader that samples from multiple datasets

    Features:
    - Configurable sampling ratios
    - Dataset balancing
    - Domain labels for analysis
    """

    def __init__(self, datasets, ratios=None):
        self.datasets = datasets
        self.ratios = ratios or {name: 1.0 for name in datasets}
        # Your code here

    def sample(self, batch_size):
        """
        TODO: Sample a batch from combined datasets
        Return batch and domain labels
        """
        pass

def evaluate_transfer(policy, datasets):
    """
    TODO: Evaluate policy on each dataset and return metrics
    """
    results = {}
    for name, dataset in datasets.items():
        # Evaluate
        metrics = evaluate(policy, dataset)
        results[name] = metrics
    return results

def create_transfer_matrix(results):
    """
    TODO: Create a transfer matrix showing:
    - Rows: Training dataset
    - Columns: Evaluation dataset
    - Values: V-Max score
    """
    pass

11. Interview Questions

Conceptual Questions

Q1: Why does V-Max use a POMDP formulation instead of MDP?

<details> <summary>Click to reveal answer</summary>

Answer: Autonomous driving is inherently partially observable because:

  1. Sensor limitations: The ego vehicle cannot observe the full state of the world - occluded vehicles, pedestrians behind buildings, etc.

  2. Intentions are hidden: Other agents' goals and planned trajectories are not directly observable.

  3. Map uncertainty: The exact lane boundaries and traffic signal states may be ambiguous.

In the POMDP formulation:

  • State S: Full world state (positions, velocities, intentions of all agents)
  • Observation O: What ego can perceive (BEV features within sensor range)
  • Belief: The policy must maintain implicit beliefs about hidden state

This is why transformer encoders with attention over historical observations work well - they can learn to aggregate temporal information to estimate hidden state.

</details>

Q2: Explain why SAC outperforms PPO in V-Max despite PPO often being preferred for robotics.

<details> <summary>Click to reveal answer</summary>

Answer: Several factors favor SAC in the V-Max setting:

  1. Sample Efficiency: SAC's off-policy learning with replay buffer achieves good performance in 25M steps vs PPO's 200M steps. This is crucial when simulation is expensive.

  2. Exploration via Entropy: SAC's entropy maximization encourages diverse driving behaviors, helping discover safe trajectories in complex scenarios.

  3. Continuous Action Space: SAC was designed for continuous control (acceleration, steering), while PPO's clipping can be sensitive in high-dimensional continuous spaces.

  4. Deterministic Evaluation: SAC can use a deterministic policy at test time, reducing variance in safety-critical scenarios.

  5. Dense Rewards: V-Max's hierarchical rewards provide dense feedback, which off-policy methods like SAC utilize efficiently.

PPO might be preferred when:

  • Sample efficiency matters less than stability
  • Distributed training across many workers is needed
  • The reward signal is very sparse
</details>

Q3: Why is the hierarchical reward structure (Safety > Navigation > Behavior) important?

<details> <summary>Click to reveal answer</summary>

Answer: The hierarchy enforces a priority ordering that mirrors human driving values:

  1. Safety as Hard Constraint: Collisions and traffic violations are never acceptable, regardless of navigation goals. By making safety rewards dominant, the policy learns to prioritize survival.

  2. Navigation Enables Learning: Without navigation rewards, a policy might learn to "play it safe" by never moving. Navigation rewards ensure the agent attempts the driving task.

  3. Behavior Refinement: Comfort and compliance only matter once safety and navigation are achieved. Adding behavior rewards too early can cause reward hacking.

Mathematical Insight: The hierarchy can be seen as lexicographic optimization:

minimize collision_rate
subject to: progress > threshold
            comfort_score > minimum

This is approximated by the weighted sum with large safety penalties.

Practical Impact (from Table 3):

  • Safety-only: Low collision but no progress
  • +Navigation: Progress improves significantly
  • +Behavior: Collision rate decreases further while maintaining progress
</details>

Technical Questions

Q4: How does the Latent Query encoder differ from standard transformers, and why does it perform best?

<details> <summary>Click to reveal answer</summary>

Answer: Key differences:

Standard Transformer:

Input -> Self-Attention -> ... -> Output
[N x D] -> [N x D] -> ... -> [N x D]

Output size depends on input size.

Latent Query:

Input -> Cross-Attention(Q=Latents, K,V=Input) -> Self-Attention(Latents) -> Mean Pool
[N x D] -> [K x D] -> [K x D] -> [D]

Output is fixed size regardless of input.

Why it works best:

  1. Information Bottleneck: The K learnable latents (K=16) force compression of the N input features, learning to extract what's relevant for driving.

  2. Adaptive Queries: Unlike fixed pooling, the learnable queries adapt to find task-relevant information.

  3. Computational Efficiency: O(N*K) cross-attention instead of O(N^2) self-attention when N is large.

  4. Permutation Invariance: Order of input features doesn't matter (desirable for sets of objects).

  5. Stable Gradients: The bottleneck prevents vanishing gradients in deep networks.

</details>

Q5: Describe how ScenarioMax reconstructs SDC paths for datasets that don't provide them.

<details> <summary>Click to reveal answer</summary>

Answer: SDC path reconstruction involves:

Step 1: Lane Graph Extraction

# Extract lane connectivity from HD map
lane_graph = {
    lane_id: {
        'centerline': [(x, y), ...],
        'successors': [lane_id, ...],
        'predecessors': [lane_id, ...]
    }
}

Step 2: Find Starting Lanes

# Find lanes near ego's initial position
ego_start = scenario.ego_position[t=0]
candidate_lanes = []
for lane_id, lane in lane_graph.items():
    dist = distance_to_centerline(ego_start, lane['centerline'])
    if dist < threshold:
        candidate_lanes.append(lane_id)

Step 3: BFS Route Generation

# Generate 10 candidate routes via BFS
routes = []
for start_lane in candidate_lanes:
    queue = [(start_lane, [start_lane])]
    while queue and len(routes) < 10:
        current, path = queue.pop(0)
        if path_length(path) > min_length:
            routes.append(path)
        for successor in lane_graph[current]['successors']:
            queue.append((successor, path + [successor]))

Step 4: Route Scoring

# Score routes by driving quality metrics
def score_route(route):
    return (
        -num_lane_changes(route) * 0.3
        -num_turns(route) * 0.2
        +total_length(route) * 0.1
        -curvature(route) * 0.1
    )

best_routes = sorted(routes, key=score_route)[:10]

Step 5: Interpolation

# Interpolate to 500 points at regular intervals
sdc_path = interpolate(best_route, num_points=500, spacing=0.5)
</details>

Q6: How does V-Max handle the distribution shift between simulation and real-world driving?

<details> <summary>Click to reveal answer</summary>

Answer: V-Max addresses distribution shift through several mechanisms:

1. Diverse Datasets (ScenarioMax)

Distribution Coverage:
WOMD: US highways, suburban
nuPlan: US urban, complex intersections
Combined: Broader coverage → better generalization

2. Reactive Evaluation

Non-Reactive: Other agents replay logs (unrealistic)
Reactive: Other agents use IDM (closer to real)
Adversarial: Stress testing with aggressive agents

3. Domain Randomization (implicit)

# Observation noise
obs += normal(0, 0.1)  # Position uncertainty

# Scenario selection
scenarios = sample_diverse(dataset)  # Weather, lighting, traffic

4. Robust Reward Design

# Safety margins ensure buffer for real-world uncertainty
collision_threshold = 0.5  # meters (conservative)
ttc_threshold = 0.95  # seconds (conservative)

5. Cross-Dataset Training Results

Training  | Test WOMD | Test nuPlan | Avg
----------|-----------|-------------|-----
WOMD only |   95.97%  |    89.23%   | 92.60%
Combined  |   95.38%  |    95.54%   | 95.46%

→ Combined training improves robustness by ~3%

Remaining Gap: Simulation still differs from real-world in:

  • Sensor noise characteristics
  • Long-tail events
  • Human driver unpredictability
  • Weather conditions

These require sim-to-real transfer techniques beyond V-Max's scope.

</details>

System Design Questions

Q7: You need to deploy V-Max-trained policies on a real vehicle. What are the key considerations?

<details> <summary>Click to reveal answer</summary>

Answer: Key considerations for deployment:

1. Latency Requirements

Simulation: 100ms inference OK (10Hz)
Real-time: 20-50ms required for safety

Solutions:
- Model distillation (large → small)
- TensorRT/XLA compilation
- Batching disabled (single scenario)

2. Observation Pipeline

Simulation: Perfect BEV features from Waymax
Real-world: Must process raw sensor data

Pipeline:
LiDAR/Camera → Perception → Tracking → Feature Extraction → Policy
                  ↑                         ↑
           Need ML models here        Match V-Max format

3. Action Space Mapping

V-Max output: [acceleration, steering] (continuous)
Vehicle input: Throttle, brake, steering angle

Conversion:
if acceleration > 0:
    throttle = map_to_throttle(acceleration)
    brake = 0
else:
    throttle = 0
    brake = map_to_brake(-acceleration)
steering_angle = steering * max_steering_angle

4. Safety Wrapper

class SafetyWrapper:
    def __call__(self, action, state):
        # Hard limits
        action = clip(action, safety_bounds)

        # Emergency override
        if ttc < 0.5:
            action = emergency_brake()

        # Sanity check
        if not is_physically_plausible(action, state):
            action = fallback_action()

        return action

5. Monitoring and Fallback

Primary: V-Max policy
Fallback 1: Rule-based PDM
Fallback 2: Emergency stop
Fallback 3: Human takeover alert

Trigger conditions:
- Policy uncertainty > threshold
- Perception confidence < threshold
- Out-of-distribution detection

6. Sim-to-Real Validation

Stage 1: Closed track testing
Stage 2: Shadow mode (policy suggests, human drives)
Stage 3: Limited ODD (Operational Design Domain)
Stage 4: Expanded ODD
</details>

Q8: Design an A/B testing framework to compare V-Max policies in simulation.

<details> <summary>Click to reveal answer</summary>

Answer: A/B testing framework design:

1. Experiment Configuration

@dataclass
class ABTestConfig:
    name: str
    policy_a: PolicyConfig  # Control
    policy_b: PolicyConfig  # Treatment

    # Scenario allocation
    scenario_split: str = "random"  # or "stratified"
    split_ratio: float = 0.5

    # Metrics
    primary_metric: str = "v_max_score"
    secondary_metrics: list = field(default_factory=lambda: [
        "collision_rate", "progress", "comfort"
    ])

    # Statistical settings
    min_scenarios: int = 1000
    confidence_level: float = 0.95
    power: float = 0.80

2. Scenario Stratification

def stratified_split(scenarios, config):
    """
    Ensure balanced distribution of:
    - Scenario difficulty (easy/medium/hard)
    - Scenario type (straight/turn/intersection)
    - Traffic density
    """
    strata = {}
    for scenario in scenarios:
        key = (
            get_difficulty(scenario),
            get_type(scenario),
            get_density_bucket(scenario)
        )
        strata.setdefault(key, []).append(scenario)

    split_a, split_b = [], []
    for key, group in strata.items():
        random.shuffle(group)
        mid = len(group) // 2
        split_a.extend(group[:mid])
        split_b.extend(group[mid:])

    return split_a, split_b

3. Evaluation Pipeline

class ABTestRunner:
    def run(self, config):
        # Parallel evaluation
        results_a = self.evaluate(config.policy_a, scenarios_a)
        results_b = self.evaluate(config.policy_b, scenarios_b)

        # Statistical analysis
        analysis = self.analyze(results_a, results_b, config)

        return ABTestReport(
            config=config,
            results_a=results_a,
            results_b=results_b,
            analysis=analysis
        )

4. Statistical Analysis

def analyze(results_a, results_b, config):
    # Primary metric comparison
    metric_a = results_a[config.primary_metric]
    metric_b = results_b[config.primary_metric]

    # Bootstrap confidence intervals
    ci_a = bootstrap_ci(metric_a, config.confidence_level)
    ci_b = bootstrap_ci(metric_b, config.confidence_level)

    # Effect size
    effect_size = cohens_d(metric_a, metric_b)

    # Statistical significance
    p_value = permutation_test(metric_a, metric_b)

    # Decision
    if p_value < (1 - config.confidence_level):
        if mean(metric_b) > mean(metric_a):
            decision = "B wins"
        else:
            decision = "A wins"
    else:
        decision = "No significant difference"

    return {
        'ci_a': ci_a,
        'ci_b': ci_b,
        'effect_size': effect_size,
        'p_value': p_value,
        'decision': decision
    }

5. Report Generation

A/B Test Report: SAC_v2 vs SAC_v1
=================================

Primary Metric: V-Max Score
  Control (A):    0.887 [0.882, 0.892]
  Treatment (B):  0.901 [0.896, 0.906]
  Effect Size:    +1.6% (Cohen's d = 0.34)
  P-value:        0.0023
  Decision:       B wins (statistically significant)

Secondary Metrics:
  Collision Rate: A=0.89%, B=0.72% (p=0.01)
  Progress:       A=148m, B=151m (p=0.12)
  Comfort:        A=0.89, B=0.91 (p=0.03)

Recommendation: Deploy policy B
</details>

Q9: How would you implement continuous training for V-Max as new scenarios become available?

<details> <summary>Click to reveal answer</summary>

Answer: Continuous training architecture:

1. Data Pipeline

+------------------+     +------------------+     +------------------+
| New Scenarios    | --> | Scenario Curator | --> | Training Queue   |
| (Daily upload)   |     | (Filter/Label)   |     | (Priority Queue) |
+------------------+     +------------------+     +------------------+
                                                          |
+------------------+     +------------------+             |
| Model Registry   | <-- | Training Job     | <-----------+
| (Versioned)      |     | (GPU Cluster)    |
+------------------+     +------------------+
        |
        v
+------------------+     +------------------+
| Eval Pipeline    | --> | Deployment       |
| (A/B Test)       |     | (If improved)    |
+------------------+     +------------------+

2. Scenario Prioritization

class ScenarioCurator:
    def score_scenario(self, scenario):
        """Prioritize scenarios that are:
        1. Novel (different from existing data)
        2. Challenging (current policy struggles)
        3. Safety-critical (near-misses)
        """
        novelty = self.compute_novelty(scenario)
        difficulty = self.evaluate_difficulty(scenario)
        safety_relevance = self.assess_safety(scenario)

        return (
            0.3 * novelty +
            0.4 * difficulty +
            0.3 * safety_relevance
        )

3. Continual Learning Strategy

class ContinualTrainer:
    def __init__(self, base_policy, replay_buffer):
        self.policy = base_policy
        self.buffer = replay_buffer  # Experience replay
        self.old_scenarios = OldScenarioBuffer()  # Prevent forgetting

    def train_step(self, new_batch):
        # Mix new and old data to prevent catastrophic forgetting
        old_batch = self.old_scenarios.sample(len(new_batch) // 2)
        mixed_batch = concatenate(new_batch, old_batch)

        # Standard SAC update
        self.policy.update(mixed_batch)

        # Elastic Weight Consolidation (optional)
        ewc_loss = self.compute_ewc_loss()
        self.policy.apply_ewc_regularization(ewc_loss)

4. Automatic Evaluation

class ContinuousEval:
    def evaluate_checkpoint(self, checkpoint):
        # Evaluate on held-out test set
        test_metrics = self.eval_on_test_set(checkpoint)

        # Evaluate on new scenarios
        new_metrics = self.eval_on_new_scenarios(checkpoint)

        # Compare to production policy
        prod_metrics = self.get_production_metrics()

        return {
            'test': test_metrics,
            'new': new_metrics,
            'vs_prod': compare(test_metrics, prod_metrics)
        }

5. Deployment Decision

def should_deploy(eval_results, config):
    """Deploy if:
    1. No regression on test set (>= -0.5% V-Max)
    2. Improvement on new scenarios (>= +1% V-Max)
    3. No safety regression (collision rate <= prod)
    """
    return (
        eval_results['vs_prod']['v_max_delta'] >= -0.005 and
        eval_results['new']['v_max_score'] >= config.new_threshold and
        eval_results['test']['collision_rate'] <= config.max_collision_rate
    )

6. Monitoring and Rollback

class DeploymentMonitor:
    def monitor(self, deployed_policy):
        metrics = collect_production_metrics()

        if metrics['collision_rate'] > self.alert_threshold:
            self.alert("Collision rate elevated!")

        if metrics['v_max_score'] < self.rollback_threshold:
            self.rollback_to_previous_version()
            self.alert("Rolled back due to degraded performance")
</details>

Debugging Questions

Q10: Your V-Max training run shows high collision rate (>10%) after 10M steps. How do you debug?

<details> <summary>Click to reveal answer</summary>

Answer: Systematic debugging approach:

Step 1: Check Reward Signal

# Log reward components
def debug_rewards(transitions):
    safety_rewards = [t.reward_info['safety'] for t in transitions]
    nav_rewards = [t.reward_info['navigation'] for t in transitions]

    print(f"Safety reward mean: {np.mean(safety_rewards):.3f}")
    print(f"Navigation reward mean: {np.mean(nav_rewards):.3f}")

    # Check if safety penalties are being applied
    collision_penalties = [r for r in safety_rewards if r < -0.5]
    print(f"Collision penalties: {len(collision_penalties)}/{len(safety_rewards)}")

    # Verify reward scale
    print(f"Reward range: [{min(safety_rewards)}, {max(safety_rewards)}]")

Step 2: Analyze Collision Types

def analyze_collisions(episodes):
    collision_types = {
        'rear_end': 0,
        'side_swipe': 0,
        'head_on': 0,
        'pedestrian': 0,
        'stopped_vehicle': 0
    }

    for ep in episodes:
        if ep.had_collision:
            collision_types[classify_collision(ep)] += 1

    print("Collision breakdown:", collision_types)

    # Most common collision scenarios
    # → If rear_end dominant: braking policy issue
    # → If side_swipe dominant: lane keeping issue

Step 3: Check Observation Quality

def validate_observations(obs_batch):
    # Check for NaN/Inf
    for key, value in flatten(obs_batch).items():
        if np.any(np.isnan(value)):
            print(f"NaN found in {key}")
        if np.any(np.isinf(value)):
            print(f"Inf found in {key}")

    # Check feature scales
    for key, value in flatten(obs_batch).items():
        print(f"{key}: mean={np.mean(value):.3f}, std={np.std(value):.3f}")

Step 4: Examine Policy Behavior

def visualize_policy_failures(policy, scenarios):
    for scenario in sample_collision_scenarios(scenarios, n=10):
        # Render scenario
        render_scenario(scenario)

        # Plot policy actions
        actions = []
        for t in range(90):
            obs = get_observation(scenario, t)
            action = policy.get_action(obs)
            actions.append(action)

        plot_action_sequence(actions)

        # Check for erratic behavior
        action_changes = np.diff(actions, axis=0)
        print(f"Action jerk: {np.std(action_changes):.3f}")

Step 5: Validate Training Pipeline

def validate_training():
    # Check gradient norms
    grad_norms = training_logs['grad_norms']
    if np.mean(grad_norms[-1000:]) > 100:
        print("Gradient explosion detected")
    if np.mean(grad_norms[-1000:]) < 0.001:
        print("Gradient vanishing detected")

    # Check Q-value estimates
    q_values = training_logs['q_values']
    if np.mean(q_values[-1000:]) > 1000:
        print("Q-value explosion - reduce learning rate")

    # Check replay buffer
    buffer_rewards = replay_buffer.get_all_rewards()
    print(f"Buffer reward distribution: {np.percentile(buffer_rewards, [25, 50, 75])}")

Common Fixes:

IssueSymptomFix
Reward scaleQ-values explodeNormalize rewards to [-1, 1]
Observation NaNPolicy outputs NaNAdd input validation
Sparse collisionsPolicy doesn't learnIncrease collision penalty
Gradient explosionTraining unstableReduce learning rate, clip gradients
ExplorationSame collision repeatedlyIncrease entropy, reset buffer
</details>

Summary

V-Max represents a significant step forward in making reinforcement learning practical for autonomous driving. Key takeaways:

  1. Infrastructure Matters: JAX-based unified computation enables efficient training
  2. Multi-Dataset Support: ScenarioMax solves the data fragmentation problem
  3. Hierarchical Rewards: Safety > Navigation > Behavior priority is crucial
  4. Encoder Architecture: Latent Query's bottleneck compression works best
  5. Off-Policy RL: SAC achieves expert-level performance in 25M steps

For ML infrastructure engineers, V-Max demonstrates how to build a complete RL pipeline that's:

  • Efficient (4,609 steps/sec)
  • Modular (YAML-configurable observations)
  • Extensible (pluggable encoders and rewards)
  • Reproducible (comprehensive benchmarks)

Further Reading


Last updated: January 2026