Back to all papers
Deep Dive #540 min read

BehaviorGPT Deep Dive

State-of-the-art sim agent modeling with transformers, Next-Patch Prediction, and the 2024 WOSAC winner approach.

BehaviorGPT: A Deep Dive into the WOSAC 2024 Champion

Paper: BehaviorGPT: Smart Agent Simulation for Autonomous Driving with Next-Patch Prediction Authors: Waymo & Google Research arXiv: 2405.17372 Achievement: 1st Place, WOSAC (Waymo Open Sim Agents Challenge) 2024


Table of Contents

  1. Executive Summary
  2. Architecture Deep Dive
  3. Key Innovation: Next-Patch Prediction (NP3)
  4. Relative Spacetime Representation
  5. Interactive Code Examples
  6. Training Strategy
  7. WOSAC Metrics Breakdown
  8. Mental Models
  9. Hands-On Exercises
  10. Interview Questions

1. Executive Summary

The David vs Goliath Story

BehaviorGPT won the 2024 WOSAC Challenge with a remarkable achievement: just 3 million parameters outperformed competitors with 10-523 million parameters. This is like a compact car winning a Formula 1 race against supercars.

What Made It Win?

+------------------------------------------------------------------+
|                    BehaviorGPT's Winning Formula                  |
+------------------------------------------------------------------+
|                                                                   |
|  1. NEXT-PATCH PREDICTION (NP3)                                  |
|     - Predict trajectories in 1-second chunks, not frame-by-frame|
|     - Forces the model to "think ahead" like humans do           |
|                                                                   |
|  2. FULLY AUTOREGRESSIVE DESIGN                                  |
|     - No encoder-decoder split                                   |
|     - Treats all trajectory data uniformly                       |
|                                                                   |
|  3. RELATIVE SPACETIME REPRESENTATION                            |
|     - Position-independent coordinate system                     |
|     - Maximizes learning from every data point                   |
|                                                                   |
|  4. TRIPLE ATTENTION MECHANISM                                   |
|     - Temporal: How does this agent move over time?              |
|     - Agent-Map: How does this agent interact with roads?        |
|     - Agent-Agent: How do agents interact with each other?       |
|                                                                   |
+------------------------------------------------------------------+

Key Results

MetricBehaviorGPTMeaning
minADE1.4147Lowest prediction error
REALISM0.7473Best match to real-world behavior
COLLISION0.9537Excellent at avoiding crashes
OFFROAD0.9349Stays on the road reliably

2. Architecture Deep Dive

The "GPT" in BehaviorGPT

Just like GPT predicts the next word in a sentence, BehaviorGPT predicts the next trajectory patch for all agents in a driving scene.

Language Model (GPT):
"The car drove down the" -> [road] -> [and] -> [turned] -> [left]

BehaviorGPT:
[Agent positions at t=0-1s] -> [t=1-2s] -> [t=2-3s] -> [t=3-4s]

Full Architecture Diagram

                         BehaviorGPT Architecture
 ═══════════════════════════════════════════════════════════════════

 INPUT STAGE
 ───────────────────────────────────────────────────────────────────

     Agent Trajectories              Map Elements
    ┌─────────────────┐            ┌─────────────────┐
    │  Agent 1: P1-P9 │            │  Road segments  │
    │  Agent 2: P1-P9 │            │  Lane markers   │
    │  ...            │            │  Traffic signs  │
    │  Agent N: P1-P9 │            │  Crosswalks     │
    └────────┬────────┘            └────────┬────────┘
             │                              │
             ▼                              ▼
    ┌─────────────────┐            ┌─────────────────┐
    │ Patch Encoder   │            │ Polyline Encoder│
    │ (MLP + PE)      │            │ (MLP + PE)      │
    └────────┬────────┘            └────────┬────────┘
             │                              │
             └──────────────┬───────────────┘
                            ▼

 TRANSFORMER BLOCKS (×L layers)
 ───────────────────────────────────────────────────────────────────

 For each layer l = 1 to L:

    ┌─────────────────────────────────────────────────────────────┐
    │                                                             │
    │   ┌─────────────────────────────────────────────────────┐  │
    │   │         1. TEMPORAL SELF-ATTENTION                  │  │
    │   │                                                     │  │
    │   │   Agent 1: [P1]──[P2]──[P3]──[P4]──[P5]──...      │  │
    │   │              │    │    │    │    │                 │  │
    │   │              └────┴────┴────┴────┘                 │  │
    │   │              (causal mask: can only attend left)   │  │
    │   │                                                     │  │
    │   │   "How has this agent been moving over time?"      │  │
    │   └─────────────────────────────────────────────────────┘  │
    │                            │                                │
    │                            ▼                                │
    │   ┌─────────────────────────────────────────────────────┐  │
    │   │         2. AGENT-MAP CROSS-ATTENTION                │  │
    │   │                                                     │  │
    │   │   Agent Patch ──────► k-Nearest Map Elements       │  │
    │   │        │                    │                       │  │
    │   │        └────────────────────┘                       │  │
    │   │         Query          Keys/Values                  │  │
    │   │                                                     │  │
    │   │   "What road features should this agent consider?" │  │
    │   └─────────────────────────────────────────────────────┘  │
    │                            │                                │
    │                            ▼                                │
    │   ┌─────────────────────────────────────────────────────┐  │
    │   │         3. AGENT-AGENT SELF-ATTENTION               │  │
    │   │                                                     │  │
    │   │      Agent 1 ◄──────► Agent 2                      │  │
    │   │         │    ╲      ╱    │                         │  │
    │   │         │     ╲    ╱     │                         │  │
    │   │         │      ╲  ╱      │                         │  │
    │   │         ▼       ╲╱       ▼                         │  │
    │   │      Agent 3 ◄──╳───► Agent 4                      │  │
    │   │              (k-nearest filtering)                  │  │
    │   │                                                     │  │
    │   │   "How are nearby agents affecting this one?"      │  │
    │   └─────────────────────────────────────────────────────┘  │
    │                                                             │
    └─────────────────────────────────────────────────────────────┘

 OUTPUT STAGE
 ───────────────────────────────────────────────────────────────────

                            │
                            ▼
              ┌──────────────────────────┐
              │     GMM Decoder          │
              │   (Gaussian Mixture)     │
              └────────────┬─────────────┘
                           │
                           ▼
              ┌──────────────────────────┐
              │  Per-agent predictions:  │
              │  - 16 mixture modes      │
              │  - Mean (μ) per mode     │
              │  - Covariance (Σ)        │
              │  - Mixture weights (π)   │
              └──────────────────────────┘

Model Specifications

# BehaviorGPT Hyperparameters
config = {
    "hidden_dim": 128,           # Embedding dimension
    "num_heads": 8,              # Attention heads
    "head_dim": 16,              # 128 / 8
    "num_layers": 6,             # Transformer blocks
    "patch_size": 10,            # Timesteps per patch (1 second @ 10Hz)
    "max_agents": 128,           # Maximum agents per scenario
    "num_modes": 16,             # GMM mixture components
    "k_nearest_map": 256,        # Map elements to attend to
    "k_nearest_agents": 32,      # Agents to attend to
    "total_params": "3M"         # Remarkably small!
}

3. Key Innovation: Next-Patch Prediction (NP3)

The Problem with Frame-by-Frame Prediction

Traditional approaches predict one timestep at a time (10Hz = every 0.1 seconds):

Traditional Approach:
┌─────┐    ┌─────┐    ┌─────┐    ┌─────┐
│ t=0 │───►│ t=1 │───►│ t=2 │───►│ t=3 │  ... (×80 for 8 seconds)
└─────┘    └─────┘    └─────┘    └─────┘
   │          │          │          │
   └──────────┴──────────┴──────────┘
            Easy shortcut: just copy previous!

