.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/00_starting/plot_cm_acc_multiclass.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_00_starting_plot_cm_acc_multiclass.py: Multiclass Classification ========================= This example uses the ``iris`` dataset and performs multiclass classification using a Support Vector Machine classifier and plots heatmaps for cross-validation accuracies and plots confusion matrix for the test data. .. GENERATED FROM PYTHON SOURCE LINES 11-26 .. code-block:: Python # Authors: Shammi More # Federico Raimondo # License: AGPL import pandas as pd import seaborn as sns import numpy as np import matplotlib.pyplot as plt from seaborn import load_dataset from sklearn.model_selection import train_test_split, RepeatedKFold from sklearn.metrics import confusion_matrix from julearn import run_cross_validation from julearn.utils import configure_logging .. GENERATED FROM PYTHON SOURCE LINES 27-28 Set the logging level to info to see extra information .. GENERATED FROM PYTHON SOURCE LINES 28-30 .. code-block:: Python configure_logging(level="INFO") .. rst-class:: sphx-glr-script-out .. code-block:: none 2026-01-16 10:53:54,082 - julearn - INFO - ===== Lib Versions ===== 2026-01-16 10:53:54,082 - julearn - INFO - numpy: 1.26.4 2026-01-16 10:53:54,082 - julearn - INFO - scipy: 1.17.0 2026-01-16 10:53:54,082 - julearn - INFO - sklearn: 1.7.2 2026-01-16 10:53:54,082 - julearn - INFO - pandas: 2.3.3 2026-01-16 10:53:54,083 - julearn - INFO - julearn: 0.3.5.dev123 2026-01-16 10:53:54,083 - julearn - INFO - ======================== .. GENERATED FROM PYTHON SOURCE LINES 31-32 load the iris data from seaborn .. GENERATED FROM PYTHON SOURCE LINES 32-36 .. code-block:: Python df_iris = load_dataset("iris") X = ["sepal_length", "sepal_width", "petal_length"] y = "species" .. GENERATED FROM PYTHON SOURCE LINES 37-38 Split the dataset into train and test .. GENERATED FROM PYTHON SOURCE LINES 38-42 .. code-block:: Python train_iris, test_iris = train_test_split( df_iris, test_size=0.2, stratify=df_iris[y], random_state=200 ) .. GENERATED FROM PYTHON SOURCE LINES 43-46 We want to perform multiclass classification as iris dataset contains 3 kinds of species. We will first zscore all the features and then train a support vector machine classifier. .. GENERATED FROM PYTHON SOURCE LINES 46-60 .. code-block:: Python cv = RepeatedKFold(n_splits=5, n_repeats=5, random_state=200) scores, model_iris = run_cross_validation( X=X, y=y, data=train_iris, model="svm", preprocess="zscore", problem_type="classification", cv=cv, scoring=["accuracy"], return_estimator="final", ) .. rst-class:: sphx-glr-script-out .. code-block:: none 2026-01-16 10:53:54,086 - julearn - INFO - ==== Input Data ==== 2026-01-16 10:53:54,086 - julearn - INFO - Using dataframe as input 2026-01-16 10:53:54,086 - julearn - INFO - Features: ['sepal_length', 'sepal_width', 'petal_length'] 2026-01-16 10:53:54,086 - julearn - INFO - Target: species 2026-01-16 10:53:54,086 - julearn - INFO - Expanded features: ['sepal_length', 'sepal_width', 'petal_length'] 2026-01-16 10:53:54,087 - julearn - INFO - X_types:{} 2026-01-16 10:53:54,087 - julearn - WARNING - The following columns are not defined in X_types: ['sepal_length', 'sepal_width', 'petal_length']. They will be treated as continuous. /home/runner/work/julearn/julearn/julearn/prepare.py:576: RuntimeWarning: The following columns are not defined in X_types: ['sepal_length', 'sepal_width', 'petal_length']. They will be treated as continuous. warn_with_log( 2026-01-16 10:53:54,087 - julearn - INFO - ==================== 2026-01-16 10:53:54,087 - julearn - INFO - 2026-01-16 10:53:54,088 - julearn - INFO - Adding step zscore that applies to ColumnTypes 2026-01-16 10:53:54,088 - julearn - INFO - Step added 2026-01-16 10:53:54,088 - julearn - INFO - Adding step svm that applies to ColumnTypes 2026-01-16 10:53:54,088 - julearn - INFO - Step added 2026-01-16 10:53:54,089 - julearn - INFO - = Model Parameters = 2026-01-16 10:53:54,089 - julearn - INFO - ==================== 2026-01-16 10:53:54,089 - julearn - INFO - 2026-01-16 10:53:54,089 - julearn - INFO - = Data Information = 2026-01-16 10:53:54,089 - julearn - INFO - Problem type: classification 2026-01-16 10:53:54,089 - julearn - INFO - Number of samples: 120 2026-01-16 10:53:54,089 - julearn - INFO - Number of features: 3 2026-01-16 10:53:54,089 - julearn - INFO - ==================== 2026-01-16 10:53:54,089 - julearn - INFO - 2026-01-16 10:53:54,090 - julearn - INFO - Number of classes: 3 2026-01-16 10:53:54,090 - julearn - INFO - Target type: object 2026-01-16 10:53:54,090 - julearn - INFO - Class distributions: species versicolor 40 virginica 40 setosa 40 Name: count, dtype: int64 2026-01-16 10:53:54,091 - julearn - INFO - Using outer CV scheme RepeatedKFold(n_repeats=5, n_splits=5, random_state=200) (incl. final model) 2026-01-16 10:53:54,091 - julearn - INFO - Multi-class classification problem detected #classes = 3. .. GENERATED FROM PYTHON SOURCE LINES 61-62 The scores dataframe has all the values for each CV split. .. GENERATED FROM PYTHON SOURCE LINES 62-65 .. code-block:: Python scores.head() .. raw:: html
fit_time score_time test_accuracy n_train n_test repeat fold cv_mdsum
0 0.004745 0.002918 0.916667 96 24 0 0 fa5ab7a2b930761687a8e82d9971ebca
1 0.004689 0.002916 0.833333 96 24 0 1 fa5ab7a2b930761687a8e82d9971ebca
2 0.004912 0.003066 0.958333 96 24 0 2 fa5ab7a2b930761687a8e82d9971ebca
3 0.005040 0.003126 0.916667 96 24 0 3 fa5ab7a2b930761687a8e82d9971ebca
4 0.004966 0.003024 0.833333 96 24 0 4 fa5ab7a2b930761687a8e82d9971ebca


