Simple Model Comparison#

This example uses the ‘iris’ dataset and performs binary classifications using different models. At the end, it compares the performance of the models using different scoring functions and performs a statistical test to assess whether the difference in performance is significant.

# Authors: Federico Raimondo <f.raimondo@fz-juelich.de>
#
# License: AGPL

from seaborn import load_dataset
from sklearn.model_selection import RepeatedStratifiedKFold
from julearn import run_cross_validation
from julearn.utils import configure_logging
from julearn.stats.corrected_ttest import corrected_ttest

Set the logging level to info to see extra information

configure_logging(level="INFO")
2023-07-19 12:41:49,307 - julearn - INFO - ===== Lib Versions =====
2023-07-19 12:41:49,308 - julearn - INFO - numpy: 1.25.1
2023-07-19 12:41:49,308 - julearn - INFO - scipy: 1.11.1
2023-07-19 12:41:49,308 - julearn - INFO - sklearn: 1.3.0
2023-07-19 12:41:49,308 - julearn - INFO - pandas: 2.0.3
2023-07-19 12:41:49,308 - julearn - INFO - julearn: 0.3.1.dev1
2023-07-19 12:41:49,308 - julearn - INFO - ========================
df_iris = load_dataset("iris")

The dataset has three kind of species. We will keep two to perform a binary classification.

df_iris = df_iris[df_iris["species"].isin(["versicolor", "virginica"])]

As features, we will use the sepal length, width and petal length. We will try to predict the species.

X = ["sepal_length", "sepal_width", "petal_length"]
y = "species"
scores = run_cross_validation(
    X=X,
    y=y,
    data=df_iris,
    model="svm",
    problem_type="classification",
    preprocess="zscore",
)

print(scores["test_score"])
2023-07-19 12:41:49,311 - julearn - INFO - ==== Input Data ====
2023-07-19 12:41:49,311 - julearn - INFO - Using dataframe as input
2023-07-19 12:41:49,311 - julearn - INFO -      Features: ['sepal_length', 'sepal_width', 'petal_length']
2023-07-19 12:41:49,311 - julearn - INFO -      Target: species
2023-07-19 12:41:49,311 - julearn - INFO -      Expanded features: ['sepal_length', 'sepal_width', 'petal_length']
2023-07-19 12:41:49,311 - julearn - INFO -      X_types:{}
2023-07-19 12:41:49,311 - julearn - WARNING - The following columns are not defined in X_types: ['sepal_length', 'sepal_width', 'petal_length']. They will be treated as continuous.
/home/runner/work/julearn/julearn/julearn/utils/logging.py:238: RuntimeWarning: The following columns are not defined in X_types: ['sepal_length', 'sepal_width', 'petal_length']. They will be treated as continuous.
  warn(msg, category=category)
2023-07-19 12:41:49,312 - julearn - INFO - ====================
2023-07-19 12:41:49,312 - julearn - INFO -
2023-07-19 12:41:49,312 - julearn - INFO - Adding step zscore that applies to ColumnTypes<types={'continuous'}; pattern=(?:__:type:__continuous)>
2023-07-19 12:41:49,313 - julearn - INFO - Step added
2023-07-19 12:41:49,313 - julearn - INFO - Adding step svm that applies to ColumnTypes<types={'continuous'}; pattern=(?:__:type:__continuous)>
2023-07-19 12:41:49,313 - julearn - INFO - Step added
2023-07-19 12:41:49,313 - julearn - INFO - = Model Parameters =
2023-07-19 12:41:49,313 - julearn - INFO - ====================
2023-07-19 12:41:49,313 - julearn - INFO -
2023-07-19 12:41:49,314 - julearn - INFO - = Data Information =
2023-07-19 12:41:49,314 - julearn - INFO -      Problem type: classification
2023-07-19 12:41:49,314 - julearn - INFO -      Number of samples: 100
2023-07-19 12:41:49,314 - julearn - INFO -      Number of features: 3
2023-07-19 12:41:49,314 - julearn - INFO - ====================
2023-07-19 12:41:49,314 - julearn - INFO -
2023-07-19 12:41:49,314 - julearn - INFO -      Number of classes: 2
2023-07-19 12:41:49,314 - julearn - INFO -      Target type: object
2023-07-19 12:41:49,315 - julearn - INFO -      Class distributions: species
versicolor    50
virginica     50
Name: count, dtype: int64
2023-07-19 12:41:49,315 - julearn - INFO - Using outer CV scheme KFold(n_splits=5, random_state=None, shuffle=False)
2023-07-19 12:41:49,315 - julearn - INFO - Binary classification problem detected.
0    0.90
1    0.75
2    0.95
3    0.70
4    0.90
Name: test_score, dtype: float64

