Experience replay & target networks
Before we begin
The original DQN paper (Mnih et al., 2015) made two stabilizing ideas standard: experience replay breaks temporal correlation in training data, and target networks slow the movement of bootstrap targets. Without both, CartPole with a neural net often fails where tabular Q-learning succeeds.
Learning objectives
- Implement a circular replay buffer with uniform sampling.
- Explain why i.i.d. minibatches matter for SGD on Q-networks.
- Maintain a target network θ⁻ updated every C steps (or soft update).
- Compute DQN loss using targets from θ⁻ and predictions from θ.
- Read training curves for signs of replay or target misconfiguration.
Experience replay mechanics
Store transitions (s, a, r, s′, done) in a fixed-size deque. When full, overwrite oldest entries.
import random
from collections import deque
class ReplayBuffer:
def __init__(self, capacity: int = 100_000):
self.buf = deque(maxlen=capacity)
def push(self, s, a, r, s2, done):
self.buf.append((s, a, r, s2, done))
def sample(self, batch_size: int):
batch = random.sample(self.buf, batch_size)
# unzip into lists for numpy / torch stacking
return list(zip(*batch))
def __len__(self):
return len(self.buf)Warm-up: Do not train until len(buffer) >= batch_size (often 1k–10k transitions collected first).
Target network
Maintain target_net — a copy of policy_net with frozen weights during several gradient steps.
y = r + γ (1 − done) max_a' Q(s', a'; θ⁻)Hard update every C steps (classic DQN: C = 10,000 gradient steps):
def hard_update(target, source):
target.load_state_dict(source.state_dict())
# inside training loop, every C steps:
if step % target_update_freq == 0:
hard_update(target_net, policy_net)Soft update (Polyak averaging, used in DDPG/SAC later):
θ⁻ ← τ θ + (1 − τ) θ⁻ with small τ e.g. 0.005Worked example — replay batch composition
Buffer holds 50,000 transitions from the last 500 episodes. You sample batch_size = 64.
| Property | Without replay | With replay |
|---|---|---|
| Adjacent samples | Highly correlated | Decorrelated |
| Rare events | Seen once then forgotten | Resampled many times |
| Data efficiency | Each transition used once | Each transition reused |
Suppose a fall state appears in 2% of frames. Expected appearances per epoch of 64 random draws from 50k: roughly 64 × 0.02 ≈ 1.3 per batch — the agent repeatedly learns from critical failures instead of drowning in pole-balancing frames.
Full training step skeleton
import copy
import numpy as np
import torch
policy_net = QNet(obs_dim, n_actions)
target_net = copy.deepcopy(policy_net)
target_net.eval()
optimizer = torch.optim.Adam(policy_net.parameters(), lr=2.5e-4)
buffer = ReplayBuffer(100_000)
def train_step(batch_size=32, gamma=0.99):
if len(buffer) < batch_size:
return None
s, a, r, s2, d = buffer.sample(batch_size)
s = torch.tensor(np.array(s), dtype=torch.float32)
a = torch.tensor(a, dtype=torch.int64)
r = torch.tensor(r, dtype=torch.float32)
s2 = torch.tensor(np.array(s2), dtype=torch.float32)
d = torch.tensor(d, dtype=torch.float32)
with torch.no_grad():
next_q = target_net(s2).max(dim=1).values
target = r + gamma * (1.0 - d) * next_q
q = policy_net(s).gather(1, a.unsqueeze(1)).squeeze(1)
loss = torch.nn.functional.mse_loss(q, target)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(policy_net.parameters(), 10.0)
optimizer.step()
return loss.item()Hyperparameter starting points (CartPole-v1)
| Hyperparameter | Typical range | Notes |
|---|---|---|
| Buffer size | 50k–1M | Larger for Atari |
| Batch size | 32–128 | GPU memory bound |
| Target update C | 500–10,000 steps | Too frequent → instability |
| Learning rate | 1e-4 – 2.5e-4 | Adam default in many repos |
| ε decay | 1.0 → 0.05 over 10k–100k frames | Environment dependent |
Checkpoint — details: Target nets decouple the value we bootstrap from from the value we are actively changing — like evaluating a chess position with a frozen evaluation function while you train a new one. Summary: Replay = better SGD data; target net = stable Bellman targets.
Common mistakes
- Updating target every step — defeats the purpose; targets move as fast as predictions.
- Never clearing replay after major distribution shift — less critical in stationary envs; matters in fine-tuning.
- Storing wrong action — must store the action actually taken during ε-greedy, not the greedy argmax.
- Omitting done flag — terminal bootstrap error propagates bogus high Q at episode boundaries.
- Batch norm on tiny batches with correlated early data — LayerNorm or no norm in small MLP DQN is often safer.