import math
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, average_precision_score, accuracy_score
from tensorflow.keras.callbacks import EarlyStopping
from tabtransformertf.models.tabtransformer import TabTransformer
from tabtransformertf.utils.preprocessing import df_to_dataset, build_categorical_prep
CSV_HEADER = [
"age",
"workclass",
"fnlwgt",
"education",
"education_num",
"marital_status",
"occupation",
"relationship",
"race",
"gender",
"capital_gain",
"capital_loss",
"hours_per_week",
"native_country",
"income_bracket",
]
train_data_url = (
"https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data"
)
train_data = pd.read_csv(train_data_url, header=None, names=CSV_HEADER)
test_data_url = (
"https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test"
)
test_data = pd.read_csv(test_data_url, header=None, names=CSV_HEADER)
print(f"Train dataset shape: {train_data.shape}")
print(f"Test dataset shape: {test_data.shape}")
Train dataset shape: (32561, 15) Test dataset shape: (16282, 15)
train_data.head()
age | workclass | fnlwgt | education | education_num | marital_status | occupation | relationship | race | gender | capital_gain | capital_loss | hours_per_week | native_country | income_bracket | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 39 | State-gov | 77516 | Bachelors | 13 | Never-married | Adm-clerical | Not-in-family | White | Male | 2174 | 0 | 40 | United-States | <=50K |
1 | 50 | Self-emp-not-inc | 83311 | Bachelors | 13 | Married-civ-spouse | Exec-managerial | Husband | White | Male | 0 | 0 | 13 | United-States | <=50K |
2 | 38 | Private | 215646 | HS-grad | 9 | Divorced | Handlers-cleaners | Not-in-family | White | Male | 0 | 0 | 40 | United-States | <=50K |
3 | 53 | Private | 234721 | 11th | 7 | Married-civ-spouse | Handlers-cleaners | Husband | Black | Male | 0 | 0 | 40 | United-States | <=50K |
4 | 28 | Private | 338409 | Bachelors | 13 | Married-civ-spouse | Prof-specialty | Wife | Black | Female | 0 | 0 | 40 | Cuba | <=50K |
# Column information
NUMERIC_FEATURES = train_data.select_dtypes(include=np.number).columns
CATEGORICAL_FEATURES = train_data.select_dtypes(exclude=np.number).columns[:-1] # exclude label column and DT
FEATURES = list(NUMERIC_FEATURES) + list(CATEGORICAL_FEATURES)
LABEL = 'income_bracket'
# encoding as binary target
train_data[LABEL] = train_data[LABEL].apply(lambda x: int(x == ' >50K'))
test_data[LABEL] = test_data[LABEL].apply(lambda x: int(x == ' >50K.'))
train_data[LABEL].mean(), test_data[LABEL].mean()
(0.2408095574460244, 0.23621176759611842)
test_data = test_data.iloc[1:, :] # drop invalid row
# Set data types
train_data[CATEGORICAL_FEATURES] = train_data[CATEGORICAL_FEATURES].astype(str)
test_data[CATEGORICAL_FEATURES] = test_data[CATEGORICAL_FEATURES].astype(str)
train_data[NUMERIC_FEATURES] = train_data[NUMERIC_FEATURES].astype(float)
test_data[NUMERIC_FEATURES] = test_data[NUMERIC_FEATURES].astype(float)
# Train/test split
X_train, X_val = train_test_split(train_data, test_size=0.2)
# Category preprocessing layers
category_prep_layers = build_categorical_prep(X_train, CATEGORICAL_FEATURES)
0%| | 0/8 [00:00<?, ?it/s]2022-10-19 17:30:17.543008: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags. 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 130.63it/s]
# To TF Dataset
train_dataset = df_to_dataset(X_train[FEATURES + [LABEL]], LABEL)
val_dataset = df_to_dataset(X_val[FEATURES + [LABEL]], LABEL, shuffle=False) # No shuffle
test_dataset = df_to_dataset(test_data[FEATURES + [LABEL]], shuffle=False) # No target, no shuffle
/Users/antonsruberts/personal/TabTransformerTF/tabtransformertf/utils/preprocessing.py:20: FutureWarning: Support for multi-dimensional indexing (e.g. `obj[:, None]`) is deprecated and will be removed in a future version. Convert to a numpy array before indexing instead. dataset[key] = value[:, tf.newaxis] /Users/antonsruberts/personal/TabTransformerTF/tabtransformertf/utils/preprocessing.py:26: FutureWarning: Support for multi-dimensional indexing (e.g. `obj[:, None]`) is deprecated and will be removed in a future version. Convert to a numpy array before indexing instead. dataset[key] = value[:, tf.newaxis]
tabtransformer = TabTransformer(
numerical_features = NUMERIC_FEATURES,
categorical_features = CATEGORICAL_FEATURES,
categorical_lookup=category_prep_layers,
embedding_dim=32,
out_dim=1,
out_activation='sigmoid',
depth=4,
heads=8,
attn_dropout=0.2,
ff_dropout=0.2,
mlp_hidden_factors=[2, 4],
use_column_embedding=True,
)
LEARNING_RATE = 0.0001
WEIGHT_DECAY = 0.0001
NUM_EPOCHS = 1000
optimizer = tfa.optimizers.AdamW(
learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY
)
tabtransformer.compile(
optimizer = optimizer,
loss = tf.keras.losses.BinaryCrossentropy(),
metrics= [tf.keras.metrics.AUC(name="PR AUC", curve='PR')],
)
early = EarlyStopping(monitor="val_loss", mode="min", patience=10, restore_best_weights=True)
callback_list = [early]
history = tabtransformer.fit(
train_dataset,
epochs=NUM_EPOCHS,
validation_data=val_dataset,
callbacks=callback_list
)
Epoch 1/1000 51/51 [==============================] - 14s 178ms/step - loss: 0.7065 - PR AUC: 0.4247 - val_loss: 0.5152 - val_PR AUC: 0.6368 Epoch 2/1000 51/51 [==============================] - 9s 174ms/step - loss: 0.5650 - PR AUC: 0.5547 - val_loss: 0.4535 - val_PR AUC: 0.6604 Epoch 3/1000 51/51 [==============================] - 9s 179ms/step - loss: 0.5011 - PR AUC: 0.5835 - val_loss: 0.3856 - val_PR AUC: 0.6733 Epoch 4/1000 51/51 [==============================] - 10s 188ms/step - loss: 0.4601 - PR AUC: 0.6010 - val_loss: 0.3788 - val_PR AUC: 0.6787 Epoch 5/1000 51/51 [==============================] - 10s 185ms/step - loss: 0.4330 - PR AUC: 0.6210 - val_loss: 0.3675 - val_PR AUC: 0.6877 Epoch 6/1000 51/51 [==============================] - 9s 183ms/step - loss: 0.4225 - PR AUC: 0.6281 - val_loss: 0.3606 - val_PR AUC: 0.6971 Epoch 7/1000 51/51 [==============================] - 9s 181ms/step - loss: 0.4145 - PR AUC: 0.6384 - val_loss: 0.3563 - val_PR AUC: 0.7069 Epoch 8/1000 51/51 [==============================] - 11s 219ms/step - loss: 0.4062 - PR AUC: 0.6485 - val_loss: 0.3548 - val_PR AUC: 0.7127 Epoch 9/1000 51/51 [==============================] - 11s 204ms/step - loss: 0.3958 - PR AUC: 0.6615 - val_loss: 0.3484 - val_PR AUC: 0.7232 Epoch 10/1000 51/51 [==============================] - 9s 169ms/step - loss: 0.3979 - PR AUC: 0.6554 - val_loss: 0.3526 - val_PR AUC: 0.7249 Epoch 11/1000 51/51 [==============================] - 9s 166ms/step - loss: 0.3904 - PR AUC: 0.6665 - val_loss: 0.3466 - val_PR AUC: 0.7339 Epoch 12/1000 51/51 [==============================] - 8s 158ms/step - loss: 0.3881 - PR AUC: 0.6713 - val_loss: 0.3427 - val_PR AUC: 0.7371 Epoch 13/1000 51/51 [==============================] - 8s 160ms/step - loss: 0.3841 - PR AUC: 0.6744 - val_loss: 0.3408 - val_PR AUC: 0.7390 Epoch 14/1000 51/51 [==============================] - 8s 158ms/step - loss: 0.3812 - PR AUC: 0.6764 - val_loss: 0.3416 - val_PR AUC: 0.7415 Epoch 15/1000 51/51 [==============================] - 10s 188ms/step - loss: 0.3808 - PR AUC: 0.6787 - val_loss: 0.3386 - val_PR AUC: 0.7436 Epoch 16/1000 51/51 [==============================] - 8s 160ms/step - loss: 0.3782 - PR AUC: 0.6795 - val_loss: 0.3384 - val_PR AUC: 0.7431 Epoch 17/1000 51/51 [==============================] - 9s 175ms/step - loss: 0.3775 - PR AUC: 0.6829 - val_loss: 0.3373 - val_PR AUC: 0.7435 Epoch 18/1000 51/51 [==============================] - 9s 167ms/step - loss: 0.3749 - PR AUC: 0.6863 - val_loss: 0.3370 - val_PR AUC: 0.7445 Epoch 19/1000 51/51 [==============================] - 9s 174ms/step - loss: 0.3728 - PR AUC: 0.6910 - val_loss: 0.3364 - val_PR AUC: 0.7447 Epoch 20/1000 51/51 [==============================] - 10s 184ms/step - loss: 0.3731 - PR AUC: 0.6880 - val_loss: 0.3369 - val_PR AUC: 0.7449 Epoch 21/1000 51/51 [==============================] - 8s 158ms/step - loss: 0.3688 - PR AUC: 0.6929 - val_loss: 0.3371 - val_PR AUC: 0.7442 Epoch 22/1000 51/51 [==============================] - 8s 161ms/step - loss: 0.3694 - PR AUC: 0.6930 - val_loss: 0.3382 - val_PR AUC: 0.7450 Epoch 23/1000 51/51 [==============================] - 9s 179ms/step - loss: 0.3650 - PR AUC: 0.7002 - val_loss: 0.3377 - val_PR AUC: 0.7465 Epoch 24/1000 51/51 [==============================] - 9s 171ms/step - loss: 0.3661 - PR AUC: 0.6969 - val_loss: 0.3373 - val_PR AUC: 0.7450 Epoch 25/1000 51/51 [==============================] - 9s 171ms/step - loss: 0.3651 - PR AUC: 0.7005 - val_loss: 0.3354 - val_PR AUC: 0.7464 Epoch 26/1000 51/51 [==============================] - 9s 174ms/step - loss: 0.3616 - PR AUC: 0.7032 - val_loss: 0.3355 - val_PR AUC: 0.7460 Epoch 27/1000 51/51 [==============================] - 9s 166ms/step - loss: 0.3619 - PR AUC: 0.7015 - val_loss: 0.3381 - val_PR AUC: 0.7448 Epoch 28/1000 51/51 [==============================] - 8s 156ms/step - loss: 0.3573 - PR AUC: 0.7084 - val_loss: 0.3369 - val_PR AUC: 0.7460 Epoch 29/1000 51/51 [==============================] - 8s 156ms/step - loss: 0.3567 - PR AUC: 0.7085 - val_loss: 0.3383 - val_PR AUC: 0.7442 Epoch 30/1000 51/51 [==============================] - 8s 161ms/step - loss: 0.3583 - PR AUC: 0.7066 - val_loss: 0.3362 - val_PR AUC: 0.7440 Epoch 31/1000 51/51 [==============================] - 9s 169ms/step - loss: 0.3541 - PR AUC: 0.7156 - val_loss: 0.3370 - val_PR AUC: 0.7460 Epoch 32/1000 51/51 [==============================] - 9s 177ms/step - loss: 0.3558 - PR AUC: 0.7107 - val_loss: 0.3362 - val_PR AUC: 0.7467 Epoch 33/1000 51/51 [==============================] - 8s 155ms/step - loss: 0.3524 - PR AUC: 0.7147 - val_loss: 0.3347 - val_PR AUC: 0.7468 Epoch 34/1000 51/51 [==============================] - 8s 164ms/step - loss: 0.3537 - PR AUC: 0.7144 - val_loss: 0.3340 - val_PR AUC: 0.7465 Epoch 35/1000 51/51 [==============================] - 8s 158ms/step - loss: 0.3521 - PR AUC: 0.7161 - val_loss: 0.3374 - val_PR AUC: 0.7432 Epoch 36/1000 51/51 [==============================] - 8s 155ms/step - loss: 0.3518 - PR AUC: 0.7147 - val_loss: 0.3342 - val_PR AUC: 0.7463 Epoch 37/1000 51/51 [==============================] - 8s 159ms/step - loss: 0.3517 - PR AUC: 0.7166 - val_loss: 0.3341 - val_PR AUC: 0.7463 Epoch 38/1000 51/51 [==============================] - 8s 154ms/step - loss: 0.3512 - PR AUC: 0.7159 - val_loss: 0.3349 - val_PR AUC: 0.7463 Epoch 39/1000 51/51 [==============================] - 8s 159ms/step - loss: 0.3502 - PR AUC: 0.7183 - val_loss: 0.3345 - val_PR AUC: 0.7469 Epoch 40/1000 51/51 [==============================] - 8s 163ms/step - loss: 0.3476 - PR AUC: 0.7225 - val_loss: 0.3334 - val_PR AUC: 0.7474 Epoch 41/1000 51/51 [==============================] - 8s 160ms/step - loss: 0.3487 - PR AUC: 0.7185 - val_loss: 0.3347 - val_PR AUC: 0.7465 Epoch 42/1000 51/51 [==============================] - 8s 160ms/step - loss: 0.3480 - PR AUC: 0.7199 - val_loss: 0.3405 - val_PR AUC: 0.7455 Epoch 43/1000 51/51 [==============================] - 8s 161ms/step - loss: 0.3474 - PR AUC: 0.7230 - val_loss: 0.3350 - val_PR AUC: 0.7470 Epoch 44/1000 51/51 [==============================] - 8s 156ms/step - loss: 0.3476 - PR AUC: 0.7207 - val_loss: 0.3350 - val_PR AUC: 0.7457 Epoch 45/1000 51/51 [==============================] - 8s 160ms/step - loss: 0.3463 - PR AUC: 0.7234 - val_loss: 0.3334 - val_PR AUC: 0.7472 Epoch 46/1000 51/51 [==============================] - 8s 153ms/step - loss: 0.3464 - PR AUC: 0.7251 - val_loss: 0.3334 - val_PR AUC: 0.7464 Epoch 47/1000 51/51 [==============================] - 8s 155ms/step - loss: 0.3433 - PR AUC: 0.7290 - val_loss: 0.3353 - val_PR AUC: 0.7448 Epoch 48/1000 51/51 [==============================] - 8s 159ms/step - loss: 0.3471 - PR AUC: 0.7228 - val_loss: 0.3342 - val_PR AUC: 0.7472 Epoch 49/1000 51/51 [==============================] - 8s 161ms/step - loss: 0.3443 - PR AUC: 0.7261 - val_loss: 0.3347 - val_PR AUC: 0.7460 Epoch 50/1000 51/51 [==============================] - 8s 158ms/step - loss: 0.3444 - PR AUC: 0.7261 - val_loss: 0.3368 - val_PR AUC: 0.7440 Epoch 51/1000 51/51 [==============================] - 8s 152ms/step - loss: 0.3429 - PR AUC: 0.7279 - val_loss: 0.3364 - val_PR AUC: 0.7468 Epoch 52/1000 51/51 [==============================] - 8s 154ms/step - loss: 0.3425 - PR AUC: 0.7290 - val_loss: 0.3356 - val_PR AUC: 0.7473 Epoch 53/1000 51/51 [==============================] - 8s 158ms/step - loss: 0.3438 - PR AUC: 0.7282 - val_loss: 0.3355 - val_PR AUC: 0.7459 Epoch 54/1000 51/51 [==============================] - 8s 155ms/step - loss: 0.3431 - PR AUC: 0.7265 - val_loss: 0.3369 - val_PR AUC: 0.7461 Epoch 55/1000 51/51 [==============================] - 8s 157ms/step - loss: 0.3428 - PR AUC: 0.7267 - val_loss: 0.3361 - val_PR AUC: 0.7466 Epoch 56/1000 51/51 [==============================] - 9s 170ms/step - loss: 0.3408 - PR AUC: 0.7308 - val_loss: 0.3337 - val_PR AUC: 0.7456
test_preds = tabtransformer.predict(test_dataset)
print("Test ROC AUC:", np.round(roc_auc_score(test_data[LABEL], test_preds.ravel()), 4))
print("Test PR AUC:", np.round(average_precision_score(test_data[LABEL], test_preds.ravel()), 4))
print("Test Accuracy:", np.round(accuracy_score(test_data[LABEL], test_preds.ravel() > 0.5), 4))
Test ROC AUC: 0.8959 Test PR AUC: 0.7355 Test Accuracy: 0.8488