Simple Binary Classification

This example uses the ‘iris’ dataset and performs a simple binary classification using a Support Vector Machine classifier.

# Authors: Federico Raimondo <f.raimondo@fz-juelich.de>
#
# License: AGPL
from seaborn import load_dataset
from julearn import run_cross_validation
from julearn.utils import configure_logging

Out:

/opt/hostedtoolcache/Python/3.8.13/x64/lib/python3.8/site-packages/seaborn/cm.py:1582: UserWarning: Trying to register the cmap 'rocket' which already exists.
  mpl_cm.register_cmap(_name, _cmap)
/opt/hostedtoolcache/Python/3.8.13/x64/lib/python3.8/site-packages/seaborn/cm.py:1583: UserWarning: Trying to register the cmap 'rocket_r' which already exists.
  mpl_cm.register_cmap(_name + "_r", _cmap_r)
/opt/hostedtoolcache/Python/3.8.13/x64/lib/python3.8/site-packages/seaborn/cm.py:1582: UserWarning: Trying to register the cmap 'mako' which already exists.
  mpl_cm.register_cmap(_name, _cmap)
/opt/hostedtoolcache/Python/3.8.13/x64/lib/python3.8/site-packages/seaborn/cm.py:1583: UserWarning: Trying to register the cmap 'mako_r' which already exists.
  mpl_cm.register_cmap(_name + "_r", _cmap_r)
/opt/hostedtoolcache/Python/3.8.13/x64/lib/python3.8/site-packages/seaborn/cm.py:1582: UserWarning: Trying to register the cmap 'icefire' which already exists.
  mpl_cm.register_cmap(_name, _cmap)
/opt/hostedtoolcache/Python/3.8.13/x64/lib/python3.8/site-packages/seaborn/cm.py:1583: UserWarning: Trying to register the cmap 'icefire_r' which already exists.
  mpl_cm.register_cmap(_name + "_r", _cmap_r)
/opt/hostedtoolcache/Python/3.8.13/x64/lib/python3.8/site-packages/seaborn/cm.py:1582: UserWarning: Trying to register the cmap 'vlag' which already exists.
  mpl_cm.register_cmap(_name, _cmap)
/opt/hostedtoolcache/Python/3.8.13/x64/lib/python3.8/site-packages/seaborn/cm.py:1583: UserWarning: Trying to register the cmap 'vlag_r' which already exists.
  mpl_cm.register_cmap(_name + "_r", _cmap_r)
/opt/hostedtoolcache/Python/3.8.13/x64/lib/python3.8/site-packages/seaborn/cm.py:1582: UserWarning: Trying to register the cmap 'flare' which already exists.
  mpl_cm.register_cmap(_name, _cmap)
/opt/hostedtoolcache/Python/3.8.13/x64/lib/python3.8/site-packages/seaborn/cm.py:1583: UserWarning: Trying to register the cmap 'flare_r' which already exists.
  mpl_cm.register_cmap(_name + "_r", _cmap_r)
/opt/hostedtoolcache/Python/3.8.13/x64/lib/python3.8/site-packages/seaborn/cm.py:1582: UserWarning: Trying to register the cmap 'crest' which already exists.
  mpl_cm.register_cmap(_name, _cmap)
/opt/hostedtoolcache/Python/3.8.13/x64/lib/python3.8/site-packages/seaborn/cm.py:1583: UserWarning: Trying to register the cmap 'crest_r' which already exists.
  mpl_cm.register_cmap(_name + "_r", _cmap_r)

Set the logging level to info to see extra information

configure_logging(level='INFO')

Out:

2022-07-21 09:54:38,941 - julearn - INFO - ===== Lib Versions =====
2022-07-21 09:54:38,941 - julearn - INFO - numpy: 1.23.1
2022-07-21 09:54:38,941 - julearn - INFO - scipy: 1.8.1
2022-07-21 09:54:38,941 - julearn - INFO - sklearn: 1.0.2
2022-07-21 09:54:38,941 - julearn - INFO - pandas: 1.4.3
2022-07-21 09:54:38,941 - julearn - INFO - julearn: 0.2.5
2022-07-21 09:54:38,941 - julearn - INFO - ========================
df_iris = load_dataset('iris')

The dataset has three kind of species. We will keep two to perform a binary classification.

df_iris = df_iris[df_iris['species'].isin(['versicolor', 'virginica'])]

As features, we will use the sepal length, width and petal length. We will try to predict the species.

X = ['sepal_length', 'sepal_width', 'petal_length']
y = 'species'
scores = run_cross_validation(
    X=X, y=y, data=df_iris, model='svm', preprocess_X='zscore')

print(scores['test_score'])

Out:

