Project: fine-tuned image classifier
Before we begin
This is your first deep learning CV project. You will take a pretrained ResNet (already trained on ImageNet), adapt it to a new task, measure performance honestly on a validation set, inspect failure cases, and export the model for deployment in Module 7.
You are not training a CNN from scratch on millions of images — you are doing transfer learning, exactly as described in Lesson 3.
Figure
What you are building
How this connects to Module 3
| Lesson | Where you use it |
|---|---|
| CNNs for images | ResNet18 backbone — conv blocks shrink space, grow channels |
| Detection lesson (preview) | Same preprocessing discipline matters for any vision model |
| Transfer learning | ImageNet weights + new fc layer for your class count |
| Augmentation | RandomResizedCrop, HorizontalFlip on train only |
| Normalization | ImageNet mean/std — must match at inference and ONNX export |
Folder layout:
pet-classifier/
data/ # torchvision download cache (auto)
train_classifier.py # main script — train, plot, export
outputs/
loss_curve.png
accuracy_curve.png
misclassifications.png
best_model.pt
classifier.onnx
README.md # val accuracy + reflectionEstimated time: 3–5 hours (first time with PyTorch; ~2 hours if comfortable).
What you will build
- Load the Oxford-IIIT Pet dataset (37 cat/dog breeds) with a train/val split.
- Fine-tune ResNet18 with proper train vs val transforms.
- Track train loss and val accuracy every epoch — save the best checkpoint.
- Plot learning curves.
- Visualize 9 misclassified validation images — learn from mistakes.
- Export to ONNX for the Module 7 deployment project.
Before you start
- Finish the Module 3 quiz.
- Python 3.10+ with:
mkdir pet-classifier && cd pet-classifier
python -m venv .venv
# Windows: .venv\Scripts\activate
# macOS/Linux: source .venv/bin/activate
pip install torch torchvision matplotlib numpy onnx onnxruntime- GPU optional — ResNet18 on Oxford Pets runs in ~10–20 min on CPU for 10 epochs; GPU is faster.
Step 1 — Load Oxford-IIIT Pet with a train/val split
Goal: Get labeled images into PyTorch DataLoaders. Never tune hyperparameters on the validation set — only use it to pick the best epoch.
The Oxford-IIIT Pet dataset has ~7,400 images across 37 breeds (cats and dogs). torchvision downloads it automatically.
# train_classifier.py
import random
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, models, transforms
# Reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:", device)
ROOT = Path(__file__).resolve().parent
DATA_DIR = ROOT / "data"
OUT_DIR = ROOT / "outputs"
OUT_DIR.mkdir(exist_ok=True)Transforms — train vs val must differ:
| Transform | Train | Val | Why |
|---|---|---|---|
RandomResizedCrop(224) | ✓ | ✗ | Simulates zoom/scale — regularizes |
CenterCrop(224) | ✗ | ✓ | Fixed, deterministic eval |
RandomHorizontalFlip | ✓ | ✗ | Dogs/cats symmetric — safe flip |
Normalize(ImageNet) | ✓ | ✓ | ResNet pretrained with these stats |
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
train_tf = transforms.Compose([
transforms.Resize(256),
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
])
val_tf = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
])Important: We need two dataset instances with different transforms. The clean pattern is load once without transform, then wrap with Subset:
# Download full set once (no transform yet — we split indices first)
full_ds = datasets.OxfordIIITPet(
root=DATA_DIR,
split="trainval",
target_types="category",
download=True,
transform=None,
)
n = len(full_ds)
indices = list(range(n))
random.shuffle(indices)
val_size = int(0.15 * n) # ~15% for validation
val_idx = indices[:val_size]
train_idx = indices[val_size:]
# Apply transforms via separate wrapped datasets
class TransformSubset(torch.utils.data.Dataset):
def __init__(self, base, indices, transform):
self.base = base
self.indices = indices
self.transform = transform
def __len__(self):
return len(self.indices)
def __getitem__(self, i):
img, label = self.base[self.indices[i]]
if self.transform:
img = self.transform(img)
return img, label
train_ds = TransformSubset(full_ds, train_idx, train_tf)
val_ds = TransformSubset(full_ds, val_idx, val_tf)
train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=2)
val_loader = DataLoader(val_ds, batch_size=32, shuffle=False, num_workers=2)
num_classes = 37
class_names = full_ds.classes # breed names — use for plots
print(f"train: {len(train_ds)} val: {len(val_ds)} classes: {num_classes}")Checkpoint: Why do we shuffle indices before splitting, not use a fixed torchvision split?
Answer sketch:
trainvalis one pool — we create our own holdout. Shuffling first keeps breed distribution similar in train and val (not perfect stratification, but good for a course project).
Step 2 — Build the model (transfer learning)
Goal: Load ImageNet weights, replace the classifier head for 37 breeds.
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
# ResNet18's final layer is `fc`: Linear(512, 1000) for ImageNet
in_features = model.fc.in_features
model.fc = nn.Linear(in_features, num_classes)
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)| Piece | Role |
|---|---|
ResNet18_Weights.IMAGENET1K_V1 | Pretrained conv layers — edges, textures, parts |
New fc | Random init — must learn breed mapping |
CrossEntropyLoss | Standard multi-class loss; expects raw logits |
Adam lr=1e-3 | Reasonable default for fine-tuning entire network on this dataset size |
Optional (recommended after epoch 3): freeze backbone for first 2 epochs, train head only — then unfreeze all. See transfer learning lesson. For simplicity, full fine-tune from the start also works here.
Step 3 — Training loop with validation accuracy
Goal: One epoch = one full pass over training data. After each epoch, measure val accuracy and save the best model.
def evaluate(loader):
model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in loader:
images, labels = images.to(device), labels.to(device)
logits = model(images)
preds = logits.argmax(dim=1)
correct += (preds == labels).sum().item()
total += labels.size(0)
return correct / total
epochs = 10
best_val_acc = 0.0
loss_history = []
val_acc_history = []
for epoch in range(epochs):
model.train()
running_loss = 0.0
n_samples = 0
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
logits = model(images)
loss = criterion(logits, labels)
loss.backward()
optimizer.step()
running_loss += loss.item() * images.size(0)
n_samples += images.size(0)
avg_loss = running_loss / n_samples
val_acc = evaluate(val_loader)
loss_history.append(avg_loss)
val_acc_history.append(val_acc)
print(f"epoch {epoch + 1}/{epochs} train_loss={avg_loss:.4f} val_acc={val_acc:.4f}")
if val_acc > best_val_acc:
best_val_acc = val_acc
torch.save({
"state_dict": model.state_dict(),
"num_classes": num_classes,
"class_names": class_names,
"imagenet_mean": IMAGENET_MEAN,
"imagenet_std": IMAGENET_STD,
}, OUT_DIR / "best_model.pt")
print(" → saved best checkpoint")
print(f"\nbest val accuracy: {best_val_acc:.4f}")What to watch for:
| Signal | Meaning |
|---|---|
| Train loss ↓, val acc ↑ | Healthy learning |
| Train loss ↓, val acc flat | May need more epochs or lower LR |
| Train loss ↓, val acc ↓ | Overfitting — stop early, add augmentation, or freeze backbone |
Expect ~35–55% val accuracy on 37 fine-grained breeds with ResNet18 in 10 epochs (CPU/GPU varies). That sounds low vs MNIST — 37 similar-looking breeds is much harder than 10 digits.
Step 4 — Plot learning curves
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))
ax1.plot(range(1, epochs + 1), loss_history, marker="o")
ax1.set_xlabel("epoch")
ax1.set_ylabel("train loss")
ax1.set_title("Training loss")
ax2.plot(range(1, epochs + 1), val_acc_history, marker="o", color="green")
ax2.set_xlabel("epoch")
ax2.set_ylabel("val accuracy")
ax2.set_title("Validation accuracy")
plt.tight_layout()
plt.savefig(OUT_DIR / "learning_curves.png", dpi=150)
plt.close()Include learning_curves.png in your README — it shows you monitored training, not just ran one epoch.
Step 5 — Misclassification grid
Goal: Find where the model fails — similar breeds, bad lighting, partial crops.
Load best checkpoint, run on val set, collect wrong predictions:
# Reload best weights
ckpt = torch.load(OUT_DIR / "best_model.pt", map_location=device)
model.load_state_dict(ckpt["state_dict"])
model.eval()
misclassified = [] # list of (tensor_image, true_idx, pred_idx)
with torch.no_grad():
for images, labels in val_loader:
images, labels = images.to(device), labels.to(device)
preds = model(images).argmax(dim=1)
wrong = preds != labels
for i in range(images.size(0)):
if wrong[i] and len(misclassified) < 9:
misclassified.append((
images[i].cpu(),
labels[i].item(),
preds[i].item(),
))
# Denormalize for display
def denorm(tensor):
mean = torch.tensor(IMAGENET_MEAN).view(3, 1, 1)
std = torch.tensor(IMAGENET_STD).view(3, 1, 1)
return torch.clamp(tensor * std + mean, 0, 1)
fig, axes = plt.subplots(3, 3, figsize=(9, 9))
for ax, (img, true_i, pred_i) in zip(axes.flat, misclassified):
ax.imshow(denorm(img).permute(1, 2, 0).numpy())
ax.set_title(f"true: {class_names[true_i][:20]}\npred: {class_names[pred_i][:20]}", fontsize=8)
ax.axis("off")
plt.suptitle("Misclassified validation images", fontsize=12)
plt.tight_layout()
plt.savefig(OUT_DIR / "misclassifications.png", dpi=150)
plt.close()Reflection prompt: Are mistakes between visually similar breeds (e.g. two terriers)? Or bad images (blur, occlusion)? Write 3 sentences in your README.
Step 6 — Export to ONNX
Goal: Portable model file for Module 7's FastAPI serving project.
model.eval()
dummy = torch.randn(1, 3, 224, 224, device=device)
onnx_path = OUT_DIR / "classifier.onnx"
torch.onnx.export(
model,
dummy,
onnx_path,
input_names=["input"],
output_names=["logits"],
dynamic_axes={"input": {0: "batch"}, "logits": {0: "batch"}},
opset_version=17,
)
print("exported:", onnx_path)Verify the export loads and runs:
import onnxruntime as ort
sess = ort.InferenceSession(str(onnx_path), providers=["CPUExecutionProvider"])
test_input = dummy.cpu().numpy()
onnx_out = sess.run(None, {"input": test_input})[0]
pytorch_out = model(dummy).detach().cpu().numpy()
print("max diff pytorch vs onnx:", np.abs(pytorch_out - onnx_out).max())
# Should be ~1e-5 or smaller| Check | Why |
|---|---|
model.eval() before export | Disables dropout — inference graph must be deterministic |
| Same 224×224 input shape | Must match training crop size |
dynamic_axes batch dim | Server can batch 1 or N images later |
| ONNX sanity check | Catches export bugs before deployment |
Troubleshooting
| Problem | Fix |
|---|---|
CUDA out of memory | Lower batch_size to 16 or 8 |
Val accuracy stuck near 1/37 ≈ 2.7% | Labels wrong, or model not training — check loss.backward() runs |
| Val accuracy 100% on epoch 1 | Data leakage — train/val overlap; check your split indices |
| ONNX max diff large (>0.01) | Re-export with model.eval(); try opset_version=17 |
num_workers crash on Windows | Set num_workers=0 in DataLoader |
| Images look wrong in grid | Forgot denorm — normalized tensors look gray/wrong |
Deliverables checklist
outputs/best_model.pt— best val checkpoint.outputs/learning_curves.png— loss + val accuracy.outputs/misclassifications.png— 9 wrong examples labeled.outputs/classifier.onnx— verified with ONNX Runtime.README.mdwith:- Best val accuracy (4 decimal places).
- Hardware used (CPU/GPU).
- 3 sentences: what breeds or image types confuse the model.
- One thing you'd try next (freeze backbone, more epochs, stronger aug).
Optional extensions
- Binary cat vs dog: map 37 breeds → 2 classes; expect much higher accuracy.
- Freeze backbone for 3 epochs, then unfreeze with
lr=1e-4. - Confusion matrix on val set (
sklearn.metrics.confusion_matrix) — which breed pairs swap? - torchvision EfficientNet — swap backbone, compare val acc per training minute.
What's next
Welcome to Module 4 — object detection. Your ONNX file returns in Module 7 when you deploy a vision API.