"pytorch-widedeep, deep learning for tabular data IV: Deep Learning vs LightGBM"¶

A thorough comparison between DL algorithms and LightGBM for tabular data for classification and regression problems

• author: Javier Rodriguez
• toc: true

Here we go with yet another post in the series. I started planning this posts a few months ago, as soon as I released what it was the last beta version (0.4.8) of the library pytorch-widedeep. However, since then, a few things took priority, which meant that to run the hundreds of experiments that I run (probably over 1500), took me considerably more time than I expected. Nevertheless, here we are.

Let me start by saying thanks to the guys at the AWS community builders and specially to Cameron, for making my life a lot easier around AWS.

All the Deep Learning models for this project were run on a p2.xlarge instance and all the LightGBM experiments were run on my Mac Mid 2015.

Once the proper acknowledgments have been made, let me tell you a bit about the context of all those experiments and eventually this post.

1. Introduction: why all this?¶

Through the last couple of years, and in particular during the last year, I have been putting a lot of effort in improving pytorch-widedeep. This has been really entertaining, and I have learned a lot. However, as I was adding models to the library, especially for the tabular component (see here), I wondered if there was a purpose to it, other than learning those models themselves. You see, I am a scientist in education and I spent over a decade in academia. There we used to do a lot of not-very-useful things, cool (sometimes), but not very useful. One of the aspects that drove me to the private sector, a few years back now, was the search for a sense of "usefulness", where I could build things that have a scientific aspect and at the same time are useful. With that in mind, I wanted the library to be, forgive the redundancy, useful. Here the adjective "useful" can mean a number of things. It could mean directly using the library, or fork the repo and use the code, or just copy and paste some portion of the code for a given project. However eventually, a question that I wanted to answer was: do these models compare well or even improve the performance of other more "standard" models like GBMs?. Note that I write "a question" and not "the question". More on this later in the post.

Of course, I am not the first to compare Deep Learning (hereafter DL) approaches with GBMs for tabular data, and I won't be the last. In fact, by the time I am writing these lines, a new paper: Tabular Data: Deep Learning is Not All You Need [1] was published. This post and that paper are certainly very similar, and the conclusion entirely consistent. However, there are some differences. The compare DL algorithms against XGBoost [2] and CatBoost [3], while I use LightGBM [4] (see Section 2.3 for an explanation on the use of this algorithm). Also, I would say that three of the four datasets that I use here are a bit more challenging that the datasets in their paper, but that might be just my perception. Finally, with the exception of TabNet, the DL models I use here and those in that paper are different. Nonetheless, in the Conclusion section I will write some thoughts on ways to tackle this benchmark/testing exercises.

Aside from that paper, in all papers where they release new models there are often comprehensive comparisons between DL architectures and GBMs. My main caveats with some of these publications are the following: I often do not manage to reproduce the results in the paper (which of course might be my fault) and I sometimes find that the effort placed in optimizing the DL models is a bit more "intense" than that for the GBMs. Last but not least, the lack of consistency in the results tables in some papers is, sometimes, confusing. For example, Paper A will use DL Model A to find that performs better than all GBMs, normally XGBoost, Catboost and LightGBM. Then Paper B will come with a new DL Model B that will also perform better than all GBMs, but in their paper it turns out that Model A does not beat GBMs anymore.

Considering all that, I decided to use pytorch-widedeep and run a sizeable set of experiments comprising different DL models for tabular data and LightGBM.

Before I move on let me comment on the code "quality" in that repo. One has to bear in mind that the goal here is to test algorithms in a rigorous manner, and not to write production code. If you wanted to see better code you can go to the pytorch-widedeep itself or maybe some other of my repos. Just saying in case some "purist" is tempted to waste universe's time.

2. Datasets and Models¶

For the experiments here I have used four datasets and four DL models.

2.1 Datasets¶

2. Bank Marketing (binary classification)
3. NYC taxi ride duration (regression)

The bash script get_data.sh in the repo has all the info you need to get those datasets in case you wanted to explore them yourself. Of course, all the code used to run the experiments and reproduce the results is also available in that repo.

Here are some basic information about the datasets:

In [1]:
#collapse-hide
import pandas as pd

basic_info[basic_info.Dataset != "airbnb"].reset_index(drop=True)

Out[1]:
Dataset n_rows n_cols objective neg_pos_ratio
0 adult 45222 15 binary_classification 0.3295
1 bank_marketing 41188 20 binary_classification 0.1270
2 nyc_taxi 1458644 26 regression NaN

Table 1. Basic information for the datasets used in this post

There are reasons why I choose those datasets.

In general, I looked for a binary, multi-class and regression datasets that had a good number of, if not dominated by categorical features. This is because in my experience, DL models for tabular data become more useful and competitive in sizeable datasets where categorical features are present (although [5] suggests that better results are obtained encoding numerical features as well) and moreover if these categorical features have a lot of categories. This is because the embeddings acquire a more significant value, i.e. we learn representations of those categorical features that encode relationships with all other features and also the target for a specific dataset. Note that this does not happen when using GBMs. Even if one used target encoding, in reality there is not much of a learning element there (still useful of course).

Of course, one could take datasets that are dominated by numerical features and bin them somehow to turn them into categorical. However, this seemed a bit too "forced" for me. With the idea of keeping the content of this post as close as possible to real use cases, it is hard for me to think of many "real world" scenarios where we are provided with datasets dominated by numerical features that are then turned/binned into categorical before being fed to an algorithm. In other words, I did not want to consider datasets where I had to bin the numerical features into categorical just to compare GBMs and DL models.

On the other hand, I also looked for datasets that were already familiar to me or did not required too much feature engineering to get to a stage where the data could be passed to a model. This way I could perhaps save some time on that aspect and focus a bit more on the experimentation, since I intended to run a large number of experiments. Finally I looked for datasets that, to some extent, resemble as much as possible to datasets that one would find in the "real world", but had a tractable size so I could experiment within a reasonable time frame.

While I did manage to find suitable datasets for binary classification and regression, and I did not find datasets that I particularly liked in the case of multi-class classification (if anyone has any suggestion, please comment below and I am happy to give it a go). Perhaps I will include the CoverType dataset in the future, but the one at the UCI ML repository, not the Kaggle's balanced version. For now, I will move on with those four enumerated above. Let me briefly comment on each dataset.

I would refer to the Adult Census dataset as the "easiest dataset", in the sense that simple models (i.e. a Naive Bayes classifier) will already lead to accuracies of $\sim$ 84$\%$ without any feature engineering. Personally, I normally don't find these nice datasets in the real-world. However, it is one of the most popular and well known datasets for ML tutorials, posts etc, and I eventually decided to include it.

The Bank Marketing dataset is also well known. This data is related with direct marketing campaigns based on phone calls, trying to predict whether or not a client will subscribe to a product. In this case it is important to mention a couple of relevant aspects. In the first place I used the original dataset, which is a bit imbalanced (positive to negative class ratio is 0.127). Secondly, you might look around and find that some people obtained better results that those I will show later in the post. All such cases that I found use either a balanced dataset from Kaggle, a feature called duration, or both. The duration feature, which refers to the duration of the call, is something you know after the call and highly affects the target. Therefore, I have not used it in my experiments. This dataset resembles more a real use case than the adult dataset in the sense that the data is imbalanced and the prediction is not an easy task at all. Still, the data size is small and is not that imbalanced.

The NYC taxi ride duration dataset is also well known and is the largest of all datasets I used. Here our goal is to predict the total ride duration of taxi trips in New York City. Instead of getting the dataset from the Kaggle site I manually downloaded an extended version from here, where all the feature engineering had already been done.

Finally the Facebook Comment Volume dataset was another ideal candidate, since it has a good size and all the feature engineering was done for me. Our goal here is to predict the comment volume that posts will receive. In fact this dataset was originally used to compare decision trees versus neural networks. A very detailed description of the dataset and the pre-processing can be found in the original publication [6]. In particular, I used their training Variant - 5 dataset for the experiments in this post, which has 199029 rows and 54 columns.

