Geek Out Time: AI Model Routing — Dynamically Choose Models Based on Question Complexity

Geek Out Time: AI Model Routing — Dynamically Choose Models Based on Question Complexity

(Also on Constellar tech blog https://medium.com/the-constellar-digital-technology-blog/geek-out-time-ai-model-routing-dynamically-choose-models-based-on-question-complexity-c6a37cbeef85)

Dynamic routing of questions to AI models based on their complexity is an interesting and cost-effective approach. Whether it’s answering simple trivia or explaining complex topics, this solution leverages GPT-4 for question complexity scoring and routes the query to the most suitable model for generating answers. Let’s dive into this exciting topic and geek out!

The Idea

AI models excel in different tasks based on their size and training. Why use a heavyweight model for a simple question like “What is AI?” when a lightweight model like DistilBERT can do the job? Conversely, you wouldn’t want DistilBERT to tackle a complex topic like “Explain the significance of transfer learning in deep learning models.”

Now let’s walk through:

  • Scoring question complexity using GPT-4.
  • Routing questions to DistilBERT, BERT, or T5 based on complexity.
  • Generating accurate answers tailored to the question’s difficulty.

1. Scoring Question Complexity

We use GPT-4 to classify questions into three categories:

  • Simple
  • Moderate
  • Complex

Implementation

Here’s the code for scoring question complexity. If GPT-4 fails to provide scores, we use a fallback mechanism based on the question’s length.

import re
import logging
from typing import Tuple
from openai import OpenAI
# Configure logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
class QuestionComplexityScorer:
    """Class to score question complexity using GPT-4 or a fallback mechanism."""
    
    def __init__(self, api_key: str, model: str = "gpt-4"):
        self.client = OpenAI(api_key=api_key)
        self.model = model
    def _fallback_scores(self, question: str) -> Tuple[float, float, float]:
        """Fallback scoring based on question length."""
        logger.debug(f"Using fallback heuristic for question: {question}")
        words = len(question.split())
        if words <= 5:
            return (1.0, 0.0, 0.0)  # Simple
        elif words <= 15:
            return (0.0, 1.0, 0.0)  # Moderate
        else:
            return (0.0, 0.0, 1.0)  # Complex
    def get_gpt4_scores(self, question: str) -> Tuple[float, float, float]:
        """Get complexity scores using GPT-4."""
        prompt = f"""Rate the complexity of the following question with numbers between 0 and 1.
Respond ONLY in the format: Simple: <number>, Moderate: <number>, Complex: <number>.
The numbers must sum to 1.
Example 1:
Question: What is the capital of France?
Answer: Simple: 0.9, Moderate: 0.1, Complex: 0.0
Example 2:
Question: Explain quantum mechanics in simple terms.
Answer: Simple: 0.1, Moderate: 0.4, Complex: 0.5
Now evaluate this question:
Question: {question}
Answer:"""
        try:
            response = self.client.chat.completions.create(
                model=self.model,
                messages=[
                    {"role": "system", "content": "You are a helpful assistant."},
                    {"role": "user", "content": prompt}
                ],
                temperature=0,
                max_tokens=50
            )
            gpt4_response = response.choices[0].message.content
            logger.debug(f"GPT-4 response for question '{question}': {gpt4_response}")
            # Parse scores using regex
            scores = re.findall(r"(?:Simple|Moderate|Complex):\s*(\d*\.?\d+)", gpt4_response)
            if len(scores) != 3:
                raise ValueError("Invalid GPT-4 response")
            return tuple(float(score) for score in scores)
        except Exception as e:
            logger.error(f"GPT-4 scoring failed: {e}")
            return self._fallback_scores(question)        

2. Routing Questions to the Right Model

Once we have the complexity scores, we route the question to one of the following:

  • DistilBERT for simple questions.
  • BERT for moderately complex questions.
  • T5 for complex and generative tasks.

Implementation

Here’s the code for routing and processing questions using the selected model:

from transformers import (
    DistilBertForQuestionAnswering, DistilBertTokenizer,
    BertForQuestionAnswering, BertTokenizer,
    T5ForConditionalGeneration, T5Tokenizer
)
import torch
class ModelRouter:
    """Class to route questions to the appropriate model based on scores."""
    
    def __init__(self, context: str = ""):
        self.context = context or "AI is the simulation of human intelligence in machines."
        self.models = {
            "simple": (
                DistilBertForQuestionAnswering.from_pretrained("distilbert-base-cased-distilled-squad"),
                DistilBertTokenizer.from_pretrained("distilbert-base-cased-distilled-squad")
            ),
            "moderate": (
                BertForQuestionAnswering.from_pretrained("deepset/bert-base-uncased-squad2"),
                BertTokenizer.from_pretrained("deepset/bert-base-uncased-squad2")
            ),
            "complex": (
                T5ForConditionalGeneration.from_pretrained("t5-base"),
                T5Tokenizer.from_pretrained("t5-base")
            )
        }
    def route(self, question: str, scores: Tuple[float, float, float]) -> str:
        """Route question to appropriate model and generate answer."""
        labels = ["simple", "moderate", "complex"]
        selected_label = labels[scores.index(max(scores))]
        model, tokenizer = self.models[selected_label]
        if selected_label in ["simple", "moderate"]:
            inputs = tokenizer(question, self.context, return_tensors="pt", truncation=True, max_length=512)
            with torch.no_grad():
                outputs = model(**inputs)
            start_idx = torch.argmax(outputs.start_logits)
            end_idx = torch.argmax(outputs.end_logits)
            answer_tokens = inputs.input_ids[0][start_idx:end_idx + 1]
            return tokenizer.decode(answer_tokens, skip_special_tokens=True)
        else:
            inputs = tokenizer(f"question: {question} context: {self.context}", return_tensors="pt", truncation=True)
            with torch.no_grad():
                outputs = model.generate(**inputs, max_length=50)
            return tokenizer.decode(outputs[0], skip_special_tokens=True)        

3. The Complete Workflow

Integrate scoring and routing into a simple workflow.

def main():
    scorer = QuestionComplexityScorer(api_key="your_openai_api_key")
    router = ModelRouter()
test_questions = [
        "What is AI?",
        "How does reinforcement learning differ from supervised learning?",
        "Explain the significance of transfer learning in deep learning models with examples."
    ]
    for question in test_questions:
        print("\n" + "=" * 80)
        scores = scorer.score_question(question)
        print(f"\nQuestion: {question}")
        print(f"Complexity Scores: Simple: {scores[0]:.2f}, Moderate: {scores[1]:.2f}, Complex: {scores[2]:.2f}")
        answer = router.route(question, scores)
        print(f"Answer: {answer}")        

Outputs

Here’s the output:

================================================================================

Question: What is AI?

Complexity Analysis:
- Simple:   0.20
- Moderate: 0.60
- Complex:  0.20

Selected Model: BERT (Base model for moderate complexity QA)

Generated Answer: artificial intelligence

================================================================================

Question: How does reinforcement learning differ from supervised learning?

Complexity Analysis:
- Simple:   0.10
- Moderate: 0.40
- Complex:  0.50

Selected Model: T5 (Advanced model for complex questions and generation)

Generated Answer: learning by interacting with an environment and receiving rewards or penalties

================================================================================

Question: Explain the significance of transfer learning in deep learning models with examples.

Complexity Analysis:
- Simple:   0.10
- Moderate: 0.30
- Complex:  0.60

Selected Model: T5 (Advanced model for complex questions and generation)

Generated Answer: allows models to leverage knowledge from pre-trained models        

Conclusion

This solution showcases the power of combining LLMs and model routing to build efficient, context-aware systems:

  • Dynamic Resource Allocation: Lightweight models for simple tasks, saving compute power.
  • Scalability: Seamlessly handles a wide variety of questions, from simple to complex.
  • Customizability: Easily expandable to include more models, refine the scoring logic, or enhance the context.

For even better routing decisions and answer quality, we could further explore Reinforcement Learning with Human Feedback (RLHF) to fine-tune the complexity scoring and model outputs. RLHF has shown great potential in aligning AI outputs with human preferences, making it an exciting direction for future iterations.

Give it a try and let me know how it goes on your end. Have fun experimenting!

要查看或添加评论,请登录

Nedved Yang的更多文章

社区洞察

其他会员也浏览了