LSTM for Arabic Text Generation

In this notebook, we will implement a simple LSTM network for Arabic text generation. We have trainined the same data on a simple RNN in the previous notebook, in this notebook we will see how LSTM can give better results by capturing long-term dependencies in the data.

Show Code
import torch 
import torch.nn as nn

For each training job, we start by defining the model class where we specify the architecture of the network. The below class defines a simple LSTM cell, it takes in the input and the hidden state, and outputs the new hidden state \(h_t\) and the cell state \(c_t\).

Notice that we defined an embedding layer to convert the input characters into dense vectors, why? - We know that the input is a sequence of characters, and each character is represented by an index in the vocabulary (ุจ โ†’ 5, ุง โ†’ 2โ€ฆ), but these indices do not have any meaning to the model, they are just numbers. - The embedding layer learns to map these indices to dense vectors (emed_size) of a specified size, It gives every character its own small learnable vector and the model adjust those vectors during training to capture the relationships between characters. For example, it might learn that the characters โ€œุงโ€ and โ€œุจโ€ are often used together in Arabic words, so it will adjust their vectors to be closer in the embedding space.

An analogy: Think of one-hot encoding as giving every character a separate, unconnected locker โ€” locker 5 and locker 6 have no relationship. An embedding instead places each character as a point on a map: the model is free to move related characters close together and unrelated ones far apart. Those distances are what carry meaning.

So the data flows in three clean stages: character โ†’ embedding vector โ†’ LSTM memory โ†’ score for every possible next character.

This flow is defined in the forward method of the model, where we describe the actual journey of the data โ€” what happens, in order, when we feed the model a batch of text.

  1. We start with the input x where it is only a batch of character indices (batch, seq_len).

  2. We pass it through the embedding layer to get a batch of embedding vectors (batch, seq_len, embed_size).

  3. We then pass the embedding vectors through the LSTM layer, it walks through the sequence one character at a timee, updating its hidden state and cell state as it goes. The output out of the LSTM layer is the hidden state \(h_t\) at every timestep (batch, seq_len, hidden_size). And another output hidden represents the final hidden state and cell state itโ€™s a (h, c) tuple, so it also hands you back the cell state \(c\).

  4. Finally, we take the output out, which contains the hidden state at every timestep, and pass it through a linear layer to get the scores for the next character. For each position in the sequence, we get a score for every possible next character in the vocabulary, so the final output shape is (batch, seq_len, vocab_size).

Show Code
class CharGenerator(nn.Module):
    def __init__(self, vocab_size, embed_size=32, hidden_size=128, num_layers=1):
        super().__init__()
        # char index โ†’ dense vector
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size,
                            num_layers=num_layers, batch_first=True)  # โ† all 4 gates, BPTT, for free
        # hidden state โ†’ score per char
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, x, hidden=None):
        # x: (batch, seq_len) of character indices
        x = self.embed(x)                  # โ†’ (batch, seq_len, embed_size)
        
        # pass last step's full memory back in next time
        out, hidden = self.lstm(x, hidden)
        # out: (batch, seq_len, hidden_size) of hidden states at every step
        # hidden: (h, c) tuple of final hidden state and cell state
        
        # pass the hidden state at every step through the output layer to get a score for every possible next character
        logits = self.fc(out)              # โ†’ (batch, seq_len, vocab_size)
        return logits, hidden
Show Code
x = torch.randint(0, model.fc.out_features, (1, 4))
logits, hidden = model(x)
print("input :", x.shape)

print("h_n   :", hidden[0].shape)
print("c_n   :", hidden[1].shape)

print("logits:", logits.shape)          # (1, 4, vocab_size)
input : torch.Size([1, 4])
h_n   : torch.Size([1, 1, 128])
c_n   : torch.Size([1, 1, 128])
logits: torch.Size([1, 4, 47])

Below we read the data then create two dictionaries to map characters to indices and vice versa, then we convert the entire text into a tensor of indices.

Show Code
text = open("data/arabic_text.txt", encoding="utf-8").read()  # or any string

chars = sorted(set(text))
vocab_size = len(chars)
stoi = {c: i for i, c in enumerate(chars)}    # char โ†’ index
itos = {i: c for i, c in enumerate(chars)}

