BehaviorGPT: A Deep Dive into the WOSAC 2024 Champion
Paper: BehaviorGPT: Smart Agent Simulation for Autonomous Driving with Next-Patch Prediction Authors: Waymo & Google Research arXiv: 2405.17372 Achievement: 1st Place, WOSAC (Waymo Open Sim Agents Challenge) 2024
Table of Contents
- Executive Summary
- Architecture Deep Dive
- Key Innovation: Next-Patch Prediction (NP3)
- Relative Spacetime Representation
- Interactive Code Examples
- Training Strategy
- WOSAC Metrics Breakdown
- Mental Models
- Hands-On Exercises
- Interview Questions
1. Executive Summary
The David vs Goliath Story
BehaviorGPT won the 2024 WOSAC Challenge with a remarkable achievement: just 3 million parameters outperformed competitors with 10-523 million parameters. This is like a compact car winning a Formula 1 race against supercars.
What Made It Win?
+------------------------------------------------------------------+
| BehaviorGPT's Winning Formula |
+------------------------------------------------------------------+
| |
| 1. NEXT-PATCH PREDICTION (NP3) |
| - Predict trajectories in 1-second chunks, not frame-by-frame|
| - Forces the model to "think ahead" like humans do |
| |
| 2. FULLY AUTOREGRESSIVE DESIGN |
| - No encoder-decoder split |
| - Treats all trajectory data uniformly |
| |
| 3. RELATIVE SPACETIME REPRESENTATION |
| - Position-independent coordinate system |
| - Maximizes learning from every data point |
| |
| 4. TRIPLE ATTENTION MECHANISM |
| - Temporal: How does this agent move over time? |
| - Agent-Map: How does this agent interact with roads? |
| - Agent-Agent: How do agents interact with each other? |
| |
+------------------------------------------------------------------+
Key Results
| Metric | BehaviorGPT | Meaning |
|---|---|---|
| minADE | 1.4147 | Lowest prediction error |
| REALISM | 0.7473 | Best match to real-world behavior |
| COLLISION | 0.9537 | Excellent at avoiding crashes |
| OFFROAD | 0.9349 | Stays on the road reliably |
2. Architecture Deep Dive
The "GPT" in BehaviorGPT
Just like GPT predicts the next word in a sentence, BehaviorGPT predicts the next trajectory patch for all agents in a driving scene.
Language Model (GPT):
"The car drove down the" -> [road] -> [and] -> [turned] -> [left]
BehaviorGPT:
[Agent positions at t=0-1s] -> [t=1-2s] -> [t=2-3s] -> [t=3-4s]
Full Architecture Diagram
BehaviorGPT Architecture
═══════════════════════════════════════════════════════════════════
INPUT STAGE
───────────────────────────────────────────────────────────────────
Agent Trajectories Map Elements
┌─────────────────┐ ┌─────────────────┐
│ Agent 1: P1-P9 │ │ Road segments │
│ Agent 2: P1-P9 │ │ Lane markers │
│ ... │ │ Traffic signs │
│ Agent N: P1-P9 │ │ Crosswalks │
└────────┬────────┘ └────────┬────────┘
│ │
▼ ▼
┌─────────────────┐ ┌─────────────────┐
│ Patch Encoder │ │ Polyline Encoder│
│ (MLP + PE) │ │ (MLP + PE) │
└────────┬────────┘ └────────┬────────┘
│ │
└──────────────┬───────────────┘
▼
TRANSFORMER BLOCKS (×L layers)
───────────────────────────────────────────────────────────────────
For each layer l = 1 to L:
┌─────────────────────────────────────────────────────────────┐
│ │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ 1. TEMPORAL SELF-ATTENTION │ │
│ │ │ │
│ │ Agent 1: [P1]──[P2]──[P3]──[P4]──[P5]──... │ │
│ │ │ │ │ │ │ │ │
│ │ └────┴────┴────┴────┘ │ │
│ │ (causal mask: can only attend left) │ │
│ │ │ │
│ │ "How has this agent been moving over time?" │ │
│ └─────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ 2. AGENT-MAP CROSS-ATTENTION │ │
│ │ │ │
│ │ Agent Patch ──────► k-Nearest Map Elements │ │
│ │ │ │ │ │
│ │ └────────────────────┘ │ │
│ │ Query Keys/Values │ │
│ │ │ │
│ │ "What road features should this agent consider?" │ │
│ └─────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ 3. AGENT-AGENT SELF-ATTENTION │ │
│ │ │ │
│ │ Agent 1 ◄──────► Agent 2 │ │
│ │ │ ╲ ╱ │ │ │
│ │ │ ╲ ╱ │ │ │
│ │ │ ╲ ╱ │ │ │
│ │ ▼ ╲╱ ▼ │ │
│ │ Agent 3 ◄──╳───► Agent 4 │ │
│ │ (k-nearest filtering) │ │
│ │ │ │
│ │ "How are nearby agents affecting this one?" │ │
│ └─────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────┘
OUTPUT STAGE
───────────────────────────────────────────────────────────────────
│
▼
┌──────────────────────────┐
│ GMM Decoder │
│ (Gaussian Mixture) │
└────────────┬─────────────┘
│
▼
┌──────────────────────────┐
│ Per-agent predictions: │
│ - 16 mixture modes │
│ - Mean (μ) per mode │
│ - Covariance (Σ) │
│ - Mixture weights (π) │
└──────────────────────────┘
Model Specifications
# BehaviorGPT Hyperparameters
config = {
"hidden_dim": 128, # Embedding dimension
"num_heads": 8, # Attention heads
"head_dim": 16, # 128 / 8
"num_layers": 6, # Transformer blocks
"patch_size": 10, # Timesteps per patch (1 second @ 10Hz)
"max_agents": 128, # Maximum agents per scenario
"num_modes": 16, # GMM mixture components
"k_nearest_map": 256, # Map elements to attend to
"k_nearest_agents": 32, # Agents to attend to
"total_params": "3M" # Remarkably small!
}
3. Key Innovation: Next-Patch Prediction (NP3)
The Problem with Frame-by-Frame Prediction
Traditional approaches predict one timestep at a time (10Hz = every 0.1 seconds):
Traditional Approach:
┌─────┐ ┌─────┐ ┌─────┐ ┌─────┐
│ t=0 │───►│ t=1 │───►│ t=2 │───►│ t=3 │ ... (×80 for 8 seconds)
└─────┘ └─────┘ └─────┘ └─────┘
│ │ │ │
└──────────┴──────────┴──────────┘
Easy shortcut: just copy previous!
The Problem: At 10Hz, the difference between consecutive frames is tiny. The model learns a lazy shortcut - just copy the previous position with minor adjustments. This leads to:
- Compounding errors over time
- Unrealistic "drifting" behavior
- Poor long-horizon predictions
The NP3 Solution
Next-Patch Prediction (NP3):
┌───────────────┐ ┌───────────────┐ ┌───────────────┐
│ Patch 1 │───►│ Patch 2 │───►│ Patch 3 │
│ (t=0 to 1s) │ │ (t=1 to 2s) │ │ (t=2 to 3s) │
│ 10 timesteps │ │ 10 timesteps │ │ 10 timesteps │
└───────────────┘ └───────────────┘ └───────────────┘
│ │ │
└────────────────────┴────────────────────┘
Significant changes between patches!
Model MUST reason about trajectory
Why Patches Work: An Analogy
Think of writing a story:
Word-by-Word (like frame-by-frame):
┌─────┐ ┌─────┐ ┌─────┐ ┌─────┐ ┌─────┐ ┌─────┐
│ The │──►│ car │──►│ was │──►│ on │──►│ the │──►│road │
└─────┘ └─────┘ └─────┘ └─────┘ └─────┘ └─────┘
Easy to just predict common word patterns without understanding!
Paragraph-by-Paragraph (like NP3):
┌──────────────────────┐ ┌──────────────────────┐
│ "The red sports car │───►│ "It accelerated │
│ was approaching │ │ through the curve, │
│ the intersection." │ │ narrowly avoiding │
│ │ │ the truck." │
└──────────────────────┘ └──────────────────────┘
Must understand MEANING to generate coherent next paragraph!
Mathematical Formulation
The joint probability of all agent trajectories:
Standard (per-timestep):
P(S₁:N₁:T | M) = ∏ₜ ∏ₙ P(sⁿₜ | s¹:ᴺ₁:ₜ₋₁, M)
NP3 (per-patch):
P(S₁:N₁:T | M) = ∏ₚ P(Pₚ | P₁:ₚ₋₁, M)
Where:
- S = All agent states
- N = Number of agents
- T = Total timesteps
- M = Map information
- P = Trajectory patches (10 timesteps each)
Patch Decomposition
Within each patch, predictions decompose hierarchically:
Patch-Level Prediction
│
┌───────────────┼───────────────┐
▼ ▼ ▼
Agent 1 Agent 2 ... Agent N
│ │ │
┌─────┴─────┐ ┌─────┴─────┐ ┌─────┴─────┐
▼ ▼ ▼ ▼ ▼ ▼ ▼ ▼ ▼
t1 t2 ... t1 t2 ... t1 t2 ...
Each timestep within patch:
P(sⁿₜ | sⁿ₁:ₜ₋₁, Pᵃˡˡ₁:ₚ₋₁) ← Gaussian Mixture Model
4. Relative Spacetime Representation
The Coordinate Frame Problem
Traditional approaches use absolute coordinates or ego-centric views:
Absolute Coordinates Problem:
┌──────────────────────────────────────────────┐
│ │
│ Car A at (100, 200) │
│ Car B at (105, 203) │
│ │
│ Different scenarios have different coords! │
│ Model must memorize locations = inefficient │
│ │
└──────────────────────────────────────────────┘
Ego-Centric Problem:
┌──────────────────────────────────────────────┐
│ │
│ Everything relative to ONE "ego" agent │
│ But who is "ego" in multi-agent simulation? │
│ Different agents see different views! │
│ │
└──────────────────────────────────────────────┘
BehaviorGPT's Solution: Relative Spacetime
Every element is described by its relationship to OTHER elements:
Relative Spacetime Representation
═════════════════════════════════════════════════════════════
For any two elements (patches, agents, or map features):
┌─────────────────────────────────────────────────────────┐
│ │
│ Element i Element j │
│ ●─────────────────────────────● │
│ │ │ │
│ │ Relative Features: │ │
│ │ │ │
│ │ • Distance: ||pᵢ - pⱼ|| │ │
│ │ • Angle: atan2(Δy, Δx) │ │
│ │ • Yaw diff: θᵢ - θⱼ │ │
│ │ • Time diff: tᵢ - tⱼ │ │
│ │ │ │
└─────────────────────────────────────────────────────────┘
Plus Self Features:
• Speed, acceleration
• Agent type (vehicle, pedestrian, cyclist)
• Bounding box dimensions
• Valid/invalid mask
Feature Vector Composition
# For each trajectory patch, the feature vector includes:
patch_features = {
# Semantic (self) features - position independent
"speed": float, # Current velocity magnitude
"acceleration": float, # Rate of change
"yaw_rate": float, # Turning rate
"agent_type": one_hot, # [vehicle, pedestrian, cyclist]
"bbox": [length, width], # Agent dimensions
# Relative features - computed pairwise
"relative_distance": float, # Euclidean distance
"relative_angle": float, # Bearing angle
"relative_yaw": float, # Heading difference
"time_offset": int, # Temporal distance in patches
}
# Total feature dimension
d_input = d_semantic + d_relative # Fed to MLP encoder
Why This Matters: Data Efficiency
Traditional (Position-Dependent):
┌────────────────────────────────────────────────────────┐
│ Scenario 1: Intersection at (100, 200) │
│ Scenario 2: Intersection at (5000, 3000) │
│ │
│ Same behavior, but model sees completely different │
│ numbers! Must learn the same thing twice. │
└────────────────────────────────────────────────────────┘
BehaviorGPT (Relative):
┌────────────────────────────────────────────────────────┐
│ Scenario 1: Car A is 5m behind Car B, approaching │
│ Scenario 2: Car A is 5m behind Car B, approaching │
│ │
│ SAME representation regardless of absolute position! │
│ Every example teaches the same concept once. │
└────────────────────────────────────────────────────────┘
5. Interactive Code Examples
5.1 Patch Encoder
import jax
import jax.numpy as jnp
from flax import linen as nn
class PatchEncoder(nn.Module):
"""Encodes a trajectory patch into a fixed-size embedding."""
hidden_dim: int = 128
patch_size: int = 10 # 1 second at 10Hz
@nn.compact
def __call__(self, patch_features):
"""
Args:
patch_features: [batch, num_agents, num_patches, patch_size, d_input]
Returns:
patch_embeddings: [batch, num_agents, num_patches, hidden_dim]
"""
batch, num_agents, num_patches, patch_size, d_input = patch_features.shape
# Flatten patch timesteps
x = patch_features.reshape(batch, num_agents, num_patches, -1)
# Shape: [batch, num_agents, num_patches, patch_size * d_input]
# MLP encoder
x = nn.Dense(self.hidden_dim)(x)
x = nn.relu(x)
x = nn.Dense(self.hidden_dim)(x)
x = nn.LayerNorm()(x)
# Add positional encoding for patch index
patch_indices = jnp.arange(num_patches)
pos_encoding = self._sinusoidal_encoding(patch_indices, self.hidden_dim)
x = x + pos_encoding[None, None, :, :] # Broadcast over batch and agents
return x
def _sinusoidal_encoding(self, positions, dim):
"""Standard sinusoidal positional encoding."""
pe = jnp.zeros((len(positions), dim))
div_term = jnp.exp(jnp.arange(0, dim, 2) * -(jnp.log(10000.0) / dim))
pe = pe.at[:, 0::2].set(jnp.sin(positions[:, None] * div_term))
pe = pe.at[:, 1::2].set(jnp.cos(positions[:, None] * div_term))
return pe
5.2 Triple Attention Block
class TripleAttentionBlock(nn.Module):
"""One layer of BehaviorGPT's triple attention mechanism."""
hidden_dim: int = 128
num_heads: int = 8
k_nearest_map: int = 256
k_nearest_agents: int = 32
@nn.compact
def __call__(self, agent_embeddings, map_embeddings, agent_positions):
"""
Args:
agent_embeddings: [batch, num_agents, num_patches, hidden_dim]
map_embeddings: [batch, num_map_elements, hidden_dim]
agent_positions: [batch, num_agents, num_patches, 2] (x, y)
"""
batch, num_agents, num_patches, hidden_dim = agent_embeddings.shape
# ═══════════════════════════════════════════════════════════════
# STEP 1: Temporal Self-Attention (per agent, across time)
# ═══════════════════════════════════════════════════════════════
# Reshape to [batch * num_agents, num_patches, hidden_dim]
x = agent_embeddings.reshape(batch * num_agents, num_patches, hidden_dim)
# Causal mask: each patch can only attend to previous patches
causal_mask = jnp.tril(jnp.ones((num_patches, num_patches)))
x = x + CausalSelfAttention(
num_heads=self.num_heads,
name="temporal_attention"
)(x, mask=causal_mask)
x = x.reshape(batch, num_agents, num_patches, hidden_dim)
# ═══════════════════════════════════════════════════════════════
# STEP 2: Agent-Map Cross-Attention
# ═══════════════════════════════════════════════════════════════
# For each agent-patch, find k-nearest map elements
# agent_positions: [batch, num_agents, num_patches, 2]
# map_positions: [batch, num_map_elements, 2] (centroid of each element)
map_positions = self._get_map_centroids(map_embeddings)
# Compute distances and select k-nearest
# This creates a sparse attention pattern
nearest_map_indices = self._k_nearest(
agent_positions, map_positions, k=self.k_nearest_map
)
# Gather relevant map embeddings
selected_map = jnp.take_along_axis(
map_embeddings[:, None, None, :, :],
nearest_map_indices[:, :, :, :, None],
axis=3
)
# Cross-attention: agents query, map elements are keys/values
x = x + CrossAttention(
num_heads=self.num_heads,
name="agent_map_attention"
)(query=x, key_value=selected_map)
# ═══════════════════════════════════════════════════════════════
# STEP 3: Agent-Agent Self-Attention
# ═══════════════════════════════════════════════════════════════
# For each patch timestep, attend across agents
# Reshape to [batch * num_patches, num_agents, hidden_dim]
x = x.transpose(0, 2, 1, 3).reshape(batch * num_patches, num_agents, hidden_dim)
# Find k-nearest agents at each timestep
agent_pos_per_patch = agent_positions.transpose(0, 2, 1, 3)
agent_pos_per_patch = agent_pos_per_patch.reshape(batch * num_patches, num_agents, 2)
# Create sparse attention mask based on k-nearest
agent_attention_mask = self._create_knn_mask(
agent_pos_per_patch, k=self.k_nearest_agents
)
x = x + SelfAttention(
num_heads=self.num_heads,
name="agent_agent_attention"
)(x, mask=agent_attention_mask)
# Reshape back
x = x.reshape(batch, num_patches, num_agents, hidden_dim)
x = x.transpose(0, 2, 1, 3) # [batch, num_agents, num_patches, hidden_dim]
# Feed-forward network
x = x + FeedForward(hidden_dim=hidden_dim)(x)
return x
5.3 GMM Output Head
class GMMOutputHead(nn.Module):
"""Predicts Gaussian Mixture Model parameters for next patch."""
hidden_dim: int = 128
num_modes: int = 16
patch_size: int = 10
output_dim: int = 2 # (x, y) positions
@nn.compact
def __call__(self, hidden_states):
"""
Args:
hidden_states: [batch, num_agents, num_patches, hidden_dim]
Returns:
gmm_params: Dictionary with mixture parameters
"""
batch, num_agents, num_patches, _ = hidden_states.shape
# Project to GMM parameters
# For each mode: mean, log_std, and mixture weight
# Mixture logits: [batch, num_agents, num_patches, num_modes]
mixture_logits = nn.Dense(self.num_modes, name="mixture_logits")(hidden_states)
mixture_weights = jax.nn.softmax(mixture_logits, axis=-1)
# Per-mode predictions
# Shape: [batch, num_agents, num_patches, num_modes, patch_size * output_dim]
means = nn.Dense(
self.num_modes * self.patch_size * self.output_dim,
name="means"
)(hidden_states)
means = means.reshape(
batch, num_agents, num_patches, self.num_modes,
self.patch_size, self.output_dim
)
# Log standard deviations (for numerical stability)
log_stds = nn.Dense(
self.num_modes * self.patch_size * self.output_dim,
name="log_stds"
)(hidden_states)
log_stds = log_stds.reshape(
batch, num_agents, num_patches, self.num_modes,
self.patch_size, self.output_dim
)
stds = jnp.exp(jnp.clip(log_stds, -5, 5)) # Clip for stability
return {
"mixture_weights": mixture_weights, # [B, N, P, M]
"means": means, # [B, N, P, M, T, 2]
"stds": stds, # [B, N, P, M, T, 2]
}
def sample(self, gmm_params, rng_key):
"""Sample trajectories from the GMM."""
weights = gmm_params["mixture_weights"]
means = gmm_params["means"]
stds = gmm_params["stds"]
# Sample mode for each agent
mode_key, sample_key = jax.random.split(rng_key)
mode_indices = jax.random.categorical(mode_key, jnp.log(weights), axis=-1)
# Gather selected mode parameters
# This is a simplified version - actual implementation uses vmap
selected_means = jnp.take_along_axis(
means, mode_indices[..., None, None, None], axis=3
).squeeze(3)
selected_stds = jnp.take_along_axis(
stds, mode_indices[..., None, None, None], axis=3
).squeeze(3)
# Sample from Gaussian
noise = jax.random.normal(sample_key, selected_means.shape)
samples = selected_means + noise * selected_stds
return samples
5.4 Training Loss
def compute_nll_loss(gmm_params, ground_truth, valid_mask):
"""
Negative log-likelihood loss for GMM predictions.
Args:
gmm_params: Dict with 'mixture_weights', 'means', 'stds'
ground_truth: [batch, num_agents, num_patches, patch_size, 2]
valid_mask: [batch, num_agents, num_patches] boolean
"""
weights = gmm_params["mixture_weights"] # [B, N, P, M]
means = gmm_params["means"] # [B, N, P, M, T, 2]
stds = gmm_params["stds"] # [B, N, P, M, T, 2]
# Expand ground truth for broadcasting with modes
gt_expanded = ground_truth[:, :, :, None, :, :] # [B, N, P, 1, T, 2]
# Compute Gaussian log-probabilities for each mode
# log p(x|mode) = -0.5 * ((x - μ)² / σ² + log(2π) + 2*log(σ))
var = stds ** 2
log_probs = -0.5 * (
((gt_expanded - means) ** 2) / var +
jnp.log(2 * jnp.pi) +
2 * jnp.log(stds)
)
# Sum over timesteps and dimensions within each mode
log_probs = log_probs.sum(axis=(-1, -2)) # [B, N, P, M]
# Log-sum-exp over modes (mixture probability)
# log p(x) = log Σ_m π_m * p(x|m) = logsumexp(log π_m + log p(x|m))
log_mixture_probs = jax.scipy.special.logsumexp(
jnp.log(weights + 1e-8) + log_probs,
axis=-1
) # [B, N, P]
# Apply valid mask and compute mean
nll = -log_mixture_probs
nll = jnp.where(valid_mask, nll, 0.0)
loss = nll.sum() / (valid_mask.sum() + 1e-8)
return loss
5.5 Autoregressive Inference Loop
def autoregressive_rollout(model, initial_history, map_data, num_future_patches, rng_key):
"""
Generate future trajectories autoregressively.
Args:
model: Trained BehaviorGPT model
initial_history: [batch, num_agents, history_patches, patch_size, features]
map_data: Static map information
num_future_patches: Number of 1-second patches to predict
rng_key: JAX random key
"""
batch_size, num_agents, _, patch_size, _ = initial_history.shape
# Start with history
current_sequence = initial_history
all_predictions = []
for patch_idx in range(num_future_patches):
# Split random key for this iteration
rng_key, sample_key = jax.random.split(rng_key)
# Forward pass through model
# Model sees all patches up to current time
gmm_params = model.apply(
{"params": model.params},
agent_patches=current_sequence,
map_data=map_data
)
# Extract prediction for the LAST patch (next-patch prediction)
last_patch_params = {
"mixture_weights": gmm_params["mixture_weights"][:, :, -1:, :],
"means": gmm_params["means"][:, :, -1:, :, :, :],
"stds": gmm_params["stds"][:, :, -1:, :, :, :],
}
# Sample next patch from GMM
next_patch = model.output_head.sample(last_patch_params, sample_key)
# Shape: [batch, num_agents, 1, patch_size, 2]
# Convert sampled positions to full features
next_patch_features = compute_patch_features(
next_patch,
current_sequence[:, :, -1:, :, :] # Previous patch for relative features
)
# Append to sequence
current_sequence = jnp.concatenate(
[current_sequence, next_patch_features], axis=2
)
all_predictions.append(next_patch)
# Print progress (for debugging)
print(f"Generated patch {patch_idx + 1}/{num_future_patches}")
# Stack all predictions
future_trajectory = jnp.concatenate(all_predictions, axis=2)
# Shape: [batch, num_agents, num_future_patches, patch_size, 2]
return future_trajectory
6. Training Strategy
How to Train a Champion with 3M Parameters
┌─────────────────────────────────────────────────────────────────┐
│ BehaviorGPT Training Configuration │
├─────────────────────────────────────────────────────────────────┤
│ │
│ Dataset: Waymo Open Motion Dataset (WOMD) │
│ ───────────────────────────────────────── │
│ • 570+ hours of driving data │
│ • 1750 unique scenarios per segment │
│ • 10 Hz sampling rate │
│ • Rich annotations: vehicles, pedestrians, cyclists │
│ │
│ Hardware │
│ ───────── │
│ • 8× NVIDIA RTX 4090 GPUs │
│ • Distributed data parallel training │
│ │
│ Optimization │
│ ──────────── │
│ • Optimizer: AdamW │
│ • Learning rate: 5 × 10⁻⁴ │
│ • LR schedule: Cosine annealing with warm restarts │
│ • Weight decay: 0.1 │
│ • Dropout: 0.1 │
│ • Gradient clipping: 1.0 │
│ │
│ Training Schedule │
│ ───────────────── │
│ • Epochs: 30 │
│ • Batch size: 24 scenarios │
│ • Training time: ~24 hours │
│ │
└─────────────────────────────────────────────────────────────────┘
Training Curriculum
Phase 1: Basic Motion Patterns (Epochs 1-10)
═══════════════════════════════════════════
• Model learns fundamental vehicle dynamics
• Straight-line motion, basic turns
• Speed profiles matching road types
Phase 2: Multi-Agent Coordination (Epochs 11-20)
══════════════════════════════════════════════════
• Focus on agent-agent interactions
• Following distance, lane changes
• Yielding and merging behaviors
Phase 3: Complex Scenarios (Epochs 21-30)
═════════════════════════════════════════
• Fine-tuning on challenging cases
• Intersections, roundabouts
• Rare but critical events
Learning Rate Schedule
Learning Rate
↑
5e-4 │ ╭──────╮
│ ╱ ╲
│ ╱ ╲ ╭──────╮
│ ╱ ╲ ╱ ╲
│╱ ╲ ╱ ╲
1e-5 │────────────────╲───╱────────────╲──────────►
│ ╲_╱ ╲_________
└─────────────────────────────────────────────►
5 10 15 20 25 30 Epoch
Cosine Annealing with Warm Restarts
Data Augmentation
def augment_scenario(scenario, rng_key):
"""Augmentations that preserve relative spacetime representation."""
# These augmentations are FREE with relative representation!
augmentations = {
# Global translation: Relative distances unchanged
"translate": random_translation(rng_key),
# Global rotation: Relative angles unchanged
"rotate": random_rotation(rng_key),
# Mirror flip: Still valid driving scenario
"flip": random_horizontal_flip(rng_key),
}
# Note: Because BehaviorGPT uses RELATIVE features,
# these augmentations don't change the input representation!
# This means effective 8x data augmentation for free.
return apply_augmentations(scenario, augmentations)
Ablation Study Results
┌─────────────────────────────────────────────────────────────────┐
│ Ablation Study Results │
├───────────────────────────┬───────────┬───────────┬─────────────┤
│ Configuration │ REALISM │ COLLISION │ minADE │
├───────────────────────────┼───────────┼───────────┼─────────────┤
│ Full BehaviorGPT │ 0.747 │ 0.954 │ 1.415 │
│ w/o Patching (patch=1) │ 0.712 │ 0.923 │ 1.582 │
│ w/o Agent-Agent Attention │ 0.681 │ 0.624 │ 1.498 │
│ w/o Agent-Map Attention │ 0.698 │ 0.912 │ 1.521 │
│ w/o Temporal Attention │ 0.654 │ 0.889 │ 1.634 │
│ Patch size = 5 │ 0.731 │ 0.941 │ 1.446 │
│ Patch size = 20 │ 0.724 │ 0.932 │ 1.478 │
└───────────────────────────┴───────────┴───────────┴─────────────┘
Key Findings:
• Patching provides +4.9% REALISM improvement
• Agent-Agent Attention is CRITICAL: -34.6% collision realism without it
• Patch size 10 (1 second) is optimal for 10Hz data
7. WOSAC Metrics Breakdown
Understanding the Challenge
The Waymo Open Sim Agents Challenge (WOSAC) evaluates how well simulated agents match real-world driving behavior.
WOSAC Evaluation Pipeline
═════════════════════════════════════════════════════════════════
Input: Scenario with 1.1s of history
↓
┌──────────────────────────────────────────────────────────────┐
│ │
│ Model generates 32 simulation replicas │
│ Each replica: 8 seconds of future trajectories │
│ For ALL agents simultaneously │
│ │
└──────────────────────────────────────────────────────────────┘
↓
┌──────────────────────────────────────────────────────────────┐
│ │
│ Compare simulated distribution vs. real-world logs │
│ │
│ Metrics: │
│ ├── Realism Meta-Metric (overall score) │
│ ├── Kinematic Metrics (motion quality) │
│ ├── Interactive Metrics (multi-agent behavior) │
│ └── Map-based Metrics (road compliance) │
│ │
└──────────────────────────────────────────────────────────────┘
Metric Definitions
1. REALISM (Meta-Metric)
REALISM = Aggregate score combining all metrics
Higher is better (max 1.0)
BehaviorGPT: 0.7473 (1st place!)
Interpretation:
┌────────────────────────────────────────────────────────────┐
│ 0.0 ─────────── 0.5 ─────────── 0.75 ─────────── 1.0 │
│ │ │ ↑ │ │
│ Random Baseline BehaviorGPT Perfect │
│ noise models ★ │
└────────────────────────────────────────────────────────────┘
2. minADE (Minimum Average Displacement Error)
minADE = min over modes ( average L2 distance to ground truth )
Formula:
1 T
minADE = min_k ─── Σ ||ŝₜᵏ - sₜ||₂
T t=1
Where:
• k = index over predicted modes
• T = prediction horizon (80 timesteps)
• ŝₜᵏ = predicted position at time t, mode k
• sₜ = ground truth position at time t
BehaviorGPT: 1.4147 meters (best in challenge!)
Visual Explanation:
═══════════════════
Ground Truth: ●────●────●────●────●────●
Predicted Mode 1: ○───○───○───○───○───○ (ADE = 2.1m)
Predicted Mode 2: ◇────◇────◇────◇────◇────◇ (ADE = 1.4m) ✓ min
Predicted Mode 3: △──△──△──△──△──△ (ADE = 3.2m)
minADE = 1.4m (best mode is selected)
3. COLLISION Score
COLLISION = 1 - (collision rate in simulation / collision rate in logs)
Measures: Do simulated agents collide at similar rates to real world?
Real World: ~2% of agent pairs have close encounters
Simulation: Should match this rate, not be collision-free!
BehaviorGPT: 0.9537
Note: A "perfect" safety system with 0 collisions would score POORLY!
We want REALISTIC collision rates, not zero collisions.
Why This Matters:
═══════════════════════════════════════════════════════════════
Too Few Collisions (Overly Cautious):
┌─────────────────────────────────────┐
│ AV: "I'll stop here to be safe" │
│ 🚗 ─────────── STOP ────── │
│ │
│ Real humans don't drive this way! │
│ Unrealistic for training. │
└─────────────────────────────────────┘
Too Many Collisions (Too Aggressive):
┌─────────────────────────────────────┐
│ AV: "YOLO!" │
│ 🚗 💥 🚙 │
│ │
│ Dangerous! Bad for safety testing. │
└─────────────────────────────────────┘
Just Right (BehaviorGPT):
┌─────────────────────────────────────┐
│ Matches real-world collision rates │
│ Enables realistic safety testing │
└─────────────────────────────────────┘
4. OFFROAD Score
OFFROAD = 1 - (off-road rate in simulation / off-road rate in logs)
Measures: Do agents stay on valid road surfaces?
BehaviorGPT: 0.9349
Components:
• Lane boundary violations
• Driving on sidewalks
• Entering non-drivable areas
Visual:
═══════
Road Surface
┌────────────────────────────┐
│ ════════════════════════ │
Bad: │ ══ 🚗 ════════════════════ │ (off-road!)
│ ↓ │
└────────────────────────────┘
│
│ OFFROAD
│ penalty
↓
┌────────────────────────────┐
│ ════════════════════════ │
Good: │ ═══════ 🚗 ═══════════════ │ (on-road!)
│ │
└────────────────────────────┘
WOSAC 2024 Leaderboard
┌─────────────────────────────────────────────────────────────────┐
│ WOSAC 2024 Final Results │
├──────┬─────────────────────┬──────────┬───────────┬─────────────┤
│ Rank │ Method │ Params │ REALISM │ minADE │
├──────┼─────────────────────┼──────────┼───────────┼─────────────┤
│ 🥇 │ BehaviorGPT │ 3M │ 0.747 │ 1.415 │
│ 🥈 │ MVTE │ 45M │ 0.731 │ 1.523 │
│ 🥉 │ Trajeglish │ 120M │ 0.718 │ 1.489 │
│ 4 │ MotionLM │ 523M │ 0.702 │ 1.567 │
│ 5 │ SceneTransformer │ 28M │ 0.689 │ 1.612 │
└──────┴─────────────────────┴──────────┴───────────┴─────────────┘
Key Insight: BehaviorGPT achieves BEST results with FEWEST parameters!
This is 175x smaller than MotionLM, yet performs better.
8. Mental Models
Mental Model 1: The Driving Instructor Analogy
Traditional Frame-by-Frame Prediction:
══════════════════════════════════════════════════════════════════
Student: "What do I do next?"
Instructor: "Move your foot 0.1 inches..."
Student: "Now what?"
Instructor: "Move it 0.1 more inches..."
Problem: Student never learns to PLAN, just follows micro-instructions.
BehaviorGPT with NP3:
══════════════════════════════════════════════════════════════════
Student: "What do I do next?"
Instructor: "In the next second, you'll approach the intersection,
begin braking, and prepare to turn right."
Student: "Got it!" (executes the whole maneuver)
Result: Student learns to THINK AHEAD in meaningful chunks.
Mental Model 2: The Orchestra Conductor
Multi-Agent Coordination in BehaviorGPT:
════════════════════════════════════════════════════════════════
🎭 Conductor (Model)
│
┌──────────────────┼──────────────────┐
│ │ │
▼ ▼ ▼
🎻 Violin 🎺 Trumpet 🥁 Drums
(Car A) (Car B) (Car C)
Each musician (agent):
• Has their own score (temporal attention)
• Watches the conductor (agent-agent attention)
• Follows the venue layout (agent-map attention)
The conductor doesn't micromanage every note.
Instead, they guide each PHRASE (patch) of music.
Mental Model 3: The Story Continuation
GPT Language Model:
═══════════════════════════════════════════════════════════════
Context: "The sports car approached the red light. The driver..."
Prediction: "...slowed down gradually and came to a smooth stop,
noticing the pedestrian crossing ahead."
The model completes a coherent CONTINUATION of the story.
BehaviorGPT:
═══════════════════════════════════════════════════════════════
Context: [1 second of observed driving: car approaching intersection]
Prediction: [Next 1 second: car decelerating, preparing to stop]
[Next 1 second: car stopped at light]
[Next 1 second: car begins moving as light turns green]
...
The model completes a coherent CONTINUATION of the driving scenario.
Mental Model 4: Information Flow
How Information Flows Through BehaviorGPT:
════════════════════════════════════════════════════════════════
Time →
Agent 1: [Past]─────[Past]─────[Current]─────?─────?─────?
│ │ │ │ │ │
└──────────┴───────────┘ │ │ │
│ │ │ │
│ Temporal Self-Attention │ │ │
│ "What did I do before?" │ │ │
│ │ │ │
▼ │ │ │
┌─────────────────┐ │ │ │
│ Agent State │──────────────┘ │ │
└────────┬────────┘ │ │
│ │ │
│ Agent-Agent Attention │ │
│ "What are others doing?" │ │
│ │ │
▼ │ │
┌─────────────────┐ │ │
│ Social Context │────────────────────┘ │
└────────┬────────┘ │
│ │
│ Agent-Map Attention │
│ "What does the road allow?" │
│ │
▼ │
┌─────────────────┐ │
│ Full Context │──────────────────────────┘
└────────┬────────┘
│
▼
┌─────────────────┐
│ Predict Next │
│ Patch │
└─────────────────┘
Mental Model 5: The Patch Prediction Process
Step-by-Step Patch Prediction:
════════════════════════════════════════════════════════════════
Time (seconds): 0 1 2 3 4 5 6 7 8 9
├────┼────┼────┼────┼────┼────┼────┼────┼────┤
History: [====][====]
Patch Patch
0 1
Step 1: [====][====][????]
│
└─ Predict Patch 2 from context
Step 2: [====][====][====][????]
│
└─ Predict Patch 3 from context
Step 3: [====][====][====][====][????]
│
└─ Predict Patch 4
...continue until desired horizon...
Final: [====][====][====][====][====][====][====][====][====]
├── History ──┤├───────── Predicted ────────────────┤
9. Hands-On Exercises
Exercise 1: Understanding Patch Formation
Goal: Convert raw trajectory data into patches
# Given: Raw trajectory at 10Hz
raw_trajectory = [
(0.0, 0.0), # t=0.0s
(0.5, 0.1), # t=0.1s
(1.0, 0.2), # t=0.2s
(1.5, 0.3), # t=0.3s
(2.0, 0.4), # t=0.4s
(2.5, 0.5), # t=0.5s
(3.0, 0.6), # t=0.6s
(3.5, 0.7), # t=0.7s
(4.0, 0.8), # t=0.8s
(4.5, 0.9), # t=0.9s
(5.0, 1.0), # t=1.0s
(5.5, 1.1), # t=1.1s
# ... continues
]
# TODO: Implement this function
def create_patches(trajectory, patch_size=10):
"""
Convert raw trajectory to patches.
Args:
trajectory: List of (x, y) positions at 10Hz
patch_size: Number of timesteps per patch
Returns:
List of patches, where each patch is a list of positions
"""
patches = []
# Your code here
pass
# Expected output:
# patches[0] = [(0.0, 0.0), (0.5, 0.1), ..., (4.5, 0.9)] # First 1 second
# patches[1] = [(5.0, 1.0), (5.5, 1.1), ..., (9.5, 1.9)] # Second 1 second
<details>
<summary>Solution</summary>
def create_patches(trajectory, patch_size=10):
patches = []
for i in range(0, len(trajectory) - patch_size + 1, patch_size):
patch = trajectory[i:i + patch_size]
patches.append(patch)
return patches
</details>
Exercise 2: Computing Relative Features
Goal: Understand the relative spacetime representation
import math
# Given: Two agents' states
agent_a = {
"position": (10.0, 20.0),
"heading": math.pi / 4, # 45 degrees (NE)
"speed": 15.0,
}
agent_b = {
"position": (15.0, 25.0),
"heading": math.pi / 2, # 90 degrees (N)
"speed": 10.0,
}
# TODO: Implement relative feature computation
def compute_relative_features(agent_a, agent_b):
"""
Compute relative features between two agents.
Returns:
dict with:
- distance: Euclidean distance
- bearing: Angle from A to B (in A's frame)
- relative_heading: Heading difference
- relative_speed: Speed difference
"""
# Your code here
pass
# Expected output:
# {
# "distance": 7.07, # sqrt((15-10)^2 + (25-20)^2)
# "bearing": 0.785, # 45 degrees relative to A's heading
# "relative_heading": 0.785, # B is heading 45 deg left of A
# "relative_speed": -5.0, # B is 5 m/s slower
# }
<details>
<summary>Solution</summary>
def compute_relative_features(agent_a, agent_b):
# Distance
dx = agent_b["position"][0] - agent_a["position"][0]
dy = agent_b["position"][1] - agent_a["position"][1]
distance = math.sqrt(dx**2 + dy**2)
# Bearing (angle from A to B, relative to A's heading)
global_angle = math.atan2(dy, dx)
bearing = global_angle - agent_a["heading"]
# Normalize to [-pi, pi]
bearing = math.atan2(math.sin(bearing), math.cos(bearing))
# Relative heading
relative_heading = agent_b["heading"] - agent_a["heading"]
relative_heading = math.atan2(math.sin(relative_heading),
math.cos(relative_heading))
# Relative speed
relative_speed = agent_b["speed"] - agent_a["speed"]
return {
"distance": round(distance, 2),
"bearing": round(bearing, 3),
"relative_heading": round(relative_heading, 3),
"relative_speed": relative_speed,
}
</details>
Exercise 3: Causal Attention Mask
Goal: Understand the temporal attention pattern
import numpy as np
def create_causal_mask(num_patches):
"""
Create a causal attention mask for temporal self-attention.
A patch can only attend to itself and previous patches.
Args:
num_patches: Number of trajectory patches
Returns:
mask: [num_patches, num_patches] boolean array
True = can attend, False = cannot attend
"""
# Your code here
pass
# Test with 5 patches
mask = create_causal_mask(5)
print(mask)
# Expected output:
# [[ True False False False False]
# [ True True False False False]
# [ True True True False False]
# [ True True True True False]
# [ True True True True True]]
#
# Patch 0: Can only see itself
# Patch 1: Can see patches 0 and 1
# Patch 2: Can see patches 0, 1, and 2
# etc.
<details>
<summary>Solution</summary>
def create_causal_mask(num_patches):
return np.tril(np.ones((num_patches, num_patches), dtype=bool))
</details>
Exercise 4: GMM Sampling
Goal: Sample trajectories from a Gaussian Mixture Model
import numpy as np
def sample_from_gmm(mixture_weights, means, stds, num_samples=1):
"""
Sample trajectories from a GMM.
Args:
mixture_weights: [num_modes] probabilities summing to 1
means: [num_modes, num_timesteps, 2] mean positions
stds: [num_modes, num_timesteps, 2] standard deviations
num_samples: Number of trajectories to sample
Returns:
samples: [num_samples, num_timesteps, 2] sampled trajectories
"""
samples = []
for _ in range(num_samples):
# Step 1: Sample a mode index based on mixture weights
mode_idx = None # Your code here
# Step 2: Get mean and std for selected mode
selected_mean = None # Your code here
selected_std = None # Your code here
# Step 3: Sample from Gaussian
sample = None # Your code here
samples.append(sample)
return np.array(samples)
# Test case
np.random.seed(42)
weights = np.array([0.5, 0.3, 0.2]) # 3 modes
means = np.array([
[[0, 0], [1, 0], [2, 0]], # Mode 0: Go straight
[[0, 0], [1, 1], [2, 2]], # Mode 1: Go diagonal
[[0, 0], [0, 1], [0, 2]], # Mode 2: Go up
])
stds = np.ones_like(means) * 0.1
trajectories = sample_from_gmm(weights, means, stds, num_samples=5)
print(trajectories.shape) # Should be (5, 3, 2)
<details>
<summary>Solution</summary>
def sample_from_gmm(mixture_weights, means, stds, num_samples=1):
samples = []
for _ in range(num_samples):
# Step 1: Sample mode index
mode_idx = np.random.choice(len(mixture_weights), p=mixture_weights)
# Step 2: Get parameters for selected mode
selected_mean = means[mode_idx]
selected_std = stds[mode_idx]
# Step 3: Sample from Gaussian
noise = np.random.randn(*selected_mean.shape)
sample = selected_mean + noise * selected_std
samples.append(sample)
return np.array(samples)
</details>
Exercise 5: K-Nearest Neighbor Selection
Goal: Implement the k-NN filtering for attention
import numpy as np
def select_k_nearest(query_positions, key_positions, k):
"""
For each query, find the k nearest keys.
Args:
query_positions: [num_queries, 2] array of (x, y)
key_positions: [num_keys, 2] array of (x, y)
k: Number of nearest neighbors to select
Returns:
indices: [num_queries, k] indices of k-nearest keys
distances: [num_queries, k] distances to k-nearest keys
"""
# Your code here
pass
# Test case
np.random.seed(42)
queries = np.array([
[0, 0],
[10, 10],
])
keys = np.array([
[1, 1], # Close to query 0
[9, 9], # Close to query 1
[5, 5], # Middle
[100, 100], # Far from both
])
indices, distances = select_k_nearest(queries, keys, k=2)
print("Indices:", indices)
print("Distances:", distances)
# Expected:
# For query [0,0]: nearest are indices 0 and 2
# For query [10,10]: nearest are indices 1 and 2
<details>
<summary>Solution</summary>
def select_k_nearest(query_positions, key_positions, k):
num_queries = len(query_positions)
num_keys = len(key_positions)
# Compute all pairwise distances
# Shape: [num_queries, num_keys]
diff = query_positions[:, None, :] - key_positions[None, :, :]
distances_all = np.sqrt((diff ** 2).sum(axis=-1))
# Find k smallest distances for each query
indices = np.argsort(distances_all, axis=-1)[:, :k]
distances = np.take_along_axis(distances_all, indices, axis=-1)
return indices, distances
</details>
10. Interview Questions
Conceptual Questions
Q1: Why does BehaviorGPT use patches instead of individual timesteps?
<details> <summary>Expected Answer</summary>Key Points:
-
Prevents trivial shortcuts: At 10Hz, consecutive frames are nearly identical. A model can achieve low training loss by simply copying the previous timestep with minor noise. Patches (1 second each) have meaningful differences that force the model to actually predict future behavior.
-
Captures long-range dependencies: Patch-level tokens reduce sequence length by 10x (e.g., 80 timesteps becomes 8 patches), making it computationally feasible to model longer-range dependencies with self-attention.
-
Reduces compounding errors: In autoregressive generation, errors accumulate with each prediction step. With 10x fewer prediction steps (patches vs. timesteps), error accumulation is significantly reduced.
-
Aligns with human decision-making: Drivers don't plan microsecond by microsecond; they think in terms of maneuvers lasting seconds. Patches align with this natural granularity.
Bonus insight: The ablation study shows patch_size=10 outperforms patch_size=1 by +4.9% on REALISM.
</details>Q2: Why is the relative spacetime representation important for data efficiency?
<details> <summary>Expected Answer</summary>Key Points:
-
Position invariance: The same driving behavior at different absolute locations (intersection at (100, 200) vs. (5000, 3000)) produces identical relative features. The model learns the behavior once, not separately for each location.
-
Implicit data augmentation: Translation, rotation, and reflection augmentations are "free" - they don't change relative features. This effectively multiplies training data without additional computation.
-
Generalization: Models trained with relative features generalize better to new locations, road layouts, and scenarios not seen during training.
-
Symmetric treatment: No privileged "ego" agent. All agents are treated equally, maximizing learning from every agent's trajectory.
Concrete example: A "car following" behavior has the same relative representation regardless of whether it happens in San Francisco or Phoenix.
</details>Q3: Explain the purpose of each attention mechanism in the Triple Attention Block.
<details> <summary>Expected Answer</summary>1. Temporal Self-Attention:
- Purpose: Model how each individual agent moves over time
- Key feature: Causal masking (can only attend to past patches)
- Captures: Acceleration patterns, turning behaviors, consistency of motion
2. Agent-Map Cross-Attention:
- Purpose: Model how agents interact with road geometry
- Key feature: Queries are agent patches, keys/values are map elements
- Captures: Lane following, stopping at intersections, respecting road boundaries
3. Agent-Agent Self-Attention:
- Purpose: Model social interactions between agents
- Key feature: K-nearest neighbor filtering to limit complexity
- Captures: Following distance, yielding, collision avoidance, coordinated merging
Why this order matters:
- First understand individual motion (temporal)
- Then ground it in the environment (map)
- Finally coordinate with others (social)
This mirrors human decision-making: "What am I doing?" -> "Where am I?" -> "Who's around me?"
</details>Q4: Why does BehaviorGPT use a GMM output instead of directly predicting positions?
<details> <summary>Expected Answer</summary>Key Points:
-
Multi-modality: Future trajectories are inherently uncertain. A car at an intersection could go straight, turn left, or turn right. GMM captures multiple plausible futures with different probabilities.
-
Uncertainty quantification: The mixture weights and standard deviations provide principled uncertainty estimates. This is crucial for downstream planning and safety.
-
Training with NLL loss: Gaussian likelihoods enable proper probabilistic training. The model learns not just the mean behavior but the full distribution.
-
Sampling for simulation: During inference, we can sample diverse trajectories by (a) sampling a mode from mixture weights, then (b) sampling positions from that mode's Gaussian. This creates realistic diversity across simulation replicas.
Design choice: BehaviorGPT uses 16 mixture modes per agent. Too few modes = underfit multi-modal distributions. Too many = diluted probabilities and overfitting.
</details>Q5: How does BehaviorGPT handle the curse of dimensionality with 128 agents?
<details> <summary>Expected Answer</summary>Key Points:
-
K-nearest neighbor filtering: Instead of O(N^2) attention over all agent pairs, BehaviorGPT attends only to k=32 nearest agents. This reduces complexity to O(N * k).
-
Locality assumption: Agents primarily interact with nearby agents. A car 500 meters away has negligible influence on driving decisions. K-NN captures the relevant interactions.
-
Patch-level representation: By grouping 10 timesteps into patches, sequence length is reduced 10x, making attention over agents more tractable.
-
Independent patch-level planning: Within each patch, agents plan independently (Equation 3 in the paper). Coordination happens through attention over the previous patch's representations, not joint optimization.
-
Small model size: With only 3M parameters and d=128 hidden dimensions, the model is inherently regularized against overfitting to spurious agent-agent correlations.
Scalability: The approach scales to 128 agents per scenario, which covers essentially all real-world driving situations in the WOMD dataset.
</details>Coding Questions
Q6: Implement a simplified version of causal temporal self-attention.
def temporal_self_attention(x, num_heads):
"""
Implement causal self-attention for temporal modeling.
Args:
x: [batch, num_patches, hidden_dim] tensor
num_heads: Number of attention heads
Returns:
output: [batch, num_patches, hidden_dim] tensor
"""
# Your implementation here
pass
<details>
<summary>Solution</summary>
import jax.numpy as jnp
import jax
def temporal_self_attention(x, num_heads):
batch, num_patches, hidden_dim = x.shape
head_dim = hidden_dim // num_heads
# Linear projections for Q, K, V
# In practice, these would be learned parameters
Wq = jnp.eye(hidden_dim)
Wk = jnp.eye(hidden_dim)
Wv = jnp.eye(hidden_dim)
Q = x @ Wq # [batch, num_patches, hidden_dim]
K = x @ Wk
V = x @ Wv
# Reshape for multi-head attention
Q = Q.reshape(batch, num_patches, num_heads, head_dim).transpose(0, 2, 1, 3)
K = K.reshape(batch, num_patches, num_heads, head_dim).transpose(0, 2, 1, 3)
V = V.reshape(batch, num_patches, num_heads, head_dim).transpose(0, 2, 1, 3)
# Shape: [batch, num_heads, num_patches, head_dim]
# Compute attention scores
scores = jnp.einsum('bhqd,bhkd->bhqk', Q, K) / jnp.sqrt(head_dim)
# Shape: [batch, num_heads, num_patches, num_patches]
# Create causal mask
causal_mask = jnp.tril(jnp.ones((num_patches, num_patches)))
scores = jnp.where(causal_mask, scores, -1e9)
# Softmax and apply to values
attn_weights = jax.nn.softmax(scores, axis=-1)
output = jnp.einsum('bhqk,bhkd->bhqd', attn_weights, V)
# Reshape back
output = output.transpose(0, 2, 1, 3).reshape(batch, num_patches, hidden_dim)
return output
</details>
Q7: Implement the autoregressive rollout loop.
def rollout(model, history, num_future_patches, rng_key):
"""
Generate future trajectories autoregressively.
Args:
model: A function that takes patches and returns GMM params
history: [batch, num_agents, num_history_patches, hidden_dim]
num_future_patches: Number of patches to generate
rng_key: JAX random key
Returns:
future: [batch, num_agents, num_future_patches, hidden_dim]
"""
# Your implementation here
pass
<details>
<summary>Solution</summary>
def rollout(model, history, num_future_patches, rng_key):
current = history
predictions = []
for i in range(num_future_patches):
# Split key for this iteration
rng_key, step_key = jax.random.split(rng_key)
# Get model prediction for next patch
gmm_params = model(current)
# Extract last patch prediction
last_weights = gmm_params["weights"][:, :, -1]
last_means = gmm_params["means"][:, :, -1]
last_stds = gmm_params["stds"][:, :, -1]
# Sample mode for each agent
mode_key, sample_key = jax.random.split(step_key)
mode_indices = jax.random.categorical(
mode_key,
jnp.log(last_weights + 1e-8),
axis=-1
)
# Gather selected mode parameters
batch, num_agents, num_modes, dim = last_means.shape
batch_idx = jnp.arange(batch)[:, None]
agent_idx = jnp.arange(num_agents)[None, :]
selected_mean = last_means[batch_idx, agent_idx, mode_indices]
selected_std = last_stds[batch_idx, agent_idx, mode_indices]
# Sample from Gaussian
noise = jax.random.normal(sample_key, selected_mean.shape)
next_patch = selected_mean + noise * selected_std
# Append to sequence
predictions.append(next_patch[:, :, None, :])
current = jnp.concatenate([current, next_patch[:, :, None, :]], axis=2)
return jnp.concatenate(predictions, axis=2)
</details>
System Design Questions
Q8: How would you modify BehaviorGPT for real-time inference on an autonomous vehicle?
<details> <summary>Expected Answer</summary>Key Considerations:
-
Latency Requirements:
- AV systems typically require predictions within 50-100ms
- BehaviorGPT's 3M parameters are already very efficient
- May need to reduce num_modes, k_nearest, or hidden_dim further
-
Streaming Architecture:
- Maintain a rolling buffer of recent patches
- Implement KV-caching for temporal attention (don't recompute history)
- Process new observations incrementally
-
Hardware Optimization:
- Quantization (FP16 or INT8) for faster inference
- TensorRT or ONNX optimization
- Batching across multiple scenarios if running in simulation
-
Reduced Rollout:
- For real-time use, may only need 2-3 second predictions
- Replan every 0.5 seconds (2Hz) rather than full 8-second rollouts
-
Uncertainty Thresholding:
- In production, may want to flag high-uncertainty predictions
- Use mixture entropy as a confidence measure
Sample Architecture:
Sensor Input -> Preprocessing (10ms) ->
BehaviorGPT Inference (30ms) ->
Post-processing (10ms) -> Planning Module
Total: ~50ms (20Hz)
</details>
Q9: How would you extend BehaviorGPT to handle traffic lights and dynamic map elements?
<details> <summary>Expected Answer</summary>Key Extensions:
-
Traffic Light Encoding:
- Add traffic light states (red/yellow/green) as additional map features
- Include time-varying traffic light sequences
- Encode expected time until state change
-
Dynamic Map Elements:
- Treat dynamic elements (construction zones, temporary signs) as special map tokens
- Add temporal features: when element appeared, expected duration
-
Cross-Attention Modification:
- Agent-Map cross-attention already supports this
- Add traffic light embeddings to the map token pool
- Use attention to learn which lights are relevant to each agent
-
Training Data Requirements:
- Need annotations for traffic light states over time
- Waymo Open Dataset includes some traffic light annotations
Implementation Sketch:
map_features = concat([
road_segment_embeddings, # Static
lane_boundary_embeddings, # Static
traffic_light_embeddings, # Dynamic - varies per patch
construction_zone_embeddings, # Semi-static
])
- Challenges:
- Traffic light state changes need to be synchronized across patches
- May need to predict traffic light futures jointly with agent futures
Q10: If you had unlimited compute but still needed real-time inference, how would you improve BehaviorGPT?
<details> <summary>Expected Answer</summary>Scaling Strategies:
-
Model Scaling (More Parameters):
- Increase hidden_dim: 128 -> 512 or 1024
- Add more transformer layers: 6 -> 12 or 24
- More attention heads with larger head dimensions
-
Ensemble Methods:
- Train multiple BehaviorGPT variants with different initializations
- Aggregate predictions for better uncertainty calibration
- Use mixture-of-experts architecture
-
Multi-Scale Patches:
- Combine fine-grained (0.5s) and coarse-grained (2s) patches
- Hierarchical prediction: coarse trajectory first, then refine
-
Richer Input Representations:
- Include velocity/acceleration profiles, not just positions
- Add semantic lane information (turn lanes, merge lanes)
- Incorporate intent signals (turn signals, brake lights)
-
Training Improvements:
- Larger batch sizes for better gradient estimates
- Longer training with more epochs
- Curriculum learning: easy scenarios first
-
For Real-Time Inference:
- Distillation: Train large model, distill to small model
- Speculative decoding: Use small model for most predictions, large model for verification
- Dynamic compute allocation: Simple scenarios use small model, complex ones use large
The BehaviorGPT lesson: Sometimes simpler is better. The 3M parameter model won against 500M+ models, suggesting architectural innovations matter more than raw scale for this task.
</details>Summary
BehaviorGPT represents a paradigm shift in autonomous driving simulation:
┌─────────────────────────────────────────────────────────────────┐
│ Key Takeaways │
├─────────────────────────────────────────────────────────────────┤
│ │
│ 1. NEXT-PATCH PREDICTION (NP3) │
│ Think in seconds, not milliseconds │
│ │
│ 2. RELATIVE SPACETIME │
│ Position-independent = better generalization │
│ │
│ 3. TRIPLE ATTENTION │
│ Time + Environment + Social = Complete context │
│ │
│ 4. EFFICIENT DESIGN │
│ 3M params beat 500M+ through smart architecture │
│ │
│ 5. AUTOREGRESSIVE GENERATION │
│ Same paradigm that powers GPT, now for driving │
│ │
└─────────────────────────────────────────────────────────────────┘
The success of BehaviorGPT demonstrates that the transformer architecture, when properly adapted to the domain, can achieve state-of-the-art results even with modest compute budgets. The key insight - treating trajectory prediction as "next patch prediction" analogous to "next token prediction" - elegantly bridges the gap between language modeling and behavior modeling.
Further Reading
- Original Paper: arXiv:2405.17372
- WOSAC Challenge: Waymo Open Sim Agents Challenge
- Waymo Open Motion Dataset: WOMD Documentation
- Related Work: MotionLM, Trajeglish, MVTE, SceneTransformer
This blog was created as an interactive learning resource. For corrections or suggestions, please open an issue in the repository.