이 노트북의 코드에 대한 설명은 분류기 체인: ClassifierChain 글을 참고하세요.
import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import fetch_openml
fetch yeast dataset (https://www.openml.org/d/40597)
사이킷런 0.20 버전에서 fetch_openml
이 추가되었습니다. fetch_mldata
는 0.22 버전에서 삭제될 예정입니다.
yeast = fetch_openml('yeast', version=4)
X = yeast['data']
Y = yeast['target']
X.shape, Y.shape
((2417, 103), (2417, 14))
Y
Class1 | Class2 | Class3 | Class4 | Class5 | Class6 | Class7 | Class8 | Class9 | Class10 | Class11 | Class12 | Class13 | Class14 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | FALSE | FALSE | FALSE | FALSE | FALSE | FALSE | TRUE | TRUE | FALSE | FALSE | FALSE | TRUE | TRUE | FALSE |
1 | FALSE | FALSE | TRUE | TRUE | FALSE | FALSE | FALSE | FALSE | FALSE | FALSE | FALSE | FALSE | FALSE | FALSE |
2 | FALSE | TRUE | TRUE | FALSE | FALSE | FALSE | FALSE | FALSE | FALSE | FALSE | FALSE | TRUE | TRUE | FALSE |
3 | FALSE | FALSE | TRUE | TRUE | FALSE | FALSE | FALSE | FALSE | FALSE | FALSE | FALSE | FALSE | FALSE | FALSE |
4 | FALSE | FALSE | TRUE | TRUE | TRUE | TRUE | FALSE | FALSE | FALSE | FALSE | FALSE | FALSE | FALSE | FALSE |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
2412 | FALSE | TRUE | TRUE | FALSE | FALSE | FALSE | FALSE | FALSE | FALSE | FALSE | FALSE | FALSE | FALSE | FALSE |
2413 | TRUE | TRUE | FALSE | FALSE | FALSE | FALSE | FALSE | FALSE | FALSE | FALSE | FALSE | TRUE | TRUE | FALSE |
2414 | FALSE | FALSE | FALSE | FALSE | FALSE | TRUE | TRUE | TRUE | FALSE | FALSE | FALSE | TRUE | TRUE | FALSE |
2415 | FALSE | FALSE | FALSE | FALSE | FALSE | FALSE | FALSE | FALSE | FALSE | FALSE | FALSE | TRUE | TRUE | FALSE |
2416 | FALSE | TRUE | TRUE | FALSE | FALSE | FALSE | FALSE | FALSE | FALSE | FALSE | FALSE | TRUE | TRUE | FALSE |
2417 rows × 14 columns
Y = Y == 'TRUE'
Y
Class1 | Class2 | Class3 | Class4 | Class5 | Class6 | Class7 | Class8 | Class9 | Class10 | Class11 | Class12 | Class13 | Class14 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | False | False | False | False | False | False | True | True | False | False | False | True | True | False |
1 | False | False | True | True | False | False | False | False | False | False | False | False | False | False |
2 | False | True | True | False | False | False | False | False | False | False | False | True | True | False |
3 | False | False | True | True | False | False | False | False | False | False | False | False | False | False |
4 | False | False | True | True | True | True | False | False | False | False | False | False | False | False |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
2412 | False | True | True | False | False | False | False | False | False | False | False | False | False | False |
2413 | True | True | False | False | False | False | False | False | False | False | False | True | True | False |
2414 | False | False | False | False | False | True | True | True | False | False | False | True | True | False |
2415 | False | False | False | False | False | False | False | False | False | False | False | True | True | False |
2416 | False | True | True | False | False | False | False | False | False | False | False | True | True | False |
2417 rows × 14 columns
from sklearn.model_selection import train_test_split
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2, random_state=42)
from sklearn.linear_model import LogisticRegression
from sklearn.multiclass import OneVsRestClassifier
for OVR model
ovr = OneVsRestClassifier(LogisticRegression(solver='liblinear'))
ovr.fit(X_train, Y_train)
pred_ovr = ovr.predict(X_test)
from sklearn.metrics import jaccard_score
ovr_score = jaccard_score(Y_test, pred_ovr, average='micro')
ovr_score
0.4739147749900438
for CC model (https://www.cs.waikato.ac.nz/ml/publications/2009/chains.pdf)
from sklearn.multioutput import ClassifierChain
cc = ClassifierChain(LogisticRegression(solver='liblinear'), order='random', random_state=42)
cc.fit(X_train, Y_train)
pred_cc = cc.predict(X_test)
cc_score = jaccard_score(Y_test, pred_cc, average='micro')
cc_score
0.4618533680684779
for CC ensemble
chains = [ClassifierChain(LogisticRegression(solver='liblinear'), order='random', random_state=42+i)
for i in range(10)]
for chain in chains:
chain.fit(X_train, Y_train)
pred_chains = np.array([chain.predict(X_test) for chain in chains])
chain_scores = [jaccard_score(Y_test, pred_chain, average='micro')
for pred_chain in pred_chains]
chain_scores
[0.4618533680684779, 0.4768186958160573, 0.47702001482579687, 0.46651785714285715, 0.4563857195436143, 0.4561075831154277, 0.4669636228656273, 0.44677769732078204, 0.46481954064892456, 0.4483288513253938]
plt.bar(np.arange(11), [ovr_score]+chain_scores)
plt.show()
proba_chains = np.array([chain.predict_proba(X_test) for chain in chains])
proba_ensemble = proba_chains.mean(axis=0)
ensemble_score = jaccard_score(Y_test, proba_ensemble >= 0.5, average='micro')
ensemble_score
0.4726443768996961
plt.bar(np.arange(12), [ovr_score]+chain_scores+[ensemble_score])
plt.show()