%load_ext autoreload
%autoreload 2
%matplotlib inline
from pathlib import Path
from fastai.vision import *
from fastai import __version__ as fastai_version
print(fastai_version)
1.0.52
This Notebook is a work-in-progress effort to implement multitask learning in fastai. To goal is to be able to fit several categorization or regression problems with a single neural network.
To illustrate our work, we are going to use the UTKFace dataset (https://susanqq.github.io/UTKFace/), and predict the age (regression), gender (classification) and ethnicity (classification) with a single model.
We download a sample dataset from Kaggle:
! kaggle datasets download jangedoo/utkface-new -p data/utkface
^C User cancelled operation
!unzip data/utkface/utkface-new.zip -d data/utkface
Archive: data/utkface/utkface-new.zip replace data/utkface/crop_part1/100_1_0_20170110183726390.jpg.chip.jpg? [y]es, [n]o, [A]ll, [N]one, [r]ename: ^C
pdata = Path('data/utkface/crop_part1/')
filenames = [os.path.basename(f) for f in list(pdata.glob('*'))]
# Remove badly encoded files from list:
filenames.remove('61_3_20170109150557335.jpg.chip.jpg')
filenames.remove('61_1_20170109142408075.jpg.chip.jpg')
enc_age, enc_gender, enc_ethnicity = zip(*[f.split('_')[:3] for f in filenames])
age = [float(o) for o in enc_age]
gender_map = {'0': 'male', '1': 'female'}
gender = [gender_map[o] for o in enc_gender]
ethnicity_map = {'0': 'White', '1': 'Black', '2': 'Asian', '3': 'Indian', '4': 'Others'}
ethnicity = [ethnicity_map[o] for o in enc_ethnicity]
df = pd.DataFrame(list(zip(filenames, age, gender, ethnicity)), columns=['filename', 'age', 'gender', 'ethnicity'])
np.random.seed(42)
np.random.random()
df['is_valid'] = df.apply(lambda row: np.random.random() < 0.2, axis=1)
df.head()
filename | age | gender | ethnicity | is_valid | |
---|---|---|---|---|---|
0 | 62_0_3_20170109015431667.jpg.chip.jpg | 62.0 | male | Indian | False |
1 | 18_1_0_20170109214216731.jpg.chip.jpg | 18.0 | female | White | False |
2 | 13_0_0_20170110224337867.jpg.chip.jpg | 13.0 | male | White | False |
3 | 56_1_0_20170109220607828.jpg.chip.jpg | 56.0 | female | White | True |
4 | 29_0_0_20170104184656046.jpg.chip.jpg | 29.0 | male | White | True |
open_image(pdata / df.iloc[1].filename)
We introduce 10% of NaN / "NA" values in the target labels to demonstrate how to missing data. Indeed, we want our model to be robust to datasets that are not fully annotated.
df['age'] = df.apply(lambda row: np.NaN if np.random.random() < 0.1 else row['age'], axis=1)
df['gender'] = df.apply(lambda row: 'NA' if np.random.random() < 0.1 else row['gender'], axis=1)
df['ethnicity'] = df.apply(lambda row: 'NA' if np.random.random() < 0.1 else row['ethnicity'], axis=1)
df.head(10)
filename | age | gender | ethnicity | is_valid | |
---|---|---|---|---|---|
0 | 62_0_3_20170109015431667.jpg.chip.jpg | 62.0 | NA | Indian | False |
1 | 18_1_0_20170109214216731.jpg.chip.jpg | 18.0 | female | White | False |
2 | 13_0_0_20170110224337867.jpg.chip.jpg | 13.0 | male | White | False |
3 | 56_1_0_20170109220607828.jpg.chip.jpg | 56.0 | female | White | True |
4 | 29_0_0_20170104184656046.jpg.chip.jpg | 29.0 | male | White | True |
5 | 33_1_0_20170104165723873.jpg.chip.jpg | 33.0 | female | White | True |
6 | 47_0_0_20170111181750442.jpg.chip.jpg | 47.0 | male | White | False |
7 | 20_1_0_20170103163040136.jpg.chip.jpg | NaN | female | NA | False |
8 | 7_1_4_20161223232252196.jpg.chip.jpg | 7.0 | female | NA | False |
9 | 32_1_0_20170104185215238.jpg.chip.jpg | 32.0 | female | White | True |
To set up our multitask databunch, we are going to define several label lists for each sub-task, then combine them into a custom multitask LabelLists instance.
from fastai.data_block import _maybe_squeeze
class NanLabelImageList(ImageList):
def label_from_df(self, cols:IntsOrStrs=1, label_cls:Callable=None, **kwargs):
"Label `self.items` from the values in `cols` in `self.inner_df`."
labels = self.inner_df.iloc[:,df_names_to_idx(cols, self.inner_df)]
# Commented line:##
#assert labels.isna().sum().sum() == 0, f"You have NaN values in column(s) {cols} of your dataframe, please fix it."
####################
if is_listy(cols) and len(cols) > 1 and (label_cls is None or label_cls == MultiCategoryList):
new_kwargs,label_cls = dict(one_hot=True, classes= cols),MultiCategoryList
kwargs = {**new_kwargs, **kwargs}
return self._label_from_list(_maybe_squeeze(labels), label_cls=label_cls, **kwargs)
NormalizationProcessor
:class NormalizationProcessor(PreProcessor):
"`PreProcessor` that computes mean and std from `ds.items` and normalizes them."
def __init__(self, ds:ItemList):
self.compute_stats(ds)
self.state_attrs = ['mean', 'std']
def compute_stats(self, ds:ItemList):
items = ds.items[~np.isnan(ds.items)]
self.mean = items.mean()
self.std = items.std()
def process_one(self,item):
if isinstance(item, EmptyLabel): return item
return (item - self.mean) / self.std
def unprocess_one(self, item):
if isinstance(item, EmptyLabel): return item
return item * self.std + self.mean
def process(self, ds):
if self.mean is None:
self.compute_stats(ds)
ds.mean = self.mean
ds.std = self.std
super().process(ds)
def __getstate__(self):
return {n:getattr(self,n) for n in self.state_attrs}
def __setstate__(self, state:dict):
self.state_attrs = state.keys()
for n in state.keys():
setattr(self, n, state[n])
class NormalizedFloatList(FloatList):
_processor = NormalizationProcessor
labelLists
for age gender and ethnicity as usual. Make sure the split is the same for all of them.gender_labels = (
NanLabelImageList.from_df(df, path=pdata, cols='filename')
.split_from_df(col='is_valid')
.label_from_df(cols='gender')
)
ethnicity_labels = (
NanLabelImageList.from_df(df, path=pdata, cols='filename')
.split_from_df(col='is_valid')
.label_from_df(cols='ethnicity')
)
age_labels = (
NanLabelImageList.from_df(df, path=pdata, cols='filename')
.split_from_df(col='is_valid')
.label_from_df(cols='age',label_cls=NormalizedFloatList)
)
multitask_project = {
'gender': {
'label_lists': gender_labels,
'metric': accuracy
},
'ethnicity': {
'label_lists': ethnicity_labels,
'metric': accuracy
},
'age': {
'label_lists': age_labels,
'metric': rmse,
}
}
MultitaskItem
and MultitaskItemList
classes which are respectively sub-classes of MixedItem
and MixedItemList
. The goal is:MultitaskItems
to be a list of other items (FloatItem
or Category
), so that inputs are converted to a tensor concatenating the encodings of sub-inputs.# Monkey patch FloatItem with a better default string formatting.
def float_str(self):
return "{:.4g}".format(self.obj)
FloatItem.__str__ = float_str
class MultitaskItem(MixedItem):
def __init__(self, *args, mt_names=None, **kwargs):
super().__init__(*args,**kwargs)
self.mt_names = mt_names
def __repr__(self):
return '|'.join([f'{self.mt_names[i]}:{item}' for i, item in enumerate(self.obj)])
class MultitaskItemList(MixedItemList):
def __init__(self, *args, mt_names=None, **kwargs):
super().__init__(*args,**kwargs)
self.mt_classes = [getattr(il, 'classes', None) for il in self.item_lists]
self.mt_types = [type(il) for il in self.item_lists]
self.mt_lengths = [len(i) if i else 1 for i in self.mt_classes]
self.mt_names = mt_names
def get(self, i):
return MultitaskItem([il.get(i) for il in self.item_lists], mt_names=self.mt_names)
def reconstruct(self, t_list):
items = []
t_list = self.unprocess_one(t_list)
for i,t in enumerate(t_list):
if self.mt_types[i] == CategoryList:
items.append(Category(t, self.mt_classes[i][t]))
elif issubclass(self.mt_types[i], FloatList):
items.append(FloatItem(t))
return MultitaskItem(items, mt_names=self.mt_names)
def analyze_pred(self, pred, thresh:float=0.5):
predictions = []
start = 0
for length, mt_type in zip(self.mt_lengths, self.mt_types):
if mt_type == CategoryList:
predictions.append(pred[start: start + length].argmax())
elif issubclass(mt_type, FloatList):
predictions.append(pred[start: start + length][0])
start += length
return predictions
def unprocess_one(self, item, processor=None):
if processor is not None: self.processor = processor
self.processor = listify(self.processor)
for p in self.processor:
item = _processor_unprocess_one(p, item)
return item
def _processor_unprocess_one(self, item:Any): # TODO: global function to avoid subclassing MixedProcessor. To be cleaned.
res = []
for procs, i in zip(self.procs, item):
for p in procs:
if hasattr(p, 'unprocess_one'):
i = p.unprocess_one(i)
res.append(i)
return res
LabelList
and LabelLists
to store and load all necessary state, on order to later be able to export and load the model (to run it in production for example).class MultitaskLabelList(LabelList):
def get_state(self, **kwargs):
kwargs.update({
'mt_classes': self.mt_classes,
'mt_types': self.mt_types,
'mt_lengths': self.mt_lengths,
'mt_names': self.mt_names
})
return super().get_state(**kwargs)
@classmethod
def load_state(cls, path:PathOrStr, state:dict) -> 'LabelList':
res = super().load_state(path, state)
res.mt_classes = state['mt_classes']
res.mt_types = state['mt_types']
res.mt_lengths = state['mt_lengths']
res.mt_names = state['mt_names']
return res
class MultitaskLabelLists(LabelLists):
@classmethod
def load_state(cls, path:PathOrStr, state:dict):
path = Path(path)
train_ds = MultitaskLabelList.load_state(path, state)
valid_ds = MultitaskLabelList.load_state(path, state)
return MultitaskLabelLists(path, train=train_ds, valid=valid_ds)
MultitaskItemLists
given our predefined label lists, and groups them in a final MultitaskLabelLists
instance.def label_from_mt_project(self, multitask_project):
mt_train_list = MultitaskItemList(
[task['label_lists'].train.y for task in multitask_project.values()],
mt_names=list(multitask_project.keys(),
)
)
mt_valid_list = MultitaskItemList(
[task['label_lists'].valid.y for task in multitask_project.values()],
mt_names=list(multitask_project.keys())
)
self.train = self.train._label_list(x=self.train, y=mt_train_list)
self.valid = self.valid._label_list(x=self.valid, y=mt_valid_list)
self.__class__ = MultitaskLabelLists # TODO: Class morphing should be avoided, to be improved.
self.train.__class__ = MultitaskLabelList
self.valid.__class__ = MultitaskLabelList
return self
ItemLists.label_from_mt_project = label_from_mt_project
image_lists = ImageList.from_df(df, path=pdata, cols='filename').split_from_df(col='is_valid')
mt_label_lists = image_lists.label_from_mt_project(multitask_project)
mt_label_lists
MultitaskLabelLists; Train: MultitaskLabelList (7780 items) x: ImageList Image (3, 200, 200),Image (3, 200, 200),Image (3, 200, 200),Image (3, 200, 200),Image (3, 200, 200) y: MultitaskItemList gender:NA|ethnicity:Indian|age:1.325,gender:female|ethnicity:White|age:-0.4539,gender:male|ethnicity:White|age:-0.656,gender:male|ethnicity:White|age:0.7185,gender:female|ethnicity:NA|age:nan Path: data/utkface/crop_part1; Valid: MultitaskLabelList (1998 items) x: ImageList Image (3, 200, 200),Image (3, 200, 200),Image (3, 200, 200),Image (3, 200, 200),Image (3, 200, 200) y: MultitaskItemList gender:female|ethnicity:White|age:1.082,gender:male|ethnicity:White|age:-0.009168,gender:female|ethnicity:White|age:0.1525,gender:female|ethnicity:White|age:0.1121,gender:male|ethnicity:Asian|age:nan Path: data/utkface/crop_part1; Test: None
tfms = get_transforms()
data = mt_label_lists.transform(tfms, size=128).databunch(bs=48).normalize(imagenet_stats)
data.show_batch()
inputs_concat
for clarity). We need to split it into sub-tensors given lengths of each individual tasks.reduction
parameter: by default reduction="mean", and the loss_function simply outputs a tensor of size 1, but reduction may also be set to "none", in that case the loss_function should return the loss of each item in the batch.def _clean_nan_values(input, target, mt_type, mt_classes):
if mt_type == CategoryList and 'NA' in mt_classes:
index = mt_classes.index('NA')
nan_mask = target == index
input[nan_mask] = 0.
input[nan_mask][:, index] = 1e5
elif issubclass(mt_type, FloatList):
nan_mask = (torch.isnan(target)) | (target < 0)
input[nan_mask] = 0.
target[nan_mask] = 0.
return input, target
def multitask_loss(inputs_concat, *targets, **kwargs):
mt_lengths, mt_types = data.mt_lengths, data.mt_types # TODO: avoid global variable
start = 0
loss_size = targets[0].shape[0] if kwargs.get('reduction') == 'none' else 1
losses = torch.zeros([loss_size]).cuda()
for i, length in enumerate(data.mt_lengths):
input = inputs_concat[:,start: start + length]
target = targets[i]
input, target = _clean_nan_values(input, target, data.mt_types[i], data.mt_classes[i])
if data.mt_types[i] == CategoryList:
losses += CrossEntropyFlat(**kwargs)(input, target).cuda()
elif issubclass(data.mt_types[i], FloatList):
losses += MSELossFlat(**kwargs)(input, target).cuda()
start += length
if kwargs.get('reduction') == 'none':
return losses
return losses.sum()
We want to:
To achieve this automatically given our multitask_project
definition, we need:
mt_metric_generator
that will use partial
to customize the metrics with the relevant parameters.AverageMetric
to customize its display name.def _remove_nan_values(input, target, mt_type, mt_classes):
if mt_type == CategoryList and 'NA' in mt_classes:
index = mt_classes.index('NA')
nan_mask = target == index
elif issubclass(mt_type, FloatList):
nan_mask = (torch.isnan(target)) | (target < 0)
return input[nan_mask], target[nan_mask]
class MultitaskAverageMetric(AverageMetric):
def __init__(self, func, name=None):
super().__init__(func)
self.name = name # subclass uses this attribute in the __repr__ method.
def _mt_parametrable_metric(inputs, *targets, func, start=0, length=1, i=0):
input = inputs[:,start: start + length]
target = targets[i]
_remove_nan_values(input, target, data.mt_types[i], data.mt_classes[i]) # TODO: Avoid data global reference.
if func.__name__ == 'root_mean_squared_error':
processor = listify(learn.data.y.processor)
input = processor[0].procs[i][0].unprocess_one(input) # TODO: support multi-processors
target = processor[0].procs[i][0].unprocess_one(target.float())
return func(input, target)
def _format_metric_name(field_name, metric_func):
return f"{field_name} {metric_func.__name__.replace('root_mean_squared_error', 'RMSE')}"
def mt_metrics_generator(multitask_project, mt_lengths):
metrics = []
start = 0
for i, ((name, task), length) in enumerate(zip(multitask_project.items(), mt_lengths)):
metric_func = task.get('metric')
if metric_func:
partial_metric = partial(_mt_parametrable_metric, start=start, length=length, i=i, func=metric_func)
metrics.append(MultitaskAverageMetric(partial_metric, _format_metric_name(name,metric_func)))
start += length
return metrics
metrics = mt_metrics_generator(multitask_project, data.mt_lengths)
There's a minor trick here: the "cnn_learner" builder method expects "data.c" to contain the size of the expected output vector to automaticaly create the head. Given our custom LabelLists, this is not automatically set so we need to do it manually:
data.c = sum(data.mt_lengths)
Now we have everything to create our model. Since we are asking our model to solve several problems at a time, it's probably best to pick a complex and deep network to allow it to adjust its weights to optimize for all sub tasks and offer good performance. Here we choose Densenet169 but others might work too.
arch = models.densenet169
learn = cnn_learner(data, arch, loss_func=multitask_loss, metrics=metrics, path= '.')
/opt/conda/envs/fastai/lib/python3.6/site-packages/torchvision/models/densenet.py:212: UserWarning: nn.init.kaiming_normal is now deprecated in favor of nn.init.kaiming_normal_. nn.init.kaiming_normal(m.weight.data)
We can now train our network as usual:
learn.lr_find()
LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.
learn.recorder.plot()
lr = 1e-2
learn.fit_one_cycle(10, slice(lr), callbacks=[ShowGraph(learn)])
epoch | train_loss | valid_loss | gender accuracy | ethnicity accuracy | age RMSE | time |
---|---|---|---|---|---|---|
0 | 0.927583 | 1.206344 | 0.897397 | 0.758258 | 5.598151 | 00:37 |
1 | 0.916619 | 1.046868 | 0.883383 | 0.763263 | 5.536002 | 00:36 |
2 | 0.960678 | 1.259005 | 0.893894 | 0.757257 | 5.699761 | 00:35 |
3 | 0.953147 | 1.169957 | 0.899399 | 0.757758 | 6.292777 | 00:35 |
4 | 0.903761 | 1.056455 | 0.904905 | 0.759259 | 6.414293 | 00:35 |
5 | 0.867749 | 1.162231 | 0.906406 | 0.765766 | 5.666295 | 00:36 |
6 | 0.822735 | 1.005141 | 0.902402 | 0.765265 | 5.387472 | 00:36 |
7 | 0.797436 | 1.126028 | 0.894895 | 0.762262 | 5.466666 | 00:36 |
8 | 0.762927 | 1.037011 | 0.902402 | 0.767768 | 5.952827 | 00:36 |
9 | 0.730710 | 1.030611 | 0.900400 | 0.770270 | 5.247694 | 00:35 |
learn.unfreeze()
learn.lr_find()
learn.recorder.plot()
LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.
lr = 5 * 1e-5
learn.fit_one_cycle(10, slice(lr/10, lr), callbacks=[ShowGraph(learn)])
epoch | train_loss | valid_loss | gender accuracy | ethnicity accuracy | age RMSE | time |
---|---|---|---|---|---|---|
0 | 0.750827 | 1.007371 | 0.900901 | 0.769269 | 5.388041 | 00:44 |
1 | 0.737068 | 1.005097 | 0.901401 | 0.767768 | 5.189569 | 00:44 |
2 | 0.761702 | 0.987824 | 0.903403 | 0.774775 | 5.338120 | 00:43 |
3 | 0.741460 | 1.090896 | 0.899399 | 0.770270 | 5.172823 | 00:43 |
4 | 0.712070 | 1.143334 | 0.898899 | 0.771772 | 5.459701 | 00:44 |
5 | 0.720761 | 1.105154 | 0.899900 | 0.769770 | 5.161880 | 00:44 |
6 | 0.693091 | 1.046412 | 0.905405 | 0.771772 | 6.086848 | 00:43 |
7 | 0.677825 | 1.103123 | 0.900901 | 0.767768 | 5.830963 | 00:44 |
8 | 0.642743 | 1.368310 | 0.899900 | 0.772272 | 6.906080 | 00:44 |
9 | 0.652688 | 1.127869 | 0.899900 | 0.769269 | 5.159579 | 00:43 |
(Note: further work and fine-tuning would be necessary to better fit this model, but that's beyond the scope of this Notebook, which is more a general proof of concept of Multi-task learning).
plt.rcParams['figure.subplot.wspace'] = 0.7 # increase default margin between subplots.
learn.show_results()
We want to be able to plot top losses, similarly to what is done in the ClassificationInterpreter
.
The first step is to fix the learn.get_preds
method. Given that it calls several global functions we have to reimplement a few of them.
The only significant changes here are:
loss_batch
, return all ybs, not just yb[0]
get_preds
, call the loss function with (res[0], *res[1:])
instead of res[0], res[1]
import types
from fastai.basic_train import _loss_func2activ
def mt_loss_batch(model:nn.Module, xb:Tensor, yb:Tensor, loss_func:OptLossFunc=None, opt:OptOptimizer=None,
cb_handler:Optional[CallbackHandler]=None)->Tuple[Union[Tensor,int,float,str]]:
"Calculate loss and metrics for a batch, call out to callbacks as necessary."
cb_handler = ifnone(cb_handler, CallbackHandler())
if not is_listy(xb): xb = [xb]
if not is_listy(yb): yb = [yb]
out = model(*xb)
out = cb_handler.on_loss_begin(out)
if not loss_func:
#### Change return to give all values of yb, not just the first.###########
ybs = torch.stack([ybi.detach().long() for ybi in yb])
return (to_detach(out), *ybs)
###########################################################################
loss = loss_func(out, *yb)
if opt is not None:
loss,skip_bwd = cb_handler.on_backward_begin(loss)
if not skip_bwd: loss.backward()
if not cb_handler.on_backward_end(): opt.step()
if not cb_handler.on_step_end(): opt.zero_grad()
return loss.detach().cpu()
def mt_validate(model:nn.Module, dl:DataLoader, loss_func:OptLossFunc=None, cb_handler:Optional[CallbackHandler]=None,
pbar:Optional[PBar]=None, average=True, n_batch:Optional[int]=None)->Iterator[Tuple[Union[Tensor,int],...]]:
"Calculate `loss_func` of `model` on `dl` in evaluation mode."
model.eval()
with torch.no_grad():
val_losses,nums = [],[]
if cb_handler: cb_handler.set_dl(dl)
for xb,yb in progress_bar(dl, parent=pbar, leave=(pbar is not None)):
if cb_handler: xb, yb = cb_handler.on_batch_begin(xb, yb, train=False)
#### change call to custom mt_loss_batch function ##########
val_loss = mt_loss_batch(model, xb, yb, loss_func, cb_handler=cb_handler)
##########################################################
val_losses.append(val_loss)
if not is_listy(yb): yb = [yb]
nums.append(yb[0].shape[0])
if cb_handler and cb_handler.on_batch_end(val_losses[-1]): break
if n_batch and (len(nums)>=n_batch): break
nums = np.array(nums, dtype=np.float32)
if average: return (to_np(torch.stack(val_losses)) * nums).sum() / nums.sum()
else: return val_losses
def mt_get_preds(model:nn.Module, dl:DataLoader, pbar:Optional[PBar]=None, cb_handler:Optional[CallbackHandler]=None,
activ:nn.Module=None, loss_func:OptLossFunc=None, n_batch:Optional[int]=None) -> List[Tensor]:
"Tuple of predictions and targets, and optional losses (if `loss_func`) using `dl`, max batches `n_batch`."
res = [torch.cat(o).cpu() for o in
#### change call to custom mt_validate function ##########
zip(*mt_validate(model, dl, cb_handler=cb_handler, pbar=pbar, average=False, n_batch=n_batch))]
##########################################################
if loss_func is not None:
with NoneReduceOnCPU(loss_func) as lf:
#### Replace first target by list of targets: ###########
res.append(lf(res[0], *res[1:]))
#########################################################
if activ is not None: res[0] = activ(res[0])
return res
def mt_get_preds_method(self, ds_type:DatasetType=DatasetType.Valid, with_loss:bool=False, n_batch:Optional[int]=None,
pbar:Optional[PBar]=None) -> List[Tensor]:
"Return predictions and targets on `ds_type` dataset."
lf = self.loss_func if with_loss else None
#### change call to custom mt_get_preds function ##########
return mt_get_preds(self.model, self.dl(ds_type), cb_handler=CallbackHandler(self.callbacks),
activ=_loss_func2activ(self.loss_func), loss_func=lf, n_batch=n_batch, pbar=pbar)
learn.get_preds = types.MethodType(mt_get_preds_method, learn)
pred_groups = learn.get_preds(with_loss=True)
preds,y,losses = pred_groups[0], pred_groups[1:-1], pred_groups[-1]
Then we can create a MultitaskInterpretation
class similar to ClassificationInterpretation
from fastai.callbacks.hooks import hook_output
class MultitaskInterpretation():
def __init__(self, learn:Learner, probs:Tensor, y_true:Tensor, losses:Tensor, ds_type:DatasetType=DatasetType.Valid):
self.data,self.probs,self.y_true,self.losses,self.ds_type, self.learn= learn.data,probs,y_true,losses,ds_type,learn
def plot_top_losses(self, k, largest=True, figsize=(12,12), heatmap:bool=True, heatmap_thresh:int=16,
return_fig:bool=None)->Optional[plt.Figure]:
"Show images in `top_losses` along with their prediction, actual, loss, and probability of actual class."
tl_val,tl_idx = self.top_losses(k, largest)
#classes = self.data.classes
cols = math.ceil(math.sqrt(k))
rows = math.ceil(k/cols)
fig,axes = plt.subplots(rows, cols, figsize=figsize)
fig.suptitle('prediction/actual/loss', weight='bold', size=14)
ds = self.data.dl(DatasetType.Valid).dataset
for i,idx in enumerate(tl_idx):
im,y = self.data.dl(self.ds_type).dataset[idx]
##### reconstruct y hat, set title ###################
y_hat = ds.y.reconstruct(ds.y.analyze_pred(self.probs[idx]))
y = ds.y.reconstruct(y.data)
title = f'{y_hat}\n{y}\n{losses[idx]:.2f}'
#####################################################
im.show(ax=axes.flat[i], title=title)
if ifnone(return_fig, defaults.return_fig): return fig
def top_losses(self, k:int=None, largest=True):
"`k` largest(/smallest) losses and indexes, defaulting to all losses (sorted by `largest`)."
return self.losses.topk(ifnone(k, len(self.losses)), largest=largest)
interp = MultitaskInterpretation(learn, preds, y, losses)
interp.plot_top_losses(9)
Let's check that our model can be exported and re-loaded. To load it, we need to customize the load_learner
to change the LabelLists
class to MultitaskLabelLists
def mt_load_learner(path:PathOrStr, file:PathLikeOrBinaryStream='export.pkl', test:ItemList=None, **db_kwargs):
"Load a `Learner` object saved with `export_state` in `path/file` with empty data, optionally add `test` and load on `cpu`. `file` can be file-like (file or buffer)"
source = Path(path)/file if is_pathlike(file) else file
state = torch.load(source, map_location='cpu') if defaults.device == torch.device('cpu') else torch.load(source)
model = state.pop('model')
#### change to MultitaskLabelList ###########
src = MultitaskLabelLists.load_state(path, state.pop('data'))
#############################################
if test is not None: src.add_test(test)
data = src.databunch(**db_kwargs)
#### TODO: find better way to initialize state
data.single_ds.y.mt_classes = src.mt_classes
data.single_ds.y.mt_lengths = src.mt_lengths
data.single_ds.y.mt_types = src.mt_types
data.single_ds.y.mt_names = src.mt_names
#############################################
cb_state = state.pop('cb_state')
clas_func = state.pop('cls')
res = clas_func(data, model, **state)
res.callback_fns = state['callback_fns'] #to avoid duplicates
res.callbacks = [load_callback(c,s, res) for c,s in cb_state.items()]
return res
learn.export('face.pkl')
learn_loaded = mt_load_learner('.', 'face.pkl')
Get a test image and use the the model to predict it's output:
!wget -O lenna.png https://hackage.haskell.org/package/JuicyPixels-extra-0.4.0/src/data-examples/lenna-cropped.png
--2019-05-16 08:28:43-- https://hackage.haskell.org/package/JuicyPixels-extra-0.4.0/src/data-examples/lenna-cropped.png Resolving hackage.haskell.org (hackage.haskell.org)... 151.101.0.68, 151.101.64.68, 151.101.128.68, ... Connecting to hackage.haskell.org (hackage.haskell.org)|151.101.0.68|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 62830 (61K) [image/png] Saving to: ‘lenna.png’ lenna.png 100%[===================>] 61.36K --.-KB/s in 0.02s 2019-05-16 08:28:43 (3.35 MB/s) - ‘lenna.png’ saved [62830/62830]
img = open_image('lenna.png')
img
learn_loaded.predict(img)
(gender:female|ethnicity:White|age:43.52, [tensor(1), tensor(5), tensor(0.5778)], tensor([-14.1486, 3.3912, -2.6906, -1.5196, -5.6586, -4.7217, -12.7841, 0.8719, 6.8403, 0.5778]))
data
databunch in the loss_function)