← Back to curriculum

Module 4 — Deep Q-networks

Experience replay & target networks

Replay buffer decorrelation, fixed targets, and soft target updates.

~70 min read + exercises

Experience replay & target networks

Before we begin

The original DQN paper (Mnih et al., 2015) made two stabilizing ideas standard: experience replay breaks temporal correlation in training data, and target networks slow the movement of bootstrap targets. Without both, CartPole with a neural net often fails where tabular Q-learning succeeds.


Learning objectives

  • Implement a circular replay buffer with uniform sampling.
  • Explain why i.i.d. minibatches matter for SGD on Q-networks.
  • Maintain a target network θ⁻ updated every C steps (or soft update).
  • Compute DQN loss using targets from θ⁻ and predictions from θ.
  • Read training curves for signs of replay or target misconfiguration.

Experience replay mechanics

Store transitions (s, a, r, s′, done) in a fixed-size deque. When full, overwrite oldest entries.

python
import random
from collections import deque
 
class ReplayBuffer:
    def __init__(self, capacity: int = 100_000):
        self.buf = deque(maxlen=capacity)
 
    def push(self, s, a, r, s2, done):
        self.buf.append((s, a, r, s2, done))
 
    def sample(self, batch_size: int):
        batch = random.sample(self.buf, batch_size)
        # unzip into lists for numpy / torch stacking
        return list(zip(*batch))
 
    def __len__(self):
        return len(self.buf)

Warm-up: Do not train until len(buffer) >= batch_size (often 1k–10k transitions collected first).

Target network

Maintain target_net — a copy of policy_net with frozen weights during several gradient steps.

text
y = r + γ (1 − done) max_a' Q(s', a'; θ⁻)

Hard update every C steps (classic DQN: C = 10,000 gradient steps):

python
def hard_update(target, source):
    target.load_state_dict(source.state_dict())
 
# inside training loop, every C steps:
if step % target_update_freq == 0:
    hard_update(target_net, policy_net)

Soft update (Polyak averaging, used in DDPG/SAC later):

text
θ⁻ ← τ θ + (1 − τ) θ⁻   with small τ e.g. 0.005

Worked example — replay batch composition

Buffer holds 50,000 transitions from the last 500 episodes. You sample batch_size = 64.

PropertyWithout replayWith replay
Adjacent samplesHighly correlatedDecorrelated
Rare eventsSeen once then forgottenResampled many times
Data efficiencyEach transition used onceEach transition reused

Suppose a fall state appears in 2% of frames. Expected appearances per epoch of 64 random draws from 50k: roughly 64 × 0.02 ≈ 1.3 per batch — the agent repeatedly learns from critical failures instead of drowning in pole-balancing frames.

Full training step skeleton

python
import copy
import numpy as np
import torch
 
policy_net = QNet(obs_dim, n_actions)
target_net = copy.deepcopy(policy_net)
target_net.eval()
optimizer = torch.optim.Adam(policy_net.parameters(), lr=2.5e-4)
buffer = ReplayBuffer(100_000)
 
def train_step(batch_size=32, gamma=0.99):
    if len(buffer) < batch_size:
        return None
    s, a, r, s2, d = buffer.sample(batch_size)
    s = torch.tensor(np.array(s), dtype=torch.float32)
    a = torch.tensor(a, dtype=torch.int64)
    r = torch.tensor(r, dtype=torch.float32)
    s2 = torch.tensor(np.array(s2), dtype=torch.float32)
    d = torch.tensor(d, dtype=torch.float32)
 
    with torch.no_grad():
        next_q = target_net(s2).max(dim=1).values
        target = r + gamma * (1.0 - d) * next_q
 
    q = policy_net(s).gather(1, a.unsqueeze(1)).squeeze(1)
    loss = torch.nn.functional.mse_loss(q, target)
    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(policy_net.parameters(), 10.0)
    optimizer.step()
    return loss.item()

Hyperparameter starting points (CartPole-v1)

HyperparameterTypical rangeNotes
Buffer size50k–1MLarger for Atari
Batch size32–128GPU memory bound
Target update C500–10,000 stepsToo frequent → instability
Learning rate1e-4 – 2.5e-4Adam default in many repos
ε decay1.0 → 0.05 over 10k–100k framesEnvironment dependent

Checkpoint — details: Target nets decouple the value we bootstrap from from the value we are actively changing — like evaluating a chess position with a frozen evaluation function while you train a new one. Summary: Replay = better SGD data; target net = stable Bellman targets.

Common mistakes

  1. Updating target every step — defeats the purpose; targets move as fast as predictions.
  2. Never clearing replay after major distribution shift — less critical in stationary envs; matters in fine-tuning.
  3. Storing wrong action — must store the action actually taken during ε-greedy, not the greedy argmax.
  4. Omitting done flag — terminal bootstrap error propagates bogus high Q at episode boundaries.
  5. Batch norm on tiny batches with correlated early data — LayerNorm or no norm in small MLP DQN is often safer.

Before this lesson


What's next