MoE vs Ensemble (Part 2 for technical folks and AI folks)
The core idea of an ensemble model is say you are training a very simple model to learn the function on a particular data distribution, for example a decision tree from sklearn. And lets say you are playing with the model hyper parameters and have arrived at a set of hyper parameter tuples which have identical and plateauing loss on the training and validation datasets. How do you pick from this set of hyperparameter values which ones make the most sense? In ensemble models you do not. You just take some weighted average out of the result of all these models and use it as the final response.
What does this do?
Ensemble models make it super easy to not pick between equally performing models on training data and gives you some label diversity on the final output for the test distribution. This allows you to better account for slight variations in the data distribution of test and training data.
What does this not do?
Ensemble models do not increase the accuracy of the modelled relationship between input and output sequences or reduce its dimensionality.
Where intuition collides with actionability?
So obviously the main issue is that in averaging from n models, you are moving the model which was correct in its label/token prediction farther away from the actual label, which is counter intuitive to what we wanna do.
In summary: The original idea of Mixtures of Experts (originally proposed in 1991) was to resolve this dilemma by splitting the training data distribution into many sub-distributions and each distribution was much simpler to model (aka local experts). And then you would have a manager model to delegate tasks to different local experts basically using softmax probabilities (as a starter ofc). And then the loss fn would basically be probability weighted mean squared error.
As a result if a particular expert is best at predicting its token, we would essentially increase its probability of getting picked and only change the expert weights significantly if it has higher probability of being picked to predict that next token.
领英推è
The idea of a mixture-of-experts model is that it is a sparse approach. What it means is that like say a convolutional neural net it stores relationships between distributions at a much lower dimensionality aka easier for computations. For reference compare the below figure with figure 1, dimensionality of these regression lines is 2 in MoE vs 10 in the ensemble model setup.
Okay, so now if these models are so good what is the issue with just scaling them to say a 1000? Good question, so Google DeepMind actually did this for us and saved us like ~$30mn in compute ??. There are a lot of problems with sparse MoE models in general:
- A lot more hyperparameters: This makes it super unstable to pre-train, similar to how CNNs needed a lot of new hyperparams around pooling logic, window size and so on.
- There is a magic number for experts: For example, scaling number of experts to 1024 instead of 8 will give you marginally improved returns. In my opinion, its because the underlying dataset (most likely common crawl) does not have so many discretely partition-able sections where a local expert can specialize the relationship modelling enough to justify model size reductions.
- Complex "Routing" logic is unstable: This is the part of MoE that confuses people the most especially your non-tech folks. The router is a purely softmax probability based expert selector, all complex logic that has been tried here has been subpar compared to softmax.
- Expert learning is incomprehensible to humans: Irrespective of which layer you look at the expert task wise delegation, no major learnings about level of specialization can be drawn from the experts or the "router" by either source of original training data as seen in the Mixtral paper figure or the original sparse transformer paper by Google. But overall, this does line up because the model is drunk up on its own token kool-aid.
Conclusion: Despite all these uncertainties and instabilities, the Mixtral model comfortably outperforms ChatGPT on the LMSYS leaderboard, showcasing that the architecture choice has spine.
But overall, I am very excited by the possibility that we can train sparse models to be stable (given a large enough dataset ofc). This will bring parameter sizes way down and also make it super easy for a level of parallelization and federation in machine learning that was unforeseen for dense transformers.
My 2 favorite parts of MoE that I am actively exploring are:
- Explainability of expert specialization in encoder-only models. Super useful if you are trying to train a sparse embedding model that creates multiple orthogonal axes for lossless dimension compression.
- Sparse embedding models cannot be parallelized due to the bottleneck in many-many communications. Soft MoE solves for continuous inputs like images or videos, as shown here by Google. But applying this to text I/O is still an open problem statement.
Have you worked with MoE models before, what was your experience. Share your experience in the comments.