Building a simple LSTM network

Uisng pytorch, we will build a simple LSTM network to understand how it works. ANd how information flows through the network.

Show Code
# import 
import torch
import torch.nn as nn
Show Code
text = "hello"
chars = sorted(set(text))                       # ['e', 'h', 'l', 'o']
char_to_idx = {c: i for i, c in enumerate(chars)}
idx_to_char = {i: c for i, c in enumerate(chars)}

vocab_size = len(chars)                          # 4


def one_hot(c):
    one_hot = torch.zeros(vocab_size)
    one_hot[char_to_idx[c]] = 1.0
    return one_hot

The below class defines a simple LSTM cell, it takes in the input and the hidden state, and outputs the new hidden state \(h_t\) and the cell state \(c_t\).

Show Code
class LSTMByHand(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size
        # because we feed [h_{t-1}, x_t]
        combined = input_size + hidden_size

        # Each gate is its own weight matrix W and bias b.
        # nn.Linear(combined, hidden) computes exactly:  W Β· [h_{t-1}, x_t] + b
        self.forget_gate = nn.Linear(combined, hidden_size)   # W_f, b_f
        self.input_gate = nn.Linear(combined, hidden_size)   # W_i, b_i
        self.candidate_gate = nn.Linear(combined, hidden_size)   # W_c, b_c
        self.output_gate = nn.Linear(combined, hidden_size)   # W_o, b_o

    def forward(self, x, h_prev, c_prev):
        # 1. Concatenate previous hidden state with current input β†’ [h_{t-1}, x_t]
        combined = torch.cat([h_prev, x], dim=1)

        # 2. The four gate computations:
        # f_t = Οƒ(W_fΒ·[h,x] + b_f)
        f = torch.sigmoid(self.forget_gate(combined))
        # i_t = Οƒ(W_iΒ·[h,x] + b_i)
        i = torch.sigmoid(self.input_gate(combined))
        # c̃_t = tanh(W_c·[h,x] + b_c)
        c_tilde = torch.tanh(self.candidate_gate(combined))
        # o_t = Οƒ(W_oΒ·[h,x] + b_o)
        o = torch.sigmoid(self.output_gate(combined))

        # 3. Update the cell state (long-term memory)
        # c_t = f_t βŠ™ c_{t-1} + i_t βŠ™ cΜƒ_t
        c = f * c_prev + i * c_tilde

        # 4. Expose part of it as the hidden state (short-term memory)
        # h_t = o_t βŠ™ tanh(c_t)
        h = o * torch.tanh(c)

        return h, c, (f, i, c_tilde, o)

After defining the LSTM cell, let’s inspect the gates and the cell state after passing a random input through the cell.

Show Code
torch.manual_seed(0)
cell = LSTMByHand(input_size=vocab_size, hidden_size=4)

h = torch.zeros(1, 4)          # h_0: blank short-term memory
c = torch.zeros(1, 4)          # c_0: blank long-term memory
x = one_hot('h').unsqueeze(0)  # shape (1, 4) β†’ batch of 1

h, c, (f, i, c_tilde, o) = cell(x, h, c)

print("forget gate  f :", f.detach().round(decimals=2))
print("input  gate  i :", i.detach().round(decimals=2))
print("candidate   c̃ :", c_tilde.detach().round(decimals=2))
print("output gate  o :", o.detach().round(decimals=2))
print("cell state   c :", c.detach().round(decimals=2))
print("hidden state h :", h.detach().round(decimals=2))
forget gate  f : tensor([[0.47, 0.40, 0.51, 0.40]])
input  gate  i : tensor([[0.51, 0.58, 0.37, 0.46]])
candidate   c̃ : tensor([[-0.57, -0.07, -0.13, -0.03]])
output gate  o : tensor([[0.66, 0.56, 0.51, 0.49]])
cell state   c : tensor([[-0.29, -0.04, -0.05, -0.02]])
hidden state h : tensor([[-0.19, -0.02, -0.02, -0.01]])
Show Code
class CharLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        # the gates from Step 1
        self.cell = LSTMByHand(input_size, hidden_size)
        # hidden state β†’ scores over chars
        self.fc = nn.Linear(hidden_size, output_size)
        self.hidden_size = hidden_size

    def forward(self, inputs, return_states=False):
        # inputs: (seq_len, batch, input_size)
        batch = inputs.shape[1]
        # h_0: blank short-term memory
        h = torch.zeros(batch, self.hidden_size)
        # c_0: blank long-term memory
        c = torch.zeros(batch, self.hidden_size)

        logits, states = [], []
        for x in inputs:                            # walk the sequence one timestep at a time
            # ← carry h and c forward (this IS the recurrence)
            h, c, gates = self.cell(x, h, c)
            logits.append(self.fc(h))
            states.append((h, c, gates))

        # (seq_len, batch, output_size)
        logits = torch.stack(logits)
        return (logits, states) if return_states else logits
Show Code
input_chars = text[:-1]   # "hell"
target_chars = text[1:]    # "ello"

# inputs: (seq_len, batch=1, input_size)
inputs = torch.stack([one_hot(c) for c in input_chars]).unsqueeze(1)
targets = torch.tensor([char_to_idx[c]
                       for c in target_chars])   # indices, for cross-entropy

print("inputs shape :", inputs.shape)    # (4, 1, 4)
print("targets      :", targets.tolist())
inputs shape : torch.Size([4, 1, 4])
targets      : [0, 2, 2, 3]
Show Code
torch.manual_seed(0)
model = CharLSTM(input_size=vocab_size, hidden_size=8, output_size=vocab_size)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)

