Optimizing Retrieval Augmented Generation with Learned Chunking
Krupa Galiya
PyGirl??|| Data Scientist || Google Developer Expert || Machine Learning Engineer || Tech Speaker
In collaboration with Elan Markowitz (PhDc)
Introduction
Retrieval Augmented Generation (RAG) techniques recall relevant portions of text to give an LLM context before answering a query. However, this process relies on the documents being preprocessed into smaller, more manageable chunks of text.
Chunking is thus a crucial, but understudied step in the RAG process. The RAG process starts by chunking a corpus of documents into smaller portions. These chunks will then be embedded so that they can be retrieved according to their similarity to a user’s query.
These chunks are preprocessed, embedded with an embedding model, and stored in a database.
These document chunks can then be used during the RAG process. At inference time, the user query is similarly embedded and then the most similar document chunks are retrieved. These chunks are then passed into the context of the LLM for generating its answer.
The LLM that generates answers is optimized for generating text, and the embedding model is optimized for retrieval. However, no method has introduced an optimizable framework for learning the chunking algorithms.
This blog post explores one such approach to learning chunking algorithms.
Motivation
Current text-based splitters generally use rules or heuristic-based approaches to chunking
Our approach is Learned Chunking. We train a model that predicts where the next chunk break should go based on improvement in retrieving answers for downstream Question Answering.
Train the Chunker End-to-End on Question Answering
The goal is to train the chunker to directly optimize the retrieval process. To do this we use a preference optimization approach where the model has a choice between two sampled locations to set as the chunk divider, and we empirically determine which leads to better question answering and context retrieval.
Once the chunker is trained, it can be used as follows. Given a window of an unchunked portion of a document, the chunker predicts where the next chunk break should go. This creates a new document chunk and then the window slides forward. This process is repeated until the chunker reaches the end of the document.
We use Google’s “Flan-T5-Base” as our base model.
Data Collection and Processing
We use the SQuAD (Stanford Question Answering Dataset) for training and evaluation. SQuAD is a contextual question answering dataset. Each question-answer pair is mapped to a specific context text derived from Wikipedia that is needed to answer the question. Moreover, each QA pair also contains the index of where the answer comes from. We use this information in our training process.
To turn the dataset from a contextual question answering task, to a RAG task, we incorporate additional retrieval documents from Cohere/wikipedia-22–12-en-embeddings. We include the cohere documents and embeddings from texts derived from articles matching document titles in SQuAD. This gives us a corpus of 35,211 embedded document chunks from 401 wikipedia articles.
领英推荐
Model Configuration and Training Procedure
The training procedure goes as follows. For a given Question Answering pair and context we run the chunker greedily over the context document until we get a chunk containing the answer sentence. Once that happens we sample an alternative chunk location as well. In parallel, We continue to greedily chunk the context document from each of these two locations. This leaves us with two alternative chunkings of the context document with the first difference occurring near to the answer location.
We then embed the chunks from each of the two sample chunkings and do a standard retrieval process for answering the question. We use a subset of embeddings from Cohere/wikipedia-22–12-en-embeddings (link) as distraction documents for the retrieval process. This is similar to the actual RAG process wherein retrieval is done over an entire corpus.
We then assess the quality of the retrievals according to the metric tokens-to-answer-chunk. The retrieval process produces a ranking of relevant chunks. This metric sums the cumulative tokens in chunks retrieved up to and including the chunk containing the answer. This metric favors when the answer chunk is ranked higher as there will be fewer chunks, and thus tokens, preceding the answer. However, it also favors smaller chunks, all else being equal, as that means fewer tokens as well. Fewer tokens means RAG is more likely to get the answer in a model’s context window when actually tested.
This process is used to generate preference data for the chunker. The sampling process with the lower resulting Tokens-to-Answer is considered the preferred sample. We then train with Cross-Entropy loss.
Backpropagation and Optimizing the Loss
Once a batch of preference pairs is established, we re-run the forward pass for the chunking step in which the samples diverge. We then use a loss to optimize the preferred chunk location over the alternative one.
Initial Results
With the integration of weights and biases, we monitored average tokens for both the samples, Loss, and entropy values.
Observations
We see some hints that this approach could work. We can reduce the tokens-to-answer by 100+ tokens. However, this is a relatively small scale improvement from small scale experiments.
We see that learning can happen but is heavily limited. One challenge is that the training process is incredibly noisy. As we only train on the ability to answer a single question from the passage at a time, rather than all questions for a passage, the signal is more limited.
Limitations
Learnable chunking approaches could be an important part of the RAG process, however, this current implementation has a number of limitations.
Conclusion
Our Learned Chunking approach could be a powerful way to optimize a currently heuristic-based part of the RAG process. While the current version is not ready for use, future iterations could enable a Learned chunking approach.
Startup-Empowering Software Developer | 6+ Years Experience | Freelancer | Android | Flutter | FlutterFlow | iOS | Web Development | Co-organizer @FFDGGandhinagar
5 个月Hmm, That's why you don't want to join me in Google Ai Hackathon. I get it Krupa Galiya ?? Btw, you are doing great work!! & Keep Doing it. ????
Co-Founder of Altrosyn and DIrector at CDTECH | Inventor | Manufacturer
5 个月This is a fascinating direction for RAG research. Learning chunking could indeed lead to more robust and adaptable retrieval systems. I mean, the ability to dynamically adjust chunk sizes based on context would be a huge leap forward. How do you envision this learned chunking interacting with existing ranking algorithms within the RAG pipeline?