Back to all papers
Deep Dive #855 min read

Distributed Training Deep Dive

Scaling RL to billions of steps: PureJaxRL, actor-learner architectures, and GPU-accelerated simulation infrastructure.

Distributed Training Deep Dive: Scaling RL for Autonomous Driving

Focus: Infrastructure for training at billions of environment steps Key Papers: PureJaxRL, V-Max, IMPALA, Isaac Gym, Brax Read Time: 55 min


Table of Contents

  1. Executive Summary
  2. The Scale Challenge
  3. Distributed Training Architectures
  4. JAX-Specific Distributed Training
  5. Key Systems and Papers
  6. Performance Optimization
  7. Infrastructure Components
  8. Practical Implementation
  9. Code Examples
  10. Interview Questions
  11. Further Reading

Executive Summary

The Core Problem

Training autonomous driving agents requires an astronomical number of environment interactions:

┌─────────────────────────────────────────────────────────────────────────┐
│                    SCALE OF AV TRAINING                                  │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│   Research Scale:        ~30 million agent steps                        │
│   Production Scale:      ~2.5 billion agent steps                       │
│   Scaling Factor:        ~100x                                          │
│                                                                          │
│   At 10 Hz simulation:                                                   │
│   • 30M steps = 35 days of simulated driving                            │
│   • 2.5B steps = 8 years of simulated driving                           │
│                                                                          │
│   Without acceleration, this is intractable.                            │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

The Solution: End-to-End GPU Training

Traditional RL suffers from CPU-GPU data transfer bottlenecks:

Traditional Approach (Slow):
┌─────────────┐     Copy      ┌─────────────┐
│   CPU       │ ◄──────────► │   GPU       │
│ Environment │   (10K+ ns)   │ Neural Net  │
└─────────────┘               └─────────────┘
        ↑                           ↑
        └───────────────────────────┘
          Bottleneck: Data transfer

End-to-End GPU (Fast):
┌─────────────────────────────────────────┐
│              GPU                         │
│  ┌─────────────┐    ┌─────────────┐    │
│  │ Environment │◄──►│ Neural Net  │    │
│  │   (JAX)     │    │   (JAX)     │    │
│  └─────────────┘    └─────────────┘    │
│         ↑                 ↑             │
│         └─────────────────┘             │
│           All on-device, no copy        │
└─────────────────────────────────────────┘

Result: PureJaxRL achieves 4000x speedup over traditional implementations.


The Scale Challenge

Why Billions of Steps?

  1. Long-Tail Coverage: Rare scenarios require extensive sampling
  2. Scaling Laws: Larger models need proportionally more data
  3. Multi-Agent Complexity: N agents = O(N²) interaction space
  4. Generalization: Diverse scenarios prevent overfitting

Scaling Laws in AV RL

From "Scaling Is All You Need" (2024):

Model SizeDataset (hours)Failure RateNotes
Small (1M)600BaselineQuick to train
Medium (5M)600-15%Better sample efficiency
Large (25M)600+10%Underfitting
Large (25M)6000-64%Best performance

Key Insight: Larger models underperform on small datasets but excel with larger datasets. The intersection point determines optimal model-data balance.

Compute Requirements

# Rough compute estimates for AV RL training
compute_requirements = {
    'research_experiment': {
        'agent_steps': 30_000_000,
        'hardware': 'Single GPU (V100)',
        'training_time': '2-4 hours',
    },
    'medium_scale': {
        'agent_steps': 200_000_000,
        'hardware': '8x V100 GPUs',
        'training_time': '12-48 hours',
    },
    'production_scale': {
        'agent_steps': 2_500_000_000,
        'hardware': '32+ GPUs or TPU pod',
        'training_time': 'Days to weeks',
    },
}

Distributed Training Architectures

Data Parallelism

The dominant paradigm for RL training:

