Self-Extend: A Simple Yet Effective Approach to Extend Context Windows of LLMs

Self-Extend: A Simple Yet Effective Approach to Extend Context Windows of LLMs

In the realm of LLMs, the context window length has been a limiting factor, hindering their ability to comprehend and generate text beyond a certain length. A new approach called Self-Extend challenges this limitation, unlocking the inherent potential of language models to handle long contexts without the need for extensive fine-tuning. this innovative method, based on the belief that language models possess an innate capacity for long-context understanding, employs a simple yet effective strategy to extend their context window size. by skillfully mapping unseen relative positions to those encountered during training, Self-Extend empowers language models to navigate longer sequences with remarkable proficiency.

Eliciting LLMs' Inherent Long Context Capabilities without Fine-tuning :

Self-Extend elicits LLMs' inherent long context capabilities without fine-tuning by mapping unseen relative positions into those seen during pre-training via the FLOOR operation. This is done by applying the FLOOR operation to each token's original position before the inner product is calculated in the attention mechanism. This operation essentially groups tokens together based on their relative distances, allowing the LLM to attend to long-range dependencies without having to learn new positional encodings.

Additionally, Self-Extend maintains the attention mechanism unchanged for neighbor tokens within a certain range, ensuring that the LLM can still accurately model local relationships between tokens.

A Key Challenge in Long Context Management for LLMs :

The key challenge preventing LLMs from effectively managing extensive contexts is the out-of-distribution (O.O.D) issues related to positional encoding, which is called the positional O.O.D issue. This problem arises when LLMs encounter text sequences during inference exceeding the length of their pre-training context window, where LLMs are exposed to new relative distances that were not present during their pre-training phase. It is widely recognized that neural networks are susceptible to unpredictable behaviors when dealing with O.O.D inputs.

Addressing the Positional O.O.D Issue in Self-Extend:

Self-Extend addresses the positional O.O.D issue by mapping unseen large relative positions (at inference) to known positions (at training), thus it allows LLMs to maintain coherence over longer texts without additional fine-tuning.


More specifically, Self-Extend uses the simple FLOOR (//) operation as the mapping function to map unseen large relative positions to those encountered during pre-training. This idea stems from two intuitions:

1) For texts with a long distance between words, the exact position does not need to be precise. It is sufficient to understand the overall meaning of the text as long as the relative ordering of the different parts is maintained.

2) In natural language texts, most of the time, while a small bag of words (n-grams) appears together in one area, all the tokens in that bag have only one possible order due to the conventions of the language grammar.

By using the FLOOR operation, Self-Extend effectively groups tokens with similar relative positions together, allowing the LLM to generalize its knowledge from the training data to unseen long contexts. This simple yet effective approach enables LLMs to handle long sequences without the need for fine-tuning, preserving their inherent long context capabilities.

Grouped Attention for Long-Distance Tokens and Normal Attention for Neighbor Tokens :

Self-Extend employs two types of attention mechanisms: grouped attention and normal attention.

1. Grouped Attention:

- Purpose: Handles tokens with long-distance relationships.

- Implementation:

- Applies the FLOOR operation to the positions of query and key tokens.

- Computes the attention scores using the modified positions.

- The FLOOR operation maps unseen large relative positions to known positions, allowing the LLM to maintain coherence over longer texts without additional fine-tuning.

2. Normal Attention:


- Purpose: Models the relationships between neighbor tokens.

- Implementation:

- Utilizes the original self-attention mechanism without any modifications.

- Maintains the attention mechanism unchanged for neighbor tokens within a certain range.

- Ensures that the generated sentence is fluent and the perplexity is not significantly increased.

Collaboration:

- The two attention mechanisms work together to achieve effective long context understanding:

- Grouped attention captures long-distance dependencies and enables the LLM to access information from distant parts of the input sequence.

- Normal attention precisely models the relationships between neighbor tokens, ensuring the fluency and coherence of the generated text.

- The combination of these attention mechanisms allows Self-Extend to effectively handle long contexts without sacrificing the quality of the generated text.

Maintaining Low Perplexity Out of the Pretraining Context Window with Self-Extend : (A Combination of Grouped and Normal Attention)

Self-Extend maintains a low perplexity (PPL) out of the pre-training context window by merging the normal attention mechanism, which is used for neighbor tokens within a certain range, with the grouped attention mechanism, which is designed for tokens with long distance. The grouped attention mechanism applies the FLOOR operation to the positions, which maps unseen large relative positions to those seen during pretraining. This allows the model to maintain coherence over longer texts without additional fine-tuning.

