#export
from fastai2.torch_basics import *
from fastai2.tabular.core import *
#hide
from nbdev.showdoc import *
#default_exp tabular.model
A basic model that can be used on tabular data
#export
def emb_sz_rule(n_cat):
"Rule of thumb to pick embedding size corresponding to `n_cat`"
return min(600, round(1.6 * n_cat**0.56))
#export
def _one_emb_sz(classes, n, sz_dict=None):
"Pick an embedding size for `n` depending on `classes` if not given in `sz_dict`."
sz_dict = ifnone(sz_dict, {})
n_cat = len(classes[n])
sz = sz_dict.get(n, int(emb_sz_rule(n_cat))) # rule of thumb
return n_cat,sz
#export
def get_emb_sz(to, sz_dict=None):
"Get default embedding size from `TabularPreprocessor` `proc` or the ones in `sz_dict`"
return [_one_emb_sz(to.classes, n, sz_dict) for n in to.cat_names]
#export
class TabularModel(Module):
"Basic model for tabular data."
def __init__(self, emb_szs, n_cont, out_sz, layers, ps=None, embed_p=0.,
y_range=None, use_bn=True, bn_final=False, bn_cont=True):
ps = ifnone(ps, [0]*len(layers))
if not is_listy(ps): ps = [ps]*len(layers)
self.embeds = nn.ModuleList([Embedding(ni, nf) for ni,nf in emb_szs])
self.emb_drop = nn.Dropout(embed_p)
self.bn_cont = nn.BatchNorm1d(n_cont) if bn_cont else None
n_emb = sum(e.embedding_dim for e in self.embeds)
self.n_emb,self.n_cont = n_emb,n_cont
sizes = [n_emb + n_cont] + layers + [out_sz]
actns = [nn.ReLU(inplace=True) for _ in range(len(sizes)-2)] + [None]
_layers = [LinBnDrop(sizes[i], sizes[i+1], bn=use_bn and (i!=len(actns)-1 or bn_final), p=p, act=a)
for i,(p,a) in enumerate(zip(ps+[0.],actns))]
if y_range is not None: _layers.append(SigmoidRange(*y_range))
self.layers = nn.Sequential(*_layers)
def forward(self, x_cat, x_cont=None):
if self.n_emb != 0:
x = [e(x_cat[:,i]) for i,e in enumerate(self.embeds)]
x = torch.cat(x, 1)
x = self.emb_drop(x)
if self.n_cont != 0:
if self.bn_cont is not None: x_cont = self.bn_cont(x_cont)
x = torch.cat([x, x_cont], 1) if self.n_emb != 0 else x_cont
return self.layers(x)
#export
@delegates(TabularModel.__init__)
def tabular_config(**kwargs):
"Convenience function to easily create a config for `tabular_model`"
return kwargs
#hide
from nbdev.export import notebook2script
notebook2script()
Converted 00_torch_core.ipynb. Converted 01_layers.ipynb. Converted 02_data.load.ipynb. Converted 03_data.core.ipynb. Converted 04_data.external.ipynb. Converted 05_data.transforms.ipynb. Converted 06_data.block.ipynb. Converted 07_vision.core.ipynb. Converted 08_vision.data.ipynb. Converted 09_vision.augment.ipynb. Converted 09b_vision.utils.ipynb. Converted 09c_vision.widgets.ipynb. Converted 10_tutorial.pets.ipynb. Converted 11_vision.models.xresnet.ipynb. Converted 12_optimizer.ipynb. Converted 13_callback_core.ipynb. Converted 13a_learner.ipynb. Converted 13b_metrics.ipynb. Converted 14_callback.schedule.ipynb. Converted 14a_callback.data.ipynb. Converted 15_callback.hook.ipynb. Converted 15a_vision.models.unet.ipynb. Converted 16_callback.progress.ipynb. Converted 17_callback.tracker.ipynb. Converted 18_callback.fp16.ipynb. Converted 19_callback.mixup.ipynb. Converted 20_interpret.ipynb. Converted 20a_distributed.ipynb. Converted 21_vision.learner.ipynb. Converted 22_tutorial.imagenette.ipynb. Converted 23_tutorial.transfer_learning.ipynb. Converted 30_text.core.ipynb. Converted 31_text.data.ipynb. Converted 32_text.models.awdlstm.ipynb. Converted 33_text.models.core.ipynb. Converted 34_callback.rnn.ipynb. Converted 35_tutorial.wikitext.ipynb. Converted 36_text.models.qrnn.ipynb. Converted 37_text.learner.ipynb. Converted 38_tutorial.ulmfit.ipynb. Converted 40_tabular.core.ipynb. Converted 41_tabular.data.ipynb. Converted 42_tabular.model.ipynb. Converted 43_tabular.learner.ipynb. Converted 45_collab.ipynb. Converted 50_datablock_examples.ipynb. Converted 60_medical.imaging.ipynb. Converted 65_medical.text.ipynb. Converted 70_callback.wandb.ipynb. Converted 71_callback.tensorboard.ipynb. Converted 72_callback.neptune.ipynb. Converted 97_test_utils.ipynb. Converted 99_pytorch_doc.ipynb. Converted index.ipynb.