import fastcore.all as fc url = 'https://s3.amazonaws.com/fast-ai-imagelocal/camvid.tgz' from pathlib import Path import pickle,gzip, math,os,time,shutil,matplotlib as mpl,matplotlib.pyplot as plt from glob import glob from torch import tensor from miniai.datasets import * path = Path('data/camvid') if not path.exists(): path_tgz = fc.urlsave(url, 'data') shutil.unpack_archive(str(path_tgz), 'data') path.ls() path_lbl = path/'labels' path_img = path/'images' fnames = glob(str(path_img/'*.png')) lbl_names = glob(str(path_lbl/'*.png')) img_f = Path(fnames[0]) img_f from torchvision.io import read_image img = read_image(str(img_f))/255 img.shape show_image(img); show_image((img*2).clamp_max(1)); show_image((img.clamp(0.001,0.999).logit()+1).sigmoid()); def get_y_fn(x): return path_lbl/f'{x.stem}_P{x.suffix}' mask_f = get_y_fn(img_f) mask = read_image(str(mask_f)) mask.shape show_image(mask, cmap='tab20'); ax = show_image(img) show_image(mask, ax=ax, cmap='tab20', alpha=0.5); codes = (path/'codes.txt').read_text().splitlines(False) ' '.join(codes) codes[26] src_size = tensor(mask.shape[1:]) valids = set((path/'valid.txt').read_text().splitlines(False)) list(valids)[:5] img = read_image(str(img_f)) img.shape from fastai.vision.all import * bs = 8 dblock = DataBlock(blocks=(ImageBlock(), MaskBlock(codes=codes)), splitter=FuncSplitter(lambda o: o.name in valids), get_y=get_y_fn, item_tfms=Resize((720//4,960//4))) img_files = get_image_files(path/"images") dls = dblock.dataloaders(img_files, path=path, num_workers=4) xb,yb = next(iter(dls.train)) xb.shape dls.show_batch() name2id = {v:k for k,v in enumerate(codes)} void_code = name2id['Void'] def acc_camvid(input, target): target = target.squeeze(1) mask = target != void_code return (input.argmax(dim=1)[mask]==target[mask]).float().mean() learn = unet_learner(dls, resnet18, metrics=acc_camvid, pretrained=False) learn.fit_one_cycle(8) learn.show_results(max_n=3, figsize=(7,8)) learn = unet_learner(dls, resnet18, metrics=acc_camvid) learn.fine_tune(8) learn.show_results(max_n=3, figsize=(7,8))