Learning to Reason: Enhancing Robustness and Generalization in LLMs for Machine Reading Comprehension

Learning to Reason: Enhancing Robustness and Generalization in LLMs for Machine Reading Comprehension

Before We Dive In...

AI is advancing rapidly, but how much do language models really understand? Over the past few months, Shivam Chourasia and I have been researching this question, uncovering how LLMs often rely on surface-level patterns rather than true reasoning. This has major implications for their reliability in real-world applications.

In this blog, I’ll share what we found and how reasoning-based augmentation can help AI models become more robust. I’d love to hear your thoughts—let’s explore this together!

Introduction

In recent years, large language models (LLMs) like OpenAI's GPT, Meta's LLaMA, and Google's Gemini have revolutionised natural language processing (NLP). These models are trained on vast amounts of text and exhibit impressive language comprehension abilities. However, despite their success, they still struggle with true reasoning and generalisation. Instead of genuinely understanding a passage, they often rely on spurious correlations—which are superficial patterns in the dataset that allow them to guess answers correctly without real comprehension.

One critical area where this issue manifests is Machine Reading Comprehension (MRC). MRC is a subfield of NLP where models are designed to read a given text and answer questions based on it. Unlike simple retrieval-based systems, MRC requires deep comprehension, logical inference, and reasoning. One of the benchmark datasets called SQuAD (Stanford Question Answering Dataset) has been widely used to train and evaluate these models.

This article explores our research on improving MRC by incorporating reasoning-based data augmentation. This approach significantly enhances robustness against adversarial examples and improves generalization across different domains. We achieve this by fine-tuning LLaMA 3.2 1b using LoRA (Low-Rank Adaptation) [Hu et al., 2022] and enriching training data with synthetic reasoning explanations generated by the Gemini Pro model.

Understanding the Problem: Spurious Correlations in MRC

Before diving into our solution, let's understand a fundamental issue in machine reading comprehension: spurious correlations.

What are Spurious Correlations?

A spurious correlation occurs when a model finds shortcuts in data instead of truly understanding it. Here are some common examples:

  • Positional Bias: If answers frequently appear at the beginning of a paragraph in a dataset, a model may learn to always select the first sentence instead of reasoning through the text.
  • Keyword Association: If many questions mentioning "Newton" have answers like "gravity," the model may blindly associate "Newton" with "gravity" without actually understanding the context.
  • Superficial Patterns: If questions phrased in a certain way always have similar answers, the model might just memorize the format rather than actually comprehending the passage.

These biases reduce the model's ability to generalise beyond its training data, making it vulnerable to adversarial attacks and domain shifts [Jia & Liang, 2017].

Our Approach: Enhancing MRC with Reasoning Augmentation

To tackle these challenges, we propose a novel solution: augmenting training data with reasoning-based explanations. The key idea is to make models not just predict answers but also learn the logical reasoning behind them.

Synthetic Reasoning Data Generation

To improve comprehension of LLM, we use Gemini Pro to generate reasoning explanations for each question-answer pair in SQuAD dataset. This generated reasoning is then appended as a feature in the existing SQuAD dataset. This approach has the following advantages:

  • Provide a step-by-step logical reasoning path to the LLM
  • Help the model understand how the passage leads to the correct answer
  • Reduce reliance on spurious correlations

Example of Reasoning-Augmented Data:

Passage: "The social identity of the children was strongly determined by the tribe's kinship system. Among the matrilineal tribes of the Southeast, the mixed-race children generally were accepted as and identified as Indian, as they gained their social status from their mother's clans and tribes, and often grew up with their mothers and their male relatives. By contrast, among the patrilineal Omaha, for example, the child of a white man and Omaha woman was considered "white"; such mixed-race children and their mothers would be protected, but the children could formally belong to the tribe as members only if adopted by a man. "

Question: "What were multiracial children with a tribal mother considered to be in Southeast tribes?"

Answer: "Indian"

Generated Reasoning: "Since these tribes followed matrilineal descent, children inherited their identity from their mother’s clan. The passage explicitly states that mixed-race children in these tribes were accepted as Indian."

With this enriched dataset, the model learns to justify its answers rather than guessing based on patterns.

Supervised Fine-Tuning with LoRA

