Exciting week experimenting with the newly released LLaMA 3.2 (1B) model, applied in a unique context: multi-label classification on medical records.
?? Medical Dataset: The dataset includes transcriptions from over 3,000 anonymized medical records. Labels were multi-faceted, covering domains such as medical procedures, conditions, and specialties (e.g., “Surgery,” “Pneumonia,” “MRI”), with over 10,000 unique labels spanning frequent and rare categories.
Our goal was to develop a scalable and accurate model that classifies complex medical text data (e.g., surgical notes, medical history) into over 10,000 unique categories, including medical specialties, procedures, conditions, and anatomical terms. Due to the heavy class imbalance (e.g., “Surgery” vs. “Pediatrics”) and sparse data, this project posed unique challenges.
Multi-label classification of medical texts, a task that presents specific hurdles, particularly with sparsely labeled data and lengthier medical transcripts. The LLaMA 3.2 1B model, typically used for generation tasks, required innovative adjustments.
?? Methods Used:
- Model Architecture: We leveraged the LLaMA 1B model, enhanced with LoRA (Low-Rank Adaptation) for parameter-efficient fine-tuning on high-dimensional medical text. The setup allowed effective memory use and targeted learning, supported by BitsAndBytes quantization for memory optimization.
- Dimensionality Reduction: The model, being a Causal Language Model (CLM), outputs logits over the entire vocabulary—unideal for tasks needing last_hidden_state. So we needed to try careful dimensionality reduction from over 200K+ to around 2048.
- Tokenization & Label Management: Implemented multi-label encoding via MultiLabelBinarizer to handle over 10,000 medical classes with infrequent labels removed. Tokenization used a max sequence length of 512 for performance optimization.
- Loss Functions: We experimented with several custom loss functions tailored for sparse multi-label classification:
- MaskedBCELoss: Targeted only relevant label positions.
- SparseWeightedFocalLoss: Focused on hard-to-predict classes and penalized false positives.
- FrequencyWeightedFocalLoss: Adapted to imbalanced label frequency for balanced learning.
- Custom Validation Metrics: Designed specific metrics to evaluate:
- Row-wise label overlap accuracy.
- Per-class precision, recall, and F1 scores.
- Custom scoring to focus on labels relevant to each record while reducing 0-vs-0 matches.
- High-Frequency Label Bias: The model frequently predicted common labels (e.g., “surgery,” “MRI”), indicating a need to balance label frequency.
- Sparse Predictions: Fewer predictions were generated for lower-frequency labels. The custom loss functions helped improve this but remain a focus for future work.
- Reduced Overfitting: Using gradient checkpointing and dropout adjustments effectively controlled overfitting on high-frequency labels.
?? Results (Sample):
- Row-wise accuracy improvements: Adjusted the model to achieve meaningful, diverse predictions by incorporating label masks and adaptive learning rates. Initial results showed accuracy in predicting more frequent labels such as 'Orthopedic', 'Cardiovascular / pulmonary', 'Radiology', "MRI," "Surgery," etc., but more tuning is needed for sparsely occurring labels.
- Early experiments show an F1-score above baseline, with precision on top labels exceeding 70%.
?? Next Steps & Reflections:
This project has been incredibly insightful, revealing both the potential and the limitations of GenAI LLMs models in the medical NLP domain.
- GenAI and Healthcare: These applications demonstrate GenAI’s potential to transform healthcare, especially through automated tagging and diagnosis.
- Challenges: Sparse, multi-label data requires specialized loss functions and careful balancing to prevent overfitting on frequent terms.
- Further Optimizations: To improve model performance, I propose:
- Experimenting with adaptive learning rates to dynamically adjust during training.
- Investigating alternative architectures or hybrid approaches (e.g., LLaMA with a dense layer pipeline for specific labels).
- Trying curriculum learning by progressively adding less frequent classes to improve learning generalization.
Let’s connect and hear your views on GenAI and open source LLMs applications in healthcare AI!
#NLP #AI #MachineLearning #MedicalAI #TransformerModels #HealthcareInnovation #DataScience