#hide
from utils import *
from fastai2.vision.all import *
path = untar_data(URLs.IMAGENETTE)
dblock = DataBlock(blocks=(ImageBlock(), CategoryBlock()),
get_items=get_image_files,
get_y=parent_label,
item_tfms=Resize(460),
batch_tfms=aug_transforms(size=224, min_scale=0.75))
dls = dblock.dataloaders(path, bs=64)
model = xresnet50()
learn = Learner(dls, model, loss_func=CrossEntropyLossFlat(), metrics=accuracy)
learn.fit_one_cycle(5, 3e-3)
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 1.583403 | 2.064317 | 0.401792 | 01:03 |
1 | 1.208877 | 1.260106 | 0.601568 | 01:02 |
2 | 0.925265 | 1.036154 | 0.664302 | 01:03 |
3 | 0.730190 | 0.700906 | 0.777819 | 01:03 |
4 | 0.585707 | 0.541810 | 0.825243 | 01:03 |
x,y = dls.one_batch()
x.mean(dim=[0,2,3]),x.std(dim=[0,2,3])
(TensorImage([0.4842, 0.4711, 0.4511], device='cuda:5'), TensorImage([0.2873, 0.2893, 0.3110], device='cuda:5'))
def get_dls(bs, size):
dblock = DataBlock(blocks=(ImageBlock, CategoryBlock),
get_items=get_image_files,
get_y=parent_label,
item_tfms=Resize(460),
batch_tfms=[*aug_transforms(size=size, min_scale=0.75),
Normalize.from_stats(*imagenet_stats)])
return dblock.dataloaders(path, bs=bs)
dls = get_dls(64, 224)
x,y = dls.one_batch()
x.mean(dim=[0,2,3]),x.std(dim=[0,2,3])
(TensorImage([-0.0787, 0.0525, 0.2136], device='cuda:5'), TensorImage([1.2330, 1.2112, 1.3031], device='cuda:5'))
model = xresnet50()
learn = Learner(dls, model, loss_func=CrossEntropyLossFlat(), metrics=accuracy)
learn.fit_one_cycle(5, 3e-3)
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 1.632865 | 2.250024 | 0.391337 | 01:02 |
1 | 1.294041 | 1.579932 | 0.517177 | 01:02 |
2 | 0.960535 | 1.069164 | 0.657207 | 01:04 |
3 | 0.730220 | 0.767433 | 0.771845 | 01:05 |
4 | 0.577889 | 0.550673 | 0.824496 | 01:06 |
dls = get_dls(128, 128)
learn = Learner(dls, xresnet50(), loss_func=CrossEntropyLossFlat(), metrics=accuracy)
learn.fit_one_cycle(4, 3e-3)
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 1.902943 | 2.447006 | 0.401419 | 00:30 |
1 | 1.315203 | 1.572992 | 0.525765 | 00:30 |
2 | 1.001199 | 0.767886 | 0.759149 | 00:30 |
3 | 0.765864 | 0.665562 | 0.797984 | 00:30 |
learn.dls = get_dls(64, 224)
learn.fine_tune(5, 1e-3)
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 0.985213 | 1.654063 | 0.565721 | 01:06 |
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 0.706869 | 0.689622 | 0.784541 | 01:07 |
1 | 0.739217 | 0.928541 | 0.712472 | 01:07 |
2 | 0.629462 | 0.788906 | 0.764003 | 01:07 |
3 | 0.491912 | 0.502622 | 0.836445 | 01:06 |
4 | 0.414880 | 0.431332 | 0.863331 | 01:06 |
preds,targs = learn.tta()
accuracy(preds, targs).item()
0.8737863898277283
church = PILImage.create(get_image_files_sorted(path/'train'/'n03028079')[0])
gas = PILImage.create(get_image_files_sorted(path/'train'/'n03425413')[0])
church = church.resize((256,256))
gas = gas.resize((256,256))
tchurch = tensor(church).float() / 255.
tgas = tensor(gas).float() / 255.
_,axs = plt.subplots(1, 3, figsize=(12,4))
show_image(tchurch, ax=axs[0]);
show_image(tgas, ax=axs[1]);
show_image((0.3*tchurch + 0.7*tgas), ax=axs[2]);