Project: U-Net pet segmentation
Before we begin
You will build a U-Net from scratch in PyTorch, train it on Oxford-IIIT Pet segmentation masks, measure IoU, and visualize predictions overlaid on photos. This is the course’s first dense vision project — harder than MNIST or spam, closer to real CV work.
After your from-scratch build, an optional stretch compares a pretrained DeepLab (Lesson 4) so you see when industrial backbones beat a small custom U-Net.
Figure
What you are building
How this connects to Module 5
| Lesson | Where you use it |
|---|---|
| Segmentation types | 3-class pet trimap: background / pet / border |
| Encoder–decoder | Downsample path + upsample path in your UNet class |
| U-Net skips | torch.cat encoder features into decoder |
| Beyond U-Net | Stretch: fine-tune DeepLab with segmentation_models.pytorch |
| Instance seg | Awareness — pet trimap is semantic; counting pets needs Mask R-CNN |
| IoU / CE loss | CrossEntropyLoss + compute_iou() on val set |
Folder layout:
unet-segmentation/
data/ # torchvision download
train_unet.py # model + loop + metrics + plots
predict.py # single image inference + overlay
compare_deeplab.py # optional stretch — pretrained baseline
outputs/
loss_curve.png
iou_curve.png
samples/ # image | GT | pred overlays
checkpoints/
unet_best.ptWhat you will build
- Load Oxford-IIIT Pet with segmentation masks (torchvision).
- Implement U-Net (double conv, pool, upsample, skips).
- Train with per-pixel cross-entropy; track val mIoU.
- Save best checkpoint; plot loss and IoU curves.
- Visualize predictions — color mask overlaid on input.
- Optional: Flask or Next.js route — upload image → return mask PNG.
- Stretch: compare val mIoU against a pretrained DeepLabv3+ on the same split.
Estimated time: 6–8 hours (core); +2–3 hours for DeepLab stretch.
Before you start
- Finish the Module 5 quiz.
- Read Beyond U-Net before the stretch step.
- Module 4 — CNNs completed.
pip install torch torchvision matplotlib numpy- Stretch:
pip install segmentation-models-pytorch
Step 1 — Dataset (Oxford-IIIT Pet)
Goal: Paired images and pixel masks — torchvision handles download.
# train_unet.py
import torch
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from pathlib import Path
IMG_SIZE = 256
data_dir = Path(__file__).resolve().parent / "data"
transform = transforms.Compose([
transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
mask_transform = transforms.Compose([
transforms.Resize((IMG_SIZE, IMG_SIZE), interpolation=transforms.InterpolationMode.NEAREST),
])
class PetSegDataset(torch.utils.data.Dataset):
def __init__(self, split="trainval"):
self.ds = datasets.OxfordIIITPet(
root=data_dir,
split=split,
target_types="segmentation",
download=True,
transform=transform,
target_transform=mask_transform,
)
def __len__(self):
return len(self.ds)
def __getitem__(self, i):
img, mask = self.ds[i]
# mask: 1=fg, 2=bg, 3=border → map to 0,1,2 for CrossEntropyLoss
mask = mask.squeeze(0).long()
mask = torch.where(mask == 1, 0, torch.where(mask == 2, 1, 2))
return img, mask
full = PetSegDataset("trainval")
n_val = int(0.15 * len(full))
n_train = len(full) - n_val
train_ds, val_ds = random_split(full, [n_train, n_val], generator=torch.Generator().manual_seed(42))
train_loader = DataLoader(train_ds, batch_size=8, shuffle=True, num_workers=2)
val_loader = DataLoader(val_ds, batch_size=8, shuffle=False, num_workers=2)| Piece | Why |
|---|---|
NEAREST resize on mask | Preserves class IDs — bilinear would blend labels |
Normalize on image only | ImageNet stats — standard for conv encoders |
| 3 classes | Pet / background / trimap border |
Step 2 — U-Net building blocks
import torch.nn as nn
import torch.nn.functional as F
class DoubleConv(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.ReLU(inplace=True),
)
def forward(self, x):
return self.net(x)
class UNet(nn.Module):
def __init__(self, in_ch=3, num_classes=3, base=32):
super().__init__()
self.enc1 = DoubleConv(in_ch, base)
self.enc2 = DoubleConv(base, base * 2)
self.enc3 = DoubleConv(base * 2, base * 4)
self.enc4 = DoubleConv(base * 4, base * 8)
self.pool = nn.MaxPool2d(2)
self.bottleneck = DoubleConv(base * 8, base * 16)
self.up4 = nn.ConvTranspose2d(base * 16, base * 8, 2, stride=2)
self.dec4 = DoubleConv(base * 16, base * 8)
self.up3 = nn.ConvTranspose2d(base * 8, base * 4, 2, stride=2)
self.dec3 = DoubleConv(base * 8, base * 4)
self.up2 = nn.ConvTranspose2d(base * 4, base * 2, 2, stride=2)
self.dec2 = DoubleConv(base * 4, base * 2)
self.up1 = nn.ConvTranspose2d(base * 2, base, 2, stride=2)
self.dec1 = DoubleConv(base * 2, base)
self.head = nn.Conv2d(base, num_classes, 1)
def forward(self, x):
e1 = self.enc1(x)
e2 = self.enc2(self.pool(e1))
e3 = self.enc3(self.pool(e2))
e4 = self.enc4(self.pool(e3))
b = self.bottleneck(self.pool(e4))
d4 = self.up4(b)
d4 = self.dec4(torch.cat([d4, e4], dim=1))
d3 = self.up3(d4)
d3 = self.dec3(torch.cat([d3, e3], dim=1))
d2 = self.up2(d3)
d2 = self.dec2(torch.cat([d2, e2], dim=1))
d1 = self.up1(d2)
d1 = self.dec1(torch.cat([d1, e1], dim=1))
return self.head(d1)Skip connections: torch.cat([upsampled, encoder_feature], dim=1) — channel dimension doubles before DoubleConv.
Start with base=32 on a laptop; try base=64 if GPU memory allows.
Step 3 — IoU metric
import numpy as np
def compute_iou(preds: torch.Tensor, targets: torch.Tensor, num_classes: int):
"""preds, targets: (N,H,W) int64"""
ious = []
for c in range(num_classes):
p = preds == c
t = targets == c
inter = (p & t).sum().item()
union = (p | t).sum().item()
ious.append(inter / union if union > 0 else float("nan"))
return float(np.nanmean(ious))Log mIoU each epoch — expect gradual improvement over 15–30 epochs.
Step 4 — Training loop
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet(in_ch=3, num_classes=3, base=32).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
best_iou = 0.0
for epoch in range(25):
model.train()
train_loss = 0.0
for imgs, masks in train_loader:
imgs, masks = imgs.to(device), masks.to(device)
optimizer.zero_grad()
logits = model(imgs)
loss = criterion(logits, masks)
loss.backward()
optimizer.step()
train_loss += loss.item() * imgs.size(0)
model.eval()
all_preds, all_masks = [], []
with torch.no_grad():
for imgs, masks in val_loader:
imgs = imgs.to(device)
preds = model(imgs).argmax(dim=1).cpu()
all_preds.append(preds)
all_masks.append(masks)
preds_cat = torch.cat(all_preds)
masks_cat = torch.cat(all_masks)
val_iou = compute_iou(preds_cat, masks_cat, num_classes=3)
print(f"epoch {epoch+1} loss={train_loss/len(train_ds):.4f} val_mIoU={val_iou:.4f}")
if val_iou > best_iou:
best_iou = val_iou
torch.save(model.state_dict(), "checkpoints/unet_best.pt")Make it meaningful:
- Plot
train_lossandval_mIoUwith matplotlib. - Save 5 sample overlays when
val_iouimproves.
Step 5 — Visualize predictions
# predict.py
import matplotlib.pyplot as plt
import numpy as np
PALETTE = np.array([
[128, 0, 128], # pet — purple
[0, 0, 0], # background
[0, 255, 0], # border — green
], dtype=np.uint8)
def overlay_mask(image_tensor, pred_mask):
img = image_tensor.permute(1, 2, 0).numpy()
img = (img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])).clip(0, 1)
color = PALETTE[pred_mask.numpy()]
blended = (0.6 * img + 0.4 * (color / 255)).clip(0, 1)
return blendedSave side-by-side: input | ground truth | prediction in outputs/samples/.
Target: val mIoU ≥ 0.55 with base=32 and 256×256 after ~20 epochs (hardware varies).
Step 6 — Optional web demo
Same pattern as MNIST / spam:
- Flask
POST /segmentwith image bytes → returns mask PNG. - Or Next.js API calling Python subprocess.
Show mask overlay in the browser — strong portfolio piece.
Step 7 — Stretch: compare pretrained DeepLab
Goal: Same val split, same mIoU function — see if a ResNet-backed DeepLabv3+ beats your scratch U-Net after a short fine-tune.
# compare_deeplab.py — requires segmentation-models-pytorch
import segmentation_models_pytorch as smp
deeplab = smp.DeepLabV3Plus(
encoder_name="resnet34",
encoder_weights="imagenet",
in_channels=3,
classes=3,
).to(device)
optimizer = torch.optim.Adam(deeplab.parameters(), lr=1e-4)
# Reuse train_loader / val_loader / compute_iou from train_unet.py
# Train 10–15 epochs — ImageNet encoder already knows edges and textures| Observation | What it teaches |
|---|---|
| DeepLab mIoU higher after few epochs | Pretrained context + ASPP helps on natural photos |
| U-Net competitive on borders | Skips still strong when data is small |
| DeepLab slower / more VRAM | Trade-off for deployment |
Write one paragraph in your README: when you would pick U-Net vs DeepLab for a new product.
Troubleshooting
| Symptom | Fix |
|---|---|
| mIoU stuck near 0 | Mask label mapping wrong; check torch.where remap |
| Loss NaN | Lower LR; check mask values are 0..C-1 |
| OOM | batch_size=4, IMG_SIZE=128, or base=16 |
| Blurry borders only | Train longer; confirm skips are concatenated (not accidentally dropped) |
| Checkerboard artifacts | Switch transposed conv to upsample + conv |
Deliverables
-
UNetclass with skip connections -
loss_curve.png+iou_curve.png - Best val mIoU recorded in README
- At least 5 overlay images in
outputs/samples/ - Short write-up: one failure mode you observed (e.g. border class confusion)
- (Stretch) U-Net vs DeepLab mIoU table on the same val split
What's next
Module 5 complete. Continue to Module 6 — Transformers (core of GenAI) when ready.
Return to the AI course curriculum anytime.