Multiclass Classification.#

This example uses the ‘iris’ dataset and performs multiclass classification using a Support Vector Machine classifier and plots heatmaps for cross-validation accuracies and plots confusion matrix for the test data.

# Authors: Shammi More <>
#          Federico Raimondo <>
# License: AGPL

import pandas as pd
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
from seaborn import load_dataset
from sklearn.model_selection import train_test_split, RepeatedKFold
from sklearn.metrics import confusion_matrix

from julearn import run_cross_validation
from julearn.utils import configure_logging

Set the logging level to info to see extra information

2023-07-19 12:41:46,649 - julearn - INFO - ===== Lib Versions =====
2023-07-19 12:41:46,649 - julearn - INFO - numpy: 1.25.1
2023-07-19 12:41:46,649 - julearn - INFO - scipy: 1.11.1
2023-07-19 12:41:46,649 - julearn - INFO - sklearn: 1.3.0
2023-07-19 12:41:46,649 - julearn - INFO - pandas: 2.0.3
2023-07-19 12:41:46,649 - julearn - INFO - julearn: 0.3.1.dev1
2023-07-19 12:41:46,650 - julearn - INFO - ========================

load the iris data from seaborn

df_iris = load_dataset("iris")
X = ["sepal_length", "sepal_width", "petal_length"]
y = "species"

Split the dataset into train and test

