Inspecting the fold-wise predictions#

This example uses the iris dataset and performs a simple binary classification using a Support Vector Machine classifier.

We later inspect the predictions of the model for each fold.

# Authors: Federico Raimondo <f.raimondo@fz-juelich.de>
# License: AGPL

from seaborn import load_dataset

from sklearn.model_selection import RepeatedStratifiedKFold, ShuffleSplit

from julearn import run_cross_validation
from julearn.pipeline import PipelineCreator
from julearn.utils import configure_logging

Set the logging level to info to see extra information.

configure_logging(level="INFO")
2026-05-29 20:45:58,115 - julearn - INFO - ===== Lib Versions =====
2026-05-29 20:45:58,115 - julearn - INFO - numpy: 1.26.4
2026-05-29 20:45:58,115 - julearn - INFO - scipy: 1.17.1
2026-05-29 20:45:58,115 - julearn - INFO - sklearn: 1.4.2
2026-05-29 20:45:58,115 - julearn - INFO - pandas: 2.1.4
2026-05-29 20:45:58,115 - julearn - INFO - julearn: 0.3.2
2026-05-29 20:45:58,116 - julearn - INFO - ========================
df_iris = load_dataset("iris")

The dataset has three kind of species. We will keep two to perform a binary classification.

df_iris = df_iris[df_iris["species"].isin(["versicolor", "virginica"])]

As features, we will use the sepal length, width and petal length. We will try to predict the species.

X = ["sepal_length", "sepal_width", "petal_length"]
y = "species"
X_types = {"continuous": X}

creator = PipelineCreator(problem_type="classification")
creator.add("zscore")
creator.add("svm")

cv = ShuffleSplit(n_splits=5, train_size=0.7, random_state=200)
cv = RepeatedStratifiedKFold(n_splits=5, n_repeats=4, random_state=200)

scores, model, inspector = run_cross_validation(
    X=X,
    y=y,
    data=df_iris,
    model=creator,
    return_inspector=True,
    cv=cv,
)