The Problem: At 10Hz, the difference between consecutive frames is tiny. The model learns a lazy shortcut - just copy the previous position with minor adjustments. This leads to:

  • Compounding errors over time
  • Unrealistic "drifting" behavior
  • Poor long-horizon predictions

The NP3 Solution

Next-Patch Prediction (NP3):
┌───────────────┐    ┌───────────────┐    ┌───────────────┐
│   Patch 1     │───►│   Patch 2     │───►│   Patch 3     │
│  (t=0 to 1s)  │    │  (t=1 to 2s)  │    │  (t=2 to 3s)  │
│  10 timesteps │    │  10 timesteps │    │  10 timesteps │
└───────────────┘    └───────────────┘    └───────────────┘
       │                    │                    │
       └────────────────────┴────────────────────┘
              Significant changes between patches!
              Model MUST reason about trajectory

Why Patches Work: An Analogy

Think of writing a story:

Word-by-Word (like frame-by-frame):
┌─────┐   ┌─────┐   ┌─────┐   ┌─────┐   ┌─────┐   ┌─────┐
│ The │──►│ car │──►│ was │──►│ on  │──►│ the │──►│road │
└─────┘   └─────┘   └─────┘   └─────┘   └─────┘   └─────┘
  Easy to just predict common word patterns without understanding!

Paragraph-by-Paragraph (like NP3):
┌──────────────────────┐    ┌──────────────────────┐
│ "The red sports car  │───►│ "It accelerated      │
│  was approaching     │    │  through the curve,  │
│  the intersection."  │    │  narrowly avoiding   │
│                      │    │  the truck."         │
└──────────────────────┘    └──────────────────────┘
  Must understand MEANING to generate coherent next paragraph!

Mathematical Formulation

The joint probability of all agent trajectories:

Standard (per-timestep):
P(S₁:N₁:T | M) = ∏ₜ ∏ₙ P(sⁿₜ | s¹:ᴺ₁:ₜ₋₁, M)

NP3 (per-patch):
P(S₁:N₁:T | M) = ∏ₚ P(Pₚ | P₁:ₚ₋₁, M)

Where:
- S = All agent states
- N = Number of agents
- T = Total timesteps
- M = Map information
- P = Trajectory patches (10 timesteps each)

Patch Decomposition

Within each patch, predictions decompose hierarchically:

                    Patch-Level Prediction
                           │
           ┌───────────────┼───────────────┐
           ▼               ▼               ▼
       Agent 1         Agent 2    ...  Agent N
           │               │               │
     ┌─────┴─────┐   ┌─────┴─────┐   ┌─────┴─────┐
     ▼     ▼     ▼   ▼     ▼     ▼   ▼     ▼     ▼
    t1    t2   ...  t1    t2   ...  t1    t2   ...

Each timestep within patch:
  P(sⁿₜ | sⁿ₁:ₜ₋₁, Pᵃˡˡ₁:ₚ₋₁) ← Gaussian Mixture Model

4. Relative Spacetime Representation

The Coordinate Frame Problem

Traditional approaches use absolute coordinates or ego-centric views:

Absolute Coordinates Problem:
┌──────────────────────────────────────────────┐
│                                              │
│  Car A at (100, 200)                         │
│  Car B at (105, 203)                         │
│                                              │
│  Different scenarios have different coords!  │
│  Model must memorize locations = inefficient │
│                                              │
└──────────────────────────────────────────────┘

Ego-Centric Problem:
┌──────────────────────────────────────────────┐
│                                              │
│  Everything relative to ONE "ego" agent      │
│  But who is "ego" in multi-agent simulation? │
│  Different agents see different views!       │
│                                              │
└──────────────────────────────────────────────┘

BehaviorGPT's Solution: Relative Spacetime

Every element is described by its relationship to OTHER elements:

Relative Spacetime Representation
═════════════════════════════════════════════════════════════

For any two elements (patches, agents, or map features):

┌─────────────────────────────────────────────────────────┐
│                                                         │
│    Element i                    Element j               │
│        ●─────────────────────────────●                 │
│        │                             │                 │
│        │     Relative Features:      │                 │
│        │                             │                 │
│        │  • Distance: ||pᵢ - pⱼ||    │                 │
│        │  • Angle: atan2(Δy, Δx)     │                 │
│        │  • Yaw diff: θᵢ - θⱼ        │                 │
│        │  • Time diff: tᵢ - tⱼ       │                 │
│        │                             │                 │
└─────────────────────────────────────────────────────────┘

Plus Self Features:
• Speed, acceleration
• Agent type (vehicle, pedestrian, cyclist)
• Bounding box dimensions
• Valid/invalid mask

Feature Vector Composition

# For each trajectory patch, the feature vector includes:
patch_features = {
    # Semantic (self) features - position independent
    "speed": float,              # Current velocity magnitude
    "acceleration": float,       # Rate of change
    "yaw_rate": float,          # Turning rate
    "agent_type": one_hot,      # [vehicle, pedestrian, cyclist]
    "bbox": [length, width],    # Agent dimensions

    # Relative features - computed pairwise
    "relative_distance": float,  # Euclidean distance
    "relative_angle": float,     # Bearing angle
    "relative_yaw": float,       # Heading difference
    "time_offset": int,          # Temporal distance in patches
}

# Total feature dimension
d_input = d_semantic + d_relative  # Fed to MLP encoder

Why This Matters: Data Efficiency

Traditional (Position-Dependent):
┌────────────────────────────────────────────────────────┐
│  Scenario 1: Intersection at (100, 200)               │
│  Scenario 2: Intersection at (5000, 3000)             │
│                                                        │
│  Same behavior, but model sees completely different   │
│  numbers! Must learn the same thing twice.            │
└────────────────────────────────────────────────────────┘

BehaviorGPT (Relative):
┌────────────────────────────────────────────────────────┐
│  Scenario 1: Car A is 5m behind Car B, approaching    │
│  Scenario 2: Car A is 5m behind Car B, approaching    │
│                                                        │
│  SAME representation regardless of absolute position! │
│  Every example teaches the same concept once.         │
└────────────────────────────────────────────────────────┘

5. Interactive Code Examples

5.1 Patch Encoder

import jax
import jax.numpy as jnp
from flax import linen as nn

class PatchEncoder(nn.Module):
    """Encodes a trajectory patch into a fixed-size embedding."""
    hidden_dim: int = 128
    patch_size: int = 10  # 1 second at 10Hz

    @nn.compact
    def __call__(self, patch_features):
        """
        Args:
            patch_features: [batch, num_agents, num_patches, patch_size, d_input]
        Returns:
            patch_embeddings: [batch, num_agents, num_patches, hidden_dim]
        """
        batch, num_agents, num_patches, patch_size, d_input = patch_features.shape

        # Flatten patch timesteps
        x = patch_features.reshape(batch, num_agents, num_patches, -1)
        # Shape: [batch, num_agents, num_patches, patch_size * d_input]

        # MLP encoder
        x = nn.Dense(self.hidden_dim)(x)
        x = nn.relu(x)
        x = nn.Dense(self.hidden_dim)(x)
        x = nn.LayerNorm()(x)

        # Add positional encoding for patch index
        patch_indices = jnp.arange(num_patches)
        pos_encoding = self._sinusoidal_encoding(patch_indices, self.hidden_dim)
        x = x + pos_encoding[None, None, :, :]  # Broadcast over batch and agents

        return x

    def _sinusoidal_encoding(self, positions, dim):
        """Standard sinusoidal positional encoding."""
        pe = jnp.zeros((len(positions), dim))
        div_term = jnp.exp(jnp.arange(0, dim, 2) * -(jnp.log(10000.0) / dim))
        pe = pe.at[:, 0::2].set(jnp.sin(positions[:, None] * div_term))
        pe = pe.at[:, 1::2].set(jnp.cos(positions[:, None] * div_term))
        return pe

