Using LLMs as Context-Aware Text Embedding Models - NV-Embed Paper Review
Victor Dibia, PhD
Principal RDSE at Microsoft Research (Generative AI, Agents) | Carnegie Mellon Alumnus
Can you harness the immense language understanding capabilities of generative models (e.g., large language models) in generating high quality text embeddings? Yes!
An earlier version of this post is on my substack newsletter!
This paper - NV-Embed (NV-Embed: Improved Techniques for Training LLMs as Generalist Embedding Models) demonstrates how to finetune a base LLM (Mistral 7B) to provide state of the art (SOTA) text embeddings. Perhaps, the more important thing about this approach is that it offers a way to move from static embeddings (e.g., traditional embeddings that cannot be changed once the model is trained) to dynamic, context-aware embeddings (embeddings that can be tuned via instructions).
This post includes my thoughts on a quick review of the NV-Embed model (current No #1 on the MTEB leaderboard), key ideas and results from a quick experiment.
TLDR;
Text Embeddings
Text embeddings play a crucial role in numerous real-world applications. They are essential for tasks such as data analysis (generating semantic embeddings, clustering, and deriving insights), recommendation systems (suggesting items similar to a given item), and many others.
An embedding is a vector representation of a data point (which could be text, an image, a video, etc.) that encodes the semantic meaning of that data point. These embeddings can be utilized for various applications that require understanding the relationship between data points e.g., computing the relevance or similarity between data points.
One of the key advantages of embeddings is that they enable efficient processing at scale, thanks to fast vector search algorithms like approximate nearest neighbor search (FAISS, Annoy, SCANN, etc). This scalability is fundamental to the Retrieval-Augmented Generation (RAG) pattern, where embeddings are used to identify the most relevant documents to feed into a language model for a given task.
In a typical workflow, all relevant passages are first embedded. Then, at query time, the input query is also embedded, and similarity is computed using metrics such as cosine similarity. This process allows for rapid retrieval of the most semantically relevant information from large document collections.
Critically, the quality of data analysis systems, recommendation engines, and RAG pipelines is heavily dependent on the effectiveness of the retrieval step. Poor retrieval can compromise the entire workflow. Some common challenges in this area include:
These challenges underscore the excitement surrounding papers like NVEmbed, which introduce innovative ideas for repurposing Large Language Models (LLMs) as generalist embedding models. NV-Embed attempts to address these three key problems and has achieved the top ranking on the MTEB: Massive Text Embedding Benchmark MTEB leaderboard.
What is the NV Embed Model?
The NV-Embed model is a generalist embedding model designed to significantly enhance the performance of decoder-only large language models (LLMs) for embedding and retrieval tasks. The primary motivations behind its development were threefold: to improve the performance of decoder-only LLMs as versatile embedding models, to create a state-of-the-art embedding model using only publicly available data, and to enhance performance across a wide range of tasks, including retrieval, classification, and clustering. The model aims to address these goals through novel architectural designs and a two-stage training procedure, ultimately achieving superior results on comprehensive embedding benchmarks without relying on proprietary synthetic data from frontier LLMs like GPT-4.
Key Decisions with NV-Embed
The NV-Embed model is based on the Mistral-7B, enabling it leverage the strong language understanding capabilities inherent in a large, pre-trained model like Mistral-7B.
IMO, Some of the key design decisions in the paper include:
Experiment: Exploring differences in Embedding Structures
Given that the model is trained to consider instruction prompts, this provides the developer latitude to influence the behaviour and quality of embeddings (at zero cost, no finetuning or training needed).
Method
To explore this, I ran a simple experiment to visualize the structural impact of instructions on the extracted embeddings using data from YCombinator. Overall process was:
Instruction Templates
I had three main instruction template conditions (below) and a fourth case where no instruction is given.
Given a company description, retrieve other companies that are semantically similar or are in the same domain? \nQuery:
Classify the company description, as Artificial Intelligence (AI) or not artificial intelligence. \nQuery
Classify the company description as health domain or not health domain. \nQuery
The first follows a retrieval template while the next two follow a classification template. The code used for this experiment is available in the reference section. The relevant section is here:
semantic_clustering_instruction = "Given a company description, retrieve other companies that are semantically similar or are in the same domain."
semantic_clustering_embeddings = get_embeddings(model, yc_desc, semantic_clustering_instruction, max_length=max_seq_length )
semantic_reduced_dims = reduce_dimensions(semantic_clustering_embeddings, 2)
save_json(tensor_to_json(semantic_clustering_embeddings), 'data/semantic_embeddings.json')
plot_clusters(semantic_reduced_dims, df, color_by='mentions_ai', title="Semantic Instruction Clustering of YC Companies")
The general hypothesis here is:
Evaluation
So how do we verify or interpret the quality of these embeddings? Well, the right way would be to carefully construct a benchmark with labels and compute standard retrieval metrics such as ndcg@k and classification metrics. For classification, we'd typically look at metrics like accuracy, precision, recall, and F1-score. These metrics provide a quantitative measure of how well our embeddings can be used to categorize companies into different groups.
In this case, we'll take some liberties and infer a few labels, then iteratively explore each visualization to make sense of the data. For labels, we'll add a "mentions_ai" field to our dataset column and a "mentions_health" column, both based on regular expressions.
It's important to note that in a production environment, you'd typically need more than this as a first benchmark. However, this approach serves as a starting point for our analysis.
Next, we'll plot the data points colored by these labels to see how well the embeddings map to the data. For example, we expect that the condition instructed on semantic relationships will show visible clusters, and the condition instructed on AI will show clear separation between AI and non-AI companies. By visually inspecting these plots, we can gain insights into how well our embeddings capture the intended semantic information.
Results
Retrieval
Base embedding vs semantic instruction
We see well much better defined clusters when using the semantic instruction embeddings. In theory, these clusters structure should yield better retrieval results
Classification
Base embedding vs AI instruction embedding.
The results here were particularly fascinating. The TLDR is that all companies that use AI were effectively clustered together, with small micro-clusters representing areas like AI for biology, AI for clinical infrasctructure, AI for security etc. This opens up new ways to visualize and explore data with a specific context or lens.
Similar results when comparing a base embedding vs health instruction embedding. In the later, we see well structured clustering of health companies with sub clusters focused on health. Again, this structure can be exploited to design hierarchical retrieval systems.
Extra Credit
Conclusion
Some high-level takeaways from this exercise:
References
Code for experiments here