All the code for the data preparation steps, before the data is fed to the algorithms can be found here

2.2. The DL Models¶

As I mentioned earlier in the post, all DL models were run via pytorch-widedeep. This library offers four wide and deep model components: wide, deeptabular, deeptext, deepimage. Let me briefly comment on each one of them. For more details, please see the companion posts, the documentation or the source code itself.

1. wide: this is just a linear model implemented via an Embedding layer
1. deeptabular: this component will take care of the "standard" tabular data (i.e. categorical and numerical columns) and has 4 alternatives:

2.1 TabMlp: a simple standard MLP. Very similar to, for example, the tabular api implementation in the fastai library.

2.2 TabResnet: similar to the MLP but instead of dense layers I use Resnet blocks.

2.3 Tabnet[7]: this is a very interesting implementation. It is hard to explain it in a few sentences, therefore I strongly suggest reading the paper. Tabnet is meant to be competitive with GBMs and offers model interpretability via feature importance. pytorch-widedeep's implementation of Tabnet is fully based on the fantastic implementation by the guys at dreamquark-ai, therefore, ALL credit to them. Simply, I have adapted it to work within a Wide and Deep frame and added a couple of extra features, such as internal dropout in the GLU blocks and the possibility of not using ghost batch normalization [8].

Note that the original implementation allows training in two stages. First self-supervised training via a standard encoder-decoder approach and then supervised training or fine-tuning using only the encoder. Only the supervised training (i.e. the encoder) is implemented in pytorch-widedeep. The authors showed that unsupervised pre-training improves the performance mostly in low data sizes regime or when the unlabeled dataset is much larger than the labeled dataset. Therefore, if you are in one of those scenarios (or simply as a general statement), you better use dreamquark-ai's implementation.

2.4.TabTransformer[9]: this is similar to TabResnet, but instead of Resnet blocks the authors used Transformer [10] blocks. Similar to the case of Tabnet, the TabTransformer allows for a two stages training process, unsupervised pre-training followed by supervised training or fine-tuning. pytorch-widedeep's implementation of the TabTransformer is designed to be used in a "standard" way, i.e. supervised training. Note that consistent with the results of Sercan Ö. Arık, Tomas Pfister for Tabnet, the authors found that unsupervised pre-training improves the performance mostly in low data volume regime or when the unlabeled dataset is much larger than the labeled dataset. The TabTransformer implementation available in pytorch-widedeep is partially based on that at the autogluon library and that from Phil Wang here.

1. deeptext: standard text classifier/regressor comprised by a stack of RNNs (LSTMs or GRUs). In addition, there is the option to add a set of dense layers on top of the stack of RNNs and some other extra features.
1. deepimage: standard image classifier/regressor using a pretrained network (in particular ResNets) or a sequence of 4 convolution layers. In addition, there is the option to add a set of dense layers on top of the stack of CNNs and some other extra features.

2.3. Why LightGBM?¶

If you have worked with me, or even have a chat with me about some ML project, you will know that one of my favorite algorithms is LightGBM. I have used is extensively. In fact, the last 3 ML systems that I have productionised all relied on LightGBM. It performs similarly, when not better, than its brothers and sisters (e.g. XGBoost or CatBoost), is significantly faster and offers support for categorical features (see here. Although when it comes to support for categorical features CatBoost is probably the superior solution). In additions, offers the usual flexibility and performance of GBMs.

2.4. Experiments setup and other considerations¶

As I mentioned earlier in the post, I run many experiments (not all were recorded and/or made it to the post) for the four datasets focusing on the different models available for the deeptabular component. All the experiments run can be found here in the repo.

The experiments not only considered different parameters for the models (i.e. number of units, layers, etc..) but also different optimizers, learning rate schedulers, and training processes. For example, all experiments where run with early stopping, with patience of 30 epochs in most cases. I used three different optimizers (Adam[11], AdamW[12] and RAdam[13]) and three different learning rate schedulers (ReduceLROnPlateau, OneCycleLR[14], CyclicLR[15]). The following command corresponds to one of the experiments run:

python adult/adult_tabmlp.py --mlp_hidden_dims [100,50] --mlp_dropout 0.2 --optimizer Adam --early_stop_patience 30 --lr_scheduler CyclicLR --base_lr 5e-4 --max_lr 0.01 --n_cycles 10 --n_epochs 100 --save_results


That command above will run a TabMlp model for the adult dataset. Most args are straightforward to understand. Perhaps the only interesting aspect to comment is that this particular experiment was run with a CyclicLR scheduler, where the learning rate oscillates between 0.0005 to 0.01, 10 times over 100 epochs (i.e. a cycle every 10 epochs).

It is worth mentioning that when running the experiments, I assumed that there is an inherent hierarchy in the DL model parameters and training set ups. Therefore, rather than optimizing all parameters at once, I chose those that I considered more relevant and run experiments that reproduced that hierarchy. For example, when running a simple MLP, I assume that the number of neurons in the layers is a more important parameter than whether or not I use BatchNorm in the last layer. It might be, or surely it is, that the best thing to do is to optimize all parameters at once, but following this "hierarchical" approach also gave me a sense of how changing some individual parameters affected the performance of the model. Nonetheless, around 100 experiments were run per model and per dataset on average, so the exploration was relatively exhaustive (just relatively).

On the other hand LightGBM was optimized using Optuna[16], Hyperopt[17], or both and choosing the parameters that lead to the best metrics. All the code can be found here. Note that the experiments, and the code in the repo, represent a very detailed and thorough tutorial on how to use pytorch-widedeep (if you wanted to use the library).

It is also worth mentioning that when running the experiment, the early stop criterion for both the DL models and LightGBM was based on the validation loss. Alternatively, one can monitor a metric, such as accuracy of the f1 score. Note that accuracy (or f1) and loss are not necessarily exactly inversely correlated. There might be edge cases where the algorithm is really unsure about some predictions (i.e. predictions are close to the metric threshold leading to high loss values) yet ends up making the right prediction (higher accuracy). Of course, ideally we want the algorithm to be sure and make the right predictions, but you know, the real world is messy and noisy. Nonetheless, out of curiosity, I tried to monitor metrics in some experiments. Overall, I did find that the results where consistent with those monitoring loss values, although slightly better metrics could be achieved in some cases.

Another relevant piece of information is related to the number of embeddings used to represent the categorical features. As one can imagine the amount of possibilities here is endless, and I had to find a way to consistently automate the process across all experiment. To that end I decided to use fastai's empirical rule of thumb. For a given categorical feature, the number of embeddings will be:

$$n_{embed} = min\big(600, int(1.6 \times n_{cat}^{0.56})\big)$$

The exception is the TabTransformer. The TabTransformer treats the categorical features as if they were part of a sequence (i.e. contextual) where the sequence order is irrelevant, i.e. no positional encoding needed. Therefore, rather than stack them "one besides another", they are stacked "one on top of each other". This means that all categorical features must have the same dimensions. Note that this is bit of an inconvenient when we have a wide range of categories for the categorical features in the dataset.

For example, let's say we have a dataset with just 2 categorical features having 50 and 3 different categories respectively. While using embeddings of 16 dimensions, for example, seems appropriate for the former, it certainly seems like an "over-representation" in the latter case. One could still use fastai's rule of thumb and pad the embeddings with lower dimension, but that would imply that some of the attention heads will be attending to zeros/nothing throughout the entire training process, which seems like a waste to me. Despite of this potential "waste", I am considering bringing this as an option for pytorch-widedeep's TabTransformer implementation. In the meantime, "all" TabTransformer experiments were run with an additional set up where categorical features with a small number of categories were passed through the wide component.

