Teaching AI New Thanksgiving Traditions using Retrieval Augmentation

Teaching AI New Thanksgiving Traditions using Retrieval Augmentation

One of the issues with foundational models, from an enterprise perspective, is that they are pre-trained with generic data that cuts off at a certain point in time. Companies wish to interact with AI solutions using data that is relevant to them. The challenge is in how to make the data available to a model given that training a model on company data is expensive and to have up-to-date responses data has to be constantly fed to the model.??

One way to do this is to use a Retrieval Augmented Generation ?(RAG) pattern. Simply, we pass the relevant knowledge, along with the question, to set the context for the LLM to be able to answer the question with information it does not have in its training set.? ? The aim of this article is to introduce the Retrieval Augmented Generation pattern and provide a (very) simple example to reinforce the concept. ?

To keep this as simple as possible our example will use HuggingFace's T5ForConditionalGeneration model which is an implementation of the T5 text-to-text transfer transformer model architecture. Unlike something like ChatGPT, T5 has a smaller fixed-size vocabulary and the T5 model is being run locally after loading the pretrained checkpoints via HuggingFace Transformers. The benefit is it allows fully offline usage without dependencies on external services.? ?

Again for simplicity, the example focuses purely on text retrieval to demonstrate the RAG pattern and therefore lacks a demonstration of encoding context into dense vectors before augmentation (which is commonly done with commercial models like GPT-3 and Bard).?

‘T5’ in the model’s name stands for "Text-To-Text Transfer Transformer" and was created by Google AI. It is pretrained on a multi-task mixture of unsupervised and supervised tasks which has the benefit of allowing it to be fine-tuned on a wide range of Natural Language Processing ( NLP) tasks.?

T5 transforms all text inputs to text outputs, framing tasks as text generation. This unifies the model architecture across many tasks. It is pretrained on the C4 corpus containing hundreds of gigabytes of English text from the web.?

From the perspective of our example, T5ForConditionalGeneration provides an easy API for text generation by taking care of passing the encoder states to the decoder model. This makes it suitable to be used for tasks like translation, summarization, and importantly for our example, question answering, by feeding the task specific input text .?

T5 has versions in various sizes from T5-Small to T5-11B with increasing model capacity. T5-Base is a good starting point and the HuggingFace implementation makes it easy to use T5 for augmentation.?

The simple code below demonstrates the following steps:? ?

Input: The input is the prompt "Tell me about Thanksgiving"?

Retrieve: The input is used to retrieve relevant context from Wikipedia using the wikipediaapi module?

Augment Context: The context is concatenated with the original input prompt.?

Pass to T5 LLM: This combined text is passed to the T5 LLM model. ?

Generate Output: T5 generates the final output text conditioned on the input prompt + retrieved context.?

This breaks out into the following code:

(to run the below code you will need to have PyTorch, Tensorflow, Transformers, SentencePiece and Wikipedia-api installed in Python.)

import wikipediaapi 

from transformers import T5ForConditionalGeneration, AutoTokenizer 

import logging 

# Set level to only output errors   

logging.getLogger("transformers").setLevel(logging.ERROR)  

# Initialize Wikipedia API usage 

wiki = wikipediaapi.Wikipedia(user_agent='RAG-Example ([email protected])') 

model = T5ForConditionalGeneration.from_pretrained("t5-base") 

tokenizer = AutoTokenizer.from_pretrained("t5-base", model_max_length=512) 

def retrieve_context(input_prompt): 

  

  topic = extract_topic(input_prompt) 

  page = wiki.page(topic) 

  if not page.exists(): 

    return "Page does not exist"   

  context = page.summary 

  for s in page.sections: 

    context += s.text 

  return context 

def extract_topic(input_prompt): 

  topic = input_prompt.split('about')[-1].strip() 

  return topic 

input_prompt = "Tell me about Thanksgiving" 

context = retrieve_context(input_prompt) 

input_text = input_prompt + " " + context 

inputs = tokenizer(input_text, return_tensors="pt") 

output = model.generate(max_length=600, **inputs)  

#Show the original retrieval context 

print("This is the original context from Wikipedia") 

print() 

print(context) 

print() 

#Show the context as summarized from the T5 model 

print("This is the summarized LLM output") 

print() 

# Decode and strip tokens 

output_text = tokenizer.decode(output[0], skip_special_tokens=True) 

print(output_text)         

Running this code results in:

What we have shown is how to provide additional knowledge to an LLM to enable it to provide a natural language summary answer based on the input prompt. This is a very simplified example, a single returned Wikipedia page, with a concise summary, to demonstrate the core retrieval concept. If we wanted to have a longer summary, we could increase the max_length parameter to allow more tokens to be generated. ?

In this example the T5 model does not automatically know which parts of the retrieved context are most relevant to focus on when generating text so If we wanted to enhance even this simple example, we could consider highlighting key sentences or passages with special tokens such as <hl> to indicate importance. We could also enhance this example to query additional context from related Wikipedia pages, not just the topic article. With more varied context beyond a single document, passing it through T5 has better knowledge and therefore better potential to synthesize complimentary information from different sources.?

For this extremely simple example using Wikipedia serves its purpose but retrieval augmentation can also be used effectively with in-house SQL databases, other API’s and/or search solutions.? ? We could replace our Wikipedia retrieval with something like the following pseudo code:? ?

def retrieve_from_sql(input_prompt):     

    topic = extract_topic(input_prompt)  

    # Query SQL database for topic 

    results = sql_database.query(""" 

        SELECT * FROM contents 

        WHERE topic = {input_topic} 

    """, input_topic=topic) 

    return format_sql_results(results) ????        

The key idea is that the same prepare context function can be reused, while swapping out the retrieval mechanism to pull data from a SQL database versus Wikipedia. Although a very simple example, the modular design allows for easily plugging in different context sources.? ?

In our walkthrough, for simplicity we have focused purely on text retrieval for augmenting the document context but, when using external AI solutions such as OpenAI and Bard, vectorization of knowledge for storage and lookup is a key component of real-world retrieval augmented generation systems and I’ll cover a worked example of this in a future post.?


Roman Gelembjuk

Team Lead Software Developer

11 个月

What if we want to allow AI to detect the topic of the request ? Is there something in that models that can parse then input and detect what user wants to know and then will look in external modules for extra info?

回复
Roman Gelembjuk

Team Lead Software Developer

11 个月

It is very interesting. Thanks Jim. The article demonstrates that AI is not just chatgpt. There are other interesting technologies to explore.

回复
Mark Ward

Chief Operating Officer

11 个月

Great article!

回复

Love it, Jim! ??

回复

Great and timely post Jim. Happy Thanksgiving

要查看或添加评论,请登录

社区洞察

其他会员也浏览了