┌─────────────────────────────────────────────────────────────────────────┐
│                    DATA PARALLELISM                                      │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│   Device 0           Device 1           Device 2           Device 3     │
│   ┌───────────┐      ┌───────────┐      ┌───────────┐      ┌───────────┐│
│   │ Full Model│      │ Full Model│      │ Full Model│      │ Full Model││
│   └─────┬─────┘      └─────┬─────┘      └─────┬─────┘      └─────┬─────┘│
│         │                  │                  │                  │      │
│         ▼                  ▼                  ▼                  ▼      │
│   ┌───────────┐      ┌───────────┐      ┌───────────┐      ┌───────────┐│
│   │ Data Shard│      │ Data Shard│      │ Data Shard│      │ Data Shard││
│   │    0      │      │    1      │      │    2      │      │    3      ││
│   └─────┬─────┘      └─────┬─────┘      └─────┬─────┘      └─────┬─────┘│
│         │                  │                  │                  │      │
│         ▼                  ▼                  ▼                  ▼      │
│   ┌───────────┐      ┌───────────┐      ┌───────────┐      ┌───────────┐│
│   │ Gradients │      │ Gradients │      │ Gradients │      │ Gradients ││
│   └─────┬─────┘      └─────┬─────┘      └─────┬─────┘      └─────┬─────┘│
│         │                  │                  │                  │      │
│         └──────────────────┴──────────────────┴──────────────────┘      │
│                                    │                                     │
│                                    ▼                                     │
│                            ┌─────────────┐                              │
│                            │  All-Reduce │                              │
│                            │  (Average)  │                              │
│                            └─────────────┘                              │
│                                    │                                     │
│                                    ▼                                     │
│                          Synchronized Update                             │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Advantages:

  • Simple to implement
  • Scales linearly with data
  • Well-supported in frameworks

Disadvantages:

  • Synchronization overhead
  • Memory limited by single device (for model)
  • Stragglers slow entire system

Actor-Learner Architecture (IMPALA)

Decouples experience collection from policy updates:

┌─────────────────────────────────────────────────────────────────────────┐
│                    IMPALA ARCHITECTURE                                   │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│   ACTORS (Many)                           LEARNER (Centralized)         │
│   ┌─────────────┐                         ┌─────────────────────┐       │
│   │   Actor 0   │────────────────────────►│                     │       │
│   │ Env + Policy│     Trajectories        │    Policy Update    │       │
│   └─────────────┘                         │                     │       │
│   ┌─────────────┐                         │    • Compute loss   │       │
│   │   Actor 1   │────────────────────────►│    • Update params  │       │
│   │ Env + Policy│                         │    • V-trace        │       │
│   └─────────────┘                         │      correction     │       │
│   ┌─────────────┐                         │                     │       │
│   │   Actor N   │────────────────────────►│                     │       │
│   │ Env + Policy│                         └──────────┬──────────┘       │
│   └─────────────┘                                    │                  │
│         ▲                                            │                  │
│         └────────────────────────────────────────────┘                  │
│                        Updated Policy                                    │
│                                                                          │
│   V-trace Correction:                                                    │
│   Handles off-policy data from stale actor policies                     │
│   Importance sampling with truncated weights                            │
│                                                                          │
│   Performance: 250,000 frames/second (30x single-machine A3C)           │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

SEED RL: Centralized Inference

Addresses IMPALA's limitation of running inference on actors:

┌─────────────────────────────────────────────────────────────────────────┐
│                    SEED RL ARCHITECTURE                                  │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│   ACTORS (CPU/TPU)                        LEARNER (TPU)                 │
│   ┌─────────────┐                         ┌─────────────────────┐       │
│   │   Actor 0   │                         │                     │       │
│   │  Env Only   │◄───────────────────────►│  Policy Inference   │       │
│   └─────────────┘   Observations/Actions  │         +           │       │
│   ┌─────────────┐   (Streaming gRPC)      │  Policy Training    │       │
│   │   Actor 1   │◄───────────────────────►│                     │       │
│   │  Env Only   │                         │  All NN compute     │       │
│   └─────────────┘                         │  on accelerator     │       │
│   ┌─────────────┐                         │                     │       │
│   │   Actor N   │◄───────────────────────►│                     │       │
│   │  Env Only   │                         └─────────────────────┘       │
│   └─────────────┘                                                        │
│                                                                          │
│   Benefits:                                                              │
│   • No model parameters sent to actors                                  │
│   • Efficient batched inference                                         │
│   • Lower bandwidth requirements                                         │
│   • Better TPU utilization                                              │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

