window.dataLayer = window.dataLayer || []; function gtag(){dataLayer.push(arguments);} gtag('js', new Date()); gtag('config', 'UA-165882250-1'); Skip to content
Bays Consulting Logo
  • Home
  • About
  • Consultancy
  • Case Studies
  • Training
  • MTN
  • Blog
  • Contact us
Home/SK Learn – Decision Trees and Penguins
Previous Next
SK Learn – Decision Trees and Penguins

Decision Trees (DTs) are a non-parametric supervised learning method used for classification and regression. The goal is to create a model that predicts the value of a target variable by learning simple decision rules inferred from the data features. We’re going to try to predict the species of penguins based on their body measurements.

Firstly we’ll import the .csv file and take a look at the data.

In [129]:
import pandas as pd

df = pd.read_csv(r'C:\Users\hanna\Downloads\penguins.csv')
df.head(5)
Out[129]:
species island culmen_length_mm culmen_depth_mm flipper_length_mm body_mass_g sex
0 Adelie Torgersen 39.1 18.7 181.0 3750.0 MALE
1 Adelie Torgersen 39.5 17.4 186.0 3800.0 FEMALE
2 Adelie Torgersen 40.3 18.0 195.0 3250.0 FEMALE
3 Adelie Torgersen NaN NaN NaN NaN NaN
4 Adelie Torgersen 36.7 19.3 193.0 3450.0 FEMALE

 

df.info() and df.decribe() are simple and awesome ways of checking the data available.

In [130]:
print(df.info())
print(df.describe())
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 344 entries, 0 to 343
Data columns (total 7 columns):
 #   Column             Non-Null Count  Dtype  
---  ------             --------------  -----  
 0   species            344 non-null    object 
 1   island             344 non-null    object 
 2   culmen_length_mm   342 non-null    float64
 3   culmen_depth_mm    342 non-null    float64
 4   flipper_length_mm  342 non-null    float64
 5   body_mass_g        342 non-null    float64
 6   sex                334 non-null    object 
dtypes: float64(4), object(3)
memory usage: 18.9+ KB
None
       culmen_length_mm  culmen_depth_mm  flipper_length_mm  body_mass_g
count        342.000000       342.000000         342.000000   342.000000
mean          43.921930        17.151170         200.915205  4201.754386
std            5.459584         1.974793          14.061714   801.954536
min           32.100000        13.100000         172.000000  2700.000000
25%           39.225000        15.600000         190.000000  3550.000000
50%           44.450000        17.300000         197.000000  4050.000000
75%           48.500000        18.700000         213.000000  4750.000000
max           59.600000        21.500000         231.000000  6300.000000

Thankfully I already know this dataset so I’m going to move straight to dropping null values as we won’t need those.

In [131]:
df.dropna(inplace=True)
df.isna().sum()
Out[131]:
species              0
island               0
culmen_length_mm     0
culmen_depth_mm      0
flipper_length_mm    0
body_mass_g          0
sex                  0
dtype: int64

Seaborn is a fantastic library for visualising datasets.

In [132]:
import seaborn as sns

sns.boxplot(x="culmen_length_mm",data=df,color ="orange"); 

As we know we want to look at species I want to know how many of each species we have in the dataset.

In [133]:
df.species.value_counts()
Out[133]:
Adelie       146
Gentoo       120
Chinstrap     68
Name: species, dtype: int64

.corr is a great way to instantly see the correlation between numerical columns. 1 is completely correlated. Anything negative is negatively correlated.

In [134]:
corr = df.corr()
corr
Out[134]:
culmen_length_mm culmen_depth_mm flipper_length_mm body_mass_g
culmen_length_mm 1.000000 -0.228640 0.652126 0.589066
culmen_depth_mm -0.228640 1.000000 -0.578730 -0.472987
flipper_length_mm 0.652126 -0.578730 1.000000 0.873211
body_mass_g 0.589066 -0.472987 0.873211 1.000000

If you only have a few columns to view then the seaborn heatmap is a great visual tool for correlation:

In [135]:
sns.heatmap(corr,vmin=-1, vmax=1, center=0,
    cmap=sns.diverging_palette(220, 20, as_cmap=True),
    square=True, annot = True
);

Splitting out the features so we know what we want to use to predict. We want the feature names to predict the species and the class names in the tree to be the species names. In this case we’re going to use all of the numerical columns.

In [136]:
feature_names = ['culmen_length_mm', 'culmen_depth_mm', 'flipper_length_mm', 'body_mass_g']
X = df[feature_names]
y = df['species']
class_names = df['species'].unique()

Training the Model