Finally, for all experiments I used 80% of the data for training and 10% for validation/parameter tuning. Then these 2 datasets were combined in one last training run and the algorithm was tested on the remaining 10% of the data. The datasets were split at random unless there is a temporal component. In those cases I used chronological train/test split (note that in the case of the Facebook Comment Volume dataset I did not use the test set used in the paper. All train, validation and test datasets are splits of the Variant - 5 dataset described in the paper).

And that's all, without further ado, let's move to the results.

3. Results¶

The previous sections provide context to this "project" and details on the experiments that I did run. In this section I will simply show the top 5 results for all data and model combinations along with some comments when I consider necessary. The complete tables with the results for all experiments can be found here.

In [2]:
#hide
from pathlib import Path

import pandas as pd



3.1.1 TabMlp¶

In [3]:
#collapse-hide

Out[3]:
mlp_hidden_dims mlp_activation mlp_dropout mlp_batchnorm mlp_batchnorm_last mlp_linear_first embed_dropout lr batch_size weight_decay optimizer lr_scheduler base_lr max_lr div_factor final_div_factor n_cycles val_loss_or_metric
0 [400,200] relu 0.5 False False False 0.1 0.0010 128 0.0 AdamW ReduceLROnPlateau 0.0010 0.01 25 10000.0 5.0 0.2857
1 [400,200] relu 0.5 False False False 0.0 0.0005 128 0.0 Adam CyclicLR 0.0005 0.01 25 10000.0 10.0 0.2860
2 [100,50] relu 0.2 False False False 0.0 0.0004 128 0.0 Adam OneCycleLR 0.0010 0.01 25 1000.0 5.0 0.2860
3 [400,200] relu 0.5 False False False 0.1 0.0010 128 0.0 Adam ReduceLROnPlateau 0.0010 0.01 25 10000.0 5.0 0.2861
4 [400,200] relu 0.5 False False False 0.0 0.0005 128 0.0 RAdam CyclicLR 0.0005 0.01 25 10000.0 10.0 0.2862

Table 2. Results obtained for the Adult Census dataset using TabMlp.

Perhaps the first comment to make relates to the columns/parameters. It is straightforward to understand that not all parameters/columns apply to each experiment/row. For example, parameters/columns like base_lr, max_lr, div_factor or final_div_factor apply only when the learning rate scheduler is either CyclicLR or OneCycleLR.

On the other hand, the dense layers of the MLP are built using a very similar approach to that in the fastai library. This approach offers flexibility in terms of the operations that occur within each dense layer in the MLP (see here for details). in that context thee columns mlp_batchnorm_last and mlp_linear_first set the order in which these operations occur. For example, if for a given dense layer we set mlp_linear_first = True, the implemented dense layer will look like this: [LIN -> ACT -> DP]. On the other hand, If mlp_linear_first = False then the dense layer will perform the operations in the following order: [DP -> LIN -> ACT].

In the case of the Adult census dataset cyclic learning rates schedulers produce very good results. In fact, a one cycle learning rate with the adequate parameters would already lead to an acceptable validation loss in just one epoch (provided that the batch size is small enough), which perhaps illustrates that this dataset is not particularly difficult. Nonetheless the best result (by a negligible amount) was obtained with a ReduceLROnPlateau learning rate scheduler. This is actually common across all experiments for the different dataset and is also consistent with my experience running DL models in many different scenarios, for tabular data or text. The ReduceLROnPlateau learning rate scheduler was run with patience of 10 epochs. This along with the EarlyStopping patience of 30 epochs means that, when ReduceLROnPlateau is used, the learning rate will be reduced 3 times before the experiment is forced to stop.

For full details on the experiments setup, the model implementation and the meaning behind each parameter/column please have a look to the two pytorch-widedeep's documentation and the experiments repo.

3.1.2 TabResnet¶

In [4]:
#collapse-hide

Out[4]:
blocks_dims blocks_dropout mlp_hidden_dims mlp_activation mlp_dropout mlp_batchnorm mlp_batchnorm_last mlp_linear_first embed_dropout lr batch_size weight_decay optimizer lr_scheduler base_lr max_lr div_factor final_div_factor n_cycles val_loss_or_metric
0 same 0.5 None relu 0.1 False False False 0.1 0.0004 32 0.0 Adam OneCycleLR 0.001 0.01 25 1000.0 5.0 0.2850
1 same 0.5 None relu 0.1 False False False 0.0 0.0004 32 0.0 Adam OneCycleLR 0.001 0.01 25 1000.0 5.0 0.2853
2 same 0.5 None relu 0.1 False False False 0.1 0.0004 128 0.0 AdamW OneCycleLR 0.001 0.01 25 1000.0 5.0 0.2854
3 same 0.5 None relu 0.1 False False False 0.1 0.0004 64 0.0 AdamW OneCycleLR 0.001 0.01 25 1000.0 5.0 0.2855
4 same 0.5 None relu 0.1 False False False 0.1 0.0004 32 0.0 AdamW OneCycleLR 0.001 0.01 25 1000.0 5.0 0.2856

Table 3. Results obtained for the Adult dataset using TabResnet.

block_dim = same in Table 3 indicate that the Resnet blocks, which are comprised by dense layers, have the same dimensions than the incoming embeddings (see here for details on the implementation).

On the other hand, the TabResnet model offers the possibility of using an MLP "on top" of the Resnet blocks. When mlp_hidden_dims = None indicates that no MLP was used and the output of the last Resnet block was "plugged" directly into the output neuron. Therefore, as shown in Table 3, the top 5 results obtained using TabResnet correspond to architectures that have no MLP. In consequence, all MLP related parameters/columns are redundant for those experiments.

I find interesting that whether Adam or AdamW, the best results are obtained using OneCycleLR. When using this scheduler, I normally set the number of epochs to be in between 1 and 10. Normally I obtain the best results for a small number of epochs ($\leq 5$) and a small batch size, which implies that the increase/decrease of the learning rate will be more gradual (i.e. spread over a higher number of steps) as opposed as using large batch sizes. Finally note that the parameter/column n_cycles only apply to the CyclicLR scheduler. Since it is not used in any of the top 5 experiments it can be ignored in Table 3.

3.1.3 Tabnet¶

In [5]:
#collapse-hide

Out[5]:
n_steps step_dim attn_dim ghost_bn virtual_batch_size momentum gamma dropout embed_dropout lr batch_size weight_decay lambda_sparse optimizer lr_scheduler base_lr max_lr div_factor final_div_factor n_cycles val_loss_or_metric
0 5 32 32 False 128 0.98 1.5 0.1 0.0 0.03 128 0.0 0.0001 AdamW ReduceLROnPlateau 0.001 0.01 25 10000.0 5.0 0.2916
1 5 64 64 False 128 0.98 1.5 0.2 0.0 0.03 128 0.0 0.0001 Adam ReduceLROnPlateau 0.001 0.01 25 10000.0 5.0 0.2938
2 5 32 32 False 128 0.98 1.5 0.1 0.0 0.03 128 0.0 0.0001 Adam ReduceLROnPlateau 0.001 0.01 25 10000.0 5.0 0.2939
3 5 64 64 False 128 0.98 1.5 0.2 0.0 0.03 128 0.0 0.0001 AdamW ReduceLROnPlateau 0.001 0.01 25 10000.0 5.0 0.2945
4 5 64 64 False 128 0.98 1.5 0.2 0.0 0.05 128 0.0 0.0001 RAdam ReduceLROnPlateau 0.001 0.01 25 10000.0 5.0 0.2962

Table 4. Results obtained for the Adult dataset using Tabnet.

Tabnet has received some attention lately for being competitive with GBMs, and even over-performing them. In addition, it is a very elegant implementation that offers model interpretability via feature importance obtained using attention mechanisms.

The reality is that for the Adult Census dataset I obtain the worst loss values on the validation set (but as we will see later, not the worst metric). Maybe I simply missed "that precise" set of parameters that lead to better results. However, it is worth emphasizing that I have explored Tabnet with the same level of detail that any of the other 3 model alternatives.

