Transformer feed-forward layers are key-value memories
DALL-E 3: Memory as a Transformer

Transformer feed-forward layers are key-value memories

Geva, M., Schuster, R., Berant, J., & Levy, O. (2020). Transformer feed-forward layers are key-value memories. arXiv preprint arXiv:2012.14913.

The paper analyzes the role of feed-forward layers in transformer models. Feed-forward layers constitute two-thirds of the parameters in a transformer model, yet their specific function has been relatively under-explored compared to the self-attention layers. The authors make the key observation that feed-forward layers are structurally very similar to key-value memories, just without the normalization. Each feed-forward layer contains two parameter matrices - keys and values. The input is multiplied with the keys to get a weight for each key. These weights are used to compute a weighted sum of the values to produce the output.

An illustration of how a feed-forward layer emulates a key-value memory. Input vectors (here, x5) are multiplied by keys to produce memory coefficients (e.g., the memory coefficient for v1 is 0.2), which then weigh distributions over the output vocabulary, stored in the values. The feed-forward layer’s output is thus the weighted sum of its values.

The paper aims to understand - what kind of memories are stored in these feed-forward layers? What patterns do the keys capture and what is represented in their corresponding values? How does the final transformer output aggregate information from multiple feed-forward layer memories across layers? Through extensive experiments, the paper shows that keys correlate with human-interpretable patterns in the inputs like n-grams or topics. The values represent probability distributions over vocabulary that tend to predict the token following the key's input pattern, especially in upper layers.

Examples of human-identified patterns that trigger different memory keys.
Example values, their top prediction, the fraction of their key’s top-50 trigger examples that agree with their prediction, and a matching trigger example (with the target token marked in blue).

Lower layer keys correlate more with shallow surface patterns, whereas upper layer keys capture more semantic topics. For example, a lower layer key may activate on inputs ending in "substitutes," whereas an upper layer key may trigger on military/bases related context. The values in the upper layers also have distributions that rank the correct next token higher, agreeing with the actual next token in the key's triggering examples. So, the key-value pairs learn to store continuation predictions for input patterns.

The fraction of active memories (i.e., with positive memory coefficient) out of 4096 memories in every layer, for a random sample of 4,000 examples.

Each feed-forward layer combines hundreds of active memories, composing a distribution different from any single value distribution. The residuals between layers allow sequential refinement of the predictions. The feed-forward output provides a parallel memory combination, which is gently tuned across layers through residuals - making small tweaks without drastically altering the information. Around 30% of predictions match the residual by the early layers, showing models can make decisions early on for many cases. But residuals continue to update probabilities even for agreed decisions, increasing confidence. Only in extremely rare cases does the feed-forward output directly override residuals. Mostly, compositions create compromise predictions between the two.

Fraction of examples in each layer, where the residual’s top prediction matches the model’s output.

The authors conclude that feed-forward layers fundamentally act as detectors of patterns in the inputs across all layers. The patterns become more semantic in higher layers. The values store predictions for the output tokens that follow these input patterns. The model composes a bottom-up construction of the output distribution through weighted aggregation of hundreds of memories within each layer, which are then sequentially refined through residual connections.?

This understanding of feed-forward layers opens up several future research directions - exploring why value distributions correlate more with continuation tokens in upper layers hinting at embedding space transformations; extending the key-value memory formulation to other transformer models and studying what generic insights apply; and developing practical applications leveraging this knowledge of memories - better interpretability, evaluating privacy risks, and guiding architectural innovations.

Probability of the token output by the model according to the residual of each layer.

In conclusion, the paper pushes our understanding of powerful transformer models by unveiling the role of the previously mysterious feed-forward layers. They operate as detectors of patterns in the inputs stored with associated predictions for the output tokens that follow those patterns. The full model aggregates information from these memories within and across layers for composing its rich output distributions. The insights from this research can guide future work on areas like interpretability, privacy, and model architecture.

Rashmi Sundareson

Marketing| Growth Lead | Brand Management | Business Development | Digital Transformations | Techpreneur || Speaker

8 个月

?Vijay Raghavan Ph.D., M.B.A., well-written, informative!

回复

要查看或添加评论,请登录

Vijay Raghavan Ph.D., M.B.A.,的更多文章

社区洞察

其他会员也浏览了