The augmented dataset retained all original SQuAD features (e.g., question, passage, answer) and introduced reasoning as an additional feature. The reasoning augmented dataset was then used to fine-tune the LLM using supervised fine-tuning (SFT).

Supervised Fine-Tuning (SFT) is the process of refining a pre-trained model on a labeled dataset, where the model learns from input-output pairs under human supervision to adapt to specific tasks.

However, fine-tuning is computationally expensive. Instead of fully fine-tuning the LLM, we employed Low-Rank Adaptation (LoRA) with SFT.

Low-Rank Adaptation (LoRA), an efficient technique that freezes most of the model’s weights and introduces small trainable weight matrices to adapt the model.

LoRA has been shown to be an effective and resource-efficient way to fine-tune large models for domain-specific tasks [Hu et al., 2022]. By fine-tuning Llama 3.2 1b on the SQuAD dataset, we improve its ability to answer questions accurately.

Results: Improved Generalization and Robustness

In this section, we present the results of our experiments across the three datasets: SQuAD, BioASQ, and SQuADAdversarial. The performance of the LLM was evaluated using Exact Match (EM) and F1 Score, which are standard evaluation metrics for question-answering tasks.

Exact Match (EM): Measures how often the model’s answer exactly matches the ground truth answer.
F1 Score: It is the harmonic mean of precision and recall.

We compared the performance of the model in three settings: zero-shot inference, fine-tuning with supervised fine-tuning (SFT), and fine-tuning with reasoning-augmented data (SFT-Reasoning).

The following table records the performance of various methods tested.


Table 1: Performance comparison of the LLM across datasets in three settings: zero-shot, SFT, and SFT-reasoning. Metrics include Exact Match (EM) and F1 Score in percentages(%).

Key Takeaways:

  • Fine-tuning greatly boosts performance on all datasets, showing that adapting to specific tasks is crucial.
  • Reasoning augmentation further enhances performance, especially in out-of-domain biomedical questions and adversarial examples.
  • While improvements were observed, adversarial robustness remains a challenge, indicating the need for further refinements.

Understanding the Impact of Reasoning Augmentation

To further assess the model’s robustness, we examined its ability to correctly handle adversarial modifications using SQuADAdversarial. The Table-2 , shows the correct predictions made by the LLM after reasoning-based SFT+LoRA, demonstrating improved robustness against adversarial inputs.


Table 2: Examples of correct predictions for SQuADAdversarial

Table 3, on the other hand, highlights failure cases where reasoning augmentation still struggled against complex distractors.


Table 3: Examples of incorrect predictions for SQuADAdversarial

Challenges and Future Directions

While our approach shows promising results, challenges remain:

  • Handling adversarial attacks: Despite improvements, the model still struggles with heavily paraphrased adversarial questions.
  • Scalability: Generating reasoning explanations for large datasets requires additional computational resources.
  • Evaluation Metrics: Standard metrics like EM may underestimate performance, as they don't consider semantically correct but slightly different verbose in the answers generated by the LLMs.

Future Work:

  • Incorporating adversarial training to improve robustness.
  • Exploring self-supervised learning to reduce reliance on labelled data.
  • Refining evaluation metrics to better assess comprehension quality.

Conclusion

Our research demonstrates that augmenting datasets with reasoning explanations significantly improves MRC models' robustness and generalisation. However, as shown in Table 3, adversarial robustness remains a challenge. Future work will focus on adversarial training and refined evaluation metrics to further enhance model resilience.

As we continue to explore the depths of AI, it’s clear that there’s still so much to learn. While we’ve made progress, the road ahead is full of both challenges and exciting possibilities.

I’m grateful for the opportunity to dive into this research with Shivam Chourasia . I look forward to seeing how the AI community builds on these insights. Thank you for reading, and I can’t wait to hear your thoughts as we keep pushing the boundaries of what AI can really understand!

Godwin Josh

Co-Founder of Altrosyn and DIrector at CDTECH | Inventor | Manufacturer

4 周

Surface-level pattern reliance limits LLM generalization. Reasoning augmentation could bridge this gap, fostering true AI comprehension. Shivam, how do you envision fine-tuning these augmentations for nuanced, context-dependent reasoning?

要查看或添加评论,请登录

社区洞察

其他会员也浏览了