Getting started with Quantization in PyTorch
The current state of the AI world is very exciting. Never has the tools for AI been so accessible and easy to use for everyone. With Large Language models being the hype train, the push for creating the bigger and best model has been ongoing for quite some time. We are looking at models hitting consumption of 80GB VRAM easily and most companies and startups are focusing on using more and more GPU's to deploy these models and run them.
In spite of pushing the boundaries of computation, there is a part of the community focused on running these models at the lowest possible requirements and surprisingly the tradeoff between accuracy and computation in some cases have been negligible even compared to the full sized models.
There are different aspects to improving the efficiency of these models on :
On the model optimization side, Quantization is the technique to decrease the memory and computation requirements of these models while at the same time decreasing their latency on the hardware. While there are different quantization techniques, some hardware specific and some global most of them follow the basic principle of changing precision points. Before understanding the quantization techniques being used to run some these large LLM's on the latest common denominator hardware lets look at the basics, underlying principles and tradeoffs wrt accuracy and latency.
Quantization Principle
Traditionally models are loaded in FP32. Floating point 32 or FP32 is its precision. Precision defines the floating point or integral information each layer of the model can store. Let's assume a model is of Floating Point Precision 3, it means each layer can store computation upto 8 decimal points or values. But does precision matter so much ? Lets look at some numbers.
2^64 = 18,446,744,000,000,000,000
2 ^32 = 4294967296
2^16 = 65536
2^8 = 256
While 2^8, 2^16 and 2^32 at a glance don't seem to be much different from each other. In reality the difference is massive. Even if a model has just 2 layers. A model with precision point 8 would need to make matrix computations with 256 digit numbers while a model with precision point 16 would need the CPU to perform computations with 65,536 sized matrixes. And from the calculation given above I don't think we need to mention the massive size of the matrix multiplication required for 32 floating point models.
So if higher precision leads to higher computation requirements, lower inference speed and more complexities why are these models the norms? The answer is accuracy.
Unlike traditional statistics and classical ML models, DL models work on the concept of perceptrons where each layer in layman terms can be thought of as a node of a larger neural network where all of them are connected to each other. Each layer has a random weight. When the input data flows from one to the next different, operations are performed on them which with the help of a loss/cost function along with an optimizer the weights are updated according to the task and data.
Basically the overall concept is somewhat trying to replicate the logic of interconnected neurons as in a human brain. The computations are very complex and as such even small differences and variations in the floating points at even the 10th decimal place can make a big difference in the results. Loss computations specifically are very sensitive to precision point changes. Deep learning is specially a new field with learnings coming up every day and replicating the process flow of perceptrons has been found to be the most accurate yet most heavy.
Bit maturity in any field come the concepts and techniques to make it viable for usage in actual day to day devices and applications. Over the years, researchers and ML practitioners have been tinkering with the idea of whether we need the full model as it is for the inference task. The intuition behind the fact that maybe for some tasks after a model has been trained, most of the digits in the last places of the weights are not needed to do the same task or that while high precision of the model might be required for training, but for inference it is not so.
Personally I have found optimization through quantization most useful and practical in real time vision related tasks where the criteria for accuracy may be high but not as unforgiving as language tasks. But yes for sensitive tasks such as healthcare and finance optimization may still be far off as the criteria for accuracy is too high there and the alternative not so forgiving.
For the next section before we go let me talk about the reference model, its a VGG16 Model Fine tuned for image classification on a mixed dataset. The model weights are the ImageNet pretrained taken from the torchvision module of Pytorch.
On the surface level it can seem that to perform quantization once can just take the model and convert the weights from higher precision to lower precision and one may say that is the definition in practicality it is so much more than that. We have to consider the accuracy, the data type of the input for the model and whether quantization affects it or not, whether target hardware supports the method we are using etc.
Will be exploring this topic little by little but lets get started on the basics first.
After training the model, have kept the image size constant at 128 x 128 RGB. Also to make the comparisons simpler, we have kept the inference using CPU only.
Dynamic Quantization
Its a post training quantization method.
Its only applied to the linear layers of the model. For context VGG16 is one of the earlier CNN so most of it consist of linear layers instead of complex Batch and activation functions making it one of the models most likely to be affected by it. At the same time the effects of quantization might not be that visible on other models which have lesser number of linear layers compared to the other ones.
Most simplest of the quantization modules offered by PyTorch.
On applying this type of quantization all the linear layers are converted to INT8 precision type when u save the model locally for storage.
During inference the input tensors (input data) do not require any conversion as the weights are dequantized to FP32 on the fly. They are not dequantized to their original values. The model analyzes the distribution of the batch of inputs each time and accordingly determines a scaling value which is then applied to the weights. The activations and other function layers remain in the FP32 format throughout the inference process.
Since dequantization and input tensor distribution analysis happens on the fly, computationally it is more expensive but also makes it more accurate. Good for models where the activations are distributed out.
Can be a good starting point to quantize the models for memory footprint and also depending on the distribution of the layers and activations can also reduce inference time heavily.
Static Quantization
Another post training quantization method which consist of multiple steps :
First we identify the layers that support quantization. We can then fuse multiple operation layers like Conv, Batchnorm to quantize them together to reduce computational overhead.
Using Pytorch QuantStub (Quantization) and DeQuantStub (Dequantization) points on the flow of the layers we want quantization and dequantization we ensure that the input tensors do not need to be compressed beforehand and also the format of the output tensors do not change.
Choose a representative dataset depending on what type of input data the model will be exposed to. This dataset is used to calibrate, i.e., collection activation statistics to determine optimal quantization parameters (scale and zero-point values).
Using these collected statistics the whole model is converted int8 using the calibrated value and also config value.
This config value is used to determine the architecture of target hardware like Arm and X86.
领英推荐
During inference there is no need to manually or separately convert the data into int8 precision, the quantization stubs does it for us before feeding it to the hidden layers, all the operations are then performed in INT8 and the output is dequantized back to FP32.
This pre-quantization makes the computational performance better than dynamic method but since the activations are analyzed once and not on the fly for each batch of input tensors the accuracy might be lower than the dynamic one mostly.
The latency is consistent for the same reason.
The ability to configure for each hardware separately makes this quantization very good for running on embedded machines.
Although one major con for this type of quantization is the fact that the quantization steps can be very complex.
Quantization Aware Training (QAT)
This pre-quantization makes the performance better than dynamic method.
Theoretically this would allow the model to “learn” and adapt to the lower “INT8” precisions and overall mitigate the accuracy loss associated with post training methods like dynamic and static quantization.
The training happens as if the model were to be quantized. The forward pass through the model includes simulated quantization's of weights and activations.
The values are then “fake quantized”. So they are clipped and rounded off to mimic them to look like int8 representations, but the underlying computations and all are still in the floating point precisions. This ensures that the gradients are calculated correctly during backpropagation.
Then the model is trained with these fake quantized weights. Which makes the model learn to mitigate the loss in the INT8 precisions.
After training the model is fully quantized and exported to int8 for inference like static quantization. But since the model was trained using quantized data, it can handle it and retains its accuracy as opposed to the former.
QAT has higher accuracy and more customizability (fine-tuning) and configuration. Whatever workflow we are using for training our custom models can be applied here. Unlike static we do not need to check or face the issues if the function or layer's quantization is supported by PyTorch.
Like dynamic less sensitive to changes in the input tensor.
The biggest con of a method like this is the fact that it becomes computationally very expensive and without knowledge of the training of the original model it won't be possible to perform quantization on the model. Also most of the times quantization is done when its not possible to run the uncompressed model on the target hardware, so fine tuning and training the model again in QAT makes it much more resource intensive.
Mixed Precision Training
Different parts of the model are trained at different precisions. Cannot tell from hands on experience about it much considering I have trained only on half precision completely.
There are certain operations like loss calculations are very sensitive to precision. During backpropagation, gradients computed in FP16 can be very small and prone to underflow. To counter this loss scaling is applied to give some numerical stability to ensure smaller gradients are not lost during training.
After loss scaling the gradients are often accumulated in FP32 to maintain precision before applying to the model weights.
‘Automatic Mixed Precision’ tools help automate the mixed training process. Personally I have used Pytorch autocast to account for the difference in training data and labels with the half precision model.
Normally computations like matrix multiplication are performed in FP16 to reduce computation and latency unlike the activations and other functions which can be hard to compute accurately with the backpropagation.
Theoretically it would not seem as complex as the static quantization without any specific need for model fusing, but the path to gain reasonable accuracy is not that straightforward. It is possible that the model can be trained with the quant/de-quant stubs just like the uncompressed model but getting the right number of epochs for disabling the quantization observers and freezing batchnorm stats can be very tricky and not something preferable for beginners.
Conclusion
Just for reference put the GPU latency w.r.t all the other ones run on the CPU to show how much computationally effective the GPU makes running these DL models but that is for another day.
The uncompressed model (Full Floating Point 32 Precision) is the slowest since no quantization has been applied there and any other result would have been contrary to the point I was trying to make from this write up. Out of all the methods
Dynamic seems to be the slowest as expected. There is a 20% decrease in latency from the uncompressed model. Static and QAT have the highest inference speeds with nearly the same numbers. The assumption that while static and QAT are complex in preparation and conversion the end results as far as latency is considered is more stable and faster.
I have tested the mixed precision models on two different machines and the results are very odd, at most it should have been near to the actual FP32 latency but its abysmally high. While I have not used any particular module for this quantization and have just trained using half precision point directly along with PyTorch Autocast method to handle the difference in precision points during computation, the native approach might be the reason the latency is so off or maybe the Autocast module has issues.
Comparison on basis of latency is not the main thing and we have to consider both the variables used in the tradeoff discussion, we are talking about the accuracy. More than latency its the accuracy that is considered the deciding factor in choosing or deploying a model. There are cases where a smaller model might be 3x faster than the bigger one but even a 10% drop in accuracy which lead us to choosing the bigger and slower but more accurate model. Now where there is more leeway for accuracies and not that critical tasks faster models are more preferred but this is completely a case to case business scenario discussion.
The Full precision model at 94% can be considered as the baseline. While dynamic had around 20% increase in the inference speed, the accuracy remains the same. Not is it specific to my choice of a model with high number of linear layers or not is something I am not sure. Mixed precision performed good too but at that abysmal latency that is not something I would go to production with.
Static and QAT both have very low accuracies but very high inference speeds. This again brings us to the problem of complex quantization techniques. For my target hardware, I could have played around with different config options and tuned the reference datasets along with the fused layers to find the model which would have got the higher accuracy, but for this article I wanted to go on with a beginners perspective to quantization. Achieving high accuracies with low latency can be a bit tricky and requires some patience but it is achievable.
Quantization is a very valuable and crucial technique. Along with reducing memory footprint and computational cost its valuable for deploying models on resource constrained platforms like mobile devices, embedded systems and even CPU bound machines. It has had a transformative impact on the application of Deep Learning models and their applications in real life with AI models running locally on our phones and so many instances which we are not even aware of day to day.
Today, for us to leverage the best in class LLM's we might need to offload the whole computation on a cloud of GPU's but as the industry matures and understands the nuances, the day is not far when we might run the next GPT clone on our Androids and iPhones.
Here is my notebook gist with the whole steps and flow for the methods I have discussed.
Co-Founder of Altrosyn and DIrector at CDTECH | Inventor | Manufacturer
3 个月The emphasis on mobile native LLM deployment highlights the increasing demand for efficient inference techniques. Quantization's impact on reducing model size and computational requirements aligns with Moore's Law, enabling advancements in edge computing. Recent studies by demonstrate a 30% reduction in latency for quantized models compared to their full-precision counterparts. Given the growing use of LLMs in real-time applications like voice assistants, how can quantization techniques be tailored to ensure both accuracy and responsiveness in resource-constrained mobile environments?