from fastai.vision.all import *
from fastai.vision.widgets import *
You need to know whether you're being chased by a dangerous grizzly, or a sweet teddy bear, and you need an answer fast? Then you've come to the right place. Take a pic of the potentially vicious killer, and click 'upload' to classify it. (Important: this only handles grizzly bears, black bears, and teddy bears. It will not give a sensible answer for polar bears, a bear market, a bear of a man, or hot dogs.
path = Path()
learn_inf = load_learner(path/'export.pkl', cpu=True)
btn_upload = widgets.FileUpload()
out_pl = widgets.Output()
lbl_pred = widgets.Label()
/home/jhoward/anaconda3/lib/python3.7/site-packages/torch/serialization.py:649: SourceChangeWarning: source code of class 'torch.nn.modules.container.Sequential' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes. warnings.warn(msg, SourceChangeWarning) /home/jhoward/anaconda3/lib/python3.7/site-packages/torch/serialization.py:649: SourceChangeWarning: source code of class 'torch.nn.modules.conv.Conv2d' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes. warnings.warn(msg, SourceChangeWarning) /home/jhoward/anaconda3/lib/python3.7/site-packages/torch/serialization.py:649: SourceChangeWarning: source code of class 'torch.nn.modules.batchnorm.BatchNorm2d' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes. warnings.warn(msg, SourceChangeWarning) /home/jhoward/anaconda3/lib/python3.7/site-packages/torch/serialization.py:649: SourceChangeWarning: source code of class 'torch.nn.modules.activation.ReLU' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes. warnings.warn(msg, SourceChangeWarning) /home/jhoward/anaconda3/lib/python3.7/site-packages/torch/serialization.py:649: SourceChangeWarning: source code of class 'torch.nn.modules.pooling.MaxPool2d' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes. warnings.warn(msg, SourceChangeWarning) /home/jhoward/anaconda3/lib/python3.7/site-packages/torch/serialization.py:649: SourceChangeWarning: source code of class 'torchvision.models.resnet.BasicBlock' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes. warnings.warn(msg, SourceChangeWarning)
--------------------------------------------------------------------------- UnpicklingError Traceback (most recent call last) <ipython-input-3-87ace48b76b0> in <module> 1 path = Path() ----> 2 learn_inf = load_learner(path/'export.pkl', cpu=True) 3 btn_upload = widgets.FileUpload() 4 out_pl = widgets.Output() 5 lbl_pred = widgets.Label() ~/git/fastai/fastai/learner.py in load_learner(fname, cpu) 534 "Load a `Learner` object in `fname`, optionally putting it on the `cpu`" 535 distrib_barrier() --> 536 res = torch.load(fname, map_location='cpu' if cpu else None) 537 if hasattr(res, 'to_fp32'): res = res.to_fp32() 538 if cpu: res.dls.cpu() ~/anaconda3/lib/python3.7/site-packages/torch/serialization.py in load(f, map_location, pickle_module, **pickle_load_args) 583 return torch.jit.load(opened_file) 584 return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args) --> 585 return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args) 586 587 ~/anaconda3/lib/python3.7/site-packages/torch/serialization.py in _legacy_load(f, map_location, pickle_module, **pickle_load_args) 763 unpickler = pickle_module.Unpickler(f, **pickle_load_args) 764 unpickler.persistent_load = persistent_load --> 765 result = unpickler.load() 766 767 deserialized_storage_keys = pickle_module.load(f, **pickle_load_args) UnpicklingError: invalid load key, '\x0a'.
def on_data_change(change):
lbl_pred.value = ''
img = PILImage.create(btn_upload.data[-1])
out_pl.clear_output()
with out_pl: display(img.to_thumb(128,128))
pred,pred_idx,probs = learn_inf.predict(img)
lbl_pred.value = f'Prediction: {pred}; Probability: {probs[pred_idx]:.04f}'
btn_upload.observe(on_data_change, names=['data'])
display(VBox([widgets.Label('Select your bear!'), btn_upload, out_pl, lbl_pred]))
VBox(children=(Label(value='Select your bear!'), FileUpload(value={}, description='Upload'), Output(), Label(v…