#export
from pathlib import Path, PosixPath
import mimetypes
import pandas as pd
import os
import PIL.Image
from collections.abc import Iterable
from torch.utils.data import DataLoader
from functools import partial
import numpy as np
import torch
import matplotlib.pyplot as plt
path = PosixPath('/home/eross/.fastai/data/imagenette-160')
#export
Path.ls = lambda self: list(self.iterdir())
path.ls()
[PosixPath('/home/eross/.fastai/data/imagenette-160/train'), PosixPath('/home/eross/.fastai/data/imagenette-160/val')]
#export
def _get_files(p, fs, extensions=None):
p = Path(p)
res = [p/f for f in fs if not f.startswith('.')
and ((not extensions) or f'.{f.split(".")[-1].lower()}' in extensions)]
return res
def get_files(path, recurse=False, extensions=None, include=None):
path = Path(path)
extensions = set(extensions)
extensions = {e.lower() for e in extensions}
if recurse:
res = []
for i,(p,d,f) in enumerate(os.walk(path)): # returns (dirpath, dirnames, filenames)
if include is not None and i==0: d[:] = [o for o in d if o in include]
else: d[:] = [o for o in d if not o.startswith('.')]
res += _get_files(p, f, extensions)
return res
else:
f = [o.name for o in os.scandir(path) if o.is_file()]
return _get_files(path, f, extensions)
def extensions_mime(ext):
return frozenset(k for k,v in mimetypes.types_map.items() if v.startswith(ext))
EXTENSIONS_IMAGE = extensions_mime('image/')
EXTENSIONS_AUDIO = extensions_mime('audio/')
EXTENSIONS_VIDEO = extensions_mime('video/')
def img_get_files(path, recurse=False, extensions=None, include=None):
if extensions is None: extensions = EXTENSIONS_IMAGE
return get_files(path, recurse, extensions, include)
image_paths = img_get_files(path, True)
image_paths[:5]
[PosixPath('/home/eross/.fastai/data/imagenette-160/train/n01440764/n01440764_10026.JPEG'), PosixPath('/home/eross/.fastai/data/imagenette-160/train/n01440764/n01440764_10027.JPEG'), PosixPath('/home/eross/.fastai/data/imagenette-160/train/n01440764/n01440764_10029.JPEG'), PosixPath('/home/eross/.fastai/data/imagenette-160/train/n01440764/n01440764_10040.JPEG'), PosixPath('/home/eross/.fastai/data/imagenette-160/train/n01440764/n01440764_10042.JPEG')]
df = (pd
.DataFrame(image_paths, columns=['path'])
.assign(label=lambda df: df.path.apply(lambda x: x.parent.name),
split=lambda df: df.path.apply(lambda x: x.parent.parent.name),
train=lambda df: df.split == 'train')
)
df.head()
path | label | split | train | |
---|---|---|---|---|
0 | /home/eross/.fastai/data/imagenette-160/train/... | n01440764 | train | True |
1 | /home/eross/.fastai/data/imagenette-160/train/... | n01440764 | train | True |
2 | /home/eross/.fastai/data/imagenette-160/train/... | n01440764 | train | True |
3 | /home/eross/.fastai/data/imagenette-160/train/... | n01440764 | train | True |
4 | /home/eross/.fastai/data/imagenette-160/train/... | n01440764 | train | True |
(df
.groupby('label')
.train
.agg(['sum', lambda x: (~x).sum()])
.set_axis(['train', 'val'], axis=1, inplace=False)
.T
)
label | n01440764 | n02102040 | n02979186 | n03000684 | n03028079 | n03394916 | n03417042 | n03425413 | n03445777 | n03888257 |
---|---|---|---|---|---|---|---|---|---|---|
train | 1300.0 | 1300.0 | 1300.0 | 1194.0 | 1300.0 | 1300.0 | 1300.0 | 1300.0 | 1300.0 | 1300.0 |
val | 50.0 | 50.0 | 50.0 | 50.0 | 50.0 | 50.0 | 50.0 | 50.0 | 50.0 | 50.0 |
#export
# Compose and apply: Called compose in dl2v3 lectures
def comply(functions, x, reverse=False):
if reverse: functions = reversed(functions)
for f in functions:
x = f(x)
return x
It would probably be "nicer" to make all of these classes: but it's good to know you can use quick-and-dirty functions for prototyping.
#export
img_open = PIL.Image.open
def img_rgb(item): return item.convert('RGB')
def img_resize(item, size, method=PIL.Image.LANCZOS):
return item.resize(size if isinstance(size, Iterable) else (size, size),
method)
def img_to_float(x): return torch.from_numpy(np.array(x, dtype=np.float32, copy=False)).permute(2,0,1).contiguous()/255.
def img_from_float(x, mode="RGB"):
c, h, w = x.shape
assert c == len(mode), f"Mode {mode} doesn't agree with channels {c}"
return PIL.Image.frombytes(mode, (w,h), bytes((x*255.).permute(1,2,0).contiguous().view(-1).byte()))
Check the functions work
img = img_open(image_paths[1])
print(img.size)
img
(213, 160)
img_resize(img, 128)
img_to_float(img).shape
torch.Size([3, 160, 213])
img_from_float(img_to_float(img))
assert torch.allclose(img_to_float(img), img_to_float(img_from_float(img_to_float(img))))
assert torch.allclose(img_to_float(img), img_to_float(img_rgb(img)))
img = img_open(image_paths[407])
print(img.size)
img
(240, 160)
img_to_float(img).shape
--------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) <ipython-input-18-11efb45e1717> in <module> ----> 1 img_to_float(img).shape <ipython-input-10-91fceb8ad6e2> in img_to_float(x) 8 method) 9 ---> 10 def img_to_float(x): return torch.from_numpy(np.array(x, dtype=np.float32, copy=False)).permute(2,0,1).contiguous()/255. 11 12 def img_from_float(x, mode="RGB"): RuntimeError: number of dims don't match in permute
img_to_float(img_rgb(img)).shape
torch.Size([3, 160, 240])
img_from_float(img_to_float(img_rgb(img)))
Looks good
Let's make sure we can plot a batch and normalize
img_tfms = [img_open, partial(img_resize, size=128), img_rgb, img_to_float]
batch = torch.stack([comply(img_tfms, p) for p in image_paths[:4]])
batch.shape
torch.Size([4, 3, 128, 128])
def _img_plot(ax, img, title=None):
ax.axis('off')
ax.imshow(img.permute(1,2,0))
if title is not None:
ax.set_title(title)
#export
def img_batch_plot(batch, titles=None, nrows=2, ncols=2):
fig, axs = plt.subplots(nrows=nrows, ncols=ncols)
fig.tight_layout()
axs = [ax for row in axs for ax in row] # Flatten
titles = titles or [None for _ in axs]
for img, ax, title in zip(batch, axs, titles):
_img_plot(ax, img, title)
img_batch_plot(batch, nrows=3, ncols=3)
#export
class ImgNormalize:
def __init__(self, mean=None, sd=None, batch=True):
self.shape = (1, -1, 1, 1) if batch else (-1, 1, 1)
self.mean = mean
self.sd = sd
def get_stats(self, items):
assert self.mean is None and self.sd is None, "Stats already calculated"
n = 0
mean = 0
sd = 0
for item in items:
n += 1
item = item.view(item.shape[0], -1)
mean += item.mean(1)
sd += item.std(1)
self.mean = (mean / n).view(self.shape)
self.sd = (sd / n).view(self.shape)
return self
def __call__(self, item):
return (item - self.mean) / self