Additionally, Self-Extend keeps the attention mechanism unchanged in the neighbor area, which ensures that the generated sentence is fluent and the PPL is not large. The transition from the normal attention area to the grouped attention area is smooth, and the two parts of attention are merged by replacing the attention values out of the neighbor token window with the attention values from the grouped attention. This helps to maintain a low PPL out of the pretraining context window.

Passkey Retrieval Task Results Demonstrating Self-Extend's Effectiveness :


On the passkey retrieval task, Self-Extend achieves a 100% passkey retrieval accuracy across all tested depths and context lengths, while Mistral-7b-instruct-0.1 with SWA nearly cannot retrieve the passkey out of the sliding window. This result strongly suggests that Self-Extend can effectively extend the context window of LLMs and enable them to access information across long sequences, while Mistral-7b-instruct-0.1 with SWA, despite having a low perplexity beyond its pretraining context window, still cannot truly handle long contexts.

The passkey retrieval task is a simple task that requires a language model to retrieve a five-digit random number (passkey) in a long meaningless text sequence. This task tests whether an LLM can be aware of the information across all positions of the input sequence. The results of the passkey retrieval task demonstrate the effectiveness of Self-Extend in extending the context window of LLMs and enabling them to access information across long sequences.

Comparing Self-Extend to Fine-tuned Models on Longbench and L-Eval Benchmarks :

On the Longbench benchmark, Self-Extend achieves comparable or better performance than many fine-tuned models on several datasets. For example, on the HotpotQA dataset, Self-Extend outperforms all fine-tuned counterparts. On the MultiNews dataset, Self-Extend has comparable performance to fine-tuned models.

On the L-Eval benchmark, Self-Extend also achieves comparable or better performance than fine-tuned models on most datasets. For example, on the Coursera dataset, Self-Extend outperforms all fine-tuned baselines. On the GSM dataset, Self-Extend has comparable performance to fine-tuned models.

Overall, Self-Extend compares favorably to fine-tuned models on both the Longbench and L-Eval benchmarks, demonstrating its effectiveness as a fine-tuning-free method for extending the context window of LLMs.


Limitations of Self-Extend and Potential Future Work: Addressing Implementation Issues, Exploring Sophisticated Mapping Methods, and Testing on Larger Models and Longer Contexts

Limited context window extension: Self-Extend can only extend the context window to a certain length, which is determined by the group size and the neighbor window size. With the current implementation of Self-Extend, the context window cannot be extended to infinity.

Performance degradation with large group size: As the group size increases, the performance of Self-Extend may degrade. This is because a larger group size means that more tokens are mapped to the same relative position, which can lead to information loss.

Lack of implementation of Flash Attention: Self-Extend does not currently implement Flash Attention, which is a more efficient attention mechanism that can be used to speed up the computation of self-attention.

How to address the limitations of Self-Extend in future work ?

Develop more sophisticated mapping methods: The simple FLOOR operation used in Self-Extend can be replaced with more sophisticated mapping methods that can better preserve the positional information of tokens. This may lead to better long context understanding abilities and a longer extended context window length.

Implement Flash Attention: Implementing Flash Attention in Self-Extend can improve its efficiency and make it more practical for use in real-world applications.

Explore the use of Self-Extend with larger models and longer contexts: Self-Extend has been tested on models with a maximum context window size of 16k tokens. It would be interesting to see how Self-Extend performs on models with larger context window sizes and on tasks that require understanding longer contexts.

Potential Applications of Self-Extend in Real-World Scenarios :

Self-Extend, a fine-tuning-free method for extending the context window of pretrained LLMs, has various potential applications in real-world scenarios:

1. Long-Form Question Answering: Self-Extend can empower LLMs to answer questions that require reasoning over long contexts, such as answering questions about historical events, scientific discoveries, or legal cases. By extending the context window, LLMs can access more relevant information and generate more comprehensive and accurate answers.

2. Document Summarization: it can be utilized to summarize long documents, such as research papers, news articles, or legal documents. By considering the entire document as context, Self-Extend enables LLMs to capture the main points and generate concise and informative summaries.

3. Machine Translation: it can be applied to machine translation tasks involving long sentences or documents. By extending the context window, LLMs can better understand the context and generate more fluent and accurate translations.