print(scores)
2026-05-29 20:45:58,120 - julearn - INFO - Adding step zscore that applies to ColumnTypes<types={'continuous'}; pattern=(?:__:type:__continuous)>
2026-05-29 20:45:58,121 - julearn - INFO - Step added
2026-05-29 20:45:58,121 - julearn - INFO - Adding step svm that applies to ColumnTypes<types={'continuous'}; pattern=(?:__:type:__continuous)>
2026-05-29 20:45:58,121 - julearn - INFO - Step added
2026-05-29 20:45:58,122 - julearn - INFO - Inspector requested: setting return_estimator='all'
2026-05-29 20:45:58,122 - julearn - INFO - ==== Input Data ====
2026-05-29 20:45:58,122 - julearn - INFO - Using dataframe as input
2026-05-29 20:45:58,122 - julearn - INFO -      Features: ['sepal_length', 'sepal_width', 'petal_length']
2026-05-29 20:45:58,123 - julearn - INFO -      Target: species
2026-05-29 20:45:58,123 - julearn - INFO -      Expanded features: ['sepal_length', 'sepal_width', 'petal_length']
2026-05-29 20:45:58,123 - julearn - INFO -      X_types:{}
2026-05-29 20:45:58,124 - julearn - WARNING - The following columns are not defined in X_types: ['sepal_length', 'sepal_width', 'petal_length']. They will be treated as continuous.
/private/var/folders/09/t22x2_p106j7p24khr0jdxrw0000gn/T/tmp880z_x00/julearn/prepare.py:505: RuntimeWarning: The following columns are not defined in X_types: ['sepal_length', 'sepal_width', 'petal_length']. They will be treated as continuous.
  warn_with_log(
2026-05-29 20:45:58,125 - julearn - INFO - ====================
2026-05-29 20:45:58,125 - julearn - INFO -
2026-05-29 20:45:58,126 - julearn - INFO - = Model Parameters =
2026-05-29 20:45:58,126 - julearn - INFO - ====================
2026-05-29 20:45:58,126 - julearn - INFO -
2026-05-29 20:45:58,126 - julearn - INFO - = Data Information =
2026-05-29 20:45:58,127 - julearn - INFO -      Problem type: classification
2026-05-29 20:45:58,127 - julearn - INFO -      Number of samples: 100
2026-05-29 20:45:58,127 - julearn - INFO -      Number of features: 3
2026-05-29 20:45:58,127 - julearn - INFO - ====================
2026-05-29 20:45:58,127 - julearn - INFO -
2026-05-29 20:45:58,127 - julearn - INFO -      Number of classes: 2
2026-05-29 20:45:58,128 - julearn - INFO -      Target type: object
2026-05-29 20:45:58,128 - julearn - INFO -      Class distributions: species
versicolor    50
virginica     50
Name: count, dtype: int64
2026-05-29 20:45:58,129 - julearn - INFO - Using outer CV scheme RepeatedStratifiedKFold(n_repeats=4, n_splits=5, random_state=200)
2026-05-29 20:45:58,129 - julearn - INFO - Binary classification problem detected.
2026-05-29 20:45:58,342 - julearn - INFO - Fitting final model
    fit_time  score_time                                          estimator  test_score  n_train  n_test  repeat  fold                          cv_mdsum
0   0.007778    0.004338  (SetColumnTypes(X_types={}), StandardScaler(),...        0.90       80      20       0     0  42489ff0163b2f12752440a6b7ef74c7
1   0.006071    0.002680  (SetColumnTypes(X_types={}), StandardScaler(),...        1.00       80      20       0     1  42489ff0163b2f12752440a6b7ef74c7
2   0.005316    0.002563  (SetColumnTypes(X_types={}), StandardScaler(),...        0.90       80      20       0     2  42489ff0163b2f12752440a6b7ef74c7
3   0.006554    0.004018  (SetColumnTypes(X_types={}), StandardScaler(),...        0.95       80      20       0     3  42489ff0163b2f12752440a6b7ef74c7
4   0.006655    0.003521  (SetColumnTypes(X_types={}), StandardScaler(),...        0.85       80      20       0     4  42489ff0163b2f12752440a6b7ef74c7
5   0.006926    0.002734  (SetColumnTypes(X_types={}), StandardScaler(),...        0.80       80      20       1     0  42489ff0163b2f12752440a6b7ef74c7
6   0.005601    0.002641  (SetColumnTypes(X_types={}), StandardScaler(),...        0.95       80      20       1     1  42489ff0163b2f12752440a6b7ef74c7
7   0.007343    0.003350  (SetColumnTypes(X_types={}), StandardScaler(),...        0.90       80      20       1     2  42489ff0163b2f12752440a6b7ef74c7
8   0.005713    0.004182  (SetColumnTypes(X_types={}), StandardScaler(),...        0.95       80      20       1     3  42489ff0163b2f12752440a6b7ef74c7
9   0.007025    0.003404  (SetColumnTypes(X_types={}), StandardScaler(),...        0.90       80      20       1     4  42489ff0163b2f12752440a6b7ef74c7
10  0.006288    0.002528  (SetColumnTypes(X_types={}), StandardScaler(),...        0.90       80      20       2     0  42489ff0163b2f12752440a6b7ef74c7
11  0.010092    0.004568  (SetColumnTypes(X_types={}), StandardScaler(),...        0.90       80      20       2     1  42489ff0163b2f12752440a6b7ef74c7
12  0.006411    0.003786  (SetColumnTypes(X_types={}), StandardScaler(),...        0.85       80      20       2     2  42489ff0163b2f12752440a6b7ef74c7
13  0.005721    0.002578  (SetColumnTypes(X_types={}), StandardScaler(),...        1.00       80      20       2     3  42489ff0163b2f12752440a6b7ef74c7
14  0.006826    0.003406  (SetColumnTypes(X_types={}), StandardScaler(),...        0.90       80      20       2     4  42489ff0163b2f12752440a6b7ef74c7
15  0.004510    0.002623  (SetColumnTypes(X_types={}), StandardScaler(),...        0.85       80      20       3     0  42489ff0163b2f12752440a6b7ef74c7
16  0.006739    0.003168  (SetColumnTypes(X_types={}), StandardScaler(),...        0.95       80      20       3     1  42489ff0163b2f12752440a6b7ef74c7
17  0.006409    0.003312  (SetColumnTypes(X_types={}), StandardScaler(),...        0.90       80      20       3     2  42489ff0163b2f12752440a6b7ef74c7
18  0.004705    0.002569  (SetColumnTypes(X_types={}), StandardScaler(),...        1.00       80      20       3     3  42489ff0163b2f12752440a6b7ef74c7
19  0.005267    0.003956  (SetColumnTypes(X_types={}), StandardScaler(),...        0.85       80      20       3     4  42489ff0163b2f12752440a6b7ef74c7

We can now inspect the predictions of the model for each fold.

cv_predictions = inspector.folds.predict()

print(cv_predictions)
    repeat0_p0  repeat1_p0  repeat2_p0  repeat3_p0      target
0   versicolor  versicolor  versicolor  versicolor  versicolor
1   versicolor  versicolor  versicolor  versicolor  versicolor
2   versicolor  versicolor  versicolor  versicolor  versicolor
3   versicolor  versicolor  versicolor  versicolor  versicolor
4   versicolor  versicolor  versicolor  versicolor  versicolor
..         ...         ...         ...         ...         ...
95   virginica   virginica   virginica   virginica   virginica
96   virginica   virginica   virginica   virginica   virginica
97   virginica   virginica   virginica   virginica   virginica
98   virginica   virginica   virginica   virginica   virginica
99   virginica   virginica   virginica   virginica   virginica

[100 rows x 5 columns]
inspector.folds[0].model
<julearn.inspect._pipeline.PipelineInspector object at 0x124067a90>

Total running time of the script: (0 minutes 0.354 seconds)

Gallery generated by Sphinx-Gallery