How Explainable AI can drive AI adoption in the enterprise?
Machine Learning and Deep Learning based solutions are all over the news. They are helping create new products and drive innovative services. Big name companies like Google, Amazon, Facebook, Twitter, Linkedin, etc. are using these technologies in every area of their businesses. This has led to a frenzy in the industry. Other enterprises also want to tap into this new technology to create value for their customers, departments, employees and bring new products and services to the market. The top management at these enterprises are urging the departments to look at ways of using these technologies to drive innovation. We, the data scientists, get excited by this news as this means more projects which means more opportunities to apply the coolest of algorithms to the problems at hand ??. We go to talk to our department head and tell them how excited we are about this new development from the top management. We start pitching a few ideas for new projects where machine learning can help solve the problems we are having. The department head listens intently to these ideas. Then, she says, "This is all fine and dandy. But, you know what, I am a little sceptical about the results these are going to produce. Let us say that we invest in one of your ideas. The project is completed and you have great results to show for in terms of numbers. But, can the models used in the project help me understand what is really going on? I have heard that these models are very good at what they do but cannot reason about why they arrived at a result. What are we going to do if someone asks us this question? I do not trust a system which cannot give us a reason why it arrived at some particular result. I am sorry, but, I don't think I can go ahead with this." We are disheartened. What should we do? How can we convince the stakeholders to invest in these kinds of projects? I suggest that interpretable machine learning can help. In this post, we are going to take a look at the idea of explainable AI (interpretable machine learning). The rest of the post is organized in the following two sections:
- What is Machine learning Interpretability?
- Brief Deep-dive into the techniques.
What is Machine Learning Interpretability?
Interpretability is the degree to which a human can understand the cause of a decision. It follows from this definition that the higher the interpretability of a model, the easier it is for a person to comprehend why a certain prediction was made. When comparing two models, the model which can be comprehended more easily than the other one can be said to be more interepretable.
We seek interpretability due to the following reasons:
- Curiosity and Learning - We, humans, by nature, are curious. We seek to identify reasons for things that happen unexpectedly in all areas of our life. If our phone suddenly stops working, we try to identify what happened and many a times who are technical enough dig deep to find why that happened. Based on the identified reasons, we build a new model or update the model of the world around us. Similarly, in a business setting, when we learn something about the model, we can modify our products, services and processes.
- Detect Bias - Most machine learning models learn from previous data. This data has been created by us. However, people, sometimes cannot take decisions objectively and this is passed onto the data in our systems. When we use this data to automate tasks, this bias can inadvertently be picked up by the machine learning models. In order to prevent this from happening and make the systems fair and accountable, we need methods to understand the predictions so that appropriate steps can be taken.
- Social Acceptance - When machine learning models are integrated into our lives as products and services, we need models that can be trusted to drive social acceptance. e.g. when we look at product recommendations from Amazon, we accept the recommendations because it clearly shows us the reason of that recommendation; be it "Similar products" or "Frequently Bough Together"
A brief deep-dive into the techniques
We can attain model interpretability in many ways. There are mainly three broad categories which can be used to make a model intepretable:
- Interpretable models - There are some machine learning models which are more interpretable than the others. Linear Models like linear regression, logistic regression and their variants are among one of the most interpretable models. Decision trees also belong to this category of models. However, the models which are inherently interpretable sacrifice accuracy or predictive power to be more easy to understand. e.g. A linear regression model can only model linear relationships between the features and the target variable. Any type of interaction among the features too cannot be identified by linear regression unless it is explicitly specified.
- Model - Agnostic Methods - Model-agnostic methods are flexible. They can be applied to any model. These methods can be though of as a layer on top of the machine learning models. Below are some common forms of Model-agnostic methods:
I. Partial Dependence Plots - Partial dependence plots show the marginal effect of a feature on the predicted outcomes of a model. PDP can only be calculated after we already have a machine learning model. Let us try to understand how PDP is calculated. Let us suppose we have a housing price prediction problem with different features - size, number of bedrooms, number of bathrooms, number of balconies, distance from the nearest city center, etc. Now, we want to understand how the distance to nearest city center affects the prices. We already have a fitted model. Now, let us suppose we have different values for this distance d1, d2, d3, d4, etc. We first make all records have distance as d1 and calculate the average price prediction from the model. We then make all records have a distance of d2 and calculate the average price prediction from the model and so on for all distance values. Finally, we plot these predictions against the different distance values on a graph. This plot tells us the relative effect of the distance from the city center.
II. Individual Conditional Expectation - For a chosen feature, Individual Conditional Expectation (ICE) plots draw one line per instance, representing how the instance’s prediction changes when the feature changes. For calculating ICE of a feature, we would take each instance and try out different values of that feature to get the predictions and then plot those predictions against the values on a plot for all instances in the dataset. ICE helps identify cases which are different from the overall trend identified in the PDP.
III. Feature Importance - Feature importance of a model is a very straightforward technique. After we have trained a model, we pass the instances through this trained model to get the predictions. We evaluate these predictions by comparing with the actual ground truth. After that, we take a feature and randomly shuffle the values of that feature and calculate the predictions. We evaluate the predictions with the actual values. If the model behaves poorly than earlier, we say that this feature is important in predicting the target. We do this for each of the features in our dataset and get the relative drop in the performance which can be thought of a relative importance of the features in predicting the target. The feature importances give us an overall idea about which features are predictive of the target but do not give any measure of how the target varies with a change in these features.
IV. LIME (Local Interpretable Model-Agnostic Explanations) - LIME is a method of fitting local models that can explain single predictions of any machine learning model. LIME follows the below steps to reach at an explanation:
- Choose your instance of interest for which you want to have an explanation of its black box prediction.
- Perturb your dataset and get the black box predictions for these new points.
- Weight the new samples by their proximity to the instance of interest.
- Fit a weighted, interpretable model on the dataset with the variations.
- Explain prediction by interpreting the local model.
V. SHAP (SHapley Additive exPlanations) - SHAP is a technique based on Shapley Values which comes from game theory. Shapley values can be understood at a high level as finding each player's marginal contribution, averaged over every possible sequence in which the players could have been added to the group. This same idea is applied to the features in a machine learning model and we learn each feature's marginal contribution to the model. Also, SHAP values can be calculated for individual instances. With individual level SHAP values, we can pinpoint the exact features which play a role in the prediction and take appropriate actions based on these values.
These are some of the techniques that can be used to understand the black-box models. Explainable AI is a rapidly evolving research area where different groups are continuously seeking new methods to make machine learning and deep learning based models explainable. Today, machine learning models are themselves part of our knowledge repository. As we keep improving the methods to understand these models, we will start adding more value to that knowledge. Businesses will have more insights than they previously had and appropriate decisions would be taken based on these insights to further improve the products and services, or to drive business growth.
Please share this post so that more and more people are aware that there exist methods that can explain the machine-learning black-boxes and help create value within their organizations.
P.S. If you would like to talk more about these methods or machine learning in general, you can reach out to me. Also, if you have a project that you would like to discuss, feel free to reach out.