Back to all papers
Deep Dive #355 min read

JAX Scaling RL Deep Dive

How to scale RL training across GPUs/TPUs using JAX primitives: jit, vmap, pmap, scan, and distributed PPO.

Deep Dive: Scaling Autonomous Driving with JAX-Accelerated Reinforcement Learning

Paper: Scaling Is All You Need: Autonomous Driving with JAX-Accelerated Reinforcement Learning

Target Audience: ML Infrastructure Engineers, Distributed Systems Engineers, RL Practitioners

Reading Time: 45-60 minutes (with exercises: 2-3 hours)


Table of Contents

  1. Executive Summary
  2. JAX Primitives Deep Dive
  3. Pure Functional Simulation
  4. Distributed Training Architecture
  5. Performance Optimization
  6. Pre-training Strategy
  7. Interactive Code Examples
  8. Benchmarks Analysis
  9. Common Pitfalls
  10. Hands-On Exercises
  11. Interview Questions

1. Executive Summary

The Core Insight

Scaling reinforcement learning for autonomous driving requires three pillars:

  1. Hardware-accelerated simulation - JAX enables GPU-native environment execution
  2. Distributed asynchronous training - Multiple actors/learners with gradient synchronization
  3. Massive real-world data - 6000+ hours of human driving scenarios

Key Results

MetricState-of-the-ArtThis PaperImprovement
Failure Rate2.81%0.88%-64%
Progress Ratio87.6%120.8%+25%
Agent Steps30M (Waymax)2.5B83x more

Why This Matters for ML Infrastructure

This paper demonstrates how to:

  • Eliminate Python overhead via JIT compilation
  • Achieve near-linear GPU scaling (8 to 32 GPUs with only 15% overhead)
  • Combine simulation and training into a single XLA computation graph
  • Handle off-policy corrections in distributed settings with V-trace

2. JAX Primitives Deep Dive

JAX is the foundation of the paper's performance gains. Let's understand each primitive in depth.

2.1 jax.jit - Just-In-Time Compilation

What it does: Traces Python functions and compiles them to XLA (Accelerated Linear Algebra) operations that run on GPU/TPU.

Why it matters: Eliminates Python interpreter overhead. The paper achieves 0.52ms per step vs 0.75ms for baseline - a 30% speedup from JIT alone.

import jax
import jax.numpy as jnp
from functools import partial

# WITHOUT JIT - Python interpreter overhead on every call
def slow_matmul_chain(x, weights):
    for w in weights:
        x = jnp.dot(x, w)
        x = jax.nn.relu(x)
    return x

# WITH JIT - Compiled once, runs as native XLA
@jax.jit
def fast_matmul_chain(x, weights):
    for w in weights:
        x = jnp.dot(x, w)
        x = jax.nn.relu(x)
    return x

# Benchmark comparison
import time

key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (1000, 256))
weights = [jax.random.normal(key, (256, 256)) for _ in range(10)]

# Warm-up JIT compilation
_ = fast_matmul_chain(x, weights)

# Timing
start = time.time()
for _ in range(100):
    _ = slow_matmul_chain(x, weights)
slow_time = time.time() - start

start = time.time()
for _ in range(100):
    _ = fast_matmul_chain(x, weights)
fast_time = time.time() - start

print(f"Without JIT: {slow_time:.3f}s")
print(f"With JIT: {fast_time:.3f}s")
print(f"Speedup: {slow_time/fast_time:.1f}x")

Key insight from paper: By JIT-compiling the entire simulation step, the authors avoid CPU-GPU data transfers between timesteps.

2.2 jax.vmap - Automatic Vectorization

What it does: Transforms a function that operates on single examples into one that operates on batches, without changing the function's code.

Why it matters: The paper processes 512 agents simultaneously. Writing batch-aware code manually is error-prone; vmap handles it automatically.

import jax
import jax.numpy as jnp

# Function written for a SINGLE agent
def compute_agent_reward(
    position: jnp.ndarray,    # shape: (2,)
    velocity: jnp.ndarray,    # shape: (2,)
    goal: jnp.ndarray,        # shape: (2,)
    max_speed: float = 10.0
) -> float:
    """Compute reward for one agent."""
    # Distance to goal (negative = bad)
    dist_to_goal = jnp.linalg.norm(position - goal)

    # Speed penalty if exceeding limit
    speed = jnp.linalg.norm(velocity)
    speed_penalty = jnp.maximum(0, speed - max_speed) ** 2

    return -dist_to_goal - 0.1 * speed_penalty

# Automatically vectorize over batch of agents
batched_reward = jax.vmap(compute_agent_reward)

# Now works on batches!
batch_size = 512
key = jax.random.PRNGKey(42)
positions = jax.random.uniform(key, (batch_size, 2)) * 100
velocities = jax.random.normal(key, (batch_size, 2)) * 5
goals = jax.random.uniform(key, (batch_size, 2)) * 100

# Single call processes all 512 agents in parallel
rewards = batched_reward(positions, velocities, goals)
print(f"Rewards shape: {rewards.shape}")  # (512,)

Nested vmap for multi-dimensional batching:

# Paper uses scenarios x agents x timesteps
# vmap can be nested for multiple batch dimensions

def single_step(state, action):
    """Process single (scenario, agent, timestep)."""
    return state + action * 0.1

# Vectorize over agents within a scenario
agents_step = jax.vmap(single_step, in_axes=(0, 0))

# Vectorize over scenarios
scenarios_agents_step = jax.vmap(agents_step, in_axes=(0, 0))

# Now handles (num_scenarios, num_agents, state_dim)
num_scenarios = 16
num_agents = 512
state_dim = 6

states = jnp.zeros((num_scenarios, num_agents, state_dim))
actions = jnp.ones((num_scenarios, num_agents, state_dim))

new_states = scenarios_agents_step(states, actions)
print(f"Output shape: {new_states.shape}")  # (16, 512, 6)

2.3 jax.pmap - Parallel Mapping Across Devices

What it does: Distributes computation across multiple GPUs/TPUs with automatic data parallelism.

Why it matters: The paper scales from 8 to 32 GPUs with near-linear speedup. pmap handles the device coordination.

import jax
import jax.numpy as jnp
from jax import pmap

# Check available devices
print(f"Available devices: {jax.devices()}")
n_devices = jax.local_device_count()

# Function to run on each device
def train_step(params, batch):
    """Single gradient step - runs on one device."""
    def loss_fn(p):
        predictions = jnp.dot(batch['x'], p['w']) + p['b']
        return jnp.mean((predictions - batch['y']) ** 2)

    loss, grads = jax.value_and_grad(loss_fn)(params)

    # Simple SGD update
    new_params = jax.tree.map(
        lambda p, g: p - 0.01 * g, params, grads
    )
    return new_params, loss

# pmap automatically:
# 1. Shards data across devices (first axis)
# 2. Runs computation in parallel
# 3. Handles cross-device communication
parallel_train_step = pmap(train_step)

# Initialize replicated params (same on each device)
params = {
    'w': jnp.ones((10, 1)),
    'b': jnp.zeros((1,))
}
replicated_params = jax.tree.map(
    lambda x: jnp.stack([x] * n_devices), params
)

# Create sharded batch (different data per device)
batch_per_device = 32
total_batch = batch_per_device * n_devices

key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (total_batch, 10))
y = jax.random.normal(key, (total_batch, 1))

# Reshape for pmap: (n_devices, batch_per_device, ...)
sharded_batch = {
    'x': x.reshape(n_devices, batch_per_device, 10),
    'y': y.reshape(n_devices, batch_per_device, 1)
}

# Run parallel training step
new_params, losses = parallel_train_step(replicated_params, sharded_batch)
print(f"Losses per device: {losses}")

Gradient synchronization with pmap:

# The paper synchronizes gradients across learners using NCCL
# In JAX, use lax.pmean for this

def synchronized_train_step(params, batch):
    """Train step with gradient averaging across devices."""
    def loss_fn(p):
        predictions = jnp.dot(batch['x'], p['w']) + p['b']
        return jnp.mean((predictions - batch['y']) ** 2)

    loss, grads = jax.value_and_grad(loss_fn)(params)

    # Average gradients across all devices
    # 'batch' is the axis name from pmap
    grads = jax.lax.pmean(grads, axis_name='batch')
    loss = jax.lax.pmean(loss, axis_name='batch')

    new_params = jax.tree.map(
        lambda p, g: p - 0.01 * g, params, grads
    )
    return new_params, loss

# axis_name connects pmean to the pmap dimension
parallel_sync_step = pmap(synchronized_train_step, axis_name='batch')

2.4 jax.lax.scan - Efficient Sequential Operations

What it does: Compiles a loop into a single XLA operation, avoiding Python loop overhead and enabling automatic differentiation through time.

Why it matters: The paper "scans along the time axis of dynamic data" to process 151 timesteps efficiently.

import jax
import jax.numpy as jnp
from jax import lax

# Environment step function
def env_step(carry, action):
    """
    Single environment step.
    carry: (state, cumulative_reward)
    action: action to take
    Returns: new_carry, output_for_this_step
    """
    state, cum_reward = carry

    # Physics update (simplified)
    position, velocity = state[:2], state[2:4]
    new_velocity = velocity + action * 0.1
    new_position = position + new_velocity * 0.1

    new_state = jnp.concatenate([new_position, new_velocity])

    # Reward calculation
    reward = -jnp.sum(new_position ** 2)  # Move towards origin

    new_carry = (new_state, cum_reward + reward)
    output = {'state': new_state, 'reward': reward}

    return new_carry, output

# Roll out 151 timesteps efficiently
@jax.jit
def rollout(initial_state, actions):
    """
    Rollout environment for T timesteps.
    actions: (T, action_dim)
    """
    initial_carry = (initial_state, 0.0)

    final_carry, trajectory = lax.scan(
        env_step,
        initial_carry,
        actions
    )

    final_state, total_reward = final_carry
    return trajectory, total_reward

# Usage
T = 151  # Timesteps (from paper)
action_dim = 2
state_dim = 4

initial_state = jnp.zeros(state_dim)
actions = jax.random.normal(jax.random.PRNGKey(0), (T, action_dim))

trajectory, total_reward = rollout(initial_state, actions)
print(f"Trajectory states shape: {trajectory['state'].shape}")  # (151, 4)
print(f"Total reward: {total_reward:.2f}")

Combining scan with vmap for batched rollouts:

# The paper batches scenarios AND time using scan + vmap

