Decision Trees. Implemented

Simple, module wise implementation of Decision Trees using Gini index

Hussain Safwan
3 min readJun 12, 2021

I’ve talked quite a bit about decision trees. This is the article that we actually implement one. For the very basics, decision trees are a class of supervised learning algorithms in ML that relies on repeatedly splitting the dataset with respect to the “most plausible feature” to construct a graphical structure that has decisions at the level of roots. Since this is implementation-focused writing, there won’t be much of theoretical discussions.

Inner mechanisms

For this article, I’ll refer to the target column of the dataset as Y and the rest as X. At the very core of the system, we’ve got a third level nested loop that one each iteration performs the following,

  1. The innermost loops through each pair of data points within each column of X to determine the best split point.
  2. The middle one loops through each column of X to find one with the best split point, dubbed as the best attribute. At this level, we’re gonna make the split. With the split made, we’ve got two new datasets, shrunk in size and a little bit more homogenous. We apply the same rules to the new datasets alike until we either stumble upon a case where there are no more features to split on — an absolutely pure state of the dataset, or each label of Y is the same.
  3. This brings us to the outermost loop that iterates through each of the child datasets to execute the inner loops.

Codification

Let's kick off with the imports,

Next, we’re gonna create a Node class. You start with the root node, with each split you end up with two more nodes, that store the information such as best split point, attribute, gini and so on about the split until you reach the terminal nodes (also called the leaves) that store decision.

Let’s move on to the DecisionTree class. The constructor takes in a dataset (X and Y fused) and an optional parameter rows to specify what portion of the dataset is intended for the training phase. The class also features a built-in label-encoding.

The model calls the fit function on the encoded dataset. The fit function in turn calls for a recursive function recur_fit(). Now, this is the function where the splits and stuff occur. So every time it’s invoked, it creates a Node variable, gets the best split attribute and splits the dataset into two, and stores each as left and right, respectively, of the fresh node variable. Note that the contents of each of the snippets from here on are within the scope of the DecisionTree class and are pasted separately only due to better elucidation purposes.

find_best_attribute() is a function that takes in a table (a subset of the entire dataset) and spits out the attribute w.r.t the split is to be made, where within the column the split is to be made and the gain should the split be made. find_best_split() attribute is employed for the best point to split.

Here are some auxiliary methods needed for the training phase,

Next, talking about the prediction module, the system takes in a row, label-encodes it and passes it down the tree to reach the corresponding leaf and returns the associated decision.

For testing, with say the iris.csv dataset, run the following snippet,

Visit this link for the entire code base on my Colab.

--

--