Geek Out Time: AI Model Routing — Dynamically Choose Models Based on Question Complexity
(Also on Constellar tech blog https://medium.com/the-constellar-digital-technology-blog/geek-out-time-ai-model-routing-dynamically-choose-models-based-on-question-complexity-c6a37cbeef85)
Dynamic routing of questions to AI models based on their complexity is an interesting and cost-effective approach. Whether it’s answering simple trivia or explaining complex topics, this solution leverages GPT-4 for question complexity scoring and routes the query to the most suitable model for generating answers. Let’s dive into this exciting topic and geek out!
The Idea
AI models excel in different tasks based on their size and training. Why use a heavyweight model for a simple question like “What is AI?” when a lightweight model like DistilBERT can do the job? Conversely, you wouldn’t want DistilBERT to tackle a complex topic like “Explain the significance of transfer learning in deep learning models.”
Now let’s walk through:
1. Scoring Question Complexity
We use GPT-4 to classify questions into three categories:
Implementation
Here’s the code for scoring question complexity. If GPT-4 fails to provide scores, we use a fallback mechanism based on the question’s length.
import re
import logging
from typing import Tuple
from openai import OpenAI
# Configure logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
class QuestionComplexityScorer:
"""Class to score question complexity using GPT-4 or a fallback mechanism."""
def __init__(self, api_key: str, model: str = "gpt-4"):
self.client = OpenAI(api_key=api_key)
self.model = model
def _fallback_scores(self, question: str) -> Tuple[float, float, float]:
"""Fallback scoring based on question length."""
logger.debug(f"Using fallback heuristic for question: {question}")
words = len(question.split())
if words <= 5:
return (1.0, 0.0, 0.0) # Simple
elif words <= 15:
return (0.0, 1.0, 0.0) # Moderate
else:
return (0.0, 0.0, 1.0) # Complex
def get_gpt4_scores(self, question: str) -> Tuple[float, float, float]:
"""Get complexity scores using GPT-4."""
prompt = f"""Rate the complexity of the following question with numbers between 0 and 1.
Respond ONLY in the format: Simple: <number>, Moderate: <number>, Complex: <number>.
The numbers must sum to 1.
Example 1:
Question: What is the capital of France?
Answer: Simple: 0.9, Moderate: 0.1, Complex: 0.0
Example 2:
Question: Explain quantum mechanics in simple terms.
Answer: Simple: 0.1, Moderate: 0.4, Complex: 0.5
Now evaluate this question:
Question: {question}
Answer:"""
try:
response = self.client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt}
],
temperature=0,
max_tokens=50
)
gpt4_response = response.choices[0].message.content
logger.debug(f"GPT-4 response for question '{question}': {gpt4_response}")
# Parse scores using regex
scores = re.findall(r"(?:Simple|Moderate|Complex):\s*(\d*\.?\d+)", gpt4_response)
if len(scores) != 3:
raise ValueError("Invalid GPT-4 response")
return tuple(float(score) for score in scores)
except Exception as e:
logger.error(f"GPT-4 scoring failed: {e}")
return self._fallback_scores(question)
2. Routing Questions to the Right Model
Once we have the complexity scores, we route the question to one of the following:
领英推荐
Implementation
Here’s the code for routing and processing questions using the selected model:
from transformers import (
DistilBertForQuestionAnswering, DistilBertTokenizer,
BertForQuestionAnswering, BertTokenizer,
T5ForConditionalGeneration, T5Tokenizer
)
import torch
class ModelRouter:
"""Class to route questions to the appropriate model based on scores."""
def __init__(self, context: str = ""):
self.context = context or "AI is the simulation of human intelligence in machines."
self.models = {
"simple": (
DistilBertForQuestionAnswering.from_pretrained("distilbert-base-cased-distilled-squad"),
DistilBertTokenizer.from_pretrained("distilbert-base-cased-distilled-squad")
),
"moderate": (
BertForQuestionAnswering.from_pretrained("deepset/bert-base-uncased-squad2"),
BertTokenizer.from_pretrained("deepset/bert-base-uncased-squad2")
),
"complex": (
T5ForConditionalGeneration.from_pretrained("t5-base"),
T5Tokenizer.from_pretrained("t5-base")
)
}
def route(self, question: str, scores: Tuple[float, float, float]) -> str:
"""Route question to appropriate model and generate answer."""
labels = ["simple", "moderate", "complex"]
selected_label = labels[scores.index(max(scores))]
model, tokenizer = self.models[selected_label]
if selected_label in ["simple", "moderate"]:
inputs = tokenizer(question, self.context, return_tensors="pt", truncation=True, max_length=512)
with torch.no_grad():
outputs = model(**inputs)
start_idx = torch.argmax(outputs.start_logits)
end_idx = torch.argmax(outputs.end_logits)
answer_tokens = inputs.input_ids[0][start_idx:end_idx + 1]
return tokenizer.decode(answer_tokens, skip_special_tokens=True)
else:
inputs = tokenizer(f"question: {question} context: {self.context}", return_tensors="pt", truncation=True)
with torch.no_grad():
outputs = model.generate(**inputs, max_length=50)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
3. The Complete Workflow
Integrate scoring and routing into a simple workflow.
def main():
scorer = QuestionComplexityScorer(api_key="your_openai_api_key")
router = ModelRouter()
test_questions = [
"What is AI?",
"How does reinforcement learning differ from supervised learning?",
"Explain the significance of transfer learning in deep learning models with examples."
]
for question in test_questions:
print("\n" + "=" * 80)
scores = scorer.score_question(question)
print(f"\nQuestion: {question}")
print(f"Complexity Scores: Simple: {scores[0]:.2f}, Moderate: {scores[1]:.2f}, Complex: {scores[2]:.2f}")
answer = router.route(question, scores)
print(f"Answer: {answer}")
Outputs
Here’s the output:
================================================================================
Question: What is AI?
Complexity Analysis:
- Simple: 0.20
- Moderate: 0.60
- Complex: 0.20
Selected Model: BERT (Base model for moderate complexity QA)
Generated Answer: artificial intelligence
================================================================================
Question: How does reinforcement learning differ from supervised learning?
Complexity Analysis:
- Simple: 0.10
- Moderate: 0.40
- Complex: 0.50
Selected Model: T5 (Advanced model for complex questions and generation)
Generated Answer: learning by interacting with an environment and receiving rewards or penalties
================================================================================
Question: Explain the significance of transfer learning in deep learning models with examples.
Complexity Analysis:
- Simple: 0.10
- Moderate: 0.30
- Complex: 0.60
Selected Model: T5 (Advanced model for complex questions and generation)
Generated Answer: allows models to leverage knowledge from pre-trained models
Conclusion
This solution showcases the power of combining LLMs and model routing to build efficient, context-aware systems:
For even better routing decisions and answer quality, we could further explore Reinforcement Learning with Human Feedback (RLHF) to fine-tune the complexity scoring and model outputs. RLHF has shown great potential in aligning AI outputs with human preferences, making it an exciting direction for future iterations.
Give it a try and let me know how it goes on your end. Have fun experimenting!