#default_exp xtras #export from fastcore.imports import * from fastcore.foundation import * from fastcore.basics import * from functools import wraps import mimetypes,pickle,random,json,subprocess,shlex,bz2,gzip,zipfile,tarfile import imghdr,struct,distutils.util,tempfile,time,string,collections from contextlib import contextmanager,ExitStack from pdb import set_trace from datetime import datetime, timezone from timeit import default_timer from fastcore.test import * from nbdev.showdoc import * from fastcore.nb_imports import * from time import sleep #export def dict2obj(d): "Convert (possibly nested) dicts (or lists of dicts) to `AttrDict`" if isinstance(d, (L,list)): return L(d).map(dict2obj) if not isinstance(d, dict): return d return AttrDict(**{k:dict2obj(v) for k,v in d.items()}) d1 = dict(a=1, b=dict(c=2,d=3)) d2 = dict2obj(d1) test_eq(d2.b.c, 2) test_eq(d2.b['c'], 2) _list_of_dicts = [d1, d1] ds = dict2obj(_list_of_dicts) test_eq(ds[0].b.c, 2) #export def obj2dict(d): "Convert (possibly nested) AttrDicts (or lists of AttrDicts) to `dict`" if isinstance(d, (L,list)): return list(L(d).map(obj2dict)) if not isinstance(d, dict): return d return dict(**{k:obj2dict(v) for k,v in d.items()}) test_eq(obj2dict(d2), d1) test_eq(obj2dict(ds), _list_of_dicts) #export def _repr_dict(d, lvl): if isinstance(d,dict): its = [f"{k}: {_repr_dict(v,lvl+1)}" for k,v in d.items()] elif isinstance(d,(list,L)): its = [_repr_dict(o,lvl+1) for o in d] else: return str(d) return '\n' + '\n'.join([" "*(lvl*2) + "- " + o for o in its]) #export def repr_dict(d): "Print nested dicts and lists, such as returned by `dict2obj`" return _repr_dict(d,0).strip() print(repr_dict(d2)) #export @patch def __repr__(self:AttrDict): return repr_dict(self) AttrDict._repr_markdown_ = AttrDict.__repr__ print(repr(d2)) d2 #export def is_listy(x): "`isinstance(x, (tuple,list,L,slice,Generator))`" return isinstance(x, (tuple,list,L,slice,Generator)) assert is_listy((1,)) assert is_listy([1]) assert is_listy(L([1])) assert is_listy(slice(2)) assert not is_listy(array([1])) #export def shufflish(x, pct=0.04): "Randomly relocate items of `x` up to `pct` of `len(x)` from their starting location" n = len(x) return L(x[i] for i in sorted(range_of(x), key=lambda o: o+n*(1+random.random()*pct))) #export def mapped(f, it): "map `f` over `it`, unless it's not listy, in which case return `f(it)`" return L(it).map(f) if is_listy(it) else f(it) def _f(x,a=1): return x-a test_eq(mapped(_f,1),0) test_eq(mapped(_f,[1,2]),[0,1]) test_eq(mapped(_f,(1,)),(0,)) #export #hide class IterLen: "Base class to add iteration to anything supporting `__len__` and `__getitem__`" def __iter__(self): return (self[i] for i in range_of(self)) #export @docs class ReindexCollection(GetAttr, IterLen): "Reindexes collection `coll` with indices `idxs` and optional LRU cache of size `cache`" _default='coll' def __init__(self, coll, idxs=None, cache=None, tfm=noop): if idxs is None: idxs = L.range(coll) store_attr() if cache is not None: self._get = functools.lru_cache(maxsize=cache)(self._get) def _get(self, i): return self.tfm(self.coll[i]) def __getitem__(self, i): return self._get(self.idxs[i]) def __len__(self): return len(self.coll) def reindex(self, idxs): self.idxs = idxs def shuffle(self): random.shuffle(self.idxs) def cache_clear(self): self._get.cache_clear() def __getstate__(self): return {'coll': self.coll, 'idxs': self.idxs, 'cache': self.cache, 'tfm': self.tfm} def __setstate__(self, s): self.coll,self.idxs,self.cache,self.tfm = s['coll'],s['idxs'],s['cache'],s['tfm'] _docs = dict(reindex="Replace `self.idxs` with idxs", shuffle="Randomly shuffle indices", cache_clear="Clear LRU cache") show_doc(ReindexCollection, title_level=4) rc=ReindexCollection(['a', 'b', 'c', 'd', 'e'], idxs=[4,3,2,1,0]) list(rc) show_doc(ReindexCollection.reindex, title_level=6) rc=ReindexCollection(['a', 'b', 'c', 'd', 'e']) rc.reindex([4,3,2,1,0]) list(rc) sz = 50 t = ReindexCollection(L.range(sz), cache=2) #trigger a cache hit by indexing into the same element multiple times t[0], t[0] t._get.cache_info() show_doc(ReindexCollection.cache_clear, title_level=5) sz = 50 t = ReindexCollection(L.range(sz), cache=2) #trigger a cache hit by indexing into the same element multiple times t[0], t[0] t.cache_clear() t._get.cache_info() show_doc(ReindexCollection.shuffle, title_level=5) rc=ReindexCollection(['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h']) rc.shuffle() list(rc) sz = 50 t = ReindexCollection(L.range(sz), cache=2) test_eq(list(t), range(sz)) test_eq(t[sz-1], sz-1) test_eq(t._get.cache_info().hits, 1) t.shuffle() test_eq(t._get.cache_info().hits, 1) test_ne(list(t), range(sz)) test_eq(set(t), set(range(sz))) t.cache_clear() test_eq(t._get.cache_info().hits, 0) test_eq(t.count(0), 1) #hide #Test ReindexCollection pickles t1 = pickle.loads(pickle.dumps(t)) test_eq(list(t), list(t1)) # export @contextmanager def maybe_open(f, mode='r', **kwargs): "Context manager: open `f` if it is a path (and close on exit)" if isinstance(f, (str,os.PathLike)): with open(f, mode, **kwargs) as f: yield f else: yield f def _f(fn): with maybe_open(fn) as f: return f.encoding fname = '00_test.ipynb' sys_encoding = 'cp1252' if sys.platform == 'win32' else 'UTF-8' test_eq(_f(fname), sys_encoding) with open(fname) as fh: test_eq(_f(fh), sys_encoding) def what(file, h=None): f = None try: if h is None: if isinstance(file, (str,os.PathLike)): f = open(file, 'rb') h = f.read(32) else: location = file.tell() h = file.read(32) file.seek(location) for tf in imghdr.tests: res = tf(h, f) if res: return res finally: if f: f.close() return None fname = 'images/puppy.jpg' what(fname) def what(file, h=None): if h is None: with maybe_open(file, 'rb') as f: h = f.peek(32) return L(imghdr.tests).map_first(Self(h,file)) test_eq(what(fname), 'jpeg') with open(fname,'rb') as f: test_eq(what(f), 'jpeg') with open(fname,'rb') as f: test_eq(what(None, h=f.read(32)), 'jpeg') def _jpg_size(f): size,ftype = 2,0 while not 0xc0 <= ftype <= 0xcf: f.seek(size, 1) byte = f.read(1) while ord(byte) == 0xff: byte = f.read(1) ftype = ord(byte) size = struct.unpack('>H', f.read(2))[0] - 2 f.seek(1, 1) # `precision' h,w = struct.unpack('>HH', f.read(4)) return w,h def _gif_size(f): return struct.unpack('i', head[4:8])[0]==0x0d0a1a0a return struct.unpack('>ii', head[16:24]) #export def image_size(fn): "Tuple of (w,h) for png, gif, or jpg; `None` otherwise" d = dict(png=_png_size, gif=_gif_size, jpeg=_jpg_size) with maybe_open(fn, 'rb') as f: return d[imghdr.what(f)](f) test_eq(image_size(fname), (1200,803)) #export def bunzip(fn): "bunzip `fn`, raising exception if output already exists" fn = Path(fn) assert fn.exists(), f"{fn} doesn't exist" out_fn = fn.with_suffix('') assert not out_fn.exists(), f"{out_fn} already exists" with bz2.BZ2File(fn, 'rb') as src, out_fn.open('wb') as dst: for d in iter(lambda: src.read(1024*1024), b''): dst.write(d) f = Path('files/test.txt') if f.exists(): f.unlink() bunzip('files/test.txt.bz2') t = f.open().readlines() test_eq(len(t),1) test_eq(t[0], 'test\n') f.unlink() #export def join_path_file(file, path, ext=''): "Return `path/file` if file is a string or a `Path`, file otherwise" if not isinstance(file, (str, Path)): return file path.mkdir(parents=True, exist_ok=True) return path/f'{file}{ext}' path = Path.cwd()/'_tmp'/'tst' f = join_path_file('tst.txt', path) assert path.exists() test_eq(f, path/'tst.txt') with open(f, 'w') as f_: assert join_path_file(f_, path) == f_ shutil.rmtree(Path.cwd()/'_tmp') #export def loads(s, cls=None, object_hook=None, parse_float=None, parse_int=None, parse_constant=None, object_pairs_hook=None, **kw): "Same as `json.loads`, but handles `None`" if not s: return {} return json.loads(s, cls=cls, object_hook=object_hook, parse_float=parse_float, parse_int=parse_int, parse_constant=parse_constant, object_pairs_hook=object_pairs_hook, **kw) #export def loads_multi(s:str): "Generator of >=0 decoded json dicts, possibly with non-json ignored text at start and end" _dec = json.JSONDecoder() while s.find('{')>=0: s = s[s.find('{'):] obj,pos = _dec.raw_decode(s) if not pos: raise ValueError(f'no JSON object found at {pos}') yield obj s = s[pos:] tst = """ # ignored { "a":1 } hello { "b":2 } """ test_eq(list(loads_multi(tst)), [{'a': 1}, {'b': 2}]) #export def untar_dir(file, dest): with tempfile.TemporaryDirectory(dir='.') as d: d = Path(d) with tarfile.open(mode='r:gz', fileobj=file) as t: t.extractall(d) next(d.iterdir()).rename(dest) #export def repo_details(url): "Tuple of `owner,name` from ssh or https git repo `url`" res = remove_suffix(url.strip(), '.git') res = res.split(':')[-1] return res.split('/')[-2:] test_eq(repo_details('https://github.com/fastai/fastai.git'), ['fastai', 'fastai']) test_eq(repo_details('git@github.com:fastai/nbdev.git\n'), ['fastai', 'nbdev']) #export def run(cmd, *rest, ignore_ex=False, as_bytes=False, stderr=False): "Pass `cmd` (splitting with `shlex` if string) to `subprocess.run`; return `stdout`; raise `IOError` if fails" if rest: cmd = (cmd,)+rest elif isinstance(cmd,str): cmd = shlex.split(cmd) res = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) stdout = res.stdout if stderr and res.stderr: stdout += b' ;; ' + res.stderr if not as_bytes: stdout = stdout.decode().strip() if ignore_ex: return (res.returncode, stdout) if res.returncode: raise IOError(stdout) return stdout if sys.platform == 'win32': assert 'ipynb' in run('cmd /c dir /p') assert 'ipynb' in run(['cmd', '/c', 'dir', '/p']) assert 'ipynb' in run('cmd', '/c', 'dir', '/p') else: assert 'ipynb' in run('ls -ls') assert 'ipynb' in run(['ls', '-l']) assert 'ipynb' in run('ls', '-l') if sys.platform == 'win32': test_eq(run('cmd /c findstr asdfds 00_test.ipynb', ignore_ex=True)[0], 1) else: test_eq(run('grep asdfds 00_test.ipynb', ignore_ex=True)[0], 1) if sys.platform == 'win32': # why I ingore as_types, becuase every time nbdev_clean_nbs will update \n to \nn test_eq(run('cmd /c echo hi'), 'hi') else: test_eq(run('echo hi', as_bytes=True), b'hi\n') #export def open_file(fn, mode='r', **kwargs): "Open a file, with optional compression if gz or bz2 suffix" if isinstance(fn, io.IOBase): return fn fn = Path(fn) if fn.suffix=='.bz2': return bz2.BZ2File(fn, mode, **kwargs) elif fn.suffix=='.gz' : return gzip.GzipFile(fn, mode, **kwargs) elif fn.suffix=='.zip': return zipfile.ZipFile(fn, mode, **kwargs) else: return open(fn,mode, **kwargs) #export def save_pickle(fn, o): "Save a pickle file, to a file name or opened file" with open_file(fn, 'wb') as f: pickle.dump(o, f) #export def load_pickle(fn): "Load a pickle file from a file name or opened file" with open_file(fn, 'rb') as f: return pickle.load(f) for suf in '.pkl','.bz2','.gz': # delete=False is added for Windows. https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file with tempfile.NamedTemporaryFile(suffix=suf, delete=False) as f: fn = Path(f.name) save_pickle(fn, 't') t = load_pickle(fn) f.close() test_eq(t,'t') #export @patch def readlines(self:Path, hint=-1, encoding='utf8'): "Read the content of `self`" with self.open(encoding=encoding) as f: return f.readlines(hint) #export @patch def read_json(self:Path, encoding=None, errors=None): "Same as `read_text` followed by `loads`" return loads(self.read_text(encoding=encoding, errors=errors)) #export @patch def mk_write(self:Path, data, encoding=None, errors=None, mode=511): "Make all parent dirs of `self`, and write `data`" self.parent.mkdir(exist_ok=True, parents=True, mode=mode) self.write_text(data, encoding=encoding, errors=errors) #export @patch def ls(self:Path, n_max=None, file_type=None, file_exts=None): "Contents of path as a list" extns=L(file_exts) if file_type: extns += L(k for k,v in mimetypes.types_map.items() if v.startswith(file_type+'/')) has_extns = len(extns)==0 res = (o for o in self.iterdir() if has_extns or o.suffix in extns) if n_max is not None: res = itertools.islice(res, n_max) return L(res) path = Path() t = path.ls() assert len(t)>0 t1 = path.ls(10) test_eq(len(t1), 10) t2 = path.ls(file_exts='.ipynb') assert len(t)>len(t2) t[0] lib_path = (path/'../fastcore') txt_files=lib_path.ls(file_type='text') assert len(txt_files) > 0 and txt_files[0].suffix=='.py' ipy_files=path.ls(file_exts=['.ipynb']) assert len(ipy_files) > 0 and ipy_files[0].suffix=='.ipynb' txt_files[0],ipy_files[0] #hide path = Path() pkl = pickle.dumps(path) p2 = pickle.loads(pkl) test_eq(path.ls()[0], p2.ls()[0]) #export @patch def __repr__(self:Path): b = getattr(Path, 'BASE_PATH', None) if b: try: self = self.relative_to(b) except: pass return f"Path({self.as_posix()!r})" t = ipy_files[0].absolute() try: Path.BASE_PATH = t.parent.parent test_eq(repr(t), f"Path('nbs/{t.name}')") finally: Path.BASE_PATH = None #export def truncstr(s:str, maxlen:int, suf:str='…', space='')->str: "Truncate `s` to length `maxlen`, adding suffix `suf` if truncated" return s[:maxlen-len(suf)]+suf if len(s)+len(space)>maxlen else s+space w = 'abacadabra' test_eq(truncstr(w, 10), w) test_eq(truncstr(w, 5), 'abac…') test_eq(truncstr(w, 5, suf=''), 'abaca') test_eq(truncstr(w, 11, space='_'), w+"_") test_eq(truncstr(w, 10, space='_'), w[:-1]+'…') test_eq(truncstr(w, 5, suf='!!'), 'aba!!') #export spark_chars = '▁▂▃▅▆▇' #export def _ceil(x, lim=None): return x if (not lim or x <= lim) else lim def _sparkchar(x, mn, mx, incr, empty_zero): if x is None or (empty_zero and not x): return ' ' if incr == 0: return spark_chars[0] res = int((_ceil(x,mx)-mn)/incr-0.5) return spark_chars[res] #export def sparkline(data, mn=None, mx=None, empty_zero=False): "Sparkline for `data`, with `None`s (and zero, if `empty_zero`) shown as empty column" valid = [o for o in data if o is not None] if not valid: return ' ' mn,mx,n = ifnone(mn,min(valid)),ifnone(mx,max(valid)),len(spark_chars) res = [_sparkchar(x=o, mn=mn, mx=mx, incr=(mx-mn)/n, empty_zero=empty_zero) for o in data] return ''.join(res) data = [9,6,None,1,4,0,8,15,10] print(f'without "empty_zero": {sparkline(data, empty_zero=False)}') print(f' with "empty_zero": {sparkline(data, empty_zero=True )}') sparkline([1,2,3,400], mn=0, mx=3) #export def autostart(g): "Decorator that automatically starts a generator" @functools.wraps(g) def f(): r = g() next(r) return r return f #export class EventTimer: "An event timer with history of `store` items of time `span`" def __init__(self, store=5, span=60): self.hist,self.span,self.last = collections.deque(maxlen=store),span,default_timer() self._reset() def _reset(self): self.start,self.events = self.last,0 def add(self, n=1): "Record `n` events" if self.duration>self.span: self.hist.append(self.freq) self._reset() self.events +=n self.last = default_timer() @property def duration(self): return default_timer()-self.start @property def freq(self): return self.events/self.duration show_doc(EventTimer, title_level=4) # Random wait function for testing def _randwait(): yield from (sleep(random.random()/200) for _ in range(100)) c = EventTimer(store=5, span=0.03) for o in _randwait(): c.add(1) print(f'Num Events: {c.events}, Freq/sec: {c.freq:.01f}') print('Most recent: ', sparkline(c.hist), *L(c.hist).map('{:.01f}')) #export _fmt = string.Formatter() #export def stringfmt_names(s:str)->list: "Unique brace-delimited names in `s`" return uniqueify(o[1] for o in _fmt.parse(s) if o[1]) s = '/pulls/{pull_number}/reviews/{review_id}' test_eq(stringfmt_names(s), ['pull_number','review_id']) #export class PartialFormatter(string.Formatter): "A `string.Formatter` that doesn't error on missing fields, and tracks missing fields and unused args" def __init__(self): self.missing = set() super().__init__() def get_field(self, nm, args, kwargs): try: return super().get_field(nm, args, kwargs) except KeyError: self.missing.add(nm) return '{'+nm+'}',nm def check_unused_args(self, used, args, kwargs): self.xtra = filter_keys(kwargs, lambda o: o not in used) show_doc(PartialFormatter, title_level=4) #export def partial_format(s:str, **kwargs): "string format `s`, ignoring missing field errors, returning missing and extra fields" fmt = PartialFormatter() res = fmt.format(s, **kwargs) return res,list(fmt.missing),fmt.xtra res,missing,xtra = partial_format(s, pull_number=1, foo=2) test_eq(res, '/pulls/1/reviews/{review_id}') test_eq(missing, ['review_id']) test_eq(xtra, {'foo':2}) #export def utc2local(dt:datetime)->datetime: "Convert `dt` from UTC to local time" return dt.replace(tzinfo=timezone.utc).astimezone(tz=None) dt = datetime(2000,1,1,12) print(f'{dt} UTC is {utc2local(dt)} local time') #export def local2utc(dt:datetime)->datetime: "Convert `dt` from local to UTC time" return dt.replace(tzinfo=None).astimezone(tz=timezone.utc) print(f'{dt} local is {local2utc(dt)} UTC time') #export def trace(f): "Add `set_trace` to an existing function `f`" if getattr(f, '_traced', False): return f def _inner(*args,**kwargs): set_trace() return f(*args,**kwargs) _inner._traced = True return _inner #export def round_multiple(x, mult, round_down=False): "Round `x` to nearest multiple of `mult`" def _f(x_): return (int if round_down else round)(x_/mult)*mult res = L(x).map(_f) return res if is_listy(x) else res[0] test_eq(round_multiple(63,32), 64) test_eq(round_multiple(50,32), 64) test_eq(round_multiple(40,32), 32) test_eq(round_multiple( 0,32), 0) test_eq(round_multiple(63,32, round_down=True), 32) test_eq(round_multiple((63,40),32), (64,32)) #export @contextmanager def modified_env(*delete, **replace): "Context manager temporarily modifying `os.environ` by deleting `delete` and replacing `replace`" prev = dict(os.environ) try: os.environ.update(replace) for k in delete: os.environ.pop(k, None) yield finally: os.environ.clear() os.environ.update(prev) # USER isn't in Cloud Linux Environments env_test = 'USERNAME' if sys.platform == "win32" else 'SHELL' oldusr = os.environ[env_test] replace_param = {env_test: 'a'} with modified_env('PATH', **replace_param): test_eq(os.environ[env_test], 'a') assert 'PATH' not in os.environ assert 'PATH' in os.environ test_eq(os.environ[env_test], oldusr) #export class ContextManagers(GetAttr): "Wrapper for `contextlib.ExitStack` which enters a collection of context managers" def __init__(self, mgrs): self.default,self.stack = L(mgrs),ExitStack() def __enter__(self): self.default.map(self.stack.enter_context) def __exit__(self, *args, **kwargs): self.stack.__exit__(*args, **kwargs) show_doc(ContextManagers, title_level=4) #export def str2bool(s): "Case-insensitive convert string `s` too a bool (`y`,`yes`,`t`,`true`,`on`,`1`->`True`)" if not isinstance(s,str): return bool(s) return bool(distutils.util.strtobool(s)) if s else False for o in "y YES t True on 1".split(): assert str2bool(o) for o in "n no FALSE off 0".split(): assert not str2bool(o) for o in 0,None,'',False: assert not str2bool(o) for o in 1,True: assert str2bool(o) #export def _is_instance(f, gs): tst = [g if type(g) in [type, 'function'] else g.__class__ for g in gs] for g in tst: if isinstance(f, g) or f==g: return True return False def _is_first(f, gs): for o in L(getattr(f, 'run_after', None)): if _is_instance(o, gs): return False for g in gs: if _is_instance(f, L(getattr(g, 'run_before', None))): return False return True #export def sort_by_run(fs): end = L(fs).attrgot('toward_end') inp,res = L(fs)[~end] + L(fs)[end], L() while len(inp): for i,o in enumerate(inp): if _is_first(o, inp): res.append(inp.pop(i)) break else: raise Exception("Impossible to sort") return res #hide from nbdev.export import notebook2script notebook2script()