.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/basic/plot_cm_acc_multiclass.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_basic_plot_cm_acc_multiclass.py: Multiclass Classification ============================ This example uses the 'iris' dataset and performs multiclass classification using a Support Vector Machine classifier and plots heatmaps for cross-validation accuracies and plots confusion matrix for the test data. .. GENERATED FROM PYTHON SOURCE LINES 11-27 .. code-block:: default # Authors: Shammi More # Federico Raimondo # # License: AGPL import pandas as pd import seaborn as sns import numpy as np import matplotlib.pyplot as plt from seaborn import load_dataset from sklearn.model_selection import train_test_split from sklearn.metrics import confusion_matrix from julearn import run_cross_validation from julearn.utils import configure_logging .. GENERATED FROM PYTHON SOURCE LINES 28-29 Set the logging level to info to see extra information .. GENERATED FROM PYTHON SOURCE LINES 29-35 .. code-block:: default configure_logging(level='INFO') df_iris = load_dataset('iris') X = ['sepal_length', 'sepal_width', 'petal_length'] y = 'species' .. rst-class:: sphx-glr-script-out Out: .. code-block:: none 2021-01-28 20:07:58,068 - julearn - INFO - ===== Lib Versions ===== 2021-01-28 20:07:58,068 - julearn - INFO - numpy: 1.19.5 2021-01-28 20:07:58,068 - julearn - INFO - scipy: 1.6.0 2021-01-28 20:07:58,068 - julearn - INFO - sklearn: 0.24.1 2021-01-28 20:07:58,068 - julearn - INFO - pandas: 1.2.1 2021-01-28 20:07:58,068 - julearn - INFO - julearn: 0.2.5.dev19+g9c15c5f 2021-01-28 20:07:58,068 - julearn - INFO - ======================== .. GENERATED FROM PYTHON SOURCE LINES 36-37 Split the dataset into train and test .. GENERATED FROM PYTHON SOURCE LINES 37-40 .. code-block:: default train_iris, test_iris = train_test_split(df_iris, test_size=0.2, stratify=df_iris[y]) .. GENERATED FROM PYTHON SOURCE LINES 41-42 Perform multiclass classification as iris dataset contains 3 kinds of species .. GENERATED FROM PYTHON SOURCE LINES 42-47 .. code-block:: default scores, model_iris = run_cross_validation( X=X, y=y, data=train_iris, model='svm', preprocess_X='zscore', problem_type='multiclass_classification', scoring=['accuracy'], return_estimator='final') .. rst-class:: sphx-glr-script-out Out: .. code-block:: none 2021-01-28 20:07:58,073 - julearn - INFO - Using default CV 2021-01-28 20:07:58,073 - julearn - INFO - ==== Input Data ==== 2021-01-28 20:07:58,073 - julearn - INFO - Using dataframe as input 2021-01-28 20:07:58,073 - julearn - INFO - Features: ['sepal_length', 'sepal_width', 'petal_length'] 2021-01-28 20:07:58,073 - julearn - INFO - Target: species 2021-01-28 20:07:58,073 - julearn - INFO - Expanded X: ['sepal_length', 'sepal_width', 'petal_length'] 2021-01-28 20:07:58,073 - julearn - INFO - Expanded Confounds: [] 2021-01-28 20:07:58,074 - julearn - INFO - ==================== 2021-01-28 20:07:58,074 - julearn - INFO - 2021-01-28 20:07:58,074 - julearn - INFO - ====== Model ====== 2021-01-28 20:07:58,074 - julearn - INFO - Obtaining model by name: svm 2021-01-28 20:07:58,074 - julearn - INFO - =================== 2021-01-28 20:07:58,074 - julearn - INFO - 2021-01-28 20:07:58,074 - julearn - INFO - CV interpreted as RepeatedKFold with 5 repetitions of 5 folds .. GENERATED FROM PYTHON SOURCE LINES 48-49 The scores dataframe has all the values for each CV split. .. GENERATED FROM PYTHON SOURCE LINES 49-52 .. code-block:: default print(scores.head()) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none fit_time score_time test_accuracy repeat fold 0 0.012808 0.007851 0.875000 0 0 1 0.011702 0.007467 0.916667 0 1 2 0.011454 0.007438 0.875000 0 2 3 0.011553 0.007664 0.875000 0 3 4 0.011490 0.007525 1.000000 0 4 .. GENERATED FROM PYTHON SOURCE LINES 53-54 Now we can get the accuracy per fold and repetition: .. GENERATED FROM PYTHON SOURCE LINES 54-61 .. code-block:: default df_accuracy = scores.set_index( ['repeat', 'fold'])['test_accuracy'].unstack() df_accuracy.index.name = 'Repeats' df_accuracy.columns.name = 'K-fold splits' print(df_accuracy) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none K-fold splits 0 1 2 3 4 Repeats 0 0.875000 0.916667 0.875000 0.875000 1.000000 1 0.958333 0.916667 0.833333 0.875000 0.958333 2 0.958333 0.958333 0.916667 0.916667 0.916667 3 0.833333 0.916667 0.875000 0.958333 0.958333 4 1.000000 0.875000 0.833333 0.916667 0.875000 .. GENERATED FROM PYTHON SOURCE LINES 62-63 Plot heatmap of accuracy over all repeats and CV splits .. GENERATED FROM PYTHON SOURCE LINES 63-68 .. code-block:: default sns.set(font_scale=1.2) fig, ax = plt.subplots(1, 1, figsize=(10, 7)) sns.heatmap(df_accuracy, cmap="YlGnBu") plt.title('Cross-validation Accuracy') .. image:: /auto_examples/basic/images/sphx_glr_plot_cm_acc_multiclass_001.png :alt: Cross-validation Accuracy :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out Out: .. code-block:: none Text(0.5, 1.0, 'Cross-validation Accuracy') .. GENERATED FROM PYTHON SOURCE LINES 69-71 We can also test our final model's accuracy and plot the confusion matrix for the test data as an annotated heatmap .. GENERATED FROM PYTHON SOURCE LINES 71-76 .. code-block:: default y_true = test_iris[y] y_pred = model_iris.predict(test_iris[X]) cm = confusion_matrix(y_true, y_pred, labels=np.unique(y_true)) print(cm) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none [[10 0 0] [ 0 10 0] [ 0 1 9]] .. GENERATED FROM PYTHON SOURCE LINES 77-79 Now that we have our confusion matrix, let's build another matrix with annotations. .. GENERATED FROM PYTHON SOURCE LINES 79-92 .. code-block:: default cm_sum = np.sum(cm, axis=1, keepdims=True) cm_perc = cm / cm_sum.astype(float) * 100 annot = np.empty_like(cm).astype(str) nrows, ncols = cm.shape for i in range(nrows): for j in range(ncols): c = cm[i, j] p = cm_perc[i, j] if c == 0: annot[i, j] = '' else: s = cm_sum[i] annot[i, j] = '%.1f%%\n%d/%d' % (p, c, s) .. GENERATED FROM PYTHON SOURCE LINES 93-95 Finally we create another dataframe with the confusion matrix and plot the heatmap with annotations. .. GENERATED FROM PYTHON SOURCE LINES 95-101 .. code-block:: default cm = pd.DataFrame(cm, index=np.unique(y_true), columns=np.unique(y_true)) cm.index.name = 'Actual' cm.columns.name = 'Predicted' fig, ax = plt.subplots(1, 1, figsize=(10, 7)) sns.heatmap(cm, cmap="YlGnBu", annot=annot, fmt='', ax=ax) plt.title('Confusion matrix') .. image:: /auto_examples/basic/images/sphx_glr_plot_cm_acc_multiclass_002.png :alt: Confusion matrix :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out Out: .. code-block:: none Text(0.5, 1.0, 'Confusion matrix') .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 1.002 seconds) .. _sphx_glr_download_auto_examples_basic_plot_cm_acc_multiclass.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_cm_acc_multiclass.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_cm_acc_multiclass.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_