Note
Go to the end to download the full example code
Inspecting the fold-wise predictions#
This example uses the iris
dataset and performs a simple binary
classification using a Support Vector Machine classifier.
We later inspect the predictions of the model for each fold.
# Authors: Federico Raimondo <f.raimondo@fz-juelich.de>
# License: AGPL
from seaborn import load_dataset
from sklearn.model_selection import RepeatedStratifiedKFold, ShuffleSplit
from julearn import run_cross_validation
from julearn.pipeline import PipelineCreator
from julearn.utils import configure_logging
Set the logging level to info to see extra information.
configure_logging(level="INFO")
/home/runner/work/julearn/julearn/julearn/utils/logging.py:66: UserWarning: The '__version__' attribute is deprecated and will be removed in MarkupSafe 3.1. Use feature detection, or `importlib.metadata.version("markupsafe")`, instead.
vstring = str(getattr(module, "__version__", None))
2024-10-23 11:29:14,160 - julearn - INFO - ===== Lib Versions =====
2024-10-23 11:29:14,160 - julearn - INFO - numpy: 1.26.4
2024-10-23 11:29:14,160 - julearn - INFO - scipy: 1.14.1
2024-10-23 11:29:14,160 - julearn - INFO - sklearn: 1.5.2
2024-10-23 11:29:14,160 - julearn - INFO - pandas: 2.2.3
2024-10-23 11:29:14,160 - julearn - INFO - julearn: 0.3.5.dev16
2024-10-23 11:29:14,160 - julearn - INFO - ========================
df_iris = load_dataset("iris")
The dataset has three kind of species. We will keep two to perform a binary classification.
As features, we will use the sepal length, width and petal length. We will try to predict the species.
X = ["sepal_length", "sepal_width", "petal_length"]
y = "species"
X_types = {"continuous": X}
creator = PipelineCreator(problem_type="classification")
creator.add("zscore")
creator.add("svm")
cv = RepeatedStratifiedKFold(n_splits=5, n_repeats=4, random_state=200)
scores, model, inspector = run_cross_validation(
X=X,
y=y,
data=df_iris,
model=creator,
return_inspector=True,
cv=cv,
)
print(scores)
2024-10-23 11:29:14,162 - julearn - INFO - Adding step zscore that applies to ColumnTypes<types={'continuous'}; pattern=(?:__:type:__continuous)>
2024-10-23 11:29:14,163 - julearn - INFO - Step added
2024-10-23 11:29:14,163 - julearn - INFO - Adding step svm that applies to ColumnTypes<types={'continuous'}; pattern=(?:__:type:__continuous)>
2024-10-23 11:29:14,163 - julearn - INFO - Step added
2024-10-23 11:29:14,163 - julearn - INFO - ==== Input Data ====
2024-10-23 11:29:14,163 - julearn - INFO - Using dataframe as input
2024-10-23 11:29:14,163 - julearn - INFO - Features: ['sepal_length', 'sepal_width', 'petal_length']
2024-10-23 11:29:14,163 - julearn - INFO - Target: species
2024-10-23 11:29:14,163 - julearn - INFO - Expanded features: ['sepal_length', 'sepal_width', 'petal_length']
2024-10-23 11:29:14,163 - julearn - INFO - X_types:{}
2024-10-23 11:29:14,163 - julearn - WARNING - The following columns are not defined in X_types: ['sepal_length', 'sepal_width', 'petal_length']. They will be treated as continuous.
/home/runner/work/julearn/julearn/julearn/prepare.py:509: RuntimeWarning: The following columns are not defined in X_types: ['sepal_length', 'sepal_width', 'petal_length']. They will be treated as continuous.
warn_with_log(
2024-10-23 11:29:14,164 - julearn - INFO - ====================
2024-10-23 11:29:14,164 - julearn - INFO -
2024-10-23 11:29:14,164 - julearn - INFO - = Model Parameters =
2024-10-23 11:29:14,165 - julearn - INFO - ====================
2024-10-23 11:29:14,165 - julearn - INFO -
2024-10-23 11:29:14,165 - julearn - INFO - = Data Information =
2024-10-23 11:29:14,165 - julearn - INFO - Problem type: classification
2024-10-23 11:29:14,165 - julearn - INFO - Number of samples: 100
2024-10-23 11:29:14,165 - julearn - INFO - Number of features: 3
2024-10-23 11:29:14,165 - julearn - INFO - ====================
2024-10-23 11:29:14,165 - julearn - INFO -
2024-10-23 11:29:14,165 - julearn - INFO - Number of classes: 2
2024-10-23 11:29:14,165 - julearn - INFO - Target type: object
2024-10-23 11:29:14,166 - julearn - INFO - Class distributions: species
versicolor 50
virginica 50
Name: count, dtype: int64
2024-10-23 11:29:14,166 - julearn - INFO - Using outer CV scheme RepeatedStratifiedKFold(n_repeats=4, n_splits=5, random_state=200) (incl. final model)
2024-10-23 11:29:14,166 - julearn - INFO - Binary classification problem detected.
fit_time score_time ... fold cv_mdsum
0 0.004322 0.002422 ... 0 42489ff0163b2f12752440a6b7ef74c7
1 0.004309 0.002388 ... 1 42489ff0163b2f12752440a6b7ef74c7
2 0.004302 0.002392 ... 2 42489ff0163b2f12752440a6b7ef74c7
3 0.004300 0.002419 ... 3 42489ff0163b2f12752440a6b7ef74c7
4 0.004380 0.002409 ... 4 42489ff0163b2f12752440a6b7ef74c7
5 0.004330 0.002444 ... 0 42489ff0163b2f12752440a6b7ef74c7
6 0.004276 0.002380 ... 1 42489ff0163b2f12752440a6b7ef74c7
7 0.004290 0.002391 ... 2 42489ff0163b2f12752440a6b7ef74c7
8 0.004298 0.002639 ... 3 42489ff0163b2f12752440a6b7ef74c7
9 0.004912 0.002429 ... 4 42489ff0163b2f12752440a6b7ef74c7
10 0.004339 0.002437 ... 0 42489ff0163b2f12752440a6b7ef74c7
11 0.004325 0.002551 ... 1 42489ff0163b2f12752440a6b7ef74c7
12 0.004305 0.002441 ... 2 42489ff0163b2f12752440a6b7ef74c7
13 0.004268 0.002380 ... 3 42489ff0163b2f12752440a6b7ef74c7
14 0.004270 0.002407 ... 4 42489ff0163b2f12752440a6b7ef74c7
15 0.004259 0.002405 ... 0 42489ff0163b2f12752440a6b7ef74c7
16 0.004250 0.002406 ... 1 42489ff0163b2f12752440a6b7ef74c7
17 0.004339 0.002381 ... 2 42489ff0163b2f12752440a6b7ef74c7
18 0.004298 0.002383 ... 3 42489ff0163b2f12752440a6b7ef74c7
19 0.004283 0.002396 ... 4 42489ff0163b2f12752440a6b7ef74c7
[20 rows x 9 columns]
We can now inspect the predictions of the model for each fold.
cv_predictions = inspector.folds.predict()
print(cv_predictions)
repeat0_p0 repeat1_p0 repeat2_p0 repeat3_p0 target
0 versicolor versicolor versicolor versicolor versicolor
1 versicolor versicolor versicolor versicolor versicolor
2 versicolor versicolor versicolor versicolor versicolor
3 versicolor versicolor versicolor versicolor versicolor
4 versicolor versicolor versicolor versicolor versicolor
.. ... ... ... ... ...
95 virginica virginica virginica virginica virginica
96 virginica virginica virginica virginica virginica
97 virginica virginica virginica virginica virginica
98 virginica virginica virginica virginica virginica
99 virginica virginica virginica virginica virginica
[100 rows x 5 columns]
inspector.folds[0].model
<julearn.inspect._pipeline.PipelineInspector object at 0x7f90ad8f5060>
Total running time of the script: (0 minutes 0.234 seconds)