.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/04_confounds/plot_confound_removal_classification.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_04_confounds_plot_confound_removal_classification.py: Confound Removal (model comparison) =================================== This example uses the ``iris`` dataset, performs simple binary classification with and without confound removal using a Random Forest classifier. .. GENERATED FROM PYTHON SOURCE LINES 9-24 .. code-block:: Python # Authors: Shammi More # Federico Raimondo # Leonard Sasse # License: AGPL import matplotlib.pyplot as plt import pandas as pd import seaborn as sns from seaborn import load_dataset from julearn import run_cross_validation from julearn.model_selection import StratifiedBootstrap from julearn.pipeline import PipelineCreator from julearn.utils import configure_logging .. GENERATED FROM PYTHON SOURCE LINES 25-26 Set the logging level to info to see extra information. .. GENERATED FROM PYTHON SOURCE LINES 26-28 .. code-block:: Python configure_logging(level="INFO") .. rst-class:: sphx-glr-script-out .. code-block:: none 2026-01-16 10:54:20,626 - julearn - INFO - ===== Lib Versions ===== 2026-01-16 10:54:20,626 - julearn - INFO - numpy: 1.26.4 2026-01-16 10:54:20,626 - julearn - INFO - scipy: 1.17.0 2026-01-16 10:54:20,626 - julearn - INFO - sklearn: 1.7.2 2026-01-16 10:54:20,626 - julearn - INFO - pandas: 2.3.3 2026-01-16 10:54:20,627 - julearn - INFO - julearn: 0.3.5.dev123 2026-01-16 10:54:20,627 - julearn - INFO - ======================== .. GENERATED FROM PYTHON SOURCE LINES 29-30 Load the iris data from seaborn. .. GENERATED FROM PYTHON SOURCE LINES 30-32 .. code-block:: Python df_iris = load_dataset("iris") .. GENERATED FROM PYTHON SOURCE LINES 33-35 The dataset has three kind of species. We will keep two to perform a binary classification. .. GENERATED FROM PYTHON SOURCE LINES 35-37 .. code-block:: Python df_iris = df_iris[df_iris["species"].isin(["versicolor", "virginica"])] .. GENERATED FROM PYTHON SOURCE LINES 38-40 As features, we will use the sepal length, width and petal length and use petal width as confound. .. GENERATED FROM PYTHON SOURCE LINES 40-45 .. code-block:: Python X = ["sepal_length", "sepal_width", "petal_length"] y = "species" confounds = ["petal_width"] .. GENERATED FROM PYTHON SOURCE LINES 46-59 Doing hypothesis testing in ML is not that simple. If we were to use classical frequentist statistics, we have the problem that using cross validation, the samples are not independent and the population (train + test) is always the same. If we want to compare two models, an alternative is to contrast, for each fold, the performance gap between the models. If we combine that approach with bootstrapping, we can then compare the confidence intervals of the difference. If the 95% CI is above 0 (or below), we can claim that the models are different with p < 0.05. Let's use a bootstrap CV. In the interest of time we do 20 iterations, change the number of bootstrap iterations to at least 2000 for a valid test. .. GENERATED FROM PYTHON SOURCE LINES 59-63 .. code-block:: Python n_bootstrap = 20 n_elements = len(df_iris) cv = StratifiedBootstrap(n_splits=n_bootstrap, random_state=42) .. GENERATED FROM PYTHON SOURCE LINES 64-66 First, we will train a model without performing confound removal on features. Note: confounds by default. .. GENERATED FROM PYTHON SOURCE LINES 66-79 .. code-block:: Python scores_ncr = run_cross_validation( X=X, y=y, data=df_iris, model="rf", cv=cv, problem_type="classification", preprocess="zscore", scoring=["accuracy", "roc_auc"], return_estimator="cv", seed=200, ) .. rst-class:: sphx-glr-script-out .. code-block:: none 2026-01-16 10:54:20,629 - julearn - INFO - Setting random seed to 200 2026-01-16 10:54:20,630 - julearn - INFO - ==== Input Data ==== 2026-01-16 10:54:20,630 - julearn - INFO - Using dataframe as input 2026-01-16 10:54:20,630 - julearn - INFO - Features: ['sepal_length', 'sepal_width', 'petal_length'] 2026-01-16 10:54:20,630 - julearn - INFO - Target: species 2026-01-16 10:54:20,630 - julearn - INFO - Expanded features: ['sepal_length', 'sepal_width', 'petal_length'] 2026-01-16 10:54:20,630 - julearn - INFO - X_types:{} 2026-01-16 10:54:20,630 - julearn - WARNING - The following columns are not defined in X_types: ['sepal_length', 'sepal_width', 'petal_length']. They will be treated as continuous. /home/runner/work/julearn/julearn/julearn/prepare.py:576: RuntimeWarning: The following columns are not defined in X_types: ['sepal_length', 'sepal_width', 'petal_length']. They will be treated as continuous. warn_with_log( 2026-01-16 10:54:20,631 - julearn - INFO - ==================== 2026-01-16 10:54:20,631 - julearn - INFO - 2026-01-16 10:54:20,631 - julearn - INFO - Adding step zscore that applies to ColumnTypes 2026-01-16 10:54:20,631 - julearn - INFO - Step added 2026-01-16 10:54:20,631 - julearn - INFO - Adding step rf that applies to ColumnTypes 2026-01-16 10:54:20,632 - julearn - INFO - Step added 2026-01-16 10:54:20,632 - julearn - INFO - = Model Parameters = 2026-01-16 10:54:20,632 - julearn - INFO - ==================== 2026-01-16 10:54:20,632 - julearn - INFO - 2026-01-16 10:54:20,633 - julearn - INFO - = Data Information = 2026-01-16 10:54:20,633 - julearn - INFO - Problem type: classification 2026-01-16 10:54:20,633 - julearn - INFO - Number of samples: 100 2026-01-16 10:54:20,633 - julearn - INFO - Number of features: 3 2026-01-16 10:54:20,633 - julearn - INFO - ==================== 2026-01-16 10:54:20,633 - julearn - INFO - 2026-01-16 10:54:20,633 - julearn - INFO - Number of classes: 2 2026-01-16 10:54:20,633 - julearn - INFO - Target type: object 2026-01-16 10:54:20,634 - julearn - INFO - Class distributions: species versicolor 50 virginica 50 Name: count, dtype: int64 2026-01-16 10:54:20,634 - julearn - INFO - Using outer CV scheme StratifiedBootstrap(n_splits=20, random_state=42) 2026-01-16 10:54:20,634 - julearn - INFO - Binary classification problem detected. .. GENERATED FROM PYTHON SOURCE LINES 80-82 Next, we train a model after performing confound removal on the features. Note: we initialize the CV again to use the same folds as before. .. GENERATED FROM PYTHON SOURCE LINES 82-88 .. code-block:: Python cv = StratifiedBootstrap(n_splits=n_bootstrap, random_state=42) # In order to tell ``run_cross_validation`` which columns are confounds, # and which columns are features, we have to define the X_types: X_types = {"features": X, "confound": confounds} .. GENERATED FROM PYTHON SOURCE LINES 89-101 We can now define a pipeline creator and add a confound removal step. The pipeline creator should apply all the steps, by default, to the features type. The first step will zscore both features and confounds. The second step will remove the confounds (type "confound") from the "features". Finally, a random forest will be trained. Given the default ``apply_to`` in the pipeline creator, the random forest will only be trained using "features". .. GENERATED FROM PYTHON SOURCE LINES 101-118 .. code-block:: Python creator = PipelineCreator(problem_type="classification", apply_to="features") creator.add("zscore", apply_to=["features", "confound"]) creator.add("confound_removal", apply_to="features", confounds="confound") creator.add("rf") scores_cr = run_cross_validation( X=X + confounds, y=y, data=df_iris, model=creator, cv=cv, X_types=X_types, scoring=["accuracy", "roc_auc"], return_estimator="cv", seed=200, ) .. rst-class:: sphx-glr-script-out .. code-block:: none 2026-01-16 10:54:23,179 - julearn - INFO - Adding step zscore that applies to ColumnTypes 2026-01-16 10:54:23,179 - julearn - INFO - Step added 2026-01-16 10:54:23,179 - julearn - INFO - Adding step confound_removal that applies to ColumnTypes 2026-01-16 10:54:23,179 - julearn - INFO - Setting hyperparameter confounds = confound 2026-01-16 10:54:23,179 - julearn - INFO - Step added 2026-01-16 10:54:23,180 - julearn - INFO - Adding step rf that applies to ColumnTypes 2026-01-16 10:54:23,180 - julearn - INFO - Step added 2026-01-16 10:54:23,180 - julearn - INFO - Setting random seed to 200 2026-01-16 10:54:23,180 - julearn - INFO - ==== Input Data ==== 2026-01-16 10:54:23,180 - julearn - INFO - Using dataframe as input 2026-01-16 10:54:23,180 - julearn - INFO - Features: ['sepal_length', 'sepal_width', 'petal_length', 'petal_width'] 2026-01-16 10:54:23,180 - julearn - INFO - Target: species 2026-01-16 10:54:23,180 - julearn - INFO - Expanded features: ['sepal_length', 'sepal_width', 'petal_length', 'petal_width'] 2026-01-16 10:54:23,181 - julearn - INFO - X_types:{'features': ['sepal_length', 'sepal_width', 'petal_length'], 'confound': ['petal_width']} 2026-01-16 10:54:23,181 - julearn - INFO - ==================== 2026-01-16 10:54:23,181 - julearn - INFO - 2026-01-16 10:54:23,183 - julearn - INFO - = Model Parameters = 2026-01-16 10:54:23,183 - julearn - INFO - ==================== 2026-01-16 10:54:23,183 - julearn - INFO - 2026-01-16 10:54:23,183 - julearn - INFO - = Data Information = 2026-01-16 10:54:23,183 - julearn - INFO - Problem type: classification 2026-01-16 10:54:23,183 - julearn - INFO - Number of samples: 100 2026-01-16 10:54:23,184 - julearn - INFO - Number of features: 4 2026-01-16 10:54:23,184 - julearn - INFO - ==================== 2026-01-16 10:54:23,184 - julearn - INFO - 2026-01-16 10:54:23,184 - julearn - INFO - Number of classes: 2 2026-01-16 10:54:23,184 - julearn - INFO - Target type: object 2026-01-16 10:54:23,185 - julearn - INFO - Class distributions: species versicolor 50 virginica 50 Name: count, dtype: int64 2026-01-16 10:54:23,185 - julearn - INFO - Using outer CV scheme StratifiedBootstrap(n_splits=20, random_state=42) 2026-01-16 10:54:23,185 - julearn - INFO - Binary classification problem detected. .. GENERATED FROM PYTHON SOURCE LINES 119-121 Now we can compare the accuracies. We can combine the two outputs as ``pandas.DataFrame``. .. GENERATED FROM PYTHON SOURCE LINES 121-124 .. code-block:: Python scores_ncr["confounds"] = "Not Removed" scores_cr["confounds"] = "Removed" .. GENERATED FROM PYTHON SOURCE LINES 125-127 Now we convert the metrics to a column for easier seaborn plotting (convert to long format). .. GENERATED FROM PYTHON SOURCE LINES 127-144 .. code-block:: Python index = ["fold", "confounds"] scorings = ["test_accuracy", "test_roc_auc"] df_ncr_metrics = scores_ncr.set_index(index)[scorings].stack() df_ncr_metrics.index.names = ["fold", "confounds", "metric"] df_ncr_metrics.name = "value" df_cr_metrics = scores_cr.set_index(index)[scorings].stack() df_cr_metrics.index.names = ["fold", "confounds", "metric"] df_cr_metrics.name = "value" df_metrics = pd.concat((df_ncr_metrics, df_cr_metrics)) df_metrics = df_metrics.reset_index() df_metrics.head() .. raw:: html
fold confounds metric value
0 0 Not Removed test_accuracy 0.909091
1 0 Not Removed test_roc_auc 0.951852
2 1 Not Removed test_accuracy 0.900000
3 1 Not Removed test_roc_auc 0.943182
4 2 Not Removed test_accuracy 0.838710


