Distilling a Neural Network into a soft decision tree

Distilling a Neural Network into a soft decision treeRazorthink IncBlockedUnblockFollowFollowingJan 24As part of the commitment to continuous (& cutting edge) research at Razorthink Inc, we are coming up with a series of review papers which will screen through the best of research done in the field of deep learning, machine learning, data science and artificial intelligence in general, across the globe.

Each week, we will pick up one research paper, break it down to make it easier to understand, take you through the entire research approach, major takeaways and finally bring in the applicability in real use-cases.

Our first pick in the series is “Distilling a Neural Network into a soft decision tree” (download link at the bottom) originally written by Nicholas Frosst & Geoffrey Hinton (Google Brain Team).

Introduction:Deep Neural networks have been proven to be very effective in performing tasks that involve classification and prediction based on the complexity of the data.

Most importantly, it is highly useful in situations where the input data has a complex relationship with the target variable and the dimensions of the input data is very high.

However, it’s hard to explain how the neural network has performed the classification on a particular test case.

This is primarily due to the fully connected representation of the network model which explains where the activations have occurred but don’t usually help us understand the ‘WHY’ behind it.

It would be convenient to explain the prediction of a model if the model representation was similar to a decision tree which takes into consideration a series of variables and provides a way forward or a decision based on them.

However, decision trees would not serve as an alternative to neural networks mainly due to the trade-offs that occur when it comes to generalization Vs interpretability in the sense that, trees that give us great decisions could be overfitted to the training data.

Hence, they generalize poorly and the trees that are easily interpretable aren’t always accurate.

The main aim of this paper is to take a neural network and attempt to explain the way it works in the form of a decision tree.

The central thought that is mentioned in the paper is to try and mitigate this trade-off situation that occurs in a decision tree.

This is achieved by asking the neural network to build a decision tree by taking an input dataset for training, wherein the dataset is used to generate more training data that has the same nature/properties as the existing dataset.

Subsequently, this is fed to the decision tree.

Finally, in the end, it is shown that not only do these trees perform better than other vanilla decision trees, but they can explain decisions in a much more comprehensive manner as compared to that of a neural network where the explainability is a huge constraint.

Summary:The excellent generalization abilities of deep neural nets depend on their use of distributed representations in their hidden layers, but these representations are hard to understand.

For the first hidden layer we can understand what causes an activation of a unit while for the last hidden layer understanding the effects of activating a unit is possible, but for rest of the hidden layers it is much harder to understand the causes and effects of a feature activation in terms of variables that are meaningful including the input and output variables.

Also deep neural nets can make reliable decisions by modeling a very large number of weak statistical regularities in the relationship between the inputs and outputs of the training data and there is nothing in the neural network to distinguish the weak regularities that are true properties of the data from the spurious regularities that are created by the sampling peculiarities of the training set.

Instead of trying to understand how a hidden layer representation is performing a classification of a particular task, it is easy to explain how a decision tree makes any particular classification because this depends on a relatively short sequence of decisions and each decision is based directly on the input data.

Decision trees, however, do not usually generalize as well as deep neural nets.

Unlike the hidden units in a neural net, a typical node at the lower levels of a decision tree is only used by a very small fraction of the training data so the lower parts of the decision tree tend to overfit unless the size of the training set is exponentially large compared with the depth of the tree.

In this paper, the authors seek to combine together the advantages of both (neural net and the decision tree).

This is done by using a neural network to train a soft decision tree.

Later this decision tree is used to do the inference.

Since it is a hierarchical representation, the decisions made at every node in order to reach the final leaf node can be tracked down.

Hence it provides the interpretability to the model.

Research Approach:Let’s first describe the representation of the soft decision tree.

We will look into the loss function and the training of this decision tree a little later.

The diagram represents a soft binary decision tree trained with mini-batch gradient descent, where each inner node “i” has a learned filter wi and a bias bi, and each leaf node l has a learned distribution Ql.

At each inner node, the probability of taking the rightmost branch is:pi(x) = ????(xwi + bi)where x is the input to the model and σ is the sigmoid logistic function.

The model learns a hierarchy of filters that are used to assign each example to a particular bigot with a particular path probability, and each bigot learns a simple, static distribution over the possible output classes, k.

Qkl = exp(ɸkl) / ∑k’(exp(ɸk’l))where Ql· denotes the probability distribution at the lth leaf, and each ɸl· is a learned parameter at that leaf.

This model gives a predictive distribution over classes by using the distribution from the leaf with the greatest path probability.

It is trained using a loss function that seeks to minimize the cross-entropy between each leaf, weighted by its path probability, and the target distribution.

For a single training case with input vector x and target distribution T, the loss is:L(x) = — log(∑l ∈ leaf nodes Pl(x) ∑k Tk log(Qkl))Where T is the target distribution and Pl(x) is the probability of arriving at leaf node l given the input x.

Unlike most decision trees, soft decision trees use decision boundaries that are not aligned with the axes defined by the components of the input vector.

Also, they are trained by first picking the size of the tree and then using mini-batch gradient descent to update all of their parameters simultaneously, rather than the more standard greedy approach that decides the splits one node at a time.

The diagram below depicts the soft decision tree created using soft targets of a network trained on the MNIST data.

Results:The soft decision tree trained in this way achieved a test accuracy of 96.

76% which is about halfway between the neural net and the soft decision tree trained directly on the data.

Possible Areas of Application:One of the important use cases of this is explainability.

As the model is based on a decision tree, it is easier to track down the decisions made to reach a particular output.

Another use case is distilling a heavier model into a lighter model with not much of an accuracy trade-off.

You can download the original paper here.

Visit Razorthink & Razorthink Guide for additional resources.

.. More details

Leave a Reply