Additionally, we can choose to assess the performance of the model using different scoring functions.

For example, we might have an unbalanced dataset:

df_unbalanced = df_iris[20:]  # drop the first 20 versicolor samples
print(df_unbalanced["species"].value_counts())
species
virginica     50
versicolor    30
Name: count, dtype: int64

So we will choose to use the balanced_accuracy and roc_auc metrics.

scoring = ["balanced_accuracy", "roc_auc"]

Since we are comparing the performance of different models, we will need to use the same random seed to split the data in the same way.

cv = RepeatedStratifiedKFold(n_splits=5, n_repeats=5, random_state=42)

First we will use a default SVM model.

scores1 = run_cross_validation(
    X=X,
    y=y,
    data=df_unbalanced,
    model="svm",
    preprocess="zscore",
    problem_type="classification",
    scoring=scoring,
    cv=cv,
)

scores1["model"] = "svm"
2023-07-19 12:41:49,365 - julearn - INFO - ==== Input Data ====
2023-07-19 12:41:49,365 - julearn - INFO - Using dataframe as input
2023-07-19 12:41:49,365 - julearn - INFO -      Features: ['sepal_length', 'sepal_width', 'petal_length']
2023-07-19 12:41:49,365 - julearn - INFO -      Target: species
2023-07-19 12:41:49,365 - julearn - INFO -      Expanded features: ['sepal_length', 'sepal_width', 'petal_length']
2023-07-19 12:41:49,365 - julearn - INFO -      X_types:{}
2023-07-19 12:41:49,365 - julearn - WARNING - The following columns are not defined in X_types: ['sepal_length', 'sepal_width', 'petal_length']. They will be treated as continuous.
/home/runner/work/julearn/julearn/julearn/utils/logging.py:238: RuntimeWarning: The following columns are not defined in X_types: ['sepal_length', 'sepal_width', 'petal_length']. They will be treated as continuous.
  warn(msg, category=category)
2023-07-19 12:41:49,366 - julearn - INFO - ====================
2023-07-19 12:41:49,366 - julearn - INFO -
2023-07-19 12:41:49,366 - julearn - INFO - Adding step zscore that applies to ColumnTypes<types={'continuous'}; pattern=(?:__:type:__continuous)>
2023-07-19 12:41:49,366 - julearn - INFO - Step added
2023-07-19 12:41:49,366 - julearn - INFO - Adding step svm that applies to ColumnTypes<types={'continuous'}; pattern=(?:__:type:__continuous)>
2023-07-19 12:41:49,367 - julearn - INFO - Step added
2023-07-19 12:41:49,367 - julearn - INFO - = Model Parameters =
2023-07-19 12:41:49,367 - julearn - INFO - ====================
2023-07-19 12:41:49,367 - julearn - INFO -
2023-07-19 12:41:49,367 - julearn - INFO - = Data Information =
2023-07-19 12:41:49,367 - julearn - INFO -      Problem type: classification
2023-07-19 12:41:49,367 - julearn - INFO -      Number of samples: 80
2023-07-19 12:41:49,367 - julearn - INFO -      Number of features: 3
2023-07-19 12:41:49,367 - julearn - INFO - ====================
2023-07-19 12:41:49,368 - julearn - INFO -
2023-07-19 12:41:49,368 - julearn - INFO -      Number of classes: 2
2023-07-19 12:41:49,368 - julearn - INFO -      Target type: object
2023-07-19 12:41:49,368 - julearn - INFO -      Class distributions: species
virginica     50
versicolor    30
Name: count, dtype: int64
2023-07-19 12:41:49,369 - julearn - INFO - Using outer CV scheme RepeatedStratifiedKFold(n_repeats=5, n_splits=5, random_state=42)
2023-07-19 12:41:49,369 - julearn - INFO - Binary classification problem detected.