On the other hand, it is interesting that, within all the experiments run, the best results are consistently obtained without Ghost batch normalization. Therefore, the parameter/column virtual_batch_size can be ignored in Table 4. Similarly, since the best results are all obtained using ReduceLROnPlateau, all the parameters related to cyclic learning rate schedulers can be ignored in Table 4.

Finally, consistent with some experiments I run in the past, the best results obtained using RAdam normally involve relatively high learning rates.

3.1.4 TabTransformer¶

In [6]:
#collapse-hide

Out[6]:
embed_dropout full_embed_dropout shared_embed add_shared_embed frac_shared_embed input_dim n_heads n_blocks dropout ff_hidden_dim transformer_activation mlp_hidden_dims mlp_activation mlp_batchnorm mlp_batchnorm_last mlp_linear_first with_wide lr batch_size weight_decay optimizer lr_scheduler base_lr max_lr div_factor final_div_factor n_cycles val_loss_or_metric
0 0.0 False False False 8 16 4 4 0.1 NaN relu None relu False False False False 0.010 128 0.0 RAdam ReduceLROnPlateau 0.001 0.01 25 10000.0 5.0 0.2879
1 0.0 False False False 8 16 4 4 0.1 NaN relu same relu False False False False 0.010 128 0.0 RAdam ReduceLROnPlateau 0.001 0.01 25 10000.0 5.0 0.2885
2 0.0 False False False 8 16 4 4 0.1 NaN relu None relu False False False True 0.010 128 0.0 RAdam ReduceLROnPlateau 0.001 0.01 25 10000.0 5.0 0.2888
3 0.0 False False False 8 16 4 8 0.2 NaN relu None relu False False False True 0.001 128 0.0 AdamW ReduceLROnPlateau 0.001 0.01 25 10000.0 5.0 0.2892
4 0.0 False False False 8 16 2 4 0.1 NaN relu None relu False False False False 0.010 128 0.0 RAdam ReduceLROnPlateau 0.001 0.01 25 10000.0 5.0 0.2894

Table 5. Results obtained for the Adult Census dataset using the TabTransformer.

As with all the previous models, if you wanted details on the meaning of each parameter/column, please have a look to the [documentation] of the [source code] itself.

It is perhaps worth mentioning that when feed forward hidden dim (ff_hidden_dim) is set to NaN the model will default to a ff_hidden_dim value equal to 4 times the input embedding dimensions (16 in all the experiments/rows shown in the Table). This will result in a feed forward layer with dimensions [ff_input_dim -> 4 * ff_input_dim -> ff_input_dim]. Similarly, when mlp_hidden_dims = None the model will default to 4 times the input dimensions, resulting in an MLP of dimensions [mlp_input_dim -> 4 * mlp_input_dim -> 2* mlp_input_dim -> output_dim].

On In addition, and as mentioned before, the TabTransformer was also run with a set up that includes a wide component. This is specified by the with_wide parameter.

Is is worth noticing that the best loss values, which are similar to those of the rest of the DL models, are normally obtained using a RAdam optimizer.

3.1.5 DL vs LightGBM¶

After having gone through the results obtained for each of the DL models, this is the moment of truth, let's see how the DL results compare with those obtained with LightGBM.

In [7]:
#collapse-hide

Out[7]:
model acc runtime best_epoch_or_ntrees
0 lightgbm 0.8782 0.9086 408.0
1 tabmlp 0.8722 205.3576 62.0
2 tabtransformer 0.8718 288.6406 32.0
3 tabnet 0.8704 422.2967 26.0
4 tabresnet 0.8698 388.9325 25.0

Table 6. Results obtained for the Adult Census dataset using four DL models and LightGBM. runtime units are seconds

Let me emphasize again that the metrics shown in Table 6 are all obtained, of course, for the test dataset. The runtime column shows the training time, in seconds, for the final train dataset (i.e. a dataset comprising 90% of the data) using the best parameters obtained during validation. The DL models where run on a p2.xlarge instance on AWS and all the LightGBM experiments were run on my Mac Mid 2015.

They are a few aspects worth commenting. In the first place, all DL models obtain results that are competitive with, but not better than, those of LightGBM. Secondly, the best performing DL model (by a rather marginal amount) is the simplest model, the TabMlp. And finally, the training time when using LightGBM is simply "gigantically" better than with any of the DL models.

3.2 Bank Marketing Dataset¶

Most of the comments in the previous section apply to the tables shown in this section.

Note that as I mentioned earlier in the post, the Bank Marketing dataset is slightly imbalanced. Therefore I also run some experiments using the focal loss [18] (which is accessible in pytorch_widedeep via a parameter or as a loss function input. See here). Overall, the results obtained where similar to, but not better than those without the focal loss. This is consistent with my experience with other datasets where I find that the focal loss leads to notably better results when the dataset is highly imbalanced (for example, around 2% positive to negative class ratio).

3.2.1 TabMlp¶