5.2 Triple Attention Block

class TripleAttentionBlock(nn.Module):
    """One layer of BehaviorGPT's triple attention mechanism."""
    hidden_dim: int = 128
    num_heads: int = 8
    k_nearest_map: int = 256
    k_nearest_agents: int = 32

    @nn.compact
    def __call__(self, agent_embeddings, map_embeddings, agent_positions):
        """
        Args:
            agent_embeddings: [batch, num_agents, num_patches, hidden_dim]
            map_embeddings: [batch, num_map_elements, hidden_dim]
            agent_positions: [batch, num_agents, num_patches, 2] (x, y)
        """
        batch, num_agents, num_patches, hidden_dim = agent_embeddings.shape

        # ═══════════════════════════════════════════════════════════════
        # STEP 1: Temporal Self-Attention (per agent, across time)
        # ═══════════════════════════════════════════════════════════════

        # Reshape to [batch * num_agents, num_patches, hidden_dim]
        x = agent_embeddings.reshape(batch * num_agents, num_patches, hidden_dim)

        # Causal mask: each patch can only attend to previous patches
        causal_mask = jnp.tril(jnp.ones((num_patches, num_patches)))

        x = x + CausalSelfAttention(
            num_heads=self.num_heads,
            name="temporal_attention"
        )(x, mask=causal_mask)

        x = x.reshape(batch, num_agents, num_patches, hidden_dim)

        # ═══════════════════════════════════════════════════════════════
        # STEP 2: Agent-Map Cross-Attention
        # ═══════════════════════════════════════════════════════════════

        # For each agent-patch, find k-nearest map elements
        # agent_positions: [batch, num_agents, num_patches, 2]
        # map_positions: [batch, num_map_elements, 2] (centroid of each element)

        map_positions = self._get_map_centroids(map_embeddings)

        # Compute distances and select k-nearest
        # This creates a sparse attention pattern
        nearest_map_indices = self._k_nearest(
            agent_positions, map_positions, k=self.k_nearest_map
        )

        # Gather relevant map embeddings
        selected_map = jnp.take_along_axis(
            map_embeddings[:, None, None, :, :],
            nearest_map_indices[:, :, :, :, None],
            axis=3
        )

        # Cross-attention: agents query, map elements are keys/values
        x = x + CrossAttention(
            num_heads=self.num_heads,
            name="agent_map_attention"
        )(query=x, key_value=selected_map)

        # ═══════════════════════════════════════════════════════════════
        # STEP 3: Agent-Agent Self-Attention
        # ═══════════════════════════════════════════════════════════════

        # For each patch timestep, attend across agents
        # Reshape to [batch * num_patches, num_agents, hidden_dim]
        x = x.transpose(0, 2, 1, 3).reshape(batch * num_patches, num_agents, hidden_dim)

        # Find k-nearest agents at each timestep
        agent_pos_per_patch = agent_positions.transpose(0, 2, 1, 3)
        agent_pos_per_patch = agent_pos_per_patch.reshape(batch * num_patches, num_agents, 2)

        # Create sparse attention mask based on k-nearest
        agent_attention_mask = self._create_knn_mask(
            agent_pos_per_patch, k=self.k_nearest_agents
        )

        x = x + SelfAttention(
            num_heads=self.num_heads,
            name="agent_agent_attention"
        )(x, mask=agent_attention_mask)

        # Reshape back
        x = x.reshape(batch, num_patches, num_agents, hidden_dim)
        x = x.transpose(0, 2, 1, 3)  # [batch, num_agents, num_patches, hidden_dim]

        # Feed-forward network
        x = x + FeedForward(hidden_dim=hidden_dim)(x)

        return x

5.3 GMM Output Head

class GMMOutputHead(nn.Module):
    """Predicts Gaussian Mixture Model parameters for next patch."""
    hidden_dim: int = 128
    num_modes: int = 16
    patch_size: int = 10
    output_dim: int = 2  # (x, y) positions

    @nn.compact
    def __call__(self, hidden_states):
        """
        Args:
            hidden_states: [batch, num_agents, num_patches, hidden_dim]
        Returns:
            gmm_params: Dictionary with mixture parameters
        """
        batch, num_agents, num_patches, _ = hidden_states.shape

        # Project to GMM parameters
        # For each mode: mean, log_std, and mixture weight

        # Mixture logits: [batch, num_agents, num_patches, num_modes]
        mixture_logits = nn.Dense(self.num_modes, name="mixture_logits")(hidden_states)
        mixture_weights = jax.nn.softmax(mixture_logits, axis=-1)

        # Per-mode predictions
        # Shape: [batch, num_agents, num_patches, num_modes, patch_size * output_dim]
        means = nn.Dense(
            self.num_modes * self.patch_size * self.output_dim,
            name="means"
        )(hidden_states)
        means = means.reshape(
            batch, num_agents, num_patches, self.num_modes,
            self.patch_size, self.output_dim
        )

        # Log standard deviations (for numerical stability)
        log_stds = nn.Dense(
            self.num_modes * self.patch_size * self.output_dim,
            name="log_stds"
        )(hidden_states)
        log_stds = log_stds.reshape(
            batch, num_agents, num_patches, self.num_modes,
            self.patch_size, self.output_dim
        )
        stds = jnp.exp(jnp.clip(log_stds, -5, 5))  # Clip for stability

        return {
            "mixture_weights": mixture_weights,  # [B, N, P, M]
            "means": means,                       # [B, N, P, M, T, 2]
            "stds": stds,                         # [B, N, P, M, T, 2]
        }

    def sample(self, gmm_params, rng_key):
        """Sample trajectories from the GMM."""
        weights = gmm_params["mixture_weights"]
        means = gmm_params["means"]
        stds = gmm_params["stds"]

        # Sample mode for each agent
        mode_key, sample_key = jax.random.split(rng_key)
        mode_indices = jax.random.categorical(mode_key, jnp.log(weights), axis=-1)

        # Gather selected mode parameters
        # This is a simplified version - actual implementation uses vmap
        selected_means = jnp.take_along_axis(
            means, mode_indices[..., None, None, None], axis=3
        ).squeeze(3)
        selected_stds = jnp.take_along_axis(
            stds, mode_indices[..., None, None, None], axis=3
        ).squeeze(3)

        # Sample from Gaussian
        noise = jax.random.normal(sample_key, selected_means.shape)
        samples = selected_means + noise * selected_stds

        return samples

5.4 Training Loss

def compute_nll_loss(gmm_params, ground_truth, valid_mask):
    """
    Negative log-likelihood loss for GMM predictions.

    Args:
        gmm_params: Dict with 'mixture_weights', 'means', 'stds'
        ground_truth: [batch, num_agents, num_patches, patch_size, 2]
        valid_mask: [batch, num_agents, num_patches] boolean
    """
    weights = gmm_params["mixture_weights"]  # [B, N, P, M]
    means = gmm_params["means"]              # [B, N, P, M, T, 2]
    stds = gmm_params["stds"]                # [B, N, P, M, T, 2]

    # Expand ground truth for broadcasting with modes
    gt_expanded = ground_truth[:, :, :, None, :, :]  # [B, N, P, 1, T, 2]

    # Compute Gaussian log-probabilities for each mode
    # log p(x|mode) = -0.5 * ((x - μ)² / σ² + log(2π) + 2*log(σ))
    var = stds ** 2
    log_probs = -0.5 * (
        ((gt_expanded - means) ** 2) / var +
        jnp.log(2 * jnp.pi) +
        2 * jnp.log(stds)
    )

    # Sum over timesteps and dimensions within each mode
    log_probs = log_probs.sum(axis=(-1, -2))  # [B, N, P, M]

    # Log-sum-exp over modes (mixture probability)
    # log p(x) = log Σ_m π_m * p(x|m) = logsumexp(log π_m + log p(x|m))
    log_mixture_probs = jax.scipy.special.logsumexp(
        jnp.log(weights + 1e-8) + log_probs,
        axis=-1
    )  # [B, N, P]

    # Apply valid mask and compute mean
    nll = -log_mixture_probs
    nll = jnp.where(valid_mask, nll, 0.0)
    loss = nll.sum() / (valid_mask.sum() + 1e-8)

    return loss

