Data augmentation & training tricks
Before we begin
Module 3 Lesson 1 covered how CNNs work. This lesson covers how to train them well — especially when data is limited, as in your classifier project.
Learning objectives
- Apply train-only vs eval transforms correctly.
- Use common augmentations: crop, flip, jitter, rotation.
- Monitor train vs val curves for overfitting.
- Set learning rate, weight decay, and early stopping.
Train vs eval transforms
| Transform | Training | Validation / test |
|---|---|---|
| RandomResizedCrop | ✓ | ✗ — use CenterCrop |
| RandomHorizontalFlip | ✓ (if label-safe) | ✗ |
| ColorJitter | ✓ | ✗ |
| Normalize (mean/std) | ✓ | ✓ — must match |
Rule: Anything random stays in training only. Validation must be repeatable.
train_tf = transforms.Compose([
transforms.Resize(256),
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
val_tf = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])Why augmentation works
Real cameras change lighting, framing, and noise. Augmentation simulates that variation so the model learns invariants instead of memorizing exact pixels.
| Augmentation | Simulates |
|---|---|
| Random crop | Object scale / framing |
| Flip | Left-right symmetry (when valid) |
| Color jitter | White balance, exposure |
| Slight rotation | Camera roll (use sparingly for text/faces) |
Do not augment in ways that change the label (e.g. vertical flip for "road vs sky" scenes with gravity cues).
Reading train vs val curves
| Pattern | Likely cause | Try |
|---|---|---|
| Train loss ↓, val acc ↑ | Healthy | Continue |
| Train loss ↓, val acc flat | Underfitting or LR too low | More epochs, higher LR |
| Train acc high, val acc low | Overfitting | More aug, weight decay, freeze backbone, early stop |
Save best val checkpoint — not the last epoch.
Learning rate and weight decay
- Adam
lr=1e-3— common starting point for fine-tuning heads. - AdamW with
weight_decay=1e-2— decoupled L2 regularization. - Scheduler:
StepLRorCosineAnnealingLRfor longer runs.
For transfer learning, use lower LR on backbone (see next lesson).
Early stopping
best_val = 0.0
patience = 3
bad_epochs = 0
for epoch in range(max_epochs):
train_one_epoch(...)
val_acc = evaluate(...)
if val_acc > best_val:
best_val = val_acc
save_checkpoint()
bad_epochs = 0
else:
bad_epochs += 1
if bad_epochs >= patience:
breakCheckpoint
Why is ColorJitter disabled at validation time?
Answer sketch: Val metrics must be stable and comparable across epochs — random jitter would add noise to accuracy.