← Back to curriculum

Module 5 — Policy gradients

Why learn policies directly

Stochastic policies, continuous actions, and parameterizing π_θ(a|s).

~55 min read + exercises

Why learn policies directly

Before we begin

Value-based methods like DQN learn Q(s,a) then derive a policy via argmax. Policy gradient methods parameterize π(a|s; θ) directly and optimize expected return with gradient ascent. That matters for continuous actions, stochastic policies, and problems where the best policy is not deterministic.


Learning objectives

  • Contrast value-based vs policy-based vs actor–critic approaches.
  • State when argmax Q is insufficient (continuous actions, large discrete spaces).
  • Write the policy objective J(θ) = E_π [sum of discounted rewards].
  • Recognize stochastic policies as a built-in exploration mechanism.
  • Map policy outputs to Gymnasium action spaces (discrete softmax, continuous Gaussian).

Three families of RL algorithms

FamilyLearnsPolicy extractionTypical use
Value-basedQ(s,a) or V(s)ε-greedy or argmaxDiscrete actions, Atari
Policy-basedπ(as; θ)Direct sampling
Actor–criticπ and V or QSample from π, critic reduces variancePPO, SAC, modern default

DQN cannot output a torque of −0.37 N·m without discretizing into hundreds of bins — policy gradients output continuous parameters naturally.

Policy parameterization examples

Discrete (CartPole): softmax logits → categorical distribution.

python
import torch
import torch.nn as nn
from torch.distributions import Categorical
 
class PolicyDiscrete(nn.Module):
    def __init__(self, obs_dim, n_actions, hidden=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_dim, hidden), nn.ReLU(),
            nn.Linear(hidden, n_actions),
        )
 
    def forward(self, obs):
        logits = self.net(obs)
        return Categorical(logits=logits)
 
# usage
obs = torch.randn(4)
dist = policy(obs)
action = dist.sample()
log_prob = dist.log_prob(action)

Continuous (Pendulum): Gaussian mean and learned log_std.

python
from torch.distributions import Normal
 
class PolicyContinuous(nn.Module):
    def __init__(self, obs_dim, action_dim, hidden=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_dim, hidden), nn.ReLU(),
            nn.Linear(hidden, action_dim),
        )
        self.log_std = nn.Parameter(torch.zeros(action_dim))
 
    def forward(self, obs):
        mean = self.net(obs)
        std = self.log_std.exp()
        return Normal(mean, std)

Worked example — why stochastic policies help

Rock–paper–scissors against an opponent who beats your last move. A deterministic policy is exploitable: opponent counters every time. A mixed policy (each action 1/3) is unexploitable in expectation. Even in MDPs with optimal deterministic policies, stochastic behavior during training explores without ε-greedy hacks on a separate value function.

Objective function

Maximize expected discounted return:

text
J(θ) = E_{τ ~ π_θ} [ G_0 ]

where trajectory τ is a sequence of states and actions sampled from the environment under π_θ. No argmax — gradients flow through log π(a|s; θ) (covered in the next lesson).

When to prefer policy gradients

SituationValue-basedPolicy gradient
Discrete, moderate actionsStrong (DQN)Works (REINFORCE)
Continuous actionsWeak without discretizationNatural
Stochastic optimal policySuboptimal if forced deterministicNatural
High-dimensional actionargmax expensiveFactorized distributions
Need policy entropy / safetyBoltzmann on Q onlyDirect entropy bonus

Checkpoint — details: If your action space is Box in Gymnasium, start with policy gradients or actor–critic, not DQN. Summary: Learn π directly when the policy itself is the object you need or when argmax Q is awkward.

Common mistakes

  1. Using DQN on continuous Box actions without discretization — coarse bins destroy control quality.
  2. Forgetting to squash continuous actions to env bounds — use tanh scaling to low/high.
  3. Deterministic policy at initialization — zero logits → uniform or degenerate; check entropy.
  4. Confusing policy loss sign — we maximize return, so ascent on J (or minimize negative J).
  5. Ignoring action masking — invalid moves in games need masked softmax, not raw logits.

Before this lesson


What's next