← Back to curriculum

Module 5 — Image segmentation

Project: U-Net pet segmentation

Build U-Net in PyTorch on Oxford-IIIT Pet masks, track mIoU, save overlay visualizations, optional DeepLab compare, and upload demo.

~420 min read + exercises

Project: U-Net pet segmentation

Before we begin

You will build a U-Net from scratch in PyTorch, train it on Oxford-IIIT Pet segmentation masks, measure IoU, and visualize predictions overlaid on photos. This is the course’s first dense vision project — harder than MNIST or spam, closer to real CV work.

After your from-scratch build, an optional stretch compares a pretrained DeepLab (Lesson 4) so you see when industrial backbones beat a small custom U-Net.

Figure

What you are building

Project pipeline — from pixel labels to interactive mask previewPet imagesMask labelsU-Net trainIoU evalMask overlay UI
Pet photos + pixel masks → train U-Net → IoU + mask overlay demo.

How this connects to Module 5

LessonWhere you use it
Segmentation types3-class pet trimap: background / pet / border
Encoder–decoderDownsample path + upsample path in your UNet class
U-Net skipstorch.cat encoder features into decoder
Beyond U-NetStretch: fine-tune DeepLab with segmentation_models.pytorch
Instance segAwareness — pet trimap is semantic; counting pets needs Mask R-CNN
IoU / CE lossCrossEntropyLoss + compute_iou() on val set

Folder layout:

text
unet-segmentation/
  data/                  # torchvision download
  train_unet.py          # model + loop + metrics + plots
  predict.py             # single image inference + overlay
  compare_deeplab.py     # optional stretch — pretrained baseline
  outputs/
    loss_curve.png
    iou_curve.png
    samples/             # image | GT | pred overlays
  checkpoints/
    unet_best.pt

What you will build

  1. Load Oxford-IIIT Pet with segmentation masks (torchvision).
  2. Implement U-Net (double conv, pool, upsample, skips).
  3. Train with per-pixel cross-entropy; track val mIoU.
  4. Save best checkpoint; plot loss and IoU curves.
  5. Visualize predictions — color mask overlaid on input.
  6. Optional: Flask or Next.js route — upload image → return mask PNG.
  7. Stretch: compare val mIoU against a pretrained DeepLabv3+ on the same split.

Estimated time: 6–8 hours (core); +2–3 hours for DeepLab stretch.


Before you start


Step 1 — Dataset (Oxford-IIIT Pet)

Goal: Paired images and pixel masks — torchvision handles download.

python
# train_unet.py
import torch
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from pathlib import Path
 
IMG_SIZE = 256
data_dir = Path(__file__).resolve().parent / "data"
 
transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
 
mask_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE), interpolation=transforms.InterpolationMode.NEAREST),
])
 
class PetSegDataset(torch.utils.data.Dataset):
  def __init__(self, split="trainval"):
    self.ds = datasets.OxfordIIITPet(
      root=data_dir,
      split=split,
      target_types="segmentation",
      download=True,
      transform=transform,
      target_transform=mask_transform,
    )
 
  def __len__(self):
    return len(self.ds)
 
  def __getitem__(self, i):
    img, mask = self.ds[i]
    # mask: 1=fg, 2=bg, 3=border → map to 0,1,2 for CrossEntropyLoss
    mask = mask.squeeze(0).long()
    mask = torch.where(mask == 1, 0, torch.where(mask == 2, 1, 2))
    return img, mask
 
full = PetSegDataset("trainval")
n_val = int(0.15 * len(full))
n_train = len(full) - n_val
train_ds, val_ds = random_split(full, [n_train, n_val], generator=torch.Generator().manual_seed(42))
 
train_loader = DataLoader(train_ds, batch_size=8, shuffle=True, num_workers=2)
val_loader = DataLoader(val_ds, batch_size=8, shuffle=False, num_workers=2)
PieceWhy
NEAREST resize on maskPreserves class IDs — bilinear would blend labels
Normalize on image onlyImageNet stats — standard for conv encoders
3 classesPet / background / trimap border

Step 2 — U-Net building blocks

python
import torch.nn as nn
import torch.nn.functional as F
 
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True),
        )
 
    def forward(self, x):
        return self.net(x)
 
class UNet(nn.Module):
    def __init__(self, in_ch=3, num_classes=3, base=32):
        super().__init__()
        self.enc1 = DoubleConv(in_ch, base)
        self.enc2 = DoubleConv(base, base * 2)
        self.enc3 = DoubleConv(base * 2, base * 4)
        self.enc4 = DoubleConv(base * 4, base * 8)
        self.pool = nn.MaxPool2d(2)
        self.bottleneck = DoubleConv(base * 8, base * 16)
 
        self.up4 = nn.ConvTranspose2d(base * 16, base * 8, 2, stride=2)
        self.dec4 = DoubleConv(base * 16, base * 8)
        self.up3 = nn.ConvTranspose2d(base * 8, base * 4, 2, stride=2)
        self.dec3 = DoubleConv(base * 8, base * 4)
        self.up2 = nn.ConvTranspose2d(base * 4, base * 2, 2, stride=2)
        self.dec2 = DoubleConv(base * 4, base * 2)
        self.up1 = nn.ConvTranspose2d(base * 2, base, 2, stride=2)
        self.dec1 = DoubleConv(base * 2, base)
        self.head = nn.Conv2d(base, num_classes, 1)
 
    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        e4 = self.enc4(self.pool(e3))
        b = self.bottleneck(self.pool(e4))
 
        d4 = self.up4(b)
        d4 = self.dec4(torch.cat([d4, e4], dim=1))
        d3 = self.up3(d4)
        d3 = self.dec3(torch.cat([d3, e3], dim=1))
        d2 = self.up2(d3)
        d2 = self.dec2(torch.cat([d2, e2], dim=1))
        d1 = self.up1(d2)
        d1 = self.dec1(torch.cat([d1, e1], dim=1))
        return self.head(d1)