In [8]:
#collapse-hide
# focal loss values are on a different scale
bank_marketing_tabmlp = bank_marketing_tabmlp[bank_marketing_tabmlp.val_loss_or_metric > 0.2]
(bank_marketing_tabmlp
.sort_values("val_loss_or_metric", ascending=True)
.reset_index(drop=True)

Out[8]:
mlp_hidden_dims mlp_activation mlp_dropout mlp_batchnorm mlp_batchnorm_last mlp_linear_first embed_dropout lr batch_size weight_decay optimizer lr_scheduler base_lr max_lr div_factor final_div_factor n_cycles val_loss_or_metric
0 [100,50] relu 0.1 True True False 0.1 0.001 512 0.0 AdamW ReduceLROnPlateau 0.001 0.01 25 10000.0 5.0 0.2638
1 [100,50] relu 0.1 True False True 0.1 0.001 512 0.0 AdamW ReduceLROnPlateau 0.001 0.01 25 10000.0 5.0 0.2639
2 [100,50] relu 0.1 True True False 0.1 0.001 512 0.0 Adam ReduceLROnPlateau 0.001 0.01 25 10000.0 5.0 0.2643
3 [100,50] relu 0.1 False False False 0.1 0.001 512 0.0 AdamW ReduceLROnPlateau 0.001 0.01 25 10000.0 5.0 0.2643
4 [100,50] relu 0.1 True False False 0.1 0.001 512 0.0 Adam ReduceLROnPlateau 0.001 0.01 25 10000.0 5.0 0.2646

Table 7. Results obtained for the Bank Marketing dataset using TabMlp.

3.2.2 TabResnet¶

In [9]:
#collapse-hide
bank_marketing_tabresnet.round(4)

Out[9]:
blocks_dims blocks_dropout mlp_hidden_dims mlp_activation mlp_dropout mlp_batchnorm mlp_batchnorm_last mlp_linear_first embed_dropout lr batch_size weight_decay optimizer lr_scheduler base_lr max_lr div_factor final_div_factor n_cycles val_loss_or_metric
0 same 0.5 None relu 0.1 False False False 0.0 0.0004 64 0.0 Adam OneCycleLR 0.001 0.01 25 1000.0 5.0 0.2660
1 [50,50,50,50] 0.2 None relu 0.1 False False False 0.0 0.0010 512 0.0 Adam ReduceLROnPlateau 0.001 0.01 25 10000.0 5.0 0.2661
2 same 0.5 None relu 0.1 False False False 0.0 0.0004 64 0.0 RAdam OneCycleLR 0.001 0.01 25 1000.0 5.0 0.2663
3 same 0.5 None relu 0.1 False False False 0.0 0.0004 128 0.0 RAdam OneCycleLR 0.001 0.01 25 1000.0 5.0 0.2664
4 same 0.5 None relu 0.1 False False False 0.0 0.0004 128 0.0 Adam OneCycleLR 0.001 0.01 25 1000.0 5.0 0.2667

Table 8. Results obtained for the Bank Marketing dataset using TabResnet.

Again, and very interestingly, RAdam optimizer and OneCycleLR leading to some of the best results for this DL model.

3.2.3 Tabnet¶

In [10]:
#collapse-hide
bank_marketing_tabnet.round(4)

Out[10]:
n_steps step_dim attn_dim ghost_bn virtual_batch_size momentum gamma dropout embed_dropout lr batch_size weight_decay lambda_sparse optimizer lr_scheduler base_lr max_lr div_factor final_div_factor n_cycles val_loss_or_metric
0 5 16 16 True 128 0.75 1.5 0.0 0.0 0.03 512 0.0 0.0001 AdamW ReduceLROnPlateau 0.001 0.01 25 10000.0 5.0 0.2714
1 5 16 16 True 64 0.25 1.5 0.0 0.0 0.03 512 0.0 0.0001 AdamW ReduceLROnPlateau 0.001 0.01 25 10000.0 5.0 0.2722
2 5 64 64 False 128 0.98 1.5 0.2 0.0 0.03 128 0.0 0.0001 Adam ReduceLROnPlateau 0.001 0.01 25 10000.0 5.0 0.2726
3 5 64 64 False 128 0.98 1.5 0.2 0.0 0.03 128 0.0 0.0001 AdamW ReduceLROnPlateau 0.001 0.01 25 10000.0 5.0 0.2738
4 5 16 16 True 128 0.98 2.0 0.0 0.0 0.03 512 0.0 0.0001 AdamW ReduceLROnPlateau 0.001 0.01 25 10000.0 5.0 0.2739

Table 9. Results obtained for the Bank Marketing dataset using Tabnet.

Note the top 5 results obtained with Tabnet in this case all have relatively high learning rate values (lr = 0.03). Also, and similar to the case of the Adult Census dataset, Tabnet produces the worst validation loss values.

3.2.4 TabTransformer¶

In [11]:
#collapse-hide
bank_marketing_tabtransformer.round(4)

Out[11]:
embed_dropout full_embed_dropout shared_embed add_shared_embed frac_shared_embed input_dim n_heads n_blocks dropout ff_hidden_dim transformer_activation mlp_hidden_dims mlp_activation mlp_batchnorm mlp_batchnorm_last mlp_linear_first with_wide lr batch_size weight_decay optimizer lr_scheduler base_lr max_lr div_factor final_div_factor n_cycles val_loss_or_metric
0 0.0 False False False 8 32 8 6 0.1 NaN relu None relu False False False False 0.001 512 0.0 Adam ReduceLROnPlateau 0.001 0.01 25 10000.0 5.0 0.2646
1 0.0 False False False 8 32 8 6 0.1 NaN relu None relu False False False False 0.001 512 0.0 AdamW ReduceLROnPlateau 0.001 0.01 25 10000.0 5.0 0.2647
2 0.0 False True False 4 16 4 6 0.1 NaN relu None relu False False False False 0.010 128 0.0 RAdam ReduceLROnPlateau 0.001 0.01 25 10000.0 5.0 0.2668
3 0.0 False False False 8 32 8 6 0.1 NaN relu None relu False False False False 0.010 1024 0.0 RAdam ReduceLROnPlateau 0.001 0.01 25 10000.0 5.0 0.2672
4 0.0 False False False 8 32 8 6 0.1 NaN relu None relu False False False False 0.001 1024 0.0 Adam ReduceLROnPlateau 0.001 0.01 25 10000.0 5.0 0.2672

Table 10. Results obtained for the Bank Marketing dataset using the TabTransformer.

It is perhaps worth noticing that consistent with some of the previous results, the best results obtained here using RAdam involve relatively high learning rates (a factor of 10 compared to those obtained using Adam or AdamW.)

3.2.5 DL vs LightGBM¶

In [12]:
#collapse-hide
lightgbm_vs_dl_bank_marketing.round(4)

Out[12]:
model f1 auc runtime best_epoch_or_ntrees
0 tabresnet 0.4298 0.6501 92.5175 11.0
1 tabtransformer 0.4200 0.6440 31.6938 4.0
2 tabmlp 0.3855 0.6281 9.5721 7.0
3 lightgbm 0.3852 0.6265 0.4614 57.0
4 tabnet 0.3087 0.5943 77.8781 13.0

Table 11. Results obtained for the Bank Marketing dataset using four DL models and LightGBM.

I must admit that the results shown in Table 11 were surprising to me at first (to say the least). I went and run a few DL models again and LightGBM multiple times to double check, and finally concluded (spoiler alert) that this is going to be the only case among all experiments I run in this post where DL models perform better than LightGBM. In fact, if we joined the experiments here with my experience at work, this is the second time ever that I find that DL models perform better than LightGBM (more on this later). Furthermore, the improvement obtained using TabResnet or the TabTransformer is quite significant to the point that if this was a "real world" example, one might consider using a DL model and accept the trade between running time and success metric.

Of course one could go and dive a bit deeper into LightGBM, setting sample weights, or even using a custom loss, but the same can be said about the DL models. Therefore, and overall, I consider the comparison fair. However, I am so surprised that I consider the possibility that I might have a bug in the code that I have not been able to find. Therefore, if anyone goes through the code at some point and finds indeed a bug please let me know 🙂.

Finally, someone might feel disappointed by Tabnet's performance, as I was. There is a possibility that I have not implemented it correctly, although the code is fully based on that from dreamquark-ai's implementation (ALL credit to them) and when tested with easier datasets, I obtain similar results to those with GBMs. I find Tabnet to be a very elegant implementation and somehow I believe it should perform better. I will come back to this point in the Conclusions section.

3.3 NYC Taxi trip duration¶

As I mentioned earlier this is the largest dataset, and in consequence, I experimented with larger batch sizes. While this might slightly change some of the individual results, I believe it will not change the overall conclusion in this section.

3.3.1 TabMlp¶

In [13]:
#collapse-hide
nyc_taxi_tabmlp.round(4)

Out[13]:
mlp_hidden_dims mlp_activation mlp_dropout mlp_batchnorm mlp_batchnorm_last mlp_linear_first embed_dropout lr batch_size weight_decay optimizer lr_scheduler base_lr max_lr div_factor final_div_factor n_cycles val_loss_or_metric
0 auto relu 0.1 False False True 0.0 0.01 1024 0.0 Adam ReduceLROnPlateau 0.001 0.01 25 10000.0 5.0 79252.7786
1 auto relu 0.1 False False True 0.0 0.01 1024 0.0 AdamW ReduceLROnPlateau 0.001 0.01 25 10000.0 5.0 79440.6025
2 auto relu 0.1 False False False 0.1 0.01 1024 0.0 Adam ReduceLROnPlateau 0.001 0.01 25 10000.0 5.0 79477.5653
3 auto relu 0.1 False False False 0.1 0.01 1024 0.0 AdamW ReduceLROnPlateau 0.001 0.01 25 10000.0 5.0 79710.8551
4 auto relu 0.1 False False False 0.0 0.01 1024 0.0 AdamW ReduceLROnPlateau 0.001 0.01 25 10000.0 5.0 80214.7197

Table 12. Results obtained for the NYC Taxi trip duration dataset using the TabMlp.

The validation loss in this case is the MSE. The standard deviation (std hereafter) of the target variable in the validation set is $\sim$599. Given that the std is the RMSE we would obtain if we always predicted the expected value, we can see that this is not a very powerful model, i.e. the task of predicting taxi trip duration is, indeed, relatively challenging.

Let's see how the other DL models perform.

3.3.2 TabResnet¶

In [14]:
#hide
nyc_taxi_tabresnet.round(4)

Out[14]:
blocks_dims blocks_dropout mlp_hidden_dims mlp_activation mlp_dropout mlp_batchnorm mlp_batchnorm_last mlp_linear_first embed_dropout lr batch_size weight_decay optimizer lr_scheduler base_lr max_lr div_factor final_div_factor n_cycles val_loss_or_metric
0 same 0.5 auto relu 0.2 False False False 0.0 0.01 2048 0.0 Adam ReduceLROnPlateau 0.001 0.01 25 10000.0 5 97015.1182
1 same 0.2 auto relu 0.1 False False False 0.0 0.01 1024 0.0 AdamW ReduceLROnPlateau 0.001 0.01 25 10000.0 5 98266.4310
2 same 0.5 auto relu 0.2 False False False 0.0 0.04 2048 0.0 Adam ReduceLROnPlateau 0.001 0.01 25 10000.0 5 100332.3569
3 same 0.2 auto relu 0.1 False False False 0.0 0.01 1024 0.0 Adam ReduceLROnPlateau 0.001 0.01 25 10000.0 5 103006.5603
4 same 0.5 auto relu 0.2 False False False 0.0 0.01 2048 0.0 AdamW ReduceLROnPlateau 0.001 0.01 25 10000.0 5 105967.2627

Table 13. Results obtained for the NYC Taxi trip duration dataset using the TabResnet.

3.3.3 Tabnet¶

In [15]:
#collapse-hide
nyc_taxi_tabnet.round(4)

Out[15]:
n_steps step_dim attn_dim ghost_bn virtual_batch_size momentum gamma dropout embed_dropout lr batch_size weight_decay lambda_sparse optimizer lr_scheduler base_lr max_lr div_factor final_div_factor n_cycles val_loss_or_metric
0 5 8 8 False 128 0.75 1.5 0.0 0.0 0.01 1024 0.0 0.0001 Adam ReduceLROnPlateau 0.001 0.01 25 10000.0 5 144819.1190
1 5 8 8 False 128 0.98 1.5 0.0 0.0 0.01 1024 0.0 0.0001 Adam ReduceLROnPlateau 0.001 0.01 25 10000.0 5 146057.8078
2 5 8 8 False 128 0.50 1.5 0.0 0.0 0.01 1024 0.0 0.0001 Adam ReduceLROnPlateau 0.001 0.01 25 10000.0 5 146201.3771
3 5 16 16 False 128 0.98 1.5 0.0 0.0 0.01 1024 0.0 0.0001 Adam ReduceLROnPlateau 0.001 0.01 25 10000.0 5 146461.7343
4 5 8 8 False 128 0.25 1.5 0.0 0.0 0.01 1024 0.0 0.0001 Adam ReduceLROnPlateau 0.001 0.01 25 10000.0 5 148636.8888

Table 14. Results obtained for the NYC Taxi trip duration dataset using the Tabnet.

3.3.4 TabTransformer¶

In [16]:
#collapse-hide
nyc_taxi_tabtransformer.round(4)

Out[16]:
embed_dropout full_embed_dropout shared_embed add_shared_embed frac_shared_embed input_dim n_heads n_blocks dropout ff_hidden_dim transformer_activation mlp_hidden_dims mlp_activation mlp_batchnorm mlp_batchnorm_last mlp_linear_first with_wide lr batch_size weight_decay optimizer lr_scheduler base_lr max_lr div_factor final_div_factor n_cycles val_loss_or_metric
0 0.0 False False False 8 16 4 4 0.1 NaN relu None relu False False False False 0.01 1024 0.0 Adam ReduceLROnPlateau 0.001 0.01 25 10000.0 5 180162.4087
1 0.0 False False False 8 16 4 4 0.1 NaN relu None relu False False False False 0.01 256 0.0 Adam ReduceLROnPlateau 0.001 0.01 25 10000.0 5 186017.1888
2 0.0 False False False 8 16 4 4 0.1 NaN relu None relu False False False False 0.01 512 0.0 Adam ReduceLROnPlateau 0.001 0.01 25 10000.0 5 196144.0674
3 0.0 False False False 8 32 8 4 0.4 NaN relu None relu False False False False 0.01 1024 0.0 Adam ReduceLROnPlateau 0.001 0.01 25 10000.0 5 357869.3703
4 0.0 False False False 8 64 16 4 0.4 NaN relu None relu False False False False 0.01 512 0.0 Adam ReduceLROnPlateau 0.001 0.01 25 10000.0 5 357884.9043

Table 15. Results obtained for the NYC Taxi trip duration dataset using the TabTransformer.

3.3.5 DL vs LightGBM¶

In [17]:
#collapse-hide
lightgbm_vs_dl_nyc_taxy.round(4)

Out[17]:
model rmse r2 runtime best_epoch_or_ntrees
0 lightgbm 262.7099 0.8044 42.7211 504.0
1 tabmlp 271.3422 0.7913 568.4309 24.0
2 tabresnet 292.8908 0.7569 471.2650 24.0
3 tabtransformer 336.5826 0.6789 5779.0314 54.0
4 tabnet 376.0530 0.5992 1844.4723 15.0

Table 16. Results obtained for the NYC Taxi trip duration dataset using four DL models and LightGBM.

The TabTransformer and Tabnet are, in this case, the models that have the worst performance. As I mentioned earlier I will reflect on potential reasons later in the Conclusion section.

This is the last of the four datasets we will be discussing in this post, a second regression problem.

3.4.1 TabMlp¶

In [18]:
#collapse-hide

Out[18]:
mlp_hidden_dims mlp_activation mlp_dropout mlp_batchnorm mlp_batchnorm_last mlp_linear_first embed_dropout lr batch_size weight_decay optimizer lr_scheduler base_lr max_lr div_factor final_div_factor n_cycles val_loss_or_metric
0 [100,50] relu 0.1 False False True 0.0 0.001 512 0.0 RAdam ReduceLROnPlateau 0.001 0.01 25 10000.0 5.0 32.5931
1 [100,50] relu 0.1 False False False 0.0 0.001 512 0.0 RAdam ReduceLROnPlateau 0.001 0.01 25 10000.0 5.0 33.3515
2 [200, 100] relu 0.1 False False False 0.0 0.001 256 0.0 Adam ReduceLROnPlateau 0.001 0.01 25 10000.0 5.0 33.4140
3 [200, 100] relu 0.1 False False False 0.1 0.001 256 0.0 Adam ReduceLROnPlateau 0.001 0.01 25 10000.0 5.0 33.5679
4 [200, 100] relu 0.1 False False False 0.0 0.001 512 0.0 RAdam ReduceLROnPlateau 0.001 0.01 25 10000.0 5.0 33.6284

Table 17. Results obtained for the Facebook comments volume dataset using TabMlp.

As in the case of the NYC Taxi trip duration, the validation loss is the MSE loss. The std of the target variable is ~13 in the case of the Facebook comments volume dataset. Therefore, following the same reasoning, we can see that the task of predicting the volume of facebook comments using this particular dataset is challenging.

Let's see how the other DL models perform.

3.4.2 TabResnet¶

In [19]:
#collapse-hide

Out[19]:
blocks_dims blocks_dropout mlp_hidden_dims mlp_activation mlp_dropout mlp_batchnorm mlp_batchnorm_last mlp_linear_first embed_dropout lr batch_size weight_decay optimizer lr_scheduler base_lr max_lr div_factor final_div_factor n_cycles val_loss_or_metric
0 [100, 100, 100] 0.1 None relu 0.1 False False False 0.0 0.0005 512 0.0 Adam CyclicLR 0.0005 0.03 25 10000.0 10.0 34.4972
1 [100, 100, 100] 0.1 None relu 0.1 False False False 0.0 0.0005 512 0.0 AdamW CyclicLR 0.0005 0.03 25 10000.0 10.0 34.8520
2 [100, 100, 100] 0.1 None relu 0.1 False False False 0.0 0.0005 512 0.0 Adam CyclicLR 0.0005 0.03 25 10000.0 10.0 34.9504
3 [100, 100, 100] 0.1 None relu 0.1 False False False 0.0 0.0005 512 0.0 Adam CyclicLR 0.0005 0.01 25 10000.0 10.0 35.1668
4 [100, 100, 100] 0.1 None relu 0.1 False False False 0.0 0.0005 512 0.0 AdamW CyclicLR 0.0005 0.01 25 10000.0 10.0 35.2503

Table 18. Results obtained for the Facebook comments volume dataset using TabResnet.

3.4.3 Tabnet¶

In [20]:
#collapse-hide

Out[20]:
n_steps step_dim attn_dim ghost_bn virtual_batch_size momentum gamma dropout embed_dropout lr batch_size weight_decay lambda_sparse optimizer lr_scheduler base_lr max_lr div_factor final_div_factor n_cycles val_loss_or_metric
0 5 16 16 False 128 0.98 1.5 0.0 0.0 0.03 512 0.0 0.0001 AdamW ReduceLROnPlateau 0.001 0.01 25 10000.0 5 35.8122
1 3 16 16 False 128 0.98 1.5 0.2 0.0 0.03 512 0.0 0.0001 Adam ReduceLROnPlateau 0.001 0.01 25 10000.0 5 37.6417
2 5 16 16 False 128 0.98 1.5 0.0 0.0 0.03 512 0.0 0.0001 AdamW ReduceLROnPlateau 0.001 0.01 25 10000.0 5 38.9771
3 5 16 16 False 128 0.98 1.5 0.2 0.0 0.03 512 0.0 0.0001 Adam ReduceLROnPlateau 0.001 0.01 25 10000.0 5 39.5899
4 5 16 16 False 128 0.98 1.5 0.0 0.0 0.03 256 0.0 0.0001 Adam ReduceLROnPlateau 0.001 0.01 25 10000.0 5 40.9462

Table 19. Results obtained for the Facebook comments volume dataset using Tabnet.

3.4.4 TabTransformer¶

In [21]:
#collapse-hide

Out[21]:
embed_dropout full_embed_dropout shared_embed add_shared_embed frac_shared_embed input_dim n_heads n_blocks dropout ff_hidden_dim transformer_activation mlp_hidden_dims mlp_activation mlp_batchnorm mlp_batchnorm_last mlp_linear_first with_wide lr batch_size weight_decay optimizer lr_scheduler base_lr max_lr div_factor final_div_factor n_cycles val_loss_or_metric
0 0.0 False False False 8 16 2 4 0.1 NaN relu None relu False False False False 0.0005 1024 0.0 Adam CyclicLR 0.0005 0.01 25 10000.0 10.0 33.0946
1 0.0 False False False 8 16 2 4 0.1 NaN relu None relu False False False False 0.0005 4096 0.0 AdamW OneCycleLR 0.0010 0.01 25 1000.0 5.0 33.1283
2 0.0 False False False 8 16 2 4 0.1 NaN relu None relu False False False False 0.0010 1024 0.0 Adam ReduceLROnPlateau 0.0010 0.01 25 10000.0 5.0 33.2175
3 0.0 False False False 8 16 2 4 0.1 NaN relu same relu False False False False 0.0010 1024 0.0 Adam ReduceLROnPlateau 0.0010 0.01 25 10000.0 5.0 33.4698
4 0.0 False False False 8 16 4 4 0.1 NaN relu None relu False False False False 0.0010 1024 0.0 Adam ReduceLROnPlateau 0.0010 0.01 25 10000.0 5.0 33.7950

Table 20. Results obtained for the Facebook comments volume dataset using the TabTransformer.

3.4.5 DL vs LightGBM¶

In [22]:
#collapse-hide

Out[22]:
model rmse r2 runtime best_epoch_or_ntrees
0 lightgbm 5.5290 0.8232 6.5259 687.0
1 tabmlp 5.9085 0.7981 250.4768 43.0
2 tabtransformer 5.9256 0.7969 533.3908 27.0
3 tabresnet 6.2138 0.7767 70.4661 9.0
4 tabnet 6.4285 0.7610 935.0205 59.0

Table 21. Results obtained for the Facebook comments volume dataset using four DL models and LightGBM.

4. Summary¶

I have used four datasets and run over 1500 experiments (meaning runs with a parameter setup) comparing four DL models with LightGBM. This is a summary of some of the results.

• LightGBM wins, and there was never a fight

With one exception, LightGBM performs better than the DL models, and that one exception is precisely that, exceptional. To the experiments run and discussed here I could add two occasions where I used DL for tabular data in companies that I worked with. In particular, the model that is referred here as TabMlp with a wide component in one case and on its own in the other.

The Wide & Deep model was used in the context of a recommendation algorithm, shortly after the popular Wide and Deep [19] paper was published in 2016. Back then I was using XGBoost to predict a measure of interest and rank offers based on that measure. The Wide and Deep model, implemented then with Keras, obtained slightly better MAP and NDCG than XGBoost (almost identical metrics, although slightly lower, were obtained when using just the deep component). Given the number of additional considerations that one needs to take into account as you go to production, we eventually used XGBoost.

In the second occasion, a more recent project, TabMlp on its own obtained very similar, but still lower RMSE and R2 values to those obtained using LightGBM. Even though TabMLP's predictions were not directly used, we found the embeddings useful for a number of additional projects and we built a production system around TabMlp.

Up to this point, I have focused on performance as measured by success metrics. However, when it comes to training (and prediction) time, the difference is so significant that makes some of these algorithms, at this stage, just useful for research purposes and/or kaggle competitions. Don't get me wrong, you only push an industry technologically by challenging current solutions and established concepts. I am simply stating that at this stage, in a production environment, it would be hard to envision a robust system built around some of these algorithms. This is the reason why I wrote "there was never a fight". When you go live, quite often is not only about success metrics but also speed and resilience. Considering altogether it seems to me like DL models for tabular data are still a bit far from being normally inserted in productions systems (but read below).

Finally, you might read here and there that with the proper feature engineering, noise removal, balancing and "who-knows-what-else" DL models outperform GBMs. The truth is that in my experience is actually the opposite. When one manages to engineer good, powerful features GBMs perform even better than DL models. This is also consistent with the results in some recent competitions. For example, in the RecSys Challenge 2020 the guys at NVIDIA won using clever featuring engineering (e.g. target oriented encoding) "plugged" into XGBoost on steroids (or better, GPUs). I am not sure that using those features and a DL model would actually improve their results.

Overall, if I joined the results found this post, plus that I have found trying DL models on tabular data on real datasets in the industry, I can only conclude that DL models for tabular data "are not quite there yet" in terms of overall performance.

• TabNet and the TabTransformer

One rather surprising results was the poor performance of Tabnet, and perhaps to a lesser extent, the TabTransformer.

One possibility is that I have not found the right set of parameters that lead to good metrics. In fact, the amount of overfitting when using Tabnet and TabTransformer was very significant, higher than in the case of TabResnet and furthermore TabMlp. This makes me believe that if I find a better set of regularization parameters, or simply using a different number of embeddings per categorical feature, I might be able to improve the results shown in the tables above. However, I should also say that given the good reception that these algorithms are having and the poor results I obtained, I placed a bit more emphasis in trying some additional parameters. Unfortunately, none of my attempts lead to a significant improvement.

A second possibility is, of course, that the implementation at pytorch-widedeep is wrong. I guess I will find this out as I keep releasing versions and using the package.

Overall, I find that TabNet is the worst performer (and the slowest) and I will certainly devote some extra time in the coming weeks to see if this is related to the input parameters.

• Simplicity over complexity.

It is interesting to see that overall, the DL algorithm that achieves similar performance to that of LightGBM is a simple MLP. By the time I write this, I wonder if this is somehow related to the emerging trend that is bringing MLPs back (e.g. [20], [21] or [22]), and the advent of more complex models is simply the result of hype instead of a proper exploration of current solutions.

Of course, for more complex models, there is more room for exploration and hyperparameter optimization. While this is something I intend to keep exploring, there is a moment in space and time that one wonders "is this really worth it?".

Let's see if I manage to answer this question in the next section

5 Conclusion¶

When I started thinking of this post a part of me already knew that DL models were, overall, not a real challenge for LightGBM. If we focused only in performance metrics and running time the only possible conclusion is that DL models for tabular data are still not competition for GBMs in real-world environments. However, at this stage in the industry/market, is that really the question to answer? I don't think so.

This is not a competition, and it should not be, this should be a coalition. The question to answer is: "how DL models for tabular data can help in the industry and complement the current systems". Let's reflect a bit on this question.

In my experience, DL models on tabular data perform best on sizeable dataset that involve many categorical features and these have many categories themselves. In those scenarios, one could just try DL models with an initial aim of using directly the prediction. However, even if the prediction is eventually not used, the embeddings contain a wealth of useful information. Information on how each categorical feature interacts with each other and information on how each categorical features relates to the target variable. These embeddings can be used for a number of additional products.

For example, let's assume that you have a dataset with metadata for thousands of brands and prices for their corresponding products. Your task is to predict how the price changes over time (i.e. forecasting price). The embeddings for the categorical feature brand will give you information about how a particular brand relates to the rest of the columns in the dataset and the target (price). In other words, if given a brand you find the closest brands as defined by embeddings proximity you would be "naturally" and directly finding competitors within a given space (assuming that the dataset is representative of the market).

In additions, GBMs do not allow for transfer learning, but DL models do. Furthermore, and as mentioned in the TabNet and the TabTransformer papers, self-supervised training leads to better performance in regimes where the data is low or the unlabeled dataset is much larger than the labeled dataset. Therefore, there are scenarios where DL models can be extremely useful.

For example, let's assume you have a large dataset for a given problem in one country but a much smaller dataset for the exact same problem in another country. Let's also assume that the datasets are, column-wise, rather similar. One could train a DL model using the large dataset and "transfer the learnings" to the second, much smaller dataset with the hope of obtaining a much higher performance than just using that small dataset alone.

There are some other scenarios that I can think of, but I will leave it here. In general, I simply wanted to illustrate that, if you came here to enjoy the fact that GBMs perform better than DL models, I hope you enjoyed the ride (and that you start thinking in a good therapist), but in my opinion, that is not the point.

In terms of metrics, GBMs perform better than DL models, that is correct, but the latter bring some functionalities to the table that GBMs don't have and therefore, complement them perfectly.

6. Future Work¶

I started thinking in this post months ago. Then some other things took priority in my life (plus a lot of work) and it became a bit of a longer journey. I now hope I can get a bit of help from very clever people in my team and improve the Tabular vs DL code in the repo, perhaps automating some processes so I can easily add more datasets in the future.

Also this has been a good test for the pytorch-widedeep library (if you like it, or find it useful, give it a star please 😊). All the links in this post point towards the tabnet branch in the repo, which is the most updated. During the next few days I will merge and release v1 of the package and then update the links and the post. From there, there are a series of algorithms we would like to bring (such as SAINT) and also add some different forms of training.

Beyond adding more algorithms to the library or improving the benchmark code, I wanted to close this with one final thought. As I mentioned in the beginning of the post, there is an element of inconsistency between papers. Different papers will find different results for all algorithms considered, GBMs or DL-based. When you read them one gets the feeling that there is some rush, some urgency to publish something that obtains SoTA. For someone like me, coming from a different background than computer science, this reminds me, in a sense, of my days as astronomer. For years then I found that most of the publications in my field where not very good, but since all that you are judged for are publications and citations, one would publish anything, and the faster, the better.

At this stage, leaving publications and citations aside, I think there is an opportunity for some of us, and some companies as well (so that we can use actual real-world data), to collaborate and properly benchmark DL algorithms for tabular data. I believe the potential of these algorithms in the industry is enormous and with proper benchmarks we could learn not only where they perform better, but how to use them more efficiently.

And that's it! if you made it to here I hope you enjoyed and/or find this useful.

References¶

[1] Tabular Data: Deep Learning is Not All You Need: Ravid Shwartz-Ziv, Amitai Armon, 2021, arxiv:2106.03253

[2] XGBoost: A Scalable Tree Boosting System. Tianqi Chen, Carlos Guestrin 2016, arXiv:1603.02754

[3] CatBoost: unbiased boosting with categorical features. Liudmila Prokhorenkova, Gleb Gusev, Aleksandr Vorobev, Anna Veronika Dorogush, Andrey Gulin, arXiv:1706.09516

[4] LightGBM: A Highly Efficient Gradient Boosting Decision Tree. Guolin Ke, Qi Meng, Thomas Finley, Taifeng Wang, 2017, 31st Conference on Neural Information Processing Systems

[5] SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pre-Training. Gowthami Somepalli, Micah Goldblum, Avi Schwarzschild, C. Bayan Bruss, Tom Goldstein, 2021, arXiv:2106.01342

[6] Comment Volume Prediction using Neural Networks and Decision Trees, Kamaljot Singh, Ranjeet Kaur, 2015 17th UKSIM-AMSS International Conference on Modelling and Simulation.

[7] TabNet: Attentive Interpretable Tabular Learning, Sercan O. Arik, Tomas Pfister, arXiv:1908.07442v5

[8] Train longer, generalize better: closing the generalization gap in large batch training of neural networks. Elad Hoffer, Itay Hubara and Daniel Soudry, 2017, arXiv:1705.08741

[9] TabTransformer: Tabular Data Modeling Using Contextual Embeddings. Xin Huang, Ashish Khetan, Milan Cvitkovic, Zohar Karnin, 2020. arXiv:2012.06678v1

[10] Attention Is All You Need, Ashish Vaswani, Noam Shazeer, Niki Parmar, et al., 2017. arXiv:1706.03762v5

[11] Adam: A Method for Stochastic Optimization, Diederik P. Kingma, Jimmy Ba, 2014, arXiv:1412.6980

[12] Decoupled Weight Decay Regularization, Ilya Loshchilov, Frank Hutter, 2017.arXiv:1711.05101

[13] On the Variance of the Adaptive Learning Rate and Beyond, Liyuan Liu, Haoming Jiang, Pengcheng He, Weizhu Chen, Xiaodong Liu, Jianfeng Gao, Jiawei Han, 2019, arxiv.org:1908.03265

[14] Cyclical Learning Rates for Training Neural Networks, Leslie N. Smith, 2017, arxiv.org:1506.01186

[15] Super-Convergence: Very Fast Training of Neural Networks Using Large Learning Rates, Leslie N. Smith, Nicholay Topin, 2017, arxiv.org:1708.0712

[16] Optuna: A Next-generation Hyperparameter Optimization Framework. Takuya Akiba, Shotaro Sano, Toshihiko Yanase, Takeru Ohta, Masanori Koyama, 2019, arXiv:1907.10902

[17] Algorithms for Hyper-Parameter Optimization, James Bergstra, Rémi Bardenet, Yoshua Bengio, Balázs Kégl, 2011, 25th Conference on Neural Information Processing Systems

[18] Focal Loss for Dense Object Detection, Tsung-Yi Lin, Priya Goyal, Ross Girshick, Kaiming He, Piotr Dollár, 2017, arxiv.org:1708.02002

[19] Wide & Deep Learning for Recommender Systems, Heng-Tze Cheng, Levent Koc, Jeremiah Harmsen, et al, 2016, arxiv.org:1606.07792

[20] FNet: Mixing Tokens with Fourier Transforms, James Lee-Thorp, Joshua Ainslie, Ilya Eckstein, Santiago Ontanon, 2021, arxiv.org:2105.03824

[21] Pay Attention to MLPs, Hanxiao Liu, Zihang Dai, David R. So, Quoc V. Le, 2021, arxiv.org:2105.08050

[22] ResMLP: Feedforward networks for image classification with data-efficient training, Hugo Touvron, Piotr Bojanowski, Mathilde Caron, et al, 2021, arxiv.org:2105.03404