Show Code
import torch
import torch.nn as nnIn this notebook, we will implement a simple LSTM network for Arabic text generation. We have trainined the same data on a simple RNN in the previous notebook, in this notebook we will see how LSTM can give better results by capturing long-term dependencies in the data.
For each training job, we start by defining the model class where we specify the architecture of the network. The below class defines a simple LSTM cell, it takes in the input and the hidden state, and outputs the new hidden state \(h_t\) and the cell state \(c_t\).
Notice that we defined an embedding layer to convert the input characters into dense vectors, why? - We know that the input is a sequence of characters, and each character is represented by an index in the vocabulary (ุจ โ 5, ุง โ 2โฆ), but these indices do not have any meaning to the model, they are just numbers. - The embedding layer learns to map these indices to dense vectors (emed_size) of a specified size, It gives every character its own small learnable vector and the model adjust those vectors during training to capture the relationships between characters. For example, it might learn that the characters โุงโ and โุจโ are often used together in Arabic words, so it will adjust their vectors to be closer in the embedding space.
An analogy: Think of one-hot encoding as giving every character a separate, unconnected locker โ locker 5 and locker 6 have no relationship. An embedding instead places each character as a point on a map: the model is free to move related characters close together and unrelated ones far apart. Those distances are what carry meaning.
So the data flows in three clean stages: character โ embedding vector โ LSTM memory โ score for every possible next character.
This flow is defined in the forward method of the model, where we describe the actual journey of the data โ what happens, in order, when we feed the model a batch of text.
We start with the input x where it is only a batch of character indices (batch, seq_len).
We pass it through the embedding layer to get a batch of embedding vectors (batch, seq_len, embed_size).
We then pass the embedding vectors through the LSTM layer, it walks through the sequence one character at a timee, updating its hidden state and cell state as it goes. The output out of the LSTM layer is the hidden state \(h_t\) at every timestep (batch, seq_len, hidden_size). And another output hidden represents the final hidden state and cell state itโs a (h, c) tuple, so it also hands you back the cell state \(c\).
Finally, we take the output out, which contains the hidden state at every timestep, and pass it through a linear layer to get the scores for the next character. For each position in the sequence, we get a score for every possible next character in the vocabulary, so the final output shape is (batch, seq_len, vocab_size).
class CharGenerator(nn.Module):
def __init__(self, vocab_size, embed_size=32, hidden_size=128, num_layers=1):
super().__init__()
# char index โ dense vector
self.embed = nn.Embedding(vocab_size, embed_size)
self.lstm = nn.LSTM(embed_size, hidden_size,
num_layers=num_layers, batch_first=True) # โ all 4 gates, BPTT, for free
# hidden state โ score per char
self.fc = nn.Linear(hidden_size, vocab_size)
def forward(self, x, hidden=None):
# x: (batch, seq_len) of character indices
x = self.embed(x) # โ (batch, seq_len, embed_size)
# pass last step's full memory back in next time
out, hidden = self.lstm(x, hidden)
# out: (batch, seq_len, hidden_size) of hidden states at every step
# hidden: (h, c) tuple of final hidden state and cell state
# pass the hidden state at every step through the output layer to get a score for every possible next character
logits = self.fc(out) # โ (batch, seq_len, vocab_size)
return logits, hiddeninput : torch.Size([1, 4])
h_n : torch.Size([1, 1, 128])
c_n : torch.Size([1, 1, 128])
logits: torch.Size([1, 4, 47])
Below we read the data then create two dictionaries to map characters to indices and vice versa, then we convert the entire text into a tensor of indices.
text = open("data/arabic_text.txt", encoding="utf-8").read() # or any string
chars = sorted(set(text))
vocab_size = len(chars)
stoi = {c: i for i, c in enumerate(chars)} # char โ index
itos = {i: c for i, c in enumerate(chars)}
data = torch.tensor([stoi[c] for c in text], dtype=torch.long)
print(f"text length: {len(text)}, vocab size: {vocab_size}")text length: 3296, vocab size: 47
Then, we create x and y tensors, where x is the input sequence and y is the target sequence. The target sequence is just the input sequence shifted by one character.
seq_len = 64
batch_size = 32
def get_batch():
ix = torch.randint(0, len(data) - seq_len - 1, (batch_size,))
x = torch.stack([data[i: i + seq_len] for i in ix]) # (batch, seq_len)
y = torch.stack([data[i + 1: i + seq_len + 1] for i in ix]) # shifted by 1
return x, y
x, y = get_batch()
print("x:", x.shape) # (batch, seq_len)
print("y:", y.shape) # (batch, seq_len)
print("x[0]:", x[0]) # example input sequence (as indices)
print("y[0]:", y[0]) # example target sequence (as indicesx: torch.Size([32, 64])
y: torch.Size([32, 64])
x[0]: tensor([36, 32, 34, 1, 11, 34, 35, 29, 21, 31, 13, 1, 35, 36, 1, 11, 34, 24,
21, 32, 1, 9, 34, 39, 1, 11, 34, 30, 21, 12, 1, 29, 12, 21, 1, 11,
34, 32, 21, 38, 36, 1, 11, 34, 38, 23, 27, 39, 2, 0, 11, 34, 34, 30,
13, 1, 11, 34, 29, 21, 12, 40, 13, 1])
y[0]: tensor([32, 34, 1, 11, 34, 35, 29, 21, 31, 13, 1, 35, 36, 1, 11, 34, 24, 21,
32, 1, 9, 34, 39, 1, 11, 34, 30, 21, 12, 1, 29, 12, 21, 1, 11, 34,
32, 21, 38, 36, 1, 11, 34, 38, 23, 27, 39, 2, 0, 11, 34, 34, 30, 13,
1, 11, 34, 29, 21, 12, 40, 13, 1, 34])
device = "cuda" if torch.cuda.is_available() else "cpu"
model = CharGenerator(vocab_size).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-3)
loss_fn = nn.CrossEntropyLoss()
for step in range(2000):
x, y = get_batch()
x, y = x.to(device), y.to(device)
# (batch, seq_len, vocab)
logits, _ = model(x)
# flatten, same as before
loss = loss_fn(logits.view(-1, vocab_size), y.view(-1))
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step % 100 == 0:
print(f"step {step:4d} loss {loss.item():.4f}")step 0 loss 3.8656
step 100 loss 2.1116
step 200 loss 1.3889
step 300 loss 0.7983
step 400 loss 0.4408
step 500 loss 0.3007
step 600 loss 0.2307
step 700 loss 0.1886
step 800 loss 0.1728
step 900 loss 0.1530
step 1000 loss 0.1615
step 1100 loss 0.1398
step 1200 loss 0.1298
step 1300 loss 0.1151
step 1400 loss 0.1119
step 1500 loss 0.1140
step 1600 loss 0.1098
step 1700 loss 0.1099
step 1800 loss 0.1124
step 1900 loss 0.1100
We have defined seq_len = 64 and batch_size = 32. This means that the LSTM cell is unrolled for 64 time steps, and for each cell there is a single output, so we have a total of 64 outputs for each sequence. Each output will pass through the linear layer to get logits of shape (1, seq_len, vocab_size) for a single batch.
After training the model, the below code is used to generate new text. We start with seed text. The input x is taking each char in the seed text, converting it to its corresponding index, and creating a tensor of shape (1, seq_len) to represent the input sequence.
Here is the detailed steps of the generation process:
x.temperature parameter:temperature < 1 (e.g. 0.5) โ dividing by a small number stretches the gaps โ softmax becomes sharper โ the top character dominates โ safe, confident, repetitive.temperature > 1 (e.g. 2.0) โ dividing by a big number compresses the gaps โ softmax becomes flatter โ probabilities more equal โ adventurous, varied, more typos.temperature = 1 โ divide by 1 โ no change โ the modelโs raw distribution.temperature = 0.8
seed = "ุงูู
ูุฌ" # any starting text
x = torch.tensor([[stoi[c] for c in seed]])
print(f"seed -> x shape:{x.shape}, values: {x}, each char maps to an index")
last = x[:, -1].view(1, 1) # start with just the last char
print(f"last char index: {last.item()}, char: '{itos[last.item()]}'")
result = list(seed)
hidden = None
for _ in range(300):
logits, hidden = model(last, hidden) # feed prev char + memory
# logits: (1, 1, vocab_size) โ we only care about the last position
last_logits = logits[0, -1] # just the LAST position โ shape (47,)
# 1. scores โ probabilities
probs = torch.softmax(last_logits / temperature, dim=-1) # (47,) summing to 1
# 2. draw ONE character, randomly, weighted by those probabilities
next_idx = torch.multinomial(probs, num_samples=1) # e.g. tensor([12])
# 3. turn the index back into a character
result.append(itos[next_idx.item()])
last = next_idx.view(1, 1) # โ the new char becomes next input
print("last char shape:", last.shape)
print("last hidden shape:", hidden[0].shape)
print(f"last logits shape: {last_logits.shape}, and the logits tensor: {logits.shape}")
print("".join(result))seed -> x shape:torch.Size([1, 5]), values: tensor([[11, 34, 35, 38, 16]]), each char maps to an index
last char index: 16, char: 'ุฌ'
last char shape: torch.Size([1, 1])
last hidden shape: torch.Size([1, 1, 128])
last logits shape: torch.Size([47]), and the logits tensor: torch.Size([1, 1, 47])
ุงูู
ูุฌุฏูุฏุฉุ ูุชู ูุณุชุฎุฏู
ู.
ู
ุณุชูุจู ุงูุฅูุณุงููุฉ ู
ุฑููู ุจูุฏุฑุชูุง ุนูู ุงุณุชุฎุฏุงู
ุงูุชูููููุฌูุง ุจุญูู
ุฉ ูู
ุณุคูููุฉ.
ูู ุฒู
ู ุงูุบุฑุจุฉุ ุชุตุจุญ ุงูุฐูุฑูุงุช ูุทููุง ูุงูุญููู ุฏูููุง ูุงูุฃู
ู ูุจูุฉ.
ุงูุบุฑูุจ ูุญู
ู ูุทูู ูู ููุจู ุญูุซู
ุง ุณุงุฑุ ูุนูู ุดูุชูู ุงุณู
ุฃุญุจุงุจู ุฏุงุฆู
ูุง.
ุงูู
ุณุงูุฉ ูุง ุชูุจุนุฏ ู
ู ุชุญุจูู
ุ ุจู ุชุฌุนู ุงูุดูู ุฃุดุฏ ูุงูู
ุญุจุฉ ุฃุนู
ู ูุฃุตุฏู.
ุงููุทู ููุณ ู
ุฌุฑุฏ ุฃ
Comments