!pip install fastai --upgrade import torch import fastai as fastai import fastai.vision.all as vision import pandas as pd import matplotlib.pyplot as plt from functools import partial path = vision.untar_data(vision.URLs.PASCAL_2007) path.ls() train_img = '/root/.fastai/data/pascal_2007/train/' csv = pd.read_csv(path.ls()[2]) csv.info() def get_x(df): return train_img+df['fname'] #return df['fname'] def get_y(df): return df['labels'].split()#return csv[csv['fname'] == fname]['labels'].iloc[0].split() def train_test(df): train = df.index[~df['is_valid']].tolist() valid = df.index[df['is_valid']].tolist() return train, valid db = vision.DataBlock( blocks=(vision.ImageBlock, vision.MultiCategoryBlock), get_x=get_x, get_y=get_y, splitter=train_test, item_tfms=vision.RandomResizedCrop(128, min_scale=0.35) ) db.summary(csv) dataset = db.datasets(csv) print(dataset.vocab[torch.where(dataset[171][1]==1)[0]]) dataset[171][0] dls = db.dataloaders(csv, bs=32) dls.show_batch(nrows=1, ncols=3) model = vision.cnn_learner(dls, vision.resnet18) for data, label in dls.train: print(data.shape, label.shape) break model.model x, y = vision.to_cpu(dls.train.one_batch()) out = model.model(x) print(out.shape, model.loss_func) out[0] def idx2vocab(tens): print(dataset.vocab[torch.where(tens.round()==1)[0]]) idx2vocab(out[10]) def binary_cross_entropy(pred, target): pred = pred.sigmoid() return -torch.where(target==1, pred, 1-pred).log().mean() def acc_multi(pred, target, thresh=0.5, sigmoid=True): if sigmoid: pred = pred.sigmoid() return ((pred>thresh) == target.bool()).float().mean() acc_multi(out[0], out[0].sigmoid().round()) model = vision.cnn_learner(dls, vision.resnet18, metrics=partial(acc_multi, thresh=0.57)) model.lr_find() model = vision.cnn_learner(dls, vision.resnet18, metrics=partial(acc_multi, thresh=0.57)) model.fine_tune(3, base_lr=0.003, max_lr=slice(0.002, 0.1), freeze_epochs=4) pred, targ = model.get_preds() pred[0], targ[1] model.metrics = partial(acc_multi, thresh=0.5) loss, acc = model.validate() print("val_acc: {}\tval_loss: {}".format(loss, acc)) acc_multi(pred, targ, thresh=0.55, sigmoid=False) x = torch.linspace(0.05, 0.95) y = [acc_multi(pred, targ, thresh=i, sigmoid=False) for i in x] plt.plot(x, y, 'b-') model.metrics = partial(acc_multi, thresh=0.57) model.validate() output = model.model(dls.one_batch()[0]) dls.one_batch()[0][0].shape model.metrics = partial(acc_multi, thresh=0.55) model.show_results()