← Back to curriculum

Module 5 — Image segmentation

U-Net architecture

Skip connections, double conv blocks, shape table for a 256×256 forward pass, concat vs add, and pointer to other model families.

~85 min read + exercises

U-Net architecture

Before we begin

Lesson 2 left you with a problem: a plain encoder–decoder squashes the image to a small bottleneck, then tries to rebuild sharp masks from coarse features alone. Borders suffer.

U-Net (Ronneberger, Fischer, Brox — 2015, biomedical cell segmentation) fixes this with skip connections: copy high-resolution encoder feature maps across to the decoder and concatenate them before the next conv block.

The name comes from the U shape — down the left, up the right, bridges in the middle.

Figure

U-Net schematic

U-Net — contract path (left), expand path (right), skip connections (dashed)Encoder ↓Decoder ↑256×2563ch128×12864ch64×64128ch32×32256ch16×16512chbottleneck
Encoder (left) loses resolution but gains semantics; decoder (right) recovers resolution; gold dashes = skips.

What you will learn

  • Read a U-Net diagram and map each part to code.
  • Explain skip connections with a “what vs where” intuition.
  • Track tensor shapes through one forward pass.
  • Recognize the double conv block and 1×1 output head.
  • Know what changes in modern variants (awareness, not implementation).

Before this lesson


High-level forward pass

text
Input (3, H, W)
  enc1 → feature map e1  (C, H, W)     ───skip───┐
  pool → enc2 → e2       (2C, H/2, W/2) ──skip──┤
  pool → enc3 → e3       (4C, H/4, W/4) ─skip───┤
  pool → bottleneck      (8C, H/8, W/8)
  up   → concat(e3) → dec3
  up   → concat(e2) → dec2
  up   → concat(e1) → dec1
  head 1×1 conv → logits (num_classes, H, W)

Exact channel multipliers vary by implementation; your project uses base=32 → 32, 64, 128, 256, 512.


Double conv block — the repeating unit

At each level, U-Net typically applies two 3×3 convolutions with ReLU:

python
nn.Sequential(
    nn.Conv2d(in_ch, out_ch, 3, padding=1),
    nn.ReLU(inplace=True),
    nn.Conv2d(out_ch, out_ch, 3, padding=1),
    nn.ReLU(inplace=True),
)

Why two? One conv can only learn a limited set of filters; two stacked convs = deeper local nonlinearity before you pool away resolution — standard since the original paper.

padding=1 on 3×3 conv preserves H×W (when stride=1).


Skip connections — the core idea

Before you MaxPool away e1, save it. When the decoder upsamples back to the same resolution, concatenate the saved encoder map:

python
d1 = self.up(dec2)                    # upsampled decoder feature
d1 = torch.cat([d1, e1], dim=1)      # channels double along dim=1
d1 = self.dec1_conv(d1)               # mix fused features
SourceInformation
Bottleneck / deep decoderSemantic context — “there is a pet in this region”
Shallow encoder skipSpatial detail — edges, corners, texture

Analogy: the bottleneck writes the caption (“a dog on grass”); the skip feeds back the original outline sketch so the decoder traces the boundary correctly.

Without skipsWith skips
Blobby masksSharper object borders
Lost hair / legsThin structures survive
Small objects vanishLocal cues preserved

Add vs concat: U-Net uses concat (channels stack). Some models add feature maps (same channel count required). Concat is more flexible — the following conv learns how to merge.


Worked shape example (H=W=256, base=32)

Simplified one forward pass (batch size 1):

StepTensor shapeNotes
Input(1, 3, 256, 256)RGB image
e1(1, 32, 256, 256)saved for skip
e2(1, 64, 128, 128)after pool
e3(1, 128, 64, 64)saved
e4(1, 256, 32, 32)saved
bottleneck(1, 512, 16, 16)smallest grid
up + concat e4(1, 512, 32, 32)channels = 256+256 after concat
… continue up …
output logits(1, 3, 256, 256)3 classes (pet / bg / border)

Inference: pred = logits.argmax(dim=1)(1, 256, 256) integer mask.


Output head

python
self.head = nn.Conv2d(base, num_classes, kernel_size=1)

Each spatial location gets its own 3-vector of logits (for 3 classes). CrossEntropyLoss applies softmax across num_classes independently at each pixel.


U-Net vs classification CNN (ResNet-style)

PieceClassifierU-Net
Deepest layersRich semanticsSame
Final poolingGlobal average → one vectorNone — keep grid
Spatial outputLostH×W preserved
LabelsOne per imageOne per pixel
Data neededOften largeU-Net often works with smaller medical / niche sets

Why U-Net for this course project

Oxford-IIIT Pet gives ~7k images with trimap masks — not ImageNet scale. U-Net is the standard teaching choice because:

  • Implementation fits in ~80 lines of PyTorch.
  • Trains on a single GPU in minutes to tens of minutes.
  • Skip connections visibly improve borders in overlay plots.
  • Same ideas transfer to portrait segmentation and medical imaging.

Modern variants — covered in the next lessons

U-Net is your implementation target in this module. Lesson 4 surveys FCN, DeepLab, PSPNet, SegFormer; Lesson 5 covers instance segmentation with Mask R-CNN. Read those before the quiz — they explain when to move past vanilla U-Net.


Checkpoint

  1. What tensors are stored for skip connections — before or after pooling?
  2. After torch.cat([decoder, encoder], dim=1), what happens to channel count?
  3. Why is a 1×1 conv a valid “per-pixel classifier”?

Answers: (1) Save encoder outputs before pooling at each level (the high-res version). (2) It doubles (sum of encoder and decoder channels). (3) It applies the same linear map across space — each (h,w) classified independently from its channel vector.


What's next

Lesson 4 — Beyond U-Net: other segmentation models