5.5 Autoregressive Inference Loop

def autoregressive_rollout(model, initial_history, map_data, num_future_patches, rng_key):
    """
    Generate future trajectories autoregressively.

    Args:
        model: Trained BehaviorGPT model
        initial_history: [batch, num_agents, history_patches, patch_size, features]
        map_data: Static map information
        num_future_patches: Number of 1-second patches to predict
        rng_key: JAX random key
    """
    batch_size, num_agents, _, patch_size, _ = initial_history.shape

    # Start with history
    current_sequence = initial_history
    all_predictions = []

    for patch_idx in range(num_future_patches):
        # Split random key for this iteration
        rng_key, sample_key = jax.random.split(rng_key)

        # Forward pass through model
        # Model sees all patches up to current time
        gmm_params = model.apply(
            {"params": model.params},
            agent_patches=current_sequence,
            map_data=map_data
        )

        # Extract prediction for the LAST patch (next-patch prediction)
        last_patch_params = {
            "mixture_weights": gmm_params["mixture_weights"][:, :, -1:, :],
            "means": gmm_params["means"][:, :, -1:, :, :, :],
            "stds": gmm_params["stds"][:, :, -1:, :, :, :],
        }

        # Sample next patch from GMM
        next_patch = model.output_head.sample(last_patch_params, sample_key)
        # Shape: [batch, num_agents, 1, patch_size, 2]

        # Convert sampled positions to full features
        next_patch_features = compute_patch_features(
            next_patch,
            current_sequence[:, :, -1:, :, :]  # Previous patch for relative features
        )

        # Append to sequence
        current_sequence = jnp.concatenate(
            [current_sequence, next_patch_features], axis=2
        )

        all_predictions.append(next_patch)

        # Print progress (for debugging)
        print(f"Generated patch {patch_idx + 1}/{num_future_patches}")

    # Stack all predictions
    future_trajectory = jnp.concatenate(all_predictions, axis=2)
    # Shape: [batch, num_agents, num_future_patches, patch_size, 2]

    return future_trajectory

6. Training Strategy

How to Train a Champion with 3M Parameters

┌─────────────────────────────────────────────────────────────────┐
│              BehaviorGPT Training Configuration                 │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  Dataset: Waymo Open Motion Dataset (WOMD)                     │
│  ─────────────────────────────────────────                     │
│  • 570+ hours of driving data                                  │
│  • 1750 unique scenarios per segment                           │
│  • 10 Hz sampling rate                                         │
│  • Rich annotations: vehicles, pedestrians, cyclists           │
│                                                                 │
│  Hardware                                                       │
│  ─────────                                                      │
│  • 8× NVIDIA RTX 4090 GPUs                                     │
│  • Distributed data parallel training                          │
│                                                                 │
│  Optimization                                                   │
│  ────────────                                                   │
│  • Optimizer: AdamW                                            │
│  • Learning rate: 5 × 10⁻⁴                                     │
│  • LR schedule: Cosine annealing with warm restarts            │
│  • Weight decay: 0.1                                           │
│  • Dropout: 0.1                                                │
│  • Gradient clipping: 1.0                                      │
│                                                                 │
│  Training Schedule                                              │
│  ─────────────────                                              │
│  • Epochs: 30                                                  │
│  • Batch size: 24 scenarios                                    │
│  • Training time: ~24 hours                                    │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

Training Curriculum

Phase 1: Basic Motion Patterns (Epochs 1-10)
═══════════════════════════════════════════
• Model learns fundamental vehicle dynamics
• Straight-line motion, basic turns
• Speed profiles matching road types

Phase 2: Multi-Agent Coordination (Epochs 11-20)
══════════════════════════════════════════════════
• Focus on agent-agent interactions
• Following distance, lane changes
• Yielding and merging behaviors

Phase 3: Complex Scenarios (Epochs 21-30)
═════════════════════════════════════════
• Fine-tuning on challenging cases
• Intersections, roundabouts
• Rare but critical events

Learning Rate Schedule

Learning Rate
     ↑
5e-4 │    ╭──────╮
     │   ╱        ╲
     │  ╱          ╲         ╭──────╮
     │ ╱            ╲       ╱        ╲
     │╱              ╲     ╱          ╲
1e-5 │────────────────╲───╱────────────╲──────────►
     │                 ╲_╱              ╲_________
     └─────────────────────────────────────────────►
         5      10     15     20     25     30    Epoch

         Cosine Annealing with Warm Restarts

Data Augmentation

def augment_scenario(scenario, rng_key):
    """Augmentations that preserve relative spacetime representation."""

    # These augmentations are FREE with relative representation!
    augmentations = {
        # Global translation: Relative distances unchanged
        "translate": random_translation(rng_key),

        # Global rotation: Relative angles unchanged
        "rotate": random_rotation(rng_key),

        # Mirror flip: Still valid driving scenario
        "flip": random_horizontal_flip(rng_key),
    }

    # Note: Because BehaviorGPT uses RELATIVE features,
    # these augmentations don't change the input representation!
    # This means effective 8x data augmentation for free.

    return apply_augmentations(scenario, augmentations)

Ablation Study Results

┌─────────────────────────────────────────────────────────────────┐
│                     Ablation Study Results                       │
├───────────────────────────┬───────────┬───────────┬─────────────┤
│ Configuration             │  REALISM  │ COLLISION │   minADE    │
├───────────────────────────┼───────────┼───────────┼─────────────┤
│ Full BehaviorGPT          │   0.747   │   0.954   │    1.415    │
│ w/o Patching (patch=1)    │   0.712   │   0.923   │    1.582    │
│ w/o Agent-Agent Attention │   0.681   │   0.624   │    1.498    │
│ w/o Agent-Map Attention   │   0.698   │   0.912   │    1.521    │
│ w/o Temporal Attention    │   0.654   │   0.889   │    1.634    │
│ Patch size = 5            │   0.731   │   0.941   │    1.446    │
│ Patch size = 20           │   0.724   │   0.932   │    1.478    │
└───────────────────────────┴───────────┴───────────┴─────────────┘

Key Findings:
• Patching provides +4.9% REALISM improvement
• Agent-Agent Attention is CRITICAL: -34.6% collision realism without it
• Patch size 10 (1 second) is optimal for 10Hz data

7. WOSAC Metrics Breakdown

Understanding the Challenge

The Waymo Open Sim Agents Challenge (WOSAC) evaluates how well simulated agents match real-world driving behavior.

WOSAC Evaluation Pipeline
═════════════════════════════════════════════════════════════════

Input: Scenario with 1.1s of history
       ↓
┌──────────────────────────────────────────────────────────────┐
│                                                              │
│   Model generates 32 simulation replicas                     │
│   Each replica: 8 seconds of future trajectories             │
│   For ALL agents simultaneously                              │
│                                                              │
└──────────────────────────────────────────────────────────────┘
       ↓
┌──────────────────────────────────────────────────────────────┐
│                                                              │
│   Compare simulated distribution vs. real-world logs         │
│                                                              │
│   Metrics:                                                   │
│   ├── Realism Meta-Metric (overall score)                   │
│   ├── Kinematic Metrics (motion quality)                    │
│   ├── Interactive Metrics (multi-agent behavior)            │
│   └── Map-based Metrics (road compliance)                   │
│                                                              │
└──────────────────────────────────────────────────────────────┘

