Show Code
import torch
from torch import nn
from d2l import torch as d2l
Show Code
def init_cnn(module):  # @save
    """Initialize weights for CNNs."""
    if type(module) == nn.Linear or type(module) == nn.Conv2d:
        nn.init.xavier_uniform_(module.weight)

class LeNet(d2l.Classifier):

    def __init__(self, lr=0.1, num_classes=10):
        super().__init__()
        self.save_hyperparameters()

        self.conv1 = nn.Conv2d(1, 6, 5, padding=2)
        self.pool1 = nn.AvgPool2d(2, 2)

        self.conv2 = nn.Conv2d(6, 16, 5)
        self.pool2 = nn.AvgPool2d(2, 2)

        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, num_classes)

    def forward(self, X):
        #print("Input:", X.shape)

        X = torch.sigmoid(self.conv1(X))
        #print("After conv1:", X.shape)

        X = self.pool1(X)
        #print("After pool1:", X.shape)

        X = torch.sigmoid(self.conv2(X))
        #print("After conv2:", X.shape)

        X = self.pool2(X)
        #print("After pool2:", X.shape)

        X = torch.flatten(X, 1)
        #print("After flatten:", X.shape)

        X = torch.sigmoid(self.fc1(X))
        #print("After fc1:", X.shape)

        X = torch.sigmoid(self.fc2(X))
        #print("After fc2:", X.shape)
        X = self.fc3(X)
        print("Output:", X.shape)

        return X

Build the same Trainer class from d2l with adding print statments to inspect the flow of data during training.

Show Code
class Trainer_test(d2l.HyperParameters):
    """The base class for training models with data.

    Defined in :numref:`subsec_oo-design-models`"""

    def __init__(self, max_epochs, num_gpus=0, gradient_clip_val=0):
        self.save_hyperparameters()
        assert num_gpus == 0, 'No GPU support yet'

    def prepare_data(self, data):
        self.train_dataloader = data.train_dataloader()
        self.val_dataloader = data.val_dataloader()
        self.num_train_batches = len(self.train_dataloader)
        self.num_val_batches = (len(self.val_dataloader)
                                if self.val_dataloader is not None else 0)

    def prepare_model(self, model):
        model.trainer = self
        model.board.xlim = [0, self.max_epochs]
        self.model = model

    def fit(self, model, data):
        self.prepare_data(data)
        self.prepare_model(model)
        self.optim = model.configure_optimizers()
        self.epoch = 0
        self.train_batch_idx = 0
        self.val_batch_idx = 0
        for self.epoch in range(self.max_epochs):
            self.fit_epoch()

    def fit_epoch(self):
        raise NotImplementedError

    def prepare_batch(self, batch):
        """Defined in :numref:`sec_linear_scratch`"""
        return batch

    def fit_epoch(self):
        """Defined in :numref:`sec_linear_scratch`"""
        self.model.train()
        for batch in self.train_dataloader:
            print("\n--- New Batch ---")
            X, y = batch
            print("Input shape:", X.shape)
            print("Target shape:", y.shape)
            loss = self.model.training_step(self.prepare_batch(batch))
            self.optim.zero_grad()
            with torch.no_grad():
                loss.backward()
                if self.gradient_clip_val > 0:  # To be discussed later
                    self.clip_gradients(self.gradient_clip_val, self.model)
                self.optim.step()
            self.train_batch_idx += 1
        if self.val_dataloader is None:
            return
        self.model.eval()
        for batch in self.val_dataloader:
            with torch.no_grad():
                self.model.validation_step(self.prepare_batch(batch))
            self.val_batch_idx += 1

    def __init__(self, max_epochs, num_gpus=0, gradient_clip_val=0):
        """Defined in :numref:`sec_use_gpu`"""
        self.save_hyperparameters()
        self.gpus = [d2l.gpu(i) for i in range(min(num_gpus, d2l.num_gpus()))]

    def prepare_batch(self, batch):
        """Defined in :numref:`sec_use_gpu`"""
        if self.gpus:
            batch = [d2l.to(a, self.gpus[0]) for a in batch]
        return batch

    def prepare_model(self, model):
        """Defined in :numref:`sec_use_gpu`"""
        model.trainer = self
        model.board.xlim = [0, self.max_epochs]
        if self.gpus:
            model.to(self.gpus[0])
        self.model = model

    def clip_gradients(self, grad_clip_val, model):
        """Defined in :numref:`sec_rnn-scratch`"""
        params = [p for p in model.parameters() if p.requires_grad]
        norm = torch.sqrt(sum(torch.sum((p.grad ** 2)) for p in params))
        if norm > grad_clip_val:
            for param in params:
                param.grad[:] *= grad_clip_val / norm
Show Code
#get data
data = d2l.FashionMNIST(batch_size=128)
# inspect data
print("One sample of training:", data.train[0][0].shape)
print("One sample of training label:", data.train[0][1])
One sample of training: torch.Size([1, 28, 28])
One sample of training label: 9

Below try to run one epoch to see the data flow

Show Code
trainer = Trainer_test(max_epochs=10, num_gpus=1)
data = d2l.FashionMNIST(batch_size=128)
data.num_workers = 0
model = LeNet(lr=0.1)

model.apply(init_cnn)


#prepare model and data
trainer.prepare_model(model)
trainer.prepare_data(data)
trainer.optim = model.configure_optimizers()

# ✅ Initialize the missing attributes that training_step expects
trainer.train_batch_idx = 0
trainer.val_batch_idx = 0
trainer.epoch = 0

#get ONE batch of data
batch = next(iter(trainer.train_dataloader))
print("Batch input shape:", batch[0].shape)
print("Batch target shape:", batch[1].shape)

#move to device if needed
batch = trainer.prepare_batch(batch)

#forward pass + loss
loss = model.training_step(batch) # This will call the training_step method of the model, which computes the loss for the given batch.
print("Loss:", loss.item())

#backward pass
trainer.optim.zero_grad() # Clear existing gradients
loss.backward() # Compute gradients for the model parameters

trainer.optim.step() # Update the model parameters based on the computed gradients

print("One training step completed.")
Batch input shape: torch.Size([128, 1, 28, 28])
Batch target shape: torch.Size([128])
Output: torch.Size([128, 10])
Loss: 2.457829713821411
One training step completed.

Comments