.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/basic/plot_confound_removal_classification.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_basic_plot_confound_removal_classification.py: Confound Removal (model comparison) =================================== This example uses the 'iris' dataset, performs simple binary classification with and without confound removal using a Random Forest classifier. .. GENERATED FROM PYTHON SOURCE LINES 9-24 .. code-block:: default # Authors: Shammi More # Federico Raimondo # # License: AGPL import pandas as pd import matplotlib.pyplot as plt import seaborn as sns from seaborn import load_dataset from julearn import run_cross_validation from julearn.utils import configure_logging from julearn.model_selection import StratifiedBootstrap .. rst-class:: sphx-glr-script-out Out: .. code-block:: none /opt/hostedtoolcache/Python/3.8.13/x64/lib/python3.8/site-packages/seaborn/cm.py:1582: UserWarning: Trying to register the cmap 'rocket' which already exists. mpl_cm.register_cmap(_name, _cmap) /opt/hostedtoolcache/Python/3.8.13/x64/lib/python3.8/site-packages/seaborn/cm.py:1583: UserWarning: Trying to register the cmap 'rocket_r' which already exists. mpl_cm.register_cmap(_name + "_r", _cmap_r) /opt/hostedtoolcache/Python/3.8.13/x64/lib/python3.8/site-packages/seaborn/cm.py:1582: UserWarning: Trying to register the cmap 'mako' which already exists. mpl_cm.register_cmap(_name, _cmap) /opt/hostedtoolcache/Python/3.8.13/x64/lib/python3.8/site-packages/seaborn/cm.py:1583: UserWarning: Trying to register the cmap 'mako_r' which already exists. mpl_cm.register_cmap(_name + "_r", _cmap_r) /opt/hostedtoolcache/Python/3.8.13/x64/lib/python3.8/site-packages/seaborn/cm.py:1582: UserWarning: Trying to register the cmap 'icefire' which already exists. mpl_cm.register_cmap(_name, _cmap) /opt/hostedtoolcache/Python/3.8.13/x64/lib/python3.8/site-packages/seaborn/cm.py:1583: UserWarning: Trying to register the cmap 'icefire_r' which already exists. mpl_cm.register_cmap(_name + "_r", _cmap_r) /opt/hostedtoolcache/Python/3.8.13/x64/lib/python3.8/site-packages/seaborn/cm.py:1582: UserWarning: Trying to register the cmap 'vlag' which already exists. mpl_cm.register_cmap(_name, _cmap) /opt/hostedtoolcache/Python/3.8.13/x64/lib/python3.8/site-packages/seaborn/cm.py:1583: UserWarning: Trying to register the cmap 'vlag_r' which already exists. mpl_cm.register_cmap(_name + "_r", _cmap_r) /opt/hostedtoolcache/Python/3.8.13/x64/lib/python3.8/site-packages/seaborn/cm.py:1582: UserWarning: Trying to register the cmap 'flare' which already exists. mpl_cm.register_cmap(_name, _cmap) /opt/hostedtoolcache/Python/3.8.13/x64/lib/python3.8/site-packages/seaborn/cm.py:1583: UserWarning: Trying to register the cmap 'flare_r' which already exists. mpl_cm.register_cmap(_name + "_r", _cmap_r) /opt/hostedtoolcache/Python/3.8.13/x64/lib/python3.8/site-packages/seaborn/cm.py:1582: UserWarning: Trying to register the cmap 'crest' which already exists. mpl_cm.register_cmap(_name, _cmap) /opt/hostedtoolcache/Python/3.8.13/x64/lib/python3.8/site-packages/seaborn/cm.py:1583: UserWarning: Trying to register the cmap 'crest_r' which already exists. mpl_cm.register_cmap(_name + "_r", _cmap_r) .. GENERATED FROM PYTHON SOURCE LINES 25-26 Set the logging level to info to see extra information .. GENERATED FROM PYTHON SOURCE LINES 26-28 .. code-block:: default configure_logging(level='INFO') .. rst-class:: sphx-glr-script-out Out: .. code-block:: none 2022-07-21 09:54:57,048 - julearn - INFO - ===== Lib Versions ===== 2022-07-21 09:54:57,048 - julearn - INFO - numpy: 1.23.1 2022-07-21 09:54:57,048 - julearn - INFO - scipy: 1.8.1 2022-07-21 09:54:57,048 - julearn - INFO - sklearn: 1.0.2 2022-07-21 09:54:57,048 - julearn - INFO - pandas: 1.4.3 2022-07-21 09:54:57,048 - julearn - INFO - julearn: 0.2.5 2022-07-21 09:54:57,048 - julearn - INFO - ======================== .. GENERATED FROM PYTHON SOURCE LINES 29-30 load the iris data from seaborn .. GENERATED FROM PYTHON SOURCE LINES 30-32 .. code-block:: default df_iris = load_dataset('iris') .. GENERATED FROM PYTHON SOURCE LINES 33-35 The dataset has three kind of species. We will keep two to perform a binary classification. .. GENERATED FROM PYTHON SOURCE LINES 35-38 .. code-block:: default df_iris = df_iris[df_iris['species'].isin(['versicolor', 'virginica'])] .. GENERATED FROM PYTHON SOURCE LINES 39-41 As features, we will use the sepal length, width and petal length and use petal width as confound. .. GENERATED FROM PYTHON SOURCE LINES 41-46 .. code-block:: default X = ['sepal_length', 'sepal_width', 'petal_length'] y = 'species' confound = 'petal_width' .. GENERATED FROM PYTHON SOURCE LINES 47-60 Doing hypothesis testing in ML is not that simple. If we were to used classical frequentist statistics, we have the problem that using cross validation, the samples are not independent and the population (train + test) is always the same. If we want to compare two models, an alternative is to contrast, for each fold, the performance gap between the models. If we combine that approach with bootstrapping, we can then compare the confidence intervals of the difference. If the 95% CI is above 0 (or below), we can claim that the models are different with p < 0.05. Lets use a bootstrap CV. For time purposes we do 20 iterations, change the number of bootstrap iterations to at least 2000 for a valid test. .. GENERATED FROM PYTHON SOURCE LINES 60-64 .. code-block:: default n_bootstrap = 20 n_elements = len(df_iris) cv = StratifiedBootstrap(n_splits=n_bootstrap, test_size=.3, random_state=42) .. GENERATED FROM PYTHON SOURCE LINES 65-67 First, we will train a model without performing confound removal on features Note: confounds=None by default .. GENERATED FROM PYTHON SOURCE LINES 67-72 .. code-block:: default scores_ncr = run_cross_validation( X=X, y=y, data=df_iris, model='rf', cv=cv, preprocess_X='zscore', scoring=['accuracy', 'roc_auc'], return_estimator='cv', seed=200) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none 2022-07-21 09:54:57,052 - julearn - INFO - Setting random seed to 200 2022-07-21 09:54:57,052 - julearn - INFO - ==== Input Data ==== 2022-07-21 09:54:57,052 - julearn - INFO - Using dataframe as input 2022-07-21 09:54:57,052 - julearn - INFO - Features: ['sepal_length', 'sepal_width', 'petal_length'] 2022-07-21 09:54:57,052 - julearn - INFO - Target: species 2022-07-21 09:54:57,052 - julearn - INFO - Expanded X: ['sepal_length', 'sepal_width', 'petal_length'] 2022-07-21 09:54:57,052 - julearn - INFO - Expanded Confounds: [] 2022-07-21 09:54:57,053 - julearn - INFO - ==================== 2022-07-21 09:54:57,053 - julearn - INFO - 2022-07-21 09:54:57,053 - julearn - INFO - ====== Model ====== 2022-07-21 09:54:57,053 - julearn - INFO - Obtaining model by name: rf 2022-07-21 09:54:57,053 - julearn - INFO - =================== 2022-07-21 09:54:57,053 - julearn - INFO - 2022-07-21 09:54:57,054 - julearn - INFO - Using scikit-learn CV scheme StratifiedBootstrap(n_splits=20, random_state=42, test_size=0.3, train_size=None) .. GENERATED FROM PYTHON SOURCE LINES 73-75 Next, we train a model after performing confound removal on the features Note: we initialize the CV again to use the same folds as before .. GENERATED FROM PYTHON SOURCE LINES 75-81 .. code-block:: default cv = StratifiedBootstrap(n_splits=n_bootstrap, test_size=.3, random_state=42) scores_cr = run_cross_validation( X=X, y=y, confounds=confound, data=df_iris, model='rf', preprocess_X='remove_confound', preprocess_confounds='zscore', cv=cv, scoring=['accuracy', 'roc_auc'], return_estimator='cv', seed=200) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none 2022-07-21 09:54:59,674 - julearn - INFO - Setting random seed to 200 2022-07-21 09:54:59,674 - julearn - INFO - ==== Input Data ==== 2022-07-21 09:54:59,674 - julearn - INFO - Using dataframe as input 2022-07-21 09:54:59,674 - julearn - INFO - Features: ['sepal_length', 'sepal_width', 'petal_length'] 2022-07-21 09:54:59,674 - julearn - INFO - Target: species 2022-07-21 09:54:59,674 - julearn - INFO - Confounds: petal_width 2022-07-21 09:54:59,674 - julearn - INFO - Expanded X: ['sepal_length', 'sepal_width', 'petal_length'] 2022-07-21 09:54:59,674 - julearn - INFO - Expanded Confounds: ['petal_width'] 2022-07-21 09:54:59,675 - julearn - INFO - ==================== 2022-07-21 09:54:59,675 - julearn - INFO - 2022-07-21 09:54:59,675 - julearn - INFO - ====== Model ====== 2022-07-21 09:54:59,675 - julearn - INFO - Obtaining model by name: rf 2022-07-21 09:54:59,675 - julearn - INFO - =================== 2022-07-21 09:54:59,675 - julearn - INFO - 2022-07-21 09:54:59,676 - julearn - INFO - Using scikit-learn CV scheme StratifiedBootstrap(n_splits=20, random_state=42, test_size=0.3, train_size=None) .. GENERATED FROM PYTHON SOURCE LINES 82-84 Now we can compare the accuracies. We can combine the two outputs as pandas dataframes .. GENERATED FROM PYTHON SOURCE LINES 84-87 .. code-block:: default scores_ncr['confounds'] = 'Not Removed' scores_cr['confounds'] = 'Removed' .. GENERATED FROM PYTHON SOURCE LINES 88-90 Now we convert the metrics to a column for easier seaborn plotting (convert to long format) .. GENERATED FROM PYTHON SOURCE LINES 90-107 .. code-block:: default index = ['fold', 'confounds'] scorings = ['test_accuracy', 'test_roc_auc'] df_ncr_metrics = scores_ncr.set_index(index)[scorings].stack() df_ncr_metrics.index.names = ['fold', 'confounds', 'metric'] df_ncr_metrics.name = 'value' df_cr_metrics = scores_cr.set_index(index)[scorings].stack() df_cr_metrics.index.names = ['fold', 'confounds', 'metric'] df_cr_metrics.name = 'value' df_metrics = pd.concat((df_ncr_metrics, df_cr_metrics)) df_metrics = df_metrics.reset_index() print(df_metrics.head()) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none fold confounds metric value 0 0 Not Removed test_accuracy 0.933333 1 0 Not Removed test_roc_auc 0.968889 2 1 Not Removed test_accuracy 0.933333 3 1 Not Removed test_roc_auc 0.948889 4 2 Not Removed test_accuracy 1.000000 .. GENERATED FROM PYTHON SOURCE LINES 108-109 And finally plot the results .. GENERATED FROM PYTHON SOURCE LINES 109-113 .. code-block:: default sns.catplot(x='confounds', y='value', col='metric', data=df_metrics, kind='swarm') plt.tight_layout() .. image-sg:: /auto_examples/basic/images/sphx_glr_plot_confound_removal_classification_001.png :alt: metric = test_accuracy, metric = test_roc_auc :srcset: /auto_examples/basic/images/sphx_glr_plot_confound_removal_classification_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 114-121 While this plot allows us to see the mean performance values and compare them, these samples are paired. In order to see if there is a systematic difference, we need to check the distribution of differeces between the the models. First we remove the column "confounds" from the index and make the difference between the metrics .. GENERATED FROM PYTHON SOURCE LINES 121-128 .. code-block:: default df_cr_metrics = df_cr_metrics.reset_index().set_index(['fold', 'metric']) df_ncr_metrics = df_ncr_metrics.reset_index().set_index(['fold', 'metric']) df_diff_metrics = df_ncr_metrics['value'] - df_cr_metrics['value'] df_diff_metrics = df_diff_metrics.reset_index() .. GENERATED FROM PYTHON SOURCE LINES 129-131 Now we can finally plot the difference, setting the whiskers of the box plot at 2.5 and 97.5 to see the 95% CI. .. GENERATED FROM PYTHON SOURCE LINES 131-136 .. code-block:: default sns.boxplot(x='metric', y='value', data=df_diff_metrics.reset_index(), whis=[2.5, 97.5]) plt.axhline(0, color='k', ls=':') plt.tight_layout() .. image-sg:: /auto_examples/basic/images/sphx_glr_plot_confound_removal_classification_002.png :alt: plot confound removal classification :srcset: /auto_examples/basic/images/sphx_glr_plot_confound_removal_classification_002.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 137-150 We can see that while it seems that the accuracy and ROC AUC scores are higher when confounds are not removed. We can not really claim (using this test), that the models are different in terms of these metrics. Maybe the percentiles will be more accuracy with the proper amount of bootstrap iterations? But the main point of confound removal is for interpretability. Lets see if there is a change in the feature importances. First, we need to collect the feature importances for each model, for each fold. .. GENERATED FROM PYTHON SOURCE LINES 150-175 .. code-block:: default ncr_fi = [] for i_fold, estimator in enumerate(scores_ncr['estimator']): this_importances = pd.DataFrame({ 'feature': [x.replace('_', ' ') for x in X], 'importance': estimator['rf'].feature_importances_, 'confounds': 'Not Removed', 'fold': i_fold }) ncr_fi.append(this_importances) ncr_fi = pd.concat(ncr_fi) cr_fi = [] for i_fold, estimator in enumerate(scores_cr['estimator']): this_importances = pd.DataFrame({ 'feature': [x.replace('_', ' ') for x in X], 'importance': estimator['rf'].feature_importances_, 'confounds': 'Removed', 'fold': i_fold }) cr_fi.append(this_importances) cr_fi = pd.concat(cr_fi) feature_importance = pd.concat([cr_fi, ncr_fi]) .. GENERATED FROM PYTHON SOURCE LINES 176-177 We can now plot the importances .. GENERATED FROM PYTHON SOURCE LINES 177-181 .. code-block:: default sns.catplot(x='feature', y='importance', hue='confounds', dodge=True, data=feature_importance, kind='swarm', s=3) plt.tight_layout() .. image-sg:: /auto_examples/basic/images/sphx_glr_plot_confound_removal_classification_003.png :alt: plot confound removal classification :srcset: /auto_examples/basic/images/sphx_glr_plot_confound_removal_classification_003.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 182-184 And check the differences in importances. We can now see that there is a difference in importances. .. GENERATED FROM PYTHON SOURCE LINES 184-190 .. code-block:: default diff_fi = (cr_fi.set_index(['feature', 'fold'])['importance'] - ncr_fi.set_index(['feature', 'fold'])['importance']) sns.boxplot(x='importance', y='feature', data=diff_fi.reset_index(), whis=[2.5, 97.5]) plt.axvline(0, color='k', ls=':') plt.tight_layout() .. image-sg:: /auto_examples/basic/images/sphx_glr_plot_confound_removal_classification_004.png :alt: plot confound removal classification :srcset: /auto_examples/basic/images/sphx_glr_plot_confound_removal_classification_004.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 6.760 seconds) .. _sphx_glr_download_auto_examples_basic_plot_confound_removal_classification.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_confound_removal_classification.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_confound_removal_classification.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_