@jax.jit
def batched_rollout(initial_states, actions_batch):
    """
    Rollout multiple scenarios in parallel.
    initial_states: (batch_size, state_dim)
    actions_batch: (batch_size, T, action_dim)
    """
    # vmap over batch dimension
    return jax.vmap(rollout)(initial_states, actions_batch)

# Run 16 scenarios with 151 timesteps each
batch_size = 16
initial_states = jnp.zeros((batch_size, state_dim))
actions_batch = jax.random.normal(
    jax.random.PRNGKey(1),
    (batch_size, T, action_dim)
)

trajectories, total_rewards = batched_rollout(initial_states, actions_batch)
print(f"Batched states shape: {trajectories['state'].shape}")  # (16, 151, 4)
print(f"Total rewards shape: {total_rewards.shape}")  # (16,)

3. Pure Functional Simulation

Why Functional Programming Matters

The paper enforces strict functional programming principles:

"All functions need to be pure, meaning they cannot have any side effects"

This enables:

  1. JIT compilation - XLA can only compile pure functions
  2. Automatic differentiation - Gradients require deterministic computation
  3. Parallelization - No shared state means no race conditions
  4. Reproducibility - Same inputs always produce same outputs

The Anti-Pattern: Stateful Simulation

# BAD: Stateful simulation - CANNOT be JIT compiled
class StatefulEnv:
    def __init__(self):
        self.state = None
        self.step_count = 0
        self.rng = np.random.RandomState(42)  # Hidden state!

    def reset(self):
        self.state = np.zeros(4)
        self.step_count = 0
        return self.state

    def step(self, action):
        # Mutates internal state - side effect!
        self.state = self.state + action * 0.1
        self.step_count += 1

        # Uses hidden RNG state - non-deterministic!
        noise = self.rng.randn(*self.state.shape) * 0.01
        self.state += noise

        reward = -np.sum(self.state ** 2)
        done = self.step_count >= 100

        return self.state.copy(), reward, done

Problems with stateful design:

  • Cannot JIT compile (Python objects store state)
  • Non-deterministic (RNG state hidden)
  • Cannot vmap (shared state across batch)
  • Cannot differentiate through (state mutation)

The Pattern: Pure Functional Simulation

import jax
import jax.numpy as jnp
from typing import NamedTuple

# State is explicit, immutable, and passed through functions
class EnvState(NamedTuple):
    position: jnp.ndarray
    velocity: jnp.ndarray
    step_count: int
    rng_key: jnp.ndarray

# GOOD: Pure functional environment
@jax.jit
def env_reset(rng_key: jnp.ndarray) -> EnvState:
    """Pure reset - no side effects."""
    return EnvState(
        position=jnp.zeros(2),
        velocity=jnp.zeros(2),
        step_count=0,
        rng_key=rng_key
    )

@jax.jit
def env_step(
    state: EnvState,
    action: jnp.ndarray
) -> tuple[EnvState, float, bool]:
    """
    Pure step function.
    - Takes state explicitly
    - Returns new state (no mutation)
    - RNG key is part of state
    """
    # Split RNG key for this step
    rng_key, noise_key = jax.random.split(state.rng_key)

    # Deterministic physics
    new_velocity = state.velocity + action * 0.1
    new_position = state.position + new_velocity * 0.1

    # Explicit randomness with passed key
    noise = jax.random.normal(noise_key, new_position.shape) * 0.01
    new_position = new_position + noise

    # Create NEW state (immutable)
    new_state = EnvState(
        position=new_position,
        velocity=new_velocity,
        step_count=state.step_count + 1,
        rng_key=rng_key  # Updated key for next step
    )

    reward = -jnp.sum(new_position ** 2)
    done = new_state.step_count >= 100

    return new_state, reward, done

# Usage - state flows explicitly
key = jax.random.PRNGKey(42)
state = env_reset(key)

for _ in range(10):
    action = jnp.array([0.1, 0.2])
    state, reward, done = env_step(state, action)
    print(f"Step {state.step_count}: reward = {reward:.3f}")

Handling Conditional Logic Without Branching

The paper notes:

"Logical branching is eliminated by executing all code branches and selecting the required result after calculation"

This is necessary because XLA compiles static graphs - it cannot handle Python if statements.

import jax
import jax.numpy as jnp

# BAD: Python branching - breaks JIT
def bad_reward(collision: bool, at_goal: bool) -> float:
    if collision:
        return -100.0
    elif at_goal:
        return 100.0
    else:
        return -1.0

# GOOD: JAX-compatible conditional using jnp.where
@jax.jit
def good_reward(collision: jnp.ndarray, at_goal: jnp.ndarray) -> jnp.ndarray:
    """All branches computed, result selected."""
    # Compute ALL possible rewards
    collision_reward = -100.0
    goal_reward = 100.0
    step_reward = -1.0

    # Select based on conditions (no Python branching)
    reward = jnp.where(
        collision,
        collision_reward,
        jnp.where(at_goal, goal_reward, step_reward)
    )
    return reward

# GOOD: Using jax.lax.cond for expensive branches
@jax.jit
def expensive_reward(collision: jnp.ndarray, state: jnp.ndarray) -> jnp.ndarray:
    """Use lax.cond when branches have different compute costs."""
    def collision_branch(s):
        # Simple computation
        return -100.0

    def normal_branch(s):
        # Expensive computation (only runs if not collision)
        complex_reward = jnp.sum(jnp.exp(-s ** 2))
        return complex_reward

    # Only evaluates one branch at runtime
    return jax.lax.cond(collision, collision_branch, normal_branch, state)

Data Padding for Uniform Shapes

The paper pads all scenarios to uniform dimensions:

  • 512 agents maximum per scenario
  • 151 timesteps (30 seconds at 5 Hz)
  • 128 road elements maximum
import jax.numpy as jnp
from typing import NamedTuple

class PaddedScenario(NamedTuple):
    """Fixed-shape scenario for JIT compilation."""
    # Agent data: (max_agents, timesteps, features)
    agent_positions: jnp.ndarray  # (512, 151, 2)
    agent_velocities: jnp.ndarray  # (512, 151, 2)
    agent_valid_mask: jnp.ndarray  # (512,) - which agents are real

    # Road data: (max_elements, features)
    road_points: jnp.ndarray  # (128, 10, 2)
    road_valid_mask: jnp.ndarray  # (128,)

    # Metadata
    num_valid_agents: int
    num_valid_roads: int

def pad_scenario(
    positions: jnp.ndarray,  # (actual_agents, actual_steps, 2)
    max_agents: int = 512,
    max_steps: int = 151
) -> jnp.ndarray:
    """Pad variable-size data to fixed shape."""
    actual_agents, actual_steps, feat_dim = positions.shape

    # Create padded array
    padded = jnp.zeros((max_agents, max_steps, feat_dim))

    # Copy actual data (JAX-compatible slicing)
    padded = padded.at[:actual_agents, :actual_steps, :].set(positions)

    return padded

def mask_padded_agents(
    rewards: jnp.ndarray,  # (512,)
    valid_mask: jnp.ndarray  # (512,) boolean
) -> jnp.ndarray:
    """Zero out rewards for padded (invalid) agents."""
    return jnp.where(valid_mask, rewards, 0.0)

4. Distributed Training Architecture

System Overview

The paper implements an asynchronous actor-learner architecture similar to OpenAI's Dota 2 system:

+--------------------------------------------------+
|            DISTRIBUTED TRAINING SYSTEM            |
+--------------------------------------------------+

                    EXPERIENCE FLOW
                         |
         +---------------+---------------+
         |               |               |
    +----v----+    +----v----+    +-----v----+
    |  Actor  |    |  Actor  |    |  Actor   |
    | Group 1 |    | Group 2 |    | Group N  |
    | (GPU 1) |    | (GPU 2) |    | (GPU N)  |
    +----+----+    +----+----+    +-----+----+
         |               |               |
         | trajectories  | trajectories  |
         |               |               |
    +----v---------------v---------------v----+
    |           EXPERIENCE BUFFERS            |
    |    (Replay buffers per learner group)   |
    +----+---------------+---------------+----+
         |               |               |
         |               |               |
    +----v----+    +----v----+    +-----v----+
    | Learner |    | Learner |    | Learner  |
    | Group 1 |    | Group 2 |    | Group N  |
    | (GPUs)  |    | (GPUs)  |    | (GPUs)   |
    +----+----+    +----+----+    +-----+----+
         |               |               |
         +---------------+---------------+
                         |
                    NCCL AllReduce
                   (Gradient Sync)
                         |
                  +------v------+
                  |   UPDATED   |
                  |   POLICY    |
                  +------+------+
                         |
            +------------+------------+
            |            |            |
       +----v----+  +----v----+  +----v----+
       |  Actor  |  |  Actor  |  |  Actor  |
       | (pulls  |  | (pulls  |  | (pulls  |
       | latest) |  | latest) |  | latest) |
       +---------+  +---------+  +---------+

Asynchronous PPO with V-trace

The Problem: In asynchronous training, actors use slightly stale policies. When the learner updates, the collected experience is "off-policy."

The Solution: V-trace importance sampling correction.

import jax
import jax.numpy as jnp
from typing import NamedTuple

class VTraceOutput(NamedTuple):
    vs: jnp.ndarray  # Corrected value estimates
    advantages: jnp.ndarray  # Policy gradient advantages
    rhos: jnp.ndarray  # Truncated importance weights