Async vs. Sync Training

AspectSynchronousAsynchronous
ConvergenceStable, predictableMay oscillate
ThroughputLimited by stragglersHigher
Gradient freshnessAlways currentMay be stale
ImplementationSimplerMore complex
Best forSmall clustersLarge clusters

Hybrid Approach (Stale Synchronous Parallel):

  • Allow bounded staleness (e.g., max 2 iterations behind)
  • Balance throughput and convergence
  • Common in production systems

JAX-Specific Distributed Training

jax.pmap (Parallel Map)

import jax
import jax.numpy as jnp
from jax import pmap

# Replicate model across devices
num_devices = jax.device_count()
replicated_params = jax.tree.map(
    lambda x: jnp.broadcast_to(x, (num_devices,) + x.shape),
    params
)

@pmap
def train_step(params, batch):
    """Run on each device with its local batch."""

    def loss_fn(p):
        predictions = model.apply(p, batch['observations'])
        return jnp.mean((predictions - batch['targets']) ** 2)

    loss, grads = jax.value_and_grad(loss_fn)(params)

    # Average gradients across all devices
    grads = jax.lax.pmean(grads, axis_name='batch')

    # Update parameters (same on all devices)
    params = jax.tree.map(lambda p, g: p - lr * g, params, grads)

    return params, loss

# Training loop
for epoch in range(num_epochs):
    for batch in dataloader:
        # Split batch across devices
        batch = jax.tree.map(
            lambda x: x.reshape(num_devices, -1, *x.shape[1:]),
            batch
        )
        replicated_params, loss = train_step(replicated_params, batch)

Device Mesh and Sharding

from jax.sharding import Mesh, NamedSharding, PartitionSpec as P

# Create 2D mesh: (data_parallel, model_parallel)
devices = jax.devices()
mesh = Mesh(
    devices.reshape(2, 4),  # 8 devices -> 2x4 grid
    axis_names=('data', 'model')
)

# Shard parameters across model dimension
param_sharding = NamedSharding(mesh, P(None, 'model'))
sharded_params = jax.device_put(params, param_sharding)

# Shard data across data dimension
data_sharding = NamedSharding(mesh, P('data', None))
sharded_batch = jax.device_put(batch, data_sharding)

# JIT-compiled function respects sharding
@jax.jit
def forward(params, batch):
    return model.apply(params, batch)

output = forward(sharded_params, sharded_batch)
# Output sharding inferred automatically

Collective Operations

import jax.lax as lax

# Inside pmapped function:

# All-reduce: Sum across devices, result on all devices
total = lax.psum(local_value, axis_name='batch')

# All-reduce mean: Average across devices
mean = lax.pmean(local_value, axis_name='batch')

# All-gather: Collect shards from all devices
full_data = lax.all_gather(local_shard, axis_name='batch')

# Reduce-scatter: Sum then scatter results
scattered = lax.psum_scatter(values, axis_name='batch')

# Ring permutation: Shift data between devices
shifted = lax.ppermute(
    x,
    axis_name='batch',
    perm=[(i, (i + 1) % n) for i in range(n)]
)

TPU Pod Patterns

# Multi-host initialization
jax.distributed.initialize()

print(f"Global devices: {jax.device_count()}")
print(f"Local devices: {jax.local_device_count()}")
print(f"Process index: {jax.process_index()}")

# Create global array spanning all hosts
from jax.experimental import multihost_utils

global_mesh = Mesh(
    jax.devices(),
    axis_names=('data',)
)

# Each host creates its shard
local_data = load_data_for_host(jax.process_index())

# Combine into global array
global_array = multihost_utils.host_local_array_to_global_array(
    local_data,
    global_mesh,
    P('data')
)

