import numpy as np
import pandas as pd
import seaborn as sns
import random
import matplotlib.pyplot as plt
from sklearn.model_selection import GridSearchCV, StratifiedKFold
from sklearn.linear_model import LogisticRegression
from sklearn.neighbors import KNeighborsClassifier
from sklearn.svm import SVC
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.datasets import make_classification
Т.к. в задании сказано мерить AUC, то речь должна идти о бинарной классификации. Т.к. мы хотим перебирать число признаков как параметр, то создаём данные, в которых 10 признаков, информативными являются только 4
X, y = make_classification(n_samples=2000, n_features=10, n_informative=4, n_classes=3, random_state=42)
data = pd.DataFrame(np.concatenate((X, y.reshape(2000, 1)), axis=1),
columns=['x', '0', '1', '2', '3', '4', '5', '6', '7', '8', 'y'], dtype=float)
data.head()
x | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | y | |
---|---|---|---|---|---|---|---|---|---|---|---|
0 | -1.765310 | -1.024146 | -0.080752 | -1.371603 | -1.693012 | 1.203602 | 0.485726 | -0.805957 | -2.294698 | -0.872444 | 0.0 |
1 | -0.858258 | 0.148367 | -0.676983 | -2.565923 | -2.296268 | -0.160094 | 1.177521 | -0.901972 | -0.488819 | -2.252863 | 0.0 |
2 | -1.070693 | -1.402868 | -1.964698 | -2.615374 | -1.870387 | -0.723574 | -0.190782 | -3.305596 | 1.029794 | -2.143583 | 0.0 |
3 | 0.330069 | 2.002894 | -1.453344 | 0.084383 | 0.248392 | -0.676807 | -1.160034 | 0.004808 | 0.935594 | -0.138064 | 1.0 |
4 | 1.427064 | 0.000755 | 0.745855 | 0.588909 | 1.038847 | -0.383004 | -1.143316 | -0.207470 | 1.574944 | 0.663299 | 1.0 |
sns.pairplot(data, hue='y')
<seaborn.axisgrid.PairGrid at 0x1a205764e0>
class2 = data[data.y == 2]
class1 = data[data.y == 1]
class0 = data[data.y == 0]
#class0
for i in class0.columns[:10]:
plt.hist(class0[i])
plt.xlabel(i)
plt.show()
#class1
for i in class1.columns[:10]:
plt.hist(class1[i])
plt.xlabel(i)
plt.show()
#class2
for i in class2.columns[:10]:
plt.hist(class2[i])
plt.xlabel(i)
plt.show()
corr = data.corr()
sns.heatmap(corr,
xticklabels=corr.columns.values,
yticklabels=corr.columns.values);
Сильно коррелируют (3, 8). Но в целом независимо сгенерированные случайные признаки слабо коррелируют друг с другом
Здесь мы будем использовать Pipline, который сначала скейлит данные, а потом уже к отскейлинным данным применяет нужный классификатор
Метод ближайших соседей
knn = Pipeline([['scaler', StandardScaler()],
['knn', KNeighborsClassifier()]])
Логистическая регрессия
lr = Pipeline([['scaler', StandardScaler()],
['logreg', LogisticRegression()]])
SVM
svm = Pipeline([['scaler', StandardScaler()],
['svm', SVC()]])
ROC_AUC
from sklearn.metrics import roc_auc_score
F1 -score
from sklearn.metrics import f1_score
Разбиваем выборку на тренировочную и валидационную часть, оставляя на валидацую треть
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42, stratify = y)
Для метода ближайших соседей подберём число ближайших соседей
knn_params = {'knn__n_neighbors': np.arange(5, 50)}
knn_gs = GridSearchCV(knn, knn_params, scoring='roc_auc', cv=StratifiedKFold(n_splits=3, shuffle=True, random_state=42))
knn_gs.fit(X_train, y_train)
GridSearchCV(cv=StratifiedKFold(n_splits=3, random_state=42, shuffle=True), error_score='raise', estimator=Pipeline(memory=None, steps=[['scaler', StandardScaler(copy=True, with_mean=True, with_std=True)], ['knn', KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski', metric_params=None, n_jobs=1, n_neighbors=5, p=2, weights='uniform')]]), fit_params=None, iid=True, n_jobs=1, param_grid={'knn__n_neighbors': array([ 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49])}, pre_dispatch='2*n_jobs', refit=True, return_train_score='warn', scoring='roc_auc', verbose=0)
knn_gs.best_params_
{'knn__n_neighbors': 48}
knn_gs.cv_results_
/anaconda3/lib/python3.6/site-packages/sklearn/utils/deprecation.py:122: FutureWarning: You are accessing a training score ('mean_train_score'), which will not be available by default any more in 0.21. If you need training scores, please set return_train_score=True warnings.warn(*warn_args, **warn_kwargs) /anaconda3/lib/python3.6/site-packages/sklearn/utils/deprecation.py:122: FutureWarning: You are accessing a training score ('split0_train_score'), which will not be available by default any more in 0.21. If you need training scores, please set return_train_score=True warnings.warn(*warn_args, **warn_kwargs) /anaconda3/lib/python3.6/site-packages/sklearn/utils/deprecation.py:122: FutureWarning: You are accessing a training score ('split1_train_score'), which will not be available by default any more in 0.21. If you need training scores, please set return_train_score=True warnings.warn(*warn_args, **warn_kwargs) /anaconda3/lib/python3.6/site-packages/sklearn/utils/deprecation.py:122: FutureWarning: You are accessing a training score ('split2_train_score'), which will not be available by default any more in 0.21. If you need training scores, please set return_train_score=True warnings.warn(*warn_args, **warn_kwargs) /anaconda3/lib/python3.6/site-packages/sklearn/utils/deprecation.py:122: FutureWarning: You are accessing a training score ('std_train_score'), which will not be available by default any more in 0.21. If you need training scores, please set return_train_score=True warnings.warn(*warn_args, **warn_kwargs)
{'mean_fit_time': array([0.00115665, 0.00097028, 0.00097569, 0.00096703, 0.00094994, 0.00096329, 0.00097068, 0.00097203, 0.00095789, 0.00096544, 0.00095566, 0.00107876, 0.00097434, 0.00098109, 0.00097569, 0.00097307, 0.00099095, 0.00097529, 0.00097966, 0.00098872, 0.00099031, 0.0009807 , 0.00098403, 0.00098165, 0.00101471, 0.00100501, 0.00098554, 0.00098268, 0.00099071, 0.00098848, 0.00099651, 0.00099476, 0.00098729, 0.00100629, 0.00099277, 0.00099166, 0.00098936, 0.00104698, 0.00099206, 0.00099699, 0.00099174, 0.00098944, 0.00099134, 0.00099564, 0.00100334]), 'mean_score_time': array([0.01225479, 0.01105698, 0.01113828, 0.01116832, 0.01126083, 0.01140006, 0.01147779, 0.01159732, 0.0116117 , 0.01182294, 0.01190972, 0.01216594, 0.01197378, 0.01220926, 0.01219622, 0.0124704 , 0.01240126, 0.01247764, 0.01255433, 0.01276159, 0.01280602, 0.01283026, 0.0128266 , 0.0129927 , 0.01319464, 0.01354504, 0.01324463, 0.01389035, 0.01349274, 0.01360059, 0.0136377 , 0.01383964, 0.01380237, 0.01398643, 0.01391912, 0.0140907 , 0.01419202, 0.01433229, 0.014292 , 0.01444356, 0.01444626, 0.01451818, 0.0146807 , 0.01481994, 0.01481231]), 'mean_test_score': array([0.78717427, 0.79790753, 0.80784872, 0.81614059, 0.81983185, 0.82291085, 0.82367182, 0.82677032, 0.82847411, 0.83028947, 0.83063764, 0.83209637, 0.83605537, 0.83757122, 0.83902066, 0.83793942, 0.84081019, 0.84113348, 0.84541762, 0.84529759, 0.84700625, 0.84741373, 0.84882811, 0.84817721, 0.84799437, 0.84934513, 0.84840144, 0.8483664 , 0.84708814, 0.84732238, 0.84826568, 0.84892318, 0.84937754, 0.85071211, 0.84945926, 0.84927628, 0.85203584, 0.851488 , 0.85238355, 0.85214803, 0.85345139, 0.85476754, 0.85560436, 0.85728558, 0.85487683]), 'mean_train_score': array([0.9095229 , 0.90626379, 0.90426379, 0.90192094, 0.90051158, 0.90043894, 0.89722272, 0.89701481, 0.89606581, 0.89668302, 0.89457529, 0.89453688, 0.89458858, 0.89474148, 0.89456306, 0.89373978, 0.89416187, 0.89300812, 0.89307392, 0.89336571, 0.89039311, 0.89059043, 0.89134234, 0.89049545, 0.89051628, 0.88933423, 0.88880824, 0.88798793, 0.8868775 , 0.88746096, 0.88721824, 0.88792791, 0.88809175, 0.888937 , 0.88784656, 0.88624554, 0.88630341, 0.88512573, 0.88402478, 0.88480251, 0.88385804, 0.88480886, 0.88465828, 0.88416202, 0.88418234]), 'param_knn__n_neighbors': masked_array(data=[5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49], mask=[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False], fill_value='?', dtype=object), 'params': [{'knn__n_neighbors': 5}, {'knn__n_neighbors': 6}, {'knn__n_neighbors': 7}, {'knn__n_neighbors': 8}, {'knn__n_neighbors': 9}, {'knn__n_neighbors': 10}, {'knn__n_neighbors': 11}, {'knn__n_neighbors': 12}, {'knn__n_neighbors': 13}, {'knn__n_neighbors': 14}, {'knn__n_neighbors': 15}, {'knn__n_neighbors': 16}, {'knn__n_neighbors': 17}, {'knn__n_neighbors': 18}, {'knn__n_neighbors': 19}, {'knn__n_neighbors': 20}, {'knn__n_neighbors': 21}, {'knn__n_neighbors': 22}, {'knn__n_neighbors': 23}, {'knn__n_neighbors': 24}, {'knn__n_neighbors': 25}, {'knn__n_neighbors': 26}, {'knn__n_neighbors': 27}, {'knn__n_neighbors': 28}, {'knn__n_neighbors': 29}, {'knn__n_neighbors': 30}, {'knn__n_neighbors': 31}, {'knn__n_neighbors': 32}, {'knn__n_neighbors': 33}, {'knn__n_neighbors': 34}, {'knn__n_neighbors': 35}, {'knn__n_neighbors': 36}, {'knn__n_neighbors': 37}, {'knn__n_neighbors': 38}, {'knn__n_neighbors': 39}, {'knn__n_neighbors': 40}, {'knn__n_neighbors': 41}, {'knn__n_neighbors': 42}, {'knn__n_neighbors': 43}, {'knn__n_neighbors': 44}, {'knn__n_neighbors': 45}, {'knn__n_neighbors': 46}, {'knn__n_neighbors': 47}, {'knn__n_neighbors': 48}, {'knn__n_neighbors': 49}], 'rank_test_score': array([45, 44, 43, 42, 41, 40, 39, 38, 37, 36, 35, 34, 33, 32, 30, 31, 29, 28, 26, 27, 25, 22, 16, 20, 21, 13, 17, 18, 24, 23, 19, 15, 12, 10, 11, 14, 8, 9, 6, 7, 5, 4, 2, 1, 3], dtype=int32), 'split0_test_score': array([0.77304212, 0.78662516, 0.80402186, 0.81520259, 0.81934657, 0.82446148, 0.82769459, 0.82413117, 0.82269979, 0.82594291, 0.8267737 , 0.83206879, 0.83850496, 0.83657311, 0.8390655 , 0.83704356, 0.83790439, 0.83892537, 0.84422045, 0.84246877, 0.84665279, 0.84703315, 0.84913517, 0.84483104, 0.84147782, 0.84293922, 0.84132767, 0.84061699, 0.83743394, 0.83825472, 0.84151786, 0.84143778, 0.84179813, 0.84359986, 0.84339966, 0.84480101, 0.84858464, 0.85019619, 0.85430013, 0.85522101, 0.85840407, 0.86008568, 0.85983544, 0.85939502, 0.85662236]), 'split0_train_score': array([0.91904846, 0.91770666, 0.91739567, 0.91105036, 0.91083216, 0.90783256, 0.904532 , 0.9037821 , 0.90072231, 0.9030949 , 0.90148224, 0.90114617, 0.90241523, 0.89892406, 0.90083517, 0.89988965, 0.90243028, 0.90130668, 0.9021243 , 0.90134932, 0.89964888, 0.90027087, 0.90007775, 0.89710072, 0.89832715, 0.89866573, 0.89835223, 0.89880618, 0.89710574, 0.89986206, 0.90085022, 0.9021243 , 0.90308738, 0.90466493, 0.9023751 , 0.89954354, 0.90060945, 0.89791332, 0.89682735, 0.89619031, 0.89410112, 0.8961251 , 0.89780046, 0.89682735, 0.89688252]), 'split1_test_score': array([0.80165959, 0.80919683, 0.81594331, 0.81930653, 0.82405109, 0.82761451, 0.82710402, 0.83566224, 0.83212884, 0.83130806, 0.83030709, 0.82995676, 0.83238909, 0.83451113, 0.83595251, 0.83838485, 0.8471833 , 0.847914 , 0.85524103, 0.85640215, 0.85411996, 0.85481062, 0.85432015, 0.85639214, 0.8581238 , 0.85832399, 0.85544122, 0.85824391, 0.85680253, 0.85601177, 0.85776345, 0.86019579, 0.86086643, 0.86444987, 0.86184737, 0.85900464, 0.86146701, 0.85853419, 0.86064622, 0.86139694, 0.86059617, 0.86210762, 0.86211763, 0.86535074, 0.86264814]), 'split1_train_score': array([0.89871338, 0.89650883, 0.89268911, 0.89488363, 0.89442215, 0.89286216, 0.88750251, 0.887969 , 0.88912771, 0.88891703, 0.88623596, 0.88428471, 0.88299308, 0.88460072, 0.88391352, 0.88152337, 0.88235102, 0.88440008, 0.88345706, 0.88402388, 0.88048756, 0.87936146, 0.8811572 , 0.88269211, 0.88253913, 0.87986055, 0.88064306, 0.87860403, 0.87669041, 0.87557434, 0.87238162, 0.87197783, 0.87036768, 0.8724995 , 0.8733748 , 0.87243178, 0.87402689, 0.87352779, 0.87028993, 0.8724694 , 0.87201294, 0.87433287, 0.87319422, 0.87311146, 0.87346509]), 'split2_test_score': array([0.7868203 , 0.79790058, 0.80357143, 0.81390766, 0.81608953, 0.81664254, 0.81620013, 0.82050354, 0.83059846, 0.83362492, 0.83484154, 0.83426842, 0.83727477, 0.84163851, 0.84205076, 0.83839085, 0.8373351 , 0.83655084, 0.83677204, 0.8370033 , 0.84023086, 0.84038168, 0.84301601, 0.84329754, 0.84437339, 0.84676641, 0.84843549, 0.84623351, 0.84702783, 0.8477015 , 0.84550957, 0.84512749, 0.8454593 , 0.84407175, 0.84311655, 0.84401142, 0.84604247, 0.84572072, 0.84218147, 0.8397985 , 0.84132682, 0.84208092, 0.84483591, 0.84708816, 0.84533864]), 'split2_train_score': array([0.91080687, 0.90457589, 0.9027066 , 0.89982884, 0.89628043, 0.9006221 , 0.89963365, 0.89929332, 0.89834741, 0.89803712, 0.89600767, 0.89817975, 0.89835742, 0.90069967, 0.89894048, 0.89980631, 0.8977043 , 0.89331758, 0.89364039, 0.89472393, 0.8910429 , 0.89213895, 0.89279208, 0.89169353, 0.89068256, 0.8894764 , 0.88742943, 0.88655359, 0.88683636, 0.88694647, 0.88842289, 0.88968159, 0.89082019, 0.88964656, 0.88778978, 0.88676129, 0.8842739 , 0.88393608, 0.88495706, 0.88574782, 0.88546004, 0.88396861, 0.88298016, 0.88254725, 0.88219941]), 'std_fit_time': array([2.76031220e-04, 3.17095714e-06, 1.61014446e-05, 7.84248552e-06, 4.76969595e-06, 6.67950356e-06, 9.38520301e-06, 1.46000966e-05, 1.14380193e-05, 2.12191047e-05, 1.08758834e-05, 1.34637111e-04, 3.80640984e-06, 1.32588743e-05, 9.90003465e-06, 1.54480194e-05, 7.72400971e-06, 3.89984263e-06, 4.53622284e-06, 7.42805573e-06, 8.32608595e-06, 2.06323544e-06, 2.33871761e-06, 3.01787655e-06, 2.08076278e-05, 1.91887051e-05, 9.16043220e-06, 1.42976070e-05, 4.10344780e-06, 8.48537942e-07, 1.22969966e-05, 1.06363074e-05, 1.70198138e-05, 6.33892587e-06, 5.76493298e-06, 1.10504432e-05, 4.47877337e-06, 3.90163990e-05, 7.04041627e-06, 2.88957897e-06, 8.03894591e-06, 2.14744432e-05, 5.40882731e-06, 2.54561383e-06, 9.23459583e-06]), 'std_score_time': array([1.50273262e-03, 4.13779701e-05, 3.87664667e-05, 2.08706685e-05, 6.72194771e-05, 3.93702804e-05, 5.07087914e-05, 7.43882099e-05, 3.15151727e-05, 9.95830767e-05, 1.03505704e-04, 4.01846262e-04, 3.80994249e-05, 3.29042593e-05, 1.10504432e-05, 2.66066612e-04, 2.00511143e-05, 4.81538642e-05, 2.07386099e-05, 7.64802308e-05, 8.93884922e-05, 6.72461564e-05, 1.01849361e-05, 7.98694728e-05, 1.12053407e-04, 2.60362331e-04, 1.72926076e-05, 6.32017399e-04, 1.00564481e-04, 1.07140085e-04, 3.56302632e-05, 1.27712020e-04, 2.44479495e-05, 9.59312804e-05, 9.56055529e-06, 5.15463443e-05, 9.97529076e-05, 1.58319392e-04, 6.60254473e-05, 1.59007656e-04, 7.55274034e-05, 1.06523282e-05, 3.96020524e-05, 4.21501452e-05, 6.69245477e-05]), 'std_test_score': array([0.01169007, 0.00921828, 0.00572989, 0.00230143, 0.00326777, 0.00461087, 0.00528288, 0.00646296, 0.00413283, 0.00321733, 0.00330136, 0.00176001, 0.00264205, 0.00299368, 0.00248933, 0.00063382, 0.00451497, 0.00489416, 0.00758596, 0.00816709, 0.00567465, 0.00589564, 0.00461914, 0.00584569, 0.0072634 , 0.00654182, 0.00576403, 0.00735479, 0.00791026, 0.00725693, 0.0069144 , 0.00811635, 0.00826465, 0.00972141, 0.00876538, 0.00689038, 0.00675277, 0.00530936, 0.0076578 , 0.00908001, 0.00861044, 0.00899872, 0.00766282, 0.0076022 , 0.00717237]), 'std_train_score': array([0.00835126, 0.00873589, 0.01014634, 0.0067638 , 0.00733708, 0.00611301, 0.00715823, 0.00665368, 0.00500087, 0.00586675, 0.00630614, 0.00734983, 0.00836493, 0.00720715, 0.00756998, 0.00863837, 0.00857149, 0.00690556, 0.00763139, 0.00713799, 0.00783606, 0.00860617, 0.00779201, 0.00594298, 0.00644651, 0.00767784, 0.00729518, 0.00830962, 0.00833457, 0.00992209, 0.01165343, 0.01236956, 0.01349637, 0.01314106, 0.01183939, 0.01107434, 0.01094676, 0.00999083, 0.01085389, 0.00970706, 0.00908834, 0.00891646, 0.0101153 , 0.00974907, 0.0096624 ])}
На кросс-валидации качество очент даже неплохое. Оказалось, что лучшим параметром n_neighbors в данном случае является 48. Будем использовать эту модель для отбора признаков
from sklearn.feature_selection import SelectFromModel
knn_best = Pipeline([['scaler', StandardScaler()],
['knn', KNeighborsClassifier(n_neighbors=48)]])
knn_select = Pipeline([['selector', SelectFromModel(knn_best, threshold=-np.inf)],
['knn', knn_best]])
knn_select_params = {'selector__max_features': np.arange(1, 20)}
knn_select_gs = GridSearchCV(knn_select, knn_select_params, scoring='roc_auc',
cv=StratifiedKFold(n_splits=3, shuffle=True))