Multiclass Classification

This example uses the ‘iris’ dataset and performs multiclass classification using a Support Vector Machine classifier and plots heatmaps for cross-validation accuracies and plots confusion matrix for the test data.

# Authors: Shammi More <s.more@fz-juelich.de>
#          Federico Raimondo <f.raimondo@fz-juelich.de>
#
# License: AGPL

import pandas as pd
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
from seaborn import load_dataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix

from julearn import run_cross_validation
from julearn.utils import configure_logging

Set the logging level to info to see extra information

configure_logging(level='INFO')

df_iris = load_dataset('iris')
X = ['sepal_length', 'sepal_width', 'petal_length']
y = 'species'

Out:

2021-01-28 20:09:00,853 - julearn - INFO - ===== Lib Versions =====
2021-01-28 20:09:00,853 - julearn - INFO - numpy: 1.19.5
2021-01-28 20:09:00,853 - julearn - INFO - scipy: 1.6.0
2021-01-28 20:09:00,853 - julearn - INFO - sklearn: 0.24.1
2021-01-28 20:09:00,853 - julearn - INFO - pandas: 1.2.1
2021-01-28 20:09:00,853 - julearn - INFO - julearn: 0.2.5.dev19+g9c15c5f
2021-01-28 20:09:00,853 - julearn - INFO - ========================

Split the dataset into train and test

train_iris, test_iris = train_test_split(df_iris, test_size=0.2,
                                         stratify=df_iris[y])

Perform multiclass classification as iris dataset contains 3 kinds of species

scores, model_iris = run_cross_validation(
    X=X, y=y, data=train_iris, model='svm', preprocess_X='zscore',
    problem_type='multiclass_classification', scoring=['accuracy'],
    return_estimator='final')

Out:

2021-01-28 20:09:00,858 - julearn - INFO - Using default CV
2021-01-28 20:09:00,858 - julearn - INFO - ==== Input Data ====
2021-01-28 20:09:00,858 - julearn - INFO - Using dataframe as input
2021-01-28 20:09:00,858 - julearn - INFO - Features: ['sepal_length', 'sepal_width', 'petal_length']
2021-01-28 20:09:00,858 - julearn - INFO - Target: species
2021-01-28 20:09:00,858 - julearn - INFO - Expanded X: ['sepal_length', 'sepal_width', 'petal_length']
2021-01-28 20:09:00,858 - julearn - INFO - Expanded Confounds: []
2021-01-28 20:09:00,859 - julearn - INFO - ====================
2021-01-28 20:09:00,859 - julearn - INFO -
2021-01-28 20:09:00,859 - julearn - INFO - ====== Model ======
2021-01-28 20:09:00,859 - julearn - INFO - Obtaining model by name: svm
2021-01-28 20:09:00,859 - julearn - INFO - ===================
2021-01-28 20:09:00,859 - julearn - INFO -
2021-01-28 20:09:00,859 - julearn - INFO - CV interpreted as RepeatedKFold with 5 repetitions of 5 folds

The scores dataframe has all the values for each CV split.

print(scores.head())

Out:

   fit_time  score_time  test_accuracy  repeat  fold
0  0.011665    0.007552       0.875000       0     0
1  0.011276    0.007438       0.916667       0     1
2  0.011277    0.007473       0.875000       0     2
3  0.011434    0.007445       0.875000       0     3
4  0.011578    0.007531       1.000000       0     4

Now we can get the accuracy per fold and repetition:

df_accuracy = scores.set_index(
    ['repeat', 'fold'])['test_accuracy'].unstack()
df_accuracy.index.name = 'Repeats'
df_accuracy.columns.name = 'K-fold splits'
print(df_accuracy)

Out:

K-fold splits         0         1         2         3         4
Repeats
0              0.875000  0.916667  0.875000  0.875000  1.000000
1              0.958333  0.916667  0.833333  0.875000  0.958333
2              0.958333  0.958333  0.916667  0.916667  0.916667
3              0.833333  0.916667  0.875000  0.958333  0.958333
4              1.000000  0.875000  0.833333  0.916667  0.875000

Plot heatmap of accuracy over all repeats and CV splits

sns.set(font_scale=1.2)
fig, ax = plt.subplots(1, 1, figsize=(10, 7))
sns.heatmap(df_accuracy, cmap="YlGnBu")
plt.title('Cross-validation Accuracy')
Cross-validation Accuracy

Out:

Text(0.5, 1.0, 'Cross-validation Accuracy')

We can also test our final model’s accuracy and plot the confusion matrix for the test data as an annotated heatmap

y_true = test_iris[y]
y_pred = model_iris.predict(test_iris[X])
cm = confusion_matrix(y_true, y_pred, labels=np.unique(y_true))

print(cm)

Out:

[[10  0  0]
 [ 0 10  0]
 [ 0  1  9]]

Now that we have our confusion matrix, let’s build another matrix with annotations.

cm_sum = np.sum(cm, axis=1, keepdims=True)
cm_perc = cm / cm_sum.astype(float) * 100
annot = np.empty_like(cm).astype(str)
nrows, ncols = cm.shape
for i in range(nrows):
    for j in range(ncols):
        c = cm[i, j]
        p = cm_perc[i, j]
        if c == 0:
            annot[i, j] = ''
        else:
            s = cm_sum[i]
            annot[i, j] = '%.1f%%\n%d/%d' % (p, c, s)

Finally we create another dataframe with the confusion matrix and plot the heatmap with annotations.

cm = pd.DataFrame(cm, index=np.unique(y_true), columns=np.unique(y_true))
cm.index.name = 'Actual'
cm.columns.name = 'Predicted'
fig, ax = plt.subplots(1, 1, figsize=(10, 7))
sns.heatmap(cm, cmap="YlGnBu", annot=annot, fmt='', ax=ax)
plt.title('Confusion matrix')
Confusion matrix

Out:

Text(0.5, 1.0, 'Confusion matrix')

Total running time of the script: ( 0 minutes 1.002 seconds)

Gallery generated by Sphinx-Gallery