for epoch in range(200):
    optimizer.zero_grad()

    logits = model(inputs)                     # (4, 1, 4) β€” forward pass
    loss = loss_fn(logits.view(-1, vocab_size), targets)

    loss.backward()                            # ← BPTT happens here, automatically
    optimizer.step()                           # update all four gates' weights

    if epoch % 20 == 0:
        print(f"epoch {epoch:3d}   loss {loss.item():.4f}")
epoch   0   loss 1.2731
epoch  20   loss 0.0013
epoch  40   loss 0.0001
epoch  60   loss 0.0000
epoch  80   loss 0.0000
epoch 100   loss 0.0000
epoch 120   loss 0.0000
epoch 140   loss 0.0000
epoch 160   loss 0.0000
epoch 180   loss 0.0000
Show Code
with torch.no_grad():
    _, states = model(inputs, return_states=True)

torch.set_printoptions(precision=2, sci_mode=False)

for t, (ch, (h, c, (f, i, c_tilde, o))) in enumerate(zip(input_chars, states)):
    print(f"\n──── timestep {t}: reading '{ch}' ────")
    print(f"  forget  f_t : {f[0]}   (how much old memory to KEEP)")
    print(f"  input   i_t : {i[0]}   (how much new info to WRITE)")
    print(f"  cand.   c̃_t : {c_tilde[0]}   (the new content proposed)")
    print(f"  output  o_t : {o[0]}   (how much memory to EXPOSE)")
    print(f"  cell    c_t : {c[0]}   ← long-term memory after this step")
    print(f"  hidden  h_t : {h[0]}   ← what gets passed on / predicted from")

──── timestep 0: reading 'h' ────
  forget  f_t : tensor([0.36, 0.64, 0.43, 0.55, 0.59, 0.74, 0.63, 0.79])   (how much old memory to KEEP)
  input   i_t : tensor([0.87, 0.92, 0.97, 0.95, 0.94, 0.98, 0.94, 0.40])   (how much new info to WRITE)
  cand.   c̃_t : tensor([-0.98,  0.96,  0.99,  1.00,  0.99, -1.00,  0.98,  0.50])   (the new content proposed)
  output  o_t : tensor([0.93, 0.95, 0.98, 0.96, 0.82, 0.97, 0.93, 0.28])   (how much memory to EXPOSE)
  cell    c_t : tensor([-0.85,  0.89,  0.96,  0.95,  0.94, -0.98,  0.91,  0.20])   ← long-term memory after this step
  hidden  h_t : tensor([-0.64,  0.68,  0.73,  0.71,  0.60, -0.73,  0.67,  0.05])   ← what gets passed on / predicted from

──── timestep 1: reading 'e' ────
  forget  f_t : tensor([0.04, 0.01, 0.00, 0.01, 1.00, 1.00, 0.05, 1.00])   (how much old memory to KEEP)
  input   i_t : tensor([1.00, 0.36, 0.99, 0.03, 0.95, 0.80, 0.16, 1.00])   (how much new info to WRITE)
  cand.   c̃_t : tensor([ 1.00, -0.63, -1.00,  0.92,  1.00, -1.00,  0.51, -1.00])   (the new content proposed)
  output  o_t : tensor([1.00, 0.10, 0.94, 0.00, 1.00, 0.98, 0.07, 1.00])   (how much memory to EXPOSE)
  cell    c_t : tensor([ 0.97, -0.22, -0.98,  0.04,  1.89, -1.78,  0.13, -0.80])   ← long-term memory after this step
  hidden  h_t : tensor([ 0.75, -0.02, -0.71,  0.00,  0.95, -0.92,  0.01, -0.66])   ← what gets passed on / predicted from

──── timestep 2: reading 'l' ────
  forget  f_t : tensor([0.50, 0.98, 0.98, 1.00, 0.33, 0.47, 0.99, 0.99])   (how much old memory to KEEP)
  input   i_t : tensor([1.00, 1.00, 1.00, 1.00, 1.00, 0.99, 1.00, 0.99])   (how much new info to WRITE)
  cand.   c̃_t : tensor([ 0.80, -1.00, -1.00,  1.00,  0.73, -0.65, -1.00, -1.00])   (the new content proposed)
  output  o_t : tensor([1.00, 1.00, 1.00, 0.88, 1.00, 1.00, 1.00, 1.00])   (how much memory to EXPOSE)
  cell    c_t : tensor([ 1.28, -1.21, -1.96,  1.04,  1.35, -1.48, -0.87, -1.78])   ← long-term memory after this step
  hidden  h_t : tensor([ 0.85, -0.84, -0.96,  0.69,  0.87, -0.90, -0.70, -0.94])   ← what gets passed on / predicted from

──── timestep 3: reading 'l' ────
  forget  f_t : tensor([0.07, 1.00, 1.00, 1.00, 0.02, 0.02, 1.00, 0.98])   (how much old memory to KEEP)
  input   i_t : tensor([0.99, 1.00, 1.00, 1.00, 1.00, 1.00, 1.00, 0.99])   (how much new info to WRITE)
  cand.   c̃_t : tensor([-0.99, -1.00, -1.00,  1.00, -0.98,  0.97, -1.00, -1.00])   (the new content proposed)
  output  o_t : tensor([1.00, 1.00, 1.00, 0.99, 1.00, 1.00, 1.00, 0.99])   (how much memory to EXPOSE)
  cell    c_t : tensor([-0.89, -2.21, -2.95,  2.04, -0.96,  0.94, -1.87, -2.74])   ← long-term memory after this step
  hidden  h_t : tensor([-0.71, -0.98, -0.99,  0.96, -0.74,  0.74, -0.95, -0.98])   ← what gets passed on / predicted from

In the next notebook, we will train the LSTM network to predict the next word in a sentence using a simple arabic dataset.


Comments