Demystifying Transformers: Building a Toy Model and Understanding the Landscape of Modern NLP Tools

Demystifying Transformers: Building a Toy Model and Understanding the Landscape of Modern NLP Tools

Transformers have revolutionized the field of natural language processing (NLP), enabling advancements in tasks ranging from machine translation to text generation. In this article, we'll embark on a journey to build a simple transformer model from scratch, helping you understand the inner workings of this powerful architecture. But our exploration doesn’t stop there. We'll also delve into the broader landscape of NLP tools, highlighting when to use transformers, sentence transformers, and large language models (LLMs).

Understanding the Basics: What is a Transformer?

Before we dive into the code, let’s briefly revisit what a transformer is. The transformer architecture was introduced in the 2017 landmark research paper "Attention Is All You Need," which fundamentally changed how we approach sequence-to-sequence tasks in NLP. Unlike previous models that relied on recurrent neural networks (RNNs), transformers utilize self-attention mechanisms to process data more efficiently, capturing relationships between all elements in a sequence simultaneously.


An illustration of main components of the transformer model from the paper


Building a Toy Transformer Model

To truly understand transformers, there's no better way than to build one yourself. We’ll walk through the process of building a toy transformer model using PyTorch. Transformers are powerful models that have revolutionized natural language processing (NLP), and understanding how they work can provide deep insights into modern AI systems. We'll start with the basics, gradually introduce more complex concepts, and then see how the model performs in generating short, meaningful sentences.

Defining the Transformer’s Agenda: Data, Expectations, and Strategy

Before diving into the technical details, it's important to understand what we aim to achieve with our transformer model. In this project, we’ll be working with simple, structured sentences composed of four words each. These sentences cover various natural phenomena, such as "The sun rises slowly" or "The fire burns brightly." The goal of our transformer is to learn the patterns within these sentences and generate coherent sequences when given a prompt like "The sun."

In the interest of keeping the model simple and focused, we have deliberately chosen these straightforward 4-word sentences. This simplicity allows the model to effectively capture and learn the relationships between words without being overwhelmed by the complexity of longer or more varied sentences.

Our transformer model itself is designed to balance simplicity and effectiveness, with a dimensionality of 256, 6 layers, and 8 attention heads. These parameters ensure that the model has sufficient capacity to learn and generate meaningful sequences while remaining accessible for educational purposes.

# Model configuration parameters
d_model = 256  # Dimensionality of the model (size of the embedding vectors)
d_ff = 1024    # Dimensionality of the feed-forward layers
num_heads = 8  # Number of attention heads in the multi-head attention mechanism
num_layers = 6 # Number of layers in the encoder and decoder
dropout_rate = 0.3  # Dropout rate for regularization        

To coax the transformer into generating the desired outputs, we’ll train it on these structured sentences, ensuring that it learns the relationships between the words. By carefully designing the training data and employing techniques like beam search during inference, we guide the model to predict the most likely sequence of words, reinforcing its ability to produce meaningful sentences based on the learned patterns. This structured approach sets the stage for a successful and insightful exploration of how transformers work.

Setting Up the Environment

The first step in any machine learning project is to set up the environment. In our code, we use PyTorch, a popular deep learning framework, to build and train our transformer model. We begin by setting a random seed to ensure that our results are reproducible. We also check if a GPU is available; if it is, we use it to accelerate training.

torch.manual_seed(42)
device = 'cuda' if torch.cuda.is_available() else 'cpu'        

Creating the Vocabulary

Before we can train our model, we need to convert our text data into a format that the model can understand. This involves creating a vocabulary, which is a mapping of each word in our dataset to a unique integer ID. The function create_vocab handles this by processing all the sentences in our dataset, removing any punctuation, and assigning each unique word an ID.

def create_vocab(sentences):
    vocab = {}
    for sentence in sentences:
        for word in sentence.lower().split():
            word = re.sub(r'[^\w\s]', '', word)
            if word not in vocab:
                vocab[word] = len(vocab)
    return vocab        

Encoding the Sentences

Once we have our vocabulary, we need to convert our sentences into tensors of word indices. These tensors are the inputs that our model will learn from. The encode_sentences function takes care of this by converting each word in the sentence into its corresponding index from the vocabulary.

