#!/usr/bin/env python # coding: utf-8 # ## Get your data ready for training # This module defines the basic [`DataBunch`](/basic_data.html#DataBunch) object that is used inside [`Learner`](/basic_train.html#Learner) to train a model. This is the generic class, that can take any kind of fastai [`Dataset`](https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataset) or [`DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader). You'll find helpful functions in the data module of every application to directly create this [`DataBunch`](/basic_data.html#DataBunch) for you. # In[1]: from fastai.gen_doc.nbdoc import * from fastai.basics import * # In[2]: show_doc(DataBunch) # It also ensures all the dataloaders are on `device` and applies to them `dl_tfms` as batch are drawn (like normalization). `path` is used internally to store temporary files, `collate_fn` is passed to the pytorch `Dataloader` (replacing the one there) to explain how to collate the samples picked for a batch. By default, it applies data to the object sent (see in [`vision.image`](/vision.image.html#vision.image) or the [data block API](/data_block.html) why this can be important). # # `train_dl`, `valid_dl` and optionally `test_dl` will be wrapped in [`DeviceDataLoader`](/basic_data.html#DeviceDataLoader). # ### Factory method # In[3]: show_doc(DataBunch.create) # `num_workers` is the number of CPUs to use, `tfms`, `device` and `collate_fn` are passed to the init method. # In[ ]: jekyll_warn("You can pass regular pytorch Dataset here, but they'll require more attributes than the basic ones to work with the library. See below for more details.") # ### Visualization # In[4]: show_doc(DataBunch.show_batch) # ### Grabbing some data # In[5]: show_doc(DataBunch.dl) # In[6]: show_doc(DataBunch.one_batch) # In[7]: show_doc(DataBunch.one_item) # In[8]: show_doc(DataBunch.sanity_check) # ### Load and save # You can save your [`DataBunch`](/basic_data.html#DataBunch) object for future use with this method. # In[9]: show_doc(DataBunch.save) # In[10]: show_doc(load_data) # In[ ]: jekyll_important("The arguments you passed when you created your first `DataBunch` aren't saved, so you should pass them here if you don't want the default.") # In[ ]: jekyll_note("Data cannot be serialized on Windows and then loaded on Linux or vice versa because `Path` object doesn't support this. We will find a workaround for that in v2.") # This is to allow you to easily create a new [`DataBunch`](/basic_data.html#DataBunch) with a different batch size for instance. You will also need to reapply any normalization (in vision) you might have done on your original [`DataBunch`](/basic_data.html#DataBunch). # ### Empty [`DataBunch`](/basic_data.html#DataBunch) for inference # In[11]: show_doc(DataBunch.export) # In[12]: show_doc(DataBunch.load_empty, full_name='load_empty') # This method should be used to create a [`DataBunch`](/basic_data.html#DataBunch) at inference, see the corresponding [tutorial](/tutorial.inference.html). # In[13]: show_doc(DataBunch.add_test) # ### Dataloader transforms # In[14]: show_doc(DataBunch.add_tfm) # Adds a transform to all dataloaders. # ## Using a custom Dataset in fastai # If you want to use your pytorch [`Dataset`](https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataset) in fastai, you may need to implement more attributes/methods if you want to use the full functionality of the library. Some functions can easily be used with your pytorch [`Dataset`](https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataset) if you just add an attribute, for others, the best would be to create your own [`ItemList`](/data_block.html#ItemList) by following [this tutorial](/tutorial.itemlist.html). Here is a full list of what the library will expect. # ### Basics # First of all, you obviously need to implement the methods `__len__` and `__getitem__`, as indicated by the pytorch docs. Then the most needed things would be: # - `c` attribute: it's used in most functions that directly create a [`Learner`](/basic_train.html#Learner) ([`tabular_learner`](/tabular.learner.html#tabular_learner), [`text_classifier_learner`](/text.learner.html#text_classifier_learner), [`unet_learner`](/vision.learner.html#unet_learner), [`cnn_learner`](/vision.learner.html#cnn_learner)) and represents the number of outputs of the final layer of your model (also the number of classes if applicable). # - `classes` attribute: it's used by [`ClassificationInterpretation`](/train.html#ClassificationInterpretation) and also in [`collab_learner`](/collab.html#collab_learner) (best to use [`CollabDataBunch.from_df`](/collab.html#CollabDataBunch.from_df) than a pytorch [`Dataset`](https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataset)) and represents the unique tags that appear in your data. # - maybe a `loss_func` attribute: that is going to be used by [`Learner`](/basic_train.html#Learner) as a default loss function, so if you know your custom [`Dataset`](https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataset) requires a particular loss, you can put it. # # Toy example with image-like numpy arrays and binary label # In[ ]: class ArrayDataset(Dataset): "Sample numpy array dataset" def __init__(self, x, y): self.x, self.y = x, y self.c = 2 # binary label def __len__(self): return len(self.x) def __getitem__(self, i): return self.x[i], self.y[i] # In[ ]: train_x = np.random.rand(10, 3, 3) # 10 images (3x3) train_y = np.random.rand(10, 1).round() # binary label valid_x = np.random.rand(10, 3, 3) valid_y = np.random.rand(10, 1).round() train_ds, valid_ds = ArrayDataset(train_x, train_y), ArrayDataset(valid_x, valid_y) data = DataBunch.create(train_ds, valid_ds, bs=2, num_workers=1) data.one_batch() # ### For a specific application # In text, your dataset will need to have a `vocab` attribute that should be an instance of [`Vocab`](/text.transform.html#Vocab). It's used by [`text_classifier_learner`](/text.learner.html#text_classifier_learner) and [`language_model_learner`](/text.learner.html#language_model_learner) when building the model. # # In tabular, your dataset will need to have a `cont_names` attribute (for the names of continuous variables) and a `get_emb_szs` method that returns a list of tuple `(n_classes, emb_sz)` representing, for each categorical variable, the number of different codes (don't forget to add 1 for nan) and the corresponding embedding size. Those two are used with the `c` attribute by [`tabular_learner`](/tabular.learner.html#tabular_learner). # ### Functions that really won't work # To make those last functions work, you really need to use the [data block API](/data_block.html) and maybe write your own [custom ItemList](/tutorial.itemlist.html). # - [`DataBunch.show_batch`](/basic_data.html#DataBunch.show_batch) (requires `.x.reconstruct`, `.y.reconstruct` and `.x.show_xys`) # - [`Learner.predict`](/basic_train.html#Learner.predict) (requires `x.set_item`, `.y.analyze_pred`, `.y.reconstruct` and maybe `.x.reconstruct`) # - [`Learner.show_results`](/basic_train.html#Learner.show_results) (requires `x.reconstruct`, `y.analyze_pred`, `y.reconstruct` and `x.show_xyzs`) # - `DataBunch.set_item` (requires `x.set_item`) # - [`Learner.backward`](/basic_train.html#Learner.backward) (uses `DataBunch.set_item`) # - [`DataBunch.export`](/basic_data.html#DataBunch.export) (requires `export`) # In[15]: show_doc(DeviceDataLoader) # Put the batches of `dl` on `device` after applying an optional list of `tfms`. `collate_fn` will replace the one of `dl`. All dataloaders of a [`DataBunch`](/basic_data.html#DataBunch) are of this type. # ### Factory method # In[16]: show_doc(DeviceDataLoader.create) # The given `collate_fn` will be used to put the samples together in one batch (by default it grabs their data attribute). `shuffle` means the dataloader will take the samples randomly if that flag is set to `True`, or in the right order otherwise. `tfms` are passed to the init method. All `kwargs` are passed to the pytorch [`DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) class initialization. # ### Methods # In[17]: show_doc(DeviceDataLoader.add_tfm) # In[18]: show_doc(DeviceDataLoader.remove_tfm) # In[19]: show_doc(DeviceDataLoader.new) # In[20]: show_doc(DeviceDataLoader.proc_batch) # In[21]: show_doc(DatasetType, doc_string=False) # Internal enumerator to name the training, validation and test dataset/dataloader. # ## Open This Notebook # # # ## Undocumented Methods - Methods moved below this line will intentionally be hidden # In[22]: show_doc(DeviceDataLoader.collate_fn) # ## New Methods - Please document or move to the undocumented section