Skip connections: torch.cat([upsampled, encoder_feature], dim=1) — channel dimension doubles before DoubleConv.

Start with base=32 on a laptop; try base=64 if GPU memory allows.


Step 3 — IoU metric

python
import numpy as np
 
def compute_iou(preds: torch.Tensor, targets: torch.Tensor, num_classes: int):
    """preds, targets: (N,H,W) int64"""
    ious = []
    for c in range(num_classes):
        p = preds == c
        t = targets == c
        inter = (p & t).sum().item()
        union = (p | t).sum().item()
        ious.append(inter / union if union > 0 else float("nan"))
    return float(np.nanmean(ious))

Log mIoU each epoch — expect gradual improvement over 15–30 epochs.


Step 4 — Training loop

python
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet(in_ch=3, num_classes=3, base=32).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
 
best_iou = 0.0
for epoch in range(25):
    model.train()
    train_loss = 0.0
    for imgs, masks in train_loader:
        imgs, masks = imgs.to(device), masks.to(device)
        optimizer.zero_grad()
        logits = model(imgs)
        loss = criterion(logits, masks)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * imgs.size(0)
 
    model.eval()
    all_preds, all_masks = [], []
    with torch.no_grad():
        for imgs, masks in val_loader:
            imgs = imgs.to(device)
            preds = model(imgs).argmax(dim=1).cpu()
            all_preds.append(preds)
            all_masks.append(masks)
    preds_cat = torch.cat(all_preds)
    masks_cat = torch.cat(all_masks)
    val_iou = compute_iou(preds_cat, masks_cat, num_classes=3)
 
    print(f"epoch {epoch+1}  loss={train_loss/len(train_ds):.4f}  val_mIoU={val_iou:.4f}")
 
    if val_iou > best_iou:
        best_iou = val_iou
        torch.save(model.state_dict(), "checkpoints/unet_best.pt")

Make it meaningful:

  • Plot train_loss and val_mIoU with matplotlib.
  • Save 5 sample overlays when val_iou improves.

Step 5 — Visualize predictions

python
# predict.py
import matplotlib.pyplot as plt
import numpy as np
 
PALETTE = np.array([
    [128, 0, 128],   # pet — purple
    [0, 0, 0],       # background
    [0, 255, 0],     # border — green
], dtype=np.uint8)
 
def overlay_mask(image_tensor, pred_mask):
    img = image_tensor.permute(1, 2, 0).numpy()
    img = (img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])).clip(0, 1)
    color = PALETTE[pred_mask.numpy()]
    blended = (0.6 * img + 0.4 * (color / 255)).clip(0, 1)
    return blended

Save side-by-side: input | ground truth | prediction in outputs/samples/.

Target: val mIoU ≥ 0.55 with base=32 and 256×256 after ~20 epochs (hardware varies).


Step 6 — Optional web demo

Same pattern as MNIST / spam:

  • Flask POST /segment with image bytes → returns mask PNG.
  • Or Next.js API calling Python subprocess.

Show mask overlay in the browser — strong portfolio piece.


Step 7 — Stretch: compare pretrained DeepLab

Goal: Same val split, same mIoU function — see if a ResNet-backed DeepLabv3+ beats your scratch U-Net after a short fine-tune.

python
# compare_deeplab.py — requires segmentation-models-pytorch
import segmentation_models_pytorch as smp
 
deeplab = smp.DeepLabV3Plus(
    encoder_name="resnet34",
    encoder_weights="imagenet",
    in_channels=3,
    classes=3,
).to(device)
 
optimizer = torch.optim.Adam(deeplab.parameters(), lr=1e-4)
# Reuse train_loader / val_loader / compute_iou from train_unet.py
# Train 10–15 epochs — ImageNet encoder already knows edges and textures
ObservationWhat it teaches
DeepLab mIoU higher after few epochsPretrained context + ASPP helps on natural photos
U-Net competitive on bordersSkips still strong when data is small
DeepLab slower / more VRAMTrade-off for deployment

Write one paragraph in your README: when you would pick U-Net vs DeepLab for a new product.


Troubleshooting

SymptomFix
mIoU stuck near 0Mask label mapping wrong; check torch.where remap
Loss NaNLower LR; check mask values are 0..C-1
OOMbatch_size=4, IMG_SIZE=128, or base=16
Blurry borders onlyTrain longer; confirm skips are concatenated (not accidentally dropped)
Checkerboard artifactsSwitch transposed conv to upsample + conv

Deliverables

  • UNet class with skip connections
  • loss_curve.png + iou_curve.png
  • Best val mIoU recorded in README
  • At least 5 overlay images in outputs/samples/
  • Short write-up: one failure mode you observed (e.g. border class confusion)
  • (Stretch) U-Net vs DeepLab mIoU table on the same val split

What's next

Module 5 complete. Continue to Module 6 — Transformers (core of GenAI) when ready.

Return to the AI course curriculum anytime.