Note
Click here to download the full example code
Rank Selection
Authors: Federico Raimondo, Kaustubh Patil
License: BSD 3 clause
from opnmf.selection import rank_permute
from opnmf.logging import configure_logging
import matplotlib.pyplot as plt
import seaborn as sns
set up logging
configure_logging('INFO')
Out:
2021-11-09 12:44:33,906 - opnmf - INFO - ===== Lib Versions =====
2021-11-09 12:44:33,906 - opnmf - INFO - numpy: 1.19.5
2021-11-09 12:44:33,906 - opnmf - INFO - scipy: 1.7.2
2021-11-09 12:44:33,906 - opnmf - INFO - sklearn: 0.24.2
2021-11-09 12:44:33,906 - opnmf - INFO - opnmf: 0.0.3.dev1+g88d7273.d20211109
2021-11-09 12:44:33,906 - opnmf - INFO - ========================
Load IRIS dataset
Find rank. In this example we are bounded by the number of features (4)
min_components = 1
max_components = 4
result = rank_permute(X, min_components, max_components)
good_ranks, tested_ranks, errors, random_errors, estimators = result
Out:
2021-11-09 12:44:34,448 - opnmf - INFO - Choosing ranks between: [1 2 3 4]
2021-11-09 12:44:34,448 - opnmf - INFO - Fitting estimators with random permutations
2021-11-09 12:44:34,448 - opnmf - INFO - Initializing using nndsvd
2021-11-09 12:44:34,450 - opnmf - INFO - iter=0 diff=0.8927946343839875, obj=44.36921300841488
2021-11-09 12:44:34,451 - opnmf - INFO - Converged in 1 iterations
2021-11-09 12:44:34,451 - opnmf - INFO - Initializing using nndsvd
2021-11-09 12:44:34,452 - opnmf - INFO - iter=0 diff=0.8954233634414889, obj=42.46008641379486
2021-11-09 12:44:34,484 - opnmf - INFO - iter=100 diff=0.003230190158376143, obj=35.50271365161412
2021-11-09 12:44:34,523 - opnmf - INFO - iter=200 diff=0.0015813875310213721, obj=35.114281740365584
2021-11-09 12:44:34,561 - opnmf - INFO - iter=300 diff=0.0001971828170329425, obj=34.99537206006796
2021-11-09 12:44:34,583 - opnmf - INFO - iter=400 diff=3.447297178623933e-05, obj=34.98302190035095
2021-11-09 12:44:34,600 - opnmf - INFO - Converged in 475 iterations
2021-11-09 12:44:34,600 - opnmf - INFO - Initializing using nndsvd
2021-11-09 12:44:34,603 - opnmf - INFO - iter=0 diff=0.8972933734621308, obj=40.95256384555148
2021-11-09 12:44:34,623 - opnmf - INFO - iter=100 diff=0.0007537736034581833, obj=25.00032880051592
2021-11-09 12:44:34,643 - opnmf - INFO - iter=200 diff=0.00028154829833711733, obj=25.136553554295514
2021-11-09 12:44:34,664 - opnmf - INFO - iter=300 diff=0.00011940889734639332, obj=25.203972701557213
2021-11-09 12:44:34,684 - opnmf - INFO - iter=400 diff=5.965971048615933e-05, obj=25.23532574982238
2021-11-09 12:44:34,704 - opnmf - INFO - iter=500 diff=3.310902773006408e-05, obj=25.251857601095065
2021-11-09 12:44:34,723 - opnmf - INFO - iter=600 diff=1.982751307613487e-05, obj=25.261415948935383
2021-11-09 12:44:34,741 - opnmf - INFO - iter=700 diff=1.2637275233539403e-05, obj=25.267351122564527
2021-11-09 12:44:34,751 - opnmf - INFO - Converged in 758 iterations
2021-11-09 12:44:34,751 - opnmf - INFO - Initializing using nndsvd
2021-11-09 12:44:34,753 - opnmf - INFO - iter=0 diff=0.8998379823515824, obj=40.1460266459101
2021-11-09 12:44:34,774 - opnmf - INFO - iter=100 diff=0.0002571011848116155, obj=1.9133110087811631
2021-11-09 12:44:34,795 - opnmf - INFO - iter=200 diff=5.474576895304248e-05, obj=0.8848600987713051
2021-11-09 12:44:34,817 - opnmf - INFO - iter=300 diff=2.2889453491382444e-05, obj=0.5720309643423076
2021-11-09 12:44:34,837 - opnmf - INFO - iter=400 diff=1.2458072789557783e-05, obj=0.4217661096072306
2021-11-09 12:44:34,846 - opnmf - INFO - Converged in 445 iterations
2021-11-09 12:44:34,846 - opnmf - INFO - Fitting estimators with original data
2021-11-09 12:44:34,846 - opnmf - INFO - Initializing using nndsvd
2021-11-09 12:44:34,848 - opnmf - INFO - iter=0 diff=0.8979166119932412, obj=18.19299122423655
2021-11-09 12:44:34,849 - opnmf - INFO - Converged in 1 iterations
2021-11-09 12:44:34,849 - opnmf - INFO - Initializing using nndsvd
2021-11-09 12:44:34,851 - opnmf - INFO - iter=0 diff=0.9005345540236736, obj=17.592898608366433
2021-11-09 12:44:34,881 - opnmf - INFO - iter=100 diff=0.004782334046616283, obj=12.374361231854794
2021-11-09 12:44:34,903 - opnmf - INFO - iter=200 diff=0.0017090250985303454, obj=9.236112536836712
2021-11-09 12:44:34,913 - opnmf - INFO - iter=300 diff=0.0004670349054987102, obj=8.274114227196149
2021-11-09 12:44:34,923 - opnmf - INFO - iter=400 diff=0.00015084168501738635, obj=8.021475182985375
2021-11-09 12:44:34,933 - opnmf - INFO - iter=500 diff=5.3279572748436734e-05, obj=7.941750994725744
2021-11-09 12:44:34,943 - opnmf - INFO - iter=600 diff=1.9508155278104865e-05, obj=7.913889310596234
2021-11-09 12:44:34,949 - opnmf - INFO - Converged in 668 iterations
2021-11-09 12:44:34,950 - opnmf - INFO - Initializing using nndsvd
2021-11-09 12:44:34,951 - opnmf - INFO - iter=0 diff=0.9008854436852691, obj=17.589994890219803
2021-11-09 12:44:34,962 - opnmf - INFO - iter=100 diff=0.006490351920417269, obj=12.281892360977254
2021-11-09 12:44:34,972 - opnmf - INFO - iter=200 diff=0.0029083573435830685, obj=7.428428494714522
2021-11-09 12:44:34,982 - opnmf - INFO - iter=300 diff=0.0008369396656010743, obj=4.692345362790703
2021-11-09 12:44:34,993 - opnmf - INFO - iter=400 diff=0.00033340595841189986, obj=3.746048776589669
2021-11-09 12:44:35,003 - opnmf - INFO - iter=500 diff=0.00016958404516681999, obj=3.3866556678528785
2021-11-09 12:44:35,014 - opnmf - INFO - iter=600 diff=0.00010053594640454818, obj=3.224643627824671
2021-11-09 12:44:35,024 - opnmf - INFO - iter=700 diff=6.58747456302156e-05, obj=3.140665500006886
2021-11-09 12:44:35,034 - opnmf - INFO - iter=800 diff=4.6260574718618346e-05, obj=3.092338617378829
2021-11-09 12:44:35,044 - opnmf - INFO - iter=900 diff=3.416646735891813e-05, obj=3.062256256463681
2021-11-09 12:44:35,057 - opnmf - INFO - iter=1000 diff=2.6215926734249445e-05, obj=3.0423680638944566
2021-11-09 12:44:35,068 - opnmf - INFO - iter=1100 diff=2.0724385737918712e-05, obj=3.0285824848281906
2021-11-09 12:44:35,078 - opnmf - INFO - iter=1200 diff=1.6779513269325736e-05, obj=3.0186571197072434
2021-11-09 12:44:35,088 - opnmf - INFO - iter=1300 diff=1.3853983874214158e-05, obj=3.011285535705444
2021-11-09 12:44:35,097 - opnmf - INFO - iter=1400 diff=1.1626484504865934e-05, obj=3.0056673153828557
2021-11-09 12:44:35,106 - opnmf - INFO - Converged in 1494 iterations
2021-11-09 12:44:35,107 - opnmf - INFO - Initializing using nndsvd
2021-11-09 12:44:35,109 - opnmf - INFO - iter=0 diff=0.9009726594607051, obj=17.579567022097375
2021-11-09 12:44:35,118 - opnmf - INFO - iter=100 diff=0.004736622724118709, obj=11.56302516543663
2021-11-09 12:44:35,128 - opnmf - INFO - iter=200 diff=0.002601350330340558, obj=7.44475231384076
2021-11-09 12:44:35,139 - opnmf - INFO - iter=300 diff=0.0008629189126871222, obj=4.139826452234598
2021-11-09 12:44:35,149 - opnmf - INFO - iter=400 diff=0.00034684015663308737, obj=2.5709506761378043
2021-11-09 12:44:35,159 - opnmf - INFO - iter=500 diff=0.00017414127196746965, obj=1.8021428845057668
2021-11-09 12:44:35,171 - opnmf - INFO - iter=600 diff=0.0001019585812450138, obj=1.3696389513039646
2021-11-09 12:44:35,182 - opnmf - INFO - iter=700 diff=6.617752710378401e-05, obj=1.098089889540354
2021-11-09 12:44:35,192 - opnmf - INFO - iter=800 diff=4.615562611253175e-05, obj=0.9135682660008669
2021-11-09 12:44:35,203 - opnmf - INFO - iter=900 diff=3.3922841629715706e-05, obj=0.7807271391302854
2021-11-09 12:44:35,213 - opnmf - INFO - iter=1000 diff=2.594008726205868e-05, obj=0.6808413393286783
2021-11-09 12:44:35,224 - opnmf - INFO - iter=1100 diff=2.0458660499052915e-05, obj=0.6031603926541143
2021-11-09 12:44:35,234 - opnmf - INFO - iter=1200 diff=1.65395753849674e-05, obj=0.5411070811307666
2021-11-09 12:44:35,245 - opnmf - INFO - iter=1300 diff=1.3644125708837993e-05, obj=0.49044638273710617
2021-11-09 12:44:35,257 - opnmf - INFO - iter=1400 diff=1.1446172057175162e-05, obj=0.44833495441646
2021-11-09 12:44:35,268 - opnmf - INFO - Converged in 1484 iterations
Plot the results
plt.figure()
plt.title('Rank selection on IRIS dataset')
plt.plot(tested_ranks, random_errors, label='permuted')
plt.plot(tested_ranks, errors, label='original')
good_errors = errors[good_ranks - min_components]
plt.plot(good_ranks, good_errors, label='selected', marker='o', c='r',
ls='None')
plt.xticks(tested_ranks)
plt.xlabel('# Components')
plt.ylabel('Error')
plt.legend()
plt.show()
Total running time of the script: ( 0 minutes 1.882 seconds)