Interpretable Machine Learning with Decision Trees

- 5 mins

I try to explain how Decision-Tree based methods are powerful and yet very interpretable and act as a bridge between Easy Inference methods (Like Linear Regression) and Black box methods (Like Neural Networks). So let’s dive in.

Let’s look at a basic decision tree and how it is built. A decision tree partitions the feature space recursively ( the partition is a horizontal line or a vertical line in the case of 2D feature space as shown below). In each of the boxes that you see below the output of the classifier is the mean of the target values in a given box(regression/classification) or the dominant class in the box(classification).

The way that the decision tree finds out the optimal “cut” that it has to make, comes from something known as “Gini Index” or “Entropy” in case of classification and simple mean squared loss in case of regression.

This is a pretty straight forward algorithm ie. traverse along the X-axis and find the best(minimum total error) split, then go into each of the two splits find the best split in those splits and on and on. The scheme seems simple but is easy to interpret as to what tree is learning.

Only if we could have this interpretability and yet have a state of the art predictive power, it would be a desirable combination to have. Fortunately, there are Random Forests and Boosting, these are meta algorithms.

I won’t be going in-depth into Random Forest and Boosting but general idea is that Random Forests parallelly generates a lot of trees and does max voting or averaging of the results, whereas in Boosting the trees are fitted sequentially. In both cases, the underlying Decision Tree Algorithm is the same.

Now coming to Interpretability, we have seen that tree splits are pretty intuitive but in case of thousands of trees we fit in case of Random Forests and Boosting Trees it’s hard to infer anything concrete (though in case of Boosting the First tree has the highest predictive power). The major inference we would like to draw is that which features of our input space have been given the most importance by our algorithm and assign values to each of the features according to their importance in prediction.

Finding relative feature importance Feature importance is based on the idea that a feature which leads to a maximum decrease in RMSE for a split is given a high score of importance. Let’s put it more formally

The above formula says that for a given tree (T) with J terminal nodes (hence J-1 internal nodes) for a given variable ‘k’ we sum up over all the nodes(where the variable ‘k’ was used to make a split) the increase in mse for that node. This makes intuitive sense that the most importance is given to the variable with most discriminative power. This formula can be used in classification setting as well as replacing the gain calculation with “Entropy Gain” or equivalent.

A natural extension of this is applied to the ensemble tree algorithms like Random Forest and Gradient Boosting. Where in we sum up the contribution of a variable ‘k’ over all nodes of all trees like the below formula.

But there are some drawbacks. As we know the tree is entirely built on the basis of training error and there may be a case where a variable showing promising importance in the train set might actually not be of much importance in the real world, but there’s an alternative way where we can take some help of the validation set to further be confident on our feature importance estimates.

Random permutation-based approach

This method is intuitive and model-agnostic, to find feature importance by this method randomly permute the data points of the feature you want to calculate the importance of keeping the rest of columns untouched. Make prediction using say a validation set before and after permuting. A less difference in loss before and after permuting clearly says that the feature does not have a strong correlation/predictive power with respect to the target variable. Note that we are here able to use a validation set and incase of Bagging (in Random Forest) we can do it on Out of Bag(OOB) samples, all of this reduces dependence on the training set.

Visualizing trees for inference

As we know Decision trees are very intuitive data structures, visualizing Decision Trees after fitting them (would be tough to pick good tree in case of Random Forest but in case of Boosting the first tree is ideal) can give valuable insights into what actually does the model learn and how much domain-relevant is the learning, extracting information from these trees also can help in making features that can be fed into more powerful predictive/black-box models. Here is an example.

This is an example of the iris dataset, top branches of the tree starting from the root can be used to generate features which are a combination of say petal_length & petal_width which other class of models will not be able to capture.

So, in conclusion, Trees with their inference power and also predictive power of ensemble-trees make it a good choice for many machine learning tasks especially to understand the task and get valuable insights at hand before going into more black-box methods.

References

Trevor Hastie, Robert Tibshirani, and Jerome Friedman. The Element of Statistical Learning. p. 593.