← Back to curriculum

Module 4 — Deep Q-networks

From Q-learning to DQN

Neural Q-network, loss as MSE on TD target, and ε-greedy exploration.

~65 min read + exercises

From Q-learning to DQN

Before we begin

Tabular Q-learning updates one table cell at a time. Deep Q-Networks (DQN) replace that table with a neural network Q(s,a; θ) so the same update rule scales to images, sensors, and continuous state vectors. This lesson bridges the Bellman backup you already know to the first practical deep RL algorithm.


Learning objectives

  • Recall the tabular Q-learning backup and its off-policy nature.
  • State the DQN loss: minimize squared TD error on a minibatch.
  • Explain why naïve online NN Q-learning diverges (moving targets + correlated samples).
  • Sketch a minimal DQN training loop: act, store, sample, backprop.
  • Connect Gymnasium observations to network inputs and discrete action heads.

From tabular backup to neural fitting

Tabular Q-learning for discrete actions:

text
Q(s,a) ← Q(s,a) + α [ r + γ max_a' Q(s',a') − Q(s,a) ]

DQN replaces Q(s,a) with a network output. For a batch of transitions (s, a, r, s′, done):

text
y_i = r_i + γ (1 − done_i) max_a' Q(s'_i, a'; θ)
L(θ) = mean_i ( y_i − Q(s_i, a_i; θ) )²

The argmax for the bootstrap target still picks the greedy next action — same semantics as Q-learning, different representation.

Worked example — one transition

Suppose CartPole state is a 4-vector, action 1 (push right), reward +1, next state s′, not terminal, γ = 0.99.

QuantityValue
Q(s, a=1; θ) before update12.4
max_a' Q(s′, a′; θ)14.0
TD target y1 + 0.99 × 14.0 = 14.86
TD error14.86 − 12.4 = 2.46

Gradient descent nudges Q(s,1;θ) upward. Over thousands of transitions the network generalizes: similar states get similar Q-values.

Minimal DQN network (PyTorch)

python
import torch
import torch.nn as nn
 
class QNet(nn.Module):
    def __init__(self, obs_dim: int, n_actions: int, hidden: int = 128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, hidden),
            nn.ReLU(),
            nn.Linear(hidden, n_actions),
        )
 
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)  # shape (batch, n_actions)
 
# One SGD step on a synthetic batch
net = QNet(obs_dim=4, n_actions=2)
opt = torch.optim.Adam(net.parameters(), lr=1e-3)
 
states = torch.randn(32, 4)
actions = torch.randint(0, 2, (32,))
rewards = torch.ones(32)
next_states = torch.randn(32, 4)
dones = torch.zeros(32)
 
gamma = 0.99
with torch.no_grad():
    next_q = net(next_states).max(dim=1).values
    targets = rewards + gamma * (1.0 - dones) * next_q
 
q_sa = net(states).gather(1, actions.unsqueeze(1)).squeeze(1)
loss = ((targets - q_sa) ** 2).mean()
opt.zero_grad()
loss.backward()
opt.step()

ε-greedy acting while learning

DQN is off-policy: behavior can explore with ε-greedy while targets use max over actions. Typical schedule: ε starts near 1.0, decays to 0.05 over 50k–200k frames.

python
import random
 
def select_action(q_values, epsilon: float) -> int:
    if random.random() < epsilon:
        return random.randrange(len(q_values))
    return int(q_values.argmax())

Checkpoint — details: Why off-policy matters here: replay stores transitions from an old exploratory policy; the max in the target still estimates the greedy policy's value. Summary: DQN = Q-learning backup + function approximator + (later) replay and target nets for stability.

Why naïve online NN Q-learning fails

Failure modeSymptomRoot cause
Diverging QLoss explodes, returns collapseBootstrapping from a network that changes every step
Catastrophic forgettingGood states regressCorrelated consecutive frames overwrite weights
OverestimationQ too optimisticmax operator on noisy estimates

Module 3's deadly triad (function approximation + bootstrapping + off-policy) appears in full force. DQN's fixes — experience replay and target networks — are the next lesson.

Common mistakes

  1. Training on consecutive frames only — gradients correlate; learning is unstable. Fix: replay buffer (next lesson).
  2. Using the same network for target and prediction — chasing a moving target. Fix: frozen target copy.
  3. Forgetting to zero bootstrap at terminal — use (1 - done) in the target or Q blows up at episode end.
  4. Treating continuous actions as discrete without discretization — vanilla DQN needs a finite action head; continuous control needs different algorithms (Module 8).
  5. No gradient clip on early prototypes — occasional loss spikes can destroy weights; clip_grad_norm_ helps.

Before this lesson


What's next