← Back to curriculum

Module 5 — Policy gradients

REINFORCE & the policy gradient theorem

Monte Carlo policy gradient, log-derivative trick, and episodic updates.

~70 min read + exercises

REINFORCE & the policy gradient theorem

Before we begin

REINFORCE (Monte Carlo policy gradient) updates θ using full-episode returns weighted by log-probabilities of actions taken. The policy gradient theorem justifies replacing a brutal trajectory derivative with an elegant expectation: ∇J ∝ E[∇ log π(a|s) · G_t].


Learning objectives

  • Derive the REINFORCE update at an intuitive level (no measure theory required).
  • Implement episodic REINFORCE with return-to-go G_t.
  • Explain why log π turns policy optimization into a weighted score function estimator.
  • Connect discount γ to return-to-go weighting.
  • Recognize high variance as the main practical weakness.

Policy gradient theorem (intuition)

Goal: maximize J(θ) = E_π[G_0]. The trick:

text
∇_θ J(θ) = E_π [ sum_t ∇_θ log π(a_t|s_t; θ) · G_t ]
  • ∇ log π points in parameter space that increase probability of action a_t in state s_t.
  • G_t is return from time t — if the episode went well, amplify actions that led there; if poorly, suppress them.
  • Expectation estimated by averaging over episodes (Monte Carlo).

REINFORCE algorithm

text
for each episode:
    generate τ = (s_0,a_0,r_1, ..., s_T) using π_θ
    for each t:
        G_t = sum_{k=t}^{T} γ^{k-t} r_{k+1}
    for each t:
        θ ← θ + α · γ^t · ∇_θ log π(a_t|s_t; θ) · G_t

Often use return-to-go from t to end without extra γ^t on the gradient (implementations vary — be consistent within one codebase).

PyTorch REINFORCE (one episode)

python
import torch
 
def reinforce_episode(policy, optimizer, episode, gamma=0.99):
    log_probs = []
    rewards = []
 
    for obs, action, reward in episode:
        obs_t = torch.tensor(obs, dtype=torch.float32)
        dist = policy(obs_t)
        a = torch.tensor(action)
        log_probs.append(dist.log_prob(a))
        rewards.append(reward)
 
    # returns-to-go
    G = 0.0
    returns = []
    for r in reversed(rewards):
        G = r + gamma * G
        returns.insert(0, G)
    returns = torch.tensor(returns, dtype=torch.float32)
 
    # optional: normalize returns for stability
    returns = (returns - returns.mean()) / (returns.std() + 1e-8)
 
    loss = []
    for log_p, Gt in zip(log_probs, returns):
        loss.append(-log_p * Gt)  # minimize negative objective
    loss = torch.stack(loss).sum()
 
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss.item()

Worked example — two-step episode

Actions a_0, a_1 with log π = −0.5, −0.3. Rewards r_1=0, r_2=+1. γ=1.

tG_t∇θ log πUpdate direction
01.0increases a_0 probpositive (good episode)
11.0increases a_1 probpositive

If final reward were −1, same gradients flip sign — actions made less likely.

Log-derivative trick (why log?)

text
∇_θ π(a|s; θ) = π(a|s; θ) ∇_θ log π(a|s; θ)

Rewriting the gradient as an expectation of score function ∇ log π times reward avoids differentiating through the environment transition dynamics — only the policy needs to be differentiable.

Variance and sample complexity

REINFORCE uses full episode returns — early actions get credit for late randomness. Variance grows with horizon. Typical CartPole needs thousands of episodes; same env with actor–critic (Module 5–6) learns faster.

PropertyREINFORCEQ-learning
BootstrappingNo (MC)Yes (TD)
VarianceHighLower (with replay)
On-policyYesOff-policy
Step complexitySimpleReplay + targets

Checkpoint — details: REINFORCE is the cleanest proof of concept for policy gradients — if it never learns CartPole, check return normalization and learning rate before adding critics. Summary: ∇J uses log π weighted by how good the rest of the episode was.

Common mistakes

  1. Backprop through environment — only policy parameters get gradients.
  2. Using total episode return for every t — should be G_t from t onward, not G_0 for all t.
  3. No return normalization — gradients explode when returns scale varies.
  4. Learning rate too high — policy collapse to near-deterministic bad actions.
  5. Off-policy data with REINFORCE — standard REINFORCE assumes on-policy samples.

Before this lesson


What's next