Inspecting the fold-wise predictions#

This example uses the iris dataset and performs a simple binary classification using a Support Vector Machine classifier.

We later inspect the predictions of the model for each fold.

# Authors: Federico Raimondo <f.raimondo@fz-juelich.de>
# License: AGPL

from seaborn import load_dataset

from sklearn.model_selection import RepeatedStratifiedKFold, ShuffleSplit

from julearn import run_cross_validation
from julearn.pipeline import PipelineCreator
from julearn.utils import configure_logging

Set the logging level to info to see extra information.

configure_logging(level="INFO")
2024-05-16 08:52:29,327 - julearn - INFO - ===== Lib Versions =====
2024-05-16 08:52:29,327 - julearn - INFO - numpy: 1.26.4
2024-05-16 08:52:29,327 - julearn - INFO - scipy: 1.13.0
2024-05-16 08:52:29,327 - julearn - INFO - sklearn: 1.4.2
2024-05-16 08:52:29,327 - julearn - INFO - pandas: 2.1.4
2024-05-16 08:52:29,327 - julearn - INFO - julearn: 0.3.3
2024-05-16 08:52:29,327 - julearn - INFO - ========================
df_iris = load_dataset("iris")

The dataset has three kind of species. We will keep two to perform a binary classification.

df_iris = df_iris[df_iris["species"].isin(["versicolor", "virginica"])]

As features, we will use the sepal length, width and petal length. We will try to predict the species.

X = ["sepal_length", "sepal_width", "petal_length"]
y = "species"
X_types = {"continuous": X}

creator = PipelineCreator(problem_type="classification")
creator.add("zscore")
creator.add("svm")

cv = ShuffleSplit(n_splits=5, train_size=0.7, random_state=200)
cv = RepeatedStratifiedKFold(n_splits=5, n_repeats=4, random_state=200)

scores, model, inspector = run_cross_validation(
    X=X,
    y=y,
    data=df_iris,
    model=creator,
    return_inspector=True,
    cv=cv,
)