def compute_vtrace(
    behavior_log_probs: jnp.ndarray,  # Log probs from actor's policy
    target_log_probs: jnp.ndarray,    # Log probs from learner's policy
    rewards: jnp.ndarray,             # (T,)
    values: jnp.ndarray,              # V(s) estimates (T+1,)
    dones: jnp.ndarray,               # Episode termination (T,)
    gamma: float = 0.99,
    rho_clip: float = 1.0,            # Importance weight clip
    c_clip: float = 1.0               # Trace coefficient clip
) -> VTraceOutput:
    """
    V-trace off-policy correction.

    Key insight: When actor policy != learner policy,
    we need to correct for the distribution mismatch.
    """
    T = rewards.shape[0]

    # Importance sampling ratios
    log_rhos = target_log_probs - behavior_log_probs
    rhos = jnp.exp(log_rhos)

    # Clip importance weights (controls variance)
    clipped_rhos = jnp.minimum(rhos, rho_clip)
    clipped_cs = jnp.minimum(rhos, c_clip)

    # TD errors with importance correction
    # delta_t = rho_t * (r_t + gamma * V(s_{t+1}) - V(s_t))
    not_done = 1.0 - dones
    deltas = clipped_rhos * (
        rewards + gamma * not_done * values[1:] - values[:-1]
    )

    # Backward pass to compute V-trace targets
    # vs_t = V(s_t) + sum_{k=t}^{T-1} gamma^{k-t} * (prod_{i=t}^{k-1} c_i) * delta_k
    def scan_fn(carry, inputs):
        """Accumulate discounted corrections backward."""
        acc = carry
        delta, c, not_d = inputs
        acc = delta + gamma * c * not_d * acc
        return acc, acc

    # Scan backward through time
    _, corrections = jax.lax.scan(
        scan_fn,
        jnp.zeros(()),
        (deltas[::-1], clipped_cs[::-1], not_done[::-1])
    )
    corrections = corrections[::-1]  # Reverse back to forward order

    # V-trace value estimates
    vs = values[:-1] + corrections

    # Advantages for policy gradient
    # Use rho-clipped importance weights
    vs_plus_1 = jnp.concatenate([vs[1:], values[-1:]])
    advantages = clipped_rhos * (
        rewards + gamma * not_done * vs_plus_1 - values[:-1]
    )

    return VTraceOutput(vs=vs, advantages=advantages, rhos=clipped_rhos)

# Example usage
T = 32  # Sequence length from paper
key = jax.random.PRNGKey(0)

# Simulated off-policy data
behavior_log_probs = jax.random.normal(key, (T,)) - 1.0  # Actor's policy
target_log_probs = jax.random.normal(key, (T,))          # Learner's policy
rewards = jax.random.uniform(key, (T,))
values = jax.random.uniform(key, (T + 1,))
dones = jnp.zeros(T)

vtrace = compute_vtrace(
    behavior_log_probs, target_log_probs,
    rewards, values, dones
)
print(f"V-trace targets shape: {vtrace.vs.shape}")
print(f"Advantages shape: {vtrace.advantages.shape}")

PPO Loss with V-trace

import jax
import jax.numpy as jnp
from typing import NamedTuple

class PPOLossOutput(NamedTuple):
    total_loss: jnp.ndarray
    policy_loss: jnp.ndarray
    value_loss: jnp.ndarray
    entropy_loss: jnp.ndarray

def ppo_loss(
    log_probs: jnp.ndarray,          # Current policy log probs
    old_log_probs: jnp.ndarray,      # Behavior policy log probs
    values: jnp.ndarray,             # Current value estimates
    vtrace_targets: jnp.ndarray,     # V-trace corrected targets
    advantages: jnp.ndarray,         # V-trace advantages
    entropy: jnp.ndarray,            # Policy entropy
    clip_eps: float = 0.3,           # From paper
    value_coef: float = 0.5,
    entropy_coef: float = 0.03       # From paper
) -> PPOLossOutput:
    """
    PPO-clip loss with V-trace corrections.
    """
    # Policy loss with clipping
    ratio = jnp.exp(log_probs - old_log_probs)
    clipped_ratio = jnp.clip(ratio, 1.0 - clip_eps, 1.0 + clip_eps)

    # Pessimistic policy update
    policy_loss = -jnp.minimum(
        ratio * advantages,
        clipped_ratio * advantages
    ).mean()

    # Value loss (MSE with V-trace targets)
    value_loss = jnp.mean((values - vtrace_targets) ** 2)

    # Entropy bonus (encourages exploration)
    entropy_loss = -entropy.mean()

    total_loss = (
        policy_loss +
        value_coef * value_loss +
        entropy_coef * entropy_loss
    )

    return PPOLossOutput(
        total_loss=total_loss,
        policy_loss=policy_loss,
        value_loss=value_loss,
        entropy_loss=entropy_loss
    )

Gradient Synchronization Architecture

+--------------------------------------------------------+
|              GRADIENT SYNC (NCCL AllReduce)            |
+--------------------------------------------------------+

  Learner 0          Learner 1          Learner 2
  +-------+          +-------+          +-------+
  |Grads_0|          |Grads_1|          |Grads_2|
  +---+---+          +---+---+          +---+---+
      |                  |                  |
      |    Ring AllReduce (NCCL)           |
      +------------------+------------------+
                         |
                  Average Gradients
                         |
      +------------------+------------------+
      |                  |                  |
  +---v---+          +---v---+          +---v---+
  |Avg_G  |          |Avg_G  |          |Avg_G  |
  +---+---+          +---+---+          +---+---+
      |                  |                  |
  Apply to           Apply to           Apply to
  Local Params       Local Params       Local Params


Implementation in JAX:
- pmap with axis_name='devices'
- lax.pmean(grads, 'devices') for averaging
- All devices get identical gradients
- Synchronized parameter updates
import jax
import jax.numpy as jnp
from jax import pmap, lax

def create_distributed_train_step(model_apply, optimizer):
    """
    Create a training step that synchronizes across devices.
    """

    def train_step(params, opt_state, batch, rng):
        """Single training step with gradient sync."""

        def loss_fn(p):
            # Forward pass
            outputs = model_apply(p, batch['observations'])

            # Compute PPO loss (simplified)
            log_probs = outputs['log_probs']
            values = outputs['values']

            # V-trace would be computed here
            advantages = batch['advantages']
            vtrace_targets = batch['vtrace_targets']

            policy_loss = -jnp.mean(log_probs * advantages)
            value_loss = jnp.mean((values - vtrace_targets) ** 2)

            return policy_loss + 0.5 * value_loss

        loss, grads = jax.value_and_grad(loss_fn)(params)

        # CRITICAL: Synchronize gradients across all devices
        grads = lax.pmean(grads, axis_name='devices')
        loss = lax.pmean(loss, axis_name='devices')

        # Apply synchronized gradients
        updates, new_opt_state = optimizer.update(grads, opt_state, params)
        new_params = optax.apply_updates(params, updates)

        return new_params, new_opt_state, loss

    # Parallelize across devices
    return pmap(train_step, axis_name='devices')

5. Performance Optimization

GPU Utilization Strategy

The paper achieves high GPU utilization through:

  1. Double-buffered data loading
  2. Fused simulation + inference
  3. Asynchronous experience transfer
+--------------------------------------------------------+
|              GPU UTILIZATION TIMELINE                   |
+--------------------------------------------------------+

GPU 0 (Simulation):
|==Sim Batch 0==|==Sim Batch 1==|==Sim Batch 2==|
                ^               ^
                |               |
          Pre-load B1     Pre-load B2

GPU Memory:
|  Batch 0 (running)  |  Batch 1 (pre-loaded)  |
                      |  Batch 2 (loading)      |

"To achieve high GPU utilization one scenario batch
is pre-loaded to GPU memory while a simulation of
a different batch is running"
import jax
import jax.numpy as jnp
from typing import Iterator
import threading
from queue import Queue

class DoubleBufferedLoader:
    """
    Pre-load next batch while current batch is processing.
    Hides data transfer latency behind compute.
    """

    def __init__(self, data_iterator: Iterator, prefetch_size: int = 2):
        self.data_iter = data_iterator
        self.prefetch_queue = Queue(maxsize=prefetch_size)
        self.stop_event = threading.Event()

        # Start background prefetch thread
        self.prefetch_thread = threading.Thread(target=self._prefetch_loop)
        self.prefetch_thread.start()

    def _prefetch_loop(self):
        """Background thread: load and transfer to GPU."""
        for batch in self.data_iter:
            if self.stop_event.is_set():
                break

            # Transfer to GPU in background
            gpu_batch = jax.device_put(batch)
            self.prefetch_queue.put(gpu_batch)

    def get_batch(self):
        """Get pre-loaded GPU batch (blocks if not ready)."""
        return self.prefetch_queue.get()

    def close(self):
        self.stop_event.set()
        self.prefetch_thread.join()

Fused Simulation + Inference

The key optimization: compile simulation AND policy inference into a single XLA graph.

import jax
import jax.numpy as jnp
from jax import lax

def create_fused_rollout(env_step_fn, policy_fn):
    """
    Fuse environment step and policy inference.

    This creates a SINGLE compiled XLA graph that:
    1. Steps the environment
    2. Computes policy on new state
    3. Samples action
    4. Repeats

    No Python loop overhead, no CPU-GPU transfers between steps.
    """

    def single_step(carry, _):
        state, rng = carry

        # Split RNG
        rng, action_rng, step_rng = jax.random.split(rng, 3)

        # Policy inference (ON GPU)
        action_logits = policy_fn(state.observation)
        action = jax.random.categorical(action_rng, action_logits)
        log_prob = jax.nn.log_softmax(action_logits)[action]

        # Environment step (ON GPU)
        new_state, reward, done = env_step_fn(state, action)

        output = {
            'observation': state.observation,
            'action': action,
            'log_prob': log_prob,
            'reward': reward,
            'done': done
        }

        new_carry = (new_state, rng)
        return new_carry, output

    @jax.jit
    def fused_rollout(initial_state, rng, num_steps):
        """Execute full rollout as single compiled graph."""
        initial_carry = (initial_state, rng)

        # lax.scan compiles entire loop
        final_carry, trajectory = lax.scan(
            single_step,
            initial_carry,
            xs=None,  # No input sequence
            length=num_steps
        )

        return trajectory

    return fused_rollout

Batch Size Optimization

The paper shows batch size dramatically affects throughput:

Batch SizeSimulator Time (V100)Speedup
10.52 ms1.0x
160.82 ms9.8x
import jax
import jax.numpy as jnp
import time

def benchmark_batch_sizes(env_step_fn, batch_sizes, num_iters=100):
    """
    Benchmark different batch sizes to find optimal throughput.

    Key insight: Larger batches amortize kernel launch overhead
    but eventually hit memory limits or diminishing returns.
    """
    key = jax.random.PRNGKey(0)

    results = {}
    for batch_size in batch_sizes:
        # Create batched environment step
        batched_step = jax.vmap(env_step_fn)
        jit_batched_step = jax.jit(batched_step)

        # Create dummy data
        states = jnp.zeros((batch_size, 64))
        actions = jnp.zeros((batch_size, 2))

        # Warmup (JIT compilation)
        _ = jit_batched_step(states, actions)

        # Benchmark
        start = time.time()
        for _ in range(num_iters):
            states, _, _ = jit_batched_step(states, actions)
        jax.block_until_ready(states)  # Wait for async ops
        elapsed = time.time() - start

        steps_per_second = (batch_size * num_iters) / elapsed
        time_per_step = elapsed / num_iters * 1000  # ms

        results[batch_size] = {
            'steps_per_second': steps_per_second,
            'time_per_step_ms': time_per_step
        }

        print(f"Batch {batch_size:4d}: {steps_per_second:.0f} steps/s, "
              f"{time_per_step:.2f} ms/step")

    return results

