Inspecting the fold-wise predictions#

This example uses the ‘iris’ dataset and performs a simple binary classification using a Support Vector Machine classifier.

We later inspect the predictions of the model for each fold.

# Authors: Federico Raimondo <f.raimondo@fz-juelich.de>
#
# License: AGPL
from seaborn import load_dataset

from sklearn.model_selection import RepeatedStratifiedKFold, ShuffleSplit

from julearn import run_cross_validation
from julearn.pipeline import PipelineCreator
from julearn.utils import configure_logging

Set the logging level to info to see extra information

configure_logging(level="INFO")
2023-07-19 12:41:53,396 - julearn - INFO - ===== Lib Versions =====
2023-07-19 12:41:53,396 - julearn - INFO - numpy: 1.25.1
2023-07-19 12:41:53,396 - julearn - INFO - scipy: 1.11.1
2023-07-19 12:41:53,396 - julearn - INFO - sklearn: 1.3.0
2023-07-19 12:41:53,396 - julearn - INFO - pandas: 2.0.3
2023-07-19 12:41:53,396 - julearn - INFO - julearn: 0.3.1.dev1
2023-07-19 12:41:53,396 - julearn - INFO - ========================
df_iris = load_dataset("iris")

The dataset has three kind of species. We will keep two to perform a binary classification.

df_iris = df_iris[df_iris["species"].isin(["versicolor", "virginica"])]

As features, we will use the sepal length, width and petal length. We will try to predict the species.

X = ["sepal_length", "sepal_width", "petal_length"]
y = "species"
X_types = {"continuous": X}

creator = PipelineCreator(problem_type="classification")
creator.add("zscore")
creator.add("svm")

cv = ShuffleSplit(n_splits=5, train_size=.7, random_state=200)
cv = RepeatedStratifiedKFold(n_splits=5, n_repeats=4, random_state=200)

scores, model, inspector = run_cross_validation(
    X=X,
    y=y,
    data=df_iris,
    model=creator,
    return_inspector=True,
    cv=cv,
)

print(scores)
2023-07-19 12:41:53,400 - julearn - INFO - Adding step zscore that applies to ColumnTypes<types={'continuous'}; pattern=(?:__:type:__continuous)>
2023-07-19 12:41:53,400 - julearn - INFO - Step added
2023-07-19 12:41:53,400 - julearn - INFO - Adding step svm that applies to ColumnTypes<types={'continuous'}; pattern=(?:__:type:__continuous)>
2023-07-19 12:41:53,400 - julearn - INFO - Step added
2023-07-19 12:41:53,400 - julearn - INFO - Inspector requested: setting return_estimator='all'
2023-07-19 12:41:53,400 - julearn - INFO - ==== Input Data ====
2023-07-19 12:41:53,400 - julearn - INFO - Using dataframe as input
2023-07-19 12:41:53,400 - julearn - INFO -      Features: ['sepal_length', 'sepal_width', 'petal_length']
2023-07-19 12:41:53,401 - julearn - INFO -      Target: species
2023-07-19 12:41:53,401 - julearn - INFO -      Expanded features: ['sepal_length', 'sepal_width', 'petal_length']
2023-07-19 12:41:53,401 - julearn - INFO -      X_types:{}
2023-07-19 12:41:53,401 - julearn - WARNING - The following columns are not defined in X_types: ['sepal_length', 'sepal_width', 'petal_length']. They will be treated as continuous.
/home/runner/work/julearn/julearn/julearn/utils/logging.py:238: RuntimeWarning: The following columns are not defined in X_types: ['sepal_length', 'sepal_width', 'petal_length']. They will be treated as continuous.
  warn(msg, category=category)
