top of page
  • Writer's pictureEkta Aggarwal

Decision Trees in Python

In this tutorial we would be understanding how to implement Decision Trees algorithm in Python.


If you wish to understand the theory behind Decision Trees then you can refer to this tutorial: Working of Decision Trees


To run Decision Trees in Python we will use iris dataset, which is inbuilt in Python. Iris dataset comprises of data for 150 flowers belonging to 3 different species: Setosa, Versicolor and Virginica. For these 150 flowers their Sepal Length, Sepal Width, Petal Length and Petal Width information is available.


Let us firstly load pandas library

import pandas as pd

Now we will load iris dataset from sklearn library

from sklearn.datasets import load_iris
iris = load_iris()

Following are the variable names in iris dataset

iris.feature_names

Output: ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']

Now we are storing the independent variables from iris dataset in X and dependent variable in y

X = iris.data
y = iris.target

We can see from the shape that X has 150 rows and 4 columns

X.shape

Output:

(150, 4)

We can see the number of occurences of different species:

pd.Series(y).value_counts()

Output:

2 50 1 50 0 50 dtype: int64


Now we are splitting the data in training set and test set. Note that we will build our model using the training set and we will use test set to check our performance of the algorithm. We are splitting our data into 80% training set and 20% test set. We can see that training set has got 120 rows and test set has 30 rows.

from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

print(X_train.shape);

print(X_test.shape);

print(y_train.shape);

print(y_test.shape)

Output:

(120, 4) (30, 4) (120,) (30,)


Let us build our Decision Tree model with default parameters. For this we are loading our libraries:



from sklearn import tree
from sklearn.tree import DecisionTreeClassifier

Now we are building a model for decision tree classifier with default parameters:

dt = DecisionTreeClassifier()

We are fitting our decision tree on training set

dt.fit(X_train,y_train)

Making the predictions on training and test set:

pred_train = dt.predict(X_train)
pred_test = dt.predict(X_test)

Using plot_tree function we can visualise what our decision tree looks like:

import matplotlib.pyplot as plt
plt.subplots(nrows = 1,ncols = 1,figsize = (14,7))
tree.plot_tree(dt);


We can also see how are decision tree is created, in terms of text using export_text function:

from sklearn.tree import export_text
r = export_text(dt, feature_names=iris['feature_names'])
print(r)

Output:

|--- petal width (cm) <= 0.80
|   |--- class: 0
|--- petal width (cm) >  0.80
|   |--- petal length (cm) <= 4.75
|   |   |--- petal width (cm) <= 1.65
|   |   |   |--- class: 1
|   |   |--- petal width (cm) >  1.65
|   |   |   |--- class: 2
|   |--- petal length (cm) >  4.75
|   |   |--- petal width (cm) <= 1.75
|   |   |   |--- petal length (cm) <= 4.95
|   |   |   |   |--- class: 1
|   |   |   |--- petal length (cm) >  4.95
|   |   |   |   |--- petal width (cm) <= 1.55
|   |   |   |   |   |--- class: 2
|   |   |   |   |--- petal width (cm) >  1.55
|   |   |   |   |   |--- sepal length (cm) <= 6.95
|   |   |   |   |   |   |--- class: 1
|   |   |   |   |   |--- sepal length (cm) >  6.95
|   |   |   |   |   |   |--- class: 2
|   |   |--- petal width (cm) >  1.75
|   |   |   |--- petal length (cm) <= 4.85
|   |   |   |   |--- sepal length (cm) <= 5.95
|   |   |   |   |   |--- class: 1
|   |   |   |   |--- sepal length (cm) >  5.95
|   |   |   |   |   |--- class: 2
|   |   |   |--- petal length (cm) >  4.85
|   |   |   |   |--- class: 2


Now we can calculate our accuracy using accuracy_score function:

from sklearn.metrics import accuracy_score

Getting the accuracy for training set

accuracy_score(y_train,pred_train)

Output: 1.0


Getting the accuracy for test set

accuracy_score(y_test,pred_test)

Output: 1.0





Decision Trees Using Grid search


Earlier wehad built our decision tree using default parameters, but there are 3 parameters in a decision tree to be tuned:

  • max_depth: The maximum depth of the tree.

  • min_samples_split: The minimum number of samples required to split a node

  • min_samples_leaf: The minimum number of samples required to be at a leaf node. Let us say min_sample_leaf = 5, thus, If after splitting a node we do not have 5 observations in the child nodes then the parent node will not split.

The most optimal way to find the best parameters for a Decision Tree is to use GridSearch

from sklearn import metrics
from sklearn.model_selection import GridSearchCV

We are creating a parameter grid i.e., mapping the parameter names to the values that should be searched. In this grid we are only tuning max_depth and min_samples_split. We are trying the calues 3,4, and 5 for max_depth and 5, 6 for min_samples_split.

param_grid = {'max_depth' : [3,4,5] , 
             'min_samples_split' : [5,6]}
print(param_grid)

Output: {'max_depth': [3, 4, 5], 'min_samples_split': [5, 6]}


We are now defining the Grid Search for our model, specifying our parameter grid. CV = 5 implies 10-fold cross validation with scoring mechanism as 'accuracy'.

grid = GridSearchCV(dt, param_grid, cv=5, scoring='accuracy')

We are now fitting the model to our training data to obtain the best parameters

grid.fit(X_train, y_train)

The best estimator of our model is given as:

grid.best_estimator_

Output: DecisionTreeClassifier(max_depth=4, min_samples_split=6)


The best 5-fold cross validation accuracy for our model is:

grid.best_score_

Output: 0.9416666666666668


The best parameters of our model are given as:

grid.best_params_

Output: {'max_depth': 4, 'min_samples_split': 6}



Now we are fitting out final Decision Tree model using our best estimator i.e.,

dt = grid.best_estimator_
dt.fit(X_train,y_train)

We are now making the predictions on training and test set:


pred_train = dt.predict(X_train)
pred_test = dt.predict(X_test)

Accuracy for our test set is given by:

accuracy_score(y_test,pred_test)

Output: 1.0


We are now plotting our decision tree:

import matplotlib.pyplot as plt
plt.subplots(nrows = 1,ncols = 1,figsize = (14,7))
tree.plot_tree(dt);

bottom of page