Second we will use a default Random Forest model.

scores2 = run_cross_validation(
    X=X,
    y=y,
    data=df_unbalanced,
    model="rf",
    preprocess="zscore",
    problem_type="classification",
    scoring=scoring,
    cv=cv,
)

scores2["model"] = "rf"
2023-07-19 12:41:49,715 - julearn - INFO - ==== Input Data ====
2023-07-19 12:41:49,715 - julearn - INFO - Using dataframe as input
2023-07-19 12:41:49,715 - julearn - INFO -      Features: ['sepal_length', 'sepal_width', 'petal_length']
2023-07-19 12:41:49,715 - julearn - INFO -      Target: species
2023-07-19 12:41:49,715 - julearn - INFO -      Expanded features: ['sepal_length', 'sepal_width', 'petal_length']
2023-07-19 12:41:49,715 - julearn - INFO -      X_types:{}
2023-07-19 12:41:49,715 - julearn - WARNING - The following columns are not defined in X_types: ['sepal_length', 'sepal_width', 'petal_length']. They will be treated as continuous.
/home/runner/work/julearn/julearn/julearn/utils/logging.py:238: RuntimeWarning: The following columns are not defined in X_types: ['sepal_length', 'sepal_width', 'petal_length']. They will be treated as continuous.
  warn(msg, category=category)
2023-07-19 12:41:49,716 - julearn - INFO - ====================
2023-07-19 12:41:49,716 - julearn - INFO -
2023-07-19 12:41:49,716 - julearn - INFO - Adding step zscore that applies to ColumnTypes<types={'continuous'}; pattern=(?:__:type:__continuous)>
2023-07-19 12:41:49,716 - julearn - INFO - Step added
2023-07-19 12:41:49,716 - julearn - INFO - Adding step rf that applies to ColumnTypes<types={'continuous'}; pattern=(?:__:type:__continuous)>
2023-07-19 12:41:49,716 - julearn - INFO - Step added
2023-07-19 12:41:49,717 - julearn - INFO - = Model Parameters =
2023-07-19 12:41:49,717 - julearn - INFO - ====================
2023-07-19 12:41:49,717 - julearn - INFO -
2023-07-19 12:41:49,717 - julearn - INFO - = Data Information =
2023-07-19 12:41:49,717 - julearn - INFO -      Problem type: classification
2023-07-19 12:41:49,717 - julearn - INFO -      Number of samples: 80
2023-07-19 12:41:49,717 - julearn - INFO -      Number of features: 3
2023-07-19 12:41:49,717 - julearn - INFO - ====================
2023-07-19 12:41:49,717 - julearn - INFO -
2023-07-19 12:41:49,717 - julearn - INFO -      Number of classes: 2
2023-07-19 12:41:49,717 - julearn - INFO -      Target type: object
2023-07-19 12:41:49,718 - julearn - INFO -      Class distributions: species
virginica     50
versicolor    30
Name: count, dtype: int64
2023-07-19 12:41:49,718 - julearn - INFO - Using outer CV scheme RepeatedStratifiedKFold(n_repeats=5, n_splits=5, random_state=42)
2023-07-19 12:41:49,718 - julearn - INFO - Binary classification problem detected.

The third model will be a SVM with a linear kernel.

scores3 = run_cross_validation(
    X=X,
    y=y,
    data=df_unbalanced,
    model="svm",
    model_params={"svm__kernel": "linear"},
    preprocess="zscore",
    problem_type="classification",
    scoring=scoring,
    cv=cv,
)

