Project: MNIST digit classifier + draw UI
Before we begin
You will train a network on MNIST (28×28 handwritten digits), evaluate on a held-out test set, then wire a canvas UI so anyone can draw a digit and get a prediction.
Figure
What you are building
How this connects to Module 3
| Lesson | Where you use it |
|---|---|
| Perceptron / layers | Linear(784, 128) stacks weights + bias |
| Activation (ReLU) | Non-linearity between hidden and output |
| Forward & backward pass | loss.backward() updates all weights |
| Loss functions | CrossEntropyLoss for 10-class digits |
| Training loop | Epochs, batches, optimizer.step() |
Folder layout:
mnist-project/
data/ # MNIST download (auto)
train_mnist.py # train + plot + save weights
predict_server.py # optional Flask sidecar (same pattern as Module 2)
outputs/
loss_curve.png
accuracy_curve.png
digit_net.pt # copy to Next.js public/models/What you will build
- Train a 2-layer MLP in PyTorch on MNIST.
- Plot loss and validation accuracy per epoch.
- Save model weights for inference.
- API route in this Next.js app that accepts a 28×28 pixel grid.
- Canvas component — draw a digit, see predicted class + confidence.
Optional stretch: implement tiny backprop in NumPy for a 784→32→10 network on a subset.
Estimated time: 4–6 hours.
Before you start
- Finish the Module 3 quiz.
pip install torch torchvision matplotlib
Create folder mnist-project/ beside your course notes (or anywhere local).
Step 1 — Load MNIST (PyTorch)
Goal: Download MNIST and wrap it in DataLoader batches.
# train_mnist.py
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from pathlib import Path
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:", device)
transform = transforms.Compose([
transforms.ToTensor(), # PIL → float tensor [0,1], shape (1,28,28)
transforms.Normalize((0.1307,), (0.3081,)), # MNIST mean/std — match at inference time
])
data_dir = Path(__file__).resolve().parent / "data"
train_ds = datasets.MNIST(data_dir, train=True, download=True, transform=transform)
test_ds = datasets.MNIST(data_dir, train=False, download=True, transform=transform)
train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=256)| Piece | Why it matters |
|---|---|
ToTensor() | Converts pixels to floats 0–1 |
Normalize | Centers data — same transform when user draws on canvas |
shuffle=True | Random order each epoch — better generalization |
| Official test split | Never train on test — honest accuracy |
Step 2 — Define the network
class DigitNet(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Flatten(), # (batch, 1, 28, 28) → (batch, 784)
nn.Linear(28 * 28, 128),
nn.ReLU(),
nn.Linear(128, 10), # 10 logits — one per digit 0–9
)
def forward(self, x):
return self.net(x)
model = DigitNet().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)Cross-entropy expects raw logits (no softmax in the model). PyTorch applies log-softmax internally. Labels are integers 0–9, not one-hot.
Step 3 — Training loop with per-epoch metrics
Goal: Track loss and test accuracy every epoch — spot overfitting early.
import matplotlib.pyplot as plt
epochs = 5
loss_history = []
acc_history = []
def evaluate_accuracy(loader):
model.eval()
correct = 0
with torch.no_grad():
for x, y in loader:
x, y = x.to(device), y.to(device)
pred = model(x).argmax(dim=1)
correct += (pred == y).sum().item()
return correct / len(loader.dataset)
for epoch in range(epochs):
model.train()
total_loss = 0.0
for x, y in train_loader:
x, y = x.to(device), y.to(device)
optimizer.zero_grad()
logits = model(x)
loss = criterion(logits, y)
loss.backward()
optimizer.step()
total_loss += loss.item() * x.size(0)
avg_loss = total_loss / len(train_ds)
acc = evaluate_accuracy(test_loader)
loss_history.append(avg_loss)
acc_history.append(acc)
print(f"epoch {epoch+1}/{epochs} loss={avg_loss:.4f} test_acc={acc:.4f}")
out = Path(__file__).resolve().parent / "outputs"
out.mkdir(exist_ok=True)
plt.plot(loss_history)
plt.xlabel("epoch"); plt.ylabel("train loss"); plt.savefig(out / "loss_curve.png")
plt.figure()
plt.plot(acc_history)
plt.xlabel("epoch"); plt.ylabel("test accuracy"); plt.savefig(out / "accuracy_curve.png")What each line does:
model.train()— enables dropout/batchnorm training mode (none here, but good habit).zero_grad()— clears old gradients beforebackward().evaluate_accuracyon test — quick sanity check each epoch (for a course project; in industry you'd use a separate val set).
Step 4 — Final evaluation + save
final_acc = evaluate_accuracy(test_loader)
print(f"final test accuracy: {final_acc:.4f}") # expect ~97%+ after 5 epochs
torch.save({
"state_dict": model.state_dict(),
"mean": 0.1307,
"std": 0.3081,
}, Path(__file__).resolve().parent / "digit_net.pt")Copy digit_net.pt into your Next.js project at public/models/digit_net.pt (or keep a Python sidecar — see Step 5).
Step 5 — Inference bridge (choose one)
Option A — Python sidecar (recommended for learning)
Same pattern as Module 2 spam project:
# predict_server.py
from flask import Flask, request, jsonify
import torch
from train_mnist import DigitNet # import your class
app = Flask(__name__)
ckpt = torch.load("digit_net.pt", map_location="cpu", weights_only=True)
model = DigitNet()
model.load_state_dict(ckpt["state_dict"])
model.eval()
MEAN, STD = ckpt["mean"], ckpt["std"]
@app.post("/predict")
def predict():
pixels = request.json.get("pixels", [])
if len(pixels) != 784:
return jsonify({"error": "expected 784 pixels"}), 400
x = torch.tensor(pixels, dtype=torch.float32).view(1, 1, 28, 28)
x = (x - MEAN) / STD
with torch.no_grad():
logits = model(x)
probs = torch.softmax(logits, dim=1)[0].tolist()
digit = int(logits.argmax().item())
return jsonify({"digit": digit, "probabilities": probs})
if __name__ == "__main__":
app.run(port=5002)Option B — Next.js calls sidecar
// app/api/predict-digit/route.ts
import { NextResponse } from "next/server";
const PY_URL = process.env.PY_DIGIT_URL ?? "http://127.0.0.1:5002/predict";
export async function POST(req: Request) {
const { pixels } = await req.json();
if (!Array.isArray(pixels) || pixels.length !== 784) {
return NextResponse.json({ error: "expected 784 pixels" }, { status: 400 });
}
const res = await fetch(PY_URL, {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ pixels }),
});
if (!res.ok) return NextResponse.json({ error: "inference failed" }, { status: 502 });
return NextResponse.json(await res.json());
}Step 6 — Draw canvas (React client component)
Goal: 280×280 drawing area → downsample to 28×28 → POST to API.
Key details:
| Detail | Why |
|---|---|
| MNIST is white on black | Invert if your canvas draws black strokes on white |
| Downsample 10× | 280÷28 = 10 — average pixel blocks or nearest-neighbor |
| Values 0–1 | Match ToTensor() before normalize in server |
| Center the digit | MNIST digits are roughly centered — crop bounding box optional stretch |
Sketch flow:
"use client";
// 1. on pointer move → draw thick stroke on 280×280 canvas
// 2. on submit → get ImageData, resize to 28×28 grayscale array length 784
// 3. fetch("/api/predict-digit", { method: "POST", body: JSON.stringify({ pixels }) })
// 4. show data.digit and top-3 from data.probabilitiesMake it meaningful: draw messy 4 vs 9, 3 vs 8 — note when the model fails. That connects to robustness and data augmentation (Module 4 preview).
Optional — NumPy backprop (from scratch)
On 1,000 MNIST samples:
- Forward: 784→32→10 with ReLU hidden.
- Cross-entropy gradient on output.
- Backprop manually through one hidden layer.
- Compare val accuracy to PyTorch after same epochs.
Strongly reinforces Lesson 4 — optional but high value.
Troubleshooting
| Symptom | Fix |
|---|---|
| Accuracy stuck ~10% | Forgot Normalize at inference; or labels wrong shape |
| Canvas always predicts 1 | Invert colors — MNIST is white digit on black |
CUDA out of memory | Use batch_size=32 or device="cpu" |
| Flask import error | Run from mnist-project/ so train_mnist is on path |
Deliverables checklist
- Train/val/test discipline (MNIST official test split)
-
loss_curve.png+accuracy_curve.pngsaved - Test accuracy recorded in README
- Working canvas → API → prediction
- Short write-up: one failure case you observed
What's next
Module 3 complete. Continue to Module 4 — Deep learning architectures when ready.
Return to the AI course curriculum anytime to track progress.