def encode_sentences(sentences, vocab):
    encoded = []
    for sentence in sentences:
        tokens = []
        for word in sentence.lower().split():
            word = re.sub(r'[^\w\s]', '', word)
            tokens.append(vocab[word])
        encoded.append(torch.tensor(tokens))
    return encoded        

Understanding the Transformer Model

Transformers are built from several key components: multi-head attention, feed-forward networks, and positional encoding. These components work together to allow the model to understand the relationships between words in a sentence, regardless of their position.

Positional Encoding

Transformers don’t inherently understand the order of words, so we need to add some information about word positions. The PositionalEncoding layer adds this information by applying sine and cosine functions to create unique positional encodings for each word in the sequence.

class PositionalEncoding(nn.Module):
    ...        

Multi-Head Attention

The MultiHeadAttention layer allows the model to focus on different parts of the input sentence when making predictions. This is crucial for understanding context, as it enables the model to consider how each word relates to every other word in the sentence.

class MultiHeadAttention(nn.Module):
    ...        

Feed-Forward Networks and Layer Normalization

Each transformer layer includes a feed-forward network, which processes the information from the attention layer and helps the model make more complex predictions. The output is then normalized using a technique called layer normalization, which helps stabilize the learning process.

class FeedForward(nn.Module):
    ...        

Building the Transformer Model

Our transformer model is composed of multiple encoder and decoder blocks, each of which contains the layers we’ve just discussed. The encoder processes the input sentence, while the decoder generates predictions word by word. The final output is produced by a linear layer that projects the model's predictions onto the vocabulary.

class Transformer(nn.Module):
    ...        

Training the Model

Training a transformer involves feeding it input sentences and their corresponding target sentences. The model learns to predict the next word in the sequence by minimizing the difference between its predictions and the actual target words. We use a combination of curriculum learning (starting with easier tasks and gradually increasing difficulty) and fine-tuning to help the model learn effectively.

def train_transformer(model, input_tensors, target_tensors, fine_tuning_tensors=None, num_epochs=500):
    ...        

Making Predictions with Beam Search

After training, we can use the model to generate predictions. Beam search is a technique that helps the model find the most likely sequence of words by considering multiple possible continuations at each step. This helps ensure that the model generates coherent and contextually appropriate sentences.

def predict_two_words_beam_search_full_context_dynamic(model, prompt, beam_width=5, repetition_penalty=5.0):
    ...        

Complete Code

Here's the code in full, with helpful comments throughout, for you to understand and even try it out on your system.

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import re
import os

# Set up the environment by setting a random seed for reproducibility and choosing the device (GPU if available)
torch.manual_seed(42)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Flag to control saving and loading of models
SAVE_MODEL = True
LOAD_MODEL = False
MODEL_PATH = "toy_transformer_model_1.pth"

# Training sentences, each with exactly 4 words, to train the model
training_sentences = [
    "The sun rises slowly.",
    "The sun sets quietly.",
    "The fire crackles loudly.",
    "The fire burns steadily.",
    "The bird sings sweetly.",
    "The bird flies high.",
    "The bell tolls loudly.",
    "The bell rings softly.",
    "The wind blows gently.",
    "The wind howls fiercely.",
    "The river flows smoothly.",
    "The river rushes swiftly."
]

# Fine-tuning sentences, also with 4 words, to refine the model’s understanding of these specific patterns
fine_tuning_sentences = [
    "The sun sets quietly.",
    "The fire burns brightly.",
    "The bird sings sweetly.",
    "The bell rings softly.",
    "The wind blows gently.",
    "The river flows smoothly."
]

def create_vocab(sentences):
    """
    Create a vocabulary from a list of sentences.

    Args:
    sentences (list of str): List of sentences from which to create the vocabulary.

    Returns:
    vocab (dict): Dictionary mapping each word to a unique integer ID.
    """
    vocab = {}
    for sentence in sentences:
        for word in sentence.lower().split():
            word = re.sub(r'[^\w\s]', '', word)  # Remove punctuation
            if word not in vocab:
                vocab[word] = len(vocab)
    return vocab

# Generate vocabulary and inverse vocabulary for decoding
vocab = create_vocab(training_sentences + fine_tuning_sentences)
inv_vocab = {v: k for k, v in vocab.items()}

