In my ongoing journey to become a top product management SME in the ML/AI space, I've generated and printed enough outlines to clear a small forest. For fellow product managers interested in ML/AI concepts without the technical overload, I'll be sharing these digestible outlines from time to time.
Today's focus: Core pre-training methods for LLMs or how LLMs learn to do what they do.
TL:DR: Pre-training is a crucial, resource-intensive phase where models undergo 'unsupervised learning'. Instead of studying labelled data to predict the label of new data, unsupervised learning is about deciphering context (a whole other rabbit hole) to generate meaningful outputs (Like when you run to Perplexity for THAT authentic Argentinian chimichurri recipe).
During pre-training, LLMs devour massive datasets (often text) and leverage various methods (listed below) to learn contextual relationships. This learned context shapes the model's parameters/weights, ultimately driving its performance.
While not exhaustive, this list covers the major methods in pre-training tasks today.
Quick note: The one consistent term you will see in the below summaries is 'cross-entropy loss'. This is 'machine learning-speak' for measuring the difference between a model's specific prediction vs the actual outcome it is trying to predict. If you're an accountant, think of it as your 'budget vs actual', the smaller the difference, the better.
Disclaimer: There are likely a few mistakes and/or misinterpretations here, so I'll gladly welcome and accept all corrections and revise when needed.
1. Masked Language Modeling (MLM)
... Your basic ML sobriety test.
Core Mechanism:
- Masking: During pre-training, a portion of the input tokens (typically 15%) is randomly selected and replaced with a special [MASK] token.
- Prediction: The model's objective is to predict the original identity of these masked tokens based on the context provided by the remaining unmasked tokens in the sequence.
Variations in Masking Strategies:
- Basic Masking: The simplest approach where the selected tokens are directly replaced with [MASK].
- Whole Word Masking: If a subword of a word is selected for masking, the entire word is masked. This encourages the model to understand word-level semantics.
- Dynamic Masking: The masking pattern is randomly generated for each training instance, preventing the model from memorizing specific masking patterns.
- Span Masking: Consecutive sequences of tokens are masked, promoting the model's ability to understand longer-range dependencies.
Training Objective:
- The model typically uses a cross-entropy loss function to compare its predicted probability distribution over the vocabulary with the true distribution of the masked token.
Benefits and Impact:
- Bidirectional Contextual Understanding: By predicting masked tokens based on both preceding and following context, the model learns to represent words in relation to their entire surrounding context, not just one direction.
- Foundation for Language Understanding: MLM serves as a powerful pre-training objective that enables the model to acquire general language knowledge, which can then be fine-tuned for specific downstream tasks.
- Self-Supervised Learning: MLM leverages the vast amount of unlabeled text data available, eliminating the need for expensive and time-consuming manual annotation.
Real-world Applications:
Improved performance on various NLP tasks: MLM has been shown to significantly improve performance on tasks like:
- Text Classification: Sentiment analysis, topic categorization
- Question Answering: Extractive and generative question answering
- Machine Translation: Translating text from one language to another
- Text Summarization: Generating concise summaries of longer texts???
- Natural Language Generation: Generating coherent and contextually relevant text
Beyond the Basics:
- Advanced Techniques: Researchers continue to explore more sophisticated masking strategies and training objectives to further enhance the effectiveness of MLM.
- Model Architectures: While MLM is commonly associated with transformer-based models, it can also be applied to other neural network architectures.
Example (with more context):
- Input: "The quick brown fox jumps over the [MASK] dog."
- Contextual Clues:
- "The quick brown fox jumps over" suggests a physical action.
- "dog" at the end suggests the masked word describes the dog.
- Predicted output: "lazy" (or other adjectives describing a dog, like "sleeping", "small")
2. Next Sentence Prediction (NSP)
... Who invited the 'corpus' crasher?
Core Mechanism:
- Sentence Pairs: The model is presented with pairs of sentences (Sentence A and Sentence B).
- Binary Classification: The task is to classify whether Sentence B is the actual next sentence that follows Sentence A in the original text or if it's a randomly selected sentence from the corpus.
Data Preparation:
- Positive Examples: Sentence pairs where Sentence B directly follows Sentence A in the original text.
- Negative Examples: Sentence pairs where Sentence B is randomly chosen from the corpus and is unrelated to Sentence A.
- Ratio: Typically, a 50/50 ratio of positive and negative examples is used during training.
Training Objective:
- The model typically uses a binary cross-entropy loss function to measure its performance in predicting whether a sentence pair is "IsNext" or "NotNext."
Benefits and Impact:
- Understanding Sentence Relationships: NSP helps the model learn to recognize coherence and discourse structure in text. It learns to identify logical connections, topic continuity, and other relationships between sentences.
- Foundation for Downstream Tasks: This understanding of sentence relationships is crucial for tasks that require processing and generating coherent multi-sentence text, such as:
- Text Summarization: Identifying the most important sentences in a document and organizing them into a coherent summary.
- Question Answering: Understanding the context of a question and identifying relevant information across multiple sentences in a passage.
- Dialogue Systems: Generating contextually appropriate and coherent responses in a conversation.
Real-world Applications:
- Improved performance on tasks requiring sentence-level understanding: NSP has been shown to contribute to improved performance on tasks like:
- Document Classification: Categorizing documents based on their overall theme or topic.
- Natural Language Inference: Determining the logical relationship between two sentences (entailment, contradiction, or neutral).
- Paraphrase Detection: Identifying whether two sentences express the same meaning.
Evolution and Alternatives:
- NSP in BERT: NSP was initially introduced in the BERT model.
- Limitations and Alternatives: Some studies have questioned the effectiveness of NSP, and alternative pre-training tasks like Sentence Order Prediction (SOP) have been proposed to better capture discourse-level understanding.
- Continued Research: The field of NLP is constantly evolving, and researchers are actively exploring new and improved pre-training objectives to enhance language models' understanding of sentence relationships and discourse.
Example (with more context):
- Input:
- Sentence A: "The concert was amazing last night."
- Sentence B: "The band played all their greatest hits."
- Contextual Clues:
- Sentence B provides additional details about the concert mentioned in Sentence A, suggesting a logical continuation.
- Predicted output: IsNextSentence
3. Causal Language Modeling (CLM)
... Predicting the future, one word at a time.
Core Mechanism:
- Sequential Prediction: The model processes the input text from left to right, one token at a time. At each step, it predicts the probability distribution over the entire vocabulary for the next token, given all the previous tokens in the sequence.
- Autoregressive Nature: The model generates text by repeatedly sampling from this predicted distribution and appending the sampled token to the sequence. This process continues until a special end-of-sequence token is generated or a desired length is reached.
Training Objective:
- The model is typically trained using a maximum likelihood estimation (MLE) objective, which aims to maximize the probability of the actual next token in the training data given the previous context. The loss function used is often cross-entropy loss.
Benefits and Impact:
- Generative Capabilities: CLM empowers the model to generate coherent and contextually relevant text, making it suitable for a wide range of natural language generation tasks.
- Foundation for Text Generation: CLM serves as the backbone for tasks like:
- Text Completion: Predicting the most likely continuation of a given text prompt.
- Machine Translation: Generating translations of sentences or documents from one language to another.
- Dialogue Systems/Chatbots: Generating responses that are relevant and engaging in a conversation.
- Creative Writing: Generating poems, stories, or other forms of creative text.
- Code Generation: Assisting in writing code snippets or completing code based on context.
Real-world Applications:
- Text Prediction and Autocomplete: CLM powers features like autocomplete in search engines and messaging apps, suggesting the next word or phrase as you type.
- Machine Translation Services: Online translation tools rely heavily on CLM-based models to provide accurate and fluent translations.
- Chatbots and Virtual Assistants: CLM enables chatbots to generate human-like responses and engage in meaningful conversations.
- Content Creation: CLM can be used to generate various forms of content, including product descriptions, blog posts, and social media captions.
Considerations and Challenges:
- Bias and Sensitivity: CLM models can inherit biases present in the training data, leading to potentially harmful or discriminatory outputs. Careful data curation and bias mitigation techniques are essential.
- Control and Safety: Generating text that is factually accurate, safe, and aligned with human values remains a challenge. Researchers are actively working on developing methods to control and steer the output of CLM models.
Example (with more context):
- Input: "The quick brown fox jumps over the"
- Model's Processing:
- The model processes each token sequentially ("The", "quick", "brown", "fox", "jumps", "over", "the").
- At each step, it predicts the probability distribution for the next token based on the preceding context.
- Predicted output: "lazy"
- The model assigns the highest probability to "lazy" among all possible tokens in its vocabulary, considering the context of the sentence.
4. Permuted Language Modeling (PLM)
... Shake 'my tokens' like a polaroid picture.
Core Mechanism:
- Permutation/Rearrangement: The input sequence of tokens is randomly shuffled or permuted. The model's task is to predict the original, correct order of these tokens.
- No Masking: Unlike Masked Language Modeling (MLM), PLM doesn't involve replacing tokens with [MASK].The entire sequence is present, but in a jumbled order.
Variations in Permutation Strategies:
- Random Shuffling: The simplest approach where the tokens are randomly rearranged.
- Controlled Permutations: The permutations might be constrained to maintain some local word order or grammatical structure, making the task slightly easier while still encouraging the model to learn long-range dependencies.
Training Objective:
- The model typically uses a cross-entropy loss function to compare its predicted probability distribution over all possible permutations with the true original order.
Benefits and Impact:
- Deeper Contextual Understanding: By reconstructing the original order from a permuted sequence, the model is forced to learn complex relationships between tokens that might be far apart in the sequence. This helps it capture long-range dependencies and understand the overall structure of the sentence.
- Enhanced Language Modeling: PLM has been shown to improve the model's ability to generate fluent and grammatically correct text, as it learns to consider the broader context when predicting the position of each token.
Real-world Applications:
- Improved performance on various NLP tasks: PLM has been shown to benefit tasks like:
- Machine Translation: Capturing long-range dependencies helps in generating more accurate and contextually appropriate translations.
- Text Summarization: Understanding the overall structure of a document aids in identifying and organizing the most important information.
- Language Generation: Generating text that is not only fluent but also grammatically correct and structurally sound.
Relationship to Other Pre-training Tasks:
- Complementary to MLM: PLM can be used in conjunction with MLM to provide a more comprehensive pre-training experience. MLM focuses on local context and word prediction, while PLM emphasizes global sentence structure and long-range dependencies.
Example (with more context):
- Input: "jumps fox brown quick the over lazy dog the"
- Model's Processing:
- The model analyzes the permuted sequence and tries to identify clues about the original order.
- It might recognize common phrases ("the lazy dog"), grammatical structures (adjective-noun pairs like "quick brown"), and semantic relationships (subject-verb- object like "fox jumps over").
- Predicted output: "The quick brown fox jumps over the lazy dog"
- The model rearranges the tokens to form the most likely and coherent sentence based on its understanding of language structure and semantics.
5. Sentence Order Prediction (SOP)
... When you really want to mess with someone.
Core Mechanism:
- Shuffled Sentences: The model receives a set of sentences that originally formed a coherent paragraph or passage, but their order has been randomly shuffled.
- Order Prediction: The task is to predict the correct, original order of these sentences, restoring the logical flow and coherence of the text.
Variations in Shuffling Strategies:
- Random Shuffling: The simplest approach where the sentences are completely randomly rearranged.
- Controlled Shuffling: The shuffling might be constrained to maintain some local coherence or avoid overly complex permutations, making the task slightly easier while still encouraging the model to learn discourse structure.
Training Objective:
- The model typically uses a permutation language modeling objective, where it predicts the probability of each possible sentence order and is trained to maximize the probability of the correct order. The loss function used is often cross-entropy loss.
Benefits and Impact:
- Discourse Structure Understanding: SOP forces the model to learn about the relationships between sentences, including temporal order, cause-and-effect, and logical flow. This helps it understand how sentences contribute to a larger narrative or argument.
- Coherence Modeling: The model learns to recognize what makes a sequence of sentences coherent and meaningful, which is crucial for generating or processing natural language that is easy to understand and follow.
Real-world Applications:
- Text Summarization: Identifying the most important sentences and arranging them in a logical order to create a coherent summary.
- Question Answering: Understanding the context of a question and the relationships between sentences in a passage to provide accurate answers.
- Dialogue Systems: Generating responses that are not only contextually relevant but also maintain the flow and coherence of the conversation.
- Document Organization: Automatically organizing information within a document or across multiple documents based on the logical flow of ideas.
Relationship to Other Pre-training Tasks:
- Complementary to NSP: While Next Sentence Prediction (NSP) focuses on predicting whether two sentences are consecutive, SOP extends this to understanding the order of multiple sentences within a larger context.
- Enhances Language Understanding: SOP, along with other pre-training tasks, contributes to the model's overall ability to understand and generate coherent and meaningful text.
Example (with more context):
- Input:
- (1) He scored the winning goal in the final minute.
- (2) The crowd erupted in cheers.
- (3) The soccer match was intense.
- Model's Processing:
- The model analyzes the shuffled sentences and tries to identify clues about their original order.
- It might recognize that sentence 3 sets the scene, sentence 1 describes a climactic event, and sentence 2 describes the reaction to that event.
- Predicted output: 3, 1, 2
- The model rearranges the sentences to create a coherent narrative: "The soccer match was intense. He scored the winning goal in the final minute. The crowd erupted in cheers."
6. Denoising Autoencoder (DAE)
... I mean, someone has to handle your typos.
Core Mechanism:
- Noise Introduction: The original input text is intentionally corrupted by applying various noise functions. Common noise types include:
- Token Replacement: Randomly replacing some tokens with other tokens from the vocabulary or with a special [MASK] token.
- Token Deletion: Randomly removing some tokens from the sequence.
- Token Shuffling: Rearranging the order of tokens within a certain window or the entire sequence.
- Character-Level Noise: Introducing typos or misspellings at the character level.
- Reconstruction: The model's task is to take this noisy input and reconstruct the original, clean text.
Training Objective:
- The model is typically trained to minimize the reconstruction error, which measures the difference between the predicted output and the original clean text. The loss function used is often mean squared error (MSE) or cross-entropy loss.
Benefits and Impact:
- Robust Language Representations: By learning to reconstruct clean text from noisy input, the model develops a deeper understanding of the underlying meaning and structure of language, even in the presence of errors or variations.
- Improved Generalization: The model becomes more resilient to noise and variations in real-world data, leading to better performance on downstream tasks where the input might be imperfect or contain errors.
Real-world Applications:
- Machine Translation: Handling noisy or grammatically incorrect input in the source language and generating fluent and accurate translations.
- Text Summarization: Extracting the key information from noisy or poorly written text.
- Grammatical Error Correction: Identifying and correcting grammatical errors or typos in user-generated text.
- Speech Recognition: Handling noisy or incomplete speech input and transcribing it into accurate text.
- Data Cleaning and Preprocessing: Automatically correcting errors and inconsistencies in large text datasets.
Relationship to Other Pre-training Tasks:
- Complementary to MLM: While MLM focuses on predicting masked tokens within a clean context, DAE focuses on reconstructing the entire sequence from a noisy context. Both tasks encourage the model to learn robust language representations.
Example (with more context):
- Input: "The quikc brwn fox jmps oevr the lzy dog." (corrupted with typos)
- Model's Processing:
- The model analyzes the noisy input and tries to identify the underlying meaning despite the errors.
- It might leverage its knowledge of common word spellings, grammar rules, and semantic relationships to infer the intended words.
- Predicted output: "The quick brown fox jumps over the lazy dog."
- The model reconstructs the original, clean sentence by correcting the typos and restoring the correct word order.
7. Contrastive Predictive Coding (CPC)
Core Mechanism:
- Context and Target: The model is given a sequence of tokens representing the past context and a target token (or sequence of tokens) that follows the context.
- Positive and Negative Samples:
- Positive Sample: The actual token(s) that follow the context in the original text.
- Negative Samples: Randomly sampled tokens (or sequences) from the corpus that are not the actual continuation of the context.
- Discrimination: The model is trained to distinguish between the positive sample and the negative samples, assigning a higher probability to the positive sample.
Training Objective:
- The model typically uses a contrastive loss function, such as InfoNCE loss, which aims to maximize the mutual information between the context and the positive sample while minimizing the mutual information between the context and the negative samples.
Benefits and Impact:
- Learning High-Level Representations: CPC encourages the model to learn representations that capture the underlying semantic and syntactic structure of the language, as it needs to distinguish between meaningful continuations and random noise.
- Contextual Understanding: The model learns to predict future information based on the past context, which is crucial for understanding the flow and meaning of text.
Real-world Applications:
- Text Generation: Generating coherent and contextually relevant text by predicting the most likely continuation of a given prompt.
- Machine Translation: Capturing the semantic and syntactic relationships between words and phrases in different languages to produce accurate translations.
- Sentiment Analysis: Understanding the sentiment expressed in a text by considering the context and predicting the likely sentiment of subsequent words or phrases.
- Recommendation Systems: Predicting user preferences and recommending items based on their past behavior and context.
Relationship to Other Pre-training Tasks:
- Complementary to MLM and CLM: CPC focuses on learning high-level representations and contextual understanding, while MLM and CLM focus more on predicting specific tokens. These tasks can be used together to provide a more comprehensive pre-training experience.
Example (with more context):
- Input: "The cat sat on the"
- Positive Example: "mat"
- Negative Examples: "dog", "car", "apple"
- Model's Processing:
- The model encodes the context ("The cat sat on the") into a representation.
- It also encodes the positive and negative samples into representations.
- It then computes a similarity score between the context representation and each sample representation.
- Predicted output:
- The model is trained to assign a higher similarity score to the positive sample ("mat") compared to the negative samples ("dog", "car", "apple"). This indicates that "mat" is the most likely continuation of the given context.
Additional Considerations:
- Number of Negative Samples: The number of negative samples used during training can influence the model's performance. More negative samples generally lead to better discrimination but also increase computational cost.
- Sampling Strategies: Different strategies can be used to sample negative examples, such as random sampling from the corpus or using techniques like importance sampling to select more challenging negative examples.
8. Translation Language Modeling (TLM)
Core Mechanism:
- Parallel Corpus: The model is trained on a dataset containing pairs of sentences that are translations of each other in two or more languages (e.g., English and French).
- Masking and Prediction: Similar to Masked Language Modeling (MLM), some tokens in both the source and target sentences are masked. The model is then trained to predict the original masked tokens in both languages,leveraging the parallel context.
Training Objective:
- The model typically uses a cross-entropy loss function to compare its predicted probability distributions over the vocabularies of both languages with the true distributions of the masked tokens.
Benefits and Impact:
- Cross-lingual Representations: TLM encourages the model to learn representations that capture the semantic and syntactic similarities and differences between languages.
- Alignment and Translation: The model learns to align words and phrases across languages, which is crucial for machine translation tasks.
Real-world Applications:
- Machine Translation: TLM has been shown to significantly improve the performance of machine translation systems, especially for low-resource languages where parallel data is scarce.
- Cross-lingual Transfer Learning: The model can leverage knowledge learned from one language to improve its performance on tasks in other languages, even with limited labeled data in those languages.
- Multilingual NLP: TLM can be used to build models that can understand and generate text in multiple languages, enabling applications like multilingual search and information retrieval.
Relationship to Other Pre-training Tasks:
- Extension of MLM: TLM can be seen as an extension of MLM to the multilingual setting, where the model learns to predict masked tokens in multiple languages simultaneously.
- Synergistic with Other Tasks: TLM can be combined with other pre-training tasks like MLM and NSP to further enhance the model's language understanding and generation capabilities across multiple languages.
Example:
- Input:
- English: "The cat sat on the mat."
- French: "Le chat était assis sur le tapis."
- Masked English: "The cat sat on the [MASK]."
- Masked French: "Le chat était assis sur le [MASK]."
- Predicted output:
- English: "mat"
- French: "tapis"
9. Replaced Token Detection (RTD)
... Fake news? Not on my watch.
Core Mechanism:
- Token Replacement: The original input text is modified by replacing some tokens with other tokens that are semantically or syntactically plausible but incorrect in the given context.
- Binary Classification: For each token in the modified sequence, the model's task is to classify whether it's the original token or a replaced token.
Training Objective:
- The model is typically trained using a binary cross-entropy loss function to measure its performance in identifying replaced tokens.
Benefits and Impact:
- Enhanced Contextual Understanding: RTD forces the model to pay close attention to the subtle relationships between words and their context, as it needs to identify tokens that are semantically or syntactically inconsistent with their surroundings.
- Improved Language Representation: The model learns to create more nuanced and context-sensitive representations of words, which can benefit various downstream tasks.
Real-world Applications:
- Grammatical Error Correction: Identifying and correcting words that are used incorrectly in a sentence.
- Text Style Transfer: Recognizing and modifying words to change the style or tone of a text.
- Adversarial Attack Detection: Detecting subtle modifications made to text with the intent to deceive or manipulate.
Relationship to Other Pre-training Tasks:
- Complementary to MLM: While MLM focuses on predicting masked tokens, RTD focuses on identifying replaced tokens, both encouraging the model to learn robust language representations.
Example:
- Original Input: "The quick brown fox jumps over the lazy dog."
- Modified Input: "The quick brown fox leaps over the idle dog."
- Predicted output:
- The: Original
- quick: Original
- brown: Original
- fox: Original
- leaps: Replaced (original: jumps)
- over: Original
- the: Original
- idle: Replaced (original: lazy)
- dog: Original