← Back to curriculum

Module 4 — Deep Q-networks

Project: DQN on CartPole

PyTorch DQN with replay and target net; reach 200+ mean return on CartPole-v1.

~150 min read + exercises

Project: DQN on CartPole

Before we begin

Train a Deep Q-Network (DQN) in PyTorch on CartPole-v1. This is your first deep RL project — experience replay and a target network are required for stability, not optional extras.


How this connects to Module 4

LessonWhere you use it
From Q-learning to DQNSame Bellman target, function approximator = MLP
Experience replayUniform sample from buffer breaks correlation
Target networky = r + γ max_a Q_target(s', a)
Double / duelingOptional stretch goals
HyperparametersBatch size, ε decay, target sync interval

What you will build

PiecePurpose
ReplayBufferStore (s, a, r, s', done) tuples
QNetworkMLP: obs_dim → hidden → n_actions
DQNAgentε-greedy + periodic target net sync
returns.pngRolling mean return

Folder layout:

text
dqn-cartpole/
  buffer.py
  model.py
  agent.py
  train.py
  checkpoints/best.pt
  outputs/returns.png

Estimated time: 4–6 hours.


Before you start

  • Finish the Module 4 quiz.
  • pip install gymnasium torch numpy matplotlib

Step 1 — Replay buffer & Q-network

python
# buffer.py
import random
from collections import deque
import numpy as np
 
class ReplayBuffer:
    def __init__(self, capacity=50_000):
        self.buf = deque(maxlen=capacity)
 
    def push(self, transition):
        self.buf.append(transition)
 
    def sample(self, batch_size):
        batch = random.sample(self.buf, batch_size)
        s, a, r, s2, d = map(np.array, zip(*batch))
        return s, a, r, s2, d
 
    def __len__(self):
        return len(self.buf)
python
# model.py
import torch
import torch.nn as nn
 
class QNetwork(nn.Module):
    def __init__(self, obs_dim=4, n_actions=2, hidden=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_dim, hidden), nn.ReLU(),
            nn.Linear(hidden, hidden), nn.ReLU(),
            nn.Linear(hidden, n_actions),
        )
 
    def forward(self, x):
        return self.net(x)

Step 2 — DQN training step

python
import torch
import torch.nn.functional as F
 
def train_step(online, target, optimizer, batch, gamma=0.99):
    s, a, r, s2, d = batch
    s = torch.tensor(s, dtype=torch.float32)
    a = torch.tensor(a, dtype=torch.int64)
    r = torch.tensor(r, dtype=torch.float32)
    s2 = torch.tensor(s2, dtype=torch.float32)
    d = torch.tensor(d, dtype=torch.float32)
 
    q_sa = online(s).gather(1, a.unsqueeze(1)).squeeze(1)
    with torch.no_grad():
        q_next = target(s2).max(dim=1).values
        y = r + gamma * q_next * (1 - d)
 
    loss = F.mse_loss(q_sa, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss.item()

Sync target weights every C steps: target.load_state_dict(online.state_dict()).


Step 3 — Main loop

python
import gymnasium as gym
 
env = gym.make("CartPole-v1")
online = QNetwork()
target = QNetwork()
target.load_state_dict(online.state_dict())
optimizer = torch.optim.Adam(online.parameters(), lr=1e-3)
buffer = ReplayBuffer()
 
epsilon, eps_min, eps_decay = 1.0, 0.05, 0.995
for episode in range(500):
    s, _ = env.reset()
    done = False
    while not done:
        if np.random.random() < epsilon:
            a = env.action_space.sample()
        else:
            with torch.no_grad():
                a = int(online(torch.tensor(s, dtype=torch.float32)).argmax())
        s2, r, term, trunc, _ = env.step(a)
        done = term or trunc
        buffer.push((s, a, r, s2, float(done)))
        s = s2
        if len(buffer) >= 64:
            train_step(online, target, optimizer, buffer.sample(64))
    epsilon = max(eps_min, epsilon * eps_decay)
    if episode % 10 == 0:
        target.load_state_dict(online.state_dict())

Success criteria

CriterionTarget
Replay buffer + target net usedRequired
Mean return over last 100 episodes ≥ 200CartPole solved
Saved best.pt checkpointRequired
One ablation (no replay OR no target) fails to solveRecommended

Extension ideas

  • Double DQN: use online net to select action, target net to evaluate.
  • Dueling architecture: separate value and advantage streams.
  • Log loss and Q magnitudes to catch divergence early.

What's next

Return to the course curriculum and continue to the next module when your project runs end-to-end.