data = torch.tensor([stoi[c] for c in text], dtype=torch.long)
print(f"text length: {len(text)},  vocab size: {vocab_size}")
text length: 3296,  vocab size: 47

Then, we create x and y tensors, where x is the input sequence and y is the target sequence. The target sequence is just the input sequence shifted by one character.

Show Code
seq_len = 64
batch_size = 32


def get_batch():
    ix = torch.randint(0, len(data) - seq_len - 1, (batch_size,))
    x = torch.stack([data[i: i + seq_len] for i in ix])  # (batch, seq_len)
    y = torch.stack([data[i + 1: i + seq_len + 1] for i in ix])  # shifted by 1
    return x, y

x, y = get_batch()
print("x:", x.shape)  # (batch, seq_len)
print("y:", y.shape)  # (batch, seq_len)
print("x[0]:", x[0])  # example input sequence (as indices)
print("y[0]:", y[0])  # example target sequence (as indices
x: torch.Size([32, 64])
y: torch.Size([32, 64])
x[0]: tensor([36, 32, 34,  1, 11, 34, 35, 29, 21, 31, 13,  1, 35, 36,  1, 11, 34, 24,
        21, 32,  1,  9, 34, 39,  1, 11, 34, 30, 21, 12,  1, 29, 12, 21,  1, 11,
        34, 32, 21, 38, 36,  1, 11, 34, 38, 23, 27, 39,  2,  0, 11, 34, 34, 30,
        13,  1, 11, 34, 29, 21, 12, 40, 13,  1])
y[0]: tensor([32, 34,  1, 11, 34, 35, 29, 21, 31, 13,  1, 35, 36,  1, 11, 34, 24, 21,
        32,  1,  9, 34, 39,  1, 11, 34, 30, 21, 12,  1, 29, 12, 21,  1, 11, 34,
        32, 21, 38, 36,  1, 11, 34, 38, 23, 27, 39,  2,  0, 11, 34, 34, 30, 13,
         1, 11, 34, 29, 21, 12, 40, 13,  1, 34])
Show Code
device = "cuda" if torch.cuda.is_available() else "cpu"
model = CharGenerator(vocab_size).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=3e-3)
loss_fn = nn.CrossEntropyLoss()

for step in range(2000):
    x, y = get_batch()
    x, y = x.to(device), y.to(device)

    # (batch, seq_len, vocab)
    logits, _ = model(x)
    # flatten, same as before
    loss = loss_fn(logits.view(-1, vocab_size), y.view(-1))

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if step % 100 == 0:
        print(f"step {step:4d}   loss {loss.item():.4f}")
step    0   loss 3.8656
step  100   loss 2.1116
step  200   loss 1.3889
step  300   loss 0.7983
step  400   loss 0.4408
step  500   loss 0.3007
step  600   loss 0.2307
step  700   loss 0.1886
step  800   loss 0.1728
step  900   loss 0.1530
step 1000   loss 0.1615
step 1100   loss 0.1398
step 1200   loss 0.1298
step 1300   loss 0.1151
step 1400   loss 0.1119
step 1500   loss 0.1140
step 1600   loss 0.1098
step 1700   loss 0.1099
step 1800   loss 0.1124
step 1900   loss 0.1100

We have defined seq_len = 64 and batch_size = 32. This means that the LSTM cell is unrolled for 64 time steps, and for each cell there is a single output, so we have a total of 64 outputs for each sequence. Each output will pass through the linear layer to get logits of shape (1, seq_len, vocab_size) for a single batch.

Generating Text

After training the model, the below code is used to generate new text. We start with seed text. The input x is taking each char in the seed text, converting it to its corresponding index, and creating a tensor of shape (1, seq_len) to represent the input sequence.

