← Back to curriculum

Module 3 — Neural networks basics

Project: MNIST digit classifier + draw UI

Train a network in PyTorch on MNIST, optional NumPy backprop sketch, Next.js canvas UI, and API route for live predictions.

~240 min read + exercises

Project: MNIST digit classifier + draw UI

Before we begin

You will train a network on MNIST (28×28 handwritten digits), evaluate on a held-out test set, then wire a canvas UI so anyone can draw a digit and get a prediction.

Figure

What you are building

Canvasdraw 0–9APINext.jsModelPyTorchResultdigit + %
Train in Python/PyTorch → serve via Next.js API → user draws on canvas.

How this connects to Module 3

LessonWhere you use it
Perceptron / layersLinear(784, 128) stacks weights + bias
Activation (ReLU)Non-linearity between hidden and output
Forward & backward passloss.backward() updates all weights
Loss functionsCrossEntropyLoss for 10-class digits
Training loopEpochs, batches, optimizer.step()

Folder layout:

text
mnist-project/
  data/                  # MNIST download (auto)
  train_mnist.py         # train + plot + save weights
  predict_server.py      # optional Flask sidecar (same pattern as Module 2)
  outputs/
    loss_curve.png
    accuracy_curve.png
  digit_net.pt           # copy to Next.js public/models/

What you will build

  1. Train a 2-layer MLP in PyTorch on MNIST.
  2. Plot loss and validation accuracy per epoch.
  3. Save model weights for inference.
  4. API route in this Next.js app that accepts a 28×28 pixel grid.
  5. Canvas component — draw a digit, see predicted class + confidence.

Optional stretch: implement tiny backprop in NumPy for a 784→32→10 network on a subset.

Estimated time: 4–6 hours.


Before you start

  • Finish the Module 3 quiz.
  • pip install torch torchvision matplotlib

Create folder mnist-project/ beside your course notes (or anywhere local).


Step 1 — Load MNIST (PyTorch)

Goal: Download MNIST and wrap it in DataLoader batches.

python
# train_mnist.py
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from pathlib import Path
 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:", device)
 
transform = transforms.Compose([
    transforms.ToTensor(),                          # PIL → float tensor [0,1], shape (1,28,28)
    transforms.Normalize((0.1307,), (0.3081,)),     # MNIST mean/std — match at inference time
])
 
data_dir = Path(__file__).resolve().parent / "data"
train_ds = datasets.MNIST(data_dir, train=True, download=True, transform=transform)
test_ds = datasets.MNIST(data_dir, train=False, download=True, transform=transform)
 
train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=256)
PieceWhy it matters
ToTensor()Converts pixels to floats 0–1
NormalizeCenters data — same transform when user draws on canvas
shuffle=TrueRandom order each epoch — better generalization
Official test splitNever train on test — honest accuracy

Step 2 — Define the network

python
class DigitNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Flatten(),              # (batch, 1, 28, 28) → (batch, 784)
            nn.Linear(28 * 28, 128),
            nn.ReLU(),
            nn.Linear(128, 10),        # 10 logits — one per digit 0–9
        )
 
    def forward(self, x):
        return self.net(x)
 
model = DigitNet().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

Cross-entropy expects raw logits (no softmax in the model). PyTorch applies log-softmax internally. Labels are integers 09, not one-hot.


Step 3 — Training loop with per-epoch metrics

Goal: Track loss and test accuracy every epoch — spot overfitting early.

python
import matplotlib.pyplot as plt
 
epochs = 5
loss_history = []
acc_history = []
 
def evaluate_accuracy(loader):
    model.eval()
    correct = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            pred = model(x).argmax(dim=1)
            correct += (pred == y).sum().item()
    return correct / len(loader.dataset)
 
for epoch in range(epochs):
    model.train()
    total_loss = 0.0
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * x.size(0)
 
    avg_loss = total_loss / len(train_ds)
    acc = evaluate_accuracy(test_loader)
    loss_history.append(avg_loss)
    acc_history.append(acc)
    print(f"epoch {epoch+1}/{epochs}  loss={avg_loss:.4f}  test_acc={acc:.4f}")
 
out = Path(__file__).resolve().parent / "outputs"
out.mkdir(exist_ok=True)
plt.plot(loss_history)
plt.xlabel("epoch"); plt.ylabel("train loss"); plt.savefig(out / "loss_curve.png")
plt.figure()
plt.plot(acc_history)
plt.xlabel("epoch"); plt.ylabel("test accuracy"); plt.savefig(out / "accuracy_curve.png")