.. GENERATED FROM PYTHON SOURCE LINES 66-67 Now we can get the accuracy per fold and repetition: .. GENERATED FROM PYTHON SOURCE LINES 67-73 .. code-block:: Python df_accuracy = scores.set_index(["repeat", "fold"])["test_accuracy"].unstack() df_accuracy.index.name = "Repeats" df_accuracy.columns.name = "K-fold splits" df_accuracy .. raw:: html
K-fold splits 0 1 2 3 4
Repeats
0 0.916667 0.833333 0.958333 0.916667 0.833333
1 0.875000 0.833333 0.916667 0.833333 0.833333
2 0.750000 0.916667 0.916667 0.958333 0.916667
3 1.000000 0.791667 0.875000 1.000000 0.791667
4 0.875000 0.833333 0.875000 0.916667 0.958333


.. GENERATED FROM PYTHON SOURCE LINES 74-75 Plot heatmap of accuracy over all repeats and CV splits .. GENERATED FROM PYTHON SOURCE LINES 75-80 .. code-block:: Python sns.set(font_scale=1.2) fig, ax = plt.subplots(1, 1, figsize=(10, 7)) sns.heatmap(df_accuracy, cmap="YlGnBu") plt.title("Cross-validation Accuracy") .. image-sg:: /auto_examples/00_starting/images/sphx_glr_plot_cm_acc_multiclass_001.png :alt: Cross-validation Accuracy :srcset: /auto_examples/00_starting/images/sphx_glr_plot_cm_acc_multiclass_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none Text(0.5, 1.0, 'Cross-validation Accuracy') .. GENERATED FROM PYTHON SOURCE LINES 81-83 We can also test our final model's accuracy and plot the confusion matrix for the test data as an annotated heatmap .. GENERATED FROM PYTHON SOURCE LINES 83-89 .. code-block:: Python y_true = test_iris[y] y_pred = model_iris.predict(test_iris[X]) cm = confusion_matrix(y_true, y_pred, labels=np.unique(y_true)) print(cm) .. rst-class:: sphx-glr-script-out .. code-block:: none [[9 1 0] [0 9 1] [0 2 8]] .. GENERATED FROM PYTHON SOURCE LINES 90-92 Now that we have our confusion matrix, let's build another matrix with annotations. .. GENERATED FROM PYTHON SOURCE LINES 92-106 .. code-block:: Python cm_sum = np.sum(cm, axis=1, keepdims=True) cm_perc = cm / cm_sum.astype(float) * 100 annot = np.empty_like(cm).astype(str) nrows, ncols = cm.shape for i in range(nrows): for j in range(ncols): c = cm[i, j] p = cm_perc[i, j] if c == 0: annot[i, j] = "" else: s = cm_sum[i] annot[i, j] = "%.1f%%\n%d/%d" % (p, c, s) .. rst-class:: sphx-glr-script-out .. code-block:: none /home/runner/work/julearn/julearn/examples/00_starting/plot_cm_acc_multiclass.py:104: DeprecationWarning: Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.) annot[i, j] = "%.1f%%\n%d/%d" % (p, c, s) .. GENERATED FROM PYTHON SOURCE LINES 107-109 Finally we create another dataframe with the confusion matrix and plot the heatmap with annotations. .. GENERATED FROM PYTHON SOURCE LINES 109-116 .. code-block:: Python cm = pd.DataFrame(cm, index=np.unique(y_true), columns=np.unique(y_true)) cm.index.name = "Actual" cm.columns.name = "Predicted" fig, ax = plt.subplots(1, 1, figsize=(10, 7)) sns.heatmap(cm, cmap="YlGnBu", annot=annot, fmt="", ax=ax) plt.title("Confusion matrix") .. image-sg:: /auto_examples/00_starting/images/sphx_glr_plot_cm_acc_multiclass_002.png :alt: Confusion matrix :srcset: /auto_examples/00_starting/images/sphx_glr_plot_cm_acc_multiclass_002.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none Text(0.5, 1.0, 'Confusion matrix') .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 0.530 seconds) .. _sphx_glr_download_auto_examples_00_starting_plot_cm_acc_multiclass.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_cm_acc_multiclass.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_cm_acc_multiclass.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_cm_acc_multiclass.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_