Using LLMs as Context-Aware Text Embedding Models - NV-Embed Paper Review

Using LLMs as Context-Aware Text Embedding Models - NV-Embed Paper Review

Can you harness the immense language understanding capabilities of generative models (e.g., large language models) in generating high quality text embeddings? Yes!

An earlier version of this post is on my substack newsletter!


This paper - NV-Embed (NV-Embed: Improved Techniques for Training LLMs as Generalist Embedding Models) demonstrates how to finetune a base LLM (Mistral 7B) to provide state of the art (SOTA) text embeddings. Perhaps, the more important thing about this approach is that it offers a way to move from static embeddings (e.g., traditional embeddings that cannot be changed once the model is trained) to dynamic, context-aware embeddings (embeddings that can be tuned via instructions).

This post includes my thoughts on a quick review of the NV-Embed model (current No #1 on the MTEB leaderboard), key ideas and results from a quick experiment.

TLDR;

  • The approach of instruction tuning embeddings is valuable. It offers the developer an additional lever to tune/optimize the embedding model for the task at hand. This includes building high quality RAG pipelines useful for building agentic systems.
  • My quick experiments show the approach is flexible (it generalizes to a new dataset I tested with - clustering and classifying YC company descriptions), potentially opening new ways to explore, analyze and visualize data.


Text Embeddings

Text embeddings play a crucial role in numerous real-world applications. They are essential for tasks such as data analysis (generating semantic embeddings, clustering, and deriving insights), recommendation systems (suggesting items similar to a given item), and many others.

An embedding is a vector representation of a data point (which could be text, an image, a video, etc.) that encodes the semantic meaning of that data point. These embeddings can be utilized for various applications that require understanding the relationship between data points e.g., computing the relevance or similarity between data points.

One of the key advantages of embeddings is that they enable efficient processing at scale, thanks to fast vector search algorithms like approximate nearest neighbor search (FAISS, Annoy, SCANN, etc). This scalability is fundamental to the Retrieval-Augmented Generation (RAG) pattern, where embeddings are used to identify the most relevant documents to feed into a language model for a given task.

In a typical workflow, all relevant passages are first embedded. Then, at query time, the input query is also embedded, and similarity is computed using metrics such as cosine similarity. This process allows for rapid retrieval of the most semantically relevant information from large document collections.

Critically, the quality of data analysis systems, recommendation engines, and RAG pipelines is heavily dependent on the effectiveness of the retrieval step. Poor retrieval can compromise the entire workflow. Some common challenges in this area include:

  1. Insufficient Semantic Modeling: Embedding models may inadequately capture the semantics of the data. This is a complex problem that has been the focus of extensive research in metric learning and contrastive learning - using the right model architecture, curating the right training data (hard negative mining) and objectives etc
  2. Out-of-Distribution Application: Models are often applied to tasks that diverge significantly from their training data. This is particularly common when off-the-shelf models trained on academic datasets are used to embed text containing customer-specific jargon or domain-specific language.
  3. Query-Passage Structure Mismatch: The classic approach of embedding queries and passages, then using cosine similarity for retrieval, can break down when the structure of queries differs substantially from that of passages. While passages are often rich, self-contained chunks of text, queries can vary widely - they might be questions, concatenations of user context and direct questions, or samples from recent user interactions. In such cases, cosine similarity may not be an adequate measure of relevance.


These challenges underscore the excitement surrounding papers like NVEmbed, which introduce innovative ideas for repurposing Large Language Models (LLMs) as generalist embedding models. NV-Embed attempts to address these three key problems and has achieved the top ranking on the MTEB: Massive Text Embedding Benchmark MTEB leaderboard.


What is the NV Embed Model?

The NV-Embed model is a generalist embedding model designed to significantly enhance the performance of decoder-only large language models (LLMs) for embedding and retrieval tasks. The primary motivations behind its development were threefold: to improve the performance of decoder-only LLMs as versatile embedding models, to create a state-of-the-art embedding model using only publicly available data, and to enhance performance across a wide range of tasks, including retrieval, classification, and clustering. The model aims to address these goals through novel architectural designs and a two-stage training procedure, ultimately achieving superior results on comprehensive embedding benchmarks without relying on proprietary synthetic data from frontier LLMs like GPT-4.

Key Decisions with NV-Embed

The NV-Embed model is based on the Mistral-7B, enabling it leverage the strong language understanding capabilities inherent in a large, pre-trained model like Mistral-7B.

IMO, Some of the key design decisions in the paper include:

  • Contrastive Instruction-Tuning Training: The model employs a two-stage contrastive instruction-tuning method.
  • Removal of Causal Attention Masks: The model eliminates the causal attention constraint during contrastive training. This departure from traditional decoder-only LLM architectures allows the model to learn bidirectional contextual representations.
  • Latent Attention Layer: NV-Embed introduces a novel pooling mechanism featuring 512 latents and 8 multi-head attentions. This innovation aims to improve the quality of sentence embeddings compared to simpler strategies like mean pooling or using the last token's representation.


Experiment: Exploring differences in Embedding Structures

Given that the model is trained to consider instruction prompts, this provides the developer latitude to influence the behaviour and quality of embeddings (at zero cost, no finetuning or training needed).

Method

To explore this, I ran a simple experiment to visualize the structural impact of instructions on the extracted embeddings using data from YCombinator. Overall process was:

  • Using data on all YC companies from Jan 2010 - August 2024, extract company descriptions
  • Extract embeddings for each company using the NV-Embed model with multiple instruction template conditions.
  • Visualize embeddings i.e., reduce to 2 dimensions using TSNE, generate an interactive plot.


Instruction Templates

I had three main instruction template conditions (below) and a fourth case where no instruction is given.

Given a company description, retrieve other companies that are semantically similar or are in the same domain? \nQuery:         
Classify the company description, as Artificial Intelligence (AI)  or not artificial intelligence. \nQuery        
Classify the company description as health domain or not health domain. \nQuery        

The first follows a retrieval template while the next two follow a classification template. The code used for this experiment is available in the reference section. The relevant section is here:

semantic_clustering_instruction = "Given a company description, retrieve other companies that are semantically similar or are in the same domain." 
semantic_clustering_embeddings =  get_embeddings(model, yc_desc, semantic_clustering_instruction, max_length=max_seq_length )

semantic_reduced_dims = reduce_dimensions(semantic_clustering_embeddings, 2)  
save_json(tensor_to_json(semantic_clustering_embeddings), 'data/semantic_embeddings.json')
plot_clusters(semantic_reduced_dims, df, color_by='mentions_ai', title="Semantic Instruction Clustering of YC Companies")        


The general hypothesis here is:

  • Instruction embeddings are better than the base condition (no instruction). E.g. in the AI classification case we want to see better separation of AI vs non AI companies in the embeddings compared to the base condition. Similarly


Evaluation

So how do we verify or interpret the quality of these embeddings? Well, the right way would be to carefully construct a benchmark with labels and compute standard retrieval metrics such as ndcg@k and classification metrics. For classification, we'd typically look at metrics like accuracy, precision, recall, and F1-score. These metrics provide a quantitative measure of how well our embeddings can be used to categorize companies into different groups.

In this case, we'll take some liberties and infer a few labels, then iteratively explore each visualization to make sense of the data. For labels, we'll add a "mentions_ai" field to our dataset column and a "mentions_health" column, both based on regular expressions.


It's important to note that in a production environment, you'd typically need more than this as a first benchmark. However, this approach serves as a starting point for our analysis.

Next, we'll plot the data points colored by these labels to see how well the embeddings map to the data. For example, we expect that the condition instructed on semantic relationships will show visible clusters, and the condition instructed on AI will show clear separation between AI and non-AI companies. By visually inspecting these plots, we can gain insights into how well our embeddings capture the intended semantic information.


Results


Retrieval

Base embedding vs semantic instruction

We see well much better defined clusters when using the semantic instruction embeddings. In theory, these clusters structure should yield better retrieval results


Classification

Base embedding vs AI instruction embedding.

The results here were particularly fascinating. The TLDR is that all companies that use AI were effectively clustered together, with small micro-clusters representing areas like AI for biology, AI for clinical infrasctructure, AI for security etc. This opens up new ways to visualize and explore data with a specific context or lens.


Similar results when comparing a base embedding vs health instruction embedding. In the later, we see well structured clustering of health companies with sub clusters focused on health. Again, this structure can be exploited to design hierarchical retrieval systems.



Extra Credit

  • Compute simple clustering metrics like silhouette scores for each condition.
  • Train a simple classifier on our embeddings and evaluate its classification performance (accuracy, precision, F1)


Conclusion

Some high-level takeaways from this exercise:

  • Models like NV-Embed provide a way to "tune" embedding quality post-training via instructions. This is a massive boost for the observant ML Engineer. It provides an extra surface to optimize and improve your system, be it a retrieval, recommendation, or RAG pipeline. Take advantage of it!
  • Results indeed show that instructions actually change the structure of the embeddings. Following the instruction template used to train the model matters. As a developer, you will likely construct multiple versions of prompts and embeddings for various tasks. This opens up new ways to visualize and explore data with a specific context or lens.
  • The model is still large - 7B parameters still require some GPU muscle. My explorations here were done on an A600 GPU consumer. There is some work to be done to optimize the inference process here.
  • Academic benchmarks are not your task. My example with YC companies is also not your task. You should still conduct your own experiments :) .


References

Code for experiments here

  1. Lee, C., Roy, R., Xu, M., Raiman, J., Shoeybi, M., Catanzaro, B., & Ping, W. (2024). NV-Embed: Improved Techniques for Training LLMs as Generalist Embedding Models. arXiv preprint arXiv:2405.17428.

要查看或添加评论,请登录