.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/02_inspection/plot_inspect_random_forest.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_02_inspection_plot_inspect_random_forest.py: Inspecting Random Forest models =============================== This example uses the ``iris`` dataset, performs simple binary classification using a Random Forest classifier and analyse the model. .. include:: ../../links.inc .. GENERATED FROM PYTHON SOURCE LINES 10-22 .. code-block:: Python # Authors: Federico Raimondo # License: AGPL import pandas as pd import matplotlib.pyplot as plt import seaborn as sns from seaborn import load_dataset from julearn import run_cross_validation from julearn.utils import configure_logging .. GENERATED FROM PYTHON SOURCE LINES 23-24 Set the logging level to info to see extra information. .. GENERATED FROM PYTHON SOURCE LINES 24-26 .. code-block:: Python configure_logging(level="INFO") .. rst-class:: sphx-glr-script-out .. code-block:: none 2026-01-16 10:54:01,129 - julearn - INFO - ===== Lib Versions ===== 2026-01-16 10:54:01,129 - julearn - INFO - numpy: 1.26.4 2026-01-16 10:54:01,129 - julearn - INFO - scipy: 1.17.0 2026-01-16 10:54:01,129 - julearn - INFO - sklearn: 1.7.2 2026-01-16 10:54:01,129 - julearn - INFO - pandas: 2.3.3 2026-01-16 10:54:01,129 - julearn - INFO - julearn: 0.3.5.dev123 2026-01-16 10:54:01,129 - julearn - INFO - ======================== .. GENERATED FROM PYTHON SOURCE LINES 27-29 Random Forest variable importance --------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 29-32 .. code-block:: Python df_iris = load_dataset("iris") .. GENERATED FROM PYTHON SOURCE LINES 33-35 The dataset has three kind of species. We will keep two to perform a binary classification. .. GENERATED FROM PYTHON SOURCE LINES 35-41 .. code-block:: Python df_iris = df_iris[df_iris["species"].isin(["versicolor", "virginica"])] X = ["sepal_length", "sepal_width", "petal_length"] y = "species" .. GENERATED FROM PYTHON SOURCE LINES 42-45 We will use a Random Forest classifier. By setting `return_estimator='final'`, the :func:`.run_cross_validation` function returns the estimator fitted with all the data. .. GENERATED FROM PYTHON SOURCE LINES 45-56 .. code-block:: Python scores, model_iris = run_cross_validation( X=X, y=y, data=df_iris, model="rf", preprocess="zscore", problem_type="classification", return_estimator="final", ) .. rst-class:: sphx-glr-script-out .. code-block:: none 2026-01-16 10:54:01,132 - julearn - INFO - ==== Input Data ==== 2026-01-16 10:54:01,132 - julearn - INFO - Using dataframe as input 2026-01-16 10:54:01,132 - julearn - INFO - Features: ['sepal_length', 'sepal_width', 'petal_length'] 2026-01-16 10:54:01,132 - julearn - INFO - Target: species 2026-01-16 10:54:01,132 - julearn - INFO - Expanded features: ['sepal_length', 'sepal_width', 'petal_length'] 2026-01-16 10:54:01,132 - julearn - INFO - X_types:{} 2026-01-16 10:54:01,132 - julearn - WARNING - The following columns are not defined in X_types: ['sepal_length', 'sepal_width', 'petal_length']. They will be treated as continuous. /home/runner/work/julearn/julearn/julearn/prepare.py:576: RuntimeWarning: The following columns are not defined in X_types: ['sepal_length', 'sepal_width', 'petal_length']. They will be treated as continuous. warn_with_log( 2026-01-16 10:54:01,133 - julearn - INFO - ==================== 2026-01-16 10:54:01,133 - julearn - INFO - 2026-01-16 10:54:01,133 - julearn - INFO - Adding step zscore that applies to ColumnTypes 2026-01-16 10:54:01,133 - julearn - INFO - Step added 2026-01-16 10:54:01,134 - julearn - INFO - Adding step rf that applies to ColumnTypes 2026-01-16 10:54:01,134 - julearn - INFO - Step added 2026-01-16 10:54:01,134 - julearn - INFO - = Model Parameters = 2026-01-16 10:54:01,134 - julearn - INFO - ==================== 2026-01-16 10:54:01,135 - julearn - INFO - 2026-01-16 10:54:01,135 - julearn - INFO - = Data Information = 2026-01-16 10:54:01,135 - julearn - INFO - Problem type: classification 2026-01-16 10:54:01,135 - julearn - INFO - Number of samples: 100 2026-01-16 10:54:01,135 - julearn - INFO - Number of features: 3 2026-01-16 10:54:01,135 - julearn - INFO - ==================== 2026-01-16 10:54:01,135 - julearn - INFO - 2026-01-16 10:54:01,135 - julearn - INFO - Number of classes: 2 2026-01-16 10:54:01,135 - julearn - INFO - Target type: object 2026-01-16 10:54:01,136 - julearn - INFO - Class distributions: species versicolor 50 virginica 50 Name: count, dtype: int64 2026-01-16 10:54:01,136 - julearn - INFO - Using outer CV scheme KFold(n_splits=5, random_state=None, shuffle=False) (incl. final model) 2026-01-16 10:54:01,136 - julearn - INFO - Binary classification problem detected. .. GENERATED FROM PYTHON SOURCE LINES 57-61 This type of classifier has an internal variable that can inform us on how *important* is each of the features. Caution: read the proper ``scikit-learn`` documentation :class:`~sklearn.ensemble.RandomForestClassifier` to understand how this learning algorithm works. .. GENERATED FROM PYTHON SOURCE LINES 61-75 .. code-block:: Python rf = model_iris["rf"] to_plot = pd.DataFrame( { "variable": [x.replace("_", " ") for x in X], "importance": rf.feature_importances_, } ) fig, ax = plt.subplots(1, 1, figsize=(6, 4)) sns.barplot(x="importance", y="variable", data=to_plot, ax=ax) ax.set_title("Variable Importances for Random Forest Classifier") fig.tight_layout() .. image-sg:: /auto_examples/02_inspection/images/sphx_glr_plot_inspect_random_forest_001.png :alt: Variable Importances for Random Forest Classifier :srcset: /auto_examples/02_inspection/images/sphx_glr_plot_inspect_random_forest_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 76-83 However, some reviewers (including us), might wander about the variability of the importance of these features. In the previous example all the feature importances were obtained by fitting on the entire dataset, while the performance was estimated using cross validation. By specifying `return_estimator='cv'`, we can get, for each fold, the fitted estimator. .. GENERATED FROM PYTHON SOURCE LINES 83-94 .. code-block:: Python scores = run_cross_validation( X=X, y=y, data=df_iris, model="rf", preprocess="zscore", problem_type="classification", return_estimator="cv", ) .. rst-class:: sphx-glr-script-out .. code-block:: none 2026-01-16 10:54:02,009 - julearn - INFO - ==== Input Data ==== 2026-01-16 10:54:02,009 - julearn - INFO - Using dataframe as input 2026-01-16 10:54:02,009 - julearn - INFO - Features: ['sepal_length', 'sepal_width', 'petal_length'] 2026-01-16 10:54:02,009 - julearn - INFO - Target: species 2026-01-16 10:54:02,009 - julearn - INFO - Expanded features: ['sepal_length', 'sepal_width', 'petal_length'] 2026-01-16 10:54:02,009 - julearn - INFO - X_types:{} 2026-01-16 10:54:02,010 - julearn - WARNING - The following columns are not defined in X_types: ['sepal_length', 'sepal_width', 'petal_length']. They will be treated as continuous. /home/runner/work/julearn/julearn/julearn/prepare.py:576: RuntimeWarning: The following columns are not defined in X_types: ['sepal_length', 'sepal_width', 'petal_length']. They will be treated as continuous. warn_with_log( 2026-01-16 10:54:02,011 - julearn - INFO - ==================== 2026-01-16 10:54:02,011 - julearn - INFO - 2026-01-16 10:54:02,011 - julearn - INFO - Adding step zscore that applies to ColumnTypes 2026-01-16 10:54:02,011 - julearn - INFO - Step added 2026-01-16 10:54:02,011 - julearn - INFO - Adding step rf that applies to ColumnTypes 2026-01-16 10:54:02,011 - julearn - INFO - Step added 2026-01-16 10:54:02,012 - julearn - INFO - = Model Parameters = 2026-01-16 10:54:02,012 - julearn - INFO - ==================== 2026-01-16 10:54:02,012 - julearn - INFO - 2026-01-16 10:54:02,012 - julearn - INFO - = Data Information = 2026-01-16 10:54:02,012 - julearn - INFO - Problem type: classification 2026-01-16 10:54:02,013 - julearn - INFO - Number of samples: 100 2026-01-16 10:54:02,013 - julearn - INFO - Number of features: 3 2026-01-16 10:54:02,013 - julearn - INFO - ==================== 2026-01-16 10:54:02,013 - julearn - INFO - 2026-01-16 10:54:02,013 - julearn - INFO - Number of classes: 2 2026-01-16 10:54:02,013 - julearn - INFO - Target type: object 2026-01-16 10:54:02,014 - julearn - INFO - Class distributions: species versicolor 50 virginica 50 Name: count, dtype: int64 2026-01-16 10:54:02,014 - julearn - INFO - Using outer CV scheme KFold(n_splits=5, random_state=None, shuffle=False) 2026-01-16 10:54:02,014 - julearn - INFO - Binary classification problem detected. .. GENERATED FROM PYTHON SOURCE LINES 95-96 Now we can obtain the feature importance for each estimator (CV fold). .. GENERATED FROM PYTHON SOURCE LINES 96-109 .. code-block:: Python to_plot = [] for i_fold, estimator in enumerate(scores["estimator"]): this_importances = pd.DataFrame( { "variable": [x.replace("_", " ") for x in X], "importance": estimator["rf"].feature_importances_, "fold": i_fold, } ) to_plot.append(this_importances) to_plot = pd.concat(to_plot) .. GENERATED FROM PYTHON SOURCE LINES 110-111 Finally, we can plot the variable importances for each fold. .. GENERATED FROM PYTHON SOURCE LINES 111-119 .. code-block:: Python fig, ax = plt.subplots(1, 1, figsize=(6, 4)) sns.swarmplot(x="importance", y="variable", data=to_plot, ax=ax) ax.set_title( "Distribution of variable Importances for Random Forest " "Classifier across folds" ) fig.tight_layout() .. image-sg:: /auto_examples/02_inspection/images/sphx_glr_plot_inspect_random_forest_002.png :alt: Distribution of variable Importances for Random Forest Classifier across folds :srcset: /auto_examples/02_inspection/images/sphx_glr_plot_inspect_random_forest_002.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 1.644 seconds) .. _sphx_glr_download_auto_examples_02_inspection_plot_inspect_random_forest.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_inspect_random_forest.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_inspect_random_forest.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_inspect_random_forest.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_