Please run those two cells before running the Notebook!
As those plotting settings are standard throughout the book, we do not show them in the book every time we plot something.
# %matplotlib inline
%config InlineBackend.figure_format = "retina"
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
from pandas.core.common import SettingWithCopyWarning
warnings.simplefilter(action="ignore", category=FutureWarning)
warnings.simplefilter(action="ignore", category=SettingWithCopyWarning)
# feel free to modify, for example, change the context to "notebook"
sns.set_theme(context="talk", style="whitegrid",
palette="colorblind", color_codes=True,
rc={"figure.figsize": [12, 8]})
fastai
's Tabular Learner¶from fastai.tabular.all import *
from sklearn.model_selection import train_test_split
from chapter_15_utils import performance_evaluation_report_fastai
import pandas as pd
df = pd.read_csv("../Datasets/credit_card_default.csv",
na_values="")
df.head()
limit_bal | sex | education | marriage | age | payment_status_sep | payment_status_aug | payment_status_jul | payment_status_jun | payment_status_may | ... | bill_statement_jun | bill_statement_may | bill_statement_apr | previous_payment_sep | previous_payment_aug | previous_payment_jul | previous_payment_jun | previous_payment_may | previous_payment_apr | default_payment_next_month | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 20000 | Female | University | Married | 24.0 | Payment delayed 2 months | Payment delayed 2 months | Payed duly | Payed duly | Unknown | ... | 0 | 0 | 0 | 0 | 689 | 0 | 0 | 0 | 0 | 1 |
1 | 120000 | Female | University | Single | 26.0 | Payed duly | Payment delayed 2 months | Unknown | Unknown | Unknown | ... | 3272 | 3455 | 3261 | 0 | 1000 | 1000 | 1000 | 0 | 2000 | 1 |
2 | 90000 | Female | University | Single | 34.0 | Unknown | Unknown | Unknown | Unknown | Unknown | ... | 14331 | 14948 | 15549 | 1518 | 1500 | 1000 | 1000 | 1000 | 5000 | 0 |
3 | 50000 | Female | University | Married | 37.0 | Unknown | Unknown | Unknown | Unknown | Unknown | ... | 28314 | 28959 | 29547 | 2000 | 2019 | 1200 | 1100 | 1069 | 1000 | 0 |
4 | 50000 | Male | University | Married | 57.0 | Payed duly | Unknown | Payed duly | Unknown | Unknown | ... | 20940 | 19146 | 19131 | 2000 | 36681 | 10000 | 9000 | 689 | 679 | 0 |
5 rows × 24 columns
# as a reminder, where the possible missing values are
df.isna().any()
limit_bal False sex True education True marriage True age True payment_status_sep False payment_status_aug False payment_status_jul False payment_status_jun False payment_status_may False payment_status_apr False bill_statement_sep False bill_statement_aug False bill_statement_jul False bill_statement_jun False bill_statement_may False bill_statement_apr False previous_payment_sep False previous_payment_aug False previous_payment_jul False previous_payment_jun False previous_payment_may False previous_payment_apr False default_payment_next_month False dtype: bool
TARGET = "default_payment_next_month"
cat_features = list(df.select_dtypes("object").columns)
num_features = list(df.select_dtypes("number").columns)
num_features.remove(TARGET)
preprocessing = [FillMissing, Categorify, Normalize]
splits = RandomSplitter(valid_pct=0.2, seed=42)(range_of(df))
splits
((#24000) [27362,16258,19716,9066,1258,23042,18939,24443,4328,4976...], (#6000) [7542,10109,19114,5209,9270,15555,12970,10207,13694,1745...])
TabularPandas
dataset:tabular_df = TabularPandas(
df,
procs=preprocessing,
cat_names=cat_features,
cont_names=num_features,
y_names=TARGET,
y_block=CategoryBlock(),
splits=splits
)
PREVIEW_COLS = ["sex", "education", "marriage",
"payment_status_sep", "age_na", "limit_bal",
"age", "bill_statement_sep"]
tabular_df.xs.iloc[:5][PREVIEW_COLS]
sex | education | marriage | payment_status_sep | age_na | limit_bal | age | bill_statement_sep | |
---|---|---|---|---|---|---|---|---|
27362 | 2 | 4 | 3 | 10 | 1 | -0.290227 | -0.919919 | -0.399403 |
16258 | 1 | 4 | 1 | 10 | 1 | -0.443899 | -0.266960 | 0.731335 |
19716 | 1 | 1 | 3 | 2 | 1 | 2.014862 | -0.158134 | -0.493564 |
9066 | 1 | 2 | 3 | 3 | 1 | -0.674408 | -0.919919 | -0.646319 |
1258 | 2 | 1 | 3 | 1 | 1 | 0.324464 | -0.266960 | -0.692228 |
tabular_df.xs.columns
Index(['sex', 'education', 'marriage', 'payment_status_sep', 'payment_status_aug', 'payment_status_jul', 'payment_status_jun', 'payment_status_may', 'payment_status_apr', 'age_na', 'limit_bal', 'age', 'bill_statement_sep', 'bill_statement_aug', 'bill_statement_jul', 'bill_statement_jun', 'bill_statement_may', 'bill_statement_apr', 'previous_payment_sep', 'previous_payment_aug', 'previous_payment_jul', 'previous_payment_jun', 'previous_payment_may', 'previous_payment_apr'], dtype='object')
DataLoaders
object from the TabularPandas
dataset:data_loader = tabular_df.dataloaders(bs=64, drop_last=True)
data_loader.show_batch()
sex | education | marriage | payment_status_sep | payment_status_aug | payment_status_jul | payment_status_jun | payment_status_may | payment_status_apr | age_na | limit_bal | age | bill_statement_sep | bill_statement_aug | bill_statement_jul | bill_statement_jun | bill_statement_may | bill_statement_apr | previous_payment_sep | previous_payment_aug | previous_payment_jul | previous_payment_jun | previous_payment_may | previous_payment_apr | default_payment_next_month | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | Female | University | Married | Unknown | Unknown | Unknown | Payment delayed 2 months | Payment delayed 2 months | Payment delayed 2 months | False | 49999.996248 | 44.999999 | 26615.000783 | 22241.000008 | 23406.000381 | 15379.000251 | 16715.000690 | 11132.999160 | 1999.999921 | 2000.000073 | -0.000163 | 1499.999924 | 0.000058 | 1000.000099 | 0 |
1 | Female | University | Married | Unknown | Unknown | Unknown | Unknown | Unknown | Unknown | False | 250000.000277 | 41.000000 | 214301.005462 | 211761.997590 | 214629.002288 | 204092.997983 | 150752.002005 | 150862.997034 | 8971.999882 | 9700.000071 | 8959.999878 | 5630.000021 | 5776.999990 | 5458.999994 | 0 |
2 | Male | University | Single | Unknown | Unknown | Unknown | Unknown | Unknown | Unknown | False | 119999.999598 | 38.000000 | 116398.997952 | 117513.999178 | 111002.998624 | 84839.998144 | 86954.000433 | 83447.999191 | 4999.999994 | 4999.999988 | 2999.999927 | 3499.999964 | 3000.000004 | 81999.999746 | 1 |
3 | Male | High school | Married | Payment delayed 1 month | Unknown | Payed duly | Unknown | Unknown | Unknown | False | 49999.996248 | 47.000000 | -2011.998149 | -2011.998079 | 47713.999998 | 48684.999863 | 18138.999176 | 18519.000666 | 0.000039 | 50223.999794 | 2195.000049 | 648.999802 | 672.000021 | 748.999839 | 0 |
4 | Female | University | Single | Unknown | Unknown | Unknown | Unknown | Unknown | Unknown | False | 20000.005341 | 23.000000 | 4731.000802 | 6642.000937 | 11508.001605 | 16952.999205 | 17310.000648 | 18465.999694 | 1999.999921 | 4999.999988 | 6000.000017 | 626.000093 | 1440.999895 | -0.000013 | 0 |
5 | Male | Graduate school | Married | Unknown | Unknown | Unknown | Unknown | Unknown | Unknown | False | 119999.999598 | 44.999999 | 118287.001734 | 117775.001568 | 117106.004112 | 118206.999340 | 116884.997508 | 123040.002858 | 5700.000000 | 5699.999990 | 6000.000017 | 4999.999993 | 9999.999951 | -0.000013 | 0 |
6 | Female | University | Married | Unknown | Unknown | Unknown | Unknown | Unknown | Unknown | False | 40000.001865 | 51.000000 | 144396.996155 | 147924.001028 | 26974.000119 | 22709.999695 | 37976.999905 | 39346.999995 | 3988.000052 | 2013.999968 | 799.000257 | 32000.999700 | 1999.999947 | 1000.000099 | 1 |
7 | Female | Graduate school | Married | Payment delayed 1 month | Payed duly | Payment delayed 3 months | Payment delayed 2 months | Payed duly | Payed duly | False | 59999.998389 | 38.000000 | 0.000274 | 780.001227 | 780.001819 | 389.999246 | 390.000193 | 390.000078 | 779.999841 | -0.000043 | -0.000163 | 390.000095 | 389.999981 | 86.999973 | 0 |
8 | Male | Graduate school | Single | Payed duly | Payed duly | Payed duly | Payed duly | Unknown | Unknown | False | 499999.999495 | 33.000000 | 1024.998878 | 1691.000159 | 1316.000832 | 4941.000584 | 8539.000455 | 3841.999798 | 1999.999921 | 1316.000115 | 4999.999992 | 4999.999993 | 1999.999947 | 1222.999896 | 0 |
9 | Male | High school | Single | Payment delayed 1 month | Payment delayed 2 months | Payment delayed 3 months | Payment delayed 2 months | Payment delayed 2 months | Payment delayed 2 months | False | 49999.996248 | 30.000000 | 31217.000920 | 33422.999919 | 32600.000411 | 31777.000449 | 33966.000055 | 34759.000197 | 2999.999987 | -0.000043 | -0.000163 | 2700.000066 | 1500.000086 | -0.000013 | 0 |
recall = Recall()
precision = Precision()
learn = tabular_learner(
data_loader,
[500, 200],
metrics=[accuracy, recall, precision]
)
learn.model
TabularModel( (embeds): ModuleList( (0): Embedding(3, 3) (1): Embedding(5, 4) (2): Embedding(4, 3) (3): Embedding(11, 6) (4): Embedding(11, 6) (5): Embedding(11, 6) (6): Embedding(11, 6) (7): Embedding(10, 6) (8): Embedding(10, 6) (9): Embedding(3, 3) ) (emb_drop): Dropout(p=0.0, inplace=False) (bn_cont): BatchNorm1d(14, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (layers): Sequential( (0): LinBnDrop( (0): Linear(in_features=63, out_features=500, bias=False) (1): ReLU(inplace=True) (2): BatchNorm1d(500, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (1): LinBnDrop( (0): Linear(in_features=500, out_features=200, bias=False) (1): ReLU(inplace=True) (2): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (2): LinBnDrop( (0): Linear(in_features=200, out_features=2, bias=True) ) ) )
# we can also figure out the embeddings using the following snippet
emb_szs = get_emb_sz(tabular_df)
emb_szs
[(3, 3), (5, 4), (4, 3), (11, 6), (11, 6), (11, 6), (11, 6), (10, 6), (10, 6), (3, 3)]
Embedding(11, 6)
means that a categorical embedding was created with 11 input values and 6 output latent features.
learn.lr_find()
# plt.savefig("images/figure_15_3")
learn.fit(n_epoch=25, lr=1e-3, wd=0.2)
learn.recorder.plot_loss()
plt.tight_layout()
sns.despine()
# plt.savefig("images/figure_15_5")
DataLoaders
:valid_data_loader = learn.dls.test_dl(df.loc[list(splits[1])])
valid_data_loader.show_batch()
sex | education | marriage | payment_status_sep | payment_status_aug | payment_status_jul | payment_status_jun | payment_status_may | payment_status_apr | age_na | limit_bal | age | bill_statement_sep | bill_statement_aug | bill_statement_jul | bill_statement_jun | bill_statement_may | bill_statement_apr | previous_payment_sep | previous_payment_aug | previous_payment_jul | previous_payment_jun | previous_payment_may | previous_payment_apr | default_payment_next_month | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | Female | Graduate school | Single | Payment delayed 1 month | Unknown | Unknown | Unknown | Payed duly | Payed duly | False | 80000.002671 | 28.0 | 0.000274 | -0.001422 | -0.000269 | -0.000013 | 2284.000497 | -786.000871 | 0.000039 | -0.000043 | -0.000163 | 2283.999916 | 0.000058 | -0.000013 | 1 |
1 | Female | Graduate school | Single | Unknown | Unknown | Unknown | Unknown | Unknown | Unknown | False | 360000.000553 | 44.0 | 347695.995104 | 329863.996200 | 322158.997157 | 289376.998417 | 146945.999005 | 130085.003447 | 19999.999860 | 20008.999765 | 30000.000210 | 9999.999766 | 9999.999951 | 10000.000060 | 0 |
2 | Female | University | Married | Unknown | Unknown | Unknown | Unknown | Unknown | Unknown | False | 80000.002671 | 31.0 | 0.000274 | -0.001422 | -0.000269 | -0.000013 | 0.001374 | 0.001333 | 0.000039 | -0.000043 | -0.000163 | -0.000076 | 0.000058 | -0.000013 | 0 |
3 | Female | University | Married | Payed duly | Payed duly | Payed duly | Payed duly | Payed duly | Payed duly | False | 49999.996248 | 43.0 | 560.001359 | 1120.999531 | 194.999217 | 196.998524 | 196.998239 | 197.000509 | 1121.000175 | 194.999828 | 197.000070 | 197.000190 | 196.999852 | 196.999786 | 0 |
4 | Male | University | Single | Unknown | Unknown | Unknown | Unknown | Unknown | Unknown | False | 179999.999837 | 26.0 | 145574.003103 | 108616.999401 | 102710.999113 | 95651.998160 | 97661.001694 | 99655.998778 | 3906.999985 | 3600.000028 | 3414.999937 | 3541.999967 | 3615.999970 | 1999.999949 | 0 |
5 | Male | University | Single | Payment delayed 2 months | Payment delayed 2 months | Payment delayed 2 months | Payed duly | Unknown | Payed duly | False | 20000.005341 | 37.0 | 3254.001134 | 2521.998928 | -0.000269 | 779.998506 | 390.000193 | 390.000078 | 1999.999921 | -0.000043 | 780.000096 | -0.000076 | 389.999981 | 1679.999945 | 1 |
6 | Female | University | Married | Unknown | Unknown | Unknown | Unknown | Unknown | Unknown | False | 279999.998942 | 44.0 | 286706.004848 | 272243.004586 | 203743.992187 | 203520.005928 | 207877.993990 | 211811.996533 | 10480.000231 | 8041.000065 | 7200.000052 | 7499.999924 | 7092.999917 | 5701.999989 | 0 |
7 | Male | University | Single | Unknown | Unknown | Unknown | Unknown | Unknown | Unknown | False | 80000.002671 | 26.0 | 76157.999269 | 81858.000996 | 80337.000365 | 61002.000665 | 58147.999258 | 52514.999924 | 8000.000067 | 4999.999988 | 4000.000000 | 4999.999993 | 3000.000004 | 5000.000007 | 0 |
8 | Male | Graduate school | Married | Unknown | Unknown | Unknown | Payed duly | Payed duly | Payed duly | False | 209999.999471 | 52.0 | 44933.000212 | 38541.999976 | 39331.999888 | 11140.000617 | 8462.999592 | 10406.000874 | 1792.999970 | 8242.000014 | 12000.000196 | 8533.999995 | 11000.000119 | 7500.000042 | 0 |
9 | Female | Graduate school | Single | Unknown | Unknown | Unknown | Unknown | Unknown | Unknown | False | 49999.996248 | 22.0 | 49459.000028 | 49281.000000 | 50071.000110 | 10104.000406 | 9208.000207 | 10075.000742 | 2299.999941 | 2000.000073 | 1000.000042 | 500.000082 | 999.999780 | 500.000043 | 0 |
learn.validate(dl=valid_data_loader)
(#4) [0.42411357164382935,0.824833333492279,0.3622848200312989,0.6623748211731044]
preds, y_true = learn.get_preds(dl=valid_data_loader)
preds
tensor([[0.8092, 0.1908], [0.9339, 0.0661], [0.8631, 0.1369], ..., [0.9249, 0.0751], [0.8556, 0.1444], [0.8670, 0.1330]])
preds.argmax(dim=-1)
tensor([0, 0, 0, ..., 0, 0, 0])
y_true
tensor([[1], [0], [0], ..., [0], [0], [0]], dtype=torch.int8)
perf = performance_evaluation_report_fastai(
learn, valid_data_loader, show_plot=True
)
sns.despine()
# plt.savefig("images/figure_15_6", dpi=200)
perf
We can also be more specific when creating the training/validation split. Below, we use the sklearn
funcitonalities and pass indices to the IndexSplitter
class.
from sklearn.model_selection import StratifiedKFold
X = df.copy()
y = X.pop(TARGET)
strat_split = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
train_ind, test_ind = next(strat_split.split(X, y))
ind_splits = IndexSplitter(valid_idx=list(test_ind))(range_of(df))
tabular_df = TabularPandas(
df,
procs=preprocessing,
cat_names=cat_features,
cont_names=num_features,
y_names=TARGET,
y_block=CategoryBlock(),
splits=ind_splits
)
We can look into the example results.
learn.show_results()
Or create predictions for a single row:
row, clas, probs = learn.predict(df.iloc[0])
row
clas
probs
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import recall_score
from pytorch_tabnet.tab_model import TabNetClassifier
from pytorch_tabnet.metrics import Metric
import torch
import pandas as pd
import numpy as np
df = pd.read_csv("../Datasets/credit_card_default.csv",
na_values="")
df.head()
limit_bal | sex | education | marriage | age | payment_status_sep | payment_status_aug | payment_status_jul | payment_status_jun | payment_status_may | ... | bill_statement_jun | bill_statement_may | bill_statement_apr | previous_payment_sep | previous_payment_aug | previous_payment_jul | previous_payment_jun | previous_payment_may | previous_payment_apr | default_payment_next_month | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 20000 | Female | University | Married | 24.0 | Payment delayed 2 months | Payment delayed 2 months | Payed duly | Payed duly | Unknown | ... | 0 | 0 | 0 | 0 | 689 | 0 | 0 | 0 | 0 | 1 |
1 | 120000 | Female | University | Single | 26.0 | Payed duly | Payment delayed 2 months | Unknown | Unknown | Unknown | ... | 3272 | 3455 | 3261 | 0 | 1000 | 1000 | 1000 | 0 | 2000 | 1 |
2 | 90000 | Female | University | Single | 34.0 | Unknown | Unknown | Unknown | Unknown | Unknown | ... | 14331 | 14948 | 15549 | 1518 | 1500 | 1000 | 1000 | 1000 | 5000 | 0 |
3 | 50000 | Female | University | Married | 37.0 | Unknown | Unknown | Unknown | Unknown | Unknown | ... | 28314 | 28959 | 29547 | 2000 | 2019 | 1200 | 1100 | 1069 | 1000 | 0 |
4 | 50000 | Male | University | Married | 57.0 | Payed duly | Unknown | Payed duly | Unknown | Unknown | ... | 20940 | 19146 | 19131 | 2000 | 36681 | 10000 | 9000 | 689 | 679 | 0 |
5 rows × 24 columns
X = df.copy()
y = X.pop("default_payment_next_month")
cat_features = list(X.select_dtypes("object").columns)
num_features = list(X.select_dtypes("number").columns)
# as a reminder, where the possible missing values are
X.isna().any()
limit_bal False sex True education True marriage True age True payment_status_sep False payment_status_aug False payment_status_jul False payment_status_jun False payment_status_may False payment_status_apr False bill_statement_sep False bill_statement_aug False bill_statement_jul False bill_statement_jun False bill_statement_may False bill_statement_apr False previous_payment_sep False previous_payment_aug False previous_payment_jul False previous_payment_jun False previous_payment_may False previous_payment_apr False dtype: bool
LabelEncoder
and store the number of unique categories per feature:cat_dims = {}
for col in cat_features:
label_encoder = LabelEncoder()
X[col] = X[col].fillna("Missing")
X[col] = label_encoder.fit_transform(X[col].values)
cat_dims[col] = len(label_encoder.classes_)
cat_dims
{'sex': 3, 'education': 5, 'marriage': 4, 'payment_status_sep': 10, 'payment_status_aug': 10, 'payment_status_jul': 10, 'payment_status_jun': 10, 'payment_status_may': 9, 'payment_status_apr': 9}
# create the initial split - training and temp
X_train, X_temp, y_train, y_temp = train_test_split(
X, y,
test_size=0.3,
stratify=y,
random_state=42
)
# create the valid and test sets
X_valid, X_test, y_valid, y_test = train_test_split(
X_temp, y_temp,
test_size=0.5,
stratify=y_temp,
random_state=42
)
print("Percentage of data in each set ----")
print(f"Train: {100 * len(X_train) / len(X):.2f}%")
print(f"Valid: {100 * len(X_valid) / len(X):.2f}%")
print(f"Test: {100 * len(X_test) / len(X):.2f}%")
print("")
print("Class distribution in each set ----")
print(f"Train: {y_train.value_counts(normalize=True).values}")
print(f"Valid: {y_valid.value_counts(normalize=True).values}")
print(f"Test: {y_test.value_counts(normalize=True).values}")
Percentage of data in each set ---- Train: 70.00% Valid: 15.00% Test: 15.00% Class distribution in each set ---- Train: [0.77880952 0.22119048] Valid: [0.77888889 0.22111111] Test: [0.77866667 0.22133333]
for col in num_features:
imp_mean = X_train[col].mean()
X_train[col] = X_train[col].fillna(imp_mean)
X_valid[col] = X_valid[col].fillna(imp_mean)
X_test[col] = X_test[col].fillna(imp_mean)
features = X.columns.to_list()
cat_ind = [features.index(feat) for feat in cat_features]
cat_dims = list(cat_dims.values())
cat_ind
[1, 2, 3, 5, 6, 7, 8, 9, 10]
class Recall(Metric):
def __init__(self):
self._name = "recall"
self._maximize = True
def __call__(self, y_true, y_score):
y_pred = np.argmax(y_score, axis=1)
return recall_score(y_true, y_pred)
tabnet_params = {
"cat_idxs": cat_ind,
"cat_dims": cat_dims,
"optimizer_fn": torch.optim.Adam,
"optimizer_params": dict(lr=2e-2),
"scheduler_params": {
"step_size":20,
"gamma":0.9
},
"scheduler_fn": torch.optim.lr_scheduler.StepLR,
"mask_type": "sparsemax",
"seed": 42,
}
tabnet = TabNetClassifier(**tabnet_params)
Device used : cpu
tabnet.fit(
X_train=X_train.values,
y_train=y_train.values,
eval_set=[
(X_train.values, y_train.values),
(X_valid.values, y_valid.values)
],
eval_name=["train", "valid"],
eval_metric=["auc", Recall],
max_epochs=200,
patience=20,
batch_size=1024,
virtual_batch_size=128,
weights=1,
)
epoch 0 | loss: 0.69867 | train_auc: 0.61461 | train_recall: 0.3789 | valid_auc: 0.62232 | valid_recall: 0.37286 | 0:00:01s epoch 1 | loss: 0.62342 | train_auc: 0.70538 | train_recall: 0.51539 | valid_auc: 0.69053 | valid_recall: 0.48744 | 0:00:02s epoch 2 | loss: 0.59902 | train_auc: 0.71777 | train_recall: 0.51625 | valid_auc: 0.71667 | valid_recall: 0.48643 | 0:00:04s epoch 3 | loss: 0.59629 | train_auc: 0.73428 | train_recall: 0.5268 | valid_auc: 0.72767 | valid_recall: 0.49447 | 0:00:05s epoch 4 | loss: 0.58383 | train_auc: 0.75126 | train_recall: 0.52723 | valid_auc: 0.74693 | valid_recall: 0.49749 | 0:00:06s epoch 5 | loss: 0.58065 | train_auc: 0.75707 | train_recall: 0.52766 | valid_auc: 0.75177 | valid_recall: 0.49648 | 0:00:08s epoch 6 | loss: 0.58437 | train_auc: 0.76085 | train_recall: 0.53671 | valid_auc: 0.75355 | valid_recall: 0.50452 | 0:00:09s epoch 7 | loss: 0.5804 | train_auc: 0.76206 | train_recall: 0.54682 | valid_auc: 0.75346 | valid_recall: 0.52161 | 0:00:10s epoch 8 | loss: 0.57665 | train_auc: 0.76926 | train_recall: 0.52659 | valid_auc: 0.75885 | valid_recall: 0.49347 | 0:00:12s epoch 9 | loss: 0.57048 | train_auc: 0.76021 | train_recall: 0.53283 | valid_auc: 0.75155 | valid_recall: 0.50653 | 0:00:13s epoch 10 | loss: 0.57483 | train_auc: 0.76547 | train_recall: 0.55867 | valid_auc: 0.76387 | valid_recall: 0.53668 | 0:00:14s epoch 11 | loss: 0.57675 | train_auc: 0.76774 | train_recall: 0.55587 | valid_auc: 0.76215 | valid_recall: 0.53367 | 0:00:16s epoch 12 | loss: 0.57628 | train_auc: 0.7716 | train_recall: 0.53348 | valid_auc: 0.76371 | valid_recall: 0.51457 | 0:00:17s epoch 13 | loss: 0.57112 | train_auc: 0.77359 | train_recall: 0.5746 | valid_auc: 0.76409 | valid_recall: 0.53869 | 0:00:18s epoch 14 | loss: 0.56588 | train_auc: 0.77446 | train_recall: 0.57524 | valid_auc: 0.76292 | valid_recall: 0.55477 | 0:00:19s epoch 15 | loss: 0.56647 | train_auc: 0.77612 | train_recall: 0.55608 | valid_auc: 0.76377 | valid_recall: 0.53668 | 0:00:21s epoch 16 | loss: 0.56715 | train_auc: 0.77478 | train_recall: 0.58385 | valid_auc: 0.76256 | valid_recall: 0.55578 | 0:00:22s epoch 17 | loss: 0.57186 | train_auc: 0.77721 | train_recall: 0.62476 | valid_auc: 0.76724 | valid_recall: 0.60804 | 0:00:23s epoch 18 | loss: 0.56314 | train_auc: 0.7783 | train_recall: 0.55845 | valid_auc: 0.77098 | valid_recall: 0.5407 | 0:00:25s epoch 19 | loss: 0.56288 | train_auc: 0.77915 | train_recall: 0.55673 | valid_auc: 0.76882 | valid_recall: 0.52663 | 0:00:26s epoch 20 | loss: 0.558 | train_auc: 0.78156 | train_recall: 0.56448 | valid_auc: 0.76819 | valid_recall: 0.54271 | 0:00:27s epoch 21 | loss: 0.56744 | train_auc: 0.78057 | train_recall: 0.53735 | valid_auc: 0.76705 | valid_recall: 0.51859 | 0:00:29s epoch 22 | loss: 0.56238 | train_auc: 0.77992 | train_recall: 0.59462 | valid_auc: 0.76765 | valid_recall: 0.58894 | 0:00:30s epoch 23 | loss: 0.56762 | train_auc: 0.77899 | train_recall: 0.64069 | valid_auc: 0.76628 | valid_recall: 0.6191 | 0:00:31s epoch 24 | loss: 0.56032 | train_auc: 0.78407 | train_recall: 0.58105 | valid_auc: 0.77206 | valid_recall: 0.57085 | 0:00:33s epoch 25 | loss: 0.55865 | train_auc: 0.7855 | train_recall: 0.54575 | valid_auc: 0.77309 | valid_recall: 0.52663 | 0:00:35s epoch 26 | loss: 0.55971 | train_auc: 0.78605 | train_recall: 0.56211 | valid_auc: 0.77086 | valid_recall: 0.54975 | 0:00:36s epoch 27 | loss: 0.5653 | train_auc: 0.78738 | train_recall: 0.53477 | valid_auc: 0.77474 | valid_recall: 0.5206 | 0:00:38s epoch 28 | loss: 0.56239 | train_auc: 0.78801 | train_recall: 0.61119 | valid_auc: 0.77231 | valid_recall: 0.59497 | 0:00:39s epoch 29 | loss: 0.56089 | train_auc: 0.78539 | train_recall: 0.61787 | valid_auc: 0.77069 | valid_recall: 0.60201 | 0:00:40s epoch 30 | loss: 0.55615 | train_auc: 0.78558 | train_recall: 0.57094 | valid_auc: 0.76635 | valid_recall: 0.55477 | 0:00:42s epoch 31 | loss: 0.55757 | train_auc: 0.78605 | train_recall: 0.61636 | valid_auc: 0.77302 | valid_recall: 0.60302 | 0:00:43s epoch 32 | loss: 0.55584 | train_auc: 0.78387 | train_recall: 0.55285 | valid_auc: 0.7725 | valid_recall: 0.5397 | 0:00:45s epoch 33 | loss: 0.55765 | train_auc: 0.78776 | train_recall: 0.58062 | valid_auc: 0.77333 | valid_recall: 0.56482 | 0:00:46s epoch 34 | loss: 0.56327 | train_auc: 0.78804 | train_recall: 0.59203 | valid_auc: 0.77051 | valid_recall: 0.57588 | 0:00:47s epoch 35 | loss: 0.55651 | train_auc: 0.78608 | train_recall: 0.5761 | valid_auc: 0.77021 | valid_recall: 0.55377 | 0:00:49s epoch 36 | loss: 0.55767 | train_auc: 0.78749 | train_recall: 0.58837 | valid_auc: 0.77065 | valid_recall: 0.57487 | 0:00:50s epoch 37 | loss: 0.55989 | train_auc: 0.78954 | train_recall: 0.55759 | valid_auc: 0.773 | valid_recall: 0.54171 | 0:00:51s epoch 38 | loss: 0.55281 | train_auc: 0.79089 | train_recall: 0.60947 | valid_auc: 0.77575 | valid_recall: 0.59698 | 0:00:53s epoch 39 | loss: 0.55517 | train_auc: 0.78749 | train_recall: 0.5718 | valid_auc: 0.77657 | valid_recall: 0.55377 | 0:00:54s epoch 40 | loss: 0.56106 | train_auc: 0.78877 | train_recall: 0.57051 | valid_auc: 0.7724 | valid_recall: 0.54573 | 0:00:55s epoch 41 | loss: 0.55624 | train_auc: 0.78805 | train_recall: 0.57137 | valid_auc: 0.77056 | valid_recall: 0.55075 | 0:00:57s epoch 42 | loss: 0.56028 | train_auc: 0.78509 | train_recall: 0.6028 | valid_auc: 0.76955 | valid_recall: 0.58191 | 0:00:58s epoch 43 | loss: 0.56235 | train_auc: 0.7891 | train_recall: 0.55651 | valid_auc: 0.77126 | valid_recall: 0.5407 | 0:00:59s Early stopping occurred at epoch 43 with best_epoch = 23 and best_valid_recall = 0.6191 Best weights from best epoch are automatically used!
history_df = pd.DataFrame(tabnet.history.history)
history_df.head(10)
loss | lr | train_auc | train_recall | valid_auc | valid_recall | |
---|---|---|---|---|---|---|
0 | 0.698671 | 0.02 | 0.614612 | 0.378902 | 0.622316 | 0.372864 |
1 | 0.623422 | 0.02 | 0.705383 | 0.515393 | 0.690526 | 0.487437 |
2 | 0.599019 | 0.02 | 0.717768 | 0.516254 | 0.716672 | 0.486432 |
3 | 0.596287 | 0.02 | 0.734278 | 0.526803 | 0.727671 | 0.494472 |
4 | 0.583834 | 0.02 | 0.751256 | 0.527234 | 0.746931 | 0.497487 |
5 | 0.580650 | 0.02 | 0.757071 | 0.527664 | 0.751774 | 0.496482 |
6 | 0.584366 | 0.02 | 0.760852 | 0.536706 | 0.753551 | 0.504523 |
7 | 0.580403 | 0.02 | 0.762065 | 0.546825 | 0.753458 | 0.521608 |
8 | 0.576649 | 0.02 | 0.769265 | 0.526588 | 0.758850 | 0.493467 |
9 | 0.570478 | 0.02 | 0.760208 | 0.532831 | 0.751546 | 0.506533 |
history_df["loss"].plot(
title="Loss over epochs",
xlabel="epochs",
ylabel="loss"
)
plt.tight_layout()
sns.despine()
# plt.savefig("images/figure_15_7")
(
history_df[["train_auc", "valid_auc"]]
.plot(title="AUC over epochs",
xlabel="epochs",
ylabel="AUC")
);
plt.tight_layout()
sns.despine()
(
history_df[["train_recall", "valid_recall"]]
.plot(title="Recall over epochs",
xlabel="epochs",
ylabel="recall")
);
plt.tight_layout()
sns.despine()
# plt.savefig("images/figure_15_8")
y_pred = tabnet.predict(X_test.values)
print(f"Best validation score: {tabnet.best_cost:.4f}")
print(f"Test set score: {recall_score(y_test, y_pred):.4f}")
Best validation score: 0.6191 Test set score: 0.6275
tabnet_feat_imp = pd.Series(tabnet.feature_importances_,
index=X_train.columns)
(
tabnet_feat_imp
.nlargest(20)
.sort_values()
.plot(kind="barh",
title="TabNet's feature importances")
)
plt.tight_layout()
sns.despine()
# plt.savefig("images/figure_15_9")
np.sum(tabnet.feature_importances_)
1.0
explain_matrix, masks = tabnet.explain(X_test.values)
fig, axs = plt.subplots(1, 3)
for i in range(3):
axs[i].imshow(masks[i][:50])
axs[i].set_title(f"mask {i}")
explain_matrix.shape
(4500, 23)
X_test.shape
(4500, 23)
# save tabnet model
MODEL_PATH = "models/tabnet_model"
saved_filepath = tabnet.save_model(MODEL_PATH)
# define new model with basic parameters and load state dict weights
loaded_tabnet = TabNetClassifier()
loaded_tabnet.load_model(saved_filepath)
Successfully saved model at models/tabnet_model.zip Device used : cpu Device used : cpu
import warnings
warnings.filterwarnings("ignore")
import matplotlib.pyplot as plt
import pandas as pd
import torch
import yfinance as yf
from random import sample, seed
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_forecasting import DeepAR, TimeSeriesDataSet
df = pd.read_html(
"https://en.wikipedia.org/wiki/List_of_S%26P_500_companies"
)
df = df[0]
seed(44)
sampled_tickers = sample(df["Symbol"].to_list(), 100)
raw_df = yf.download(sampled_tickers,
start="2020-01-01",
end="2021-12-31")
[*********************100%***********************] 100 of 100 completed 2 Failed downloads: - BF.B: No data found for this date range, symbol may be delisted - CEG: Data doesn't exist for startDate = 1577833200, endDate = 1640905200
df = raw_df["Adj Close"]
df = df.loc[:, ~df.isna().any()]
selected_tickers = df.columns
df.head()
ABC | ABMD | ADM | AJG | ALB | ALGN | ALL | ANET | APD | AVB | ... | TDY | TFX | VRSK | WAB | WBD | WDC | WELL | WTW | XRAY | XYL | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
Date | |||||||||||||||||||||
2019-12-31 00:00:00 | 81.503708 | 170.589996 | 43.140732 | 91.528305 | 70.897469 | 279.040009 | 105.060913 | 50.849998 | 220.630493 | 190.954117 | ... | 346.540009 | 372.275787 | 146.726288 | 76.436058 | 32.740002 | 62.155804 | 73.971069 | 194.330948 | 55.088299 | 76.240341 |
2020-01-02 00:00:00 | 81.561241 | 168.809998 | 42.917355 | 91.797424 | 70.480080 | 283.679993 | 105.406624 | 51.180000 | 217.015732 | 188.714035 | ... | 357.489990 | 374.303070 | 148.475128 | 79.530838 | 32.220001 | 64.771545 | 72.487686 | 196.582764 | 55.419270 | 77.266037 |
2020-01-03 00:00:00 | 80.535492 | 166.820007 | 42.833588 | 91.605194 | 69.470596 | 280.440002 | 105.415955 | 50.212502 | 212.189835 | 190.526108 | ... | 360.049988 | 370.475891 | 149.919418 | 78.921715 | 32.029999 | 63.774597 | 73.763054 | 196.630890 | 54.805992 | 77.720825 |
2020-01-06 00:00:00 | 81.714607 | 179.039993 | 42.498516 | 92.028084 | 69.392937 | 285.880005 | 105.724274 | 50.715000 | 212.095917 | 190.844849 | ... | 358.630005 | 374.797546 | 150.263290 | 78.597488 | 31.959999 | 62.550629 | 74.884636 | 196.871460 | 55.107769 | 77.217651 |
2020-01-07 00:00:00 | 81.129852 | 180.350006 | 41.986595 | 91.038124 | 70.305367 | 283.059998 | 104.818016 | 51.212502 | 212.997253 | 186.692459 | ... | 361.089996 | 374.223938 | 151.520905 | 78.568024 | 32.070000 | 66.785164 | 74.396202 | 196.467316 | 55.399799 | 76.927368 |
5 rows × 97 columns
df = df.reset_index(drop=False)
df = (
pd.melt(df,
id_vars=["Date"],
value_vars=selected_tickers,
value_name="price"
).rename(columns={"variable": "ticker"})
)
df["time_idx"] = df.groupby("ticker").cumcount()
df
Date | ticker | price | time_idx | |
---|---|---|---|---|
0 | 2019-12-31 | ABC | 81.503708 | 0 |
1 | 2020-01-02 | ABC | 81.561241 | 1 |
2 | 2020-01-03 | ABC | 80.535492 | 2 |
3 | 2020-01-06 | ABC | 81.714607 | 3 |
4 | 2020-01-07 | ABC | 81.129852 | 4 |
... | ... | ... | ... | ... |
48980 | 2021-12-23 | XYL | 116.300835 | 500 |
48981 | 2021-12-27 | XYL | 117.082787 | 501 |
48982 | 2021-12-28 | XYL | 118.300217 | 502 |
48983 | 2021-12-29 | XYL | 118.141861 | 503 |
48984 | 2021-12-30 | XYL | 117.884506 | 504 |
48985 rows × 4 columns
df.info()
<class 'pandas.core.frame.DataFrame'> RangeIndex: 48985 entries, 0 to 48984 Data columns (total 4 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 Date 48985 non-null datetime64[ns] 1 ticker 48985 non-null object 2 price 48985 non-null float64 3 time_idx 48985 non-null int64 dtypes: datetime64[ns](1), float64(1), int64(1), object(1) memory usage: 1.5+ MB
MAX_ENCODER_LENGTH = 40
MAX_PRED_LENGTH = 20
BATCH_SIZE = 128
MAX_EPOCHS = 30
training_cutoff = df["time_idx"].max() - MAX_PRED_LENGTH
train_set = TimeSeriesDataSet(
df[lambda x: x["time_idx"] <= training_cutoff],
time_idx="time_idx",
target="price",
group_ids=["ticker"],
time_varying_unknown_reals=["price"],
max_encoder_length=MAX_ENCODER_LENGTH,
max_prediction_length=MAX_PRED_LENGTH,
)
valid_set = TimeSeriesDataSet.from_dataset(
train_set, df, min_prediction_idx=training_cutoff+1
)
train_dataloader = train_set.to_dataloader(
train=True, batch_size=BATCH_SIZE
)
valid_dataloader = valid_set.to_dataloader(
train=False, batch_size=BATCH_SIZE
)
pl.seed_everything(42)
deep_ar = DeepAR.from_dataset(
train_set,
learning_rate=1e-2,
hidden_size=30,
rnn_layers=4
)
trainer = pl.Trainer(gradient_clip_val=1e-1)
res = trainer.tuner.lr_find(
deep_ar,
train_dataloaders=train_dataloader,
val_dataloaders=valid_dataloader,
min_lr=1e-5,
max_lr=1e0,
early_stop_threshold=100,
)
print(f"Suggested learning rate: {res.suggestion()}")
fig = res.plot(show=True, suggest=True)
plt.tight_layout()
sns.despine()
# plt.savefig("images/figure_15_11")
Global seed set to 42 GPU available: True (mps), used: False TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs HPU available: False, using: 0 HPUs Finding best initial lr: 100%|██████████| 100/100 [00:15<00:00, 6.69it/s]`Trainer.fit` stopped: `max_steps=100` reached. Finding best initial lr: 100%|██████████| 100/100 [00:15<00:00, 6.52it/s] Restoring states from the checkpoint path at /Users/eryk/Documents/eryk/python_for_finance_2nd_private/15_deep_learning_in_finance/.lr_find_843d0230-d9a9-4dff-a7ba-6ab740ca331b.ckpt
Suggested learning rate: 5.0118723362727245e-05
<Figure size 864x576 with 0 Axes>
pl.seed_everything(42)
deep_ar.hparams.learning_rate = res.suggestion()
early_stop_callback = EarlyStopping(
monitor="val_loss",
min_delta=1e-4,
patience=10
)
trainer = pl.Trainer(
max_epochs=MAX_EPOCHS,
gradient_clip_val=0.1,
callbacks=[early_stop_callback]
)
trainer.fit(
deep_ar,
train_dataloaders=train_dataloader,
val_dataloaders=valid_dataloader,
)
Global seed set to 42 GPU available: True (mps), used: False TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs HPU available: False, using: 0 HPUs Missing logger folder: /Users/eryk/Documents/eryk/python_for_finance_2nd_private/15_deep_learning_in_finance/lightning_logs | Name | Type | Params ------------------------------------------------------------------ 0 | loss | NormalDistributionLoss | 0 1 | logging_metrics | ModuleList | 0 2 | embeddings | MultiEmbedding | 0 3 | rnn | LSTM | 26.3 K 4 | distribution_projector | Linear | 62 ------------------------------------------------------------------ 26.3 K Trainable params 0 Non-trainable params 26.3 K Total params 0.105 Total estimated model params size (MB)
Epoch 29: 100%|██████████| 323/323 [00:49<00:00, 6.55it/s, loss=2.5, v_num=0, train_loss_step=2.400, val_loss=2.380, train_loss_epoch=2.510]
`Trainer.fit` stopped: `max_epochs=30` reached.
Epoch 29: 100%|██████████| 323/323 [00:49<00:00, 6.55it/s, loss=2.5, v_num=0, train_loss_step=2.400, val_loss=2.380, train_loss_epoch=2.510]
best_model = DeepAR.load_from_checkpoint(
trainer.checkpoint_callback.best_model_path
)
raw_predictions, x = best_model.predict(
valid_dataloader,
mode="raw",
return_x=True,
n_samples=100
)
tickers = valid_set.x_to_index(x)["ticker"]
for idx in range(5):
best_model.plot_prediction(
x, raw_predictions, idx=idx, add_loss_to_title=True
)
plt.suptitle(f"Ticker: {tickers.iloc[idx]}")
plt.tight_layout()
sns.despine()
# plt.savefig(f"images/figure_15_12_{idx}")
from pytorch_forecasting.metrics import MultivariateNormalDistributionLoss
import seaborn as sns
import numpy as np
# df = generate_ar_data(
# seasonality=10.0,
# timesteps=len(raw_df),
# n_series=len(selected_tickers),
# seed=42
# )
# df["date"] = pd.Timestamp("2020-01-01") + pd.to_timedelta(df.time_idx, "D")
# df = df.astype(dict(series=str))
# df.columns = ["ticker", "time_idx", "price", "date"]
# df
batch_sampler
:train_set = TimeSeriesDataSet(
df[lambda x: x["time_idx"] <= training_cutoff],
time_idx="time_idx",
target="price",
group_ids=["ticker"],
static_categoricals=["ticker"],
time_varying_unknown_reals=["price"],
max_encoder_length=MAX_ENCODER_LENGTH,
max_prediction_length=MAX_PRED_LENGTH,
)
valid_set = TimeSeriesDataSet.from_dataset(
train_set, df, min_prediction_idx=training_cutoff+1
)
train_dataloader = train_set.to_dataloader(
train=True,
batch_size=BATCH_SIZE,
batch_sampler="synchronized"
)
valid_dataloader = valid_set.to_dataloader(
train=False,
batch_size=BATCH_SIZE,
batch_sampler="synchronized"
)
pl.seed_everything(42)
deep_var = DeepAR.from_dataset(
train_set,
learning_rate=1e-2,
hidden_size=30,
rnn_layers=4,
loss=MultivariateNormalDistributionLoss()
)
trainer = pl.Trainer(gradient_clip_val=1e-1)
res = trainer.tuner.lr_find(
deep_var,
train_dataloaders=train_dataloader,
val_dataloaders=valid_dataloader,
min_lr=1e-5,
max_lr=1e0,
early_stop_threshold=100,
)
print(f"Suggested learning rate: {res.suggestion()}")
fig = res.plot(show=True, suggest=True)
Global seed set to 42 GPU available: True (mps), used: False TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs HPU available: False, using: 0 HPUs Finding best initial lr: 100%|██████████| 100/100 [00:12<00:00, 8.32it/s]`Trainer.fit` stopped: `max_steps=100` reached. Finding best initial lr: 100%|██████████| 100/100 [00:12<00:00, 8.25it/s] Restoring states from the checkpoint path at /Users/eryk/Documents/eryk/python_for_finance_2nd_private/15_deep_learning_in_finance/.lr_find_738ad47e-9c6c-4cf2-845a-96c7059437c6.ckpt
Suggested learning rate: 8.912509381337456e-05
pl.seed_everything(42)
deep_var.hparams.learning_rate = res.suggestion()
early_stop_callback = EarlyStopping(
monitor="val_loss",
min_delta=1e-4,
patience=10
)
trainer = pl.Trainer(
max_epochs=MAX_EPOCHS,
gradient_clip_val=0.1,
callbacks=[early_stop_callback]
)
trainer.fit(
deep_var,
train_dataloaders=train_dataloader,
val_dataloaders=valid_dataloader,
)
Global seed set to 42 GPU available: True (mps), used: False TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs HPU available: False, using: 0 HPUs | Name | Type | Params ------------------------------------------------------------------------------ 0 | loss | MultivariateNormalDistributionLoss | 0 1 | logging_metrics | ModuleList | 0 2 | embeddings | MultiEmbedding | 2.0 K 3 | rnn | LSTM | 28.8 K 4 | distribution_projector | Linear | 372 ------------------------------------------------------------------------------ 31.2 K Trainable params 0 Non-trainable params 31.2 K Total params 0.125 Total estimated model params size (MB)
Epoch 29: 100%|██████████| 427/427 [00:57<00:00, 7.43it/s, loss=222, v_num=1, train_loss_step=248.0, val_loss=196.0, train_loss_epoch=214.0]
`Trainer.fit` stopped: `max_epochs=30` reached.
Epoch 29: 100%|██████████| 427/427 [00:57<00:00, 7.43it/s, loss=222, v_num=1, train_loss_step=248.0, val_loss=196.0, train_loss_epoch=214.0]
best_model = DeepAR.load_from_checkpoint(
trainer.checkpoint_callback.best_model_path
)
raw_predictions, x = best_model.predict(
valid_dataloader,
mode="raw",
return_x=True,
n_samples=100
)
tickers = valid_set.x_to_index(x)["ticker"]
for idx in range(5):
best_model.plot_prediction(
x, raw_predictions, idx=idx, add_loss_to_title=True
)
plt.suptitle(f"Ticker: {tickers.iloc[idx]}")
preds = best_model.predict(valid_dataloader,
mode=("raw", "prediction"),
n_samples=None)
cov_matrix = (
best_model
.loss
.map_x_to_distribution(preds)
.base_dist
.covariance_matrix
.mean(0)
)
# normalize the covariance matrix diagonal to 1.0
cov_diag_mult = (
torch.diag(cov_matrix)[None] * torch.diag(cov_matrix)[None].T
)
corr_matrix = cov_matrix / torch.sqrt(cov_diag_mult)
mask = np.triu(np.ones_like(corr_matrix, dtype=bool))
fif, ax = plt.subplots()
cmap = sns.diverging_palette(230, 20, as_cmap=True)
sns.heatmap(
corr_matrix, mask=mask, cmap=cmap,
vmax=.3, center=0, square=True,
linewidths=.5, cbar_kws={"shrink": .5}
)
ax.set_title("Correlation matrix")
plt.tight_layout()
sns.despine()
# plt.savefig("images/figure_15_14")
# distribution of off-diagonal correlations
plt.hist(corr_matrix[corr_matrix < 1].numpy())
plt.xlabel("Correlation")
plt.ylabel("Count")
plt.tight_layout()
sns.despine()
# plt.savefig("images/figure_15_15")
import yfinance as yf
import pandas as pd
from neuralprophet import NeuralProphet
from neuralprophet.utils import set_random_seed
from neuralprophet.utils import set_log_level
df = yf.download("^GSPC",
start="2010-01-01",
end="2021-12-31")
df = df[["Adj Close"]].reset_index(drop=False)
df.columns = ["ds", "y"]
[*********************100%***********************] 1 of 1 completed
TEST_LENGTH = 60
df_train = df.iloc[:-TEST_LENGTH]
df_test = df.iloc[-TEST_LENGTH:]
from matplotlib.pyplot import xlabel
set_random_seed(42)
set_log_level(log_level="ERROR")
model = NeuralProphet(changepoints_range=0.95)
metrics = model.fit(df_train, freq="B")
(
metrics
.drop(columns=["RegLoss"])
.plot(title="Evaluation metrics during training",
subplots=True,
xlabel="epochs",
ylabel="metric")
)
plt.tight_layout()
sns.despine()
# plt.savefig("images/figure_15_16")
99%|█████████▊| 135/137 [00:00<00:00, 762.30it/s] 99%|█████████▉| 136/137 [00:00<00:00, 834.56it/s]
pred_df = model.predict(df)
pred_df.plot(x="ds", y=["y", "yhat1"],
title="S&P 500 - forecast vs ground truth",
ylabel="value");
plt.tight_layout()
sns.despine()
# plt.savefig("images/figure_15_17")
(
pred_df
.iloc[-TEST_LENGTH:]
.plot(x="ds", y=["y", "yhat1"],
title="S&P 500 - forecast vs ground truth",
ylabel="value")
);
plt.tight_layout()
sns.despine()
# plt.savefig("images/figure_15_18")
model.test(df_test)
SmoothL1Loss | MAE | RMSE | |
---|---|---|---|
0 | 0.000336 | 65.007812 | 74.876572 |
set_random_seed(42)
set_log_level(log_level="ERROR")
model = NeuralProphet(
changepoints_range=0.95,
n_lags=10,
ar_reg=1,
)
metrics = model.fit(df_train, freq="B")
pred_df = model.predict(df)
pred_df.plot(x="ds", y=["y", "yhat1"],
title="S&P 500 - forecast vs ground truth",
ylabel="value");
plt.tight_layout()
sns.despine()
# plt.savefig("images/figure_15_19")
100%|██████████| 137/137 [00:00<00:00, 921.94it/s] 100%|██████████| 137/137 [00:00<00:00, 904.97it/s]
(
pred_df
.iloc[-TEST_LENGTH:]
.plot(x="ds", y=["y", "yhat1"],
title="S&P 500 - forecast vs ground truth",
ylabel="value")
);
plt.tight_layout()
sns.despine()
# plt.savefig("images/figure_15_20")
set_random_seed(42)
set_log_level(log_level="ERROR")
model = NeuralProphet(
changepoints_range=0.95,
n_lags=10,
ar_reg=1,
num_hidden_layers=3,
d_hidden=32,
)
metrics = model.fit(df_train, freq="B")
pred_df = model.predict(df)
(
pred_df
.iloc[-TEST_LENGTH:]
.plot(x="ds", y=["y", "yhat1"],
title="S&P 500 - forecast vs ground truth",
ylabel="value")
);
plt.tight_layout()
sns.despine()
# plt.savefig("images/figure_15_21")
100%|██████████| 137/137 [00:00<00:00, 728.68it/s] 82%|████████▏ | 113/137 [00:00<00:00, 559.38it/s]
model.test(df_test)
SmoothL1Loss | MAE | RMSE | |
---|---|---|---|
0 | 0.000311 | 65.459953 | 72.100792 |
# for plotting only, as there is some issue with the AR plot
# after plotting the components we can revert to the settings at the top of the Notebook
import matplotlib as mpl
mpl.rcParams.update(mpl.rcParamsDefault)
model.plot_components(model.predict(df_train));
plt.tight_layout()
sns.despine()
# plt.savefig("images/figure_15_22")
model.plot_parameters();
plt.tight_layout()
sns.despine()
# plt.savefig("images/figure_15_23")
set_random_seed(42)
set_log_level(log_level="ERROR")
model = NeuralProphet(
changepoints_range=0.95,
n_lags=10,
ar_reg=1,
num_hidden_layers=3,
d_hidden=32,
)
model = model.add_country_holidays(
"US", lower_window=-1, upper_window=1
)
metrics = model.fit(df_train, freq="B")
80%|████████ | 110/137 [00:00<00:00, 644.55it/s] 100%|██████████| 137/137 [00:00<00:00, 637.01it/s]
pred_df = model.predict(df_train)
model.plot_components(pred_df)
plt.tight_layout()
sns.despine()
# plt.savefig("images/figure_15_24")
model.plot_parameters();
set_random_seed(42)
set_log_level(log_level="ERROR")
model = NeuralProphet(
n_lags=10,
n_forecasts=10,
ar_reg=1,
learning_rate=0.01
)
metrics = model.fit(df_train, freq="B")
pred_df = model.predict(df)
pred_df.tail()
ds | y | yhat1 | residual1 | yhat2 | residual2 | yhat3 | residual3 | yhat4 | residual4 | ... | ar4 | ar5 | ar6 | ar7 | ar8 | ar9 | ar10 | trend | season_yearly | season_weekly | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
3126 | 2021-12-24 | 4758.489990 | 4650.537109 | -107.952881 | 4658.835938 | -99.654053 | 4628.593262 | -129.896729 | 4628.078125 | -130.411865 | ... | 3135.196533 | 3184.866943 | 3190.829102 | 3203.072998 | 3161.448975 | 3186.029541 | 3196.741699 | 1177.05127 | -3.653834 | 319.484009 |
3127 | 2021-12-27 | 4791.189941 | 4678.010742 | -113.179199 | 4684.404785 | -106.785156 | 4667.214355 | -123.975586 | 4644.814941 | -146.375 | ... | 3152.4646 | 3068.444092 | 3219.397217 | 3183.488525 | 3173.246338 | 3146.842773 | 3235.819336 | 1177.903931 | -3.819474 | 318.265717 |
3128 | 2021-12-28 | 4786.350098 | 4733.727051 | -52.623047 | 4726.821777 | -59.52832 | 4719.004395 | -67.345703 | 4654.171875 | -132.178223 | ... | 3160.530518 | 3120.293457 | 3135.982178 | 3160.213867 | 3240.533203 | 3186.236328 | 3165.561279 | 1178.18811 | -3.757542 | 319.211151 |
3129 | 2021-12-29 | 4793.060059 | 4743.061035 | -49.999023 | 4752.594238 | -40.46582 | 4732.162598 | -60.897461 | 4711.364746 | -81.695312 | ... | 3217.064453 | 3193.08374 | 3127.438965 | 3115.985596 | 3198.358643 | 3185.796875 | 3191.491699 | 1178.472412 | -3.634705 | 319.46283 |
3130 | 2021-12-30 | 4778.729980 | 4739.231934 | -39.498047 | 4743.820801 | -34.90918 | 4716.883789 | -61.846191 | 4736.578125 | -42.151855 | ... | 3242.431152 | 3233.673828 | 3179.129639 | 3142.590088 | 3121.921387 | 3162.443604 | 3235.914551 | 1178.756592 | -3.450272 | 318.840424 |
5 rows × 35 columns
# set_random_seed(42)
pred_df = model.predict(df, raw=True, decompose=False)
pred_df.tail().round(2)
ds | step0 | step1 | step2 | step3 | step4 | step5 | step6 | step7 | step8 | step9 | |
---|---|---|---|---|---|---|---|---|---|---|---|
3116 | 2021-12-24 | 4650.540039 | 4684.399902 | 4719.000000 | 4711.359863 | 4727.819824 | 4667.189941 | 4714.819824 | 4683.540039 | 4706.020020 | 4724.060059 |
3117 | 2021-12-27 | 4678.009766 | 4726.819824 | 4732.160156 | 4736.580078 | 4733.509766 | 4689.540039 | 4733.549805 | 4748.560059 | 4753.779785 | 4767.770020 |
3118 | 2021-12-28 | 4733.729980 | 4752.589844 | 4716.879883 | 4745.649902 | 4740.490234 | 4727.629883 | 4744.220215 | 4767.180176 | 4777.370117 | 4781.660156 |
3119 | 2021-12-29 | 4743.060059 | 4743.819824 | 4702.350098 | 4733.750000 | 4747.859863 | 4734.129883 | 4760.209961 | 4742.779785 | 4779.910156 | 4803.180176 |
3120 | 2021-12-30 | 4739.229980 | 4752.520020 | 4730.939941 | 4760.689941 | 4776.040039 | 4763.520020 | 4772.279785 | 4735.459961 | 4766.009766 | 4802.129883 |
pred_df = model.predict(df_test)
model.plot(pred_df)
ax = plt.gca()
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
ax.set_title("10-day ahead multi-step forecast")
plt.tight_layout()
sns.despine()
# plt.savefig("images/figure_15_27")
model = model.highlight_nth_step_ahead_of_each_forecast(1)
model.plot(pred_df)
ax = plt.gca()
ax.set_title("Step 1 of the 10-day ahead multi-step forecast")
plt.tight_layout()
sns.despine()
# plt.savefig("images/figure_15_28")