%matplotlib inline
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt
sns.set(font_scale=1.5)
import xgboost as xgb
from sklearn.metrics import accuracy_score, f1_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
Посмотрим на примере данных по оттоку клиентов из телеком-компании.### Load data
df = pd.read_csv("../../data/telecom_churn.csv")
df.head()
Штаты просто занумеруем, а признаки International plan (наличие международного роуминга), Voice mail plan (наличие голосовой почтыы) и целевой Churn сделаем бинарными.
state_enc = LabelEncoder()
df["State"] = state_enc.fit_transform(df["State"])
df["International plan"] = (df["International plan"] == "Yes").astype("int")
df["Voice mail plan"] = (df["Voice mail plan"] == "Yes").astype("int")
df["Churn"] = (df["Churn"]).astype("int")
Разделим данные на обучающую и тестовую выборки в отношении 7:3. Инициализируем соотв. объекты DMatrix dtrain и dtest.
X_train, X_test, y_train, y_test = train_test_split(
df.drop("Churn", axis=1),
df["Churn"],
test_size=0.3,
stratify=df["Churn"],
random_state=17,
)
dtrain = xgb.DMatrix(X_train, y_train)
dtest = xgb.DMatrix(X_test, y_test)
Обучим всего 50 деревьев решений глубины 3.
params = {"objective": "binary:logistic", "max_depth": 3, "silent": 1, "eta": 0.5}
num_rounds = 50
watchlist = [(dtest, "test"), (dtrain, "train")] # native interface only
xgb_model = xgb.train(params, dtrain, num_rounds, watchlist)
F score при оценке важности признаков в Xgboost (не путать с F1 score как метрики качества классификации) вычисляется на основе того, как часто разбиение делалось по данному признаку.
xgb.plot_importance(xgb_model);
Можно так, в виде словаря или DataFrame:
importances = xgb_model.get_fscore()
importances
# create df
importance_df = pd.DataFrame(
{"Splits": list(importances.values()), "Feature": list(importances.keys())}
)
importance_df.sort_values(by="Splits", inplace=True)
importance_df.plot(kind="barh", x="Feature", figsize=(8, 6), color="orange");