def encode_sentences(sentences, vocab):
    """
    Encode sentences into tensors of word indices based on the provided vocabulary.

    Args:
    sentences (list of str): List of sentences to encode.
    vocab (dict): Vocabulary mapping words to indices.

    Returns:
    encoded (list of torch.Tensor): List of encoded sentences as tensors of word indices.
    """
    encoded = []
    for sentence in sentences:
        tokens = []
        for word in sentence.lower().split():
            word = re.sub(r'[^\w\s]', '', word)  # Remove punctuation
            tokens.append(vocab[word])
        encoded.append(torch.tensor(tokens))
    return encoded

# Encode the training and fine-tuning sentences into tensors
input_tensors = encode_sentences(training_sentences, vocab)
target_tensors = encode_sentences(training_sentences, vocab)

# Model configuration parameters
d_model = 256  # Dimensionality of the model (size of the embedding vectors)
d_ff = 1024    # Dimensionality of the feed-forward layers
num_heads = 8  # Number of attention heads in the multi-head attention mechanism
num_layers = 6 # Number of layers in the encoder and decoder
dropout_rate = 0.3  # Dropout rate for regularization

# Positional Encoding layer to add positional information to the input embeddings
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=50):
        """
        Initialize the positional encoding layer.

        Args:
        d_model (int): Dimensionality of the embeddings.
        max_len (int): Maximum length of the input sequences.
        """
        super(PositionalEncoding, self).__init__()
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, d_model)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        Forward pass for positional encoding.

        Args:
        x (torch.Tensor): Input tensor to which positional encoding will be added.

        Returns:
        torch.Tensor: Positional encoded tensor.
        """
        return x + self.pe[:, :x.size(1), :]

# Multi-Head Attention layer that allows the model to focus on different parts of the input sequence
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        """
        Initialize the multi-head attention layer.

        Args:
        d_model (int): Dimensionality of the embeddings.
        num_heads (int): Number of attention heads.
        """
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        assert d_model % num_heads == 0  # Ensure the dimensionality is divisible by the number of heads
        self.depth = d_model // num_heads

        # Linear layers to project the queries, keys, and values
        self.wq = nn.Linear(d_model, d_model)
        self.wk = nn.Linear(d_model, d_model)
        self.wv = nn.Linear(d_model, d_model)
        self.dense = nn.Linear(d_model, d_model)  # Output linear layer

    def split_heads(self, x, batch_size):
        """
        Split the last dimension into (num_heads, depth).

        Args:
        x (torch.Tensor): Input tensor to split.
        batch_size (int): Batch size of the input.

        Returns:
        torch.Tensor: Tensor with split heads.
        """
        x = x.view(batch_size, -1, self.num_heads, self.depth)
        return x.transpose(1, 2)

    def forward(self, v, k, q, mask=None):
        """
        Forward pass for multi-head attention.

        Args:
        v (torch.Tensor): Value tensor.
        k (torch.Tensor): Key tensor.
        q (torch.Tensor): Query tensor.
        mask (torch.Tensor, optional): Mask tensor to prevent attention to certain positions.

        Returns:
        torch.Tensor: Output tensor after applying attention.
        """
        batch_size = q.size(0)
        
        # Apply the linear layers to the queries, keys, and values
        q = self.split_heads(self.wq(q), batch_size)
        k = self.split_heads(self.wk(k), batch_size)
        v = self.split_heads(self.wv(v), batch_size)

        # Scaled dot-product attention
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.depth)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        attention_weights = F.softmax(scores, dim=-1)
        
        # Combine the attention output and apply the output linear layer
        context_layer = torch.matmul(attention_weights, v)
        context_layer = context_layer.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        output = self.dense(context_layer)
        return output

# Feed-Forward Network layer applied after attention layers
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=dropout_rate):
        """
        Initialize the feed-forward network.

        Args:
        d_model (int): Dimensionality of the embeddings.
        d_ff (int): Dimensionality of the feed-forward layers.
        dropout (float): Dropout rate for regularization.
        """
        super(FeedForward, self).__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        """
        Forward pass for the feed-forward network.

        Args:
        x (torch.Tensor): Input tensor.

        Returns:
        torch.Tensor: Output tensor after feed-forward layers.
        """
        x = F.relu(self.linear1(x))  # First linear layer with ReLU activation
        x = self.dropout(x)  # Apply dropout for regularization
        x = self.linear2(x)  # Second linear layer
        return x

# Add & Norm layer to normalize the output and apply residual connections
class AddNorm(nn.Module):
    def __init__(self, d_model, dropout=0.1):
        """
        Initialize the Add & Norm layer.

        Args:
        d_model (int): Dimensionality of the embeddings.
        dropout (float): Dropout rate for regularization.
        """
        super(AddNorm, self).__init__()
        self.norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, sublayer_output):
        """
        Forward pass for Add & Norm.

        Args:
        x (torch.Tensor): Input tensor.
        sublayer_output (torch.Tensor): Output tensor from the sublayer.

        Returns:
        torch.Tensor: Normalized tensor with residual connection.
        """
        return self.norm(x + self.dropout(sublayer_output))

# Encoder Block consisting of Multi-Head Attention, Feed-Forward Network, and Add & Norm layers
class EncoderBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        """
        Initialize the encoder block.

        Args:
        d_model (int): Dimensionality of the embeddings.
        num_heads (int): Number of attention heads.
        d_ff (int): Dimensionality of the feed-forward layers.
        dropout (float): Dropout rate for regularization.
        """
        super(EncoderBlock, self).__init__()
        self.attention = MultiHeadAttention(d_model, num_heads)
        self.add_norm1 = AddNorm(d_model, dropout)
        self.ffn = FeedForward(d_model, d_ff, dropout)
        self.add_norm2 = AddNorm(d_model, dropout)

    def forward(self, x, mask=None):
        """
        Forward pass for the encoder block.

        Args:
        x (torch.Tensor): Input tensor.
        mask (torch.Tensor, optional): Mask tensor to prevent attention to certain positions.

        Returns:
        torch.Tensor: Output tensor after passing through the encoder block.
        """
        attn_output = self.attention(x, x, x, mask)  # Self-attention
        x = self.add_norm1(x, attn_output)  # Add & Norm
        ffn_output = self.ffn(x)  # Feed-forward network
        x = self.add_norm2(x, ffn_output)  # Add & Norm
        return x

# Decoder Block similar to Encoder Block, but with additional encoder-decoder attention
class DecoderBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        """
        Initialize the decoder block.

        Args:
        d_model (int): Dimensionality of the embeddings.
        num_heads (int): Number of attention heads.
        d_ff (int): Dimensionality of the feed-forward layers.
        dropout (float): Dropout rate for regularization.
        """
        super(DecoderBlock, self).__init__()
        self.self_attention = MultiHeadAttention(d_model, num_heads)
        self.add_norm1 = AddNorm(d_model, dropout)
        self.enc_dec_attention = MultiHeadAttention(d_model, num_heads)
        self.add_norm2 = AddNorm(d_model, dropout)
        self.ffn = FeedForward(d_model, d_ff, dropout)
        self.add_norm3 = AddNorm(d_model, dropout)

    def forward(self, x, enc_output, src_mask=None, tgt_mask=None):
        """
        Forward pass for the decoder block.

        Args:
        x (torch.Tensor): Input tensor.
        enc_output (torch.Tensor): Output tensor from the encoder.
        src_mask (torch.Tensor, optional): Source mask tensor to prevent attention to certain positions in the source.
        tgt_mask (torch.Tensor, optional): Target mask tensor to prevent attention to certain positions in the target.

        Returns:
        torch.Tensor: Output tensor after passing through the decoder block.
        """
        self_attn_output = self.self_attention(x, x, x, tgt_mask)  # Self-attention
        x = self.add_norm1(x, self_attn_output)  # Add & Norm
        enc_dec_attn_output = self.enc_dec_attention(enc_output, enc_output, x, src_mask)  # Encoder-decoder attention
        x = self.add_norm2(x, enc_dec_attn_output)  # Add & Norm
        ffn_output = self.ffn(x)  # Feed-forward network
        x = self.add_norm3(x, ffn_output)  # Add & Norm
        return x

# Transformer model consisting of multiple encoder and decoder blocks
class Transformer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, num_layers, vocab_size, dropout=0.1):
        """
        Initialize the transformer model.

        Args:
        d_model (int): Dimensionality of the embeddings.
        num_heads (int): Number of attention heads.
        d_ff (int): Dimensionality of the feed-forward layers.
        num_layers (int): Number of encoder and decoder layers.
        vocab_size (int): Size of the vocabulary.
        dropout (float): Dropout rate for regularization.
        """
        super(Transformer, self).__init__()
        self.encoder = nn.ModuleList([EncoderBlock(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.decoder = nn.ModuleList([DecoderBlock(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.fc_out = nn.Linear(d_model, vocab_size)  # Final linear layer to project to the vocabulary size

    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        """
        Forward pass for the transformer model.

        Args:
        src (torch.Tensor): Source input tensor.
        tgt (torch.Tensor): Target input tensor.
        src_mask (torch.Tensor, optional): Source mask tensor.
        tgt_mask (torch.Tensor, optional): Target mask tensor.

        Returns:
        torch.Tensor: Output tensor with predictions for each position in the sequence.
        """
        for layer in self.encoder:
            src = layer(src, src_mask)  # Pass through the encoder layers
        for layer in self.decoder:
            tgt = layer(tgt, src, src_mask, tgt_mask)  # Pass through the decoder layers
        output = self.fc_out(tgt)  # Final output layer
        return output

def generate_square_subsequent_mask(sz):
    """
    Generate a square subsequent mask for the sequence, masking out subsequent positions.

    Args:
    sz (int): Size of the mask.

    Returns:
    torch.Tensor: Mask tensor where each position is masked with -inf to prevent attention.
    """
    mask = torch.triu(torch.ones(sz, sz) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

def train_transformer(model, input_tensors, target_tensors, fine_tuning_tensors=None, num_epochs=500):
    """
    Train the transformer model using the provided training data.

    Args:
    model (Transformer): The transformer model to train.
    input_tensors (list of torch.Tensor): List of input sentence tensors.
    target_tensors (list of torch.Tensor): List of target sentence tensors.
    fine_tuning_tensors (list of torch.Tensor, optional): Fine-tuning data tensors.
    num_epochs (int): Number of training epochs.

    Returns:
    None
    """
    criterion = nn.CrossEntropyLoss(ignore_index=-100)  # Loss function for training
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-5)  # Optimizer with learning rate and weight decay
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1)  # Learning rate scheduler

    for epoch in range(num_epochs):
        epoch_loss = 0
        for i, (src, tgt) in enumerate(zip(input_tensors, target_tensors)):
            src, tgt = src.unsqueeze(0).to(device), tgt.unsqueeze(0).to(device)  # Add batch dimension and move to device
            tgt_input = tgt[:, :-1]  # Input all but the last word
            tgt_output = tgt[:, 1:]  # Predict from the second word onward
            tgt_mask = generate_square_subsequent_mask(tgt_input.size(1)).to(device)  # Generate target mask

            # Forward pass through the model
            output = model(pos_encoder(embedding(src)), pos_encoder(embedding(tgt_input)), tgt_mask=tgt_mask)

            # Reshape output for loss computation
            output = output.view(-1, len(vocab))
            tgt_output = tgt_output.contiguous().view(-1)  # Flatten the target

            loss = criterion(output, tgt_output)  # Compute the loss
            optimizer.zero_grad()  # Zero out the gradients
            loss.backward()  # Backpropagation
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # Gradient clipping
            optimizer.step()  # Update model parameters
            epoch_loss += loss.item()

        if fine_tuning_tensors and epoch == num_epochs // 2:  # Fine-tuning the model midway through training
            print(f"Starting fine-tuning at epoch {epoch}")
            input_tensors = fine_tuning_tensors
            target_tensors = fine_tuning_tensors

        scheduler.step()  # Adjust learning rate
        if epoch % 50 == 0:
            print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {epoch_loss / len(input_tensors):.4f}')

def save_model(model, path):
    """
    Save the model state to a file.

    Args:
    model (nn.Module): Model to save.
    path (str): File path to save the model state.

    Returns:
    None
    """
    torch.save(model.state_dict(), path)

def load_model(model, path):
    """
    Load the model state from a file.

    Args:
    model (nn.Module): Model to load.
    path (str): File path from which to load the model state.

    Returns:
    None
    """
    model.load_state_dict(torch.load(path))

# Initialize the positional encoding and embedding layers
pos_encoder = PositionalEncoding(d_model).to(device)
embedding = nn.Embedding(len(vocab), d_model).to(device)

# Initialize the transformer model
model = Transformer(d_model, num_heads, d_ff, num_layers, len(vocab)).to(device)

if LOAD_MODEL and os.path.exists(MODEL_PATH):
    print("Loading model from checkpoint...")
    load_model(model, MODEL_PATH)  # Load the model if the flag is set and the file exists
else:
    print("Training the model...")
    fine_tuning_tensors = encode_sentences(fine_tuning_sentences, vocab)  # Encode fine-tuning sentences
    train_transformer(model, input_tensors, target_tensors, fine_tuning_tensors, num_epochs=500)  # Train the model

    if SAVE_MODEL:
        print("Saving model checkpoint...")
        save_model(model, MODEL_PATH)  # Save the model after training

def predict_two_words_beam_search_full_context_dynamic(model, prompt, beam_width=5, repetition_penalty=5.0):
    """
    Generate a prediction of two words using beam search with full context.

    Args:
    model (Transformer): Trained transformer model.
    prompt (str): Initial prompt to start the prediction.
    beam_width (int): Number of beams to keep during beam search.
    repetition_penalty (float): Penalty to apply for repeating the same word.

    Returns:
    str: Predicted sentence including the prompt and two predicted words.
    """
    model.eval()  # Set model to evaluation mode
    prompt = re.sub(r'[^\w\s]', '', prompt.lower())  # Basic preprocessing of the prompt
    
    # Encode the source (input) tensor
    src = torch.tensor([vocab[word] for word in prompt.split() if word in vocab], dtype=torch.long).unsqueeze(0).to(device)
    
    # Start with the initial prompt tokens
    initial_tokens = [vocab[word] for word in prompt.split() if word in vocab]
    beams = [(initial_tokens, 0)]  # (token sequence, score)

    for _ in range(2):  # Predict exactly two words
        new_beams = []
        for tokens, score in beams:
            # Re-encode the entire sequence for every prediction step
            tgt = torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(device)
            tgt_encoded = pos_encoder(embedding(tgt).float())

            src_encoded = pos_encoder(embedding(src).float())  # Re-encode the source
            tgt_mask = generate_square_subsequent_mask(tgt_encoded.size(1)).to(device)

            # Forward pass through the transformer model
            output = model(src_encoded, tgt_encoded, tgt_mask=tgt_mask)

            # Get the logits for the last time step
            logits = output[:, -1, :]

            # Apply repetition penalty to discourage repeating the same token
            for token_id in set(tokens):
                logits[:, token_id] /= repetition_penalty

            # Get top-k predictions
            top_k_probs, top_k_indices = torch.topk(F.softmax(logits, dim=-1), beam_width)

            for i in range(beam_width):
                new_score = score + torch.log(top_k_probs[0][i])

                # Create new token list for the next step
                new_tokens = tokens + [top_k_indices[0][i].item()]

                # Store the new beam
                new_beams.append((new_tokens, new_score))

        # Keep the top beams
        beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:beam_width]

    best_tokens = beams[0][0]  # Choose the beam with the highest score

    # Return the entire predicted sequence, including the prompt and the predicted words
    return " ".join([inv_vocab[token] for token in best_tokens])

# Example inference with the updated beam search function
predicted_sentence_1 = predict_two_words_beam_search_full_context_dynamic(model, "The sun")
print("Predicted Sentence #1:", predicted_sentence_1)

predicted_sentence_2 = predict_two_words_beam_search_full_context_dynamic(model, "The fire")
print("Predicted Sentence #2:", predicted_sentence_2)

predicted_sentence_3 = predict_two_words_beam_search_full_context_dynamic(model, "The bird")
print("Predicted Sentence #3:", predicted_sentence_3)

predicted_sentence_4 = predict_two_words_beam_search_full_context_dynamic(model, "The wind")
print("Predicted Sentence #4:", predicted_sentence_4)

predicted_sentence_5 = predict_two_words_beam_search_full_context_dynamic(model, "The river")
print("Predicted Sentence #5:", predicted_sentence_5)        

The sample output:

(ai_conda_env) [balapa@system1 ~/bala/toy_transformer]$ python build_tf.py 
Training the model...
Epoch 1/500, Loss: 4.1267
Epoch 51/500, Loss: 0.0024
Epoch 101/500, Loss: 0.0011
Epoch 151/500, Loss: 0.0010
Epoch 201/500, Loss: 0.0009
Starting fine-tuning at epoch 250
Epoch 251/500, Loss: 0.0018
Epoch 301/500, Loss: 0.0064
Epoch 351/500, Loss: 0.0055
Epoch 401/500, Loss: 0.0052
Epoch 451/500, Loss: 0.0052
Saving model checkpoint...
Predicted Sentence #1: the sun sets quietly
Predicted Sentence #2: the fire crackles loudly
Predicted Sentence #3: the bird sings sweetly
Predicted Sentence #4: the wind howls fiercely
Predicted Sentence #5: the river rushes swiftly        


Understanding the NLP Toolbox: Transformers, Sentence Transformers, and LLMs

With a variety of NLP models at your disposal, understanding their unique capabilities can unlock the full potential of your projects. In this section, we'll explore when to leverage transformers, sentence transformers, and LLMs, ensuring you apply the best tool for each task.

1. Transformers: The Foundation of Modern NLP

When to Use:

Sequential Data Processing: Transformers are ideal for tasks that involve processing sequential data, where the order of elements and their relationships are crucial.

Building Blocks: They serve as the foundational architecture for more complex models, offering flexibility and efficiency in handling long-range dependencies in sequences.

Examples:

Machine Translation: Converting text from one language to another by understanding the sequence of words in a sentence.

Text Classification: Categorizing text into predefined categories (e.g., spam detection, sentiment analysis).

Why They Work:

Self-Attention Mechanism: Transformers utilize self-attention to capture relationships between all elements in a sequence simultaneously, allowing the model to weigh the importance of each word relative to others.

2. Sentence Transformers: Capturing the Meaning of Sentences

When to Use:

Sentence-Level Understanding: Sentence transformers are your go-to models when the task requires understanding or comparing the meaning of entire sentences.

Semantic Search and Similarity: They are specifically designed to generate dense vector representations of sentences, making them ideal for tasks that involve comparing or searching for meaning within large text corpora.

Examples:

Semantic Search: Finding documents or sentences that are semantically similar to a given query.

Sentence Similarity: Measuring how similar two sentences are in meaning, which is useful in tasks like duplicate detection or paraphrase identification.

Why They Work:

Dense Embeddings: Sentence transformers produce compact, meaningful embeddings that capture the semantic essence of a sentence, enabling efficient comparison and clustering of textual data.

3. Large Language Models (LLMs): The Powerhouse of Text Generation

When to Use:

General-Purpose Text Generation: LLMs are best used when you need a powerful, versatile model that can generate human-like text across a variety of contexts with minimal additional training.

Dialogue and Question-Answering: They excel in tasks that require understanding context and generating coherent, contextually appropriate responses.

Examples:

Chatbots and Virtual Assistants: Engaging users in natural, human-like conversation by understanding and responding to complex queries.

Content Generation: Creating blog posts, articles, or even code snippets based on minimal input or prompts.

Why They Work:

Massive Pre-Training: LLMs are pre-trained on vast amounts of text data, allowing them to capture a wide range of language patterns, knowledge, and nuances. This extensive training makes them adaptable to many different tasks with little additional training required.

Conclusion

By building a toy transformer model, we’ve gained hands-on experience with the foundational architecture that powers many of today’s NLP advancements. Understanding the distinctions between transformers, sentence transformers, and large language models (LLMs) equips us with the knowledge to choose the most effective tool for our specific NLP tasks. As the field continues to evolve, let's keep exploring and learning together, striving to harness the full potential of these powerful technologies in our work.


#AI #MachineLearning #Transformers #DataScience #GenerativeAI #NLP #AIResearch #AIApplications #LearningTogether


Vivek Sharma

Data Science | Machine Learning | Spark | CDAC - Applied Deep Learning | Ex -Dunnhumby

7 个月

Insightful! Balachandran

RISHABH VARSHNEY

Generation AI Engineer in Infosys | Ex-Tcser |

7 个月

Superb way to explain everything Balachandran.

要查看或添加评论,请登录

社区洞察

其他会员也浏览了