From Q-learning to DQN
Before we begin
Tabular Q-learning updates one table cell at a time. Deep Q-Networks (DQN) replace that table with a neural network Q(s,a; θ) so the same update rule scales to images, sensors, and continuous state vectors. This lesson bridges the Bellman backup you already know to the first practical deep RL algorithm.
Learning objectives
- Recall the tabular Q-learning backup and its off-policy nature.
- State the DQN loss: minimize squared TD error on a minibatch.
- Explain why naïve online NN Q-learning diverges (moving targets + correlated samples).
- Sketch a minimal DQN training loop: act, store, sample, backprop.
- Connect Gymnasium observations to network inputs and discrete action heads.
From tabular backup to neural fitting
Tabular Q-learning for discrete actions:
Q(s,a) ← Q(s,a) + α [ r + γ max_a' Q(s',a') − Q(s,a) ]DQN replaces Q(s,a) with a network output. For a batch of transitions (s, a, r, s′, done):
y_i = r_i + γ (1 − done_i) max_a' Q(s'_i, a'; θ)
L(θ) = mean_i ( y_i − Q(s_i, a_i; θ) )²The argmax for the bootstrap target still picks the greedy next action — same semantics as Q-learning, different representation.
Worked example — one transition
Suppose CartPole state is a 4-vector, action 1 (push right), reward +1, next state s′, not terminal, γ = 0.99.
| Quantity | Value |
|---|---|
| Q(s, a=1; θ) before update | 12.4 |
| max_a' Q(s′, a′; θ) | 14.0 |
| TD target y | 1 + 0.99 × 14.0 = 14.86 |
| TD error | 14.86 − 12.4 = 2.46 |
Gradient descent nudges Q(s,1;θ) upward. Over thousands of transitions the network generalizes: similar states get similar Q-values.
Minimal DQN network (PyTorch)
import torch
import torch.nn as nn
class QNet(nn.Module):
def __init__(self, obs_dim: int, n_actions: int, hidden: int = 128):
super().__init__()
self.net = nn.Sequential(
nn.Linear(obs_dim, hidden),
nn.ReLU(),
nn.Linear(hidden, hidden),
nn.ReLU(),
nn.Linear(hidden, n_actions),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x) # shape (batch, n_actions)
# One SGD step on a synthetic batch
net = QNet(obs_dim=4, n_actions=2)
opt = torch.optim.Adam(net.parameters(), lr=1e-3)
states = torch.randn(32, 4)
actions = torch.randint(0, 2, (32,))
rewards = torch.ones(32)
next_states = torch.randn(32, 4)
dones = torch.zeros(32)
gamma = 0.99
with torch.no_grad():
next_q = net(next_states).max(dim=1).values
targets = rewards + gamma * (1.0 - dones) * next_q
q_sa = net(states).gather(1, actions.unsqueeze(1)).squeeze(1)
loss = ((targets - q_sa) ** 2).mean()
opt.zero_grad()
loss.backward()
opt.step()ε-greedy acting while learning
DQN is off-policy: behavior can explore with ε-greedy while targets use max over actions. Typical schedule: ε starts near 1.0, decays to 0.05 over 50k–200k frames.
import random
def select_action(q_values, epsilon: float) -> int:
if random.random() < epsilon:
return random.randrange(len(q_values))
return int(q_values.argmax())Checkpoint — details: Why off-policy matters here: replay stores transitions from an old exploratory policy; the max in the target still estimates the greedy policy's value. Summary: DQN = Q-learning backup + function approximator + (later) replay and target nets for stability.
Why naïve online NN Q-learning fails
| Failure mode | Symptom | Root cause |
|---|---|---|
| Diverging Q | Loss explodes, returns collapse | Bootstrapping from a network that changes every step |
| Catastrophic forgetting | Good states regress | Correlated consecutive frames overwrite weights |
| Overestimation | Q too optimistic | max operator on noisy estimates |
Module 3's deadly triad (function approximation + bootstrapping + off-policy) appears in full force. DQN's fixes — experience replay and target networks — are the next lesson.
Common mistakes
- Training on consecutive frames only — gradients correlate; learning is unstable. Fix: replay buffer (next lesson).
- Using the same network for target and prediction — chasing a moving target. Fix: frozen target copy.
- Forgetting to zero bootstrap at terminal — use
(1 - done)in the target or Q blows up at episode end. - Treating continuous actions as discrete without discretization — vanilla DQN needs a finite action head; continuous control needs different algorithms (Module 8).
- No gradient clip on early prototypes — occasional loss spikes can destroy weights;
clip_grad_norm_helps.