.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/basic/plot_inspect_random_forest.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_basic_plot_inspect_random_forest.py: Inspecting Random Forest models =============================== This example uses the 'iris' dataset, performs simple binary classification using a Random Forest classifier and analyse the model. .. include:: ../../links.inc .. GENERATED FROM PYTHON SOURCE LINES 10-22 .. code-block:: default # Authors: Federico Raimondo # # License: AGPL import pandas as pd import matplotlib.pyplot as plt import seaborn as sns from seaborn import load_dataset from julearn import run_cross_validation from julearn.utils import configure_logging .. rst-class:: sphx-glr-script-out Out: .. code-block:: none /opt/hostedtoolcache/Python/3.8.13/x64/lib/python3.8/site-packages/seaborn/cm.py:1582: UserWarning: Trying to register the cmap 'rocket' which already exists. mpl_cm.register_cmap(_name, _cmap) /opt/hostedtoolcache/Python/3.8.13/x64/lib/python3.8/site-packages/seaborn/cm.py:1583: UserWarning: Trying to register the cmap 'rocket_r' which already exists. mpl_cm.register_cmap(_name + "_r", _cmap_r) /opt/hostedtoolcache/Python/3.8.13/x64/lib/python3.8/site-packages/seaborn/cm.py:1582: UserWarning: Trying to register the cmap 'mako' which already exists. mpl_cm.register_cmap(_name, _cmap) /opt/hostedtoolcache/Python/3.8.13/x64/lib/python3.8/site-packages/seaborn/cm.py:1583: UserWarning: Trying to register the cmap 'mako_r' which already exists. mpl_cm.register_cmap(_name + "_r", _cmap_r) /opt/hostedtoolcache/Python/3.8.13/x64/lib/python3.8/site-packages/seaborn/cm.py:1582: UserWarning: Trying to register the cmap 'icefire' which already exists. mpl_cm.register_cmap(_name, _cmap) /opt/hostedtoolcache/Python/3.8.13/x64/lib/python3.8/site-packages/seaborn/cm.py:1583: UserWarning: Trying to register the cmap 'icefire_r' which already exists. mpl_cm.register_cmap(_name + "_r", _cmap_r) /opt/hostedtoolcache/Python/3.8.13/x64/lib/python3.8/site-packages/seaborn/cm.py:1582: UserWarning: Trying to register the cmap 'vlag' which already exists. mpl_cm.register_cmap(_name, _cmap) /opt/hostedtoolcache/Python/3.8.13/x64/lib/python3.8/site-packages/seaborn/cm.py:1583: UserWarning: Trying to register the cmap 'vlag_r' which already exists. mpl_cm.register_cmap(_name + "_r", _cmap_r) /opt/hostedtoolcache/Python/3.8.13/x64/lib/python3.8/site-packages/seaborn/cm.py:1582: UserWarning: Trying to register the cmap 'flare' which already exists. mpl_cm.register_cmap(_name, _cmap) /opt/hostedtoolcache/Python/3.8.13/x64/lib/python3.8/site-packages/seaborn/cm.py:1583: UserWarning: Trying to register the cmap 'flare_r' which already exists. mpl_cm.register_cmap(_name + "_r", _cmap_r) /opt/hostedtoolcache/Python/3.8.13/x64/lib/python3.8/site-packages/seaborn/cm.py:1582: UserWarning: Trying to register the cmap 'crest' which already exists. mpl_cm.register_cmap(_name, _cmap) /opt/hostedtoolcache/Python/3.8.13/x64/lib/python3.8/site-packages/seaborn/cm.py:1583: UserWarning: Trying to register the cmap 'crest_r' which already exists. mpl_cm.register_cmap(_name + "_r", _cmap_r) .. GENERATED FROM PYTHON SOURCE LINES 23-24 Set the logging level to info to see extra information .. GENERATED FROM PYTHON SOURCE LINES 24-26 .. code-block:: default configure_logging(level='INFO') .. rst-class:: sphx-glr-script-out Out: .. code-block:: none 2022-07-21 09:54:40,266 - julearn - INFO - ===== Lib Versions ===== 2022-07-21 09:54:40,266 - julearn - INFO - numpy: 1.23.1 2022-07-21 09:54:40,266 - julearn - INFO - scipy: 1.8.1 2022-07-21 09:54:40,266 - julearn - INFO - sklearn: 1.0.2 2022-07-21 09:54:40,266 - julearn - INFO - pandas: 1.4.3 2022-07-21 09:54:40,266 - julearn - INFO - julearn: 0.2.5 2022-07-21 09:54:40,266 - julearn - INFO - ======================== .. GENERATED FROM PYTHON SOURCE LINES 27-30 Random Forest variable importance ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. GENERATED FROM PYTHON SOURCE LINES 30-33 .. code-block:: default df_iris = load_dataset('iris') .. GENERATED FROM PYTHON SOURCE LINES 34-36 The dataset has three kind of species. We will keep two to perform a binary classification. .. GENERATED FROM PYTHON SOURCE LINES 36-42 .. code-block:: default df_iris = df_iris[df_iris['species'].isin(['versicolor', 'virginica'])] X = ['sepal_length', 'sepal_width', 'petal_length'] y = 'species' .. GENERATED FROM PYTHON SOURCE LINES 43-46 We will use a Random Forest classifier. By setting `return_estimator='final'`, the :func:`.run_cross_validation` function returns the estimator fitted with all the data. .. GENERATED FROM PYTHON SOURCE LINES 46-51 .. code-block:: default scores, model_iris = run_cross_validation( X=X, y=y, data=df_iris, model='rf', preprocess_X='zscore', return_estimator='final') .. rst-class:: sphx-glr-script-out Out: .. code-block:: none 2022-07-21 09:54:40,269 - julearn - INFO - Using default CV 2022-07-21 09:54:40,269 - julearn - INFO - ==== Input Data ==== 2022-07-21 09:54:40,269 - julearn - INFO - Using dataframe as input 2022-07-21 09:54:40,269 - julearn - INFO - Features: ['sepal_length', 'sepal_width', 'petal_length'] 2022-07-21 09:54:40,269 - julearn - INFO - Target: species 2022-07-21 09:54:40,270 - julearn - INFO - Expanded X: ['sepal_length', 'sepal_width', 'petal_length'] 2022-07-21 09:54:40,270 - julearn - INFO - Expanded Confounds: [] 2022-07-21 09:54:40,270 - julearn - INFO - ==================== 2022-07-21 09:54:40,270 - julearn - INFO - 2022-07-21 09:54:40,270 - julearn - INFO - ====== Model ====== 2022-07-21 09:54:40,270 - julearn - INFO - Obtaining model by name: rf 2022-07-21 09:54:40,271 - julearn - INFO - =================== 2022-07-21 09:54:40,271 - julearn - INFO - 2022-07-21 09:54:40,271 - julearn - INFO - CV interpreted as RepeatedKFold with 5 repetitions of 5 folds .. GENERATED FROM PYTHON SOURCE LINES 52-55 This type of classifier has an internal variable that can inform us on how _important_ is each of the features. Caution: read the proper scikit-learn documentation (`Random Forest`_) .. GENERATED FROM PYTHON SOURCE LINES 55-67 .. code-block:: default rf = model_iris['rf'] to_plot = pd.DataFrame({ 'variable': [x.replace('_', ' ') for x in X], 'importance': rf.feature_importances_ }) fig, ax = plt.subplots(1, 1, figsize=(6, 4)) sns.barplot(x='importance', y='variable', data=to_plot, ax=ax) ax.set_title('Variable Importances for Random Forest Classifier') fig.tight_layout() .. image-sg:: /auto_examples/basic/images/sphx_glr_plot_inspect_random_forest_001.png :alt: Variable Importances for Random Forest Classifier :srcset: /auto_examples/basic/images/sphx_glr_plot_inspect_random_forest_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 68-75 However, some reviewers (including myself), might wander about the variability of the importance of these features. In the previous example all the feature importances were obtained by fitting on the entire dataset, while the performance was estimated using cross validation. By specifying `return_estimator='cv'`, we can get, for each fold, the fitted estimator. .. GENERATED FROM PYTHON SOURCE LINES 75-80 .. code-block:: default scores = run_cross_validation( X=X, y=y, data=df_iris, model='rf', preprocess_X='zscore', return_estimator='cv') .. rst-class:: sphx-glr-script-out Out: .. code-block:: none 2022-07-21 09:54:43,363 - julearn - INFO - Using default CV 2022-07-21 09:54:43,363 - julearn - INFO - ==== Input Data ==== 2022-07-21 09:54:43,363 - julearn - INFO - Using dataframe as input 2022-07-21 09:54:43,363 - julearn - INFO - Features: ['sepal_length', 'sepal_width', 'petal_length'] 2022-07-21 09:54:43,363 - julearn - INFO - Target: species 2022-07-21 09:54:43,363 - julearn - INFO - Expanded X: ['sepal_length', 'sepal_width', 'petal_length'] 2022-07-21 09:54:43,363 - julearn - INFO - Expanded Confounds: [] 2022-07-21 09:54:43,364 - julearn - INFO - ==================== 2022-07-21 09:54:43,364 - julearn - INFO - 2022-07-21 09:54:43,364 - julearn - INFO - ====== Model ====== 2022-07-21 09:54:43,364 - julearn - INFO - Obtaining model by name: rf 2022-07-21 09:54:43,365 - julearn - INFO - =================== 2022-07-21 09:54:43,365 - julearn - INFO - 2022-07-21 09:54:43,365 - julearn - INFO - CV interpreted as RepeatedKFold with 5 repetitions of 5 folds .. GENERATED FROM PYTHON SOURCE LINES 81-82 Now we can obtain the feature importance for each estimator (CV fold) .. GENERATED FROM PYTHON SOURCE LINES 82-93 .. code-block:: default to_plot = [] for i_fold, estimator in enumerate(scores['estimator']): this_importances = pd.DataFrame({ 'variable': [x.replace('_', ' ') for x in X], 'importance': estimator['rf'].feature_importances_, 'fold': i_fold }) to_plot.append(this_importances) to_plot = pd.concat(to_plot) .. GENERATED FROM PYTHON SOURCE LINES 94-95 Finally, we can plot the variable importances for each fold .. GENERATED FROM PYTHON SOURCE LINES 95-101 .. code-block:: default fig, ax = plt.subplots(1, 1, figsize=(6, 4)) sns.swarmplot(x='importance', y='variable', data=to_plot, ax=ax) ax.set_title('Distribution of variable Importances for Random Forest ' 'Classifier across folds') fig.tight_layout() .. image-sg:: /auto_examples/basic/images/sphx_glr_plot_inspect_random_forest_002.png :alt: Distribution of variable Importances for Random Forest Classifier across folds :srcset: /auto_examples/basic/images/sphx_glr_plot_inspect_random_forest_002.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 6.301 seconds) .. _sphx_glr_download_auto_examples_basic_plot_inspect_random_forest.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_inspect_random_forest.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_inspect_random_forest.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_