Key Systems and Papers

PureJaxRL

Breakthrough: Implements entire RL pipeline in JAX, enabling full JIT compilation.

# PureJaxRL-style end-to-end training
import jax
import jax.numpy as jnp
from flax import linen as nn

class SimplePPO:
    def __init__(self, env, policy_network, value_network):
        self.env = env  # JAX-native environment!
        self.policy = policy_network
        self.value = value_network

    @jax.jit
    def train_step(self, state, params, opt_state, key):
        """Entire training step is JIT-compiled."""

        # Collect rollouts (vectorized across environments)
        def rollout_step(carry, _):
            env_state, params, key = carry
            key, subkey = jax.random.split(key)

            # Policy inference
            obs = self.env.get_obs(env_state)
            action, log_prob = self.policy.sample(params, obs, subkey)

            # Environment step (JAX-native!)
            env_state, reward, done = self.env.step(env_state, action)

            return (env_state, params, key), (obs, action, reward, log_prob, done)

        # Vectorized rollout using scan
        (env_state, _, _), trajectories = jax.lax.scan(
            rollout_step,
            (state.env_state, params, key),
            None,
            length=self.rollout_length
        )

        # Compute advantages and returns
        advantages, returns = compute_gae(trajectories, params)

        # PPO update (multiple epochs over collected data)
        def ppo_epoch(carry, _):
            params, opt_state, key = carry
            key, subkey = jax.random.split(key)

            # Shuffle and create minibatches
            indices = jax.random.permutation(subkey, self.batch_size)
            minibatches = create_minibatches(trajectories, indices)

            # Update on each minibatch
            for mb in minibatches:
                params, opt_state, loss = self.ppo_update(
                    params, opt_state, mb, advantages
                )

            return (params, opt_state, key), loss

        (params, opt_state, _), losses = jax.lax.scan(
            ppo_epoch,
            (params, opt_state, key),
            None,
            length=self.ppo_epochs
        )

        return state._replace(env_state=env_state), params, opt_state

Performance:

  • CartPole: 1M frames in 0.05s (vs 46s traditional)
  • Speedup: ~1000x per environment
  • With vectorization: ~4000x total

V-Max

Production-grade RL pipeline for Waymax:

# V-Max training setup
from vmax import VMaxEnv, VmapWrapper, AutoResetWrapper
from vmax.algorithms import SAC

# Load scenarios from multiple datasets
scenarios = ScenarioMax.load(
    datasets=['womd', 'nuplan', 'argoverse'],
    split='train'
)

# Create vectorized environment
base_env = VMaxEnv(scenarios, config)
env = VmapWrapper(AutoResetWrapper(base_env), num_envs=256)

# Configure SAC
sac_config = SACConfig(
    learning_rate=3e-4,
    discount=0.99,
    tau=0.005,
    batch_size=256,
    replay_buffer_size=1_000_000,
)

# Train
agent = SAC(env, sac_config)
for step in range(total_steps):
    metrics = agent.train_step()
    if step % log_interval == 0:
        print(f"Step {step}: {metrics}")

Results:

  • Single L4 GPU: 25M steps in 36 hours
  • SAC achieves 97.44% accuracy, 1.74% collision rate

Isaac Gym / Isaac Lab

GPU-accelerated physics simulation:

# Isaac Lab multi-GPU setup
from isaaclab.envs import ManagerBasedRLEnv
from isaaclab.utils.runner import Runner

# Configuration
env_cfg = SpotEnvCfg(
    num_envs=4096,  # Per GPU
    device="cuda",
)

# Multi-GPU training
runner = Runner(
    env=env_cfg,
    agent=PPOAgent,
    num_gpus=4,  # Distribute across GPUs
)

runner.train(total_timesteps=10_000_000)
# ~90,000 FPS on RTX A6000

Brax

Differentiable physics for JAX:

import brax
from brax import envs
from brax.training.agents.ppo import train as ppo_train

# Create vectorized environment
env = envs.create('ant', batch_size=2048)

