In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

from pathlib import Path
from import *
from fastai import __version__ as fastai_version


This Notebook is a work-in-progress effort to implement multitask learning in fastai. To goal is to be able to fit several categorization or regression problems with a single neural network.

To illustrate our work, we are going to use the UTKFace dataset (, and predict the age (regression), gender (classification) and ethnicity (classification) with a single model.

Dataset Preparation

We download a sample dataset from Kaggle:

In [2]:
! kaggle datasets download jangedoo/utkface-new -p data/utkface
User cancelled operation
In [3]:
!unzip data/utkface/ -d data/utkface
Archive:  data/utkface/
replace data/utkface/crop_part1/100_1_0_20170110183726390.jpg.chip.jpg? [y]es, [n]o, [A]ll, [N]one, [r]ename: ^C
In [3]:
pdata = Path('data/utkface/crop_part1/')
In [4]:
filenames = [os.path.basename(f) for f in list(pdata.glob('*'))]
In [5]:
# Remove badly encoded files from list:
In [6]:
enc_age, enc_gender, enc_ethnicity = zip(*[f.split('_')[:3] for f in filenames])
In [7]:
age = [float(o) for o in enc_age]
In [8]:
gender_map = {'0': 'male', '1': 'female'}
gender = [gender_map[o] for o in enc_gender]
In [9]:
ethnicity_map = {'0': 'White', '1': 'Black', '2': 'Asian', '3': 'Indian', '4': 'Others'}
ethnicity = [ethnicity_map[o] for o in enc_ethnicity]
In [10]:
df = pd.DataFrame(list(zip(filenames, age, gender, ethnicity)), columns=['filename', 'age', 'gender', 'ethnicity'])
In [11]:
df['is_valid'] = df.apply(lambda row: np.random.random() < 0.2, axis=1)
In [12]:
filename age gender ethnicity is_valid
0 62_0_3_20170109015431667.jpg.chip.jpg 62.0 male Indian False
1 18_1_0_20170109214216731.jpg.chip.jpg 18.0 female White False
2 13_0_0_20170110224337867.jpg.chip.jpg 13.0 male White False
3 56_1_0_20170109220607828.jpg.chip.jpg 56.0 female White True
4 29_0_0_20170104184656046.jpg.chip.jpg 29.0 male White True
In [13]:
open_image(pdata / df.iloc[1].filename)

We introduce 10% of NaN / "NA" values in the target labels to demonstrate how to missing data. Indeed, we want our model to be robust to datasets that are not fully annotated.

In [14]:
df['age'] = df.apply(lambda row: np.NaN if np.random.random() < 0.1 else row['age'], axis=1)
df['gender'] = df.apply(lambda row: 'NA' if np.random.random() < 0.1 else row['gender'], axis=1)
df['ethnicity'] = df.apply(lambda row: 'NA' if np.random.random() < 0.1 else row['ethnicity'], axis=1)
In [15]:
filename age gender ethnicity is_valid
0 62_0_3_20170109015431667.jpg.chip.jpg 62.0 NA Indian False
1 18_1_0_20170109214216731.jpg.chip.jpg 18.0 female White False
2 13_0_0_20170110224337867.jpg.chip.jpg 13.0 male White False
3 56_1_0_20170109220607828.jpg.chip.jpg 56.0 female White True
4 29_0_0_20170104184656046.jpg.chip.jpg 29.0 male White True
5 33_1_0_20170104165723873.jpg.chip.jpg 33.0 female White True
6 47_0_0_20170111181750442.jpg.chip.jpg 47.0 male White False
7 20_1_0_20170103163040136.jpg.chip.jpg NaN female NA False
8 7_1_4_20161223232252196.jpg.chip.jpg 7.0 female NA False
9 32_1_0_20170104185215238.jpg.chip.jpg 32.0 female White True

Multitask DataBunch

To set up our multitask databunch, we are going to define several label lists for each sub-task, then combine them into a custom multitask LabelLists instance.

  1. First let's subclass ImageList to prevent it from throwing an error on NaN labels.
In [16]:
from fastai.data_block import _maybe_squeeze
class NanLabelImageList(ImageList):
    def label_from_df(self, cols:IntsOrStrs=1, label_cls:Callable=None, **kwargs):
        "Label `self.items` from the values in `cols` in `self.inner_df`."
        labels = self.inner_df.iloc[:,df_names_to_idx(cols, self.inner_df)]
        # Commented line:##
        #assert labels.isna().sum().sum() == 0, f"You have NaN values in column(s) {cols} of your dataframe, please fix it."
        if is_listy(cols) and len(cols) > 1 and (label_cls is None or label_cls == MultiCategoryList):
            new_kwargs,label_cls = dict(one_hot=True, classes= cols),MultiCategoryList
            kwargs = {**new_kwargs, **kwargs}
        return self._label_from_list(_maybe_squeeze(labels), label_cls=label_cls, **kwargs)
  1. We also need to make sure that we normalize float values in regression sub tasks. Indeed, otherwise the sub-losses would take very large values compared to cross-entropy, which would make the multitask model hard and long to fit. To do that, we define a NormalizationProcessor:
In [17]:
class NormalizationProcessor(PreProcessor):
    "`PreProcessor` that computes mean and std from `ds.items` and normalizes them."
    def __init__(self, ds:ItemList):        
        self.state_attrs = ['mean', 'std']
    def compute_stats(self, ds:ItemList):
        items = ds.items[~np.isnan(ds.items)]
        self.mean = items.mean()
        self.std = items.std()

    def process_one(self,item):
        if isinstance(item, EmptyLabel): return item
        return (item - self.mean) / self.std
    def unprocess_one(self, item):
        if isinstance(item, EmptyLabel): return item
        return item * self.std + self.mean
    def process(self, ds):
        if self.mean is None: 
        ds.mean = self.mean
        ds.std = self.std

    def __getstate__(self): 
        return {n:getattr(self,n) for n in self.state_attrs}

    def __setstate__(self, state:dict):
        self.state_attrs = state.keys()
        for n in state.keys():
            setattr(self, n, state[n])
class NormalizedFloatList(FloatList):
    _processor = NormalizationProcessor
  1. Prepare labelLists for age gender and ethnicity as usual. Make sure the split is the same for all of them.
In [18]:
gender_labels = (
    NanLabelImageList.from_df(df, path=pdata, cols='filename')
ethnicity_labels = (
    NanLabelImageList.from_df(df, path=pdata, cols='filename')
age_labels = (
    NanLabelImageList.from_df(df, path=pdata, cols='filename')
  1. We store them in a dict for convenience, the key will be a printable string we will use to describe each task. We also define an optional metric function that we will later use.
In [19]:
multitask_project = {
    'gender': {
        'label_lists': gender_labels,
        'metric': accuracy
    'ethnicity': {
        'label_lists': ethnicity_labels,
        'metric': accuracy
    'age': {
        'label_lists': age_labels,
        'metric': rmse,
  1. We now define our MultitaskItem and MultitaskItemList classes which are respectively sub-classes of MixedItem and MixedItemList. The goal is:
    • define MultitaskItems to be a list of other items (FloatItem or Category), so that inputs are converted to a tensor concatenating the encodings of sub-inputs.
    • given a predicted tensor, split it given the lengths of each task, analyze and reconstruct the predicted output. Reconstructing should also de-normalize data.
    • adapt string formatting (in particular form Multitask items, serialize them as "key:value|key2:value2")
In [20]:
# Monkey patch FloatItem with a better default string formatting.
def float_str(self):
    return "{:.4g}".format(self.obj)
FloatItem.__str__ = float_str

class MultitaskItem(MixedItem):    
    def __init__(self, *args, mt_names=None, **kwargs):
        self.mt_names = mt_names
    def __repr__(self):
        return '|'.join([f'{self.mt_names[i]}:{item}' for i, item in enumerate(self.obj)])

class MultitaskItemList(MixedItemList):
    def __init__(self, *args, mt_names=None, **kwargs):
        self.mt_classes = [getattr(il, 'classes', None) for il in self.item_lists]
        self.mt_types = [type(il) for il in self.item_lists]
        self.mt_lengths = [len(i) if i else 1 for i in self.mt_classes]
        self.mt_names = mt_names
    def get(self, i):
        return MultitaskItem([il.get(i) for il in self.item_lists], mt_names=self.mt_names)
    def reconstruct(self, t_list):
        items = []
        t_list = self.unprocess_one(t_list)
        for i,t in enumerate(t_list):
            if self.mt_types[i] == CategoryList:
                items.append(Category(t, self.mt_classes[i][t]))
            elif issubclass(self.mt_types[i], FloatList):
        return MultitaskItem(items, mt_names=self.mt_names)
    def analyze_pred(self, pred, thresh:float=0.5):         
        predictions = []
        start = 0
        for length, mt_type in zip(self.mt_lengths, self.mt_types):
            if mt_type == CategoryList:
                predictions.append(pred[start: start + length].argmax())
            elif issubclass(mt_type, FloatList):
                predictions.append(pred[start: start + length][0])
            start += length
        return predictions

    def unprocess_one(self, item, processor=None):
        if processor is not None: self.processor = processor
        self.processor = listify(self.processor)
        for p in self.processor: 
            item = _processor_unprocess_one(p, item)
        return item

def _processor_unprocess_one(self, item:Any): # TODO: global function to avoid subclassing MixedProcessor. To be cleaned.
    res = []
    for procs, i in zip(self.procs, item):
        for p in procs: 
            if hasattr(p, 'unprocess_one'):
                i = p.unprocess_one(i)
    return res
  1. We also need to sub-class LabelList and LabelLists to store and load all necessary state, on order to later be able to export and load the model (to run it in production for example).
In [21]:
class MultitaskLabelList(LabelList):
    def get_state(self, **kwargs):
            'mt_classes': self.mt_classes,
            'mt_types': self.mt_types,
            'mt_lengths': self.mt_lengths,
            'mt_names': self.mt_names
        return super().get_state(**kwargs)

    def load_state(cls, path:PathOrStr, state:dict) -> 'LabelList':
        res = super().load_state(path, state)
        res.mt_classes = state['mt_classes']
        res.mt_types = state['mt_types']
        res.mt_lengths = state['mt_lengths']
        res.mt_names = state['mt_names']
        return res
class MultitaskLabelLists(LabelLists):
    def load_state(cls, path:PathOrStr, state:dict):
        path = Path(path)
        train_ds = MultitaskLabelList.load_state(path, state)
        valid_ds = MultitaskLabelList.load_state(path, state)
        return MultitaskLabelLists(path, train=train_ds, valid=valid_ds)
  1. Now we define an util function that constructs the MultitaskItemLists given our predefined label lists, and groups them in a final MultitaskLabelLists instance.
In [22]:
def label_from_mt_project(self, multitask_project):
    mt_train_list = MultitaskItemList(
        [task['label_lists'].train.y for task in multitask_project.values()], 
    mt_valid_list = MultitaskItemList(
        [task['label_lists'].valid.y for task in multitask_project.values()], 
    self.train = self.train._label_list(x=self.train, y=mt_train_list)
    self.valid = self.valid._label_list(x=self.valid, y=mt_valid_list)
    self.__class__ = MultitaskLabelLists # TODO: Class morphing should be avoided, to be improved.
    self.train.__class__ = MultitaskLabelList
    self.valid.__class__ = MultitaskLabelList
    return self

ItemLists.label_from_mt_project = label_from_mt_project
image_lists = ImageList.from_df(df, path=pdata, cols='filename').split_from_df(col='is_valid')
In [23]:
mt_label_lists = image_lists.label_from_mt_project(multitask_project)
In [24]:

Train: MultitaskLabelList (7780 items)
x: ImageList
Image (3, 200, 200),Image (3, 200, 200),Image (3, 200, 200),Image (3, 200, 200),Image (3, 200, 200)
y: MultitaskItemList
Path: data/utkface/crop_part1;

Valid: MultitaskLabelList (1998 items)
x: ImageList
Image (3, 200, 200),Image (3, 200, 200),Image (3, 200, 200),Image (3, 200, 200),Image (3, 200, 200)
y: MultitaskItemList
Path: data/utkface/crop_part1;

Test: None
  1. Continue by creating your databunch as usual, and display a sample batch to sanity check our result
In [25]:
tfms = get_transforms()
data = mt_label_lists.transform(tfms, size=128).databunch(bs=48).normalize(imagenet_stats)