#hide ! [ -e /content ] && pip install -Uqq fastbook import fastbook fastbook.setup_book() #hide from fastbook import * [[chapter_arch_details]] model_meta[resnet50] #hide_output create_head(20,2) #hide from fastai.vision.all import * path = untar_data(URLs.PETS) files = get_image_files(path/"images") class SiameseImage(fastuple): def show(self, ctx=None, **kwargs): img1,img2,same_breed = self if not isinstance(img1, Tensor): if img2.size != img1.size: img2 = img2.resize(img1.size) t1,t2 = tensor(img1),tensor(img2) t1,t2 = t1.permute(2,0,1),t2.permute(2,0,1) else: t1,t2 = img1,img2 line = t1.new_zeros(t1.shape[0], t1.shape[1], 10) return show_image(torch.cat([t1,line,t2], dim=2), title=same_breed, ctx=ctx) def label_func(fname): return re.match(r'^(.*)_\d+.jpg$', fname.name).groups()[0] class SiameseTransform(Transform): def __init__(self, files, label_func, splits): self.labels = files.map(label_func).unique() self.lbl2files = {l: L(f for f in files if label_func(f) == l) for l in self.labels} self.label_func = label_func self.valid = {f: self._draw(f) for f in files[splits[1]]} def encodes(self, f): f2,t = self.valid.get(f, self._draw(f)) img1,img2 = PILImage.create(f),PILImage.create(f2) return SiameseImage(img1, img2, t) def _draw(self, f): same = random.random() < 0.5 cls = self.label_func(f) if not same: cls = random.choice(L(l for l in self.labels if l != cls)) return random.choice(self.lbl2files[cls]),same splits = RandomSplitter()(files) tfm = SiameseTransform(files, label_func, splits) tls = TfmdLists(files, tfm, splits=splits) dls = tls.dataloaders(after_item=[Resize(224), ToTensor], after_batch=[IntToFloatTensor, Normalize.from_stats(*imagenet_stats)]) class SiameseModel(Module): def __init__(self, encoder, head): self.encoder,self.head = encoder,head def forward(self, x1, x2): ftrs = torch.cat([self.encoder(x1), self.encoder(x2)], dim=1) return self.head(ftrs) encoder = create_body(resnet34, cut=-2) head = create_head(512*2, 2, ps=0.5) model = SiameseModel(encoder, head) def loss_func(out, targ): return nn.CrossEntropyLoss()(out, targ.long()) def siamese_splitter(model): return [params(model.encoder), params(model.head)] learn = Learner(dls, model, loss_func=loss_func, splitter=siamese_splitter, metrics=accuracy) learn.freeze() learn.fit_one_cycle(4, 3e-3) learn.unfreeze() learn.fit_one_cycle(4, slice(1e-6,1e-4))