import os
os.environ['CUDA_VISIBLE_DEVICES']='0'
os.environ['OMP_NUM_THREADS']='1'
import pickle,gzip
from glob import glob
from torcheval.metrics import MulticlassAccuracy
from miniai.imports import *
from fastprogress import progress_bar
from diffusers import AutoencoderKL
A matching Triton is not available, some optimizations will not be enabled. Error caught was: module 'triton.language' has no attribute 'constexpr'
torch.set_printoptions(precision=5, linewidth=140, sci_mode=False)
torch.manual_seed(1)
mpl.rcParams['figure.dpi'] = 70
set_seed(42)
if fc.defaults.cpus>8: fc.defaults.cpus=8
path_data = Path('data')/'ILSVRC'
path = path_data/'Data'/'CLS-LOC'
dest = path_data/'latents'
dest.mkdir(exist_ok=True)
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema").cuda().requires_grad_(False)
class ImagesDS:
def __init__(self, path, spec):
cache = path/'files.zpkl'
if cache.exists():
with gzip.open(cache) as f: self.files = pickle.load(f)
else:
self.files = glob(str(path/spec), recursive=True)
with gzip.open(cache, 'wb', compresslevel=1) as f: pickle.dump(self.files, f)
def __len__(self): return len(self.files)
def __getitem__(self, i):
f = self.files[i]
im = read_image(f, mode=ImageReadMode.RGB)/255
im = TF.resize(TF.center_crop(im, min(im.shape[1:])), 256)
return im,f
ds = ImagesDS(path, '**/*.JPEG')
dl = DataLoader(ds, batch_size=64, num_workers=fc.defaults.cpus)
xb,yb = next(iter(dl))
xe = vae.encode(xb.cuda())
xs = xe.latent_dist.mean
xs.shape
torch.Size([64, 4, 32, 32])
show_images(((xs[:16,:3])/4).sigmoid(), imsize=2)
xd = to_cpu(vae.decode(xs))
show_images(xd['sample'][:16].clamp(0,1), imsize=2)
if not dest.exists():
dest.mkdir()
for xb,yb in progress_bar(dl):
eb = to_cpu(vae.encode(xb.cuda()).latent_dist.mean).numpy()
for ebi,ybi in zip(eb,yb):
ybi = dest/Path(ybi).relative_to(path).with_suffix('')
(ybi.parent).mkdir(parents=True, exist_ok=True)
np.save(ybi, ebi)
class NumpyDS(ImagesDS):
def __getitem__(self, i):
f = self.files[i]
im = np.load(f)
return im,f
bs = 128
tds = NumpyDS(dest/'train', '**/*.npy')
vds = NumpyDS(dest/'val', '**/*.npy')
tdl = DataLoader(tds, batch_size=bs, num_workers=0)
xb,yb = next(iter(tdl))
xb.mean((0,2,3)),xb.std((0,2,3))
(tensor([ 5.23983, 2.59586, 0.45112, -2.28669]), tensor([3.94172, 4.42124, 3.24268, 3.09760]))
xmean,xstd = (tensor([ 5.37007, 2.65468, 0.44876, -2.39154]),
tensor([3.99512, 4.44317, 3.21629, 3.10339]))
class TfmDS:
def __init__(self, ds, tfmx=fc.noop, tfmy=fc.noop): self.ds,self.tfmx,self.tfmy = ds,tfmx,tfmy
def __len__(self): return len(self.ds)
def __getitem__(self, i):
x,y = self.ds[i]
return self.tfmx(x),self.tfmy(y)
id2str = (path_data/'imagenet_lsvrc_2015_synsets.txt').read_text().splitlines()
str2id = {v:k for k,v in enumerate(id2str)}
aug_tfms = nn.Sequential(T.Pad(2), T.RandomCrop(32), RandErase())
norm_tfm = T.Normalize(xmean, xstd)
def tfmx(x, aug=False):
x = norm_tfm(tensor(x))
if aug: x = aug_tfms(x[None])[0]
return x
def tfmy(y): return tensor(str2id[Path(y).parent.name])
tfm_tds = TfmDS(tds, partial(tfmx, aug=True), tfmy)
tfm_vds = TfmDS(vds, tfmx, tfmy)
def denorm(x): return (x*xstd[:,None,None]+xmean[:,None,None])
dls = DataLoaders(*get_dls(tfm_tds, tfm_vds, bs=bs, num_workers=8))
all_synsets = [o.split('\t') for o in (path_data/'words.txt').read_text().splitlines()]
synsets = {k:v.split(',', maxsplit=1)[0] for k,v in all_synsets if k in id2str}
xb,yb = next(iter(dls.train))
titles = [synsets[id2str[o]] for o in yb]
xb.mean(),xb.std()
(tensor(-0.02974), tensor(0.97078))
xd = to_cpu(vae.decode(denorm(xb[:9]).cuda()))
show_images(xd['sample'].clamp(0,1), imsize=4, titles=titles[:9])
act_gr = partial(GeneralRelu, leak=0.1, sub=0.4)
iw = partial(init_weights, leaky=0.1)
opt_func = partial(optim.AdamW, eps=1e-5)
metrics = MetricsCB(accuracy=MulticlassAccuracy())
cbs = [DeviceCB(), metrics, ProgressCB(plot=True), MixedPrecision()]
def conv(ni, nf, ks=3, stride=1, act=nn.ReLU, norm=None, bias=True):
layers = []
if norm: layers.append(norm(ni))
if act : layers.append(act())
layers.append(nn.Conv2d(ni, nf, stride=stride, kernel_size=ks, padding=ks//2, bias=bias))
return nn.Sequential(*layers)
def _conv_block(ni, nf, stride, act=act_gr, norm=None, ks=3):
return nn.Sequential(conv(ni, nf, stride=1 , act=act, norm=norm, ks=ks),
conv(nf, nf, stride=stride, act=act, norm=norm, ks=ks))
class ResBlock(nn.Module):
def __init__(self, ni, nf, stride=1, ks=3, act=act_gr, norm=None):
super().__init__()
self.convs = _conv_block(ni, nf, stride, act=act, ks=ks, norm=norm)
self.idconv = fc.noop if ni==nf else conv(ni, nf, ks=1, stride=1, act=None, norm=norm)
self.pool = fc.noop if stride==1 else nn.AvgPool2d(2, ceil_mode=True)
def forward(self, x): return self.convs(x) + self.idconv(self.pool(x))
def res_blocks(n_bk, ni, nf, stride=1, ks=3, act=act_gr, norm=None):
return nn.Sequential(*[
ResBlock(ni if i==0 else nf, nf, stride=stride if i==n_bk-1 else 1, ks=ks, act=act, norm=norm)
for i in range(n_bk)])
def get_dropmodel(nfs, nbks, act=act_gr, norm=nn.BatchNorm2d, drop=0.2):
layers = [nn.Conv2d(4, nfs[0], 5, padding=2)]
layers += [res_blocks(nbks[i], nfs[i], nfs[i+1], act=act, norm=norm, stride=2)
for i in range(len(nfs)-1)]
layers += [act_gr(), norm(nfs[-1]), nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Dropout(drop)]
layers += [nn.Linear(nfs[-1], 1000, bias=False), nn.BatchNorm1d(1000)]
return nn.Sequential(*layers).apply(iw)
epochs = 40
lr = 1e-2
tmax = epochs * len(dls.train)
sched = partial(lr_scheduler.OneCycleLR, max_lr=lr, total_steps=tmax)
xtra = [BatchSchedCB(sched)]
model = get_dropmodel(nbks=(1,2,4,3), nfs=(32, 64, 128, 512, 1024), drop=0.1)
learn = Learner(model, dls, F.cross_entropy, lr=lr, cbs=cbs+xtra, opt_func=opt_func)
learn.fit(epochs)
accuracy | loss | epoch | train |
---|---|---|---|
0.134 | 4.711 | 0 | train |
0.253 | 3.666 | 0 | eval |
0.281 | 3.468 | 1 | train |
0.339 | 3.127 | 1 | eval |
0.348 | 3.022 | 2 | train |
0.354 | 3.063 | 2 | eval |
0.385 | 2.799 | 3 | train |
0.377 | 2.926 | 3 | eval |
0.406 | 2.675 | 4 | train |
0.400 | 2.807 | 4 | eval |
0.416 | 2.617 | 5 | train |
0.408 | 2.757 | 5 | eval |
0.419 | 2.600 | 6 | train |
0.400 | 2.793 | 6 | eval |
0.417 | 2.608 | 7 | train |
0.394 | 2.874 | 7 | eval |
0.414 | 2.629 | 8 | train |
0.374 | 2.997 | 8 | eval |
0.411 | 2.645 | 9 | train |
0.407 | 2.734 | 9 | eval |
0.409 | 2.655 | 10 | train |
0.376 | 2.934 | 10 | eval |
0.410 | 2.652 | 11 | train |
0.398 | 2.797 | 11 | eval |
0.412 | 2.636 | 12 | train |
0.371 | 3.026 | 12 | eval |
0.415 | 2.619 | 13 | train |
0.385 | 2.958 | 13 | eval |
0.418 | 2.602 | 14 | train |
0.417 | 2.706 | 14 | eval |
0.422 | 2.581 | 15 | train |
0.398 | 2.831 | 15 | eval |
0.426 | 2.559 | 16 | train |
0.409 | 2.808 | 16 | eval |
0.431 | 2.532 | 17 | train |
0.424 | 2.712 | 17 | eval |
0.436 | 2.502 | 18 | train |
0.415 | 2.718 | 18 | eval |
0.442 | 2.469 | 19 | train |
0.436 | 2.626 | 19 | eval |
0.449 | 2.433 | 20 | train |
0.426 | 2.673 | 20 | eval |
0.457 | 2.393 | 21 | train |
0.415 | 2.747 | 21 | eval |
0.465 | 2.346 | 22 | train |
0.443 | 2.585 | 22 | eval |
0.474 | 2.299 | 23 | train |
0.465 | 2.494 | 23 | eval |
0.485 | 2.242 | 24 | train |
0.468 | 2.482 | 24 | eval |
0.496 | 2.185 | 25 | train |
0.481 | 2.378 | 25 | eval |
0.508 | 2.120 | 26 | train |
0.496 | 2.268 | 26 | eval |
0.522 | 2.046 | 27 | train |
0.498 | 2.312 | 27 | eval |
0.537 | 1.969 | 28 | train |
0.512 | 2.205 | 28 | eval |
0.554 | 1.882 | 29 | train |
0.517 | 2.217 | 29 | eval |
0.573 | 1.787 | 30 | train |
0.541 | 2.060 | 30 | eval |
0.594 | 1.680 | 31 | train |
0.573 | 1.930 | 31 | eval |
0.617 | 1.561 | 32 | train |
0.593 | 1.828 | 32 | eval |
0.644 | 1.431 | 33 | train |
0.609 | 1.765 | 33 | eval |
0.674 | 1.288 | 34 | train |
0.620 | 1.710 | 34 | eval |
0.707 | 1.136 | 35 | train |
0.638 | 1.650 | 35 | eval |
0.742 | 0.986 | 36 | train |
0.649 | 1.624 | 36 | eval |
0.771 | 0.858 | 37 | train |
0.656 | 1.608 | 37 | eval |
0.794 | 0.767 | 38 | train |
0.659 | 1.601 | 38 | eval |
0.805 | 0.724 | 39 | train |
0.659 | 1.597 | 39 | eval |
torch.save(learn.model, 'models/imgnet-latents')