Metric Definitions

1. REALISM (Meta-Metric)

REALISM = Aggregate score combining all metrics
         Higher is better (max 1.0)

BehaviorGPT: 0.7473 (1st place!)

Interpretation:
┌────────────────────────────────────────────────────────────┐
│  0.0 ─────────── 0.5 ─────────── 0.75 ─────────── 1.0    │
│   │               │               ↑                  │    │
│ Random         Baseline      BehaviorGPT        Perfect   │
│ noise          models            ★                        │
└────────────────────────────────────────────────────────────┘

2. minADE (Minimum Average Displacement Error)

minADE = min over modes ( average L2 distance to ground truth )

Formula:
                   1   T
minADE = min_k   ─── Σ  ||ŝₜᵏ - sₜ||₂
                   T  t=1

Where:
• k = index over predicted modes
• T = prediction horizon (80 timesteps)
• ŝₜᵏ = predicted position at time t, mode k
• sₜ = ground truth position at time t

BehaviorGPT: 1.4147 meters (best in challenge!)
Visual Explanation:
═══════════════════

Ground Truth:     ●────●────●────●────●────●

Predicted Mode 1: ○───○───○───○───○───○  (ADE = 2.1m)
Predicted Mode 2: ◇────◇────◇────◇────◇────◇  (ADE = 1.4m) ✓ min
Predicted Mode 3: △──△──△──△──△──△  (ADE = 3.2m)

minADE = 1.4m (best mode is selected)

3. COLLISION Score

COLLISION = 1 - (collision rate in simulation / collision rate in logs)

Measures: Do simulated agents collide at similar rates to real world?

Real World:   ~2% of agent pairs have close encounters
Simulation:   Should match this rate, not be collision-free!

BehaviorGPT: 0.9537

Note: A "perfect" safety system with 0 collisions would score POORLY!
      We want REALISTIC collision rates, not zero collisions.
Why This Matters:
═══════════════════════════════════════════════════════════════

Too Few Collisions (Overly Cautious):
┌─────────────────────────────────────┐
│  AV:  "I'll stop here to be safe"  │
│        🚗 ─────────── STOP ──────  │
│                                     │
│  Real humans don't drive this way! │
│  Unrealistic for training.         │
└─────────────────────────────────────┘

Too Many Collisions (Too Aggressive):
┌─────────────────────────────────────┐
│  AV:  "YOLO!"                       │
│        🚗 💥 🚙                      │
│                                     │
│  Dangerous! Bad for safety testing. │
└─────────────────────────────────────┘

Just Right (BehaviorGPT):
┌─────────────────────────────────────┐
│  Matches real-world collision rates │
│  Enables realistic safety testing   │
└─────────────────────────────────────┘

4. OFFROAD Score

OFFROAD = 1 - (off-road rate in simulation / off-road rate in logs)

Measures: Do agents stay on valid road surfaces?

BehaviorGPT: 0.9349

Components:
• Lane boundary violations
• Driving on sidewalks
• Entering non-drivable areas
Visual:
═══════

                  Road Surface
         ┌────────────────────────────┐
         │ ════════════════════════   │
    Bad: │ ══ 🚗 ════════════════════ │ (off-road!)
         │     ↓                      │
         └────────────────────────────┘
              │
              │ OFFROAD
              │ penalty
              ↓

         ┌────────────────────────────┐
         │ ════════════════════════   │
   Good: │ ═══════ 🚗 ═══════════════ │ (on-road!)
         │                            │
         └────────────────────────────┘

WOSAC 2024 Leaderboard

┌─────────────────────────────────────────────────────────────────┐
│                    WOSAC 2024 Final Results                      │
├──────┬─────────────────────┬──────────┬───────────┬─────────────┤
│ Rank │ Method              │ Params   │  REALISM  │   minADE    │
├──────┼─────────────────────┼──────────┼───────────┼─────────────┤
│  🥇  │ BehaviorGPT         │   3M     │   0.747   │    1.415    │
│  🥈  │ MVTE                │   45M    │   0.731   │    1.523    │
│  🥉  │ Trajeglish          │   120M   │   0.718   │    1.489    │
│  4   │ MotionLM            │   523M   │   0.702   │    1.567    │
│  5   │ SceneTransformer    │   28M    │   0.689   │    1.612    │
└──────┴─────────────────────┴──────────┴───────────┴─────────────┘

Key Insight: BehaviorGPT achieves BEST results with FEWEST parameters!
             This is 175x smaller than MotionLM, yet performs better.

8. Mental Models

Mental Model 1: The Driving Instructor Analogy

Traditional Frame-by-Frame Prediction:
══════════════════════════════════════════════════════════════════

Student: "What do I do next?"
Instructor: "Move your foot 0.1 inches..."
Student: "Now what?"
Instructor: "Move it 0.1 more inches..."

Problem: Student never learns to PLAN, just follows micro-instructions.


BehaviorGPT with NP3:
══════════════════════════════════════════════════════════════════

Student: "What do I do next?"
Instructor: "In the next second, you'll approach the intersection,
            begin braking, and prepare to turn right."
Student: "Got it!" (executes the whole maneuver)

Result: Student learns to THINK AHEAD in meaningful chunks.

Mental Model 2: The Orchestra Conductor

Multi-Agent Coordination in BehaviorGPT:
════════════════════════════════════════════════════════════════

                        🎭 Conductor (Model)
                             │
          ┌──────────────────┼──────────────────┐
          │                  │                  │
          ▼                  ▼                  ▼
      🎻 Violin          🎺 Trumpet        🥁 Drums
      (Car A)            (Car B)          (Car C)

Each musician (agent):
• Has their own score (temporal attention)
• Watches the conductor (agent-agent attention)
• Follows the venue layout (agent-map attention)

The conductor doesn't micromanage every note.
Instead, they guide each PHRASE (patch) of music.

Mental Model 3: The Story Continuation

GPT Language Model:
═══════════════════════════════════════════════════════════════

Context: "The sports car approached the red light. The driver..."

Prediction:  "...slowed down gradually and came to a smooth stop,
              noticing the pedestrian crossing ahead."

The model completes a coherent CONTINUATION of the story.


BehaviorGPT:
═══════════════════════════════════════════════════════════════

Context: [1 second of observed driving: car approaching intersection]

Prediction: [Next 1 second: car decelerating, preparing to stop]
            [Next 1 second: car stopped at light]
            [Next 1 second: car begins moving as light turns green]
            ...

The model completes a coherent CONTINUATION of the driving scenario.

Mental Model 4: Information Flow

How Information Flows Through BehaviorGPT:
════════════════════════════════════════════════════════════════

                    Time →

Agent 1:  [Past]─────[Past]─────[Current]─────?─────?─────?
             │          │           │         │     │     │
             └──────────┴───────────┘         │     │     │
                    │                         │     │     │
                    │ Temporal Self-Attention │     │     │
                    │ "What did I do before?" │     │     │
                    │                         │     │     │
                    ▼                         │     │     │
             ┌─────────────────┐              │     │     │
             │   Agent State   │──────────────┘     │     │
             └────────┬────────┘                    │     │
                      │                             │     │
                      │ Agent-Agent Attention       │     │
                      │ "What are others doing?"    │     │
                      │                             │     │
                      ▼                             │     │
             ┌─────────────────┐                    │     │
             │  Social Context │────────────────────┘     │
             └────────┬────────┘                          │
                      │                                   │
                      │ Agent-Map Attention               │
                      │ "What does the road allow?"       │
                      │                                   │
                      ▼                                   │
             ┌─────────────────┐                          │
             │   Full Context  │──────────────────────────┘
             └────────┬────────┘
                      │
                      ▼
             ┌─────────────────┐
             │  Predict Next   │
             │     Patch       │
             └─────────────────┘