# Train with PPO
make_inference_fn, params, metrics = ppo_train(
    environment=env,
    num_timesteps=10_000_000,
    episode_length=1000,
    num_envs=2048,
    learning_rate=3e-4,
    # Runs entirely on GPU/TPU
)

# Brax achieves hundreds of millions of steps/second on TPU pods

Performance Optimization

GPU Memory Optimization

# 1. Gradient Checkpointing
from jax.checkpoint import checkpoint

@checkpoint  # Recompute activations during backward pass
def memory_efficient_layer(params, x):
    x = nn.Dense(1024)(x)
    x = nn.relu(x)
    x = nn.Dense(1024)(x)
    return x

# 2. Mixed Precision
import jax.numpy as jnp

def forward_mixed_precision(params, x):
    # Convert to bfloat16 for compute
    x = x.astype(jnp.bfloat16)
    params_bf16 = jax.tree.map(lambda p: p.astype(jnp.bfloat16), params)

    # Compute in bfloat16
    output = model.apply(params_bf16, x)

    # Convert back for loss computation
    return output.astype(jnp.float32)

# 3. Gradient Accumulation
def train_with_accumulation(params, batches, accumulation_steps):
    accumulated_grads = jax.tree.map(jnp.zeros_like, params)

    for i, batch in enumerate(batches[:accumulation_steps]):
        grads = compute_gradients(params, batch)
        accumulated_grads = jax.tree.map(
            lambda a, g: a + g / accumulation_steps,
            accumulated_grads, grads
        )

    # Single update with accumulated gradients
    params = apply_gradients(params, accumulated_grads)
    return params

Profiling and Bottleneck Identification

# Start profiler
jax.profiler.start_trace("/tmp/tensorboard")

# Run training
for step in range(100):
    params, metrics = train_step(params, next(data_iter))

# Stop profiler
jax.profiler.stop_trace()

# View in TensorBoard:
# tensorboard --logdir=/tmp/tensorboard

# Programmatic analysis
import jax.profiler

with jax.profiler.trace("/tmp/jax-trace"):
    result = train_step(params, batch)
    result.block_until_ready()

# Key things to look for:
# 1. XLA compilation time (first step)
# 2. Device-to-host copies (should be minimal)
# 3. AllReduce duration (communication overhead)
# 4. Kernel execution gaps (memory bottlenecks)

Infrastructure Components

Replay Buffers at Scale

from typing import NamedTuple
import jax.numpy as jnp

class ReplayBuffer(NamedTuple):
    observations: jnp.ndarray
    actions: jnp.ndarray
    rewards: jnp.ndarray
    next_observations: jnp.ndarray
    dones: jnp.ndarray
    priorities: jnp.ndarray  # For prioritized replay
    position: int
    size: int

def create_replay_buffer(capacity: int, obs_shape: tuple, action_shape: tuple):
    """Create pre-allocated replay buffer on device."""
    return ReplayBuffer(
        observations=jnp.zeros((capacity,) + obs_shape),
        actions=jnp.zeros((capacity,) + action_shape),
        rewards=jnp.zeros(capacity),
        next_observations=jnp.zeros((capacity,) + obs_shape),
        dones=jnp.zeros(capacity, dtype=bool),
        priorities=jnp.ones(capacity),
        position=0,
        size=0,
    )

@jax.jit
def add_transition(buffer: ReplayBuffer, transition: dict) -> ReplayBuffer:
    """Add transition to buffer (JIT-compiled)."""
    pos = buffer.position

    new_buffer = buffer._replace(
        observations=buffer.observations.at[pos].set(transition['obs']),
        actions=buffer.actions.at[pos].set(transition['action']),
        rewards=buffer.rewards.at[pos].set(transition['reward']),
        next_observations=buffer.next_observations.at[pos].set(transition['next_obs']),
        dones=buffer.dones.at[pos].set(transition['done']),
        priorities=buffer.priorities.at[pos].set(1.0),  # Max priority for new
        position=(pos + 1) % buffer.observations.shape[0],
        size=jnp.minimum(buffer.size + 1, buffer.observations.shape[0]),
    )

    return new_buffer

