Deep Dive: Scaling Autonomous Driving with JAX-Accelerated Reinforcement Learning
Paper: Scaling Is All You Need: Autonomous Driving with JAX-Accelerated Reinforcement Learning
Target Audience: ML Infrastructure Engineers, Distributed Systems Engineers, RL Practitioners
Reading Time: 45-60 minutes (with exercises: 2-3 hours)
Table of Contents
- Executive Summary
- JAX Primitives Deep Dive
- Pure Functional Simulation
- Distributed Training Architecture
- Performance Optimization
- Pre-training Strategy
- Interactive Code Examples
- Benchmarks Analysis
- Common Pitfalls
- Hands-On Exercises
- Interview Questions
1. Executive Summary
The Core Insight
Scaling reinforcement learning for autonomous driving requires three pillars:
- Hardware-accelerated simulation - JAX enables GPU-native environment execution
- Distributed asynchronous training - Multiple actors/learners with gradient synchronization
- Massive real-world data - 6000+ hours of human driving scenarios
Key Results
| Metric | State-of-the-Art | This Paper | Improvement |
|---|---|---|---|
| Failure Rate | 2.81% | 0.88% | -64% |
| Progress Ratio | 87.6% | 120.8% | +25% |
| Agent Steps | 30M (Waymax) | 2.5B | 83x more |
Why This Matters for ML Infrastructure
This paper demonstrates how to:
- Eliminate Python overhead via JIT compilation
- Achieve near-linear GPU scaling (8 to 32 GPUs with only 15% overhead)
- Combine simulation and training into a single XLA computation graph
- Handle off-policy corrections in distributed settings with V-trace
2. JAX Primitives Deep Dive
JAX is the foundation of the paper's performance gains. Let's understand each primitive in depth.
2.1 jax.jit - Just-In-Time Compilation
What it does: Traces Python functions and compiles them to XLA (Accelerated Linear Algebra) operations that run on GPU/TPU.
Why it matters: Eliminates Python interpreter overhead. The paper achieves 0.52ms per step vs 0.75ms for baseline - a 30% speedup from JIT alone.
import jax
import jax.numpy as jnp
from functools import partial
# WITHOUT JIT - Python interpreter overhead on every call
def slow_matmul_chain(x, weights):
for w in weights:
x = jnp.dot(x, w)
x = jax.nn.relu(x)
return x
# WITH JIT - Compiled once, runs as native XLA
@jax.jit
def fast_matmul_chain(x, weights):
for w in weights:
x = jnp.dot(x, w)
x = jax.nn.relu(x)
return x
# Benchmark comparison
import time
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (1000, 256))
weights = [jax.random.normal(key, (256, 256)) for _ in range(10)]
# Warm-up JIT compilation
_ = fast_matmul_chain(x, weights)
# Timing
start = time.time()
for _ in range(100):
_ = slow_matmul_chain(x, weights)
slow_time = time.time() - start
start = time.time()
for _ in range(100):
_ = fast_matmul_chain(x, weights)
fast_time = time.time() - start
print(f"Without JIT: {slow_time:.3f}s")
print(f"With JIT: {fast_time:.3f}s")
print(f"Speedup: {slow_time/fast_time:.1f}x")
Key insight from paper: By JIT-compiling the entire simulation step, the authors avoid CPU-GPU data transfers between timesteps.
2.2 jax.vmap - Automatic Vectorization
What it does: Transforms a function that operates on single examples into one that operates on batches, without changing the function's code.
Why it matters: The paper processes 512 agents simultaneously. Writing batch-aware code manually is error-prone; vmap handles it automatically.
import jax
import jax.numpy as jnp
# Function written for a SINGLE agent
def compute_agent_reward(
position: jnp.ndarray, # shape: (2,)
velocity: jnp.ndarray, # shape: (2,)
goal: jnp.ndarray, # shape: (2,)
max_speed: float = 10.0
) -> float:
"""Compute reward for one agent."""
# Distance to goal (negative = bad)
dist_to_goal = jnp.linalg.norm(position - goal)
# Speed penalty if exceeding limit
speed = jnp.linalg.norm(velocity)
speed_penalty = jnp.maximum(0, speed - max_speed) ** 2
return -dist_to_goal - 0.1 * speed_penalty
# Automatically vectorize over batch of agents
batched_reward = jax.vmap(compute_agent_reward)
# Now works on batches!
batch_size = 512
key = jax.random.PRNGKey(42)
positions = jax.random.uniform(key, (batch_size, 2)) * 100
velocities = jax.random.normal(key, (batch_size, 2)) * 5
goals = jax.random.uniform(key, (batch_size, 2)) * 100
# Single call processes all 512 agents in parallel
rewards = batched_reward(positions, velocities, goals)
print(f"Rewards shape: {rewards.shape}") # (512,)
Nested vmap for multi-dimensional batching:
# Paper uses scenarios x agents x timesteps
# vmap can be nested for multiple batch dimensions
def single_step(state, action):
"""Process single (scenario, agent, timestep)."""
return state + action * 0.1
# Vectorize over agents within a scenario
agents_step = jax.vmap(single_step, in_axes=(0, 0))
# Vectorize over scenarios
scenarios_agents_step = jax.vmap(agents_step, in_axes=(0, 0))
# Now handles (num_scenarios, num_agents, state_dim)
num_scenarios = 16
num_agents = 512
state_dim = 6
states = jnp.zeros((num_scenarios, num_agents, state_dim))
actions = jnp.ones((num_scenarios, num_agents, state_dim))
new_states = scenarios_agents_step(states, actions)
print(f"Output shape: {new_states.shape}") # (16, 512, 6)
2.3 jax.pmap - Parallel Mapping Across Devices
What it does: Distributes computation across multiple GPUs/TPUs with automatic data parallelism.
Why it matters: The paper scales from 8 to 32 GPUs with near-linear speedup. pmap handles the device coordination.
import jax
import jax.numpy as jnp
from jax import pmap
# Check available devices
print(f"Available devices: {jax.devices()}")
n_devices = jax.local_device_count()
# Function to run on each device
def train_step(params, batch):
"""Single gradient step - runs on one device."""
def loss_fn(p):
predictions = jnp.dot(batch['x'], p['w']) + p['b']
return jnp.mean((predictions - batch['y']) ** 2)
loss, grads = jax.value_and_grad(loss_fn)(params)
# Simple SGD update
new_params = jax.tree.map(
lambda p, g: p - 0.01 * g, params, grads
)
return new_params, loss
# pmap automatically:
# 1. Shards data across devices (first axis)
# 2. Runs computation in parallel
# 3. Handles cross-device communication
parallel_train_step = pmap(train_step)
# Initialize replicated params (same on each device)
params = {
'w': jnp.ones((10, 1)),
'b': jnp.zeros((1,))
}
replicated_params = jax.tree.map(
lambda x: jnp.stack([x] * n_devices), params
)
# Create sharded batch (different data per device)
batch_per_device = 32
total_batch = batch_per_device * n_devices
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (total_batch, 10))
y = jax.random.normal(key, (total_batch, 1))
# Reshape for pmap: (n_devices, batch_per_device, ...)
sharded_batch = {
'x': x.reshape(n_devices, batch_per_device, 10),
'y': y.reshape(n_devices, batch_per_device, 1)
}
# Run parallel training step
new_params, losses = parallel_train_step(replicated_params, sharded_batch)
print(f"Losses per device: {losses}")
Gradient synchronization with pmap:
# The paper synchronizes gradients across learners using NCCL
# In JAX, use lax.pmean for this
def synchronized_train_step(params, batch):
"""Train step with gradient averaging across devices."""
def loss_fn(p):
predictions = jnp.dot(batch['x'], p['w']) + p['b']
return jnp.mean((predictions - batch['y']) ** 2)
loss, grads = jax.value_and_grad(loss_fn)(params)
# Average gradients across all devices
# 'batch' is the axis name from pmap
grads = jax.lax.pmean(grads, axis_name='batch')
loss = jax.lax.pmean(loss, axis_name='batch')
new_params = jax.tree.map(
lambda p, g: p - 0.01 * g, params, grads
)
return new_params, loss
# axis_name connects pmean to the pmap dimension
parallel_sync_step = pmap(synchronized_train_step, axis_name='batch')
2.4 jax.lax.scan - Efficient Sequential Operations
What it does: Compiles a loop into a single XLA operation, avoiding Python loop overhead and enabling automatic differentiation through time.
Why it matters: The paper "scans along the time axis of dynamic data" to process 151 timesteps efficiently.
import jax
import jax.numpy as jnp
from jax import lax
# Environment step function
def env_step(carry, action):
"""
Single environment step.
carry: (state, cumulative_reward)
action: action to take
Returns: new_carry, output_for_this_step
"""
state, cum_reward = carry
# Physics update (simplified)
position, velocity = state[:2], state[2:4]
new_velocity = velocity + action * 0.1
new_position = position + new_velocity * 0.1
new_state = jnp.concatenate([new_position, new_velocity])
# Reward calculation
reward = -jnp.sum(new_position ** 2) # Move towards origin
new_carry = (new_state, cum_reward + reward)
output = {'state': new_state, 'reward': reward}
return new_carry, output
# Roll out 151 timesteps efficiently
@jax.jit
def rollout(initial_state, actions):
"""
Rollout environment for T timesteps.
actions: (T, action_dim)
"""
initial_carry = (initial_state, 0.0)
final_carry, trajectory = lax.scan(
env_step,
initial_carry,
actions
)
final_state, total_reward = final_carry
return trajectory, total_reward
# Usage
T = 151 # Timesteps (from paper)
action_dim = 2
state_dim = 4
initial_state = jnp.zeros(state_dim)
actions = jax.random.normal(jax.random.PRNGKey(0), (T, action_dim))
trajectory, total_reward = rollout(initial_state, actions)
print(f"Trajectory states shape: {trajectory['state'].shape}") # (151, 4)
print(f"Total reward: {total_reward:.2f}")
Combining scan with vmap for batched rollouts:
# The paper batches scenarios AND time using scan + vmap
@jax.jit
def batched_rollout(initial_states, actions_batch):
"""
Rollout multiple scenarios in parallel.
initial_states: (batch_size, state_dim)
actions_batch: (batch_size, T, action_dim)
"""
# vmap over batch dimension
return jax.vmap(rollout)(initial_states, actions_batch)
# Run 16 scenarios with 151 timesteps each
batch_size = 16
initial_states = jnp.zeros((batch_size, state_dim))
actions_batch = jax.random.normal(
jax.random.PRNGKey(1),
(batch_size, T, action_dim)
)
trajectories, total_rewards = batched_rollout(initial_states, actions_batch)
print(f"Batched states shape: {trajectories['state'].shape}") # (16, 151, 4)
print(f"Total rewards shape: {total_rewards.shape}") # (16,)
3. Pure Functional Simulation
Why Functional Programming Matters
The paper enforces strict functional programming principles:
"All functions need to be pure, meaning they cannot have any side effects"
This enables:
- JIT compilation - XLA can only compile pure functions
- Automatic differentiation - Gradients require deterministic computation
- Parallelization - No shared state means no race conditions
- Reproducibility - Same inputs always produce same outputs
The Anti-Pattern: Stateful Simulation
# BAD: Stateful simulation - CANNOT be JIT compiled
class StatefulEnv:
def __init__(self):
self.state = None
self.step_count = 0
self.rng = np.random.RandomState(42) # Hidden state!
def reset(self):
self.state = np.zeros(4)
self.step_count = 0
return self.state
def step(self, action):
# Mutates internal state - side effect!
self.state = self.state + action * 0.1
self.step_count += 1
# Uses hidden RNG state - non-deterministic!
noise = self.rng.randn(*self.state.shape) * 0.01
self.state += noise
reward = -np.sum(self.state ** 2)
done = self.step_count >= 100
return self.state.copy(), reward, done
Problems with stateful design:
- Cannot JIT compile (Python objects store state)
- Non-deterministic (RNG state hidden)
- Cannot vmap (shared state across batch)
- Cannot differentiate through (state mutation)
The Pattern: Pure Functional Simulation
import jax
import jax.numpy as jnp
from typing import NamedTuple
# State is explicit, immutable, and passed through functions
class EnvState(NamedTuple):
position: jnp.ndarray
velocity: jnp.ndarray
step_count: int
rng_key: jnp.ndarray
# GOOD: Pure functional environment
@jax.jit
def env_reset(rng_key: jnp.ndarray) -> EnvState:
"""Pure reset - no side effects."""
return EnvState(
position=jnp.zeros(2),
velocity=jnp.zeros(2),
step_count=0,
rng_key=rng_key
)
@jax.jit
def env_step(
state: EnvState,
action: jnp.ndarray
) -> tuple[EnvState, float, bool]:
"""
Pure step function.
- Takes state explicitly
- Returns new state (no mutation)
- RNG key is part of state
"""
# Split RNG key for this step
rng_key, noise_key = jax.random.split(state.rng_key)
# Deterministic physics
new_velocity = state.velocity + action * 0.1
new_position = state.position + new_velocity * 0.1
# Explicit randomness with passed key
noise = jax.random.normal(noise_key, new_position.shape) * 0.01
new_position = new_position + noise
# Create NEW state (immutable)
new_state = EnvState(
position=new_position,
velocity=new_velocity,
step_count=state.step_count + 1,
rng_key=rng_key # Updated key for next step
)
reward = -jnp.sum(new_position ** 2)
done = new_state.step_count >= 100
return new_state, reward, done
# Usage - state flows explicitly
key = jax.random.PRNGKey(42)
state = env_reset(key)
for _ in range(10):
action = jnp.array([0.1, 0.2])
state, reward, done = env_step(state, action)
print(f"Step {state.step_count}: reward = {reward:.3f}")
Handling Conditional Logic Without Branching
The paper notes:
"Logical branching is eliminated by executing all code branches and selecting the required result after calculation"
This is necessary because XLA compiles static graphs - it cannot handle Python if statements.
import jax
import jax.numpy as jnp
# BAD: Python branching - breaks JIT
def bad_reward(collision: bool, at_goal: bool) -> float:
if collision:
return -100.0
elif at_goal:
return 100.0
else:
return -1.0
# GOOD: JAX-compatible conditional using jnp.where
@jax.jit
def good_reward(collision: jnp.ndarray, at_goal: jnp.ndarray) -> jnp.ndarray:
"""All branches computed, result selected."""
# Compute ALL possible rewards
collision_reward = -100.0
goal_reward = 100.0
step_reward = -1.0
# Select based on conditions (no Python branching)
reward = jnp.where(
collision,
collision_reward,
jnp.where(at_goal, goal_reward, step_reward)
)
return reward
# GOOD: Using jax.lax.cond for expensive branches
@jax.jit
def expensive_reward(collision: jnp.ndarray, state: jnp.ndarray) -> jnp.ndarray:
"""Use lax.cond when branches have different compute costs."""
def collision_branch(s):
# Simple computation
return -100.0
def normal_branch(s):
# Expensive computation (only runs if not collision)
complex_reward = jnp.sum(jnp.exp(-s ** 2))
return complex_reward
# Only evaluates one branch at runtime
return jax.lax.cond(collision, collision_branch, normal_branch, state)
Data Padding for Uniform Shapes
The paper pads all scenarios to uniform dimensions:
- 512 agents maximum per scenario
- 151 timesteps (30 seconds at 5 Hz)
- 128 road elements maximum
import jax.numpy as jnp
from typing import NamedTuple
class PaddedScenario(NamedTuple):
"""Fixed-shape scenario for JIT compilation."""
# Agent data: (max_agents, timesteps, features)
agent_positions: jnp.ndarray # (512, 151, 2)
agent_velocities: jnp.ndarray # (512, 151, 2)
agent_valid_mask: jnp.ndarray # (512,) - which agents are real
# Road data: (max_elements, features)
road_points: jnp.ndarray # (128, 10, 2)
road_valid_mask: jnp.ndarray # (128,)
# Metadata
num_valid_agents: int
num_valid_roads: int
def pad_scenario(
positions: jnp.ndarray, # (actual_agents, actual_steps, 2)
max_agents: int = 512,
max_steps: int = 151
) -> jnp.ndarray:
"""Pad variable-size data to fixed shape."""
actual_agents, actual_steps, feat_dim = positions.shape
# Create padded array
padded = jnp.zeros((max_agents, max_steps, feat_dim))
# Copy actual data (JAX-compatible slicing)
padded = padded.at[:actual_agents, :actual_steps, :].set(positions)
return padded
def mask_padded_agents(
rewards: jnp.ndarray, # (512,)
valid_mask: jnp.ndarray # (512,) boolean
) -> jnp.ndarray:
"""Zero out rewards for padded (invalid) agents."""
return jnp.where(valid_mask, rewards, 0.0)
4. Distributed Training Architecture
System Overview
The paper implements an asynchronous actor-learner architecture similar to OpenAI's Dota 2 system:
+--------------------------------------------------+
| DISTRIBUTED TRAINING SYSTEM |
+--------------------------------------------------+
EXPERIENCE FLOW
|
+---------------+---------------+
| | |
+----v----+ +----v----+ +-----v----+
| Actor | | Actor | | Actor |
| Group 1 | | Group 2 | | Group N |
| (GPU 1) | | (GPU 2) | | (GPU N) |
+----+----+ +----+----+ +-----+----+
| | |
| trajectories | trajectories |
| | |
+----v---------------v---------------v----+
| EXPERIENCE BUFFERS |
| (Replay buffers per learner group) |
+----+---------------+---------------+----+
| | |
| | |
+----v----+ +----v----+ +-----v----+
| Learner | | Learner | | Learner |
| Group 1 | | Group 2 | | Group N |
| (GPUs) | | (GPUs) | | (GPUs) |
+----+----+ +----+----+ +-----+----+
| | |
+---------------+---------------+
|
NCCL AllReduce
(Gradient Sync)
|
+------v------+
| UPDATED |
| POLICY |
+------+------+
|
+------------+------------+
| | |
+----v----+ +----v----+ +----v----+
| Actor | | Actor | | Actor |
| (pulls | | (pulls | | (pulls |
| latest) | | latest) | | latest) |
+---------+ +---------+ +---------+
Asynchronous PPO with V-trace
The Problem: In asynchronous training, actors use slightly stale policies. When the learner updates, the collected experience is "off-policy."
The Solution: V-trace importance sampling correction.
import jax
import jax.numpy as jnp
from typing import NamedTuple
class VTraceOutput(NamedTuple):
vs: jnp.ndarray # Corrected value estimates
advantages: jnp.ndarray # Policy gradient advantages
rhos: jnp.ndarray # Truncated importance weights
def compute_vtrace(
behavior_log_probs: jnp.ndarray, # Log probs from actor's policy
target_log_probs: jnp.ndarray, # Log probs from learner's policy
rewards: jnp.ndarray, # (T,)
values: jnp.ndarray, # V(s) estimates (T+1,)
dones: jnp.ndarray, # Episode termination (T,)
gamma: float = 0.99,
rho_clip: float = 1.0, # Importance weight clip
c_clip: float = 1.0 # Trace coefficient clip
) -> VTraceOutput:
"""
V-trace off-policy correction.
Key insight: When actor policy != learner policy,
we need to correct for the distribution mismatch.
"""
T = rewards.shape[0]
# Importance sampling ratios
log_rhos = target_log_probs - behavior_log_probs
rhos = jnp.exp(log_rhos)
# Clip importance weights (controls variance)
clipped_rhos = jnp.minimum(rhos, rho_clip)
clipped_cs = jnp.minimum(rhos, c_clip)
# TD errors with importance correction
# delta_t = rho_t * (r_t + gamma * V(s_{t+1}) - V(s_t))
not_done = 1.0 - dones
deltas = clipped_rhos * (
rewards + gamma * not_done * values[1:] - values[:-1]
)
# Backward pass to compute V-trace targets
# vs_t = V(s_t) + sum_{k=t}^{T-1} gamma^{k-t} * (prod_{i=t}^{k-1} c_i) * delta_k
def scan_fn(carry, inputs):
"""Accumulate discounted corrections backward."""
acc = carry
delta, c, not_d = inputs
acc = delta + gamma * c * not_d * acc
return acc, acc
# Scan backward through time
_, corrections = jax.lax.scan(
scan_fn,
jnp.zeros(()),
(deltas[::-1], clipped_cs[::-1], not_done[::-1])
)
corrections = corrections[::-1] # Reverse back to forward order
# V-trace value estimates
vs = values[:-1] + corrections
# Advantages for policy gradient
# Use rho-clipped importance weights
vs_plus_1 = jnp.concatenate([vs[1:], values[-1:]])
advantages = clipped_rhos * (
rewards + gamma * not_done * vs_plus_1 - values[:-1]
)
return VTraceOutput(vs=vs, advantages=advantages, rhos=clipped_rhos)
# Example usage
T = 32 # Sequence length from paper
key = jax.random.PRNGKey(0)
# Simulated off-policy data
behavior_log_probs = jax.random.normal(key, (T,)) - 1.0 # Actor's policy
target_log_probs = jax.random.normal(key, (T,)) # Learner's policy
rewards = jax.random.uniform(key, (T,))
values = jax.random.uniform(key, (T + 1,))
dones = jnp.zeros(T)
vtrace = compute_vtrace(
behavior_log_probs, target_log_probs,
rewards, values, dones
)
print(f"V-trace targets shape: {vtrace.vs.shape}")
print(f"Advantages shape: {vtrace.advantages.shape}")
PPO Loss with V-trace
import jax
import jax.numpy as jnp
from typing import NamedTuple
class PPOLossOutput(NamedTuple):
total_loss: jnp.ndarray
policy_loss: jnp.ndarray
value_loss: jnp.ndarray
entropy_loss: jnp.ndarray
def ppo_loss(
log_probs: jnp.ndarray, # Current policy log probs
old_log_probs: jnp.ndarray, # Behavior policy log probs
values: jnp.ndarray, # Current value estimates
vtrace_targets: jnp.ndarray, # V-trace corrected targets
advantages: jnp.ndarray, # V-trace advantages
entropy: jnp.ndarray, # Policy entropy
clip_eps: float = 0.3, # From paper
value_coef: float = 0.5,
entropy_coef: float = 0.03 # From paper
) -> PPOLossOutput:
"""
PPO-clip loss with V-trace corrections.
"""
# Policy loss with clipping
ratio = jnp.exp(log_probs - old_log_probs)
clipped_ratio = jnp.clip(ratio, 1.0 - clip_eps, 1.0 + clip_eps)
# Pessimistic policy update
policy_loss = -jnp.minimum(
ratio * advantages,
clipped_ratio * advantages
).mean()
# Value loss (MSE with V-trace targets)
value_loss = jnp.mean((values - vtrace_targets) ** 2)
# Entropy bonus (encourages exploration)
entropy_loss = -entropy.mean()
total_loss = (
policy_loss +
value_coef * value_loss +
entropy_coef * entropy_loss
)
return PPOLossOutput(
total_loss=total_loss,
policy_loss=policy_loss,
value_loss=value_loss,
entropy_loss=entropy_loss
)
Gradient Synchronization Architecture
+--------------------------------------------------------+
| GRADIENT SYNC (NCCL AllReduce) |
+--------------------------------------------------------+
Learner 0 Learner 1 Learner 2
+-------+ +-------+ +-------+
|Grads_0| |Grads_1| |Grads_2|
+---+---+ +---+---+ +---+---+
| | |
| Ring AllReduce (NCCL) |
+------------------+------------------+
|
Average Gradients
|
+------------------+------------------+
| | |
+---v---+ +---v---+ +---v---+
|Avg_G | |Avg_G | |Avg_G |
+---+---+ +---+---+ +---+---+
| | |
Apply to Apply to Apply to
Local Params Local Params Local Params
Implementation in JAX:
- pmap with axis_name='devices'
- lax.pmean(grads, 'devices') for averaging
- All devices get identical gradients
- Synchronized parameter updates
import jax
import jax.numpy as jnp
from jax import pmap, lax
def create_distributed_train_step(model_apply, optimizer):
"""
Create a training step that synchronizes across devices.
"""
def train_step(params, opt_state, batch, rng):
"""Single training step with gradient sync."""
def loss_fn(p):
# Forward pass
outputs = model_apply(p, batch['observations'])
# Compute PPO loss (simplified)
log_probs = outputs['log_probs']
values = outputs['values']
# V-trace would be computed here
advantages = batch['advantages']
vtrace_targets = batch['vtrace_targets']
policy_loss = -jnp.mean(log_probs * advantages)
value_loss = jnp.mean((values - vtrace_targets) ** 2)
return policy_loss + 0.5 * value_loss
loss, grads = jax.value_and_grad(loss_fn)(params)
# CRITICAL: Synchronize gradients across all devices
grads = lax.pmean(grads, axis_name='devices')
loss = lax.pmean(loss, axis_name='devices')
# Apply synchronized gradients
updates, new_opt_state = optimizer.update(grads, opt_state, params)
new_params = optax.apply_updates(params, updates)
return new_params, new_opt_state, loss
# Parallelize across devices
return pmap(train_step, axis_name='devices')
5. Performance Optimization
GPU Utilization Strategy
The paper achieves high GPU utilization through:
- Double-buffered data loading
- Fused simulation + inference
- Asynchronous experience transfer
+--------------------------------------------------------+
| GPU UTILIZATION TIMELINE |
+--------------------------------------------------------+
GPU 0 (Simulation):
|==Sim Batch 0==|==Sim Batch 1==|==Sim Batch 2==|
^ ^
| |
Pre-load B1 Pre-load B2
GPU Memory:
| Batch 0 (running) | Batch 1 (pre-loaded) |
| Batch 2 (loading) |
"To achieve high GPU utilization one scenario batch
is pre-loaded to GPU memory while a simulation of
a different batch is running"
import jax
import jax.numpy as jnp
from typing import Iterator
import threading
from queue import Queue
class DoubleBufferedLoader:
"""
Pre-load next batch while current batch is processing.
Hides data transfer latency behind compute.
"""
def __init__(self, data_iterator: Iterator, prefetch_size: int = 2):
self.data_iter = data_iterator
self.prefetch_queue = Queue(maxsize=prefetch_size)
self.stop_event = threading.Event()
# Start background prefetch thread
self.prefetch_thread = threading.Thread(target=self._prefetch_loop)
self.prefetch_thread.start()
def _prefetch_loop(self):
"""Background thread: load and transfer to GPU."""
for batch in self.data_iter:
if self.stop_event.is_set():
break
# Transfer to GPU in background
gpu_batch = jax.device_put(batch)
self.prefetch_queue.put(gpu_batch)
def get_batch(self):
"""Get pre-loaded GPU batch (blocks if not ready)."""
return self.prefetch_queue.get()
def close(self):
self.stop_event.set()
self.prefetch_thread.join()
Fused Simulation + Inference
The key optimization: compile simulation AND policy inference into a single XLA graph.
import jax
import jax.numpy as jnp
from jax import lax
def create_fused_rollout(env_step_fn, policy_fn):
"""
Fuse environment step and policy inference.
This creates a SINGLE compiled XLA graph that:
1. Steps the environment
2. Computes policy on new state
3. Samples action
4. Repeats
No Python loop overhead, no CPU-GPU transfers between steps.
"""
def single_step(carry, _):
state, rng = carry
# Split RNG
rng, action_rng, step_rng = jax.random.split(rng, 3)
# Policy inference (ON GPU)
action_logits = policy_fn(state.observation)
action = jax.random.categorical(action_rng, action_logits)
log_prob = jax.nn.log_softmax(action_logits)[action]
# Environment step (ON GPU)
new_state, reward, done = env_step_fn(state, action)
output = {
'observation': state.observation,
'action': action,
'log_prob': log_prob,
'reward': reward,
'done': done
}
new_carry = (new_state, rng)
return new_carry, output
@jax.jit
def fused_rollout(initial_state, rng, num_steps):
"""Execute full rollout as single compiled graph."""
initial_carry = (initial_state, rng)
# lax.scan compiles entire loop
final_carry, trajectory = lax.scan(
single_step,
initial_carry,
xs=None, # No input sequence
length=num_steps
)
return trajectory
return fused_rollout
Batch Size Optimization
The paper shows batch size dramatically affects throughput:
| Batch Size | Simulator Time (V100) | Speedup |
|---|---|---|
| 1 | 0.52 ms | 1.0x |
| 16 | 0.82 ms | 9.8x |
import jax
import jax.numpy as jnp
import time
def benchmark_batch_sizes(env_step_fn, batch_sizes, num_iters=100):
"""
Benchmark different batch sizes to find optimal throughput.
Key insight: Larger batches amortize kernel launch overhead
but eventually hit memory limits or diminishing returns.
"""
key = jax.random.PRNGKey(0)
results = {}
for batch_size in batch_sizes:
# Create batched environment step
batched_step = jax.vmap(env_step_fn)
jit_batched_step = jax.jit(batched_step)
# Create dummy data
states = jnp.zeros((batch_size, 64))
actions = jnp.zeros((batch_size, 2))
# Warmup (JIT compilation)
_ = jit_batched_step(states, actions)
# Benchmark
start = time.time()
for _ in range(num_iters):
states, _, _ = jit_batched_step(states, actions)
jax.block_until_ready(states) # Wait for async ops
elapsed = time.time() - start
steps_per_second = (batch_size * num_iters) / elapsed
time_per_step = elapsed / num_iters * 1000 # ms
results[batch_size] = {
'steps_per_second': steps_per_second,
'time_per_step_ms': time_per_step
}
print(f"Batch {batch_size:4d}: {steps_per_second:.0f} steps/s, "
f"{time_per_step:.2f} ms/step")
return results
Near-Linear Scaling Analysis
The paper achieves remarkable scaling efficiency:
| GPUs | Runtime (hours) | Normalized GPU-hours | Efficiency |
|---|---|---|---|
| 8 | 41.58 | 332.64 | 100% |
| 16 | 22.83 | 365.28 | 91% |
| 32 | 11.99 | 383.68 | 87% |
def analyze_scaling_efficiency(gpu_counts, runtimes):
"""
Analyze how efficiently training scales with more GPUs.
Perfect linear scaling: 2x GPUs = 0.5x runtime
Real-world: Communication overhead reduces efficiency
"""
base_gpus = gpu_counts[0]
base_runtime = runtimes[0]
base_gpu_hours = base_gpus * base_runtime
print("Scaling Analysis")
print("-" * 60)
for gpus, runtime in zip(gpu_counts, runtimes):
gpu_hours = gpus * runtime
# Ideal runtime with perfect scaling
ideal_runtime = base_runtime * (base_gpus / gpus)
# Efficiency = ideal_time / actual_time
efficiency = ideal_runtime / runtime * 100
# Overhead = extra GPU-hours beyond baseline
overhead = (gpu_hours - base_gpu_hours) / base_gpu_hours * 100
print(f"GPUs: {gpus:2d} | Runtime: {runtime:5.2f}h | "
f"Efficiency: {efficiency:5.1f}% | Overhead: {overhead:+5.1f}%")
# Data from paper
analyze_scaling_efficiency(
gpu_counts=[8, 16, 32],
runtimes=[41.58, 22.83, 11.99]
)
6. Pre-training Strategy
Why Pre-training Matters
The paper makes a critical observation:
"Pre-training only the policy via behavioral cloning makes it challenging to further train this policy with an actor-critic RL method"
The problem: untrained critic produces garbage value estimates, destabilizing policy gradients.
WITHOUT Value Pre-training:
+-----------------+ +-------------------+ +------------------+
| Pre-trained | | Random Critic | | Garbage |
| Policy (BC) | --> | V(s) = noise | --> | Advantages |
| pi(a|s) | | | | A = R - V |
+-----------------+ +-------------------+ +------------------+
|
v
Policy gradients are
essentially random!
WITH Value Pre-training (Paper's Approach):
+-----------------+ +-------------------+ +------------------+
| Pre-trained | | Pre-trained | | Meaningful |
| Policy (BC) | --> | Critic (returns) | --> | Advantages |
| pi(a|s) | | V(s) ~ G_0 | | A = R - V |
+-----------------+ +-------------------+ +------------------+
|
v
Policy gradients point
toward improvement!
Implementation
import jax
import jax.numpy as jnp
import optax
from typing import NamedTuple
class BCPretrainState(NamedTuple):
params: dict
opt_state: optax.OptState
step: int
def create_bc_pretrainer(model_apply, action_space_size):
"""
Behavioral cloning pre-trainer that trains BOTH policy and value.
Key insight from paper: Value pre-training uses discounted returns
from human demonstrations as targets.
"""
def compute_returns(rewards, dones, gamma=0.99):
"""Compute discounted returns G_0 = sum_{t=0}^T gamma^t * R_t."""
T = rewards.shape[0]
def scan_fn(carry, inputs):
future_return = carry
reward, done = inputs
# Reset return at episode boundary
current_return = reward + gamma * future_return * (1 - done)
return current_return, current_return
# Scan backward
_, returns = jax.lax.scan(
scan_fn,
jnp.zeros(()),
(rewards[::-1], dones[::-1])
)
return returns[::-1]
def bc_loss(params, batch):
"""
Combined behavioral cloning loss.
Loss = CrossEntropy(policy, expert_actions)
+ value_scale * MSE(values, returns)
"""
outputs = model_apply(params, batch['observations'])
# Policy loss: cross-entropy with expert actions
action_logits = outputs['action_logits']
expert_actions = batch['actions']
policy_loss = optax.softmax_cross_entropy_with_integer_labels(
action_logits, expert_actions
).mean()
# Value loss: MSE with computed returns
values = outputs['values']
returns = batch['returns'] # Pre-computed discounted returns
value_loss = jnp.mean((values - returns) ** 2)
# From paper: value_loss_scale = 1e-4
total_loss = policy_loss + 1e-4 * value_loss
return total_loss, {
'policy_loss': policy_loss,
'value_loss': value_loss
}
@jax.jit
def train_step(state: BCPretrainState, batch):
"""Single BC pre-training step."""
(loss, metrics), grads = jax.value_and_grad(
bc_loss, has_aux=True
)(state.params, batch)
updates, new_opt_state = optax.adam(2e-3).update(
grads, state.opt_state, state.params
)
new_params = optax.apply_updates(state.params, updates)
new_state = BCPretrainState(
params=new_params,
opt_state=new_opt_state,
step=state.step + 1
)
return new_state, metrics
return train_step, compute_returns
# Training loop structure
def pretrain_loop(
train_step_fn,
initial_state,
dataset,
num_epochs=20, # From paper
batch_size=32768 # From paper
):
"""
Pre-training loop for behavioral cloning.
Paper trains for 20 epochs with batch size 32768.
"""
state = initial_state
for epoch in range(num_epochs):
epoch_losses = []
for batch in dataset.shuffle().batch(batch_size):
state, metrics = train_step_fn(state, batch)
epoch_losses.append(metrics['policy_loss'])
avg_loss = jnp.mean(jnp.array(epoch_losses))
print(f"Epoch {epoch + 1}/{num_epochs}: Loss = {avg_loss:.4f}")
return state
Value Target Computation
def prepare_bc_dataset(trajectories, gamma=0.99):
"""
Prepare behavioral cloning dataset with value targets.
Each trajectory contains:
- observations: (T, obs_dim)
- actions: (T,) expert actions
- rewards: (T,) per-step rewards
- dones: (T,) episode boundaries
"""
processed = []
for traj in trajectories:
# Compute discounted returns for value targets
returns = compute_discounted_returns(traj['rewards'], traj['dones'], gamma)
processed.append({
'observations': traj['observations'],
'actions': traj['actions'],
'returns': returns # Value pre-training targets!
})
return processed
def compute_discounted_returns(rewards, dones, gamma):
"""G_0 = sum_{t=0}^T gamma^t * R(s_t)"""
T = len(rewards)
returns = jnp.zeros(T)
future_return = 0.0
for t in range(T - 1, -1, -1):
if dones[t]:
future_return = 0.0
future_return = rewards[t] + gamma * future_return
returns = returns.at[t].set(future_return)
return returns
Pre-training Impact (Ablation)
From the paper's ablation study:
| Method | Steps to 2% Failure Rate |
|---|---|
| Without pre-training | 2.5B steps |
| With BC + Value pre-training | 0.5B steps |
5x faster convergence with proper pre-training!
7. Interactive Code Examples
Example 1: Complete JAX Environment
import jax
import jax.numpy as jnp
from jax import lax
from typing import NamedTuple
from functools import partial
# ============================================================
# PURE FUNCTIONAL DRIVING ENVIRONMENT
# ============================================================
class VehicleState(NamedTuple):
"""Immutable vehicle state."""
x: jnp.ndarray # Position x
y: jnp.ndarray # Position y
vx: jnp.ndarray # Velocity x
vy: jnp.ndarray # Velocity y
heading: jnp.ndarray # Heading angle
step: jnp.ndarray # Current timestep
class EnvConfig(NamedTuple):
"""Environment configuration (static)."""
dt: float = 0.1 # Timestep
max_steps: int = 151 # From paper
max_speed: float = 15.0 # m/s
goal_x: float = 100.0
goal_y: float = 0.0
def reset(rng: jnp.ndarray, config: EnvConfig) -> VehicleState:
"""Initialize vehicle at origin."""
return VehicleState(
x=jnp.array(0.0),
y=jnp.array(0.0),
vx=jnp.array(0.0),
vy=jnp.array(0.0),
heading=jnp.array(0.0),
step=jnp.array(0)
)
def step(
state: VehicleState,
action: jnp.ndarray, # [acceleration, steering_rate]
config: EnvConfig
) -> tuple[VehicleState, jnp.ndarray, jnp.ndarray]:
"""
Pure functional environment step.
Returns: (new_state, reward, done)
"""
accel, steer_rate = action[0], action[1]
# Kinematic bicycle model (simplified)
new_heading = state.heading + steer_rate * config.dt
new_vx = state.vx + accel * jnp.cos(new_heading) * config.dt
new_vy = state.vy + accel * jnp.sin(new_heading) * config.dt
# Clip speed
speed = jnp.sqrt(new_vx**2 + new_vy**2)
scale = jnp.minimum(1.0, config.max_speed / (speed + 1e-6))
new_vx = new_vx * scale
new_vy = new_vy * scale
# Update position
new_x = state.x + new_vx * config.dt
new_y = state.y + new_vy * config.dt
new_state = VehicleState(
x=new_x, y=new_y,
vx=new_vx, vy=new_vy,
heading=new_heading,
step=state.step + 1
)
# Reward: progress toward goal + penalties
dist_to_goal = jnp.sqrt(
(new_x - config.goal_x)**2 + (new_y - config.goal_y)**2
)
progress_reward = -dist_to_goal * 0.01
# Penalties (from paper)
speed_penalty = -jnp.maximum(0, speed - config.max_speed)**2 * 0.1
accel_penalty = -accel**2 * 0.01
steer_penalty = -steer_rate**2 * 0.01
reward = progress_reward + speed_penalty + accel_penalty + steer_penalty
# Done conditions
at_goal = dist_to_goal < 1.0
timeout = state.step >= config.max_steps
done = jnp.logical_or(at_goal, timeout)
return new_state, reward, done
# Vectorize over batch of vehicles
batched_reset = jax.vmap(reset, in_axes=(0, None))
batched_step = jax.vmap(step, in_axes=(0, 0, None))
# JIT compile
jit_reset = jax.jit(batched_reset, static_argnums=(1,))
jit_step = jax.jit(batched_step, static_argnums=(2,))
# ============================================================
# USAGE EXAMPLE
# ============================================================
def main():
config = EnvConfig()
batch_size = 512 # From paper
# Initialize
rngs = jax.random.split(jax.random.PRNGKey(0), batch_size)
states = jit_reset(rngs, config)
print(f"Initial positions: x={states.x[:5]}")
# Random actions
actions = jax.random.uniform(
jax.random.PRNGKey(1),
(batch_size, 2),
minval=-1.0, maxval=1.0
)
# Step
new_states, rewards, dones = jit_step(states, actions, config)
print(f"Rewards: {rewards[:5]}")
print(f"New positions: x={new_states.x[:5]}")
if __name__ == "__main__":
main()
Example 2: Attention-based Policy Network
import jax
import jax.numpy as jnp
from flax import linen as nn
from typing import Sequence
class MultiHeadAttention(nn.Module):
"""Multi-head attention for agent interactions."""
num_heads: int
head_dim: int
@nn.compact
def __call__(self, queries, keys, values, mask=None):
# Project to heads
q = nn.Dense(self.num_heads * self.head_dim)(queries)
k = nn.Dense(self.num_heads * self.head_dim)(keys)
v = nn.Dense(self.num_heads * self.head_dim)(values)
# Reshape: (batch, seq, num_heads, head_dim)
batch, q_len, _ = q.shape
_, kv_len, _ = k.shape
q = q.reshape(batch, q_len, self.num_heads, self.head_dim)
k = k.reshape(batch, kv_len, self.num_heads, self.head_dim)
v = v.reshape(batch, kv_len, self.num_heads, self.head_dim)
# Attention scores
scale = jnp.sqrt(self.head_dim)
scores = jnp.einsum('bqhd,bkhd->bhqk', q, k) / scale
if mask is not None:
scores = jnp.where(mask, scores, -1e9)
attn_weights = jax.nn.softmax(scores, axis=-1)
# Weighted sum
output = jnp.einsum('bhqk,bkhd->bqhd', attn_weights, v)
output = output.reshape(batch, q_len, -1)
return nn.Dense(self.num_heads * self.head_dim)(output)
class PerceiverEncoder(nn.Module):
"""
Perceiver-style encoder from the paper.
Uses cross-attention between modalities and self-attention within.
"""
hidden_dim: int = 256
num_heads: int = 8
num_layers: int = 4
@nn.compact
def __call__(
self,
ego_features, # (batch, ego_dim)
agent_features, # (batch, num_agents, agent_dim)
road_features, # (batch, num_roads, road_dim)
agent_mask, # (batch, num_agents)
road_mask # (batch, num_roads)
):
batch_size = ego_features.shape[0]
# Project all inputs to hidden dim
ego = nn.Dense(self.hidden_dim)(ego_features)[:, None, :]
agents = nn.Dense(self.hidden_dim)(agent_features)
roads = nn.Dense(self.hidden_dim)(road_features)
# Concatenate for self-attention
# Shape: (batch, 1 + num_agents + num_roads, hidden_dim)
tokens = jnp.concatenate([ego, agents, roads], axis=1)
# Combined mask
ego_mask = jnp.ones((batch_size, 1), dtype=bool)
full_mask = jnp.concatenate([ego_mask, agent_mask, road_mask], axis=1)
attn_mask = full_mask[:, None, None, :] # Broadcast for attention
# Transformer layers
for _ in range(self.num_layers):
# Self-attention
attended = MultiHeadAttention(
num_heads=self.num_heads,
head_dim=self.hidden_dim // self.num_heads
)(tokens, tokens, tokens, attn_mask)
tokens = nn.LayerNorm()(tokens + attended)
# FFN
ffn_out = nn.Dense(self.hidden_dim * 4)(tokens)
ffn_out = nn.gelu(ffn_out)
ffn_out = nn.Dense(self.hidden_dim)(ffn_out)
tokens = nn.LayerNorm()(tokens + ffn_out)
# Return ego token (first position)
return tokens[:, 0, :]
class DrivingPolicy(nn.Module):
"""
Full driving policy with shared encoder.
Architecture from paper:
- Shared Perceiver encoder
- Separate policy and value heads
"""
hidden_dim: int = 256
num_accel_actions: int = 71 # From paper
num_steer_actions: int = 21
@nn.compact
def __call__(self, observations):
# Shared encoder
encoded = PerceiverEncoder(hidden_dim=self.hidden_dim)(
ego_features=observations['ego'],
agent_features=observations['agents'],
road_features=observations['roads'],
agent_mask=observations['agent_mask'],
road_mask=observations['road_mask']
)
# Policy head (Dense Residual Blocks)
policy_hidden = nn.Dense(self.hidden_dim)(encoded)
policy_hidden = nn.relu(policy_hidden)
policy_hidden = policy_hidden + nn.Dense(self.hidden_dim)(
nn.relu(nn.Dense(self.hidden_dim)(policy_hidden))
)
# Action logits
accel_logits = nn.Dense(self.num_accel_actions)(policy_hidden)
steer_logits = nn.Dense(self.num_steer_actions)(policy_hidden)
# Value head
value_hidden = nn.Dense(self.hidden_dim)(encoded)
value_hidden = nn.relu(value_hidden)
value = nn.Dense(1)(value_hidden).squeeze(-1)
return {
'accel_logits': accel_logits,
'steer_logits': steer_logits,
'value': value
}
# ============================================================
# USAGE
# ============================================================
def create_dummy_observations(batch_size, num_agents=64, num_roads=128):
"""Create dummy observations matching paper's structure."""
return {
'ego': jnp.zeros((batch_size, 10)),
'agents': jnp.zeros((batch_size, num_agents, 20)),
'roads': jnp.zeros((batch_size, num_roads, 15)),
'agent_mask': jnp.ones((batch_size, num_agents), dtype=bool),
'road_mask': jnp.ones((batch_size, num_roads), dtype=bool)
}
# Initialize and run
policy = DrivingPolicy()
obs = create_dummy_observations(batch_size=32)
params = policy.init(jax.random.PRNGKey(0), obs)
outputs = policy.apply(params, obs)
print(f"Accel logits shape: {outputs['accel_logits'].shape}")
print(f"Steer logits shape: {outputs['steer_logits'].shape}")
print(f"Value shape: {outputs['value'].shape}")
Example 3: Complete Training Loop
import jax
import jax.numpy as jnp
from jax import lax
import optax
from typing import NamedTuple
from functools import partial
# ============================================================
# TRAINING STATE
# ============================================================
class TrainState(NamedTuple):
params: dict
opt_state: optax.OptState
rng: jnp.ndarray
step: int
# ============================================================
# PPO TRAINING STEP
# ============================================================
def create_train_step(
policy_apply,
env_step,
env_config,
lr=5.6e-5, # From paper
gamma=0.99,
gae_lambda=0.95,
clip_eps=0.3, # From paper
entropy_coef=0.03, # From paper
value_coef=0.5
):
"""Create JIT-compiled training step."""
optimizer = optax.adam(lr)
def compute_gae(rewards, values, dones):
"""Generalized Advantage Estimation."""
T = rewards.shape[0]
advantages = jnp.zeros(T)
last_advantage = 0.0
def scan_fn(last_adv, inputs):
reward, value, next_value, done = inputs
delta = reward + gamma * next_value * (1 - done) - value
advantage = delta + gamma * gae_lambda * (1 - done) * last_adv
return advantage, advantage
next_values = jnp.concatenate([values[1:], values[-1:]])
_, advantages = lax.scan(
scan_fn,
jnp.array(0.0),
(rewards[::-1], values[::-1], next_values[::-1], dones[::-1])
)
return advantages[::-1]
def ppo_loss(params, batch, old_log_probs, advantages, returns):
"""PPO clipped objective."""
outputs = policy_apply(params, batch['observations'])
# Current log probs
accel_log_probs = jax.nn.log_softmax(outputs['accel_logits'])
steer_log_probs = jax.nn.log_softmax(outputs['steer_logits'])
accel_lp = accel_log_probs[
jnp.arange(len(batch['accel_actions'])),
batch['accel_actions']
]
steer_lp = steer_log_probs[
jnp.arange(len(batch['steer_actions'])),
batch['steer_actions']
]
log_probs = accel_lp + steer_lp
# Policy loss
ratio = jnp.exp(log_probs - old_log_probs)
clipped_ratio = jnp.clip(ratio, 1 - clip_eps, 1 + clip_eps)
policy_loss = -jnp.minimum(
ratio * advantages,
clipped_ratio * advantages
).mean()
# Value loss
value_loss = jnp.mean((outputs['value'] - returns) ** 2)
# Entropy bonus
accel_entropy = -jnp.sum(
jnp.exp(accel_log_probs) * accel_log_probs, axis=-1
).mean()
steer_entropy = -jnp.sum(
jnp.exp(steer_log_probs) * steer_log_probs, axis=-1
).mean()
entropy = accel_entropy + steer_entropy
total_loss = policy_loss + value_coef * value_loss - entropy_coef * entropy
return total_loss, {
'policy_loss': policy_loss,
'value_loss': value_loss,
'entropy': entropy
}
@jax.jit
def train_step(state: TrainState, batch):
"""Single training iteration."""
# Compute old log probs (before update)
old_outputs = policy_apply(state.params, batch['observations'])
old_accel_lp = jax.nn.log_softmax(old_outputs['accel_logits'])
old_steer_lp = jax.nn.log_softmax(old_outputs['steer_logits'])
old_log_probs = (
old_accel_lp[jnp.arange(len(batch['accel_actions'])), batch['accel_actions']] +
old_steer_lp[jnp.arange(len(batch['steer_actions'])), batch['steer_actions']]
)
# Compute advantages and returns
advantages = compute_gae(
batch['rewards'],
old_outputs['value'],
batch['dones']
)
returns = advantages + old_outputs['value']
# Normalize advantages
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
# Gradient step
(loss, metrics), grads = jax.value_and_grad(
ppo_loss, has_aux=True
)(state.params, batch, old_log_probs, advantages, returns)
updates, new_opt_state = optimizer.update(
grads, state.opt_state, state.params
)
new_params = optax.apply_updates(state.params, updates)
new_state = TrainState(
params=new_params,
opt_state=new_opt_state,
rng=state.rng,
step=state.step + 1
)
return new_state, metrics
return train_step
# ============================================================
# ROLLOUT COLLECTION
# ============================================================
def create_rollout_fn(policy_apply, env_step, env_reset, config, rollout_length=32):
"""Create function to collect experience."""
@jax.jit
def collect_rollout(params, env_state, rng):
"""Collect rollout_length steps of experience."""
def step_fn(carry, _):
state, rng = carry
rng, action_rng = jax.random.split(rng)
# Get action from policy
obs = state_to_obs(state) # Convert state to observation dict
outputs = policy_apply(params, obs)
accel_action = jax.random.categorical(
action_rng, outputs['accel_logits']
)
steer_action = jax.random.categorical(
action_rng, outputs['steer_logits']
)
# Execute in environment
action = actions_to_controls(accel_action, steer_action)
new_state, reward, done = env_step(state, action, config)
# Store transition
transition = {
'observations': obs,
'accel_actions': accel_action,
'steer_actions': steer_action,
'rewards': reward,
'dones': done,
'values': outputs['value']
}
# Reset if done (for continuous collection)
rng, reset_rng = jax.random.split(rng)
new_state = jax.lax.cond(
done,
lambda: env_reset(reset_rng, config),
lambda: new_state
)
return (new_state, rng), transition
(final_state, _), trajectory = lax.scan(
step_fn,
(env_state, rng),
xs=None,
length=rollout_length
)
return final_state, trajectory
return collect_rollout
8. Benchmarks Analysis
Simulator Performance
| Simulator | Device | Batch 1 | Batch 16 | Throughput Gain |
|---|---|---|---|---|
| Waymax | V100 | 0.75 ms | 2.48 ms | 6.5x |
| Paper | V100 | 0.52 ms | 0.82 ms | 19.5x |
Key insights:
- JIT compilation reduces per-step overhead (0.75 -> 0.52 ms)
- Batch efficiency is dramatically better (3.3x speedup vs 3x for Waymax)
- Effective throughput: 19,512 steps/second vs 6,452 steps/second
def analyze_throughput(single_time_ms, batch_time_ms, batch_size):
"""Analyze batching efficiency."""
# Ideal: batch_time = single_time (perfect parallelism)
ideal_time = single_time_ms
# Actual overhead
overhead_ms = batch_time_ms - single_time_ms
overhead_per_item = overhead_ms / batch_size
# Throughput
throughput = batch_size / (batch_time_ms / 1000) # steps/sec
# Efficiency
efficiency = (single_time_ms * batch_size) / batch_time_ms
print(f"Batch size: {batch_size}")
print(f" Single step: {single_time_ms:.2f} ms")
print(f" Batched step: {batch_time_ms:.2f} ms")
print(f" Overhead per item: {overhead_per_item:.3f} ms")
print(f" Throughput: {throughput:.0f} steps/sec")
print(f" Efficiency: {efficiency:.1%}")
print("=== Paper's Simulator ===")
analyze_throughput(0.52, 0.82, 16)
print("\n=== Waymax Baseline ===")
analyze_throughput(0.75, 2.48, 16)
Training Scaling Efficiency
+----------------------------------------------------------+
| SCALING EFFICIENCY CHART |
+----------------------------------------------------------+
Runtime (hours)
|
45 + *
|
40 + \.
| \. Ideal (linear)
35 + \. ----------------
| \.
30 + \.
| \.
25 + \.-- Actual -+
| \. |
20 + * | Gap = overhead
| \. |
15 + \. |
| \. |
10 + \.*
|
5 +
|
+----+----+----+----+----+----+
8 12 16 20 24 32
GPUs
Overhead Analysis:
| GPUs | Actual (h) | Ideal (h) | Overhead |
|---|---|---|---|
| 8 | 41.58 | 41.58 | 0% |
| 16 | 22.83 | 20.79 | 9.8% |
| 32 | 11.99 | 10.40 | 15.3% |
The overhead comes from:
- Gradient synchronization via NCCL AllReduce
- Data loading across more workers
- Experience buffer coordination
Model Size vs Data Size
+----------------------------------------------------------+
| FAILURE RATE BY MODEL & DATA SIZE |
+----------------------------------------------------------+
Failure Rate (%)
|
3.0+ .
| : .
2.5+ : . Model sizes:
| : . --- 0.75M params
2.0+ : . -.- 2.5M params
| : . ... 25M params
1.5+ : .
| : .
1.0+ : . . . . . .
| : : . . . . .
0.5+ : . . . . . . . . *
|
+------+------+------+------+------+------+
600 1200 2000 3500 5000 6000
Data (hours)
Key Insight: Larger models only help with MORE data
Critical finding from paper:
"Increasing the model size only helps when sufficient real-world driving data is available"
This follows scaling laws observed in language models - there's an optimal compute allocation between model size and data size.
9. Common Pitfalls
Pitfall 1: Python Control Flow in JIT
# BAD: Python if statement
@jax.jit
def bad_reward(x):
if x > 0: # This is traced ONCE at compile time!
return x * 2
else:
return x * -1
# What happens:
# - First call with x=5: traces if branch, always returns x*2
# - Call with x=-3: STILL returns x*2 (wrong!)
# GOOD: Use jnp.where or lax.cond
@jax.jit
def good_reward(x):
return jnp.where(x > 0, x * 2, x * -1)
# Or for expensive branches:
@jax.jit
def good_reward_cond(x):
return jax.lax.cond(
x > 0,
lambda x: x * 2,
lambda x: x * -1,
x
)
Pitfall 2: In-Place Mutation
# BAD: NumPy-style mutation
@jax.jit
def bad_update(arr, idx, value):
arr[idx] = value # ERROR: JAX arrays are immutable
return arr
# GOOD: Functional update with .at[]
@jax.jit
def good_update(arr, idx, value):
return arr.at[idx].set(value)
# GOOD: For multiple updates
@jax.jit
def good_batch_update(arr, indices, values):
return arr.at[indices].set(values)
Pitfall 3: Random Number Generation
# BAD: Using same key multiple times
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (10,))
y = jax.random.normal(key, (10,)) # SAME VALUES as x!
# GOOD: Split keys
key = jax.random.PRNGKey(0)
key, x_key, y_key = jax.random.split(key, 3)
x = jax.random.normal(x_key, (10,))
y = jax.random.normal(y_key, (10,)) # Different values
# GOOD: Inside JIT functions, pass key explicitly
@jax.jit
def sample_action(key, logits):
return jax.random.categorical(key, logits)
Pitfall 4: Dynamic Shapes
# BAD: Variable-length sequences
@jax.jit
def bad_sum_variable(x):
# Shape of x changes -> recompilation every time!
return jnp.sum(x)
# GOOD: Pad to fixed size with mask
@jax.jit
def good_sum_padded(x, mask):
# x always same shape, mask indicates valid elements
return jnp.sum(x * mask)
# The paper pads all data:
# - 512 agents (fixed)
# - 151 timesteps (fixed)
# - 128 road elements (fixed)
Pitfall 5: Forgetting block_until_ready
import time
# BAD: Timing without synchronization
@jax.jit
def compute(x):
return jnp.dot(x, x.T)
x = jax.random.normal(jax.random.PRNGKey(0), (1000, 1000))
start = time.time()
result = compute(x) # Returns immediately (async!)
print(f"Time: {time.time() - start:.4f}s") # Wrong! Too fast
# GOOD: Block until computation completes
start = time.time()
result = compute(x)
result.block_until_ready() # Wait for GPU
print(f"Time: {time.time() - start:.4f}s") # Correct timing
Pitfall 6: Inefficient Tree Operations
# BAD: Accessing pytree elements in loop
def bad_update_params(params, grads, lr):
for key in params:
params[key] = params[key] - lr * grads[key]
return params
# GOOD: Use jax.tree.map (vectorized)
def good_update_params(params, grads, lr):
return jax.tree.map(
lambda p, g: p - lr * g,
params, grads
)
# BETTER: Use optax for optimizer state
import optax
optimizer = optax.adam(1e-3)
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
Pitfall 7: Gradient Through Non-Differentiable Operations
# BAD: argmax is not differentiable
@jax.jit
def bad_discrete_choice(logits):
action = jnp.argmax(logits) # No gradient!
return action
# GOOD: Use Gumbel-Softmax for differentiable sampling
@jax.jit
def good_discrete_choice(key, logits, temperature=1.0):
gumbel_noise = jax.random.gumbel(key, logits.shape)
soft_sample = jax.nn.softmax((logits + gumbel_noise) / temperature)
return soft_sample # Differentiable!
# GOOD: For RL, use log_prob gradients (REINFORCE/PPO)
def policy_gradient_loss(log_probs, advantages):
return -jnp.mean(log_probs * jax.lax.stop_gradient(advantages))
Pitfall 8: Memory Leaks with JIT Cache
# BAD: Creating new functions in loop (cache grows forever)
for i in range(1000):
@jax.jit
def compute(x):
return x + i # Different closure each time!
result = compute(data)
# GOOD: Use static_argnums or define function once
@partial(jax.jit, static_argnums=(1,))
def compute_with_static(x, i):
return x + i
for i in range(1000):
result = compute_with_static(data, i)
# BETTER: Vectorize over the parameter
@jax.jit
def compute_all(x, i_values):
return jax.vmap(lambda i: x + i)(i_values)
10. Hands-On Exercises
Exercise 1: Implement a Pure Functional Environment
Goal: Create a JAX-compatible gridworld environment.
"""
TODO: Implement a pure functional gridworld environment.
Requirements:
1. State includes: position (x, y), goal position, step count
2. Actions: 0=up, 1=down, 2=left, 3=right
3. Reward: -1 per step, +10 for reaching goal
4. Episode ends at goal or after 100 steps
Constraints:
- All functions must be JIT-compilable
- No Python control flow (if/for) inside JIT functions
- State must be immutable (NamedTuple or dataclass)
"""
import jax
import jax.numpy as jnp
from typing import NamedTuple
class GridState(NamedTuple):
x: jnp.ndarray
y: jnp.ndarray
goal_x: jnp.ndarray
goal_y: jnp.ndarray
step: jnp.ndarray
def reset(rng: jnp.ndarray) -> GridState:
"""Initialize environment. Goal at (7, 7), start at (0, 0)."""
# TODO: Implement
pass
def step(state: GridState, action: jnp.ndarray) -> tuple:
"""
Execute action and return (new_state, reward, done).
Hint: Use jnp.where for conditional logic
"""
# TODO: Implement
# Action mapping: 0=up (+y), 1=down (-y), 2=left (-x), 3=right (+x)
pass
# Test your implementation
def test_env():
reset_fn = jax.jit(reset)
step_fn = jax.jit(step)
state = reset_fn(jax.random.PRNGKey(0))
print(f"Initial state: ({state.x}, {state.y})")
# Move right 7 times, then up 7 times
for action in [3]*7 + [0]*7:
state, reward, done = step_fn(state, jnp.array(action))
print(f"Action {action}: pos=({state.x}, {state.y}), r={reward}, done={done}")
Exercise 2: Batch Rollouts with scan
Goal: Collect parallel rollouts using vmap and scan.
"""
TODO: Implement batched rollout collection.
Requirements:
1. Use lax.scan for time steps
2. Use vmap for parallel environments
3. Return complete trajectories for PPO training
"""
def collect_batch_rollouts(
policy_fn,
env_reset,
env_step,
params,
rng,
num_envs: int = 64,
rollout_length: int = 32
):
"""
Collect rollouts from num_envs parallel environments.
Returns:
Dictionary with keys:
- observations: (num_envs, rollout_length, obs_dim)
- actions: (num_envs, rollout_length)
- rewards: (num_envs, rollout_length)
- dones: (num_envs, rollout_length)
- values: (num_envs, rollout_length)
- log_probs: (num_envs, rollout_length)
"""
# TODO: Implement
# Hint:
# 1. Initialize num_envs environments with vmap(reset)
# 2. Define single_step function for scan
# 3. Use vmap over environments, scan over time
pass
Exercise 3: V-trace Implementation
Goal: Implement V-trace off-policy correction.
"""
TODO: Implement V-trace algorithm.
Reference: IMPALA paper (Espeholt et al., 2018)
"""
def vtrace_targets(
behavior_log_probs: jnp.ndarray, # (T,)
target_log_probs: jnp.ndarray, # (T,)
rewards: jnp.ndarray, # (T,)
values: jnp.ndarray, # (T+1,) including bootstrap
dones: jnp.ndarray, # (T,)
gamma: float = 0.99,
rho_bar: float = 1.0,
c_bar: float = 1.0
) -> tuple:
"""
Compute V-trace targets and advantages.
Returns:
vs: V-trace value targets (T,)
advantages: Policy gradient advantages (T,)
Algorithm:
1. Compute importance weights rho = exp(target_lp - behavior_lp)
2. Clip rho to rho_bar, c to c_bar
3. Compute TD errors: delta = rho * (r + gamma * V(s') - V(s))
4. Accumulate backward: vs = V(s) + sum of discounted deltas
5. Advantages = rho * (r + gamma * vs' - V(s))
"""
# TODO: Implement
pass
# Test with known values
def test_vtrace():
T = 10
# On-policy case: behavior = target, should match GAE
log_probs = jnp.zeros(T)
rewards = jnp.ones(T)
values = jnp.zeros(T + 1)
dones = jnp.zeros(T)
vs, advantages = vtrace_targets(
log_probs, log_probs, rewards, values, dones
)
print(f"V-trace targets: {vs}")
print(f"Advantages: {advantages}")
Exercise 4: Distributed Gradient Averaging
Goal: Implement synchronized training across devices.
"""
TODO: Implement multi-device training with gradient synchronization.
Use pmap and lax.pmean to average gradients across devices.
"""
import jax
from jax import pmap, lax
def create_distributed_trainer(model_apply, optimizer):
"""
Create a training function that works across multiple devices.
Requirements:
1. Each device processes its own batch
2. Gradients are averaged across devices
3. All devices end up with identical parameters
"""
def train_step(params, opt_state, batch):
"""
Single training step with gradient synchronization.
Args:
params: Replicated parameters (n_devices, ...)
opt_state: Replicated optimizer state
batch: Sharded batch (n_devices, batch_per_device, ...)
"""
# TODO: Implement
# 1. Compute loss and gradients
# 2. Average gradients with lax.pmean
# 3. Apply updates
pass
# TODO: Wrap with pmap, specify axis_name for pmean
return None
def test_distributed():
n_devices = jax.local_device_count()
print(f"Testing with {n_devices} devices")
# Create simple linear model
def model_apply(params, x):
return jnp.dot(x, params['w']) + params['b']
# TODO: Initialize replicated params and test training step
pass
Exercise 5: Full Mini-Pipeline
Goal: Build a complete training pipeline combining all concepts.
"""
TODO: Build a complete mini RL training pipeline.
Components:
1. Pure functional CartPole-like environment
2. Simple MLP policy with value head
3. PPO loss with GAE
4. Batched rollout collection
5. Training loop with logging
Target: Train for 100 iterations, achieve >100 average reward.
"""
def build_training_pipeline():
"""
Build and return all components of the training pipeline.
Returns:
dict with keys:
- 'env_reset': JIT-compiled reset function
- 'env_step': JIT-compiled step function
- 'policy': Policy network
- 'collect_rollouts': Rollout collection function
- 'train_step': Single PPO update step
- 'train': Full training loop function
"""
# TODO: Implement all components
pass
def main():
pipeline = build_training_pipeline()
# Initialize
rng = jax.random.PRNGKey(0)
# TODO: Initialize params, opt_state
# Training loop
for iteration in range(100):
# TODO: Collect rollouts
# TODO: Compute advantages
# TODO: PPO update
# TODO: Log metrics
pass
if __name__ == "__main__":
main()
11. Interview Questions
Beginner Level
Q1: What is the difference between jax.jit and jax.vmap?
-
jax.jit: Just-In-Time compilation. Traces a function and compiles it to optimized XLA code. Eliminates Python interpreter overhead. The function runs on GPU/TPU as a single fused kernel. -
jax.vmap: Vectorizing map. Takes a function that operates on single examples and automatically transforms it to operate on batches. Adds a batch dimension without changing the function's implementation.
Key insight: They compose! jax.jit(jax.vmap(fn)) gives you a compiled, batched function.
@jax.jit
def fast_fn(x):
return x ** 2
batched_fast_fn = jax.jit(jax.vmap(fast_fn))
</details>
Q2: Why can't you use Python if statements inside JIT-compiled functions?
JIT compilation happens at trace time, not runtime. When JAX traces a function:
- It executes the Python code once with abstract "tracer" values
- Records all operations into a computation graph
- Compiles this graph
Python if statements are evaluated during tracing with tracer values, not actual data. The condition is traced once and "baked in" to the compiled function.
Solution: Use jnp.where() for element-wise conditionals or jax.lax.cond() for scalar conditionals.
# Bad
@jax.jit
def bad(x):
if x > 0: # Evaluated at trace time!
return x
return -x
# Good
@jax.jit
def good(x):
return jnp.where(x > 0, x, -x)
</details>
Q3: What is the purpose of jax.random.split() and why is it necessary?
JAX uses explicit PRNG (Pseudo-Random Number Generator) keys for reproducibility and parallelization:
- Explicit state: Unlike NumPy's hidden global RNG state, JAX requires you to pass RNG keys explicitly
- Functional purity: Using the same key twice gives the same "random" numbers
- Parallelization: Split keys can be used independently on different devices
split() creates new independent keys from a parent key:
key = jax.random.PRNGKey(0)
key, subkey1, subkey2 = jax.random.split(key, 3)
# subkey1 and subkey2 produce different random sequences
# key is now a new key for future splits
This enables reproducible yet statistically independent random sampling across parallel computations.
</details>Intermediate Level
Q4: Explain the actor-learner architecture used in the paper. What is the off-policy problem and how does V-trace solve it?
<details> <summary>Answer</summary>Architecture:
- Actors: Run the policy, collect experience, send to learners
- Learners: Receive experience, compute gradients, update policy
- Asynchronous: Actors don't wait for learner updates
Off-policy problem: When an actor collects experience with policy v1, by the time the learner processes it, the policy might be v5. The collected experience is "off-policy" - it doesn't represent the current policy's behavior.
V-trace solution: V-trace uses importance sampling to correct for the policy mismatch:
- Compute importance weight:
rho = pi_current(a|s) / pi_behavior(a|s) - Clip weights to reduce variance:
rho_clipped = min(rho, rho_bar) - Use clipped weights in TD error:
delta = rho_clipped * (r + gamma*V(s') - V(s)) - Accumulate corrected values backward through time
This gives unbiased (when rho_bar -> infinity) value estimates even with stale behavior policies.
</details>Q5: The paper pads all scenarios to fixed sizes (512 agents, 151 timesteps). Why is this necessary and what are the trade-offs?
<details> <summary>Answer</summary>Why necessary:
- XLA compilation requires static shapes
- JIT compiles a new function for each unique shape
- Dynamic shapes would cause constant recompilation
- vmap/pmap require uniform dimensions across batch
Trade-offs:
Pros:
- Single compiled kernel handles all scenarios
- Efficient GPU memory allocation
- Enables vectorized operations across batch
Cons:
- Memory overhead for small scenarios
- Compute waste on padded elements
- Need to track valid elements with masks
Mitigation:
- Use masks to zero out padded element contributions
- Choose pad sizes based on data distribution (512 covers 99%+ of real scenarios)
- Mask-aware attention/aggregation operations
# Masked mean example
def masked_mean(values, mask):
return jnp.sum(values * mask) / jnp.sum(mask)
</details>
Q6: How does jax.lax.scan differ from a Python for loop, and why is it crucial for this paper's performance?
Python for loop:
- Interpreted at runtime
- Each iteration is a separate GPU kernel call
- CPU-GPU synchronization between iterations
- Gradients computed via list accumulation
jax.lax.scan:
- Compiled into a single XLA while loop
- Entire loop runs as one fused kernel
- No CPU involvement during execution
- Automatic gradient computation via scan's adjoint
Performance impact: The paper processes 151 timesteps per rollout. With a Python loop, that's 151 kernel launches with synchronization. With scan, it's one kernel that runs the entire trajectory.
For batched rollouts (16 scenarios x 512 agents x 151 steps), scan enables:
- Single compilation
- No Python overhead
- Efficient GPU memory management
- Automatic backprop through time
# 151 kernel launches
for t in range(151):
state = step(state, actions[t])
# Single fused kernel
final_state, trajectory = jax.lax.scan(step, initial_state, actions)
</details>
Advanced Level
Q7: The paper achieves 87% efficiency when scaling from 8 to 32 GPUs. What are the sources of overhead and how would you improve scaling further?
<details> <summary>Answer</summary>Sources of overhead (13% loss):
- Gradient synchronization: AllReduce across 32 GPUs has O(log n) latency
- Bandwidth saturation: Large model gradients compete for network bandwidth
- Stragglers: Synchronous updates wait for slowest worker
- Data loading: More workers compete for I/O
Improvement strategies:
-
Gradient compression:
- Quantize gradients to lower precision (FP16, INT8)
- Top-k sparsification (only send largest gradients)
-
Asynchronous SGD:
- Don't wait for all gradients (paper uses this with V-trace correction)
- Local SGD: sync every N steps instead of every step
-
Pipeline parallelism:
- Split model across GPUs
- Overlap forward/backward of different microbatches
-
Better communication topology:
- Ring AllReduce is O(n) bandwidth but O(1) latency
- Hierarchical AllReduce for multi-node
-
Overlap computation and communication:
- Start AllReduce on earlier layers while computing later layers
- Requires careful dependency management
# Gradient compression example
def compress_gradients(grads, top_k_ratio=0.1):
flat = jax.tree.leaves(grads)
k = int(len(flat) * top_k_ratio)
top_k_indices = jnp.argsort(jnp.abs(flat))[-k:]
sparse = jnp.zeros_like(flat).at[top_k_indices].set(flat[top_k_indices])
return jax.tree.unflatten(jax.tree.structure(grads), sparse)
</details>
Q8: The paper pre-trains both policy and value networks. Design an experiment to quantify the contribution of value pre-training vs policy pre-training.
<details> <summary>Answer</summary>Experimental design:
| Condition | Policy Pre-train | Value Pre-train | Description |
|---|---|---|---|
| Baseline | No | No | Random init |
| Policy-only | Yes | No | BC for policy, random value |
| Value-only | No | Yes | Random policy, returns for value |
| Full | Yes | Yes | Paper's approach |
Metrics:
- Training curves: reward vs steps
- Time to threshold: steps to reach X% performance
- Final performance: asymptotic reward
- Training stability: variance across seeds
Hypothesis from paper:
- Policy-only will be unstable (random value -> noisy advantages)
- Value-only won't help much (random policy explores poorly)
- Full pre-training gives 5x faster convergence
Implementation:
def run_ablation(config):
results = {}
for pretrain_policy in [True, False]:
for pretrain_value in [True, False]:
name = f"policy={pretrain_policy}_value={pretrain_value}"
# Pre-training phase
if pretrain_policy or pretrain_value:
params = behavioral_cloning(
demo_data,
train_policy=pretrain_policy,
train_value=pretrain_value
)
else:
params = random_init()
# RL training
curve = train_ppo(params, max_steps=2.5e9)
results[name] = curve
return results
Expected findings:
- Value pre-training is crucial for actor-critic stability
- Policy pre-training speeds up exploration
- Both together are synergistic, not just additive
Q9: Explain how you would debug a training run where loss is NaN after switching from 1 GPU to 8 GPUs.
<details> <summary>Answer</summary>Systematic debugging approach:
- Verify basic pmap correctness:
# Check all devices receive same initial params
@pmap
def check_params(params):
return jax.tree.map(lambda x: jnp.sum(x), params)
sums = check_params(replicated_params)
assert all(s == sums[0] for s in sums), "Params not synchronized!"
- Check gradient synchronization:
# Print gradient norms per device before/after pmean
def debug_train_step(params, batch):
loss, grads = jax.value_and_grad(loss_fn)(params)
grad_norm_before = jax.tree.map(lambda g: jnp.linalg.norm(g), grads)
jax.debug.print("Grad norm before sync: {}", grad_norm_before)
grads = lax.pmean(grads, 'devices')
grad_norm_after = jax.tree.map(lambda g: jnp.linalg.norm(g), grads)
jax.debug.print("Grad norm after sync: {}", grad_norm_after)
return grads
- Check for effective batch size mismatch:
# 8 GPUs = 8x effective batch size
# May need to scale learning rate: lr_new = lr_old * sqrt(8)
# Or adjust batch normalization statistics
- Check data sharding:
# Ensure each device gets different data
def check_data_sharding(batch):
# batch shape: (n_devices, batch_per_device, ...)
for i in range(n_devices):
for j in range(i+1, n_devices):
overlap = jnp.sum(batch[i] == batch[j])
assert overlap < threshold, f"Devices {i},{j} have duplicate data!"
- Look for numerical instabilities magnified by scale:
# Common culprits:
# - log(0) in log_probs -> add epsilon
# - exp(large) in softmax -> use log-sum-exp trick
# - Division by zero in normalization -> check masks
# Add numerical guards
log_probs = jnp.log(probs + 1e-8)
advantages = (adv - mean) / (std + 1e-8)
- Check random key handling:
# Each device needs different RNG keys
keys = jax.random.split(master_key, n_devices)
# NOT: replicated_key = jnp.stack([key] * n_devices)
Most likely culprits:
- Learning rate not scaled for larger effective batch
- Gradient explosion from pmean (sum not mean?)
- Data duplication across devices
- RNG key reuse causing identical samples
Q10: Design a system to achieve 10x the scale of this paper (25B agent steps instead of 2.5B). What are the key bottlenecks and how would you address them?
<details> <summary>Answer</summary>Key bottlenecks at 10x scale:
-
Data throughput:
- 10x more experience to collect and process
- Solution: More actors, faster simulators, better I/O
-
Gradient computation:
- 10x more gradient steps (assuming same batch size)
- Solution: Larger batches, more learners
-
Communication overhead:
- More devices = more synchronization
- Solution: Hierarchical communication, local SGD
-
Memory:
- Replay buffers grow with experience
- Solution: On-policy (no buffer) or distributed storage
Architecture for 25B steps:
+------------------------------------------------------------------+
| 10x SCALE ARCHITECTURE |
+------------------------------------------------------------------+
Tier 1: Actor Pods (100 pods x 8 GPUs each = 800 GPUs)
+-----------+ +-----------+ +-----------+
| Actor Pod | | Actor Pod | ... | Actor Pod |
| 8 GPUs | | 8 GPUs | | 8 GPUs |
+-----------+ +-----------+ +-----------+
| | |
+-------+------+-------+-----------+
|
Experience Stream (Kafka/Redis)
|
+-------------+-------------+
| | |
v v v
+-----------+ +-----------+ +-----------+
| Learner | | Learner | | Learner |
| Shard 0 | | Shard 1 | | Shard 2 |
| 16 GPUs | | 16 GPUs | | 16 GPUs |
+-----------+ +-----------+ +-----------+
| | |
+------+------+------+------+
|
Model Parallelism
(if model > GPU memory)
Key design decisions:
-
Decouple actors and learners completely:
- Actors push to message queue
- Learners pull async
- No blocking synchronization
-
Hierarchical gradient sync:
- AllReduce within learner shard (fast)
- Periodic sync across shards (slower, less frequent)
-
Population-based training:
- Multiple policies evolving in parallel
- Cross-pollination of best performers
-
Importance sampling buffer:
- Prioritize rare/difficult scenarios
- Learned difficulty estimator
-
Mixed-precision training:
- FP16 for forward pass and gradients
- FP32 master weights
- 2x memory savings, 2x throughput
Estimated resources:
- ~1000 GPUs
- ~1 week training time
- ~100TB experience data throughput
- Cost: ~$500k-1M cloud compute
Verification:
- Linear speedup tests at each scale doubling
- Convergence parity checks (same final performance)
- Ablations on buffer size, sync frequency
Summary
This paper demonstrates that autonomous driving policies can be dramatically improved through scale:
-
JAX enables scale: JIT, vmap, pmap, scan eliminate Python overhead and enable massive parallelization
-
Pure functional design: Immutable state, explicit RNG, no side effects enable compilation and parallelization
-
Distributed training: Async actor-learner with V-trace correction achieves near-linear GPU scaling
-
Pre-training matters: BC + value pre-training gives 5x faster RL convergence
-
Data-model scaling: Larger models only help with sufficient data (scaling laws apply)
The engineering principles here apply broadly to any large-scale RL system. Master these concepts and you'll be well-prepared for ML infrastructure roles at top companies.
References
- Original Paper: arXiv:2312.15122
- JAX Documentation: jax.readthedocs.io
- IMPALA (V-trace): arXiv:1802.01561
- PPO: arXiv:1707.06347
- Waymax: waymo.com/open/data/motion