U-Net architecture
Before we begin
Lesson 2 left you with a problem: a plain encoder–decoder squashes the image to a small bottleneck, then tries to rebuild sharp masks from coarse features alone. Borders suffer.
U-Net (Ronneberger, Fischer, Brox — 2015, biomedical cell segmentation) fixes this with skip connections: copy high-resolution encoder feature maps across to the decoder and concatenate them before the next conv block.
The name comes from the U shape — down the left, up the right, bridges in the middle.
Figure
U-Net schematic
What you will learn
- Read a U-Net diagram and map each part to code.
- Explain skip connections with a “what vs where” intuition.
- Track tensor shapes through one forward pass.
- Recognize the double conv block and 1×1 output head.
- Know what changes in modern variants (awareness, not implementation).
Before this lesson
High-level forward pass
Input (3, H, W)
enc1 → feature map e1 (C, H, W) ───skip───┐
pool → enc2 → e2 (2C, H/2, W/2) ──skip──┤
pool → enc3 → e3 (4C, H/4, W/4) ─skip───┤
pool → bottleneck (8C, H/8, W/8)
up → concat(e3) → dec3
up → concat(e2) → dec2
up → concat(e1) → dec1
head 1×1 conv → logits (num_classes, H, W)Exact channel multipliers vary by implementation; your project uses base=32 → 32, 64, 128, 256, 512.
Double conv block — the repeating unit
At each level, U-Net typically applies two 3×3 convolutions with ReLU:
nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.ReLU(inplace=True),
)Why two? One conv can only learn a limited set of filters; two stacked convs = deeper local nonlinearity before you pool away resolution — standard since the original paper.
padding=1 on 3×3 conv preserves H×W (when stride=1).
Skip connections — the core idea
Before you MaxPool away e1, save it. When the decoder upsamples back to the same resolution, concatenate the saved encoder map:
d1 = self.up(dec2) # upsampled decoder feature
d1 = torch.cat([d1, e1], dim=1) # channels double along dim=1
d1 = self.dec1_conv(d1) # mix fused features| Source | Information |
|---|---|
| Bottleneck / deep decoder | Semantic context — “there is a pet in this region” |
| Shallow encoder skip | Spatial detail — edges, corners, texture |
Analogy: the bottleneck writes the caption (“a dog on grass”); the skip feeds back the original outline sketch so the decoder traces the boundary correctly.
| Without skips | With skips |
|---|---|
| Blobby masks | Sharper object borders |
| Lost hair / legs | Thin structures survive |
| Small objects vanish | Local cues preserved |
Add vs concat: U-Net uses concat (channels stack). Some models add feature maps (same channel count required). Concat is more flexible — the following conv learns how to merge.
Worked shape example (H=W=256, base=32)
Simplified one forward pass (batch size 1):
| Step | Tensor shape | Notes |
|---|---|---|
| Input | (1, 3, 256, 256) | RGB image |
| e1 | (1, 32, 256, 256) | saved for skip |
| e2 | (1, 64, 128, 128) | after pool |
| e3 | (1, 128, 64, 64) | saved |
| e4 | (1, 256, 32, 32) | saved |
| bottleneck | (1, 512, 16, 16) | smallest grid |
| up + concat e4 | (1, 512, 32, 32) | channels = 256+256 after concat |
| … continue up … | ||
| output logits | (1, 3, 256, 256) | 3 classes (pet / bg / border) |
Inference: pred = logits.argmax(dim=1) → (1, 256, 256) integer mask.
Output head
self.head = nn.Conv2d(base, num_classes, kernel_size=1)Each spatial location gets its own 3-vector of logits (for 3 classes). CrossEntropyLoss applies softmax across num_classes independently at each pixel.
U-Net vs classification CNN (ResNet-style)
| Piece | Classifier | U-Net |
|---|---|---|
| Deepest layers | Rich semantics | Same |
| Final pooling | Global average → one vector | None — keep grid |
| Spatial output | Lost | H×W preserved |
| Labels | One per image | One per pixel |
| Data needed | Often large | U-Net often works with smaller medical / niche sets |
Why U-Net for this course project
Oxford-IIIT Pet gives ~7k images with trimap masks — not ImageNet scale. U-Net is the standard teaching choice because:
- Implementation fits in ~80 lines of PyTorch.
- Trains on a single GPU in minutes to tens of minutes.
- Skip connections visibly improve borders in overlay plots.
- Same ideas transfer to portrait segmentation and medical imaging.
Modern variants — covered in the next lessons
U-Net is your implementation target in this module. Lesson 4 surveys FCN, DeepLab, PSPNet, SegFormer; Lesson 5 covers instance segmentation with Mask R-CNN. Read those before the quiz — they explain when to move past vanilla U-Net.
Checkpoint
- What tensors are stored for skip connections — before or after pooling?
- After
torch.cat([decoder, encoder], dim=1), what happens to channel count? - Why is a 1×1 conv a valid “per-pixel classifier”?
Answers: (1) Save encoder outputs before pooling at each level (the high-res version). (2) It doubles (sum of encoder and decoder channels). (3) It applies the same linear map across space — each (h,w) classified independently from its channel vector.