4. Code Generation: it can assist in generating code for programming tasks. By providing the LLM with the entire codebase as context, Self-Extend allows the LLM to generate code that is consistent with the existing code and adheres to the programming language's syntax and semantics.

5. Conversational AI: it can enhance the capabilities of conversational AI systems by enabling them to maintain context over multiple turns of a conversation. By remembering and reasoning over the entire conversation history, Self-Extend allows conversational AI systems to generate more coherent and relevant responses.

6. Legal Research: it can be used to analyze legal documents and case law. By providing the LLM with the full text of a legal document or a collection of case law, Self-Extend enables the LLM to identify relevant legal provisions, extract key facts, and generate insights that can aid legal professionals in their research.

7. Medical Diagnosis: Self-Extend can be applied to medical diagnosis tasks. By providing the LLM with a patient's medical history, test results, and other relevant information, Self-Extend allows the LLM to consider all available data and generate more accurate diagnoses.

How does Self-Extend compare to other context window extension methods that do not require fine-tuning?

Self-Extend outperforms other context window extension methods that do not require fine-tuning in several ways:

- Effectiveness: Self-Extend is more effective at extending the context window length of LLMs without sacrificing performance. In experiments on the LongBench benchmark, Self-Extend achieved state-of-the-art results on a variety of long-context tasks, outperforming other non-fine-tuning methods such as NTK and further trained baselines such as Longchat1.5-7b-32k and Vicuna1.5-7b-32k.

- Efficiency: Self-Extend is more efficient than other non-fine-tuning methods. It only requires a few lines of code to modify the attention mechanism of an existing LLM, and it does not require any additional training or fine-tuning. This makes it easy to apply Self-Extend to any LLM, and it can be used to extend the context window length of LLMs on-the-fly.

- Generality: Self-Extend is more general than other non-fine-tuning methods. It can be applied to any LLM that uses RoPE for positional encoding, regardless of the size or architecture of the model. this makes it a versatile tool for extending the context window length of LLMs, and it can be used to improve the performance of LLMs on a wide range of tasks.

Overall, Self-Extend is a powerful and effective method for extending the context window length of LLMs without fine-tuning. It is easy to use, efficient, and general, and it can significantly improve the performance of LLMs on long-context tasks.


Self-Extend Implementation (Torch) :

self_extend for Mistral :

def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
    # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
    cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]
    sin = sin.squeeze(1).squeeze(0)  # [seq_len, dim]
    cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    q_embed = (q * cos[:,:, -q.shape[2]:]) + (rotate_half(q) * sin[:,:, -q.shape[2]:]) if q is not None else None
    k_embed = (k * cos) + (rotate_half(k) * sin) if k is not None else None
    return q_embed, k_embed

def apply_grouped_rotary_pos_emb(q, k, cos, sin, position_ids, g_size_1=1, g_size_2=4096):
    # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
    position_ids_q = position_ids//g_size_1 + g_size_2 - g_size_2//g_size_1
    position_ids_k = position_ids//g_size_1

    cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]
    sin = sin.squeeze(1).squeeze(0)  # [seq_len, dim]
    cos_q = cos[position_ids_q].unsqueeze(1)  # [bs, 1, seq_len, dim]
    sin_q = sin[position_ids_q].unsqueeze(1)  # [bs, 1, seq_len, dim]
    cos_k = cos[position_ids_k].unsqueeze(1)  # [bs, 1, seq_len, dim]
    sin_k = sin[position_ids_k].unsqueeze(1)  # [bs, 1, seq_len, dim]
    q_embed = (q * cos_q) + (rotate_half(q) * sin_q) if q is not None else None
    k_embed = (k * cos_k) + (rotate_half(k) * sin_k) if k is not None else None

    return q_embed, k_embed

def apply_neighbor_rotary_pos_emb(q, k, cos, sin, position_ids, g_size=1):
    # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
    position_ids = position_ids % g_size

    cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]
    sin = sin.squeeze(1).squeeze(0)  # [seq_len, dim]
    cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

def apply_identical_rotary_pos_emb(q, k, cos, sin, position_ids, idd_position=1024):
    position_ids = torch.ones_like(position_ids) * idd_position

    cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]
    sin = sin.squeeze(1).squeeze(0)  # [seq_len, dim]
    cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

