Show Code
# import
import torch
import torch.nn as nnUisng pytorch, we will build a simple LSTM network to understand how it works. ANd how information flows through the network.
The below class defines a simple LSTM cell, it takes in the input and the hidden state, and outputs the new hidden state \(h_t\) and the cell state \(c_t\).
class LSTMByHand(nn.Module):
def __init__(self, input_size, hidden_size):
super().__init__()
self.hidden_size = hidden_size
# because we feed [h_{t-1}, x_t]
combined = input_size + hidden_size
# Each gate is its own weight matrix W and bias b.
# nn.Linear(combined, hidden) computes exactly: W Β· [h_{t-1}, x_t] + b
self.forget_gate = nn.Linear(combined, hidden_size) # W_f, b_f
self.input_gate = nn.Linear(combined, hidden_size) # W_i, b_i
self.candidate_gate = nn.Linear(combined, hidden_size) # W_c, b_c
self.output_gate = nn.Linear(combined, hidden_size) # W_o, b_o
def forward(self, x, h_prev, c_prev):
# 1. Concatenate previous hidden state with current input β [h_{t-1}, x_t]
combined = torch.cat([h_prev, x], dim=1)
# 2. The four gate computations:
# f_t = Ο(W_fΒ·[h,x] + b_f)
f = torch.sigmoid(self.forget_gate(combined))
# i_t = Ο(W_iΒ·[h,x] + b_i)
i = torch.sigmoid(self.input_gate(combined))
# cΜ_t = tanh(W_cΒ·[h,x] + b_c)
c_tilde = torch.tanh(self.candidate_gate(combined))
# o_t = Ο(W_oΒ·[h,x] + b_o)
o = torch.sigmoid(self.output_gate(combined))
# 3. Update the cell state (long-term memory)
# c_t = f_t β c_{t-1} + i_t β cΜ_t
c = f * c_prev + i * c_tilde
# 4. Expose part of it as the hidden state (short-term memory)
# h_t = o_t β tanh(c_t)
h = o * torch.tanh(c)
return h, c, (f, i, c_tilde, o)After defining the LSTM cell, letβs inspect the gates and the cell state after passing a random input through the cell.
torch.manual_seed(0)
cell = LSTMByHand(input_size=vocab_size, hidden_size=4)
h = torch.zeros(1, 4) # h_0: blank short-term memory
c = torch.zeros(1, 4) # c_0: blank long-term memory
x = one_hot('h').unsqueeze(0) # shape (1, 4) β batch of 1
h, c, (f, i, c_tilde, o) = cell(x, h, c)
print("forget gate f :", f.detach().round(decimals=2))
print("input gate i :", i.detach().round(decimals=2))
print("candidate cΜ :", c_tilde.detach().round(decimals=2))
print("output gate o :", o.detach().round(decimals=2))
print("cell state c :", c.detach().round(decimals=2))
print("hidden state h :", h.detach().round(decimals=2))forget gate f : tensor([[0.47, 0.40, 0.51, 0.40]])
input gate i : tensor([[0.51, 0.58, 0.37, 0.46]])
candidate cΜ : tensor([[-0.57, -0.07, -0.13, -0.03]])
output gate o : tensor([[0.66, 0.56, 0.51, 0.49]])
cell state c : tensor([[-0.29, -0.04, -0.05, -0.02]])
hidden state h : tensor([[-0.19, -0.02, -0.02, -0.01]])
class CharLSTM(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super().__init__()
# the gates from Step 1
self.cell = LSTMByHand(input_size, hidden_size)
# hidden state β scores over chars
self.fc = nn.Linear(hidden_size, output_size)
self.hidden_size = hidden_size
def forward(self, inputs, return_states=False):
# inputs: (seq_len, batch, input_size)
batch = inputs.shape[1]
# h_0: blank short-term memory
h = torch.zeros(batch, self.hidden_size)
# c_0: blank long-term memory
c = torch.zeros(batch, self.hidden_size)
logits, states = [], []
for x in inputs: # walk the sequence one timestep at a time
# β carry h and c forward (this IS the recurrence)
h, c, gates = self.cell(x, h, c)
logits.append(self.fc(h))
states.append((h, c, gates))
# (seq_len, batch, output_size)
logits = torch.stack(logits)
return (logits, states) if return_states else logitsinput_chars = text[:-1] # "hell"
target_chars = text[1:] # "ello"
# inputs: (seq_len, batch=1, input_size)
inputs = torch.stack([one_hot(c) for c in input_chars]).unsqueeze(1)
targets = torch.tensor([char_to_idx[c]
for c in target_chars]) # indices, for cross-entropy
print("inputs shape :", inputs.shape) # (4, 1, 4)
print("targets :", targets.tolist())inputs shape : torch.Size([4, 1, 4])
targets : [0, 2, 2, 3]
torch.manual_seed(0)
model = CharLSTM(input_size=vocab_size, hidden_size=8, output_size=vocab_size)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
for epoch in range(200):
optimizer.zero_grad()
logits = model(inputs) # (4, 1, 4) β forward pass
loss = loss_fn(logits.view(-1, vocab_size), targets)
loss.backward() # β BPTT happens here, automatically
optimizer.step() # update all four gates' weights
if epoch % 20 == 0:
print(f"epoch {epoch:3d} loss {loss.item():.4f}")epoch 0 loss 1.2731
epoch 20 loss 0.0013
epoch 40 loss 0.0001
epoch 60 loss 0.0000
epoch 80 loss 0.0000
epoch 100 loss 0.0000
epoch 120 loss 0.0000
epoch 140 loss 0.0000
epoch 160 loss 0.0000
epoch 180 loss 0.0000
with torch.no_grad():
_, states = model(inputs, return_states=True)
torch.set_printoptions(precision=2, sci_mode=False)
for t, (ch, (h, c, (f, i, c_tilde, o))) in enumerate(zip(input_chars, states)):
print(f"\nββββ timestep {t}: reading '{ch}' ββββ")
print(f" forget f_t : {f[0]} (how much old memory to KEEP)")
print(f" input i_t : {i[0]} (how much new info to WRITE)")
print(f" cand. cΜ_t : {c_tilde[0]} (the new content proposed)")
print(f" output o_t : {o[0]} (how much memory to EXPOSE)")
print(f" cell c_t : {c[0]} β long-term memory after this step")
print(f" hidden h_t : {h[0]} β what gets passed on / predicted from")
ββββ timestep 0: reading 'h' ββββ
forget f_t : tensor([0.36, 0.64, 0.43, 0.55, 0.59, 0.74, 0.63, 0.79]) (how much old memory to KEEP)
input i_t : tensor([0.87, 0.92, 0.97, 0.95, 0.94, 0.98, 0.94, 0.40]) (how much new info to WRITE)
cand. cΜ_t : tensor([-0.98, 0.96, 0.99, 1.00, 0.99, -1.00, 0.98, 0.50]) (the new content proposed)
output o_t : tensor([0.93, 0.95, 0.98, 0.96, 0.82, 0.97, 0.93, 0.28]) (how much memory to EXPOSE)
cell c_t : tensor([-0.85, 0.89, 0.96, 0.95, 0.94, -0.98, 0.91, 0.20]) β long-term memory after this step
hidden h_t : tensor([-0.64, 0.68, 0.73, 0.71, 0.60, -0.73, 0.67, 0.05]) β what gets passed on / predicted from
ββββ timestep 1: reading 'e' ββββ
forget f_t : tensor([0.04, 0.01, 0.00, 0.01, 1.00, 1.00, 0.05, 1.00]) (how much old memory to KEEP)
input i_t : tensor([1.00, 0.36, 0.99, 0.03, 0.95, 0.80, 0.16, 1.00]) (how much new info to WRITE)
cand. cΜ_t : tensor([ 1.00, -0.63, -1.00, 0.92, 1.00, -1.00, 0.51, -1.00]) (the new content proposed)
output o_t : tensor([1.00, 0.10, 0.94, 0.00, 1.00, 0.98, 0.07, 1.00]) (how much memory to EXPOSE)
cell c_t : tensor([ 0.97, -0.22, -0.98, 0.04, 1.89, -1.78, 0.13, -0.80]) β long-term memory after this step
hidden h_t : tensor([ 0.75, -0.02, -0.71, 0.00, 0.95, -0.92, 0.01, -0.66]) β what gets passed on / predicted from
ββββ timestep 2: reading 'l' ββββ
forget f_t : tensor([0.50, 0.98, 0.98, 1.00, 0.33, 0.47, 0.99, 0.99]) (how much old memory to KEEP)
input i_t : tensor([1.00, 1.00, 1.00, 1.00, 1.00, 0.99, 1.00, 0.99]) (how much new info to WRITE)
cand. cΜ_t : tensor([ 0.80, -1.00, -1.00, 1.00, 0.73, -0.65, -1.00, -1.00]) (the new content proposed)
output o_t : tensor([1.00, 1.00, 1.00, 0.88, 1.00, 1.00, 1.00, 1.00]) (how much memory to EXPOSE)
cell c_t : tensor([ 1.28, -1.21, -1.96, 1.04, 1.35, -1.48, -0.87, -1.78]) β long-term memory after this step
hidden h_t : tensor([ 0.85, -0.84, -0.96, 0.69, 0.87, -0.90, -0.70, -0.94]) β what gets passed on / predicted from
ββββ timestep 3: reading 'l' ββββ
forget f_t : tensor([0.07, 1.00, 1.00, 1.00, 0.02, 0.02, 1.00, 0.98]) (how much old memory to KEEP)
input i_t : tensor([0.99, 1.00, 1.00, 1.00, 1.00, 1.00, 1.00, 0.99]) (how much new info to WRITE)
cand. cΜ_t : tensor([-0.99, -1.00, -1.00, 1.00, -0.98, 0.97, -1.00, -1.00]) (the new content proposed)
output o_t : tensor([1.00, 1.00, 1.00, 0.99, 1.00, 1.00, 1.00, 0.99]) (how much memory to EXPOSE)
cell c_t : tensor([-0.89, -2.21, -2.95, 2.04, -0.96, 0.94, -1.87, -2.74]) β long-term memory after this step
hidden h_t : tensor([-0.71, -0.98, -0.99, 0.96, -0.74, 0.74, -0.95, -0.98]) β what gets passed on / predicted from
In the next notebook, we will train the LSTM network to predict the next word in a sentence using a simple arabic dataset.
Comments