Deep neural decision forest in keras

Kushal Mukherjee
3 min readOct 26, 2021

--

An integration between decision forest and neural network to use best of both world

Decision forest and neural network integration
Painting by Kushal Mukherjee

The article is an explanation of the paper Deep Neural Decision Forests by Kontschieder et al. We will first explain the concept for a single decision tree. For forest model, it will be an aggregation of multiple trees. We introduce a sample code in keras (inspired by the code here). Please note that, though, we give the demonstration on tabular data, the method is more relevant for projects like image classification. For most tabular data, good old random forest with hyperparameter tuning works better.

The method is different from random forest in the sense that it uses a principled, joint and global optimization of split and leaf node parameters and from conventional deep networks because a decision forest provides the final predictions.

Now, let’s see how we can build and train a single decision tree model in this framework. The decision tree model has to be stochastic in order to make it differentiable for back-propagation to work. In conventional decision tree model, the decision function deciding the paths from node to leaves are deterministic. That, however, will not work here. In this case, we will have two sets of probabilities which decides the final output as below:

  1. Probability of an observation reaching to a leaf . These basically are associated with decision node/split node which decides whether an observation goes left or right
  2. Once an observation reaches a leaf node, probability that it takes a specific class

The final probability of an observation to belong to a class is the aggregated probability of that observation to belong to a class in each leaf node (aggregation of probabilities calculated in point 2 above). The aggregation is done using a weighted sum, where, probability of the observation reaching to the corresponding leaf (calculated in point 1 above), is taken as weight. From the paper, the actual formula is as below:

Probability formula of a sample belonging to a class in Deep neural decision forest model
Fig 1: Probability of an observation x belonging to class y. Source: main paper

For a single decision tree model, the model is structured by connecting all the features to a dense layer and then connecting each node of the dense layer to individual decision nodes of the tree model. For a forest model, instead of all the features, a predefined ratio of features are selected randomly for each individual tree. In case of other tasks e.g. image classification, just replace the features by the output neurons of the convolution layer. Refer the below picture from the main paper for better understanding.

Model structure for deep neural decision forest
Fig 2: Model structure for tabular data

The training of the model is done in two stages. Starting from a randomly initiated set of class probability for each node (from point 2), iteratively update 𝜋 and µ for the number of pre-defined number of epochs.

Next, let’s jump into the code. As mentioned before, we will demonstrate the keras code for tabular data for simpler explanation.

First, let’s import all the necessary packages and define the required variables

Next, let’s import the data, encode categorical features (if required) and expand dimension of feature inputs

Create the structure of the neural decision tree module

Further, based on single neural decision tree module, create the structure of the neural forest module

Next, let’s create the model, compile, train and run it on the test sample.

Advantage of this model is that it can integrate the generalization power of Random Forest into the Neural Network.

Conclusion:

In this article, we introduced how we can integrate the power of Random Forest with a Neural Network. We also demonstrated the code in Keras for tabular data which can easily extended for any other task. This code can further be extended easily to any other classification task. The model is also easily trainable depending on the pre-defined parameters.

--

--

Kushal Mukherjee
Kushal Mukherjee

Written by Kushal Mukherjee

Kushal writes about Machine Learning & Data Science. He has 12 years of domain experience. Connect in LinkedIn www.linkedin.com/in/kushal-mukherjee-a2583b91

Responses (1)