Advantage estimation & GAE
Before we begin
Generalized Advantage Estimation (GAE) blends multi-step returns with TD residuals to control bias vs variance. PPO, A2C, and most modern policy gradient code compute advantages with GAE(λ) before the policy update — not raw Monte Carlo returns.
Learning objectives
- Write TD residual δ_t = r + γV(s′) − V(s).
- Compute GAE advantage as exponentially weighted sum of δ.
- Interpret λ=0 (one-step TD) vs λ=1 (Monte Carlo).
- Implement backward GAE pass over a trajectory.
- Connect GAE to PPO's batch construction.
TD residual
At each step:
δ_t = r_{t+1} + γ V(s_{t+1}) − V(s_t)Single-step advantage estimate — biased if V is wrong, low variance.
GAE(λ) formula
A_t^{GAE} = sum_{l=0}^{∞} (γλ)^l δ_{t+l}Computed efficiently backward:
A_t = δ_t + γλ A_{t+1}| λ | Effect |
|---|---|
| 0 | A_t = δ_t (high bias, low variance) |
| 1 | A_t = sum of δ = G_t − V(s_t) (low bias, high variance) |
| 0.95 | Standard PPO default |
Typical γ=0.99, λ=0.95 for episodic control.
Python implementation
import numpy as np
def compute_gae(rewards, values, dones, gamma=0.99, lam=0.95):
"""
rewards: length T
values: length T+1 (bootstrap V(s_T) at end)
dones: length T, 1.0 if terminal transition
"""
T = len(rewards)
advantages = np.zeros(T, dtype=np.float32)
gae = 0.0
for t in reversed(range(T)):
mask = 1.0 - dones[t]
delta = rewards[t] + gamma * values[t + 1] * mask - values[t]
gae = delta + gamma * lam * mask * gae
advantages[t] = gae
returns = advantages + values[:-1]
return advantages, returnsreturns used to train value head; advantages for policy.
Worked example — 3-step episode
Rewards [1,1,1], V = [10, 11, 11.5, 0], γ=1, λ=1, non-terminal until end, V terminal bootstrap 0.
| t | δ_t | A_t (λ=1) |
|---|---|---|
| 2 | 1 + 0 − 11.5 = −10.5 | −10.5 |
| 1 | 1 + 11.5 − 11 = 1.5 | 1.5 + 1×(−10.5) = −9.0 |
| 0 | 1 + 11 − 10 = 2 | 2 + 1×(−9) = −7 |
Illustrates credit assignment — late critic error propagates backward (why accurate V matters).
Normalizing advantages for PPO
def normalize_advantages(adv, eps=1e-8):
return (adv - adv.mean()) / (adv.std() + eps)Apply on full batch across workers in A2C/PPO — stabilizes clip objective scale.
Relationship to n-step returns
GAE with λ is equivalent to weighted average of n-step advantages. You get smooth interpolation without running separate n for each update.
Checkpoint — details: If PPO policy updates are tiny, check whether advantages are near zero (critic too good) or unnormalized (clip always binds). Summary: GAE = smart weighted sum of TD surprises along the trajectory.
Common mistakes
- Wrong values length — need V(s_T) bootstrap at episode end; off-by-one bugs common.
- Not masking terminal δ — must not bootstrap across true episode end.
- λ=1 with noisy V — variance blows up; use 0.9–0.97.
- Computing GAE on stale trajectories — PPO recomputes with current critic before each epoch.
- Mixing truncated and terminated — Gymnasium time-limit truncation may still bootstrap (algorithm choice).