Mental Model 5: The Patch Prediction Process

Step-by-Step Patch Prediction:
════════════════════════════════════════════════════════════════

Time (seconds):   0    1    2    3    4    5    6    7    8    9
                  ├────┼────┼────┼────┼────┼────┼────┼────┼────┤

History:         [====][====]
                 Patch Patch
                   0     1

Step 1:          [====][====][????]
                              │
                              └─ Predict Patch 2 from context

Step 2:          [====][====][====][????]
                                    │
                                    └─ Predict Patch 3 from context

Step 3:          [====][====][====][====][????]
                                          │
                                          └─ Predict Patch 4

...continue until desired horizon...

Final:           [====][====][====][====][====][====][====][====][====]
                 ├── History ──┤├───────── Predicted ────────────────┤

9. Hands-On Exercises

Exercise 1: Understanding Patch Formation

Goal: Convert raw trajectory data into patches

# Given: Raw trajectory at 10Hz
raw_trajectory = [
    (0.0, 0.0),   # t=0.0s
    (0.5, 0.1),   # t=0.1s
    (1.0, 0.2),   # t=0.2s
    (1.5, 0.3),   # t=0.3s
    (2.0, 0.4),   # t=0.4s
    (2.5, 0.5),   # t=0.5s
    (3.0, 0.6),   # t=0.6s
    (3.5, 0.7),   # t=0.7s
    (4.0, 0.8),   # t=0.8s
    (4.5, 0.9),   # t=0.9s
    (5.0, 1.0),   # t=1.0s
    (5.5, 1.1),   # t=1.1s
    # ... continues
]

# TODO: Implement this function
def create_patches(trajectory, patch_size=10):
    """
    Convert raw trajectory to patches.

    Args:
        trajectory: List of (x, y) positions at 10Hz
        patch_size: Number of timesteps per patch

    Returns:
        List of patches, where each patch is a list of positions
    """
    patches = []
    # Your code here
    pass

# Expected output:
# patches[0] = [(0.0, 0.0), (0.5, 0.1), ..., (4.5, 0.9)]  # First 1 second
# patches[1] = [(5.0, 1.0), (5.5, 1.1), ..., (9.5, 1.9)]  # Second 1 second
<details> <summary>Solution</summary>
def create_patches(trajectory, patch_size=10):
    patches = []
    for i in range(0, len(trajectory) - patch_size + 1, patch_size):
        patch = trajectory[i:i + patch_size]
        patches.append(patch)
    return patches
</details>

Exercise 2: Computing Relative Features

Goal: Understand the relative spacetime representation

import math

# Given: Two agents' states
agent_a = {
    "position": (10.0, 20.0),
    "heading": math.pi / 4,  # 45 degrees (NE)
    "speed": 15.0,
}

agent_b = {
    "position": (15.0, 25.0),
    "heading": math.pi / 2,  # 90 degrees (N)
    "speed": 10.0,
}

# TODO: Implement relative feature computation
def compute_relative_features(agent_a, agent_b):
    """
    Compute relative features between two agents.

    Returns:
        dict with:
        - distance: Euclidean distance
        - bearing: Angle from A to B (in A's frame)
        - relative_heading: Heading difference
        - relative_speed: Speed difference
    """
    # Your code here
    pass

# Expected output:
# {
#     "distance": 7.07,  # sqrt((15-10)^2 + (25-20)^2)
#     "bearing": 0.785,  # 45 degrees relative to A's heading
#     "relative_heading": 0.785,  # B is heading 45 deg left of A
#     "relative_speed": -5.0,  # B is 5 m/s slower
# }
<details> <summary>Solution</summary>
def compute_relative_features(agent_a, agent_b):
    # Distance
    dx = agent_b["position"][0] - agent_a["position"][0]
    dy = agent_b["position"][1] - agent_a["position"][1]
    distance = math.sqrt(dx**2 + dy**2)

    # Bearing (angle from A to B, relative to A's heading)
    global_angle = math.atan2(dy, dx)
    bearing = global_angle - agent_a["heading"]
    # Normalize to [-pi, pi]
    bearing = math.atan2(math.sin(bearing), math.cos(bearing))

    # Relative heading
    relative_heading = agent_b["heading"] - agent_a["heading"]
    relative_heading = math.atan2(math.sin(relative_heading),
                                   math.cos(relative_heading))

    # Relative speed
    relative_speed = agent_b["speed"] - agent_a["speed"]

    return {
        "distance": round(distance, 2),
        "bearing": round(bearing, 3),
        "relative_heading": round(relative_heading, 3),
        "relative_speed": relative_speed,
    }
</details>

Exercise 3: Causal Attention Mask

Goal: Understand the temporal attention pattern

import numpy as np

def create_causal_mask(num_patches):
    """
    Create a causal attention mask for temporal self-attention.

    A patch can only attend to itself and previous patches.

    Args:
        num_patches: Number of trajectory patches

    Returns:
        mask: [num_patches, num_patches] boolean array
              True = can attend, False = cannot attend
    """
    # Your code here
    pass

# Test with 5 patches
mask = create_causal_mask(5)
print(mask)

# Expected output:
# [[ True False False False False]
#  [ True  True False False False]
#  [ True  True  True False False]
#  [ True  True  True  True False]
#  [ True  True  True  True  True]]
#
# Patch 0: Can only see itself
# Patch 1: Can see patches 0 and 1
# Patch 2: Can see patches 0, 1, and 2
# etc.
<details> <summary>Solution</summary>
def create_causal_mask(num_patches):
    return np.tril(np.ones((num_patches, num_patches), dtype=bool))
</details>

Exercise 4: GMM Sampling

Goal: Sample trajectories from a Gaussian Mixture Model

import numpy as np

def sample_from_gmm(mixture_weights, means, stds, num_samples=1):
    """
    Sample trajectories from a GMM.

    Args:
        mixture_weights: [num_modes] probabilities summing to 1
        means: [num_modes, num_timesteps, 2] mean positions
        stds: [num_modes, num_timesteps, 2] standard deviations
        num_samples: Number of trajectories to sample

    Returns:
        samples: [num_samples, num_timesteps, 2] sampled trajectories
    """
    samples = []

    for _ in range(num_samples):
        # Step 1: Sample a mode index based on mixture weights
        mode_idx = None  # Your code here

        # Step 2: Get mean and std for selected mode
        selected_mean = None  # Your code here
        selected_std = None   # Your code here

        # Step 3: Sample from Gaussian
        sample = None  # Your code here

        samples.append(sample)

    return np.array(samples)

# Test case
np.random.seed(42)
weights = np.array([0.5, 0.3, 0.2])  # 3 modes
means = np.array([
    [[0, 0], [1, 0], [2, 0]],    # Mode 0: Go straight
    [[0, 0], [1, 1], [2, 2]],    # Mode 1: Go diagonal
    [[0, 0], [0, 1], [0, 2]],    # Mode 2: Go up
])
stds = np.ones_like(means) * 0.1

trajectories = sample_from_gmm(weights, means, stds, num_samples=5)
print(trajectories.shape)  # Should be (5, 3, 2)
<details> <summary>Solution</summary>
def sample_from_gmm(mixture_weights, means, stds, num_samples=1):
    samples = []

    for _ in range(num_samples):
        # Step 1: Sample mode index
        mode_idx = np.random.choice(len(mixture_weights), p=mixture_weights)

        # Step 2: Get parameters for selected mode
        selected_mean = means[mode_idx]
        selected_std = stds[mode_idx]

        # Step 3: Sample from Gaussian
        noise = np.random.randn(*selected_mean.shape)
        sample = selected_mean + noise * selected_std

        samples.append(sample)

    return np.array(samples)
</details>