scores3["model"] = "svm_linear"
2023-07-19 12:41:52,782 - julearn - INFO - ==== Input Data ====
2023-07-19 12:41:52,782 - julearn - INFO - Using dataframe as input
2023-07-19 12:41:52,783 - julearn - INFO -      Features: ['sepal_length', 'sepal_width', 'petal_length']
2023-07-19 12:41:52,783 - julearn - INFO -      Target: species
2023-07-19 12:41:52,783 - julearn - INFO -      Expanded features: ['sepal_length', 'sepal_width', 'petal_length']
2023-07-19 12:41:52,783 - julearn - INFO -      X_types:{}
2023-07-19 12:41:52,783 - julearn - WARNING - The following columns are not defined in X_types: ['sepal_length', 'sepal_width', 'petal_length']. They will be treated as continuous.
/home/runner/work/julearn/julearn/julearn/utils/logging.py:238: RuntimeWarning: The following columns are not defined in X_types: ['sepal_length', 'sepal_width', 'petal_length']. They will be treated as continuous.
  warn(msg, category=category)
2023-07-19 12:41:52,784 - julearn - INFO - ====================
2023-07-19 12:41:52,784 - julearn - INFO -
2023-07-19 12:41:52,784 - julearn - INFO - Adding step zscore that applies to ColumnTypes<types={'continuous'}; pattern=(?:__:type:__continuous)>
2023-07-19 12:41:52,784 - julearn - INFO - Step added
2023-07-19 12:41:52,784 - julearn - INFO - Adding step svm that applies to ColumnTypes<types={'continuous'}; pattern=(?:__:type:__continuous)>
2023-07-19 12:41:52,784 - julearn - INFO - Setting hyperparameter kernel = linear
2023-07-19 12:41:52,784 - julearn - INFO - Step added
2023-07-19 12:41:52,785 - julearn - INFO - = Model Parameters =
2023-07-19 12:41:52,785 - julearn - INFO - ====================
2023-07-19 12:41:52,785 - julearn - INFO -
2023-07-19 12:41:52,785 - julearn - INFO - = Data Information =
2023-07-19 12:41:52,785 - julearn - INFO -      Problem type: classification
2023-07-19 12:41:52,785 - julearn - INFO -      Number of samples: 80
2023-07-19 12:41:52,785 - julearn - INFO -      Number of features: 3
2023-07-19 12:41:52,785 - julearn - INFO - ====================
2023-07-19 12:41:52,785 - julearn - INFO -
2023-07-19 12:41:52,785 - julearn - INFO -      Number of classes: 2
2023-07-19 12:41:52,785 - julearn - INFO -      Target type: object
2023-07-19 12:41:52,786 - julearn - INFO -      Class distributions: species
virginica     50
versicolor    30
Name: count, dtype: int64
2023-07-19 12:41:52,786 - julearn - INFO - Using outer CV scheme RepeatedStratifiedKFold(n_repeats=5, n_splits=5, random_state=42)
2023-07-19 12:41:52,786 - julearn - INFO - Binary classification problem detected.

We can now compare the performance of the models using corrected statistics

stats_df = corrected_ttest(scores1, scores2, scores3)
print(stats_df)
                   metric    t-stat  ...     model_2 p-val-corrected
0  test_balanced_accuracy -0.175075  ...          rf        1.000000
2  test_balanced_accuracy -1.062567  ...  svm_linear        0.895662
4  test_balanced_accuracy -1.151390  ...  svm_linear        0.782741
1            test_roc_auc  1.108944  ...          rf        0.835331
3            test_roc_auc -1.236153  ...  svm_linear        0.685092
5            test_roc_auc -1.669010  ...  svm_linear        0.324331

[6 rows x 6 columns]

We can also plot the performance of the models using the Julearn Score Viewer

from julearn.viz import plot_scores
panel = plot_scores(scores1, scores2, scores3)
# panel.show()
# uncomment the previous line show the plot
# read the documentation for more information
#  https://panel.holoviz.org/getting_started/build_app.html#deploying-panels
WARNING:param.Metric: Param pane was given unknown keyword argument(s) for 'metric' parameter with a widget of type <class 'panel.widgets.select.Select'>. The following keyword arguments could not be applied: 'button_type'.

This is how the plot looks like.

Note

The plot is interactive. You can zoom in and out, and hover over. However, buttons will not work in this documentation.

Total running time of the script: ( 0 minutes 3.968 seconds)

Gallery generated by Sphinx-Gallery