from fastai.gen_doc.nbdoc import * from fastai.vision import * class ImageTuple(ItemBase): def __init__(self, img1, img2): self.img1,self.img2 = img1,img2 self.obj,self.data = (img1,img2),[-1+2*img1.data,-1+2*img2.data] def apply_tfms(self, tfms, **kwargs): self.img1 = self.img1.apply_tfms(tfms, **kwargs) self.img2 = self.img2.apply_tfms(tfms, **kwargs) self.data = [-1+2*self.img1.data,-1+2*self.img2.data] return self def to_one(self): return Image(0.5+torch.cat(self.data,2)/2) jekyll_note("""If you just want to customize the way an `Image` is opened, subclass `Image` and just change the `open` method.""") class ImageTupleList(ImageList): def __init__(self, items, itemsB=None, **kwargs): super().__init__(items, **kwargs) self.itemsB = itemsB self.copy_new.append('itemsB') def get(self, i): img1 = super().get(i) fn = self.itemsB[random.randint(0, len(self.itemsB)-1)] return ImageTuple(img1, open_image(fn)) @classmethod def from_folders(cls, path, folderA, folderB, **kwargs): itemsB = ImageList.from_folder(path/folderB).items res = super().from_folder(path/folderA, itemsB=itemsB, **kwargs) res.path = path return res def reconstruct(self, t:Tensor): return ImageTuple(Image(t[0]/2+0.5),Image(t[1]/2+0.5)) class TargetTupleList(ItemList): def reconstruct(self, t:Tensor): if len(t.size()) == 0: return t return ImageTuple(Image(t[0]/2+0.5),Image(t[1]/2+0.5)) class ImageTupleList(ImageList): _label_cls=TargetTupleList def __init__(self, items, itemsB=None, **kwargs): super().__init__(items, **kwargs) self.itemsB = itemsB self.copy_new.append('itemsB') def get(self, i): img1 = super().get(i) fn = self.itemsB[random.randint(0, len(self.itemsB)-1)] return ImageTuple(img1, open_image(fn)) def reconstruct(self, t:Tensor): return ImageTuple(Image(t[0]/2+0.5),Image(t[1]/2+0.5)) @classmethod def from_folders(cls, path, folderA, folderB, **kwargs): itemsB = ImageList.from_folder(path/folderB).items res = super().from_folder(path/folderA, itemsB=itemsB, **kwargs) res.path = path return res def show_xys(self, xs, ys, figsize:Tuple[int,int]=(12,6), **kwargs): "Show the `xs` and `ys` on a figure of `figsize`. `kwargs` are passed to the show method." rows = int(math.sqrt(len(xs))) fig, axs = plt.subplots(rows,rows,figsize=figsize) for i, ax in enumerate(axs.flatten() if rows > 1 else [axs]): xs[i].to_one().show(ax=ax, **kwargs) plt.tight_layout() def show_xyzs(self, xs, ys, zs, figsize:Tuple[int,int]=None, **kwargs): """Show `xs` (inputs), `ys` (targets) and `zs` (predictions) on a figure of `figsize`. `kwargs` are passed to the show method.""" figsize = ifnone(figsize, (12,3*len(xs))) fig,axs = plt.subplots(len(xs), 2, figsize=figsize) fig.suptitle('Ground truth / Predictions', weight='bold', size=14) for i,(x,z) in enumerate(zip(xs,zs)): x.to_one().show(ax=axs[i,0], **kwargs) z.to_one().show(ax=axs[i,1], **kwargs)