Exercise 5: K-Nearest Neighbor Selection

Goal: Implement the k-NN filtering for attention

import numpy as np

def select_k_nearest(query_positions, key_positions, k):
    """
    For each query, find the k nearest keys.

    Args:
        query_positions: [num_queries, 2] array of (x, y)
        key_positions: [num_keys, 2] array of (x, y)
        k: Number of nearest neighbors to select

    Returns:
        indices: [num_queries, k] indices of k-nearest keys
        distances: [num_queries, k] distances to k-nearest keys
    """
    # Your code here
    pass

# Test case
np.random.seed(42)
queries = np.array([
    [0, 0],
    [10, 10],
])
keys = np.array([
    [1, 1],    # Close to query 0
    [9, 9],    # Close to query 1
    [5, 5],    # Middle
    [100, 100], # Far from both
])

indices, distances = select_k_nearest(queries, keys, k=2)
print("Indices:", indices)
print("Distances:", distances)

# Expected:
# For query [0,0]: nearest are indices 0 and 2
# For query [10,10]: nearest are indices 1 and 2
<details> <summary>Solution</summary>
def select_k_nearest(query_positions, key_positions, k):
    num_queries = len(query_positions)
    num_keys = len(key_positions)

    # Compute all pairwise distances
    # Shape: [num_queries, num_keys]
    diff = query_positions[:, None, :] - key_positions[None, :, :]
    distances_all = np.sqrt((diff ** 2).sum(axis=-1))

    # Find k smallest distances for each query
    indices = np.argsort(distances_all, axis=-1)[:, :k]
    distances = np.take_along_axis(distances_all, indices, axis=-1)

    return indices, distances
</details>

10. Interview Questions

Conceptual Questions

Q1: Why does BehaviorGPT use patches instead of individual timesteps?

<details> <summary>Expected Answer</summary>

Key Points:

  1. Prevents trivial shortcuts: At 10Hz, consecutive frames are nearly identical. A model can achieve low training loss by simply copying the previous timestep with minor noise. Patches (1 second each) have meaningful differences that force the model to actually predict future behavior.

  2. Captures long-range dependencies: Patch-level tokens reduce sequence length by 10x (e.g., 80 timesteps becomes 8 patches), making it computationally feasible to model longer-range dependencies with self-attention.

  3. Reduces compounding errors: In autoregressive generation, errors accumulate with each prediction step. With 10x fewer prediction steps (patches vs. timesteps), error accumulation is significantly reduced.

  4. Aligns with human decision-making: Drivers don't plan microsecond by microsecond; they think in terms of maneuvers lasting seconds. Patches align with this natural granularity.

Bonus insight: The ablation study shows patch_size=10 outperforms patch_size=1 by +4.9% on REALISM.

</details>

Q2: Why is the relative spacetime representation important for data efficiency?

<details> <summary>Expected Answer</summary>

Key Points:

  1. Position invariance: The same driving behavior at different absolute locations (intersection at (100, 200) vs. (5000, 3000)) produces identical relative features. The model learns the behavior once, not separately for each location.

  2. Implicit data augmentation: Translation, rotation, and reflection augmentations are "free" - they don't change relative features. This effectively multiplies training data without additional computation.

  3. Generalization: Models trained with relative features generalize better to new locations, road layouts, and scenarios not seen during training.

  4. Symmetric treatment: No privileged "ego" agent. All agents are treated equally, maximizing learning from every agent's trajectory.

Concrete example: A "car following" behavior has the same relative representation regardless of whether it happens in San Francisco or Phoenix.

</details>

Q3: Explain the purpose of each attention mechanism in the Triple Attention Block.

<details> <summary>Expected Answer</summary>

1. Temporal Self-Attention:

  • Purpose: Model how each individual agent moves over time
  • Key feature: Causal masking (can only attend to past patches)
  • Captures: Acceleration patterns, turning behaviors, consistency of motion

2. Agent-Map Cross-Attention:

  • Purpose: Model how agents interact with road geometry
  • Key feature: Queries are agent patches, keys/values are map elements
  • Captures: Lane following, stopping at intersections, respecting road boundaries

3. Agent-Agent Self-Attention:

  • Purpose: Model social interactions between agents
  • Key feature: K-nearest neighbor filtering to limit complexity
  • Captures: Following distance, yielding, collision avoidance, coordinated merging

Why this order matters:

  1. First understand individual motion (temporal)
  2. Then ground it in the environment (map)
  3. Finally coordinate with others (social)

This mirrors human decision-making: "What am I doing?" -> "Where am I?" -> "Who's around me?"

</details>

Q4: Why does BehaviorGPT use a GMM output instead of directly predicting positions?

<details> <summary>Expected Answer</summary>

Key Points:

  1. Multi-modality: Future trajectories are inherently uncertain. A car at an intersection could go straight, turn left, or turn right. GMM captures multiple plausible futures with different probabilities.

  2. Uncertainty quantification: The mixture weights and standard deviations provide principled uncertainty estimates. This is crucial for downstream planning and safety.

  3. Training with NLL loss: Gaussian likelihoods enable proper probabilistic training. The model learns not just the mean behavior but the full distribution.

  4. Sampling for simulation: During inference, we can sample diverse trajectories by (a) sampling a mode from mixture weights, then (b) sampling positions from that mode's Gaussian. This creates realistic diversity across simulation replicas.

Design choice: BehaviorGPT uses 16 mixture modes per agent. Too few modes = underfit multi-modal distributions. Too many = diluted probabilities and overfitting.

</details>

Q5: How does BehaviorGPT handle the curse of dimensionality with 128 agents?

<details> <summary>Expected Answer</summary>

Key Points:

  1. K-nearest neighbor filtering: Instead of O(N^2) attention over all agent pairs, BehaviorGPT attends only to k=32 nearest agents. This reduces complexity to O(N * k).

  2. Locality assumption: Agents primarily interact with nearby agents. A car 500 meters away has negligible influence on driving decisions. K-NN captures the relevant interactions.

  3. Patch-level representation: By grouping 10 timesteps into patches, sequence length is reduced 10x, making attention over agents more tractable.

  4. Independent patch-level planning: Within each patch, agents plan independently (Equation 3 in the paper). Coordination happens through attention over the previous patch's representations, not joint optimization.

  5. Small model size: With only 3M parameters and d=128 hidden dimensions, the model is inherently regularized against overfitting to spurious agent-agent correlations.

Scalability: The approach scales to 128 agents per scenario, which covers essentially all real-world driving situations in the WOMD dataset.

</details>

Coding Questions

Q6: Implement a simplified version of causal temporal self-attention.

def temporal_self_attention(x, num_heads):
    """
    Implement causal self-attention for temporal modeling.

    Args:
        x: [batch, num_patches, hidden_dim] tensor
        num_heads: Number of attention heads

    Returns:
        output: [batch, num_patches, hidden_dim] tensor
    """
    # Your implementation here
    pass
<details> <summary>Solution</summary>
import jax.numpy as jnp
import jax

def temporal_self_attention(x, num_heads):
    batch, num_patches, hidden_dim = x.shape
    head_dim = hidden_dim // num_heads

    # Linear projections for Q, K, V
    # In practice, these would be learned parameters
    Wq = jnp.eye(hidden_dim)
    Wk = jnp.eye(hidden_dim)
    Wv = jnp.eye(hidden_dim)

    Q = x @ Wq  # [batch, num_patches, hidden_dim]
    K = x @ Wk
    V = x @ Wv

    # Reshape for multi-head attention
    Q = Q.reshape(batch, num_patches, num_heads, head_dim).transpose(0, 2, 1, 3)
    K = K.reshape(batch, num_patches, num_heads, head_dim).transpose(0, 2, 1, 3)
    V = V.reshape(batch, num_patches, num_heads, head_dim).transpose(0, 2, 1, 3)
    # Shape: [batch, num_heads, num_patches, head_dim]

    # Compute attention scores
    scores = jnp.einsum('bhqd,bhkd->bhqk', Q, K) / jnp.sqrt(head_dim)
    # Shape: [batch, num_heads, num_patches, num_patches]

    # Create causal mask
    causal_mask = jnp.tril(jnp.ones((num_patches, num_patches)))
    scores = jnp.where(causal_mask, scores, -1e9)

    # Softmax and apply to values
    attn_weights = jax.nn.softmax(scores, axis=-1)
    output = jnp.einsum('bhqk,bhkd->bhqd', attn_weights, V)

    # Reshape back
    output = output.transpose(0, 2, 1, 3).reshape(batch, num_patches, hidden_dim)

    return output