train_iris, test_iris = train_test_split(
    df_iris, test_size=0.2, stratify=df_iris[y], random_state=200

We want to perform multiclass classification as iris dataset contains 3 kinds of species. We will first zscore all the features and then train a support vector machine classifier.

cv = RepeatedKFold(n_splits=5, n_repeats=5, random_state=200)
scores, model_iris = run_cross_validation(
2023-07-19 12:41:46,654 - julearn - INFO - ==== Input Data ====
2023-07-19 12:41:46,654 - julearn - INFO - Using dataframe as input
2023-07-19 12:41:46,654 - julearn - INFO -      Features: ['sepal_length', 'sepal_width', 'petal_length']
2023-07-19 12:41:46,654 - julearn - INFO -      Target: species
2023-07-19 12:41:46,654 - julearn - INFO -      Expanded features: ['sepal_length', 'sepal_width', 'petal_length']
2023-07-19 12:41:46,654 - julearn - INFO -      X_types:{}
2023-07-19 12:41:46,654 - julearn - WARNING - The following columns are not defined in X_types: ['sepal_length', 'sepal_width', 'petal_length']. They will be treated as continuous.
/home/runner/work/julearn/julearn/julearn/utils/ RuntimeWarning: The following columns are not defined in X_types: ['sepal_length', 'sepal_width', 'petal_length']. They will be treated as continuous.
  warn(msg, category=category)
2023-07-19 12:41:46,655 - julearn - INFO - ====================
2023-07-19 12:41:46,655 - julearn - INFO -
2023-07-19 12:41:46,655 - julearn - INFO - Adding step zscore that applies to ColumnTypes<types={'continuous'}; pattern=(?:__:type:__continuous)>
2023-07-19 12:41:46,655 - julearn - INFO - Step added
2023-07-19 12:41:46,655 - julearn - INFO - Adding step svm that applies to ColumnTypes<types={'continuous'}; pattern=(?:__:type:__continuous)>
2023-07-19 12:41:46,656 - julearn - INFO - Step added
2023-07-19 12:41:46,656 - julearn - INFO - = Model Parameters =
2023-07-19 12:41:46,656 - julearn - INFO - ====================
2023-07-19 12:41:46,656 - julearn - INFO -
2023-07-19 12:41:46,656 - julearn - INFO - = Data Information =
2023-07-19 12:41:46,656 - julearn - INFO -      Problem type: classification
2023-07-19 12:41:46,656 - julearn - INFO -      Number of samples: 120
2023-07-19 12:41:46,656 - julearn - INFO -      Number of features: 3
2023-07-19 12:41:46,656 - julearn - INFO - ====================
2023-07-19 12:41:46,657 - julearn - INFO -
2023-07-19 12:41:46,657 - julearn - INFO -      Number of classes: 3
2023-07-19 12:41:46,657 - julearn - INFO -      Target type: object
2023-07-19 12:41:46,657 - julearn - INFO -      Class distributions: species
versicolor    40
virginica     40
setosa        40
Name: count, dtype: int64
2023-07-19 12:41:46,658 - julearn - INFO - Using outer CV scheme RepeatedKFold(n_repeats=5, n_splits=5, random_state=200)
2023-07-19 12:41:46,658 - julearn - INFO - Multi-class classification problem detected #classes = 3.

The scores dataframe has all the values for each CV split.

   fit_time  score_time  ...  fold                          cv_mdsum
0  0.006091    0.003277  ...     0  fa5ab7a2b930761687a8e82d9971ebca
1  0.005671    0.003219  ...     1  fa5ab7a2b930761687a8e82d9971ebca
2  0.005698    0.003213  ...     2  fa5ab7a2b930761687a8e82d9971ebca
3  0.005677    0.003239  ...     3  fa5ab7a2b930761687a8e82d9971ebca
4  0.005594    0.003202  ...     4  fa5ab7a2b930761687a8e82d9971ebca

[5 rows x 8 columns]

Now we can get the accuracy per fold and repetition:

df_accuracy = scores.set_index(["repeat", "fold"])["test_accuracy"].unstack() = "Repeats" = "K-fold splits"
K-fold splits         0         1         2         3         4
0              0.916667  0.833333  0.958333  0.916667  0.833333
1              0.875000  0.833333  0.916667  0.833333  0.833333
2              0.750000  0.916667  0.916667  0.958333  0.916667
3              1.000000  0.791667  0.875000  1.000000  0.791667
4              0.875000  0.833333  0.875000  0.916667  0.958333

Plot heatmap of accuracy over all repeats and CV splits

fig, ax = plt.subplots(1, 1, figsize=(10, 7))
sns.heatmap(df_accuracy, cmap="YlGnBu")
plt.title("Cross-validation Accuracy")
Cross-validation Accuracy
Text(0.5, 1.0, 'Cross-validation Accuracy')

We can also test our final model’s accuracy and plot the confusion matrix for the test data as an annotated heatmap

[[9 1 0]
 [0 9 1]
 [0 2 8]]

Now that we have our confusion matrix, let’s build another matrix with annotations.

cm_sum = np.sum(cm, axis=1, keepdims=True)
cm_perc = cm / cm_sum.astype(float) * 100
annot = np.empty_like(cm).astype(str)
nrows, ncols = cm.shape
for i in range(nrows):
    for j in range(ncols):
        c = cm[i, j]
        p = cm_perc[i, j]
        if c == 0:
            annot[i, j] = ""
            s = cm_sum[i]
            annot[i, j] = "%.1f%%\n%d/%d" % (p, c, s)
/tmp/tmpy4hmj28m/361c4ba107896ce3e9b14e5ca2d4d851dff85b11/examples/00_starting/ DeprecationWarning: Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)
  annot[i, j] = "%.1f%%\n%d/%d" % (p, c, s)

Finally we create another dataframe with the confusion matrix and plot the heatmap with annotations.

cm = pd.DataFrame(cm, index=np.unique(y_true), columns=np.unique(y_true)) = "Actual" = "Predicted"
fig, ax = plt.subplots(1, 1, figsize=(10, 7))
sns.heatmap(cm, cmap="YlGnBu", annot=annot, fmt="", ax=ax)
plt.title("Confusion matrix")
Confusion matrix
Text(0.5, 1.0, 'Confusion matrix')

Total running time of the script: ( 0 minutes 0.602 seconds)

Gallery generated by Sphinx-Gallery