%load_ext autoreload
%autoreload 2
%matplotlib inline
from pathlib import Path
from fastai.vision import *
from fastai import __version__ as fastai_version
print(fastai_version)
This Notebook is a work-in-progress effort to implement multitask learning in fastai. To goal is to be able to fit several categorization or regression problems with a single neural network.
To illustrate our work, we are going to use the UTKFace dataset (https://susanqq.github.io/UTKFace/), and predict the age (regression), gender (classification) and ethnicity (classification) with a single model.
We download a sample dataset from Kaggle:
! kaggle datasets download jangedoo/utkface-new -p data/utkface
!unzip data/utkface/utkface-new.zip -d data/utkface
pdata = Path('data/utkface/crop_part1/')
filenames = [os.path.basename(f) for f in list(pdata.glob('*'))]
# Remove badly encoded files from list:
filenames.remove('61_3_20170109150557335.jpg.chip.jpg')
filenames.remove('61_1_20170109142408075.jpg.chip.jpg')
enc_age, enc_gender, enc_ethnicity = zip(*[f.split('_')[:3] for f in filenames])
age = [float(o) for o in enc_age]
gender_map = {'0': 'male', '1': 'female'}
gender = [gender_map[o] for o in enc_gender]
ethnicity_map = {'0': 'White', '1': 'Black', '2': 'Asian', '3': 'Indian', '4': 'Others'}
ethnicity = [ethnicity_map[o] for o in enc_ethnicity]
df = pd.DataFrame(list(zip(filenames, age, gender, ethnicity)), columns=['filename', 'age', 'gender', 'ethnicity'])
np.random.seed(42)
np.random.random()
df['is_valid'] = df.apply(lambda row: np.random.random() < 0.2, axis=1)
df.head()
open_image(pdata / df.iloc[1].filename)
We introduce 10% of NaN / "NA" values in the target labels to demonstrate how to missing data. Indeed, we want our model to be robust to datasets that are not fully annotated.
df['age'] = df.apply(lambda row: np.NaN if np.random.random() < 0.1 else row['age'], axis=1)
df['gender'] = df.apply(lambda row: 'NA' if np.random.random() < 0.1 else row['gender'], axis=1)
df['ethnicity'] = df.apply(lambda row: 'NA' if np.random.random() < 0.1 else row['ethnicity'], axis=1)
df.head(10)
To set up our multitask databunch, we are going to define several label lists for each sub-task, then combine them into a custom multitask LabelLists instance.
from fastai.data_block import _maybe_squeeze
class NanLabelImageList(ImageList):
def label_from_df(self, cols:IntsOrStrs=1, label_cls:Callable=None, **kwargs):
"Label `self.items` from the values in `cols` in `self.inner_df`."
labels = self.inner_df.iloc[:,df_names_to_idx(cols, self.inner_df)]
# Commented line:##
#assert labels.isna().sum().sum() == 0, f"You have NaN values in column(s) {cols} of your dataframe, please fix it."
####################
if is_listy(cols) and len(cols) > 1 and (label_cls is None or label_cls == MultiCategoryList):
new_kwargs,label_cls = dict(one_hot=True, classes= cols),MultiCategoryList
kwargs = {**new_kwargs, **kwargs}
return self._label_from_list(_maybe_squeeze(labels), label_cls=label_cls, **kwargs)
NormalizationProcessor
:class NormalizationProcessor(PreProcessor):
"`PreProcessor` that computes mean and std from `ds.items` and normalizes them."
def __init__(self, ds:ItemList):
self.compute_stats(ds)
self.state_attrs = ['mean', 'std']
def compute_stats(self, ds:ItemList):
items = ds.items[~np.isnan(ds.items)]
self.mean = items.mean()
self.std = items.std()
def process_one(self,item):
if isinstance(item, EmptyLabel): return item
return (item - self.mean) / self.std
def unprocess_one(self, item):
if isinstance(item, EmptyLabel): return item
return item * self.std + self.mean
def process(self, ds):
if self.mean is None:
self.compute_stats(ds)
ds.mean = self.mean
ds.std = self.std
super().process(ds)
def __getstate__(self):
return {n:getattr(self,n) for n in self.state_attrs}
def __setstate__(self, state:dict):
self.state_attrs = state.keys()
for n in state.keys():
setattr(self, n, state[n])
class NormalizedFloatList(FloatList):
_processor = NormalizationProcessor
labelLists
for age gender and ethnicity as usual. Make sure the split is the same for all of them.gender_labels = (
NanLabelImageList.from_df(df, path=pdata, cols='filename')
.split_from_df(col='is_valid')
.label_from_df(cols='gender')
)
ethnicity_labels = (
NanLabelImageList.from_df(df, path=pdata, cols='filename')
.split_from_df(col='is_valid')
.label_from_df(cols='ethnicity')
)
age_labels = (
NanLabelImageList.from_df(df, path=pdata, cols='filename')
.split_from_df(col='is_valid')
.label_from_df(cols='age',label_cls=NormalizedFloatList)
)
multitask_project = {
'gender': {
'label_lists': gender_labels,
'metric': accuracy
},
'ethnicity': {
'label_lists': ethnicity_labels,
'metric': accuracy
},
'age': {
'label_lists': age_labels,
'metric': rmse,
}
}
MultitaskItem
and MultitaskItemList
classes which are respectively sub-classes of MixedItem
and MixedItemList
. The goal is:MultitaskItems
to be a list of other items (FloatItem
or Category
), so that inputs are converted to a tensor concatenating the encodings of sub-inputs.# Monkey patch FloatItem with a better default string formatting.
def float_str(self):
return "{:.4g}".format(self.obj)
FloatItem.__str__ = float_str
class MultitaskItem(MixedItem):
def __init__(self, *args, mt_names=None, **kwargs):
super().__init__(*args,**kwargs)
self.mt_names = mt_names
def __repr__(self):
return '|'.join([f'{self.mt_names[i]}:{item}' for i, item in enumerate(self.obj)])
class MultitaskItemList(MixedItemList):
def __init__(self, *args, mt_names=None, **kwargs):
super().__init__(*args,**kwargs)
self.mt_classes = [getattr(il, 'classes', None) for il in self.item_lists]
self.mt_types = [type(il) for il in self.item_lists]
self.mt_lengths = [len(i) if i else 1 for i in self.mt_classes]
self.mt_names = mt_names
def get(self, i):
return MultitaskItem([il.get(i) for il in self.item_lists], mt_names=self.mt_names)
def reconstruct(self, t_list):
items = []
t_list = self.unprocess_one(t_list)
for i,t in enumerate(t_list):
if self.mt_types[i] == CategoryList:
items.append(Category(t, self.mt_classes[i][t]))
elif issubclass(self.mt_types[i], FloatList):
items.append(FloatItem(t))
return MultitaskItem(items, mt_names=self.mt_names)
def analyze_pred(self, pred, thresh:float=0.5):
predictions = []
start = 0
for length, mt_type in zip(self.mt_lengths, self.mt_types):
if mt_type == CategoryList:
predictions.append(pred[start: start + length].argmax())
elif issubclass(mt_type, FloatList):
predictions.append(pred[start: start + length][0])
start += length
return predictions
def unprocess_one(self, item, processor=None):
if processor is not None: self.processor = processor
self.processor = listify(self.processor)
for p in self.processor:
item = _processor_unprocess_one(p, item)
return item
def _processor_unprocess_one(self, item:Any): # TODO: global function to avoid subclassing MixedProcessor. To be cleaned.
res = []
for procs, i in zip(self.procs, item):
for p in procs:
if hasattr(p, 'unprocess_one'):
i = p.unprocess_one(i)
res.append(i)
return res
LabelList
and LabelLists
to store and load all necessary state, on order to later be able to export and load the model (to run it in production for example).class MultitaskLabelList(LabelList):
def get_state(self, **kwargs):
kwargs.update({
'mt_classes': self.mt_classes,
'mt_types': self.mt_types,
'mt_lengths': self.mt_lengths,
'mt_names': self.mt_names
})
return super().get_state(**kwargs)
@classmethod
def load_state(cls, path:PathOrStr, state:dict) -> 'LabelList':
res = super().load_state(path, state)
res.mt_classes = state['mt_classes']
res.mt_types = state['mt_types']
res.mt_lengths = state['mt_lengths']
res.mt_names = state['mt_names']
return res
class MultitaskLabelLists(LabelLists):
@classmethod
def load_state(cls, path:PathOrStr, state:dict):
path = Path(path)
train_ds = MultitaskLabelList.load_state(path, state)
valid_ds = MultitaskLabelList.load_state(path, state)
return MultitaskLabelLists(path, train=train_ds, valid=valid_ds)
MultitaskItemLists
given our predefined label lists, and groups them in a final MultitaskLabelLists
instance.def label_from_mt_project(self, multitask_project):
mt_train_list = MultitaskItemList(
[task['label_lists'].train.y for task in multitask_project.values()],
mt_names=list(multitask_project.keys(),
)
)
mt_valid_list = MultitaskItemList(
[task['label_lists'].valid.y for task in multitask_project.values()],
mt_names=list(multitask_project.keys())
)
self.train = self.train._label_list(x=self.train, y=mt_train_list)
self.valid = self.valid._label_list(x=self.valid, y=mt_valid_list)
self.__class__ = MultitaskLabelLists # TODO: Class morphing should be avoided, to be improved.
self.train.__class__ = MultitaskLabelList
self.valid.__class__ = MultitaskLabelList
return self
ItemLists.label_from_mt_project = label_from_mt_project
image_lists = ImageList.from_df(df, path=pdata, cols='filename').split_from_df(col='is_valid')
mt_label_lists = image_lists.label_from_mt_project(multitask_project)
mt_label_lists
tfms = get_transforms()
data = mt_label_lists.transform(tfms, size=128).databunch(bs=48).normalize(imagenet_stats)
data.show_batch()