.. GENERATED FROM PYTHON SOURCE LINES 145-146 And finally plot the results. .. GENERATED FROM PYTHON SOURCE LINES 146-151 .. code-block:: Python sns.catplot( x="confounds", y="value", col="metric", data=df_metrics, kind="swarm" ) plt.tight_layout() .. image-sg:: /auto_examples/04_confounds/images/sphx_glr_plot_confound_removal_classification_001.png :alt: metric = test_accuracy, metric = test_roc_auc :srcset: /auto_examples/04_confounds/images/sphx_glr_plot_confound_removal_classification_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 152-159 While this plot allows us to see the mean performance values and compare them, these samples are paired. In order to see if there is a systematic difference, we need to check the distribution of differences between the the models. First, we remove the column "confounds" from the index and make the difference between the metrics. .. GENERATED FROM PYTHON SOURCE LINES 159-165 .. code-block:: Python df_cr_metrics = df_cr_metrics.reset_index().set_index(["fold", "metric"]) df_ncr_metrics = df_ncr_metrics.reset_index().set_index(["fold", "metric"]) df_diff_metrics = df_ncr_metrics["value"] - df_cr_metrics["value"] df_diff_metrics = df_diff_metrics.reset_index() .. GENERATED FROM PYTHON SOURCE LINES 166-168 Now we can finally plot the difference, setting the whiskers of the box plot at 2.5 and 97.5 to see the 95% CI. .. GENERATED FROM PYTHON SOURCE LINES 168-174 .. code-block:: Python sns.boxplot( x="metric", y="value", data=df_diff_metrics.reset_index(), whis=[2.5, 97.5] ) plt.axhline(0, color="k", ls=":") plt.tight_layout() .. image-sg:: /auto_examples/04_confounds/images/sphx_glr_plot_confound_removal_classification_002.png :alt: plot confound removal classification :srcset: /auto_examples/04_confounds/images/sphx_glr_plot_confound_removal_classification_002.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 175-187 We can see that while it seems that the accuracy and ROC AUC scores are higher when confounds are not removed. We can not really claim (using this test), that the models are different in terms of these metrics. Maybe the percentiles will be more accuracy with the proper amount of bootstrap iterations? But the main point of confound removal is for interpretability. Let's see if there is a change in the feature importances. First, we need to collect the feature importances for each model, for each fold. .. GENERATED FROM PYTHON SOURCE LINES 187-216 .. code-block:: Python ncr_fi = [] for i_fold, estimator in enumerate(scores_ncr["estimator"]): this_importances = pd.DataFrame( { "feature": [x.replace("_", " ") for x in X], "importance": estimator["rf"].feature_importances_, "confounds": "Not Removed", "fold": i_fold, } ) ncr_fi.append(this_importances) ncr_fi = pd.concat(ncr_fi) cr_fi = [] for i_fold, estimator in enumerate(scores_cr["estimator"]): this_importances = pd.DataFrame( { "feature": [x.replace("_", " ") for x in X], "importance": estimator["rf"].model.feature_importances_, "confounds": "Removed", "fold": i_fold, } ) cr_fi.append(this_importances) cr_fi = pd.concat(cr_fi) feature_importance = pd.concat([cr_fi, ncr_fi]) .. GENERATED FROM PYTHON SOURCE LINES 217-218 We can now plot the importances. .. GENERATED FROM PYTHON SOURCE LINES 218-229 .. code-block:: Python sns.catplot( x="feature", y="importance", hue="confounds", dodge=True, data=feature_importance, kind="swarm", s=3, ) plt.tight_layout() .. image-sg:: /auto_examples/04_confounds/images/sphx_glr_plot_confound_removal_classification_003.png :alt: plot confound removal classification :srcset: /auto_examples/04_confounds/images/sphx_glr_plot_confound_removal_classification_003.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 230-232 And check the differences in importances. We can now see that there is a difference in importances. .. GENERATED FROM PYTHON SOURCE LINES 232-241 .. code-block:: Python diff_fi = ( cr_fi.set_index(["feature", "fold"])["importance"] - ncr_fi.set_index(["feature", "fold"])["importance"] ) sns.boxplot( x="importance", y="feature", data=diff_fi.reset_index(), whis=[2.5, 97.5] ) plt.axvline(0, color="k", ls=":") plt.tight_layout() .. image-sg:: /auto_examples/04_confounds/images/sphx_glr_plot_confound_removal_classification_004.png :alt: plot confound removal classification :srcset: /auto_examples/04_confounds/images/sphx_glr_plot_confound_removal_classification_004.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 6.570 seconds) .. _sphx_glr_download_auto_examples_04_confounds_plot_confound_removal_classification.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_confound_removal_classification.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_confound_removal_classification.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_confound_removal_classification.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_