Near-Linear Scaling Analysis

The paper achieves remarkable scaling efficiency:

GPUsRuntime (hours)Normalized GPU-hoursEfficiency
841.58332.64100%
1622.83365.2891%
3211.99383.6887%
def analyze_scaling_efficiency(gpu_counts, runtimes):
    """
    Analyze how efficiently training scales with more GPUs.

    Perfect linear scaling: 2x GPUs = 0.5x runtime
    Real-world: Communication overhead reduces efficiency
    """
    base_gpus = gpu_counts[0]
    base_runtime = runtimes[0]
    base_gpu_hours = base_gpus * base_runtime

    print("Scaling Analysis")
    print("-" * 60)

    for gpus, runtime in zip(gpu_counts, runtimes):
        gpu_hours = gpus * runtime

        # Ideal runtime with perfect scaling
        ideal_runtime = base_runtime * (base_gpus / gpus)

        # Efficiency = ideal_time / actual_time
        efficiency = ideal_runtime / runtime * 100

        # Overhead = extra GPU-hours beyond baseline
        overhead = (gpu_hours - base_gpu_hours) / base_gpu_hours * 100

        print(f"GPUs: {gpus:2d} | Runtime: {runtime:5.2f}h | "
              f"Efficiency: {efficiency:5.1f}% | Overhead: {overhead:+5.1f}%")

# Data from paper
analyze_scaling_efficiency(
    gpu_counts=[8, 16, 32],
    runtimes=[41.58, 22.83, 11.99]
)

6. Pre-training Strategy

Why Pre-training Matters

The paper makes a critical observation:

"Pre-training only the policy via behavioral cloning makes it challenging to further train this policy with an actor-critic RL method"

The problem: untrained critic produces garbage value estimates, destabilizing policy gradients.

WITHOUT Value Pre-training:
+-----------------+     +-------------------+     +------------------+
|  Pre-trained    |     |  Random Critic    |     |  Garbage         |
|  Policy (BC)    | --> |  V(s) = noise     | --> |  Advantages      |
|  pi(a|s)        |     |                   |     |  A = R - V       |
+-----------------+     +-------------------+     +------------------+
                                                          |
                                                          v
                                               Policy gradients are
                                               essentially random!

