import pickle,gzip,math,os,time,shutil,torch,matplotlib as mpl,numpy as np,matplotlib.pyplot as plt
import fastcore.all as fc
from collections.abc import Mapping
from pathlib import Path
from torch import tensor,nn,optim
from torch.utils.data import DataLoader,default_collate
import torch.nn.functional as F
from datasets import load_dataset,load_dataset_builder
import torchvision.transforms.functional as TF
from fastcore.test import test_close
torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)
torch.manual_seed(1)
mpl.rcParams['image.cmap'] = 'gray'
import logging
logging.disable(logging.WARNING)
# from torch.utils.data.sampler import BatchSampler, RandomSampler, SequentialSampler
# from transformers import default_data_collator
# from collections.abc import Mapping
x,y = 'image','labels'
def data_loader(ds, batch_size, as_tuple=True):
kw = {'collate_fn':collate_dict(ds)} if as_tuple else {}
return DataLoader(ds, batch_size=batch_size, **kw)
dls = DataLoaders.from_dd(tds, bs, as_tuple=False)
b = next(iter(dls.train))
b
{'image': tensor([[0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], ..., [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.]]), 'label': tensor([9, 0, 0, 3, 0, 2, 7, 2, 5, 5, 0, 9, 5, 5, 7, 9, 1, 0, 6, 4, 3, 1, 4, 8, 4, 3, 0, 2, 4, 4, 5, 3, 6, 6, 0, 8, 5, 2, 1, 6, 6, 7, 9, 5, 9, 2, 7, 3, 0, 3, 3, 3, 7, 2, 2, 6, 6, 8, 3, 3, 5, 0, 5, 5, 0, 2, 0, 0, 4, 1, 3, 1, 6, 3, 1, 4, 4, 6, 1, 9, 1, 3, 5, 7, 9, 7, 1, 7, 9, 9, 9, 3, 2, 9, 3, 6, 4, 1, 1, 8, 8, 0, 1, 1, 6, 8, 1, 9, 7, 8, 8, 9, 6, 6, 3, 1, 5, 4, 6, 7, 5, 5, 9, 2, 2, 2, 7, 6])}
class FashionMLP(nn.Module):
def __init__(self):
super().__init__()
self.l1 = nn.Linear(m,nh)
self.relu = nn.ReLU()
self.l2 = nn.Linear(nh,10)
def forward(self, b):
xb = b[x]
yb = b[y]
pred = self.l2(self.relu(self.l1(xb)))
return {'preds':pred, 'loss':F.cross_entropy(pred, yb)}
model = FashionMLP()
learn = Learner(model, dls, identity, lr=0.001, cbs=cbs)
learn.fit(1)
0 True 1.5404933774903384 0.6351112739872068
0 False 2.156612927400613 0.7089596518987342