.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/basic/run_simple_binary_classification.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_basic_run_simple_binary_classification.py: Simple Binary Classification ============================ This example uses the 'iris' dataset and performs a simple binary classification using a Support Vector Machine classifier. .. include:: ../../links.inc .. GENERATED FROM PYTHON SOURCE LINES 10-17 .. code-block:: default # Authors: Federico Raimondo # # License: AGPL from seaborn import load_dataset from julearn import run_cross_validation from julearn.utils import configure_logging .. GENERATED FROM PYTHON SOURCE LINES 18-19 Set the logging level to info to see extra information .. GENERATED FROM PYTHON SOURCE LINES 19-21 .. code-block:: default configure_logging(level='INFO') .. rst-class:: sphx-glr-script-out .. code-block:: none 2023-04-06 09:50:44,633 - julearn - INFO - ===== Lib Versions ===== 2023-04-06 09:50:44,633 - julearn - INFO - numpy: 1.23.5 2023-04-06 09:50:44,633 - julearn - INFO - scipy: 1.10.1 2023-04-06 09:50:44,633 - julearn - INFO - sklearn: 1.0.2 2023-04-06 09:50:44,633 - julearn - INFO - pandas: 1.4.4 2023-04-06 09:50:44,633 - julearn - INFO - julearn: 0.3.1.dev2 2023-04-06 09:50:44,633 - julearn - INFO - ======================== .. GENERATED FROM PYTHON SOURCE LINES 22-24 .. code-block:: default df_iris = load_dataset('iris') .. GENERATED FROM PYTHON SOURCE LINES 25-27 The dataset has three kind of species. We will keep two to perform a binary classification. .. GENERATED FROM PYTHON SOURCE LINES 27-29 .. code-block:: default df_iris = df_iris[df_iris['species'].isin(['versicolor', 'virginica'])] .. GENERATED FROM PYTHON SOURCE LINES 30-32 As features, we will use the sepal length, width and petal length. We will try to predict the species. .. GENERATED FROM PYTHON SOURCE LINES 32-40 .. code-block:: default X = ['sepal_length', 'sepal_width', 'petal_length'] y = 'species' scores = run_cross_validation( X=X, y=y, data=df_iris, model='svm', preprocess_X='zscore') print(scores['test_score']) .. rst-class:: sphx-glr-script-out .. code-block:: none 2023-04-06 09:50:44,638 - julearn - INFO - Using default CV 2023-04-06 09:50:44,638 - julearn - INFO - ==== Input Data ==== 2023-04-06 09:50:44,638 - julearn - INFO - Using dataframe as input 2023-04-06 09:50:44,638 - julearn - INFO - Features: ['sepal_length', 'sepal_width', 'petal_length'] 2023-04-06 09:50:44,638 - julearn - INFO - Target: species 2023-04-06 09:50:44,639 - julearn - INFO - Expanded X: ['sepal_length', 'sepal_width', 'petal_length'] 2023-04-06 09:50:44,639 - julearn - INFO - Expanded Confounds: [] 2023-04-06 09:50:44,640 - julearn - INFO - ==================== 2023-04-06 09:50:44,640 - julearn - INFO - 2023-04-06 09:50:44,640 - julearn - INFO - ====== Model ====== 2023-04-06 09:50:44,640 - julearn - INFO - Obtaining model by name: svm 2023-04-06 09:50:44,640 - julearn - INFO - =================== 2023-04-06 09:50:44,640 - julearn - INFO - 2023-04-06 09:50:44,640 - julearn - INFO - CV interpreted as RepeatedKFold with 5 repetitions of 5 folds 0 0.95 1 0.90 2 0.85 3 0.90 4 0.90 5 0.90 6 0.90 7 0.90 8 0.95 9 1.00 10 0.80 11 0.80 12 0.85 13 1.00 14 0.95 15 0.80 16 1.00 17 0.90 18 0.95 19 0.95 20 0.95 21 0.90 22 0.95 23 0.95 24 0.90 Name: test_score, dtype: float64 .. GENERATED FROM PYTHON SOURCE LINES 41-45 Additionally, we can choose to assess the performance of the model using different scoring functions. For example, we might have an unbalanced dataset: .. GENERATED FROM PYTHON SOURCE LINES 45-49 .. code-block:: default df_unbalanced = df_iris[20:] # drop the first 20 versicolor samples print(df_unbalanced['species'].value_counts()) .. rst-class:: sphx-glr-script-out .. code-block:: none virginica 50 versicolor 30 Name: species, dtype: int64 .. GENERATED FROM PYTHON SOURCE LINES 50-55 If we compute the `accuracy`, we might not account for this imbalance. A more suitable metric is the `balanced_accuracy`. More information in scikit-learn: `Balanced Accuracy`_ We will also set the random seed so we always split the data in the same way. .. GENERATED FROM PYTHON SOURCE LINES 55-63 .. code-block:: default scores = run_cross_validation( X=X, y=y, data=df_unbalanced, model='svm', seed=42, preprocess_X='zscore', scoring=['accuracy', 'balanced_accuracy']) print(scores['test_accuracy'].mean()) print(scores['test_balanced_accuracy'].mean()) .. rst-class:: sphx-glr-script-out .. code-block:: none 2023-04-06 09:50:45,198 - julearn - INFO - Setting random seed to 42 2023-04-06 09:50:45,199 - julearn - INFO - Using default CV 2023-04-06 09:50:45,199 - julearn - INFO - ==== Input Data ==== 2023-04-06 09:50:45,199 - julearn - INFO - Using dataframe as input 2023-04-06 09:50:45,199 - julearn - INFO - Features: ['sepal_length', 'sepal_width', 'petal_length'] 2023-04-06 09:50:45,199 - julearn - INFO - Target: species 2023-04-06 09:50:45,199 - julearn - INFO - Expanded X: ['sepal_length', 'sepal_width', 'petal_length'] 2023-04-06 09:50:45,199 - julearn - INFO - Expanded Confounds: [] 2023-04-06 09:50:45,200 - julearn - INFO - ==================== 2023-04-06 09:50:45,200 - julearn - INFO - 2023-04-06 09:50:45,200 - julearn - INFO - ====== Model ====== 2023-04-06 09:50:45,200 - julearn - INFO - Obtaining model by name: svm 2023-04-06 09:50:45,200 - julearn - INFO - =================== 2023-04-06 09:50:45,200 - julearn - INFO - 2023-04-06 09:50:45,201 - julearn - INFO - CV interpreted as RepeatedKFold with 5 repetitions of 5 folds 0.895 0.8708886668886668 .. GENERATED FROM PYTHON SOURCE LINES 64-74 Other kind of metrics allows us to evaluate how good our model is to detect specific targets. Suppose we want to create a model that correctly identifies the `versicolor` samples. Now we might want to evaluate the precision score, or the ratio of true positives (tp) over all positives (true and false positives). More information in scikit-learn: `Precision`_ For this metric to work, we need to define which are our `positive` values. In this example, we are interested in detecting `versicolor`. .. GENERATED FROM PYTHON SOURCE LINES 74-78 .. code-block:: default precision_scores = run_cross_validation( X=X, y=y, data=df_unbalanced, model='svm', preprocess_X='zscore', seed=42, scoring='precision', pos_labels='versicolor') print(precision_scores['test_score'].mean()) .. rst-class:: sphx-glr-script-out .. code-block:: none 2023-04-06 09:50:45,992 - julearn - INFO - Setting random seed to 42 2023-04-06 09:50:45,992 - julearn - INFO - Using default CV 2023-04-06 09:50:45,992 - julearn - INFO - ==== Input Data ==== 2023-04-06 09:50:45,992 - julearn - INFO - Using dataframe as input 2023-04-06 09:50:45,992 - julearn - INFO - Features: ['sepal_length', 'sepal_width', 'petal_length'] 2023-04-06 09:50:45,992 - julearn - INFO - Target: species 2023-04-06 09:50:45,993 - julearn - INFO - Expanded X: ['sepal_length', 'sepal_width', 'petal_length'] 2023-04-06 09:50:45,993 - julearn - INFO - Expanded Confounds: [] 2023-04-06 09:50:45,994 - julearn - INFO - Setting the following as positive labels ['versicolor'] 2023-04-06 09:50:45,994 - julearn - INFO - ==================== 2023-04-06 09:50:45,994 - julearn - INFO - 2023-04-06 09:50:45,994 - julearn - INFO - ====== Model ====== 2023-04-06 09:50:45,995 - julearn - INFO - Obtaining model by name: svm 2023-04-06 09:50:45,995 - julearn - INFO - =================== 2023-04-06 09:50:45,995 - julearn - INFO - 2023-04-06 09:50:45,995 - julearn - INFO - CV interpreted as RepeatedKFold with 5 repetitions of 5 folds 0.9223333333333333 .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 1.935 seconds) .. _sphx_glr_download_auto_examples_basic_run_simple_binary_classification.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: run_simple_binary_classification.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: run_simple_binary_classification.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_