@jax.jit
def sample_batch(buffer: ReplayBuffer, key: jax.random.PRNGKey, batch_size: int):
    """Sample batch with prioritized replay."""
    # Compute sampling probabilities
    probs = buffer.priorities[:buffer.size] / buffer.priorities[:buffer.size].sum()

    # Sample indices
    indices = jax.random.choice(key, buffer.size, shape=(batch_size,), p=probs)

    return {
        'observations': buffer.observations[indices],
        'actions': buffer.actions[indices],
        'rewards': buffer.rewards[indices],
        'next_observations': buffer.next_observations[indices],
        'dones': buffer.dones[indices],
    }

Checkpointing with Orbax

import orbax.checkpoint as ocp

# Create checkpointer
checkpointer = ocp.PyTreeCheckpointer()

# Save checkpoint (async for non-blocking)
def save_checkpoint(step, params, opt_state, metrics):
    checkpoint_path = f"/checkpoints/step_{step}"

    ckpt = {
        'params': params,
        'opt_state': opt_state,
        'metrics': metrics,
        'step': step,
    }

    # Async save - returns immediately
    checkpointer.save(
        checkpoint_path,
        ckpt,
        save_args=ocp.SaveArgs(aggregate=True)
    )

# Restore checkpoint
def restore_checkpoint(path, target_sharding=None):
    return checkpointer.restore(
        path,
        restore_args=ocp.RestoreArgs(
            restore_type=ocp.RestoreType.CHECKPOINT,
            sharding=target_sharding,  # Restore with correct sharding
        )
    )

# Best practice: checkpoint periodically
for step in range(total_steps):
    params, metrics = train_step(params, batch)

    if step % checkpoint_interval == 0:
        save_checkpoint(step, params, opt_state, metrics)

Practical Implementation

Single GPU to Multi-GPU Migration

# Step 1: Single GPU baseline
def train_single_gpu():
    env = create_env(num_envs=32)
    params = init_params()
    opt_state = optimizer.init(params)

    for step in range(total_steps):
        batch = collect_rollout(env, params)
        params, opt_state = update(params, opt_state, batch)

# Step 2: Add data parallelism
def train_multi_gpu():
    num_devices = jax.device_count()

    # Scale environments per device
    env = create_env(num_envs=32 * num_devices)

    # Replicate params
    params = init_params()
    params = jax.device_put_replicated(params, jax.devices())
    opt_state = jax.device_put_replicated(optimizer.init(params[0]), jax.devices())

    @jax.pmap
    def parallel_update(params, opt_state, batch):
        grads = compute_gradients(params, batch)
        grads = jax.lax.pmean(grads, 'batch')  # Average across devices
        params, opt_state = apply_gradients(params, opt_state, grads)
        return params, opt_state

    for step in range(total_steps):
        # Collect across all devices
        batch = collect_rollout(env, params)

        # Reshape for pmap: (devices, batch_per_device, ...)
        batch = reshape_for_pmap(batch, num_devices)

        params, opt_state = parallel_update(params, opt_state, batch)

Scaling Checklist

# Before scaling, verify:
scaling_checklist = {
    # Learning rate scaling
    'learning_rate': base_lr * sqrt(num_devices),  # Linear or sqrt scaling

    # Batch size scaling
    'effective_batch_size': base_batch * num_devices,

    # Verify convergence matches single GPU
    'convergence_test': compare_loss_curves(single_gpu, multi_gpu),

    # Check for communication bottlenecks
    'comm_overhead': measure_allreduce_time(),

    # Memory per device
    'memory_usage': profile_memory_per_device(),

    # Determinism
    'reproducibility': verify_same_results_given_seed(),
}

Common Pitfalls

PitfallSymptomSolution
OOM during backwardCUDA OOM errorGradient checkpointing
Slow first step60s+ first iterationExpected (XLA compilation)
Different results per runNon-reproducible trainingSet all RNG seeds
Linear scaling doesn't holdDiminishing returnsCheck communication overhead
Gradient explosionNaN lossesGradient clipping, lower LR