Here is the detailed steps of the generation process:

  1. We start with the seed text, convert it to indices, and create the input tensor x.
  2. Take the last character of the seed text and pass it through the model.
  3. So the model will get a single character as input, and it will output the logits for the next character.
  4. We take the logits then apply softmax to get probabilities for each possible next character.
  5. We can adjust the probabilities using a temperature parameter:
  • temperature < 1 (e.g. 0.5) โ†’ dividing by a small number stretches the gaps โ†’ softmax becomes sharper โ†’ the top character dominates โ†’ safe, confident, repetitive.
  • temperature > 1 (e.g. 2.0) โ†’ dividing by a big number compresses the gaps โ†’ softmax becomes flatter โ†’ probabilities more equal โ†’ adventurous, varied, more typos.
  • temperature = 1 โ†’ divide by 1 โ†’ no change โ†’ the modelโ€™s raw distribution.
Show Code
temperature = 0.8

seed = "ุงู„ู…ูˆุฌ"                                            # any starting text
x = torch.tensor([[stoi[c] for c in seed]])
print(f"seed -> x shape:{x.shape}, values: {x}, each char maps to an index")
last = x[:, -1].view(1, 1)  # start with just the last char
print(f"last char index: {last.item()}, char: '{itos[last.item()]}'")

result = list(seed)
hidden = None

for _ in range(300):
    logits, hidden = model(last, hidden)        # feed prev char + memory
    # logits: (1, 1, vocab_size) โ†’ we only care about the last position
    last_logits = logits[0, -1]           # just the LAST position โ†’ shape (47,)

    # 1. scores โ†’ probabilities
    probs = torch.softmax(last_logits / temperature, dim=-1)   # (47,) summing to 1

    # 2. draw ONE character, randomly, weighted by those probabilities
    next_idx = torch.multinomial(probs, num_samples=1)         # e.g. tensor([12])

    # 3. turn the index back into a character
    result.append(itos[next_idx.item()])
    last = next_idx.view(1, 1)  # โ† the new char becomes next input

print("last char shape:", last.shape)
print("last hidden shape:", hidden[0].shape) 
print(f"last logits shape: {last_logits.shape}, and the logits tensor: {logits.shape}")  
print("".join(result))
seed -> x shape:torch.Size([1, 5]), values: tensor([[11, 34, 35, 38, 16]]), each char maps to an index
last char index: 16, char: 'ุฌ'
last char shape: torch.Size([1, 1])
last hidden shape: torch.Size([1, 1, 128])
last logits shape: torch.Size([47]), and the logits tensor: torch.Size([1, 1, 47])
ุงู„ู…ูˆุฌุฏูŠุฏุฉุŒ ูุชู† ูŠุณุชุฎุฏู…ู‡.
ู…ุณุชู‚ุจู„ ุงู„ุฅู†ุณุงู†ูŠุฉ ู…ุฑู‡ูˆู† ุจู‚ุฏุฑุชู‡ุง ุนู„ู‰ ุงุณุชุฎุฏุงู… ุงู„ุชูƒู†ูˆู„ูˆุฌูŠุง ุจุญูƒู…ุฉ ูˆู…ุณุคูˆู„ูŠุฉ.

ููŠ ุฒู…ู† ุงู„ุบุฑุจุฉุŒ ุชุตุจุญ ุงู„ุฐูƒุฑูŠุงุช ูˆุทู†ู‹ุง ูˆุงู„ุญู†ูŠู† ุฏูŠู†ู‹ุง ูˆุงู„ุฃู…ู„ ู‚ุจู„ุฉ.
ุงู„ุบุฑูŠุจ ูŠุญู…ู„ ูˆุทู†ู‡ ููŠ ู‚ู„ุจู‡ ุญูŠุซู…ุง ุณุงุฑุŒ ูˆุนู„ู‰ ุดูุชูŠู‡ ุงุณู… ุฃุญุจุงุจู‡ ุฏุงุฆู…ู‹ุง.
ุงู„ู…ุณุงูุฉ ู„ุง ุชูุจุนุฏ ู…ู† ุชุญุจู‡ู…ุŒ ุจู„ ุชุฌุนู„ ุงู„ุดูˆู‚ ุฃุดุฏ ูˆุงู„ู…ุญุจุฉ ุฃุนู…ู‚ ูˆุฃุตุฏู‚.
ุงู„ูˆุทู† ู„ูŠุณ ู…ุฌุฑุฏ ุฃ

Comments