Now it’s time to train the model: We import the DecisionTree Classifier. We have chosen 10 leaves, and a max depth of 2. (Later on we can tweak the min_samples_leaf and the max_depth if we want to change the tree size to better suit the dataset). We are going to fit the X and y as defined above.

In [137]:
from sklearn.tree import DecisionTreeClassifier

clf = DecisionTreeClassifier(
    min_samples_leaf=10,
    max_depth=2
)

clf.fit(X, y)
Out[137]:
DecisionTreeClassifier(max_depth=2, min_samples_leaf=10)

Now we can import Tree and we can see a little tree below!

In [138]:
from sklearn import tree

tree.plot_tree(clf);

The gini index (or gini impurity) calculates the amount of probability of a specific feature that is classified incorrectly when selected randomly.

The sample size is the sample size of penguins that make up each node.

The value is how many samples of that node fall into each category (Adelie, Chinstrap, Gentoo).

The true values go to the left (as you look at the tree) and the false values go to the right. You can see this easier below where the Adelie node that begins the tree grows down to the left. The false values form the first node go to the right and that’s where we can see the Gentoo node. We don’t see Chinstrap until the next layer down.

The tree is pretty good but this is where matplotlib comes in handy for visualising things… The deeper the colour, the better the match. These colours look pretty good to me.

In [139]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(
    nrows=1,
    ncols=1,
    figsize=(3, 4),
    dpi=300
)

tree.plot_tree(
    clf,
    feature_names=feature_names, 
    class_names=class_names,
    filled=True,
    fontsize=4
);

Now we can make our own predictions if we wish. We can input our own measurements into the model and see what it comes out with. The prediction is in the order of: ‘culmen_length_mm’, ‘culmen_depth_mm’, ‘flipper_length_mm’, ‘body_mass_g’

The model prediction is that these measurements are for a ‘Gentoo’ penguin:

In [140]:
clf.predict([
    [40, 16, 220, 3500]
])
Out[140]:
array(['Gentoo'], dtype=object)
Now it’s time to evaluate the model:

We’ll take a sample of 10 rows. Sample_X is the feature names as we defined earlier. Sample_y is the species name. Let’s have a look at them…

In [141]:
sample = df.sample(10)
sample_X = sample[feature_names]
sample_y = sample['species']
sample_X
Out[141]:
culmen_length_mm culmen_depth_mm flipper_length_mm body_mass_g
257 44.4 17.3 219.0 5250.0
260 42.7 13.7 208.0 3950.0
34 36.4 17.0 195.0 3325.0
102 37.7 16.0 183.0 3075.0
157 45.2 17.8 198.0 3950.0
221 50.0 16.3 230.0 5700.0
150 36.0 17.1 187.0 3700.0
172 42.4 17.3 181.0 3600.0
105 39.7 18.9 184.0 3550.0
95 40.8 18.9 208.0 4300.0
In [142]:
sample_y
Out[142]:
257       Gentoo
260       Gentoo
34        Adelie
102       Adelie
157    Chinstrap
221       Gentoo
150       Adelie
172    Chinstrap
105       Adelie
95        Adelie
Name: species, dtype: object

Now we’re going to predict the penguin species from the measurements we’ve added as variables (sample_X) and then compare them against our known species names for those values (sample_y) and see if it gets them right.

In [143]:
predictions = clf.predict(sample_X)
predictions
Out[143]:
array(['Chinstrap', 'Gentoo', 'Adelie', 'Adelie', 'Chinstrap', 'Gentoo',
       'Adelie', 'Adelie', 'Adelie', 'Chinstrap'], dtype=object)

It’s only a small sample but this shows that the model predictions using the feature_names matched the known species names of the penguins. It’s a good start.

In [144]:
predictions == sample_y
Out[144]:
257    False
260     True
34      True
102     True
157     True
221     True
150     True
172    False
105     True
95     False
Name: species, dtype: bool

We can see above that there were incorrect predictions.

train_test_split

Now it’s time to split the data to check that the decision tree works.

We’ll train the model on 80% and test it on 20% of the data.

In [145]:
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(
    df[feature_names],
    df['species'],
    test_size=0.2
)
In [146]:
model = DecisionTreeClassifier(
    max_depth=2,
    min_samples_leaf=12,
    random_state=1
)

model.fit(X_train, y_train)
Out[146]:
DecisionTreeClassifier(max_depth=2, min_samples_leaf=12, random_state=1)
In [147]:
ground_truth = y_test

predictions = model.predict(X_test)

predictions == ground_truth
Out[147]:
310     True
336     True
248     True
57      True
305    False
       ...  
174     True
1       True
110     True
79      True
90      True
Name: species, Length: 67, dtype: bool

Time to score the model out of 1.