WITH Value Pre-training (Paper's Approach):
+-----------------+     +-------------------+     +------------------+
|  Pre-trained    |     |  Pre-trained      |     |  Meaningful      |
|  Policy (BC)    | --> |  Critic (returns) | --> |  Advantages      |
|  pi(a|s)        |     |  V(s) ~ G_0       |     |  A = R - V       |
+-----------------+     +-------------------+     +------------------+
                                                          |
                                                          v
                                               Policy gradients point
                                               toward improvement!

Implementation

import jax
import jax.numpy as jnp
import optax
from typing import NamedTuple

class BCPretrainState(NamedTuple):
    params: dict
    opt_state: optax.OptState
    step: int

def create_bc_pretrainer(model_apply, action_space_size):
    """
    Behavioral cloning pre-trainer that trains BOTH policy and value.

    Key insight from paper: Value pre-training uses discounted returns
    from human demonstrations as targets.
    """

    def compute_returns(rewards, dones, gamma=0.99):
        """Compute discounted returns G_0 = sum_{t=0}^T gamma^t * R_t."""
        T = rewards.shape[0]

        def scan_fn(carry, inputs):
            future_return = carry
            reward, done = inputs

            # Reset return at episode boundary
            current_return = reward + gamma * future_return * (1 - done)
            return current_return, current_return

        # Scan backward
        _, returns = jax.lax.scan(
            scan_fn,
            jnp.zeros(()),
            (rewards[::-1], dones[::-1])
        )
        return returns[::-1]

    def bc_loss(params, batch):
        """
        Combined behavioral cloning loss.

        Loss = CrossEntropy(policy, expert_actions)
             + value_scale * MSE(values, returns)
        """
        outputs = model_apply(params, batch['observations'])

        # Policy loss: cross-entropy with expert actions
        action_logits = outputs['action_logits']
        expert_actions = batch['actions']
        policy_loss = optax.softmax_cross_entropy_with_integer_labels(
            action_logits, expert_actions
        ).mean()

        # Value loss: MSE with computed returns
        values = outputs['values']
        returns = batch['returns']  # Pre-computed discounted returns
        value_loss = jnp.mean((values - returns) ** 2)

        # From paper: value_loss_scale = 1e-4
        total_loss = policy_loss + 1e-4 * value_loss

        return total_loss, {
            'policy_loss': policy_loss,
            'value_loss': value_loss
        }

    @jax.jit
    def train_step(state: BCPretrainState, batch):
        """Single BC pre-training step."""
        (loss, metrics), grads = jax.value_and_grad(
            bc_loss, has_aux=True
        )(state.params, batch)

        updates, new_opt_state = optax.adam(2e-3).update(
            grads, state.opt_state, state.params
        )
        new_params = optax.apply_updates(state.params, updates)

        new_state = BCPretrainState(
            params=new_params,
            opt_state=new_opt_state,
            step=state.step + 1
        )

        return new_state, metrics

    return train_step, compute_returns

# Training loop structure
def pretrain_loop(
    train_step_fn,
    initial_state,
    dataset,
    num_epochs=20,  # From paper
    batch_size=32768  # From paper
):
    """
    Pre-training loop for behavioral cloning.

    Paper trains for 20 epochs with batch size 32768.
    """
    state = initial_state

    for epoch in range(num_epochs):
        epoch_losses = []

        for batch in dataset.shuffle().batch(batch_size):
            state, metrics = train_step_fn(state, batch)
            epoch_losses.append(metrics['policy_loss'])

        avg_loss = jnp.mean(jnp.array(epoch_losses))
        print(f"Epoch {epoch + 1}/{num_epochs}: Loss = {avg_loss:.4f}")

    return state

Value Target Computation

def prepare_bc_dataset(trajectories, gamma=0.99):
    """
    Prepare behavioral cloning dataset with value targets.

    Each trajectory contains:
    - observations: (T, obs_dim)
    - actions: (T,) expert actions
    - rewards: (T,) per-step rewards
    - dones: (T,) episode boundaries
    """
    processed = []

    for traj in trajectories:
        # Compute discounted returns for value targets
        returns = compute_discounted_returns(traj['rewards'], traj['dones'], gamma)

        processed.append({
            'observations': traj['observations'],
            'actions': traj['actions'],
            'returns': returns  # Value pre-training targets!
        })

    return processed

def compute_discounted_returns(rewards, dones, gamma):
    """G_0 = sum_{t=0}^T gamma^t * R(s_t)"""
    T = len(rewards)
    returns = jnp.zeros(T)
    future_return = 0.0

    for t in range(T - 1, -1, -1):
        if dones[t]:
            future_return = 0.0
        future_return = rewards[t] + gamma * future_return
        returns = returns.at[t].set(future_return)

    return returns

Pre-training Impact (Ablation)

From the paper's ablation study:

MethodSteps to 2% Failure Rate
Without pre-training2.5B steps
With BC + Value pre-training0.5B steps

5x faster convergence with proper pre-training!


7. Interactive Code Examples

Example 1: Complete JAX Environment

import jax
import jax.numpy as jnp
from jax import lax
from typing import NamedTuple
from functools import partial

# ============================================================
# PURE FUNCTIONAL DRIVING ENVIRONMENT
# ============================================================

class VehicleState(NamedTuple):
    """Immutable vehicle state."""
    x: jnp.ndarray        # Position x
    y: jnp.ndarray        # Position y
    vx: jnp.ndarray       # Velocity x
    vy: jnp.ndarray       # Velocity y
    heading: jnp.ndarray  # Heading angle
    step: jnp.ndarray     # Current timestep

class EnvConfig(NamedTuple):
    """Environment configuration (static)."""
    dt: float = 0.1           # Timestep
    max_steps: int = 151      # From paper
    max_speed: float = 15.0   # m/s
    goal_x: float = 100.0
    goal_y: float = 0.0

def reset(rng: jnp.ndarray, config: EnvConfig) -> VehicleState:
    """Initialize vehicle at origin."""
    return VehicleState(
        x=jnp.array(0.0),
        y=jnp.array(0.0),
        vx=jnp.array(0.0),
        vy=jnp.array(0.0),
        heading=jnp.array(0.0),
        step=jnp.array(0)
    )

def step(
    state: VehicleState,
    action: jnp.ndarray,  # [acceleration, steering_rate]
    config: EnvConfig
) -> tuple[VehicleState, jnp.ndarray, jnp.ndarray]:
    """
    Pure functional environment step.

    Returns: (new_state, reward, done)
    """
    accel, steer_rate = action[0], action[1]

    # Kinematic bicycle model (simplified)
    new_heading = state.heading + steer_rate * config.dt
    new_vx = state.vx + accel * jnp.cos(new_heading) * config.dt
    new_vy = state.vy + accel * jnp.sin(new_heading) * config.dt

    # Clip speed
    speed = jnp.sqrt(new_vx**2 + new_vy**2)
    scale = jnp.minimum(1.0, config.max_speed / (speed + 1e-6))
    new_vx = new_vx * scale
    new_vy = new_vy * scale

    # Update position
    new_x = state.x + new_vx * config.dt
    new_y = state.y + new_vy * config.dt

    new_state = VehicleState(
        x=new_x, y=new_y,
        vx=new_vx, vy=new_vy,
        heading=new_heading,
        step=state.step + 1
    )

    # Reward: progress toward goal + penalties
    dist_to_goal = jnp.sqrt(
        (new_x - config.goal_x)**2 + (new_y - config.goal_y)**2
    )
    progress_reward = -dist_to_goal * 0.01

    # Penalties (from paper)
    speed_penalty = -jnp.maximum(0, speed - config.max_speed)**2 * 0.1
    accel_penalty = -accel**2 * 0.01
    steer_penalty = -steer_rate**2 * 0.01

    reward = progress_reward + speed_penalty + accel_penalty + steer_penalty

    # Done conditions
    at_goal = dist_to_goal < 1.0
    timeout = state.step >= config.max_steps
    done = jnp.logical_or(at_goal, timeout)

    return new_state, reward, done

# Vectorize over batch of vehicles
batched_reset = jax.vmap(reset, in_axes=(0, None))
batched_step = jax.vmap(step, in_axes=(0, 0, None))

# JIT compile
jit_reset = jax.jit(batched_reset, static_argnums=(1,))
jit_step = jax.jit(batched_step, static_argnums=(2,))

# ============================================================
# USAGE EXAMPLE
# ============================================================

def main():
    config = EnvConfig()
    batch_size = 512  # From paper

    # Initialize
    rngs = jax.random.split(jax.random.PRNGKey(0), batch_size)
    states = jit_reset(rngs, config)
    print(f"Initial positions: x={states.x[:5]}")

    # Random actions
    actions = jax.random.uniform(
        jax.random.PRNGKey(1),
        (batch_size, 2),
        minval=-1.0, maxval=1.0
    )

    # Step
    new_states, rewards, dones = jit_step(states, actions, config)
    print(f"Rewards: {rewards[:5]}")
    print(f"New positions: x={new_states.x[:5]}")

if __name__ == "__main__":
    main()

Example 2: Attention-based Policy Network

import jax
import jax.numpy as jnp
from flax import linen as nn
from typing import Sequence

class MultiHeadAttention(nn.Module):
    """Multi-head attention for agent interactions."""
    num_heads: int
    head_dim: int

    @nn.compact
    def __call__(self, queries, keys, values, mask=None):
        # Project to heads
        q = nn.Dense(self.num_heads * self.head_dim)(queries)
        k = nn.Dense(self.num_heads * self.head_dim)(keys)
        v = nn.Dense(self.num_heads * self.head_dim)(values)

        # Reshape: (batch, seq, num_heads, head_dim)
        batch, q_len, _ = q.shape
        _, kv_len, _ = k.shape

        q = q.reshape(batch, q_len, self.num_heads, self.head_dim)
        k = k.reshape(batch, kv_len, self.num_heads, self.head_dim)
        v = v.reshape(batch, kv_len, self.num_heads, self.head_dim)

        # Attention scores
        scale = jnp.sqrt(self.head_dim)
        scores = jnp.einsum('bqhd,bkhd->bhqk', q, k) / scale

        if mask is not None:
            scores = jnp.where(mask, scores, -1e9)

        attn_weights = jax.nn.softmax(scores, axis=-1)

        # Weighted sum
        output = jnp.einsum('bhqk,bkhd->bqhd', attn_weights, v)
        output = output.reshape(batch, q_len, -1)

        return nn.Dense(self.num_heads * self.head_dim)(output)

class PerceiverEncoder(nn.Module):
    """
    Perceiver-style encoder from the paper.

    Uses cross-attention between modalities and self-attention within.
    """
    hidden_dim: int = 256
    num_heads: int = 8
    num_layers: int = 4

    @nn.compact
    def __call__(
        self,
        ego_features,      # (batch, ego_dim)
        agent_features,    # (batch, num_agents, agent_dim)
        road_features,     # (batch, num_roads, road_dim)
        agent_mask,        # (batch, num_agents)
        road_mask          # (batch, num_roads)
    ):
        batch_size = ego_features.shape[0]

        # Project all inputs to hidden dim
        ego = nn.Dense(self.hidden_dim)(ego_features)[:, None, :]
        agents = nn.Dense(self.hidden_dim)(agent_features)
        roads = nn.Dense(self.hidden_dim)(road_features)

        # Concatenate for self-attention
        # Shape: (batch, 1 + num_agents + num_roads, hidden_dim)
        tokens = jnp.concatenate([ego, agents, roads], axis=1)

        # Combined mask
        ego_mask = jnp.ones((batch_size, 1), dtype=bool)
        full_mask = jnp.concatenate([ego_mask, agent_mask, road_mask], axis=1)
        attn_mask = full_mask[:, None, None, :]  # Broadcast for attention

        # Transformer layers
        for _ in range(self.num_layers):
            # Self-attention
            attended = MultiHeadAttention(
                num_heads=self.num_heads,
                head_dim=self.hidden_dim // self.num_heads
            )(tokens, tokens, tokens, attn_mask)

            tokens = nn.LayerNorm()(tokens + attended)

            # FFN
            ffn_out = nn.Dense(self.hidden_dim * 4)(tokens)
            ffn_out = nn.gelu(ffn_out)
            ffn_out = nn.Dense(self.hidden_dim)(ffn_out)

            tokens = nn.LayerNorm()(tokens + ffn_out)

        # Return ego token (first position)
        return tokens[:, 0, :]

class DrivingPolicy(nn.Module):
    """
    Full driving policy with shared encoder.

    Architecture from paper:
    - Shared Perceiver encoder
    - Separate policy and value heads
    """
    hidden_dim: int = 256
    num_accel_actions: int = 71   # From paper
    num_steer_actions: int = 21

    @nn.compact
    def __call__(self, observations):
        # Shared encoder
        encoded = PerceiverEncoder(hidden_dim=self.hidden_dim)(
            ego_features=observations['ego'],
            agent_features=observations['agents'],
            road_features=observations['roads'],
            agent_mask=observations['agent_mask'],
            road_mask=observations['road_mask']
        )

        # Policy head (Dense Residual Blocks)
        policy_hidden = nn.Dense(self.hidden_dim)(encoded)
        policy_hidden = nn.relu(policy_hidden)
        policy_hidden = policy_hidden + nn.Dense(self.hidden_dim)(
            nn.relu(nn.Dense(self.hidden_dim)(policy_hidden))
        )

        # Action logits
        accel_logits = nn.Dense(self.num_accel_actions)(policy_hidden)
        steer_logits = nn.Dense(self.num_steer_actions)(policy_hidden)

        # Value head
        value_hidden = nn.Dense(self.hidden_dim)(encoded)
        value_hidden = nn.relu(value_hidden)
        value = nn.Dense(1)(value_hidden).squeeze(-1)

        return {
            'accel_logits': accel_logits,
            'steer_logits': steer_logits,
            'value': value
        }

# ============================================================
# USAGE
# ============================================================

def create_dummy_observations(batch_size, num_agents=64, num_roads=128):
    """Create dummy observations matching paper's structure."""
    return {
        'ego': jnp.zeros((batch_size, 10)),
        'agents': jnp.zeros((batch_size, num_agents, 20)),
        'roads': jnp.zeros((batch_size, num_roads, 15)),
        'agent_mask': jnp.ones((batch_size, num_agents), dtype=bool),
        'road_mask': jnp.ones((batch_size, num_roads), dtype=bool)
    }

# Initialize and run
policy = DrivingPolicy()
obs = create_dummy_observations(batch_size=32)

params = policy.init(jax.random.PRNGKey(0), obs)
outputs = policy.apply(params, obs)

print(f"Accel logits shape: {outputs['accel_logits'].shape}")
print(f"Steer logits shape: {outputs['steer_logits'].shape}")
print(f"Value shape: {outputs['value'].shape}")

Example 3: Complete Training Loop

import jax
import jax.numpy as jnp
from jax import lax
import optax
from typing import NamedTuple
from functools import partial

# ============================================================
# TRAINING STATE
# ============================================================

class TrainState(NamedTuple):
    params: dict
    opt_state: optax.OptState
    rng: jnp.ndarray
    step: int

# ============================================================
# PPO TRAINING STEP
# ============================================================

def create_train_step(
    policy_apply,
    env_step,
    env_config,
    lr=5.6e-5,      # From paper
    gamma=0.99,
    gae_lambda=0.95,
    clip_eps=0.3,   # From paper
    entropy_coef=0.03,  # From paper
    value_coef=0.5
):
    """Create JIT-compiled training step."""

    optimizer = optax.adam(lr)

    def compute_gae(rewards, values, dones):
        """Generalized Advantage Estimation."""
        T = rewards.shape[0]
        advantages = jnp.zeros(T)
        last_advantage = 0.0

        def scan_fn(last_adv, inputs):
            reward, value, next_value, done = inputs

            delta = reward + gamma * next_value * (1 - done) - value
            advantage = delta + gamma * gae_lambda * (1 - done) * last_adv

            return advantage, advantage

        next_values = jnp.concatenate([values[1:], values[-1:]])
        _, advantages = lax.scan(
            scan_fn,
            jnp.array(0.0),
            (rewards[::-1], values[::-1], next_values[::-1], dones[::-1])
        )

        return advantages[::-1]

    def ppo_loss(params, batch, old_log_probs, advantages, returns):
        """PPO clipped objective."""
        outputs = policy_apply(params, batch['observations'])

        # Current log probs
        accel_log_probs = jax.nn.log_softmax(outputs['accel_logits'])
        steer_log_probs = jax.nn.log_softmax(outputs['steer_logits'])

        accel_lp = accel_log_probs[
            jnp.arange(len(batch['accel_actions'])),
            batch['accel_actions']
        ]
        steer_lp = steer_log_probs[
            jnp.arange(len(batch['steer_actions'])),
            batch['steer_actions']
        ]
        log_probs = accel_lp + steer_lp

        # Policy loss
        ratio = jnp.exp(log_probs - old_log_probs)
        clipped_ratio = jnp.clip(ratio, 1 - clip_eps, 1 + clip_eps)
        policy_loss = -jnp.minimum(
            ratio * advantages,
            clipped_ratio * advantages
        ).mean()

        # Value loss
        value_loss = jnp.mean((outputs['value'] - returns) ** 2)

        # Entropy bonus
        accel_entropy = -jnp.sum(
            jnp.exp(accel_log_probs) * accel_log_probs, axis=-1
        ).mean()
        steer_entropy = -jnp.sum(
            jnp.exp(steer_log_probs) * steer_log_probs, axis=-1
        ).mean()
        entropy = accel_entropy + steer_entropy

        total_loss = policy_loss + value_coef * value_loss - entropy_coef * entropy

        return total_loss, {
            'policy_loss': policy_loss,
            'value_loss': value_loss,
            'entropy': entropy
        }

    @jax.jit
    def train_step(state: TrainState, batch):
        """Single training iteration."""

        # Compute old log probs (before update)
        old_outputs = policy_apply(state.params, batch['observations'])
        old_accel_lp = jax.nn.log_softmax(old_outputs['accel_logits'])
        old_steer_lp = jax.nn.log_softmax(old_outputs['steer_logits'])

        old_log_probs = (
            old_accel_lp[jnp.arange(len(batch['accel_actions'])), batch['accel_actions']] +
            old_steer_lp[jnp.arange(len(batch['steer_actions'])), batch['steer_actions']]
        )

        # Compute advantages and returns
        advantages = compute_gae(
            batch['rewards'],
            old_outputs['value'],
            batch['dones']
        )
        returns = advantages + old_outputs['value']

        # Normalize advantages
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        # Gradient step
        (loss, metrics), grads = jax.value_and_grad(
            ppo_loss, has_aux=True
        )(state.params, batch, old_log_probs, advantages, returns)

        updates, new_opt_state = optimizer.update(
            grads, state.opt_state, state.params
        )
        new_params = optax.apply_updates(state.params, updates)

        new_state = TrainState(
            params=new_params,
            opt_state=new_opt_state,
            rng=state.rng,
            step=state.step + 1
        )

        return new_state, metrics

    return train_step

# ============================================================
# ROLLOUT COLLECTION
# ============================================================

def create_rollout_fn(policy_apply, env_step, env_reset, config, rollout_length=32):
    """Create function to collect experience."""

    @jax.jit
    def collect_rollout(params, env_state, rng):
        """Collect rollout_length steps of experience."""

        def step_fn(carry, _):
            state, rng = carry
            rng, action_rng = jax.random.split(rng)

            # Get action from policy
            obs = state_to_obs(state)  # Convert state to observation dict
            outputs = policy_apply(params, obs)

            accel_action = jax.random.categorical(
                action_rng, outputs['accel_logits']
            )
            steer_action = jax.random.categorical(
                action_rng, outputs['steer_logits']
            )

            # Execute in environment
            action = actions_to_controls(accel_action, steer_action)
            new_state, reward, done = env_step(state, action, config)

            # Store transition
            transition = {
                'observations': obs,
                'accel_actions': accel_action,
                'steer_actions': steer_action,
                'rewards': reward,
                'dones': done,
                'values': outputs['value']
            }

            # Reset if done (for continuous collection)
            rng, reset_rng = jax.random.split(rng)
            new_state = jax.lax.cond(
                done,
                lambda: env_reset(reset_rng, config),
                lambda: new_state
            )

            return (new_state, rng), transition

        (final_state, _), trajectory = lax.scan(
            step_fn,
            (env_state, rng),
            xs=None,
            length=rollout_length
        )

        return final_state, trajectory

    return collect_rollout

8. Benchmarks Analysis

Simulator Performance

SimulatorDeviceBatch 1Batch 16Throughput Gain
WaymaxV1000.75 ms2.48 ms6.5x
PaperV1000.52 ms0.82 ms19.5x

Key insights:

  1. JIT compilation reduces per-step overhead (0.75 -> 0.52 ms)
  2. Batch efficiency is dramatically better (3.3x speedup vs 3x for Waymax)
  3. Effective throughput: 19,512 steps/second vs 6,452 steps/second
def analyze_throughput(single_time_ms, batch_time_ms, batch_size):
    """Analyze batching efficiency."""

    # Ideal: batch_time = single_time (perfect parallelism)
    ideal_time = single_time_ms

    # Actual overhead
    overhead_ms = batch_time_ms - single_time_ms
    overhead_per_item = overhead_ms / batch_size

    # Throughput
    throughput = batch_size / (batch_time_ms / 1000)  # steps/sec

    # Efficiency
    efficiency = (single_time_ms * batch_size) / batch_time_ms

    print(f"Batch size: {batch_size}")
    print(f"  Single step: {single_time_ms:.2f} ms")
    print(f"  Batched step: {batch_time_ms:.2f} ms")
    print(f"  Overhead per item: {overhead_per_item:.3f} ms")
    print(f"  Throughput: {throughput:.0f} steps/sec")
    print(f"  Efficiency: {efficiency:.1%}")

print("=== Paper's Simulator ===")
analyze_throughput(0.52, 0.82, 16)

print("\n=== Waymax Baseline ===")
analyze_throughput(0.75, 2.48, 16)

Training Scaling Efficiency

+----------------------------------------------------------+
|                 SCALING EFFICIENCY CHART                  |
+----------------------------------------------------------+

Runtime (hours)
    |
 45 +  *
    |
 40 +   \.
    |     \.    Ideal (linear)
 35 +       \.   ----------------
    |         \.
 30 +           \.
    |             \.
 25 +               \.-- Actual -+
    |                 \.         |
 20 +                   *        |  Gap = overhead
    |                    \.      |
 15 +                      \.    |
    |                        \.  |
 10 +                          \.*
    |
  5 +
    |
    +----+----+----+----+----+----+
         8   12   16   20   24   32
                  GPUs

Overhead Analysis:

GPUsActual (h)Ideal (h)Overhead
841.5841.580%
1622.8320.799.8%
3211.9910.4015.3%

The overhead comes from:

  1. Gradient synchronization via NCCL AllReduce
  2. Data loading across more workers
  3. Experience buffer coordination

Model Size vs Data Size

+----------------------------------------------------------+
|              FAILURE RATE BY MODEL & DATA SIZE            |
+----------------------------------------------------------+

Failure Rate (%)
    |
 3.0+  .
    |  : .
 2.5+  :   .                      Model sizes:
    |  :     .                    --- 0.75M params
 2.0+  :       .                  -.- 2.5M params
    |  :         .                ... 25M params
 1.5+  :           .
    |  :             .
 1.0+  :  .  .  .  .  .  .
    |     :           :   .  .  .  .  .
 0.5+        :  .  .  .  .  .  .  .  .  *
    |
    +------+------+------+------+------+------+
          600   1200   2000   3500   5000   6000
                    Data (hours)

Key Insight: Larger models only help with MORE data

Critical finding from paper:

"Increasing the model size only helps when sufficient real-world driving data is available"

This follows scaling laws observed in language models - there's an optimal compute allocation between model size and data size.


9. Common Pitfalls

Pitfall 1: Python Control Flow in JIT

# BAD: Python if statement
@jax.jit
def bad_reward(x):
    if x > 0:  # This is traced ONCE at compile time!
        return x * 2
    else:
        return x * -1

# What happens:
# - First call with x=5: traces if branch, always returns x*2
# - Call with x=-3: STILL returns x*2 (wrong!)

# GOOD: Use jnp.where or lax.cond
@jax.jit
def good_reward(x):
    return jnp.where(x > 0, x * 2, x * -1)

# Or for expensive branches:
@jax.jit
def good_reward_cond(x):
    return jax.lax.cond(
        x > 0,
        lambda x: x * 2,
        lambda x: x * -1,
        x
    )

Pitfall 2: In-Place Mutation

# BAD: NumPy-style mutation
@jax.jit
def bad_update(arr, idx, value):
    arr[idx] = value  # ERROR: JAX arrays are immutable
    return arr

# GOOD: Functional update with .at[]
@jax.jit
def good_update(arr, idx, value):
    return arr.at[idx].set(value)

# GOOD: For multiple updates
@jax.jit
def good_batch_update(arr, indices, values):
    return arr.at[indices].set(values)

Pitfall 3: Random Number Generation

# BAD: Using same key multiple times
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (10,))
y = jax.random.normal(key, (10,))  # SAME VALUES as x!