print(scores)
2024-05-16 08:52:29,330 - julearn - INFO - Adding step zscore that applies to ColumnTypes<types={'continuous'}; pattern=(?:__:type:__continuous)>
2024-05-16 08:52:29,331 - julearn - INFO - Step added
2024-05-16 08:52:29,331 - julearn - INFO - Adding step svm that applies to ColumnTypes<types={'continuous'}; pattern=(?:__:type:__continuous)>
2024-05-16 08:52:29,331 - julearn - INFO - Step added
2024-05-16 08:52:29,331 - julearn - INFO - Inspector requested: setting return_estimator='all'
2024-05-16 08:52:29,331 - julearn - INFO - ==== Input Data ====
2024-05-16 08:52:29,331 - julearn - INFO - Using dataframe as input
2024-05-16 08:52:29,331 - julearn - INFO -      Features: ['sepal_length', 'sepal_width', 'petal_length']
2024-05-16 08:52:29,331 - julearn - INFO -      Target: species
2024-05-16 08:52:29,331 - julearn - INFO -      Expanded features: ['sepal_length', 'sepal_width', 'petal_length']
2024-05-16 08:52:29,331 - julearn - INFO -      X_types:{}
2024-05-16 08:52:29,331 - julearn - WARNING - The following columns are not defined in X_types: ['sepal_length', 'sepal_width', 'petal_length']. They will be treated as continuous.
/home/runner/work/julearn/julearn/julearn/prepare.py:505: RuntimeWarning: The following columns are not defined in X_types: ['sepal_length', 'sepal_width', 'petal_length']. They will be treated as continuous.
  warn_with_log(
2024-05-16 08:52:29,332 - julearn - INFO - ====================
2024-05-16 08:52:29,332 - julearn - INFO -
2024-05-16 08:52:29,333 - julearn - INFO - = Model Parameters =
2024-05-16 08:52:29,333 - julearn - INFO - ====================
2024-05-16 08:52:29,333 - julearn - INFO -
2024-05-16 08:52:29,333 - julearn - INFO - = Data Information =
2024-05-16 08:52:29,333 - julearn - INFO -      Problem type: classification
2024-05-16 08:52:29,333 - julearn - INFO -      Number of samples: 100
2024-05-16 08:52:29,333 - julearn - INFO -      Number of features: 3
2024-05-16 08:52:29,333 - julearn - INFO - ====================
2024-05-16 08:52:29,333 - julearn - INFO -
2024-05-16 08:52:29,333 - julearn - INFO -      Number of classes: 2
2024-05-16 08:52:29,333 - julearn - INFO -      Target type: object
2024-05-16 08:52:29,334 - julearn - INFO -      Class distributions: species
versicolor    50
virginica     50
Name: count, dtype: int64
2024-05-16 08:52:29,334 - julearn - INFO - Using outer CV scheme RepeatedStratifiedKFold(n_repeats=4, n_splits=5, random_state=200)
2024-05-16 08:52:29,334 - julearn - INFO - Binary classification problem detected.
2024-05-16 08:52:29,491 - julearn - INFO - Fitting final model
    fit_time  score_time  ... fold                          cv_mdsum
0   0.004957    0.002593  ...    0  42489ff0163b2f12752440a6b7ef74c7
1   0.004621    0.002486  ...    1  42489ff0163b2f12752440a6b7ef74c7
2   0.005030    0.002931  ...    2  42489ff0163b2f12752440a6b7ef74c7
3   0.004995    0.002652  ...    3  42489ff0163b2f12752440a6b7ef74c7
4   0.004777    0.002530  ...    4  42489ff0163b2f12752440a6b7ef74c7
5   0.004565    0.002483  ...    0  42489ff0163b2f12752440a6b7ef74c7
6   0.004510    0.002496  ...    1  42489ff0163b2f12752440a6b7ef74c7
7   0.004505    0.002453  ...    2  42489ff0163b2f12752440a6b7ef74c7
8   0.004471    0.002436  ...    3  42489ff0163b2f12752440a6b7ef74c7
9   0.004444    0.002454  ...    4  42489ff0163b2f12752440a6b7ef74c7
10  0.004508    0.002625  ...    0  42489ff0163b2f12752440a6b7ef74c7
11  0.004550    0.002476  ...    1  42489ff0163b2f12752440a6b7ef74c7
12  0.004458    0.002496  ...    2  42489ff0163b2f12752440a6b7ef74c7
13  0.004530    0.002470  ...    3  42489ff0163b2f12752440a6b7ef74c7
14  0.004498    0.002519  ...    4  42489ff0163b2f12752440a6b7ef74c7
15  0.004927    0.002717  ...    0  42489ff0163b2f12752440a6b7ef74c7
16  0.004647    0.002544  ...    1  42489ff0163b2f12752440a6b7ef74c7
17  0.004533    0.002465  ...    2  42489ff0163b2f12752440a6b7ef74c7
18  0.004511    0.002493  ...    3  42489ff0163b2f12752440a6b7ef74c7
19  0.004488    0.002538  ...    4  42489ff0163b2f12752440a6b7ef74c7

[20 rows x 9 columns]

We can now inspect the predictions of the model for each fold.

cv_predictions = inspector.folds.predict()

print(cv_predictions)
    repeat0_p0  repeat1_p0  repeat2_p0  repeat3_p0      target
0   versicolor  versicolor  versicolor  versicolor  versicolor
1   versicolor  versicolor  versicolor  versicolor  versicolor
2   versicolor  versicolor  versicolor  versicolor  versicolor
3   versicolor  versicolor  versicolor  versicolor  versicolor
4   versicolor  versicolor  versicolor  versicolor  versicolor
..         ...         ...         ...         ...         ...
95   virginica   virginica   virginica   virginica   virginica
96   virginica   virginica   virginica   virginica   virginica
97   virginica   virginica   virginica   virginica   virginica
98   virginica   virginica   virginica   virginica   virginica
99   virginica   virginica   virginica   virginica   virginica

[100 rows x 5 columns]
inspector.folds[0].model
<julearn.inspect._pipeline.PipelineInspector object at 0x7f9204f59ab0>

Total running time of the script: (0 minutes 0.245 seconds)

Gallery generated by Sphinx-Gallery