Handy Guide to visualize decision trees in Python + Code

Devesh Surve
3 min readFeb 17, 2022

--

Decision tree is one of the most popular machine learning algorithms used all along, so let’s get started!

Decision trees are used for both classification and regression problems, this story we talk about classification.

Before we dive into it, let me ask you this

But why Decision trees?

We have a couple of other algorithms there, so why do we have to choose Decision trees??

well, there might be many reasons but I believe a few which are

  1. Decision tress often mimic the human level thinking so its so simple to understand the data and make some good interpretations.
  2. Decision trees actually make you see the logic for the data to interpret(not like black box algorithms like SVM,NN,etc..)

How to visualize them ?

The one problem that we face often while visualizing decision tree is the lack of interactivity, especially when the tree size increases to over 100 nodes, and PNGs do not cut it.

Fret not, as plotly is here. Well, not directly though.

They do have treemap

Treemap charts visualize hierarchical data using nested rectangles. The input data format is the same as for Sunburst Charts and Icicle Charts: the hierarchy is defined by labels (names for px.treemap) and parents attributes.

However, we can modify the same for our decisiontreeclassifier. First let me show you with a normal matplotlib, a classification for a breast cancer dataset.

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.datasets import load_breast_cancer
from sklearn.tree import DecisionTreeClassifier, plot_tree

data = load_breast_cancer()
X, y = data['data'], data['target']
feature_names = data['feature_names']

model = DecisionTreeClassifier(criterion='entropy').fit(X,y)

plt.figure(figsize=(12, 4), dpi=200)
plot_tree(model, feature_names=feature_names, filled=True)
plt.show()

And you might understand the problem I am talking about.

Now, let’s try plotly with our modified code.

import plotly.graph_objects as go

labels = [''] * model.tree_.node_count
parents = [''] * model.tree_.node_count
labels[0] = 'root'
for i, (f, t, l, r) in enumerate(zip(
model.tree_.feature,
model.tree_.threshold,
model.tree_.children_left,
model.tree_.children_right,
)):
if l != r:
labels[l] = f'{feature_names[f]} <= {t:g}'
labels[r] = f'{feature_names[f]} > {t:g}'
parents[l] = parents[r] = labels[i]

fig = go.Figure(go.Treemap(
branchvalues='total',
labels=labels,
parents=parents,
values=model.tree_.n_node_samples,
textinfo='label+value+percent root',
marker=dict(colors=model.tree_.impurity),
customdata=list(map(str, model.tree_.value)),
hovertemplate='''
<b>%{label}</b><br>
impurity: %{color}<br>
samples: %{value} (%{percentRoot:%.2f})<br>
value: %{customdata}'''
))
fig.show()
import plotly.graph_objects as go

labels = [''] * model.tree_.node_count
parents = [''] * model.tree_.node_count
labels[0] = 'root'
for i, (f, t, l, r) in enumerate(zip(
model.tree_.feature,
model.tree_.threshold,
model.tree_.children_left,
model.tree_.children_right,
)):
if l != r:
labels[l] = f'{feature_names[f]} <= {t:g}'
labels[r] = f'{feature_names[f]} > {t:g}'
parents[l] = parents[r] = labels[i]

fig = go.Figure(go.Treemap(
branchvalues='total',
labels=labels,
parents=parents,
values=model.tree_.n_node_samples,
textinfo='label+value+percent root',
marker=dict(colors=model.tree_.impurity),
customdata=list(map(str, model.tree_.value)),
hovertemplate='''
<b>%{label}</b><br>
impurity: %{color}<br>
samples: %{value} (%{percentRoot:%.2f})<br>
value: %{customdata}'''
))
fig.show()

The major advantages it gives are:

  1. Interactiveness, Node seems to small to read ? Click on it and it will expand.
  2. Easily coloring setting
  3. Constant shape and easy to download in png and HTML format

Some thing to keep in mind however,

Unlike plot_tree, you can't color each class, so it can be difficult to use without binary classification or regression: sweat_smile:

References

--

--

Devesh Surve
Devesh Surve

Written by Devesh Surve

Grad student by day, lifelong ML/AI explorer by night. I dive deep, then share easy-to-understand, step-by-step guides to demystify the complex.

No responses yet