</details>

Q7: Implement the autoregressive rollout loop.

def rollout(model, history, num_future_patches, rng_key):
    """
    Generate future trajectories autoregressively.

    Args:
        model: A function that takes patches and returns GMM params
        history: [batch, num_agents, num_history_patches, hidden_dim]
        num_future_patches: Number of patches to generate
        rng_key: JAX random key

    Returns:
        future: [batch, num_agents, num_future_patches, hidden_dim]
    """
    # Your implementation here
    pass
<details> <summary>Solution</summary>
def rollout(model, history, num_future_patches, rng_key):
    current = history
    predictions = []

    for i in range(num_future_patches):
        # Split key for this iteration
        rng_key, step_key = jax.random.split(rng_key)

        # Get model prediction for next patch
        gmm_params = model(current)

        # Extract last patch prediction
        last_weights = gmm_params["weights"][:, :, -1]
        last_means = gmm_params["means"][:, :, -1]
        last_stds = gmm_params["stds"][:, :, -1]

        # Sample mode for each agent
        mode_key, sample_key = jax.random.split(step_key)
        mode_indices = jax.random.categorical(
            mode_key,
            jnp.log(last_weights + 1e-8),
            axis=-1
        )

        # Gather selected mode parameters
        batch, num_agents, num_modes, dim = last_means.shape
        batch_idx = jnp.arange(batch)[:, None]
        agent_idx = jnp.arange(num_agents)[None, :]

        selected_mean = last_means[batch_idx, agent_idx, mode_indices]
        selected_std = last_stds[batch_idx, agent_idx, mode_indices]

        # Sample from Gaussian
        noise = jax.random.normal(sample_key, selected_mean.shape)
        next_patch = selected_mean + noise * selected_std

        # Append to sequence
        predictions.append(next_patch[:, :, None, :])
        current = jnp.concatenate([current, next_patch[:, :, None, :]], axis=2)

    return jnp.concatenate(predictions, axis=2)
</details>

System Design Questions

Q8: How would you modify BehaviorGPT for real-time inference on an autonomous vehicle?

<details> <summary>Expected Answer</summary>

Key Considerations:

  1. Latency Requirements:

    • AV systems typically require predictions within 50-100ms
    • BehaviorGPT's 3M parameters are already very efficient
    • May need to reduce num_modes, k_nearest, or hidden_dim further
  2. Streaming Architecture:

    • Maintain a rolling buffer of recent patches
    • Implement KV-caching for temporal attention (don't recompute history)
    • Process new observations incrementally
  3. Hardware Optimization:

    • Quantization (FP16 or INT8) for faster inference
    • TensorRT or ONNX optimization
    • Batching across multiple scenarios if running in simulation
  4. Reduced Rollout:

    • For real-time use, may only need 2-3 second predictions
    • Replan every 0.5 seconds (2Hz) rather than full 8-second rollouts
  5. Uncertainty Thresholding:

    • In production, may want to flag high-uncertainty predictions
    • Use mixture entropy as a confidence measure

Sample Architecture:

Sensor Input -> Preprocessing (10ms) ->
BehaviorGPT Inference (30ms) ->
Post-processing (10ms) -> Planning Module
Total: ~50ms (20Hz)
</details>

Q9: How would you extend BehaviorGPT to handle traffic lights and dynamic map elements?

<details> <summary>Expected Answer</summary>

Key Extensions:

  1. Traffic Light Encoding:

    • Add traffic light states (red/yellow/green) as additional map features
    • Include time-varying traffic light sequences
    • Encode expected time until state change
  2. Dynamic Map Elements:

    • Treat dynamic elements (construction zones, temporary signs) as special map tokens
    • Add temporal features: when element appeared, expected duration
  3. Cross-Attention Modification:

    • Agent-Map cross-attention already supports this
    • Add traffic light embeddings to the map token pool
    • Use attention to learn which lights are relevant to each agent
  4. Training Data Requirements:

    • Need annotations for traffic light states over time
    • Waymo Open Dataset includes some traffic light annotations

Implementation Sketch:

map_features = concat([
    road_segment_embeddings,      # Static
    lane_boundary_embeddings,     # Static
    traffic_light_embeddings,     # Dynamic - varies per patch
    construction_zone_embeddings, # Semi-static
])
  1. Challenges:
    • Traffic light state changes need to be synchronized across patches
    • May need to predict traffic light futures jointly with agent futures
</details>

Q10: If you had unlimited compute but still needed real-time inference, how would you improve BehaviorGPT?

<details> <summary>Expected Answer</summary>

Scaling Strategies:

  1. Model Scaling (More Parameters):

    • Increase hidden_dim: 128 -> 512 or 1024
    • Add more transformer layers: 6 -> 12 or 24
    • More attention heads with larger head dimensions
  2. Ensemble Methods:

    • Train multiple BehaviorGPT variants with different initializations
    • Aggregate predictions for better uncertainty calibration
    • Use mixture-of-experts architecture
  3. Multi-Scale Patches:

    • Combine fine-grained (0.5s) and coarse-grained (2s) patches
    • Hierarchical prediction: coarse trajectory first, then refine
  4. Richer Input Representations:

    • Include velocity/acceleration profiles, not just positions
    • Add semantic lane information (turn lanes, merge lanes)
    • Incorporate intent signals (turn signals, brake lights)
  5. Training Improvements:

    • Larger batch sizes for better gradient estimates
    • Longer training with more epochs
    • Curriculum learning: easy scenarios first
  6. For Real-Time Inference:

    • Distillation: Train large model, distill to small model
    • Speculative decoding: Use small model for most predictions, large model for verification
    • Dynamic compute allocation: Simple scenarios use small model, complex ones use large

The BehaviorGPT lesson: Sometimes simpler is better. The 3M parameter model won against 500M+ models, suggesting architectural innovations matter more than raw scale for this task.

</details>

Summary

BehaviorGPT represents a paradigm shift in autonomous driving simulation:

┌─────────────────────────────────────────────────────────────────┐
│                    Key Takeaways                                 │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  1. NEXT-PATCH PREDICTION (NP3)                                │
│     Think in seconds, not milliseconds                         │
│                                                                 │
│  2. RELATIVE SPACETIME                                         │
│     Position-independent = better generalization               │
│                                                                 │
│  3. TRIPLE ATTENTION                                           │
│     Time + Environment + Social = Complete context             │
│                                                                 │
│  4. EFFICIENT DESIGN                                           │
│     3M params beat 500M+ through smart architecture            │
│                                                                 │
│  5. AUTOREGRESSIVE GENERATION                                  │
│     Same paradigm that powers GPT, now for driving             │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

The success of BehaviorGPT demonstrates that the transformer architecture, when properly adapted to the domain, can achieve state-of-the-art results even with modest compute budgets. The key insight - treating trajectory prediction as "next patch prediction" analogous to "next token prediction" - elegantly bridges the gap between language modeling and behavior modeling.


Further Reading


This blog was created as an interactive learning resource. For corrections or suggestions, please open an issue in the repository.