# GOOD: Split keys
key = jax.random.PRNGKey(0)
key, x_key, y_key = jax.random.split(key, 3)
x = jax.random.normal(x_key, (10,))
y = jax.random.normal(y_key, (10,))  # Different values

# GOOD: Inside JIT functions, pass key explicitly
@jax.jit
def sample_action(key, logits):
    return jax.random.categorical(key, logits)

Pitfall 4: Dynamic Shapes

# BAD: Variable-length sequences
@jax.jit
def bad_sum_variable(x):
    # Shape of x changes -> recompilation every time!
    return jnp.sum(x)

# GOOD: Pad to fixed size with mask
@jax.jit
def good_sum_padded(x, mask):
    # x always same shape, mask indicates valid elements
    return jnp.sum(x * mask)

# The paper pads all data:
# - 512 agents (fixed)
# - 151 timesteps (fixed)
# - 128 road elements (fixed)

Pitfall 5: Forgetting block_until_ready

import time

# BAD: Timing without synchronization
@jax.jit
def compute(x):
    return jnp.dot(x, x.T)

x = jax.random.normal(jax.random.PRNGKey(0), (1000, 1000))

start = time.time()
result = compute(x)  # Returns immediately (async!)
print(f"Time: {time.time() - start:.4f}s")  # Wrong! Too fast

# GOOD: Block until computation completes
start = time.time()
result = compute(x)
result.block_until_ready()  # Wait for GPU
print(f"Time: {time.time() - start:.4f}s")  # Correct timing

Pitfall 6: Inefficient Tree Operations

# BAD: Accessing pytree elements in loop
def bad_update_params(params, grads, lr):
    for key in params:
        params[key] = params[key] - lr * grads[key]
    return params

# GOOD: Use jax.tree.map (vectorized)
def good_update_params(params, grads, lr):
    return jax.tree.map(
        lambda p, g: p - lr * g,
        params, grads
    )

# BETTER: Use optax for optimizer state
import optax
optimizer = optax.adam(1e-3)
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)

Pitfall 7: Gradient Through Non-Differentiable Operations

# BAD: argmax is not differentiable
@jax.jit
def bad_discrete_choice(logits):
    action = jnp.argmax(logits)  # No gradient!
    return action

# GOOD: Use Gumbel-Softmax for differentiable sampling
@jax.jit
def good_discrete_choice(key, logits, temperature=1.0):
    gumbel_noise = jax.random.gumbel(key, logits.shape)
    soft_sample = jax.nn.softmax((logits + gumbel_noise) / temperature)
    return soft_sample  # Differentiable!

# GOOD: For RL, use log_prob gradients (REINFORCE/PPO)
def policy_gradient_loss(log_probs, advantages):
    return -jnp.mean(log_probs * jax.lax.stop_gradient(advantages))

Pitfall 8: Memory Leaks with JIT Cache

# BAD: Creating new functions in loop (cache grows forever)
for i in range(1000):
    @jax.jit
    def compute(x):
        return x + i  # Different closure each time!
    result = compute(data)

# GOOD: Use static_argnums or define function once
@partial(jax.jit, static_argnums=(1,))
def compute_with_static(x, i):
    return x + i

for i in range(1000):
    result = compute_with_static(data, i)

# BETTER: Vectorize over the parameter
@jax.jit
def compute_all(x, i_values):
    return jax.vmap(lambda i: x + i)(i_values)

10. Hands-On Exercises