Interview Questions

Conceptual Questions

Q1: Explain why CPU-GPU data transfer is the primary bottleneck in traditional RL training.

Expected Answer:

  • RL requires many environment steps (millions to billions)
  • Each step requires: obs → policy → action → env → obs
  • If env is on CPU and policy on GPU, every step requires:
    • Copy observation CPU → GPU (~10K ns)
    • Run policy inference
    • Copy action GPU → CPU (~10K ns)
    • Environment step
  • These copies dominate total time
  • Solution: Run everything on GPU (PureJaxRL, Brax, Isaac Gym)

Q2: Compare IMPALA and SEED RL architectures. When would you choose each?

Expected Answer:

AspectIMPALASEED RL
Inference locationActorsLearner
BandwidthHigh (params to actors)Low (obs/actions only)
Actor hardwareNeed GPU/TPUCPU sufficient
LatencyLowerHigher (network round trip)
Best forMany actors, GPU-richTPU pods, limited actor GPUs

Q3: How does V-trace handle off-policy data in actor-learner architectures?

Expected Answer:

  • Actors use stale policies (π_old)
  • Learner trains with current policy (π)
  • V-trace uses importance sampling to correct:
    • ρ = π(a|s) / π_old(a|s)
    • Truncate: ρ̄ = min(ρ, c)
  • Corrects for policy drift without full importance weighting
  • Enables stable training with highly off-policy data

Technical Questions

Q4: Design a distributed training system for 2.5B agent steps with fault tolerance.

Expected Answer:

Architecture:
1. Coordinator
   - Tracks training progress
   - Manages checkpoints
   - Handles worker failures

2. Workers (8-32 GPUs)
   - Data parallel training
   - Local replay buffers
   - Periodic sync with coordinator

3. Fault Tolerance
   - Checkpoint every 10K steps
   - Worker heartbeats
   - On failure: restore from checkpoint, redistribute work

4. Performance Targets
   - 2.5B steps / (8 GPUs × 100K steps/hour) ≈ 3125 hours single GPU
   - With 8 GPUs: ~390 hours = ~16 days
   - With 32 GPUs: ~100 hours = ~4 days

Q5: Explain when to use gradient accumulation vs. larger batch sizes.

Expected Answer:

  • Gradient accumulation: When batch doesn't fit in memory

    • Simulates larger batch without memory increase
    • Extra forward/backward passes (slower)
    • Use when: memory-constrained, need large effective batch
  • Larger batch sizes: When memory allows

    • More efficient (single pass)
    • Better GPU utilization
    • Use when: have sufficient memory
  • Learning rate adjustment: Both need LR scaling

    • Linear scaling: LR *= batch_scale_factor
    • Square root scaling: LR *= sqrt(batch_scale_factor)

Further Reading

Essential Papers

  1. "PureJaxRL" (2023) - Chris Lu et al.

  2. "V-Max" (2025) - Valeo AI

  3. "IMPALA" (2018) - DeepMind

  4. "Brax" (2021) - Google

  5. "Isaac Gym" (2021) - NVIDIA

Documentation

Code Repositories


Summary: Key Takeaways

  1. Scale matters exponentially - Larger models need proportionally more data; production requires billions of steps.

  2. End-to-end GPU training is transformative - Eliminating CPU-GPU copies enables 1000-4000x speedups.

  3. Actor-learner separation enables scale - IMPALA/SEED architectures maximize throughput while handling off-policy data.

  4. JAX primitives are powerful - pmap, vmap, scan, and sharding enable flexible parallelization strategies.

  5. Infrastructure is critical - Replay buffers, checkpointing, and profiling are essential for production systems.

  6. Async training trades convergence for throughput - Use V-trace or similar for off-policy correction.

  7. Scaling requires systematic approach - Learning rate adjustment, communication profiling, and reproducibility checks are essential.


Last updated: January 2025