Project: DQN on CartPole
Before we begin
Train a Deep Q-Network (DQN) in PyTorch on CartPole-v1. This is your first deep RL project — experience replay and a target network are required for stability, not optional extras.
How this connects to Module 4
| Lesson | Where you use it |
|---|---|
| From Q-learning to DQN | Same Bellman target, function approximator = MLP |
| Experience replay | Uniform sample from buffer breaks correlation |
| Target network | y = r + γ max_a Q_target(s', a) |
| Double / dueling | Optional stretch goals |
| Hyperparameters | Batch size, ε decay, target sync interval |
What you will build
| Piece | Purpose |
|---|---|
ReplayBuffer | Store (s, a, r, s', done) tuples |
QNetwork | MLP: obs_dim → hidden → n_actions |
DQNAgent | ε-greedy + periodic target net sync |
returns.png | Rolling mean return |
Folder layout:
text
dqn-cartpole/
buffer.py
model.py
agent.py
train.py
checkpoints/best.pt
outputs/returns.pngEstimated time: 4–6 hours.
Before you start
- Finish the Module 4 quiz.
pip install gymnasium torch numpy matplotlib
Step 1 — Replay buffer & Q-network
python
# buffer.py
import random
from collections import deque
import numpy as np
class ReplayBuffer:
def __init__(self, capacity=50_000):
self.buf = deque(maxlen=capacity)
def push(self, transition):
self.buf.append(transition)
def sample(self, batch_size):
batch = random.sample(self.buf, batch_size)
s, a, r, s2, d = map(np.array, zip(*batch))
return s, a, r, s2, d
def __len__(self):
return len(self.buf)python
# model.py
import torch
import torch.nn as nn
class QNetwork(nn.Module):
def __init__(self, obs_dim=4, n_actions=2, hidden=128):
super().__init__()
self.net = nn.Sequential(
nn.Linear(obs_dim, hidden), nn.ReLU(),
nn.Linear(hidden, hidden), nn.ReLU(),
nn.Linear(hidden, n_actions),
)
def forward(self, x):
return self.net(x)Step 2 — DQN training step
python
import torch
import torch.nn.functional as F
def train_step(online, target, optimizer, batch, gamma=0.99):
s, a, r, s2, d = batch
s = torch.tensor(s, dtype=torch.float32)
a = torch.tensor(a, dtype=torch.int64)
r = torch.tensor(r, dtype=torch.float32)
s2 = torch.tensor(s2, dtype=torch.float32)
d = torch.tensor(d, dtype=torch.float32)
q_sa = online(s).gather(1, a.unsqueeze(1)).squeeze(1)
with torch.no_grad():
q_next = target(s2).max(dim=1).values
y = r + gamma * q_next * (1 - d)
loss = F.mse_loss(q_sa, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
return loss.item()Sync target weights every C steps: target.load_state_dict(online.state_dict()).
Step 3 — Main loop
python
import gymnasium as gym
env = gym.make("CartPole-v1")
online = QNetwork()
target = QNetwork()
target.load_state_dict(online.state_dict())
optimizer = torch.optim.Adam(online.parameters(), lr=1e-3)
buffer = ReplayBuffer()
epsilon, eps_min, eps_decay = 1.0, 0.05, 0.995
for episode in range(500):
s, _ = env.reset()
done = False
while not done:
if np.random.random() < epsilon:
a = env.action_space.sample()
else:
with torch.no_grad():
a = int(online(torch.tensor(s, dtype=torch.float32)).argmax())
s2, r, term, trunc, _ = env.step(a)
done = term or trunc
buffer.push((s, a, r, s2, float(done)))
s = s2
if len(buffer) >= 64:
train_step(online, target, optimizer, buffer.sample(64))
epsilon = max(eps_min, epsilon * eps_decay)
if episode % 10 == 0:
target.load_state_dict(online.state_dict())Success criteria
| Criterion | Target |
|---|---|
| Replay buffer + target net used | Required |
| Mean return over last 100 episodes ≥ 200 | CartPole solved |
Saved best.pt checkpoint | Required |
| One ablation (no replay OR no target) fails to solve | Recommended |
Extension ideas
- Double DQN: use online net to select action, target net to evaluate.
- Dueling architecture: separate value and advantage streams.
- Log loss and Q magnitudes to catch divergence early.
What's next
Return to the course curriculum and continue to the next module when your project runs end-to-end.