2023-07-19 12:41:53,402 - julearn - INFO - ====================
2023-07-19 12:41:53,402 - julearn - INFO -
2023-07-19 12:41:53,402 - julearn - INFO - = Model Parameters =
2023-07-19 12:41:53,402 - julearn - INFO - ====================
2023-07-19 12:41:53,402 - julearn - INFO -
2023-07-19 12:41:53,402 - julearn - INFO - = Data Information =
2023-07-19 12:41:53,403 - julearn - INFO -      Problem type: classification
2023-07-19 12:41:53,403 - julearn - INFO -      Number of samples: 100
2023-07-19 12:41:53,403 - julearn - INFO -      Number of features: 3
2023-07-19 12:41:53,403 - julearn - INFO - ====================
2023-07-19 12:41:53,403 - julearn - INFO -
2023-07-19 12:41:53,403 - julearn - INFO -      Number of classes: 2
2023-07-19 12:41:53,403 - julearn - INFO -      Target type: object
2023-07-19 12:41:53,404 - julearn - INFO -      Class distributions: species
versicolor    50
virginica     50
Name: count, dtype: int64
2023-07-19 12:41:53,404 - julearn - INFO - Using outer CV scheme RepeatedStratifiedKFold(n_repeats=4, n_splits=5, random_state=200)
2023-07-19 12:41:53,404 - julearn - INFO - Binary classification problem detected.
    fit_time  score_time  ... fold                          cv_mdsum
0   0.005888    0.003176  ...    0  42489ff0163b2f12752440a6b7ef74c7
1   0.005551    0.003150  ...    1  42489ff0163b2f12752440a6b7ef74c7
2   0.005497    0.003109  ...    2  42489ff0163b2f12752440a6b7ef74c7
3   0.005456    0.003121  ...    3  42489ff0163b2f12752440a6b7ef74c7
4   0.005478    0.003097  ...    4  42489ff0163b2f12752440a6b7ef74c7
5   0.005495    0.003125  ...    0  42489ff0163b2f12752440a6b7ef74c7
6   0.005445    0.003120  ...    1  42489ff0163b2f12752440a6b7ef74c7
7   0.005491    0.003095  ...    2  42489ff0163b2f12752440a6b7ef74c7
8   0.005523    0.003121  ...    3  42489ff0163b2f12752440a6b7ef74c7
9   0.005480    0.003109  ...    4  42489ff0163b2f12752440a6b7ef74c7
10  0.005518    0.003127  ...    0  42489ff0163b2f12752440a6b7ef74c7
11  0.005485    0.003061  ...    1  42489ff0163b2f12752440a6b7ef74c7
12  0.005431    0.003091  ...    2  42489ff0163b2f12752440a6b7ef74c7
13  0.005466    0.003090  ...    3  42489ff0163b2f12752440a6b7ef74c7
14  0.005455    0.003095  ...    4  42489ff0163b2f12752440a6b7ef74c7
15  0.005512    0.003107  ...    0  42489ff0163b2f12752440a6b7ef74c7
16  0.005486    0.003141  ...    1  42489ff0163b2f12752440a6b7ef74c7
17  0.005490    0.003097  ...    2  42489ff0163b2f12752440a6b7ef74c7
18  0.005507    0.003081  ...    3  42489ff0163b2f12752440a6b7ef74c7
19  0.005481    0.003107  ...    4  42489ff0163b2f12752440a6b7ef74c7

[20 rows x 9 columns]

We can now inspect the predictions of the model for each fold.

cv_predictions = inspector.folds.predict()

print(cv_predictions)
    repeat0_p0  repeat1_p0  repeat2_p0  repeat3_p0      target
0   versicolor  versicolor  versicolor  versicolor  versicolor
1   versicolor  versicolor  versicolor  versicolor  versicolor
2   versicolor  versicolor  versicolor  versicolor  versicolor
3   versicolor  versicolor  versicolor  versicolor  versicolor
4   versicolor  versicolor  versicolor  versicolor  versicolor
..         ...         ...         ...         ...         ...
95   virginica   virginica   virginica   virginica   virginica
96   virginica   virginica   virginica   virginica   virginica
97   virginica   virginica   virginica   virginica   virginica
98   virginica   virginica   virginica   virginica   virginica
99   virginica   virginica   virginica   virginica   virginica

[100 rows x 5 columns]
inspector.folds[0].model
<julearn.inspect._pipeline.PipelineInspector object at 0x7f7f629e72b0>

Total running time of the script: ( 0 minutes 0.293 seconds)

Gallery generated by Sphinx-Gallery