2022-07-21 09:54:38,945 - julearn - INFO - Using default CV
2022-07-21 09:54:38,945 - julearn - INFO - ==== Input Data ====
2022-07-21 09:54:38,945 - julearn - INFO - Using dataframe as input
2022-07-21 09:54:38,945 - julearn - INFO - Features: ['sepal_length', 'sepal_width', 'petal_length']
2022-07-21 09:54:38,945 - julearn - INFO - Target: species
2022-07-21 09:54:38,945 - julearn - INFO - Expanded X: ['sepal_length', 'sepal_width', 'petal_length']
2022-07-21 09:54:38,945 - julearn - INFO - Expanded Confounds: []
2022-07-21 09:54:38,946 - julearn - INFO - ====================
2022-07-21 09:54:38,946 - julearn - INFO -
2022-07-21 09:54:38,946 - julearn - INFO - ====== Model ======
2022-07-21 09:54:38,946 - julearn - INFO - Obtaining model by name: svm
2022-07-21 09:54:38,946 - julearn - INFO - ===================
2022-07-21 09:54:38,946 - julearn - INFO -
2022-07-21 09:54:38,946 - julearn - INFO - CV interpreted as RepeatedKFold with 5 repetitions of 5 folds
0     0.90
1     0.95
2     0.90
3     0.80
4     1.00
5     1.00
6     0.95
7     0.90
8     0.90
9     0.80
10    0.90
11    1.00
12    0.95
13    0.80
14    1.00
15    0.90
16    0.95
17    0.95
18    0.95
19    0.90
20    0.95
21    0.95
22    0.80
23    0.95
24    0.95
Name: test_score, dtype: float64

Additionally, we can choose to assess the performance of the model using different scoring functions.

For example, we might have an unbalanced dataset:

df_unbalanced = df_iris[20:]  # drop the first 20 versicolor samples
print(df_unbalanced['species'].value_counts())

Out:

virginica     50
versicolor    30
Name: species, dtype: int64

If we compute the accuracy, we might not account for this imbalance. A more suitable metric is the balanced_accuracy. More information in scikit-learn: Balanced Accuracy

We will also set the random seed so we always split the data in the same way.

scores = run_cross_validation(
    X=X, y=y, data=df_unbalanced, model='svm', seed=42, preprocess_X='zscore',
    scoring=['accuracy', 'balanced_accuracy'])

print(scores['test_accuracy'].mean())
print(scores['test_balanced_accuracy'].mean())

Out:

2022-07-21 09:54:39,312 - julearn - INFO - Setting random seed to 42
2022-07-21 09:54:39,312 - julearn - INFO - Using default CV
2022-07-21 09:54:39,312 - julearn - INFO - ==== Input Data ====
2022-07-21 09:54:39,312 - julearn - INFO - Using dataframe as input
2022-07-21 09:54:39,312 - julearn - INFO - Features: ['sepal_length', 'sepal_width', 'petal_length']
2022-07-21 09:54:39,312 - julearn - INFO - Target: species
2022-07-21 09:54:39,312 - julearn - INFO - Expanded X: ['sepal_length', 'sepal_width', 'petal_length']
2022-07-21 09:54:39,312 - julearn - INFO - Expanded Confounds: []
2022-07-21 09:54:39,313 - julearn - INFO - ====================
2022-07-21 09:54:39,313 - julearn - INFO -
2022-07-21 09:54:39,313 - julearn - INFO - ====== Model ======
2022-07-21 09:54:39,313 - julearn - INFO - Obtaining model by name: svm
2022-07-21 09:54:39,313 - julearn - INFO - ===================
2022-07-21 09:54:39,313 - julearn - INFO -
2022-07-21 09:54:39,313 - julearn - INFO - CV interpreted as RepeatedKFold with 5 repetitions of 5 folds
0.895
0.8708886668886668

Other kind of metrics allows us to evaluate how good our model is to detect specific targets. Suppose we want to create a model that correctly identifies the versicolor samples.

Now we might want to evaluate the precision score, or the ratio of true positives (tp) over all positives (true and false positives). More information in scikit-learn: Precision

For this metric to work, we need to define which are our positive values. In this example, we are interested in detecting versicolor.

precision_scores = run_cross_validation(
    X=X, y=y, data=df_unbalanced, model='svm', preprocess_X='zscore', seed=42,
    scoring='precision', pos_labels='versicolor')
print(precision_scores['test_score'].mean())

Out:

2022-07-21 09:54:39,815 - julearn - INFO - Setting random seed to 42
2022-07-21 09:54:39,815 - julearn - INFO - Using default CV
2022-07-21 09:54:39,815 - julearn - INFO - ==== Input Data ====
2022-07-21 09:54:39,816 - julearn - INFO - Using dataframe as input
2022-07-21 09:54:39,816 - julearn - INFO - Features: ['sepal_length', 'sepal_width', 'petal_length']
2022-07-21 09:54:39,816 - julearn - INFO - Target: species
2022-07-21 09:54:39,816 - julearn - INFO - Expanded X: ['sepal_length', 'sepal_width', 'petal_length']
2022-07-21 09:54:39,816 - julearn - INFO - Expanded Confounds: []
2022-07-21 09:54:39,816 - julearn - INFO - Setting the following as positive labels ['versicolor']
2022-07-21 09:54:39,817 - julearn - INFO - ====================
2022-07-21 09:54:39,817 - julearn - INFO -
2022-07-21 09:54:39,817 - julearn - INFO - ====== Model ======
2022-07-21 09:54:39,817 - julearn - INFO - Obtaining model by name: svm
2022-07-21 09:54:39,817 - julearn - INFO - ===================
2022-07-21 09:54:39,817 - julearn - INFO -
2022-07-21 09:54:39,817 - julearn - INFO - CV interpreted as RepeatedKFold with 5 repetitions of 5 folds
0.9223333333333333

Total running time of the script: ( 0 minutes 1.266 seconds)

Gallery generated by Sphinx-Gallery