def self_extend_forward(
    self,
    hidden_states: torch.Tensor,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    past_key_value: Optional[Cache] = None,
    output_attentions: bool = False,
    use_cache: bool = False,
    padding_mask: Optional[torch.LongTensor] = None,
    group_size_1: Optional[float] = 8,
    group_size_2: Optional[float] = 2048,
    **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
    if "padding_mask" in kwargs:
        warnings.warn(
            "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
        )
    bsz, q_len, _ = hidden_states.size()

    query_states = self.q_proj(hidden_states)
    key_states = self.k_proj(hidden_states)
    value_states = self.v_proj(hidden_states)

    query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
    key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
    value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

    kv_seq_len = key_states.shape[-2]
    if past_key_value is not None:
        if self.layer_idx is None:
            raise ValueError(
                f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
                "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
                "with a layer index."
            )
        kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
    cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)

    
    if past_key_value is not None:
        cache_kwargs = {"sin": sin, "cos": cos}  # Specific to RoPE models
        key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
    
    query_position_ids = position_ids
    key_position_ids = torch.arange(kv_seq_len, dtype=position_ids.dtype).to(query_position_ids.device).view(bsz, kv_seq_len)


    neighbor_query_states, _ = apply_rotary_pos_emb(query_states, None, cos, sin, query_position_ids) 
    _, neighbor_key_states = apply_rotary_pos_emb(None, key_states, cos, sin, key_position_ids) 
    _re_group_size_2 = 0 if position_ids.max() < group_size_2 else group_size_2 # in case that, the smallest q position, g2-g2//g1 exceed the max position
    group_query_states, _ = apply_grouped_rotary_pos_emb(query_states, None, cos, sin, position_ids, g_size_1=group_size_1, g_size_2=_re_group_size_2) 
    _, group_key_states = apply_grouped_rotary_pos_emb(None, key_states, cos, sin, position_ids, g_size_1=group_size_1, g_size_2=_re_group_size_2) 


    group_key_states = repeat_kv(group_key_states, self.num_key_value_groups)
    neighbor_key_states = repeat_kv(neighbor_key_states, self.num_key_value_groups)
    value_states = repeat_kv(value_states, self.num_key_value_groups)

    neighbor_attn_weights = torch.matmul(neighbor_query_states, neighbor_key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
    group_attn_weights = torch.matmul(group_query_states, group_key_states.transpose(2, 3)) / math.sqrt(self.head_dim) 


    if group_attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
        raise ValueError(
            f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
            f" {group_attn_weights.size()}"
        )
    
    if attention_mask is not None:
        if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
            raise ValueError(
                f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
            )
        group_attn_weights = group_attn_weights + attention_mask
        neighbor_attn_weights = neighbor_attn_weights + attention_mask


    if q_len == 1:
        neighbor_attention_mask = torch.zeros((q_len, kv_seq_len), device=neighbor_attn_weights.device)
        neighbor_attention_mask[:, -group_size_2:] = 1
    elif q_len == kv_seq_len:
        neighbor_attention_mask = torch.ones((q_len, kv_seq_len), device=neighbor_attn_weights.device)
        neighbor_attention_mask = torch.tril(neighbor_attention_mask)
        if q_len-group_size_2 > 0:
            group_attention_mask =  torch.tril(torch.ones((q_len-group_size_2, kv_seq_len-group_size_2), device=group_attn_weights.device))
            neighbor_attention_mask[group_size_2:, :-group_size_2] -= group_attention_mask

    else:
        raise ValueError("q_len should be 1 or seq_len.")


    neighbor_attention_mask = neighbor_attention_mask.bool()
    attn_weights = torch.where(neighbor_attention_mask, neighbor_attn_weights, group_attn_weights)
    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
    attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
    attn_output = torch.matmul(attn_weights, value_states)

    if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
        raise ValueError(
            f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
            f" {attn_output.size()}"
        )

    attn_output = attn_output.transpose(1, 2).contiguous()
    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

    attn_output = self.o_proj(attn_output)

    if not output_attentions:
        attn_weights = None

    return attn_output, attn_weights, past_key_value        


Sources :

https://arxiv.org/pdf/2401.01325.pdf


  • By Kirouane Ayoub

Karim Aberkane

Chief Business Development Officer chez Algerie Telecom Satellite

1 年

It's more interesting to have a comparaison with RAG technique

回复

要查看或添加评论,请登录

AYOUB KIROUANE的更多文章

社区洞察

其他会员也浏览了