A gentle guide into Decision Trees with Python

This is based on his/her history and other measures.

Which drug is best for a particular patient.

Is a cancerous cell malignant on benign?Is an email message spam or not?and many more scenarios in real life.

Understanding the decision tree algorithm.

Decision trees are built using recursive partitioning to classify the data into two or more groups.

Real life example.

Let's say we have data of patients who have gone through cancer screening over time.

Based on the tests from screening exercises, the cells screened are classified as benign and malignant.

Based on this data a decision tree model can be built to predict these cases with the highest accuracy for future patients better than the doctors.

Decision tree along splits the data variable by variable starting with the variable with the highest predictive power, less impurity and lower entropy.

The main aim of this method is to minimize impurity and each node.

Impurity of nodes is calculated by the entropy of data in the node.

EntropyEntropy is the amount of information disorder or simply said is the amount of randomness in the data or uncertainty.

The entropy of a dataset depends on how much randomness is in the node.

It should be noted that the lower the entropy the less uniform the distribution and the purer the node.

If a sample is completely homogenous then the entropy is completely zero and if a sample is equally divided it has an entropy of 1.

In reference to the above data, let's say a node has 7 malignant and 1 benign while another node has 3 malignant and 5 benign, the former is said to have a low entropy as compared to the latter.

This is how entropy is calculated mathematically:The choice of the best tree depends on the node with the highest Information Gain after splitting.

Information GainThis is the information that can increase the level of certainty after splitting.

This is calculated as follows.

This process continues to build a basic decision tree.

Below is a step by step process in python.

Implementation with Python.

Importing Librariesimport pandas as pdimport matplotlib.

pyplot as pltimport numpy as npimport seaborn as snsData Ingestionloan_data = pd.





sample(5)Exploratory Data Analysis.

This is a brief EDA because the project aims at Decision tree Illustration.





5,color=’blue’, bins=30,label=’Credit.




5,color=’red’, bins=30,label=’Credit.



xlabel(‘FICO’)Setting Up the Data.

My data here has some categorical variables which I have to tame because the model algorithm may not work well with such data if not formatted correctly.

categorical_var = [‘purpose’]loan_data2 = pd.

get_dummies(data= loan_data,columns=categorical_var,drop_first=True)loan_data2.

columnsTrain test Splitfrom sklearn.

model_selection import train_test_splitX = loan_data2.



paid',axis = 1)y = loan_data2['not.


paid']X_trainset, X_testset, y_trainset, y_testset = train_test_split(X, y, test_size=0.

30, random_state=2)Training A decision tree modelfrom sklearn.

tree import DecisionTreeClassifierloanTree = DecisionTreeClassifier(criterion="entropy", max_depth = 4)loanTreeloanTree.

fit(X_trainset,y_trainset)Model EvaluationpredLoan = loanTree.

predict(X_testset)from sklearn.

metrics import confusion_matrix,classification_report,precision_scoreprint(classification_report(y_testset,predLoan))This produces a 78% accuracy levelLet's See the confusion matrix visualizationsns.

heatmap(confusion_matrix(y_testset,predLoan),cmap=”viridis”,lw = 2,annot=True,cbar=False)The decision tree model produces a 78% accuracy which is impressive given that no feature engineering has been done or even parameter tuning to improve the model.

The precision can be improved also later by applying a random forest classifier algorithm which is better than the Simple decision tree as this one.

More on that on another kernel.

Visualization of the decision treefrom IPython.

display import Image from sklearn.


six import StringIO from sklearn.

tree import export_graphvizimport pydotfeatures = list(X.

columns)# featuresdot_data = StringIO() export_graphviz(loanTree, out_file=dot_data,feature_names=features,filled=True,rounded=True)graph = pydot.


getvalue()) Image(graph[0].

create_png())All comments, suggestions and any improvement are welcome!See the codes above in action here.

in my Github repository.

Cheers!.. More details

Leave a Reply