What each line does:

  • model.train() — enables dropout/batchnorm training mode (none here, but good habit).
  • zero_grad() — clears old gradients before backward().
  • evaluate_accuracy on test — quick sanity check each epoch (for a course project; in industry you'd use a separate val set).

Step 4 — Final evaluation + save

python
final_acc = evaluate_accuracy(test_loader)
print(f"final test accuracy: {final_acc:.4f}")  # expect ~97%+ after 5 epochs
 
torch.save({
    "state_dict": model.state_dict(),
    "mean": 0.1307,
    "std": 0.3081,
}, Path(__file__).resolve().parent / "digit_net.pt")

Copy digit_net.pt into your Next.js project at public/models/digit_net.pt (or keep a Python sidecar — see Step 5).


Step 5 — Inference bridge (choose one)

Option A — Python sidecar (recommended for learning)

Same pattern as Module 2 spam project:

python
# predict_server.py
from flask import Flask, request, jsonify
import torch
from train_mnist import DigitNet  # import your class
 
app = Flask(__name__)
ckpt = torch.load("digit_net.pt", map_location="cpu", weights_only=True)
model = DigitNet()
model.load_state_dict(ckpt["state_dict"])
model.eval()
MEAN, STD = ckpt["mean"], ckpt["std"]
 
@app.post("/predict")
def predict():
    pixels = request.json.get("pixels", [])
    if len(pixels) != 784:
        return jsonify({"error": "expected 784 pixels"}), 400
    x = torch.tensor(pixels, dtype=torch.float32).view(1, 1, 28, 28)
    x = (x - MEAN) / STD
    with torch.no_grad():
        logits = model(x)
        probs = torch.softmax(logits, dim=1)[0].tolist()
    digit = int(logits.argmax().item())
    return jsonify({"digit": digit, "probabilities": probs})
 
if __name__ == "__main__":
    app.run(port=5002)

Option B — Next.js calls sidecar

typescript
// app/api/predict-digit/route.ts
import { NextResponse } from "next/server";
 
const PY_URL = process.env.PY_DIGIT_URL ?? "http://127.0.0.1:5002/predict";
 
export async function POST(req: Request) {
  const { pixels } = await req.json();
  if (!Array.isArray(pixels) || pixels.length !== 784) {
    return NextResponse.json({ error: "expected 784 pixels" }, { status: 400 });
  }
  const res = await fetch(PY_URL, {
    method: "POST",
    headers: { "Content-Type": "application/json" },
    body: JSON.stringify({ pixels }),
  });
  if (!res.ok) return NextResponse.json({ error: "inference failed" }, { status: 502 });
  return NextResponse.json(await res.json());
}

Step 6 — Draw canvas (React client component)

Goal: 280×280 drawing area → downsample to 28×28 → POST to API.

Key details:

DetailWhy
MNIST is white on blackInvert if your canvas draws black strokes on white
Downsample 10×280÷28 = 10 — average pixel blocks or nearest-neighbor
Values 0–1Match ToTensor() before normalize in server
Center the digitMNIST digits are roughly centered — crop bounding box optional stretch

Sketch flow:

tsx
"use client";
// 1. on pointer move → draw thick stroke on 280×280 canvas
// 2. on submit → get ImageData, resize to 28×28 grayscale array length 784
// 3. fetch("/api/predict-digit", { method: "POST", body: JSON.stringify({ pixels }) })
// 4. show data.digit and top-3 from data.probabilities

Make it meaningful: draw messy 4 vs 9, 3 vs 8 — note when the model fails. That connects to robustness and data augmentation (Module 4 preview).


Optional — NumPy backprop (from scratch)

On 1,000 MNIST samples:

  • Forward: 784→32→10 with ReLU hidden.
  • Cross-entropy gradient on output.
  • Backprop manually through one hidden layer.
  • Compare val accuracy to PyTorch after same epochs.

Strongly reinforces Lesson 4 — optional but high value.


Troubleshooting

SymptomFix
Accuracy stuck ~10%Forgot Normalize at inference; or labels wrong shape
Canvas always predicts 1Invert colors — MNIST is white digit on black
CUDA out of memoryUse batch_size=32 or device="cpu"
Flask import errorRun from mnist-project/ so train_mnist is on path

Deliverables checklist

  • Train/val/test discipline (MNIST official test split)
  • loss_curve.png + accuracy_curve.png saved
  • Test accuracy recorded in README
  • Working canvas → API → prediction
  • Short write-up: one failure case you observed

What's next

Module 3 complete. Continue to Module 4 — Deep learning architectures when ready.

Return to the AI course curriculum anytime to track progress.