from datetime import datetime
print(f'Päivitetty {datetime.now()}')
Päivitetty 2022-09-11 11:17:51.409027
Kurjenmiekkojen (iris) luokittelu on klassinen esimerkki, joka yleensä otetaan ensimmäisenä esimerkkinä luokittelumalleista.
Tässä opetan mallin erottelemaan kurjenmiekan lajikkeita (setosa, versicolor ja virginica) toisistaan verholehtien (sepal) ja terälehtien (petal) pituuksien ja leveyksien perusteella.
Tässä käytän mallina päätöspuuta (decision tree), jonka rakenteen voin esittää havainnollisena kaaviona. Parempiakin malleja löytyy, mutta tarkoituksella tässä ensimmäisessä esimerkissä pidän asiat mahdollisimman yksinkertaisina.
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
# train_test_split osaa jakaa datan opetusdataan ja testidataan
from sklearn.model_selection import train_test_split
# Käytän mallina päätöspuuta; plot_tree osaa piirtää päätöspuun
from sklearn.tree import DecisionTreeClassifier, plot_tree
# Sekaannusmatriisin näyttämiseen
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
# Lataan datan seaborn-kirjastosta
df = sns.load_dataset('iris')
df.info()
<class 'pandas.core.frame.DataFrame'> RangeIndex: 150 entries, 0 to 149 Data columns (total 5 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 sepal_length 150 non-null float64 1 sepal_width 150 non-null float64 2 petal_length 150 non-null float64 3 petal_width 150 non-null float64 4 species 150 non-null object dtypes: float64(4), object(1) memory usage: 6.0+ KB
df
sepal_length | sepal_width | petal_length | petal_width | species | |
---|---|---|---|---|---|
0 | 5.1 | 3.5 | 1.4 | 0.2 | setosa |
1 | 4.9 | 3.0 | 1.4 | 0.2 | setosa |
2 | 4.7 | 3.2 | 1.3 | 0.2 | setosa |
3 | 4.6 | 3.1 | 1.5 | 0.2 | setosa |
4 | 5.0 | 3.6 | 1.4 | 0.2 | setosa |
... | ... | ... | ... | ... | ... |
145 | 6.7 | 3.0 | 5.2 | 2.3 | virginica |
146 | 6.3 | 2.5 | 5.0 | 1.9 | virginica |
147 | 6.5 | 3.0 | 5.2 | 2.0 | virginica |
148 | 6.2 | 3.4 | 5.4 | 2.3 | virginica |
149 | 5.9 | 3.0 | 5.1 | 1.8 | virginica |
150 rows × 5 columns
# Parittaiset hajontakaaviot, joissa eri lajikkeet (species) eri väreillä
sns.pairplot(df, hue='species')
<seaborn.axisgrid.PairGrid at 0x237c489aaf0>
Hajontakaavioiden perusteella setosat ovat helposti eroteltavissa, mutta versicolor ja virginica menevät jossain määrin päällekkäin. Erityisesti terälehden pituus ja leveys (petal_length ja petal_width) näyttäisivät erottelevan lajikkeita hyvin.
# Selittävät muuttujat
X = df.drop('species', axis=1)
# Kohdemuuttuja
y = df['species']
# Datan jako opetus- ja testidataan
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=2)
Jako opetus- ja testidataan tapahtuu sattumanvaraisesti, joten eri suorituskerroilla saan erilaisia jakoja. Koska eri suorituskerroilla on erilaisia opetusdatoja, niin mallikin vaihtelee suorituskerroittain.
Tässä annan random_state-parametrille kiinteän arvon, jonka seurauksena jako tehdään jokaisella suorituskerralla samalla tavalla.
# Mallin sovitus
malli = DecisionTreeClassifier(max_depth=3, random_state=2)
malli.fit(X_train, y_train)
DecisionTreeClassifier(max_depth=3, random_state=2)
max_depth-parametri määrittää päätöpuun haarautumisten maksimi määrän. Voit Kokeilla myös muita arvoja.
Mallin tarkkuus kasvaa haarautumisten lisääntyessä, mutta samalla kasvaa ylisovituksen riski. Jos mallin tarkkuus opetusdatassa on selvästi suurempi kuin testidatassa, niin tämä kertoo ylisovituksesta.
Mallia laskeva algoritmi käyttää hyväkseen sattumanvaraisuutta ja eri suorituskerroilla voin saada hieman toisistaan poikkeavia malleja. Tässä kiinnitän mallin asettamalla random_state-parametrille kiinteän arvon.
print(f'Mallin tarkkuus opetusdatassa {malli.score(X_train, y_train):.3f}')
print(f'Mallin tarkkuus testidatassa {malli.score(X_test, y_test):.3f}')
Mallin tarkkuus opetusdatassa 0.982 Mallin tarkkuus testidatassa 0.974
Muotoiltua merkkijonoa edeltää f. Muotoiltuun merkkijonoon voin lisätä muuttujan arvon aaltosulkujen sisään ja voin samalla antaa muotoiluohjeen (esimerkiksi :.3f muotoilee desimaaliluvun kolmen desimaalin tarkkuuteen).
# Sekaannus-matriisi opetusdatalle
# Mallin antamat ennusteet opetusdatalle
y_train_malli = malli.predict(X_train)
cm = confusion_matrix(y_train, y_train_malli)
ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=['setosa', 'versicolor', 'virginica']).plot()
<sklearn.metrics._plot.confusion_matrix.ConfusionMatrixDisplay at 0x237c608c7c0>
Opetusdatassa malli ennustaa kaksi versicoloria virheellisesti virginicaksi.
# Sekaannus-matriisi testidatalle
# Mallin antamat ennusteet testidatalle
y_test_malli = malli.predict(X_test)
cm = confusion_matrix(y_test, y_test_malli)
ConfusionMatrixDisplay(confusion_matrix=cm, display_labels = ['setosa', 'versicolor', 'virginica']).plot()
<sklearn.metrics._plot.confusion_matrix.ConfusionMatrixDisplay at 0x237c607b910>
Testidatassa malli ennustaa yhden versicolorin virheellisesti virginicaksi.
# Päätöspuumallin voin havainnollistaa kaaviona
plt.figure(figsize=(10, 6))
plot_tree(decision_tree=malli,
feature_names=['sepal_length', 'sepal_width', 'petal_length', 'petal_width'])
[Text(0.375, 0.875, 'petal_width <= 0.8\ngini = 0.665\nsamples = 112\nvalue = [34, 39, 39]'), Text(0.25, 0.625, 'gini = 0.0\nsamples = 34\nvalue = [34, 0, 0]'), Text(0.5, 0.625, 'petal_width <= 1.65\ngini = 0.5\nsamples = 78\nvalue = [0, 39, 39]'), Text(0.25, 0.375, 'petal_length <= 4.95\ngini = 0.136\nsamples = 41\nvalue = [0, 38, 3]'), Text(0.125, 0.125, 'gini = 0.0\nsamples = 37\nvalue = [0, 37, 0]'), Text(0.375, 0.125, 'gini = 0.375\nsamples = 4\nvalue = [0, 1, 3]'), Text(0.75, 0.375, 'petal_length <= 4.85\ngini = 0.053\nsamples = 37\nvalue = [0, 1, 36]'), Text(0.625, 0.125, 'gini = 0.444\nsamples = 3\nvalue = [0, 1, 2]'), Text(0.875, 0.125, 'gini = 0.0\nsamples = 34\nvalue = [0, 0, 34]')]
Opetusdatassa on 34 setosaa, 39 versicoloria ja 39 virginicaa.
Ensimmäisessä haarautumisessa päätössääntönä on petal_width <= 0.8, jonka perusteella saadaan eroteltua kaikki setosa-lajikkeeseen kuuluvat omaan haaraansa.
Ensimmäisessä vaiheessa gini = (34/112)^2 + (39/112)^2 + (39/112)^2 (todennäköisyyksien neliöiden summa). Ginin arvosta voidaan johtaa gini impurity:
gini impurity = 1 - gini = 0.665.
Harhaanjohtavasti kaaviossa käytetään nimitystä gini vaikka kyseessä on gini impurity. Päätöspuu-algoritmi määrittää haarautumiskohdat siten että päästään mahdollisimman pieniin gini impurity -arvoihin.
Jos en käytä mallin sovituksessa kiinteää random_state-parametrin arvoa, niin eri suorituskerroilla päätöspuussa voi olla toisistaan poikkeavia päätössääntöjä.
# Avaan uuden datan, jossa lajikkeet (species) eivät ole tiedossa
# Datassa täytyy olla samat selittävät muuttujat kuin mallia sovitettaessa
Xnew = pd.read_excel('https://taanila.fi/irisnew.xlsx')
Xnew
sepal_length | sepal_width | petal_length | petal_width | |
---|---|---|---|---|
0 | 5.0 | 3.5 | 1.5 | 0.3 |
1 | 8.1 | 3.3 | 6.5 | 1.9 |
2 | 6.0 | 3.0 | 3.0 | 0.5 |
# Lasken ennusteet
ennuste = malli.predict(Xnew)
# Lasken todennäköisyydet
todnak = malli.predict_proba(Xnew).round(2)
# Lisään ennusteet ja todennäköisyydet dataan
Xnew['ennuste'] = ennuste
Xnew[['tn_setosa', 'tn_versicolor', 'tn_virginica']] = todnak
Xnew
sepal_length | sepal_width | petal_length | petal_width | ennuste | tn_setosa | tn_versicolor | tn_virginica | |
---|---|---|---|---|---|---|---|---|
0 | 5.0 | 3.5 | 1.5 | 0.3 | setosa | 1.0 | 0.0 | 0.0 |
1 | 8.1 | 3.3 | 6.5 | 1.9 | virginica | 0.0 | 0.0 | 1.0 |
2 | 6.0 | 3.0 | 3.0 | 0.5 | setosa | 1.0 | 0.0 | 0.0 |
Ensimmäisen ja viimeinen ennustetaan setosaksi, toinen virginicaksi.
Data-analytiikka Pythonilla: https://tilastoapu.wordpress.com/python/