Distributed Training Deep Dive: Scaling RL for Autonomous Driving
Focus: Infrastructure for training at billions of environment steps Key Papers: PureJaxRL, V-Max, IMPALA, Isaac Gym, Brax Read Time: 55 min
Table of Contents
- Executive Summary
- The Scale Challenge
- Distributed Training Architectures
- JAX-Specific Distributed Training
- Key Systems and Papers
- Performance Optimization
- Infrastructure Components
- Practical Implementation
- Code Examples
- Interview Questions
- Further Reading
Executive Summary
The Core Problem
Training autonomous driving agents requires an astronomical number of environment interactions:
┌─────────────────────────────────────────────────────────────────────────┐
│ SCALE OF AV TRAINING │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ Research Scale: ~30 million agent steps │
│ Production Scale: ~2.5 billion agent steps │
│ Scaling Factor: ~100x │
│ │
│ At 10 Hz simulation: │
│ • 30M steps = 35 days of simulated driving │
│ • 2.5B steps = 8 years of simulated driving │
│ │
│ Without acceleration, this is intractable. │
│ │
└─────────────────────────────────────────────────────────────────────────┘
The Solution: End-to-End GPU Training
Traditional RL suffers from CPU-GPU data transfer bottlenecks:
Traditional Approach (Slow):
┌─────────────┐ Copy ┌─────────────┐
│ CPU │ ◄──────────► │ GPU │
│ Environment │ (10K+ ns) │ Neural Net │
└─────────────┘ └─────────────┘
↑ ↑
└───────────────────────────┘
Bottleneck: Data transfer
End-to-End GPU (Fast):
┌─────────────────────────────────────────┐
│ GPU │
│ ┌─────────────┐ ┌─────────────┐ │
│ │ Environment │◄──►│ Neural Net │ │
│ │ (JAX) │ │ (JAX) │ │
│ └─────────────┘ └─────────────┘ │
│ ↑ ↑ │
│ └─────────────────┘ │
│ All on-device, no copy │
└─────────────────────────────────────────┘
Result: PureJaxRL achieves 4000x speedup over traditional implementations.
The Scale Challenge
Why Billions of Steps?
- Long-Tail Coverage: Rare scenarios require extensive sampling
- Scaling Laws: Larger models need proportionally more data
- Multi-Agent Complexity: N agents = O(N²) interaction space
- Generalization: Diverse scenarios prevent overfitting
Scaling Laws in AV RL
From "Scaling Is All You Need" (2024):
| Model Size | Dataset (hours) | Failure Rate | Notes |
|---|---|---|---|
| Small (1M) | 600 | Baseline | Quick to train |
| Medium (5M) | 600 | -15% | Better sample efficiency |
| Large (25M) | 600 | +10% | Underfitting |
| Large (25M) | 6000 | -64% | Best performance |
Key Insight: Larger models underperform on small datasets but excel with larger datasets. The intersection point determines optimal model-data balance.
Compute Requirements
# Rough compute estimates for AV RL training
compute_requirements = {
'research_experiment': {
'agent_steps': 30_000_000,
'hardware': 'Single GPU (V100)',
'training_time': '2-4 hours',
},
'medium_scale': {
'agent_steps': 200_000_000,
'hardware': '8x V100 GPUs',
'training_time': '12-48 hours',
},
'production_scale': {
'agent_steps': 2_500_000_000,
'hardware': '32+ GPUs or TPU pod',
'training_time': 'Days to weeks',
},
}
Distributed Training Architectures
Data Parallelism
The dominant paradigm for RL training:
┌─────────────────────────────────────────────────────────────────────────┐
│ DATA PARALLELISM │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ Device 0 Device 1 Device 2 Device 3 │
│ ┌───────────┐ ┌───────────┐ ┌───────────┐ ┌───────────┐│
│ │ Full Model│ │ Full Model│ │ Full Model│ │ Full Model││
│ └─────┬─────┘ └─────┬─────┘ └─────┬─────┘ └─────┬─────┘│
│ │ │ │ │ │
│ ▼ ▼ ▼ ▼ │
│ ┌───────────┐ ┌───────────┐ ┌───────────┐ ┌───────────┐│
│ │ Data Shard│ │ Data Shard│ │ Data Shard│ │ Data Shard││
│ │ 0 │ │ 1 │ │ 2 │ │ 3 ││
│ └─────┬─────┘ └─────┬─────┘ └─────┬─────┘ └─────┬─────┘│
│ │ │ │ │ │
│ ▼ ▼ ▼ ▼ │
│ ┌───────────┐ ┌───────────┐ ┌───────────┐ ┌───────────┐│
│ │ Gradients │ │ Gradients │ │ Gradients │ │ Gradients ││
│ └─────┬─────┘ └─────┬─────┘ └─────┬─────┘ └─────┬─────┘│
│ │ │ │ │ │
│ └──────────────────┴──────────────────┴──────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────┐ │
│ │ All-Reduce │ │
│ │ (Average) │ │
│ └─────────────┘ │
│ │ │
│ ▼ │
│ Synchronized Update │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Advantages:
- Simple to implement
- Scales linearly with data
- Well-supported in frameworks
Disadvantages:
- Synchronization overhead
- Memory limited by single device (for model)
- Stragglers slow entire system
Actor-Learner Architecture (IMPALA)
Decouples experience collection from policy updates:
┌─────────────────────────────────────────────────────────────────────────┐
│ IMPALA ARCHITECTURE │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ ACTORS (Many) LEARNER (Centralized) │
│ ┌─────────────┐ ┌─────────────────────┐ │
│ │ Actor 0 │────────────────────────►│ │ │
│ │ Env + Policy│ Trajectories │ Policy Update │ │
│ └─────────────┘ │ │ │
│ ┌─────────────┐ │ • Compute loss │ │
│ │ Actor 1 │────────────────────────►│ • Update params │ │
│ │ Env + Policy│ │ • V-trace │ │
│ └─────────────┘ │ correction │ │
│ ┌─────────────┐ │ │ │
│ │ Actor N │────────────────────────►│ │ │
│ │ Env + Policy│ └──────────┬──────────┘ │
│ └─────────────┘ │ │
│ ▲ │ │
│ └────────────────────────────────────────────┘ │
│ Updated Policy │
│ │
│ V-trace Correction: │
│ Handles off-policy data from stale actor policies │
│ Importance sampling with truncated weights │
│ │
│ Performance: 250,000 frames/second (30x single-machine A3C) │
│ │
└─────────────────────────────────────────────────────────────────────────┘
SEED RL: Centralized Inference
Addresses IMPALA's limitation of running inference on actors:
┌─────────────────────────────────────────────────────────────────────────┐
│ SEED RL ARCHITECTURE │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ ACTORS (CPU/TPU) LEARNER (TPU) │
│ ┌─────────────┐ ┌─────────────────────┐ │
│ │ Actor 0 │ │ │ │
│ │ Env Only │◄───────────────────────►│ Policy Inference │ │
│ └─────────────┘ Observations/Actions │ + │ │
│ ┌─────────────┐ (Streaming gRPC) │ Policy Training │ │
│ │ Actor 1 │◄───────────────────────►│ │ │
│ │ Env Only │ │ All NN compute │ │
│ └─────────────┘ │ on accelerator │ │
│ ┌─────────────┐ │ │ │
│ │ Actor N │◄───────────────────────►│ │ │
│ │ Env Only │ └─────────────────────┘ │
│ └─────────────┘ │
│ │
│ Benefits: │
│ • No model parameters sent to actors │
│ • Efficient batched inference │
│ • Lower bandwidth requirements │
│ • Better TPU utilization │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Async vs. Sync Training
| Aspect | Synchronous | Asynchronous |
|---|---|---|
| Convergence | Stable, predictable | May oscillate |
| Throughput | Limited by stragglers | Higher |
| Gradient freshness | Always current | May be stale |
| Implementation | Simpler | More complex |
| Best for | Small clusters | Large clusters |
Hybrid Approach (Stale Synchronous Parallel):
- Allow bounded staleness (e.g., max 2 iterations behind)
- Balance throughput and convergence
- Common in production systems
JAX-Specific Distributed Training
jax.pmap (Parallel Map)
import jax
import jax.numpy as jnp
from jax import pmap
# Replicate model across devices
num_devices = jax.device_count()
replicated_params = jax.tree.map(
lambda x: jnp.broadcast_to(x, (num_devices,) + x.shape),
params
)
@pmap
def train_step(params, batch):
"""Run on each device with its local batch."""
def loss_fn(p):
predictions = model.apply(p, batch['observations'])
return jnp.mean((predictions - batch['targets']) ** 2)
loss, grads = jax.value_and_grad(loss_fn)(params)
# Average gradients across all devices
grads = jax.lax.pmean(grads, axis_name='batch')
# Update parameters (same on all devices)
params = jax.tree.map(lambda p, g: p - lr * g, params, grads)
return params, loss
# Training loop
for epoch in range(num_epochs):
for batch in dataloader:
# Split batch across devices
batch = jax.tree.map(
lambda x: x.reshape(num_devices, -1, *x.shape[1:]),
batch
)
replicated_params, loss = train_step(replicated_params, batch)
Device Mesh and Sharding
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
# Create 2D mesh: (data_parallel, model_parallel)
devices = jax.devices()
mesh = Mesh(
devices.reshape(2, 4), # 8 devices -> 2x4 grid
axis_names=('data', 'model')
)
# Shard parameters across model dimension
param_sharding = NamedSharding(mesh, P(None, 'model'))
sharded_params = jax.device_put(params, param_sharding)
# Shard data across data dimension
data_sharding = NamedSharding(mesh, P('data', None))
sharded_batch = jax.device_put(batch, data_sharding)
# JIT-compiled function respects sharding
@jax.jit
def forward(params, batch):
return model.apply(params, batch)
output = forward(sharded_params, sharded_batch)
# Output sharding inferred automatically
Collective Operations
import jax.lax as lax
# Inside pmapped function:
# All-reduce: Sum across devices, result on all devices
total = lax.psum(local_value, axis_name='batch')
# All-reduce mean: Average across devices
mean = lax.pmean(local_value, axis_name='batch')
# All-gather: Collect shards from all devices
full_data = lax.all_gather(local_shard, axis_name='batch')
# Reduce-scatter: Sum then scatter results
scattered = lax.psum_scatter(values, axis_name='batch')
# Ring permutation: Shift data between devices
shifted = lax.ppermute(
x,
axis_name='batch',
perm=[(i, (i + 1) % n) for i in range(n)]
)
TPU Pod Patterns
# Multi-host initialization
jax.distributed.initialize()
print(f"Global devices: {jax.device_count()}")
print(f"Local devices: {jax.local_device_count()}")
print(f"Process index: {jax.process_index()}")
# Create global array spanning all hosts
from jax.experimental import multihost_utils
global_mesh = Mesh(
jax.devices(),
axis_names=('data',)
)
# Each host creates its shard
local_data = load_data_for_host(jax.process_index())
# Combine into global array
global_array = multihost_utils.host_local_array_to_global_array(
local_data,
global_mesh,
P('data')
)
Key Systems and Papers
PureJaxRL
Breakthrough: Implements entire RL pipeline in JAX, enabling full JIT compilation.
# PureJaxRL-style end-to-end training
import jax
import jax.numpy as jnp
from flax import linen as nn
class SimplePPO:
def __init__(self, env, policy_network, value_network):
self.env = env # JAX-native environment!
self.policy = policy_network
self.value = value_network
@jax.jit
def train_step(self, state, params, opt_state, key):
"""Entire training step is JIT-compiled."""
# Collect rollouts (vectorized across environments)
def rollout_step(carry, _):
env_state, params, key = carry
key, subkey = jax.random.split(key)
# Policy inference
obs = self.env.get_obs(env_state)
action, log_prob = self.policy.sample(params, obs, subkey)
# Environment step (JAX-native!)
env_state, reward, done = self.env.step(env_state, action)
return (env_state, params, key), (obs, action, reward, log_prob, done)
# Vectorized rollout using scan
(env_state, _, _), trajectories = jax.lax.scan(
rollout_step,
(state.env_state, params, key),
None,
length=self.rollout_length
)
# Compute advantages and returns
advantages, returns = compute_gae(trajectories, params)
# PPO update (multiple epochs over collected data)
def ppo_epoch(carry, _):
params, opt_state, key = carry
key, subkey = jax.random.split(key)
# Shuffle and create minibatches
indices = jax.random.permutation(subkey, self.batch_size)
minibatches = create_minibatches(trajectories, indices)
# Update on each minibatch
for mb in minibatches:
params, opt_state, loss = self.ppo_update(
params, opt_state, mb, advantages
)
return (params, opt_state, key), loss
(params, opt_state, _), losses = jax.lax.scan(
ppo_epoch,
(params, opt_state, key),
None,
length=self.ppo_epochs
)
return state._replace(env_state=env_state), params, opt_state
Performance:
- CartPole: 1M frames in 0.05s (vs 46s traditional)
- Speedup: ~1000x per environment
- With vectorization: ~4000x total
V-Max
Production-grade RL pipeline for Waymax:
# V-Max training setup
from vmax import VMaxEnv, VmapWrapper, AutoResetWrapper
from vmax.algorithms import SAC
# Load scenarios from multiple datasets
scenarios = ScenarioMax.load(
datasets=['womd', 'nuplan', 'argoverse'],
split='train'
)
# Create vectorized environment
base_env = VMaxEnv(scenarios, config)
env = VmapWrapper(AutoResetWrapper(base_env), num_envs=256)
# Configure SAC
sac_config = SACConfig(
learning_rate=3e-4,
discount=0.99,
tau=0.005,
batch_size=256,
replay_buffer_size=1_000_000,
)
# Train
agent = SAC(env, sac_config)
for step in range(total_steps):
metrics = agent.train_step()
if step % log_interval == 0:
print(f"Step {step}: {metrics}")
Results:
- Single L4 GPU: 25M steps in 36 hours
- SAC achieves 97.44% accuracy, 1.74% collision rate
Isaac Gym / Isaac Lab
GPU-accelerated physics simulation:
# Isaac Lab multi-GPU setup
from isaaclab.envs import ManagerBasedRLEnv
from isaaclab.utils.runner import Runner
# Configuration
env_cfg = SpotEnvCfg(
num_envs=4096, # Per GPU
device="cuda",
)
# Multi-GPU training
runner = Runner(
env=env_cfg,
agent=PPOAgent,
num_gpus=4, # Distribute across GPUs
)
runner.train(total_timesteps=10_000_000)
# ~90,000 FPS on RTX A6000
Brax
Differentiable physics for JAX:
import brax
from brax import envs
from brax.training.agents.ppo import train as ppo_train
# Create vectorized environment
env = envs.create('ant', batch_size=2048)
# Train with PPO
make_inference_fn, params, metrics = ppo_train(
environment=env,
num_timesteps=10_000_000,
episode_length=1000,
num_envs=2048,
learning_rate=3e-4,
# Runs entirely on GPU/TPU
)
# Brax achieves hundreds of millions of steps/second on TPU pods
Performance Optimization
GPU Memory Optimization
# 1. Gradient Checkpointing
from jax.checkpoint import checkpoint
@checkpoint # Recompute activations during backward pass
def memory_efficient_layer(params, x):
x = nn.Dense(1024)(x)
x = nn.relu(x)
x = nn.Dense(1024)(x)
return x
# 2. Mixed Precision
import jax.numpy as jnp
def forward_mixed_precision(params, x):
# Convert to bfloat16 for compute
x = x.astype(jnp.bfloat16)
params_bf16 = jax.tree.map(lambda p: p.astype(jnp.bfloat16), params)
# Compute in bfloat16
output = model.apply(params_bf16, x)
# Convert back for loss computation
return output.astype(jnp.float32)
# 3. Gradient Accumulation
def train_with_accumulation(params, batches, accumulation_steps):
accumulated_grads = jax.tree.map(jnp.zeros_like, params)
for i, batch in enumerate(batches[:accumulation_steps]):
grads = compute_gradients(params, batch)
accumulated_grads = jax.tree.map(
lambda a, g: a + g / accumulation_steps,
accumulated_grads, grads
)
# Single update with accumulated gradients
params = apply_gradients(params, accumulated_grads)
return params
Profiling and Bottleneck Identification
# Start profiler
jax.profiler.start_trace("/tmp/tensorboard")
# Run training
for step in range(100):
params, metrics = train_step(params, next(data_iter))
# Stop profiler
jax.profiler.stop_trace()
# View in TensorBoard:
# tensorboard --logdir=/tmp/tensorboard
# Programmatic analysis
import jax.profiler
with jax.profiler.trace("/tmp/jax-trace"):
result = train_step(params, batch)
result.block_until_ready()
# Key things to look for:
# 1. XLA compilation time (first step)
# 2. Device-to-host copies (should be minimal)
# 3. AllReduce duration (communication overhead)
# 4. Kernel execution gaps (memory bottlenecks)
Infrastructure Components
Replay Buffers at Scale
from typing import NamedTuple
import jax.numpy as jnp
class ReplayBuffer(NamedTuple):
observations: jnp.ndarray
actions: jnp.ndarray
rewards: jnp.ndarray
next_observations: jnp.ndarray
dones: jnp.ndarray
priorities: jnp.ndarray # For prioritized replay
position: int
size: int
def create_replay_buffer(capacity: int, obs_shape: tuple, action_shape: tuple):
"""Create pre-allocated replay buffer on device."""
return ReplayBuffer(
observations=jnp.zeros((capacity,) + obs_shape),
actions=jnp.zeros((capacity,) + action_shape),
rewards=jnp.zeros(capacity),
next_observations=jnp.zeros((capacity,) + obs_shape),
dones=jnp.zeros(capacity, dtype=bool),
priorities=jnp.ones(capacity),
position=0,
size=0,
)
@jax.jit
def add_transition(buffer: ReplayBuffer, transition: dict) -> ReplayBuffer:
"""Add transition to buffer (JIT-compiled)."""
pos = buffer.position
new_buffer = buffer._replace(
observations=buffer.observations.at[pos].set(transition['obs']),
actions=buffer.actions.at[pos].set(transition['action']),
rewards=buffer.rewards.at[pos].set(transition['reward']),
next_observations=buffer.next_observations.at[pos].set(transition['next_obs']),
dones=buffer.dones.at[pos].set(transition['done']),
priorities=buffer.priorities.at[pos].set(1.0), # Max priority for new
position=(pos + 1) % buffer.observations.shape[0],
size=jnp.minimum(buffer.size + 1, buffer.observations.shape[0]),
)
return new_buffer
@jax.jit
def sample_batch(buffer: ReplayBuffer, key: jax.random.PRNGKey, batch_size: int):
"""Sample batch with prioritized replay."""
# Compute sampling probabilities
probs = buffer.priorities[:buffer.size] / buffer.priorities[:buffer.size].sum()
# Sample indices
indices = jax.random.choice(key, buffer.size, shape=(batch_size,), p=probs)
return {
'observations': buffer.observations[indices],
'actions': buffer.actions[indices],
'rewards': buffer.rewards[indices],
'next_observations': buffer.next_observations[indices],
'dones': buffer.dones[indices],
}
Checkpointing with Orbax
import orbax.checkpoint as ocp
# Create checkpointer
checkpointer = ocp.PyTreeCheckpointer()
# Save checkpoint (async for non-blocking)
def save_checkpoint(step, params, opt_state, metrics):
checkpoint_path = f"/checkpoints/step_{step}"
ckpt = {
'params': params,
'opt_state': opt_state,
'metrics': metrics,
'step': step,
}
# Async save - returns immediately
checkpointer.save(
checkpoint_path,
ckpt,
save_args=ocp.SaveArgs(aggregate=True)
)
# Restore checkpoint
def restore_checkpoint(path, target_sharding=None):
return checkpointer.restore(
path,
restore_args=ocp.RestoreArgs(
restore_type=ocp.RestoreType.CHECKPOINT,
sharding=target_sharding, # Restore with correct sharding
)
)
# Best practice: checkpoint periodically
for step in range(total_steps):
params, metrics = train_step(params, batch)
if step % checkpoint_interval == 0:
save_checkpoint(step, params, opt_state, metrics)
Practical Implementation
Single GPU to Multi-GPU Migration
# Step 1: Single GPU baseline
def train_single_gpu():
env = create_env(num_envs=32)
params = init_params()
opt_state = optimizer.init(params)
for step in range(total_steps):
batch = collect_rollout(env, params)
params, opt_state = update(params, opt_state, batch)
# Step 2: Add data parallelism
def train_multi_gpu():
num_devices = jax.device_count()
# Scale environments per device
env = create_env(num_envs=32 * num_devices)
# Replicate params
params = init_params()
params = jax.device_put_replicated(params, jax.devices())
opt_state = jax.device_put_replicated(optimizer.init(params[0]), jax.devices())
@jax.pmap
def parallel_update(params, opt_state, batch):
grads = compute_gradients(params, batch)
grads = jax.lax.pmean(grads, 'batch') # Average across devices
params, opt_state = apply_gradients(params, opt_state, grads)
return params, opt_state
for step in range(total_steps):
# Collect across all devices
batch = collect_rollout(env, params)
# Reshape for pmap: (devices, batch_per_device, ...)
batch = reshape_for_pmap(batch, num_devices)
params, opt_state = parallel_update(params, opt_state, batch)
Scaling Checklist
# Before scaling, verify:
scaling_checklist = {
# Learning rate scaling
'learning_rate': base_lr * sqrt(num_devices), # Linear or sqrt scaling
# Batch size scaling
'effective_batch_size': base_batch * num_devices,
# Verify convergence matches single GPU
'convergence_test': compare_loss_curves(single_gpu, multi_gpu),
# Check for communication bottlenecks
'comm_overhead': measure_allreduce_time(),
# Memory per device
'memory_usage': profile_memory_per_device(),
# Determinism
'reproducibility': verify_same_results_given_seed(),
}
Common Pitfalls
| Pitfall | Symptom | Solution |
|---|---|---|
| OOM during backward | CUDA OOM error | Gradient checkpointing |
| Slow first step | 60s+ first iteration | Expected (XLA compilation) |
| Different results per run | Non-reproducible training | Set all RNG seeds |
| Linear scaling doesn't hold | Diminishing returns | Check communication overhead |
| Gradient explosion | NaN losses | Gradient clipping, lower LR |
Interview Questions
Conceptual Questions
Q1: Explain why CPU-GPU data transfer is the primary bottleneck in traditional RL training.
Expected Answer:
- RL requires many environment steps (millions to billions)
- Each step requires: obs → policy → action → env → obs
- If env is on CPU and policy on GPU, every step requires:
- Copy observation CPU → GPU (~10K ns)
- Run policy inference
- Copy action GPU → CPU (~10K ns)
- Environment step
- These copies dominate total time
- Solution: Run everything on GPU (PureJaxRL, Brax, Isaac Gym)
Q2: Compare IMPALA and SEED RL architectures. When would you choose each?
Expected Answer:
| Aspect | IMPALA | SEED RL |
|---|---|---|
| Inference location | Actors | Learner |
| Bandwidth | High (params to actors) | Low (obs/actions only) |
| Actor hardware | Need GPU/TPU | CPU sufficient |
| Latency | Lower | Higher (network round trip) |
| Best for | Many actors, GPU-rich | TPU pods, limited actor GPUs |
Q3: How does V-trace handle off-policy data in actor-learner architectures?
Expected Answer:
- Actors use stale policies (π_old)
- Learner trains with current policy (π)
- V-trace uses importance sampling to correct:
- ρ = π(a|s) / π_old(a|s)
- Truncate: ρ̄ = min(ρ, c)
- Corrects for policy drift without full importance weighting
- Enables stable training with highly off-policy data
Technical Questions
Q4: Design a distributed training system for 2.5B agent steps with fault tolerance.
Expected Answer:
Architecture:
1. Coordinator
- Tracks training progress
- Manages checkpoints
- Handles worker failures
2. Workers (8-32 GPUs)
- Data parallel training
- Local replay buffers
- Periodic sync with coordinator
3. Fault Tolerance
- Checkpoint every 10K steps
- Worker heartbeats
- On failure: restore from checkpoint, redistribute work
4. Performance Targets
- 2.5B steps / (8 GPUs × 100K steps/hour) ≈ 3125 hours single GPU
- With 8 GPUs: ~390 hours = ~16 days
- With 32 GPUs: ~100 hours = ~4 days
Q5: Explain when to use gradient accumulation vs. larger batch sizes.
Expected Answer:
-
Gradient accumulation: When batch doesn't fit in memory
- Simulates larger batch without memory increase
- Extra forward/backward passes (slower)
- Use when: memory-constrained, need large effective batch
-
Larger batch sizes: When memory allows
- More efficient (single pass)
- Better GPU utilization
- Use when: have sufficient memory
-
Learning rate adjustment: Both need LR scaling
- Linear scaling: LR *= batch_scale_factor
- Square root scaling: LR *= sqrt(batch_scale_factor)
Further Reading
Essential Papers
-
"PureJaxRL" (2023) - Chris Lu et al.
- End-to-end JAX RL with 4000x speedup
- github.com/luchris429/purejaxrl
-
"V-Max" (2025) - Valeo AI
- Production RL pipeline for Waymax
- arxiv.org/abs/2503.08388
-
"IMPALA" (2018) - DeepMind
- Scalable distributed deep RL
- arxiv.org/abs/1802.01561
-
"Brax" (2021) - Google
- Differentiable physics engine
- arxiv.org/abs/2106.13281
-
"Isaac Gym" (2021) - NVIDIA
- High-performance GPU simulation
- arxiv.org/abs/2108.10470
Documentation
Code Repositories
Summary: Key Takeaways
-
Scale matters exponentially - Larger models need proportionally more data; production requires billions of steps.
-
End-to-end GPU training is transformative - Eliminating CPU-GPU copies enables 1000-4000x speedups.
-
Actor-learner separation enables scale - IMPALA/SEED architectures maximize throughput while handling off-policy data.
-
JAX primitives are powerful - pmap, vmap, scan, and sharding enable flexible parallelization strategies.
-
Infrastructure is critical - Replay buffers, checkpointing, and profiling are essential for production systems.
-
Async training trades convergence for throughput - Use V-trace or similar for off-policy correction.
-
Scaling requires systematic approach - Learning rate adjustment, communication profiling, and reproducibility checks are essential.
Last updated: January 2025