Exercise 1: Implement a Pure Functional Environment

Goal: Create a JAX-compatible gridworld environment.

"""
TODO: Implement a pure functional gridworld environment.

Requirements:
1. State includes: position (x, y), goal position, step count
2. Actions: 0=up, 1=down, 2=left, 3=right
3. Reward: -1 per step, +10 for reaching goal
4. Episode ends at goal or after 100 steps

Constraints:
- All functions must be JIT-compilable
- No Python control flow (if/for) inside JIT functions
- State must be immutable (NamedTuple or dataclass)
"""

import jax
import jax.numpy as jnp
from typing import NamedTuple

class GridState(NamedTuple):
    x: jnp.ndarray
    y: jnp.ndarray
    goal_x: jnp.ndarray
    goal_y: jnp.ndarray
    step: jnp.ndarray

def reset(rng: jnp.ndarray) -> GridState:
    """Initialize environment. Goal at (7, 7), start at (0, 0)."""
    # TODO: Implement
    pass

def step(state: GridState, action: jnp.ndarray) -> tuple:
    """
    Execute action and return (new_state, reward, done).

    Hint: Use jnp.where for conditional logic
    """
    # TODO: Implement
    # Action mapping: 0=up (+y), 1=down (-y), 2=left (-x), 3=right (+x)
    pass

# Test your implementation
def test_env():
    reset_fn = jax.jit(reset)
    step_fn = jax.jit(step)

    state = reset_fn(jax.random.PRNGKey(0))
    print(f"Initial state: ({state.x}, {state.y})")

    # Move right 7 times, then up 7 times
    for action in [3]*7 + [0]*7:
        state, reward, done = step_fn(state, jnp.array(action))
        print(f"Action {action}: pos=({state.x}, {state.y}), r={reward}, done={done}")

Exercise 2: Batch Rollouts with scan

Goal: Collect parallel rollouts using vmap and scan.

"""
TODO: Implement batched rollout collection.

Requirements:
1. Use lax.scan for time steps
2. Use vmap for parallel environments
3. Return complete trajectories for PPO training
"""

def collect_batch_rollouts(
    policy_fn,
    env_reset,
    env_step,
    params,
    rng,
    num_envs: int = 64,
    rollout_length: int = 32
):
    """
    Collect rollouts from num_envs parallel environments.

    Returns:
        Dictionary with keys:
        - observations: (num_envs, rollout_length, obs_dim)
        - actions: (num_envs, rollout_length)
        - rewards: (num_envs, rollout_length)
        - dones: (num_envs, rollout_length)
        - values: (num_envs, rollout_length)
        - log_probs: (num_envs, rollout_length)
    """
    # TODO: Implement
    # Hint:
    # 1. Initialize num_envs environments with vmap(reset)
    # 2. Define single_step function for scan
    # 3. Use vmap over environments, scan over time
    pass

Exercise 3: V-trace Implementation

Goal: Implement V-trace off-policy correction.

"""
TODO: Implement V-trace algorithm.

Reference: IMPALA paper (Espeholt et al., 2018)
"""

def vtrace_targets(
    behavior_log_probs: jnp.ndarray,  # (T,)
    target_log_probs: jnp.ndarray,    # (T,)
    rewards: jnp.ndarray,             # (T,)
    values: jnp.ndarray,              # (T+1,) including bootstrap
    dones: jnp.ndarray,               # (T,)
    gamma: float = 0.99,
    rho_bar: float = 1.0,
    c_bar: float = 1.0
) -> tuple:
    """
    Compute V-trace targets and advantages.

    Returns:
        vs: V-trace value targets (T,)
        advantages: Policy gradient advantages (T,)

    Algorithm:
    1. Compute importance weights rho = exp(target_lp - behavior_lp)
    2. Clip rho to rho_bar, c to c_bar
    3. Compute TD errors: delta = rho * (r + gamma * V(s') - V(s))
    4. Accumulate backward: vs = V(s) + sum of discounted deltas
    5. Advantages = rho * (r + gamma * vs' - V(s))
    """
    # TODO: Implement
    pass

# Test with known values
def test_vtrace():
    T = 10

    # On-policy case: behavior = target, should match GAE
    log_probs = jnp.zeros(T)
    rewards = jnp.ones(T)
    values = jnp.zeros(T + 1)
    dones = jnp.zeros(T)

    vs, advantages = vtrace_targets(
        log_probs, log_probs, rewards, values, dones
    )

    print(f"V-trace targets: {vs}")
    print(f"Advantages: {advantages}")

Exercise 4: Distributed Gradient Averaging

Goal: Implement synchronized training across devices.

"""
TODO: Implement multi-device training with gradient synchronization.

Use pmap and lax.pmean to average gradients across devices.
"""

import jax
from jax import pmap, lax

def create_distributed_trainer(model_apply, optimizer):
    """
    Create a training function that works across multiple devices.

    Requirements:
    1. Each device processes its own batch
    2. Gradients are averaged across devices
    3. All devices end up with identical parameters
    """

    def train_step(params, opt_state, batch):
        """
        Single training step with gradient synchronization.

        Args:
            params: Replicated parameters (n_devices, ...)
            opt_state: Replicated optimizer state
            batch: Sharded batch (n_devices, batch_per_device, ...)
        """
        # TODO: Implement
        # 1. Compute loss and gradients
        # 2. Average gradients with lax.pmean
        # 3. Apply updates
        pass

    # TODO: Wrap with pmap, specify axis_name for pmean
    return None

def test_distributed():
    n_devices = jax.local_device_count()
    print(f"Testing with {n_devices} devices")

    # Create simple linear model
    def model_apply(params, x):
        return jnp.dot(x, params['w']) + params['b']

    # TODO: Initialize replicated params and test training step
    pass

Exercise 5: Full Mini-Pipeline

Goal: Build a complete training pipeline combining all concepts.

"""
TODO: Build a complete mini RL training pipeline.

Components:
1. Pure functional CartPole-like environment
2. Simple MLP policy with value head
3. PPO loss with GAE
4. Batched rollout collection
5. Training loop with logging

Target: Train for 100 iterations, achieve >100 average reward.
"""

def build_training_pipeline():
    """
    Build and return all components of the training pipeline.

    Returns:
        dict with keys:
        - 'env_reset': JIT-compiled reset function
        - 'env_step': JIT-compiled step function
        - 'policy': Policy network
        - 'collect_rollouts': Rollout collection function
        - 'train_step': Single PPO update step
        - 'train': Full training loop function
    """
    # TODO: Implement all components
    pass

def main():
    pipeline = build_training_pipeline()

    # Initialize
    rng = jax.random.PRNGKey(0)
    # TODO: Initialize params, opt_state

    # Training loop
    for iteration in range(100):
        # TODO: Collect rollouts
        # TODO: Compute advantages
        # TODO: PPO update
        # TODO: Log metrics
        pass

if __name__ == "__main__":
    main()

11. Interview Questions

Beginner Level

Q1: What is the difference between jax.jit and jax.vmap?

<details> <summary>Answer</summary>
  • jax.jit: Just-In-Time compilation. Traces a function and compiles it to optimized XLA code. Eliminates Python interpreter overhead. The function runs on GPU/TPU as a single fused kernel.

  • jax.vmap: Vectorizing map. Takes a function that operates on single examples and automatically transforms it to operate on batches. Adds a batch dimension without changing the function's implementation.

Key insight: They compose! jax.jit(jax.vmap(fn)) gives you a compiled, batched function.

@jax.jit
def fast_fn(x):
    return x ** 2

batched_fast_fn = jax.jit(jax.vmap(fast_fn))
</details>

Q2: Why can't you use Python if statements inside JIT-compiled functions?

<details> <summary>Answer</summary>

JIT compilation happens at trace time, not runtime. When JAX traces a function:

  1. It executes the Python code once with abstract "tracer" values
  2. Records all operations into a computation graph
  3. Compiles this graph

Python if statements are evaluated during tracing with tracer values, not actual data. The condition is traced once and "baked in" to the compiled function.

Solution: Use jnp.where() for element-wise conditionals or jax.lax.cond() for scalar conditionals.

# Bad
@jax.jit
def bad(x):
    if x > 0:  # Evaluated at trace time!
        return x
    return -x

# Good
@jax.jit
def good(x):
    return jnp.where(x > 0, x, -x)
</details>

Q3: What is the purpose of jax.random.split() and why is it necessary?

<details> <summary>Answer</summary>

JAX uses explicit PRNG (Pseudo-Random Number Generator) keys for reproducibility and parallelization:

  1. Explicit state: Unlike NumPy's hidden global RNG state, JAX requires you to pass RNG keys explicitly
  2. Functional purity: Using the same key twice gives the same "random" numbers
  3. Parallelization: Split keys can be used independently on different devices

split() creates new independent keys from a parent key:

key = jax.random.PRNGKey(0)
key, subkey1, subkey2 = jax.random.split(key, 3)
# subkey1 and subkey2 produce different random sequences
# key is now a new key for future splits

This enables reproducible yet statistically independent random sampling across parallel computations.

</details>

Intermediate Level

Q4: Explain the actor-learner architecture used in the paper. What is the off-policy problem and how does V-trace solve it?

<details> <summary>Answer</summary>

Architecture:

  • Actors: Run the policy, collect experience, send to learners
  • Learners: Receive experience, compute gradients, update policy
  • Asynchronous: Actors don't wait for learner updates

Off-policy problem: When an actor collects experience with policy v1, by the time the learner processes it, the policy might be v5. The collected experience is "off-policy" - it doesn't represent the current policy's behavior.

