Show Code
import random
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn
from d2l import torch as d2l
plt.rcParams["figure.figsize"] = (6, 4)This notebook is a controlled experiment to understand why adding Batch Normalization (BN) can improve training dynamics and final accuracy.
“BatchNorm reduces training loss faster and makes optimization more stable.”
The below code is aspired by the working of the wonderful book dive into deep learning, they have created their own d2l package that raps many of the complex implementations for training a model. This allows trying many different approaches to test assumptions and theories.
We will create the following setup to make sure that the only difference between runs is “with BN vs without BN”:
def set_seed(seed: int = 42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
CONFIG = {
"seed": 42,
"lr": 0.1,
"batch_size": 128,
"max_epochs": 10,
"num_gpus": 1, # set to 0 if you don't have CUDA
"num_workers": 0,
}
def make_data():
# Recreate the dataloaders after seeding to keep shuffling comparable.
data = d2l.FashionMNIST(batch_size=CONFIG["batch_size"])
data.num_workers = CONFIG["num_workers"]
return data
def make_trainer():
return d2l.Trainer(max_epochs=CONFIG["max_epochs"], num_gpus=CONFIG["num_gpus"])def init_cnn(module): # @save
"""Initialize weights for CNNs."""
if type(module) == nn.Linear or type(module) == nn.Conv2d:
nn.init.xavier_uniform_(module.weight)
class LeNet(d2l.Classifier): # @save
"""The LeNet-5 model."""
def __init__(self, lr=0.1, num_classes=10):
super().__init__()
self.activations = {}
self.gradients = {}
self.save_hyperparameters()
self.net = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=6,
kernel_size=5, padding=2), nn.Sigmoid(),
nn.AvgPool2d(kernel_size=2, stride=2),
nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5), nn.Sigmoid(),
nn.AvgPool2d(kernel_size=2, stride=2),
nn.Flatten(),
nn.Linear(16 * 5 * 5, 120), nn.Sigmoid(),
nn.Linear(120, 84), nn.Sigmoid(),
nn.Linear(84, num_classes))
@d2l.add_to_class(d2l.Classifier) # @save
def layer_summary(self, X_shape):
X = torch.randn(*X_shape)
for layer in self.net:
X = layer(X)
print(layer.__class__.__name__, 'output shape:\t', X.shape)
model = LeNet()
model.layer_summary((1, 1, 28, 28))Conv2d output shape: torch.Size([1, 6, 28, 28])
Sigmoid output shape: torch.Size([1, 6, 28, 28])
AvgPool2d output shape: torch.Size([1, 6, 14, 14])
Conv2d output shape: torch.Size([1, 16, 10, 10])
Sigmoid output shape: torch.Size([1, 16, 10, 10])
AvgPool2d output shape: torch.Size([1, 16, 5, 5])
Flatten output shape: torch.Size([1, 400])
Linear output shape: torch.Size([1, 120])
Sigmoid output shape: torch.Size([1, 120])
Linear output shape: torch.Size([1, 84])
Sigmoid output shape: torch.Size([1, 84])
Linear output shape: torch.Size([1, 10])
class BNLeNet(d2l.Classifier):
def __init__(self, lr=0.1, num_classes=10):
super().__init__()
self.activations = {}
self.gradients = {}
self.save_hyperparameters()
self.net = nn.Sequential(
nn.Conv2d(1, 6, kernel_size=5, padding=2),
nn.BatchNorm2d(6),
nn.Sigmoid(),
nn.AvgPool2d(kernel_size=2, stride=2),
nn.Conv2d(6, 16, kernel_size=5),
nn.BatchNorm2d(16),
nn.Sigmoid(),
nn.AvgPool2d(kernel_size=2, stride=2),
nn.Flatten(),
nn.Linear(16 * 5 * 5, 120),
nn.BatchNorm1d(120),
nn.Sigmoid(),
nn.Linear(120, 84),
nn.BatchNorm1d(84),
nn.Sigmoid(),
nn.Linear(84, num_classes))
@d2l.add_to_class(d2l.Classifier) # @save
def layer_summary(self, X_shape):
X = torch.randn(*X_shape)
for layer in self.net:
X = layer(X)
print(layer.__class__.__name__, 'output shape:\t', X.shape)
model2 = BNLeNet()
model2.layer_summary((128, 1, 28, 28))Conv2d output shape: torch.Size([128, 6, 28, 28])
BatchNorm2d output shape: torch.Size([128, 6, 28, 28])
Sigmoid output shape: torch.Size([128, 6, 28, 28])
AvgPool2d output shape: torch.Size([128, 6, 14, 14])
Conv2d output shape: torch.Size([128, 16, 10, 10])
BatchNorm2d output shape: torch.Size([128, 16, 10, 10])
Sigmoid output shape: torch.Size([128, 16, 10, 10])
AvgPool2d output shape: torch.Size([128, 16, 5, 5])
Flatten output shape: torch.Size([128, 400])
Linear output shape: torch.Size([128, 120])
BatchNorm1d output shape: torch.Size([128, 120])
Sigmoid output shape: torch.Size([128, 120])
Linear output shape: torch.Size([128, 84])
BatchNorm1d output shape: torch.Size([128, 84])
Sigmoid output shape: torch.Size([128, 84])
Linear output shape: torch.Size([128, 10])
The trainer calss from d2l library already logs the loss and step, so below we will define a simple extractor to store those values into arrays.
# Metric-logging Trainer (per-step train loss + per-epoch val accuracy)
import torch
def _to_float(x):
if isinstance(x, torch.Tensor):
return float(x.detach().cpu())
return float(x)
def _extract_loss(training_step_output):
"""
Handles common patterns:
- Tensor loss
- dict like {"loss": ...}
- tuple/list like (loss, ...)
"""
out = training_step_output
if isinstance(out, dict):
if "loss" not in out:
raise KeyError("training_step returned dict but has no key 'loss'")
return out["loss"]
if isinstance(out, (tuple, list)):
if len(out) == 0:
raise ValueError("training_step returned empty tuple/list")
return out[0]
return out
@torch.no_grad()
def evaluate_accuracy(model, dataloader):
"""
Generic top-1 accuracy for classification.
Assumes dataloader yields (X, y) and model(X) returns logits [N, C].
"""
model.eval()
device = next(model.parameters()).device
correct = 0
total = 0
for X, y in dataloader:
X = X.to(device)
y = y.to(device)
logits = model(X)
# If model returns (logits, ...) or dict, adapt here:
if isinstance(logits, (tuple, list)):
logits = logits[0]
if isinstance(logits, dict) and "logits" in logits:
logits = logits["logits"]
preds = logits.argmax(dim=1)
correct += int((preds == y).sum())
total += int(y.numel())
return correct / max(total, 1)
class MetricTrainer(d2l.Trainer):
"""
Logs:
- train_loss per step (batch)
- val_acc per epoch (end of epoch)
"""
def __init__(self, max_epochs, num_gpus=0, gradient_clip_val=0):
super().__init__(max_epochs=max_epochs, num_gpus=num_gpus,
gradient_clip_val=gradient_clip_val)
self.history = {
"train_step": [],
"train_loss": [],
"val_epoch": [],
"val_acc": [],
}
def fit_epoch(self):
# ---- train loop (log per batch) ----
self.model.train()
for batch in self.train_dataloader:
batch = self.prepare_batch(batch)
out = self.model.training_step(batch)
loss = _extract_loss(out)
self.optim.zero_grad()
loss.backward()
if self.gradient_clip_val > 0:
self.clip_gradients(self.gradient_clip_val, self.model)
self.optim.step()
self.history["train_step"].append(self.train_batch_idx)
self.history["train_loss"].append(_to_float(loss))
self.train_batch_idx += 1
# ---- val loop (optional, log per epoch) ----
if self.val_dataloader is None:
return
# You can still call the model's validation_step if you want its internal logging:
self.model.eval()
for batch in self.val_dataloader:
batch = self.prepare_batch(batch)
self.model.validation_step(batch)
self.val_batch_idx += 1
# Compute a clean scalar accuracy at epoch end:
val_acc = evaluate_accuracy(self.model, self.val_dataloader)
self.history["val_epoch"].append(self.epoch)
self.history["val_acc"].append(val_acc)Before starting the training, we will add a hook function to register the activations before and after the sigmoid function.
def register_hooks(model):
for name, layer in model.named_modules():
if isinstance(layer, nn.Sigmoid):
# Forward hook (activations)
def forward_hook(module, input, output, layer_name=name):
model.activations[layer_name + "_pre"] = input[0].detach()
model.activations[layer_name + "_post"] = output.detach()
# Backward hook (gradients)
def backward_hook(module, grad_input, grad_output, layer_name=name):
model.gradients[layer_name +
"_pre_grad"] = grad_input[0].detach()
model.gradients[layer_name +
"_post_grad"] = grad_output[0].detach()
layer.register_forward_hook(forward_hook)
layer.register_backward_hook(backward_hook)The below code trains the base model without BN
set_seed(42)
data = d2l.FashionMNIST(batch_size=128)
data.num_workers = 0
model = LeNet(lr=0.1)
register_hooks(model)
model.apply_init([next(iter(data.get_dataloader(True)))[0]], init_cnn)
trainer = MetricTrainer(max_epochs=10, num_gpus=1)
trainer.fit(model, data)
# hist["train_loss"], hist["train_step"], hist["val_acc"], ...
hist = trainer.historyUsing the same setup and seed values in training the base model, we will train the BNLeNet, then record the results.
set_seed(42)
data = d2l.FashionMNIST(batch_size=128)
data.num_workers = 0
model2 = BNLeNet(lr=0.1)
register_hooks(model2)
model2.apply_init([next(iter(data.get_dataloader(True)))[0]], init_cnn)
trainer2 = MetricTrainer(max_epochs=10, num_gpus=1)
trainer2.fit(model2, data)
# hist["train_loss"], hist["train_step"], hist["val_acc"], ...
hist2 = trainer2.history# compute a rolling mean of train_loss with window size 10 for smoother visualization.
window_size = 20
def rolling_mean(x, w):
return np.convolve(x, np.ones(w), 'valid') / w
# plot rolling mean of train_loss vs train_step on the same axes.
plt.plot(hist["train_step"][window_size-1:], rolling_mean(hist["train_loss"], window_size), label="LeNet")
plt.plot(hist2["train_step"][window_size-1:], rolling_mean(hist2["train_loss"], window_size), label="BNLeNet")
plt.xlabel("step")
plt.ylabel("train_loss")
plt.legend()
plt.show()Why these results??
The flat line of the baseline model indicate that the model is stuck or there is a possible saturation/vanishing gradients.
But in BN model we see an immediate drop in loss which suggests gradients are flowing easily.
We will test this:
Does BN reduce Sigmoid saturation?
Method of testing: register pre and post activations (Before and after the sigmoid layer) What to observe?: we need to see wether the Sigmoid saturates the layer values (outputs near 0 or 1), making it diffecult for the gradient to flow?
def collect_stats(model):
act_means = {}
grad_means = {}
for k, v in model.activations.items():
act_means[k] = v.abs().mean().item()
for k, v in model.gradients.items():
grad_means[k] = v.abs().mean().item()
return act_means, grad_means
act1, grad1 = collect_stats(model)
act2, grad2 = collect_stats(model2)plt.figure(figsize=(10, 5))
# Use each model's own layers separately
layers1 = list(act1.keys())
layers2 = list(act2.keys())
x1 = np.arange(len(layers1))
x2 = np.arange(len(layers2))
plt.plot(x1, [act1[k] for k in layers1], marker='o', label="LeNet")
plt.plot(x2, [act2[k] for k in layers2], marker='o', label="BNLeNet")
# Use a shared x-axis label like "Sigmoid layer index"
plt.xticks(x1, [f"Sigmoid {i}" for i in range(len(layers1))], rotation=45)
plt.ylabel("Mean Absolute Activation")
plt.title("Activation Magnitude Comparison")
plt.legend()
plt.tight_layout()
plt.show()
Comments