V-trace solution: V-trace uses importance sampling to correct for the policy mismatch:

  1. Compute importance weight: rho = pi_current(a|s) / pi_behavior(a|s)
  2. Clip weights to reduce variance: rho_clipped = min(rho, rho_bar)
  3. Use clipped weights in TD error: delta = rho_clipped * (r + gamma*V(s') - V(s))
  4. Accumulate corrected values backward through time

This gives unbiased (when rho_bar -> infinity) value estimates even with stale behavior policies.

</details>

Q5: The paper pads all scenarios to fixed sizes (512 agents, 151 timesteps). Why is this necessary and what are the trade-offs?

<details> <summary>Answer</summary>

Why necessary:

  • XLA compilation requires static shapes
  • JIT compiles a new function for each unique shape
  • Dynamic shapes would cause constant recompilation
  • vmap/pmap require uniform dimensions across batch

Trade-offs:

Pros:

  • Single compiled kernel handles all scenarios
  • Efficient GPU memory allocation
  • Enables vectorized operations across batch

Cons:

  • Memory overhead for small scenarios
  • Compute waste on padded elements
  • Need to track valid elements with masks

Mitigation:

  • Use masks to zero out padded element contributions
  • Choose pad sizes based on data distribution (512 covers 99%+ of real scenarios)
  • Mask-aware attention/aggregation operations
# Masked mean example
def masked_mean(values, mask):
    return jnp.sum(values * mask) / jnp.sum(mask)
</details>

Q6: How does jax.lax.scan differ from a Python for loop, and why is it crucial for this paper's performance?

<details> <summary>Answer</summary>

Python for loop:

  • Interpreted at runtime
  • Each iteration is a separate GPU kernel call
  • CPU-GPU synchronization between iterations
  • Gradients computed via list accumulation

jax.lax.scan:

  • Compiled into a single XLA while loop
  • Entire loop runs as one fused kernel
  • No CPU involvement during execution
  • Automatic gradient computation via scan's adjoint

Performance impact: The paper processes 151 timesteps per rollout. With a Python loop, that's 151 kernel launches with synchronization. With scan, it's one kernel that runs the entire trajectory.

For batched rollouts (16 scenarios x 512 agents x 151 steps), scan enables:

  • Single compilation
  • No Python overhead
  • Efficient GPU memory management
  • Automatic backprop through time
# 151 kernel launches
for t in range(151):
    state = step(state, actions[t])

# Single fused kernel
final_state, trajectory = jax.lax.scan(step, initial_state, actions)
</details>

Advanced Level

Q7: The paper achieves 87% efficiency when scaling from 8 to 32 GPUs. What are the sources of overhead and how would you improve scaling further?

<details> <summary>Answer</summary>

Sources of overhead (13% loss):

  1. Gradient synchronization: AllReduce across 32 GPUs has O(log n) latency
  2. Bandwidth saturation: Large model gradients compete for network bandwidth
  3. Stragglers: Synchronous updates wait for slowest worker
  4. Data loading: More workers compete for I/O

Improvement strategies:

  1. Gradient compression:

    • Quantize gradients to lower precision (FP16, INT8)
    • Top-k sparsification (only send largest gradients)
  2. Asynchronous SGD:

    • Don't wait for all gradients (paper uses this with V-trace correction)
    • Local SGD: sync every N steps instead of every step
  3. Pipeline parallelism:

    • Split model across GPUs
    • Overlap forward/backward of different microbatches
  4. Better communication topology:

    • Ring AllReduce is O(n) bandwidth but O(1) latency
    • Hierarchical AllReduce for multi-node
  5. Overlap computation and communication:

    • Start AllReduce on earlier layers while computing later layers
    • Requires careful dependency management
# Gradient compression example
def compress_gradients(grads, top_k_ratio=0.1):
    flat = jax.tree.leaves(grads)
    k = int(len(flat) * top_k_ratio)
    top_k_indices = jnp.argsort(jnp.abs(flat))[-k:]
    sparse = jnp.zeros_like(flat).at[top_k_indices].set(flat[top_k_indices])
    return jax.tree.unflatten(jax.tree.structure(grads), sparse)
</details>

Q8: The paper pre-trains both policy and value networks. Design an experiment to quantify the contribution of value pre-training vs policy pre-training.

<details> <summary>Answer</summary>

Experimental design:

ConditionPolicy Pre-trainValue Pre-trainDescription
BaselineNoNoRandom init
Policy-onlyYesNoBC for policy, random value
Value-onlyNoYesRandom policy, returns for value
FullYesYesPaper's approach

Metrics:

  • Training curves: reward vs steps
  • Time to threshold: steps to reach X% performance
  • Final performance: asymptotic reward
  • Training stability: variance across seeds

Hypothesis from paper:

  • Policy-only will be unstable (random value -> noisy advantages)
  • Value-only won't help much (random policy explores poorly)
  • Full pre-training gives 5x faster convergence

Implementation:

def run_ablation(config):
    results = {}

    for pretrain_policy in [True, False]:
        for pretrain_value in [True, False]:
            name = f"policy={pretrain_policy}_value={pretrain_value}"

            # Pre-training phase
            if pretrain_policy or pretrain_value:
                params = behavioral_cloning(
                    demo_data,
                    train_policy=pretrain_policy,
                    train_value=pretrain_value
                )
            else:
                params = random_init()

            # RL training
            curve = train_ppo(params, max_steps=2.5e9)
            results[name] = curve

    return results

Expected findings:

  • Value pre-training is crucial for actor-critic stability
  • Policy pre-training speeds up exploration
  • Both together are synergistic, not just additive
</details>

Q9: Explain how you would debug a training run where loss is NaN after switching from 1 GPU to 8 GPUs.

<details> <summary>Answer</summary>

Systematic debugging approach:

  1. Verify basic pmap correctness:
# Check all devices receive same initial params
@pmap
def check_params(params):
    return jax.tree.map(lambda x: jnp.sum(x), params)

sums = check_params(replicated_params)
assert all(s == sums[0] for s in sums), "Params not synchronized!"
  1. Check gradient synchronization:
# Print gradient norms per device before/after pmean
def debug_train_step(params, batch):
    loss, grads = jax.value_and_grad(loss_fn)(params)

    grad_norm_before = jax.tree.map(lambda g: jnp.linalg.norm(g), grads)
    jax.debug.print("Grad norm before sync: {}", grad_norm_before)

    grads = lax.pmean(grads, 'devices')

    grad_norm_after = jax.tree.map(lambda g: jnp.linalg.norm(g), grads)
    jax.debug.print("Grad norm after sync: {}", grad_norm_after)

    return grads
  1. Check for effective batch size mismatch:
# 8 GPUs = 8x effective batch size
# May need to scale learning rate: lr_new = lr_old * sqrt(8)
# Or adjust batch normalization statistics
  1. Check data sharding:
# Ensure each device gets different data
def check_data_sharding(batch):
    # batch shape: (n_devices, batch_per_device, ...)
    for i in range(n_devices):
        for j in range(i+1, n_devices):
            overlap = jnp.sum(batch[i] == batch[j])
            assert overlap < threshold, f"Devices {i},{j} have duplicate data!"
  1. Look for numerical instabilities magnified by scale:
# Common culprits:
# - log(0) in log_probs -> add epsilon
# - exp(large) in softmax -> use log-sum-exp trick
# - Division by zero in normalization -> check masks

# Add numerical guards
log_probs = jnp.log(probs + 1e-8)
advantages = (adv - mean) / (std + 1e-8)
  1. Check random key handling:
# Each device needs different RNG keys
keys = jax.random.split(master_key, n_devices)
# NOT: replicated_key = jnp.stack([key] * n_devices)

Most likely culprits:

  • Learning rate not scaled for larger effective batch
  • Gradient explosion from pmean (sum not mean?)
  • Data duplication across devices
  • RNG key reuse causing identical samples
</details>

Q10: Design a system to achieve 10x the scale of this paper (25B agent steps instead of 2.5B). What are the key bottlenecks and how would you address them?

<details> <summary>Answer</summary>

Key bottlenecks at 10x scale:

  1. Data throughput:

    • 10x more experience to collect and process
    • Solution: More actors, faster simulators, better I/O
  2. Gradient computation:

    • 10x more gradient steps (assuming same batch size)
    • Solution: Larger batches, more learners
  3. Communication overhead:

    • More devices = more synchronization
    • Solution: Hierarchical communication, local SGD
  4. Memory:

    • Replay buffers grow with experience
    • Solution: On-policy (no buffer) or distributed storage

Architecture for 25B steps:

+------------------------------------------------------------------+
|                    10x SCALE ARCHITECTURE                         |
+------------------------------------------------------------------+

Tier 1: Actor Pods (100 pods x 8 GPUs each = 800 GPUs)
+-----------+  +-----------+       +-----------+
| Actor Pod |  | Actor Pod |  ...  | Actor Pod |
| 8 GPUs    |  | 8 GPUs    |       | 8 GPUs    |
+-----------+  +-----------+       +-----------+
      |              |                   |
      +-------+------+-------+-----------+
              |
        Experience Stream (Kafka/Redis)
              |
+-------------+-------------+
|             |             |
v             v             v
+-----------+ +-----------+ +-----------+
| Learner   | | Learner   | | Learner   |
| Shard 0   | | Shard 1   | | Shard 2   |
| 16 GPUs   | | 16 GPUs   | | 16 GPUs   |
+-----------+ +-----------+ +-----------+
      |             |             |
      +------+------+------+------+
             |
       Model Parallelism
       (if model > GPU memory)

Key design decisions:

  1. Decouple actors and learners completely:

    • Actors push to message queue
    • Learners pull async
    • No blocking synchronization
  2. Hierarchical gradient sync:

    • AllReduce within learner shard (fast)
    • Periodic sync across shards (slower, less frequent)
  3. Population-based training:

    • Multiple policies evolving in parallel
    • Cross-pollination of best performers
  4. Importance sampling buffer:

    • Prioritize rare/difficult scenarios
    • Learned difficulty estimator
  5. Mixed-precision training:

    • FP16 for forward pass and gradients
    • FP32 master weights
    • 2x memory savings, 2x throughput

Estimated resources:

  • ~1000 GPUs
  • ~1 week training time
  • ~100TB experience data throughput
  • Cost: ~$500k-1M cloud compute

Verification:

  • Linear speedup tests at each scale doubling
  • Convergence parity checks (same final performance)
  • Ablations on buffer size, sync frequency
</details>

Summary

This paper demonstrates that autonomous driving policies can be dramatically improved through scale:

  1. JAX enables scale: JIT, vmap, pmap, scan eliminate Python overhead and enable massive parallelization

  2. Pure functional design: Immutable state, explicit RNG, no side effects enable compilation and parallelization

  3. Distributed training: Async actor-learner with V-trace correction achieves near-linear GPU scaling

  4. Pre-training matters: BC + value pre-training gives 5x faster RL convergence

  5. Data-model scaling: Larger models only help with sufficient data (scaling laws apply)

The engineering principles here apply broadly to any large-scale RL system. Master these concepts and you'll be well-prepared for ML infrastructure roles at top companies.


References

  1. Original Paper: arXiv:2312.15122
  2. JAX Documentation: jax.readthedocs.io
  3. IMPALA (V-trace): arXiv:1802.01561
  4. PPO: arXiv:1707.06347
  5. Waymax: waymo.com/open/data/motion