#|default_exp xtras
#|export
from __future__ import annotations
#|export
from fastcore.imports import *
from fastcore.foundation import *
from fastcore.basics import *
from importlib import import_module
from functools import wraps
import string,time
from contextlib import contextmanager,ExitStack
from datetime import datetime, timezone
from time import sleep,time,perf_counter
from os.path import getmtime
#|hide
from fastcore.test import *
from nbdev.showdoc import *
from fastcore.nb_imports import *
import shutil,tempfile,pickle,random
from dataclasses import dataclass
Utility functions used in the fastai library
Utilities (other than extensions to Pathlib.Path) for dealing with IO.
#|export
def walk(
path:Path|str, # path to start searching
symlinks:bool=True, # follow symlinks?
keep_file:callable=ret_true, # function that returns True for wanted files
keep_folder:callable=ret_true, # function that returns True for folders to enter
skip_folder:callable=ret_false, # function that returns True for folders to skip
func:callable=os.path.join, # function to apply to each matched file
ret_folders:bool=False # return folders, not just files
):
"Generator version of `os.walk`, using functions to filter files and folders"
from copy import copy
for root,dirs,files in os.walk(path, followlinks=symlinks):
if keep_folder(root,''):
if ret_folders: yield func(root, '')
yield from (func(root, name) for name in files if keep_file(root,name))
for name in copy(dirs):
if skip_folder(root,name): dirs.remove(name)
#|export
def globtastic(
path:Path|str, # path to start searching
recursive:bool=True, # search subfolders
symlinks:bool=True, # follow symlinks?
file_glob:str=None, # Only include files matching glob
file_re:str=None, # Only include files matching regex
folder_re:str=None, # Only enter folders matching regex
skip_file_glob:str=None, # Skip files matching glob
skip_file_re:str=None, # Skip files matching regex
skip_folder_re:str=None, # Skip folders matching regex,
func:callable=os.path.join, # function to apply to each matched file
ret_folders:bool=False # return folders, not just files
)->L: # Paths to matched files
"A more powerful `glob`, including regex matches, symlink handling, and skip parameters"
from fnmatch import fnmatch
path = Path(path)
if path.is_file(): return L([path])
if not recursive: skip_folder_re='.'
file_re,folder_re = compile_re(file_re),compile_re(folder_re)
skip_file_re,skip_folder_re = compile_re(skip_file_re),compile_re(skip_folder_re)
def _keep_file(root, name):
return (not file_glob or fnmatch(name, file_glob)) and (
not file_re or file_re.search(name)) and (
not skip_file_glob or not fnmatch(name, skip_file_glob)) and (
not skip_file_re or not skip_file_re.search(name))
def _keep_folder(root, name): return not folder_re or folder_re.search(os.path.join(root,name))
def _skip_folder(root, name): return skip_folder_re and skip_folder_re.search(name)
return L(walk(path, symlinks=symlinks, keep_file=_keep_file, keep_folder=_keep_folder, skip_folder=_skip_folder,
func=func, ret_folders=ret_folders))
globtastic('.', skip_folder_re='^[_.]', folder_re='core', file_glob='*.*py*', file_re='c')
(#5) ['./fastcore/docments.py','./fastcore/dispatch.py','./fastcore/basics.py','./fastcore/docscrape.py','./fastcore/script.py']
#|export
@contextmanager
def maybe_open(f, mode='r', **kwargs):
"Context manager: open `f` if it is a path (and close on exit)"
if isinstance(f, (str,os.PathLike)):
with open(f, mode, **kwargs) as f: yield f
else: yield f
This is useful for functions where you want to accept a path or file. maybe_open
will not close your file handle if you pass one in.
def _f(fn):
with maybe_open(fn) as f: return f.encoding
fname = '00_test.ipynb'
sys_encoding = 'cp1252' if sys.platform == 'win32' else 'UTF-8'
test_eq(_f(fname), sys_encoding)
with open(fname) as fh: test_eq(_f(fh), sys_encoding)
For example, we can use this to reimplement imghdr.what
from the Python standard library, which is written in Python 3.9 as:
from fastcore import imghdr
def what(file, h=None):
f = None
try:
if h is None:
if isinstance(file, (str,os.PathLike)):
f = open(file, 'rb')
h = f.read(32)
else:
location = file.tell()
h = file.read(32)
file.seek(location)
for tf in imghdr.tests:
res = tf(h, f)
if res: return res
finally:
if f: f.close()
return None
Here's an example of the use of this function:
fname = 'images/puppy.jpg'
what(fname)
'jpeg'
With maybe_open
, Self
, and L.map_first
, we can rewrite this in a much more concise and (in our opinion) clear way:
def what(file, h=None):
if h is None:
with maybe_open(file, 'rb') as f: h = f.peek(32)
return L(imghdr.tests).map_first(Self(h,file))
...and we can check that it still works:
test_eq(what(fname), 'jpeg')
...along with the version passing a file handle:
with open(fname,'rb') as f: test_eq(what(f), 'jpeg')
...along with the h
parameter version:
with open(fname,'rb') as f: test_eq(what(None, h=f.read(32)), 'jpeg')
#|export
def mkdir(path, exist_ok=False, parents=False, overwrite=False, **kwargs):
"Creates and returns a directory defined by `path`, optionally removing previous existing directory if `overwrite` is `True`"
import shutil
path = Path(path)
if path.exists() and overwrite: shutil.rmtree(path)
path.mkdir(exist_ok=exist_ok, parents=parents, **kwargs)
return path
with tempfile.TemporaryDirectory() as d:
path = Path(os.path.join(d, 'new_dir'))
new_dir = mkdir(path)
assert new_dir.exists()
test_eq(new_dir, path)
# test overwrite
with open(new_dir/'test.txt', 'w') as f: f.writelines('test')
test_eq(len(list(walk(new_dir))), 1) # assert file is present
new_dir = mkdir(new_dir, overwrite=True)
test_eq(len(list(walk(new_dir))), 0) # assert file was deleted
#|export
def image_size(fn):
"Tuple of (w,h) for png, gif, or jpg; `None` otherwise"
from fastcore import imghdr
import struct
def _jpg_size(f):
size,ftype = 2,0
while not 0xc0 <= ftype <= 0xcf:
f.seek(size, 1)
byte = f.read(1)
while ord(byte) == 0xff: byte = f.read(1)
ftype = ord(byte)
size = struct.unpack('>H', f.read(2))[0] - 2
f.seek(1, 1) # `precision'
h,w = struct.unpack('>HH', f.read(4))
return w,h
def _gif_size(f): return struct.unpack('<HH', head[6:10])
def _png_size(f):
assert struct.unpack('>i', head[4:8])[0]==0x0d0a1a0a
return struct.unpack('>ii', head[16:24])
d = dict(png=_png_size, gif=_gif_size, jpeg=_jpg_size)
with maybe_open(fn, 'rb') as f: return d[imghdr.what(f)](f)
test_eq(image_size(fname), (1200,803))
#|export
def bunzip(fn):
"bunzip `fn`, raising exception if output already exists"
fn = Path(fn)
assert fn.exists(), f"{fn} doesn't exist"
out_fn = fn.with_suffix('')
assert not out_fn.exists(), f"{out_fn} already exists"
import bz2
with bz2.BZ2File(fn, 'rb') as src, out_fn.open('wb') as dst:
for d in iter(lambda: src.read(1024*1024), b''): dst.write(d)
f = Path('files/test.txt')
if f.exists(): f.unlink()
bunzip('files/test.txt.bz2')
t = f.open().readlines()
test_eq(len(t),1)
test_eq(t[0], 'test\n')
f.unlink()
#|export
def loads(s, **kw):
"Same as `json.loads`, but handles `None`"
if not s: return {}
try: import ujson as json
except ModuleNotFoundError: import json
return json.loads(s, **kw)
#|export
def loads_multi(s:str):
"Generator of >=0 decoded json dicts, possibly with non-json ignored text at start and end"
import json
_dec = json.JSONDecoder()
while s.find('{')>=0:
s = s[s.find('{'):]
obj,pos = _dec.raw_decode(s)
if not pos: raise ValueError(f'no JSON object found at {pos}')
yield obj
s = s[pos:]
tst = """
# ignored
{ "a":1 }
hello
{
"b":2
}
"""
test_eq(list(loads_multi(tst)), [{'a': 1}, {'b': 2}])
#|export
def dumps(obj, **kw):
"Same as `json.dumps`, but uses `ujson` if available"
try: import ujson as json
except ModuleNotFoundError: import json
else: kw['escape_forward_slashes']=False
return json.dumps(obj, **kw)
#|export
def _unpack(fname, out):
import shutil
shutil.unpack_archive(str(fname), str(out))
ls = out.ls()
return ls[0] if len(ls) == 1 else out
#|export
def untar_dir(fname, dest, rename=False, overwrite=False):
"untar `file` into `dest`, creating a directory if the root contains more than one item"
import tempfile,shutil
with tempfile.TemporaryDirectory() as d:
out = Path(d)/remove_suffix(Path(fname).stem, '.tar')
out.mkdir()
if rename: dest = dest/out.name
else:
src = _unpack(fname, out)
dest = dest/src.name
if dest.exists():
if overwrite: shutil.rmtree(dest) if dest.is_dir() else dest.unlink()
else: return dest
if rename: src = _unpack(fname, out)
shutil.move(str(src), dest)
return dest
def test_untar(foldername, rename=False, **kwargs):
with tempfile.TemporaryDirectory() as d:
nm = os.path.join(d, 'a')
shutil.make_archive(nm, 'gztar', **kwargs)
with tempfile.TemporaryDirectory() as d2:
d2 = Path(d2)
untar_dir(nm+'.tar.gz', d2, rename=rename)
test_eq(d2.ls(), [d2/foldername])
If the contents of fname
contain just one file or directory, it is placed directly in dest
:
# using `base_dir` in `make_archive` results in `images` directory included in file names
test_untar('images', base_dir='images')
If rename
then the directory created is named based on the archive, without extension:
test_untar('a', base_dir='images', rename=True)
If the contents of fname
contain multiple files and directories, a new folder in dest
is created with the same name as fname
(but without extension):
# using `root_dir` in `make_archive` results in `images` directory *not* included in file names
test_untar('a', root_dir='images')
#|export
def repo_details(url):
"Tuple of `owner,name` from ssh or https git repo `url`"
res = remove_suffix(url.strip(), '.git')
res = res.split(':')[-1]
return res.split('/')[-2:]
test_eq(repo_details('https://github.com/fastai/fastai.git'), ['fastai', 'fastai'])
test_eq(repo_details('git@github.com:fastai/nbdev.git\n'), ['fastai', 'nbdev'])
#|export
def run(cmd, *rest, same_in_win=False, ignore_ex=False, as_bytes=False, stderr=False):
"Pass `cmd` (splitting with `shlex` if string) to `subprocess.run`; return `stdout`; raise `IOError` if fails"
# Even the command is same on Windows, we have to add `cmd /c `"
import subprocess
if rest:
if sys.platform == 'win32' and same_in_win:
cmd = ('cmd', '/c', cmd, *rest)
else:
cmd = (cmd,)+rest
elif isinstance(cmd, str):
if sys.platform == 'win32' and same_in_win: cmd = 'cmd /c ' + cmd
import shlex
cmd = shlex.split(cmd)
elif isinstance(cmd, list):
if sys.platform == 'win32' and same_in_win: cmd = ['cmd', '/c'] + cmd
res = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdout = res.stdout
if stderr and res.stderr: stdout += b' ;; ' + res.stderr
if not as_bytes: stdout = stdout.decode().strip()
if ignore_ex: return (res.returncode, stdout)
if res.returncode: raise IOError(stdout)
return stdout
You can pass a string (which will be split based on standard shell rules), a list, or pass args directly:
run('echo', same_in_win=True)
run('pip', '--version', same_in_win=True)
run(['pip', '--version'], same_in_win=True)
'pip 23.3.1 from /Users/jhoward/miniconda3/lib/python3.11/site-packages/pip (python 3.11)'
if sys.platform == 'win32':
assert 'ipynb' in run('cmd /c dir /p')
assert 'ipynb' in run(['cmd', '/c', 'dir', '/p'])
assert 'ipynb' in run('cmd', '/c', 'dir', '/p')
else:
assert 'ipynb' in run('ls -ls')
assert 'ipynb' in run(['ls', '-l'])
assert 'ipynb' in run('ls', '-l')
Some commands fail in non-error situations, like grep
. Use ignore_ex
in those cases, which will return a tuple of stdout and returncode:
if sys.platform == 'win32':
test_eq(run('cmd /c findstr asdfds 00_test.ipynb', ignore_ex=True)[0], 1)
else:
test_eq(run('grep asdfds 00_test.ipynb', ignore_ex=True)[0], 1)
run
automatically decodes returned bytes to a str
. Use as_bytes
to skip that:
if sys.platform == 'win32':
test_eq(run('cmd /c echo hi'), 'hi')
else:
test_eq(run('echo hi', as_bytes=True), b'hi\n')
#|export
def open_file(fn, mode='r', **kwargs):
"Open a file, with optional compression if gz or bz2 suffix"
if isinstance(fn, io.IOBase): return fn
import bz2,gzip,zipfile
fn = Path(fn)
if fn.suffix=='.bz2': return bz2.BZ2File(fn, mode, **kwargs)
elif fn.suffix=='.gz' : return gzip.GzipFile(fn, mode, **kwargs)
elif fn.suffix=='.zip': return zipfile.ZipFile(fn, mode, **kwargs)
else: return open(fn,mode, **kwargs)
#|export
def save_pickle(fn, o):
"Save a pickle file, to a file name or opened file"
import pickle
with open_file(fn, 'wb') as f: pickle.dump(o, f)
#|export
def load_pickle(fn):
"Load a pickle file from a file name or opened file"
import pickle
with open_file(fn, 'rb') as f: return pickle.load(f)
for suf in '.pkl','.bz2','.gz':
# delete=False is added for Windows
# https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
with tempfile.NamedTemporaryFile(suffix=suf, delete=False) as f:
fn = Path(f.name)
save_pickle(fn, 't')
t = load_pickle(fn)
f.close()
test_eq(t,'t')
#| export
def parse_env(s:str=None, fn:Union[str,Path]=None) -> dict:
"Parse a shell-style environment string or file"
assert bool(s)^bool(fn), "Must pass exactly one of `s` or `fn`"
if fn: s = Path(fn).read_text()
def _f(line):
m = re.match(r'^\s*(?:export\s+)?(\w+)\s*=\s*(["\']?)(.*?)(\2)\s*(?:#.*)?$', line).groups()
return m[0], m[2]
return dict(_f(o.strip()) for o in s.splitlines() if o.strip() and not re.match(r'\s*#', o))
testf = """# comment
# another comment
export FOO="bar#baz"
BAR=thing # comment "ok"
baz='thong'
QUX=quux
export ZAP = "zip" # more comments
FOOBAR = 42 # trailing space and comment"""
exp = dict(FOO='bar#baz', BAR='thing', baz='thong', QUX='quux', ZAP='zip', FOOBAR='42')
test_eq(parse_env(testf), exp)
#| export
def expand_wildcards(code):
"Expand all wildcard imports in the given code string."
import ast,importlib
tree = ast.parse(code)
def _replace_node(code, old_node, new_node):
"Replace `old_node` in the source `code` with `new_node`."
lines = code.splitlines()
lnum = old_node.lineno
indent = ' ' * (len(lines[lnum-1]) - len(lines[lnum-1].lstrip()))
new_lines = [indent+line for line in ast.unparse(new_node).splitlines()]
lines[lnum-1 : old_node.end_lineno] = new_lines
return '\n'.join(lines)
def _expand_import(node, mod, existing):
"Create expanded import `node` in `tree` from wildcard import of `mod`."
mod_all = getattr(mod, '__all__', None)
available_names = set(mod_all) if mod_all is not None else set(dir(mod))
used_names = {n.id for n in ast.walk(tree) if isinstance(n, ast.Name) and n.id in available_names} - existing
if not used_names: return node
names = [ast.alias(name=name, asname=None) for name in sorted(used_names)]
return ast.ImportFrom(module=node.module, names=names, level=node.level)
existing = set()
for node in ast.walk(tree):
if isinstance(node, ast.ImportFrom) and node.names[0].name != '*': existing.update(n.name for n in node.names)
elif isinstance(node, ast.Import): existing.update(n.name.split('.')[0] for n in node.names)
for node in ast.walk(tree):
if isinstance(node, ast.ImportFrom) and any(n.name == '*' for n in node.names):
new_import = _expand_import(node, importlib.import_module(node.module), existing)
code = _replace_node(code, node, new_import)
return code
inp = """from math import *
from os import *
from random import *
def func(): return sin(pi) + path.join('a', 'b') + randint(1, 10)"""
exp = """from math import pi, sin
from os import path
from random import randint
def func(): return sin(pi) + path.join('a', 'b') + randint(1, 10)"""
test_eq(expand_wildcards(inp), exp)
inp = """from itertools import *
def func(): pass"""
test_eq(expand_wildcards(inp), inp)
inp = """def outer():
from math import *
def inner():
from os import *
return sin(pi) + path.join('a', 'b')"""
exp = """def outer():
from math import pi, sin
def inner():
from os import path
return sin(pi) + path.join('a', 'b')"""
test_eq(expand_wildcards(inp), exp)
#|export
def dict2obj(d, list_func=L, dict_func=AttrDict):
"Convert (possibly nested) dicts (or lists of dicts) to `AttrDict`"
if isinstance(d, (L,list)): return list_func(d).map(dict2obj)
if not isinstance(d, dict): return d
return dict_func(**{k:dict2obj(v) for k,v in d.items()})
This is a convenience to give you "dotted" access to (possibly nested) dictionaries, e.g:
d1 = dict(a=1, b=dict(c=2,d=3))
d2 = dict2obj(d1)
test_eq(d2.b.c, 2)
test_eq(d2.b['c'], 2)
It can also be used on lists of dicts.
_list_of_dicts = [d1, d1]
ds = dict2obj(_list_of_dicts)
test_eq(ds[0].b.c, 2)
#|export
def obj2dict(d):
"Convert (possibly nested) AttrDicts (or lists of AttrDicts) to `dict`"
if isinstance(d, (L,list)): return list(L(d).map(obj2dict))
if not isinstance(d, dict): return d
return dict(**{k:obj2dict(v) for k,v in d.items()})
obj2dict
can be used to reverse what is done by dict2obj
:
test_eq(obj2dict(d2), d1)
test_eq(obj2dict(ds), _list_of_dicts)
#|export
def _repr_dict(d, lvl):
if isinstance(d,dict):
its = [f"{k}: {_repr_dict(v,lvl+1)}" for k,v in d.items()]
elif isinstance(d,(list,L)): its = [_repr_dict(o,lvl+1) for o in d]
else: return str(d)
return '\n' + '\n'.join([" "*(lvl*2) + "- " + o for o in its])
#|export
def repr_dict(d):
"Print nested dicts and lists, such as returned by `dict2obj`"
return _repr_dict(d,0).strip()
print(repr_dict(d2))
- a: 1 - b: - c: 2 - d: 3
#|export
def is_listy(x):
"`isinstance(x, (tuple,list,L,slice,Generator))`"
return isinstance(x, (tuple,list,L,slice,Generator))
assert is_listy((1,))
assert is_listy([1])
assert is_listy(L([1]))
assert is_listy(slice(2))
assert not is_listy(array([1]))
#|export
def mapped(f, it):
"map `f` over `it`, unless it's not listy, in which case return `f(it)`"
return L(it).map(f) if is_listy(it) else f(it)
def _f(x,a=1): return x-a
test_eq(mapped(_f,1),0)
test_eq(mapped(_f,[1,2]),[0,1])
test_eq(mapped(_f,(1,)),(0,))
The following methods are added to the standard python libary Pathlib.Path.
#|export
@patch
def readlines(self:Path, hint=-1, encoding='utf8'):
"Read the content of `self`"
with self.open(encoding=encoding) as f: return f.readlines(hint)
#|export
@patch
def read_json(self:Path, encoding=None, errors=None):
"Same as `read_text` followed by `loads`"
return loads(self.read_text(encoding=encoding, errors=errors))
#|export
@patch
def mk_write(self:Path, data, encoding=None, errors=None, mode=511):
"Make all parent dirs of `self`, and write `data`"
self.parent.mkdir(exist_ok=True, parents=True, mode=mode)
self.write_text(data, encoding=encoding, errors=errors)
#|export
@patch
def relpath(self:Path, start=None):
"Same as `os.path.relpath`, but returns a `Path`, and resolves symlinks"
return Path(os.path.relpath(self.resolve(), Path(start).resolve()))
p = Path('../fastcore/').resolve()
p
Path('/Users/jhoward/Documents/GitHub/fastcore/fastcore')
p.relpath(Path.cwd())
Path('../fastcore')
#|export
@patch
def ls(self:Path, n_max=None, file_type=None, file_exts=None):
"Contents of path as a list"
import mimetypes
extns=L(file_exts)
if file_type: extns += L(k for k,v in mimetypes.types_map.items() if v.startswith(file_type+'/'))
has_extns = len(extns)==0
res = (o for o in self.iterdir() if has_extns or o.suffix in extns)
if n_max is not None: res = itertools.islice(res, n_max)
return L(res)
We add an ls()
method to pathlib.Path
which is simply defined as list(Path.iterdir())
, mainly for convenience in REPL environments such as notebooks.
path = Path()
t = path.ls()
assert len(t)>0
t1 = path.ls(10)
test_eq(len(t1), 10)
t2 = path.ls(file_exts='.ipynb')
assert len(t)>len(t2)
t[0]
Path('000_tour.ipynb')
You can also pass an optional file_type
MIME prefix and/or a list of file extensions.
lib_path = (path/'../fastcore')
txt_files=lib_path.ls(file_type='text')
assert len(txt_files) > 0 and txt_files[0].suffix=='.py'
ipy_files=path.ls(file_exts=['.ipynb'])
assert len(ipy_files) > 0 and ipy_files[0].suffix=='.ipynb'
txt_files[0],ipy_files[0]
(Path('../fastcore/shutil.py'), Path('000_tour.ipynb'))
#|hide
path = Path()
pkl = pickle.dumps(path)
p2 = pickle.loads(pkl)
test_eq(path.ls()[0], p2.ls()[0])
#|export
@patch
def __repr__(self:Path):
b = getattr(Path, 'BASE_PATH', None)
if b:
try: self = self.relative_to(b)
except: pass
return f"Path({self.as_posix()!r})"
fastai also updates the repr
of Path
such that, if Path.BASE_PATH
is defined, all paths are printed relative to that path (as long as they are contained in Path.BASE_PATH
:
t = ipy_files[0].absolute()
try:
Path.BASE_PATH = t.parent.parent
test_eq(repr(t), f"Path('nbs/{t.name}')")
finally: Path.BASE_PATH = None
#|export
@patch
def delete(self:Path):
"Delete a file, symlink, or directory tree"
if not self.exists(): return
if self.is_dir():
import shutil
shutil.rmtree(self)
else: self.unlink()
#|export
#|hide
class IterLen:
"Base class to add iteration to anything supporting `__len__` and `__getitem__`"
def __iter__(self): return (self[i] for i in range_of(self))
#|export
@docs
class ReindexCollection(GetAttr, IterLen):
"Reindexes collection `coll` with indices `idxs` and optional LRU cache of size `cache`"
_default='coll'
def __init__(self, coll, idxs=None, cache=None, tfm=noop):
if idxs is None: idxs = L.range(coll)
store_attr()
if cache is not None: self._get = functools.lru_cache(maxsize=cache)(self._get)
def _get(self, i): return self.tfm(self.coll[i])
def __getitem__(self, i): return self._get(self.idxs[i])
def __len__(self): return len(self.coll)
def reindex(self, idxs): self.idxs = idxs
def shuffle(self):
import random
random.shuffle(self.idxs)
def cache_clear(self): self._get.cache_clear()
def __getstate__(self): return {'coll': self.coll, 'idxs': self.idxs, 'cache': self.cache, 'tfm': self.tfm}
def __setstate__(self, s): self.coll,self.idxs,self.cache,self.tfm = s['coll'],s['idxs'],s['cache'],s['tfm']
_docs = dict(reindex="Replace `self.idxs` with idxs",
shuffle="Randomly shuffle indices",
cache_clear="Clear LRU cache")
show_doc(ReindexCollection, title_level=4)
This is useful when constructing batches or organizing data in a particular manner (i.e. for deep learning). This class is primarly used in organizing data for language models in fastai.
You can supply a custom index upon instantiation with the idxs
argument, or you can call the reindex
method to supply a new index for your collection.
Here is how you can reindex a list such that the elements are reversed:
rc=ReindexCollection(['a', 'b', 'c', 'd', 'e'], idxs=[4,3,2,1,0])
list(rc)
['e', 'd', 'c', 'b', 'a']
Alternatively, you can use the reindex
method:
show_doc(ReindexCollection.reindex, title_level=6)
rc=ReindexCollection(['a', 'b', 'c', 'd', 'e'])
rc.reindex([4,3,2,1,0])
list(rc)
['e', 'd', 'c', 'b', 'a']
You can optionally specify a LRU cache, which uses functools.lru_cache upon instantiation:
sz = 50
t = ReindexCollection(L.range(sz), cache=2)
#trigger a cache hit by indexing into the same element multiple times
t[0], t[0]
t._get.cache_info()
CacheInfo(hits=1, misses=1, maxsize=2, currsize=1)
You can optionally clear the LRU cache by calling the cache_clear
method:
show_doc(ReindexCollection.cache_clear, title_level=5)
sz = 50
t = ReindexCollection(L.range(sz), cache=2)
#trigger a cache hit by indexing into the same element multiple times
t[0], t[0]
t.cache_clear()
t._get.cache_info()
CacheInfo(hits=0, misses=0, maxsize=2, currsize=0)
show_doc(ReindexCollection.shuffle, title_level=5)
Note that an ordered index is automatically constructed for the data structure even if one is not supplied.
rc=ReindexCollection(['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'])
rc.shuffle()
list(rc)
['a', 'h', 'f', 'b', 'c', 'g', 'e', 'd']
sz = 50
t = ReindexCollection(L.range(sz), cache=2)
test_eq(list(t), range(sz))
test_eq(t[sz-1], sz-1)
test_eq(t._get.cache_info().hits, 1)
t.shuffle()
test_eq(t._get.cache_info().hits, 1)
test_ne(list(t), range(sz))
test_eq(set(t), set(range(sz)))
t.cache_clear()
test_eq(t._get.cache_info().hits, 0)
test_eq(t.count(0), 1)
#|hide
#Test ReindexCollection pickles
t1 = pickle.loads(pickle.dumps(t))
test_eq(list(t), list(t1))
#|export
def _is_type_dispatch(x): return type(x).__name__ == "TypeDispatch"
def _unwrapped_type_dispatch_func(x): return x.first() if _is_type_dispatch(x) else x
def _is_property(x): return type(x)==property
def _has_property_getter(x): return _is_property(x) and hasattr(x, 'fget') and hasattr(x.fget, 'func')
def _property_getter(x): return x.fget.func if _has_property_getter(x) else x
def _unwrapped_func(x):
x = _unwrapped_type_dispatch_func(x)
x = _property_getter(x)
return x
def get_source_link(func):
"Return link to `func` in source code"
import inspect
func = _unwrapped_func(func)
try: line = inspect.getsourcelines(func)[1]
except Exception: return ''
mod = inspect.getmodule(func)
module = mod.__name__.replace('.', '/') + '.py'
try:
nbdev_mod = import_module(mod.__package__.split('.')[0] + '._nbdev')
return f"{nbdev_mod.git_url}{module}#L{line}"
except: return f"{module}#L{line}"
get_source_link
allows you get a link to source code related to an object. For nbdev related projects such as fastcore, we can get the full link to a GitHub repo. For nbdev
projects, be sure to properly set the git_url
in settings.ini
(derived from lib_name
and branch
on top of the prefix you will need to adapt) so that those links are correct.
For example, below we get the link to fastcore.test.test_eq
:
from fastcore.test import test_eq
assert 'fastcore/test.py' in get_source_link(test_eq)
assert get_source_link(test_eq).startswith('https://github.com/fastai/fastcore')
get_source_link(test_eq)
'https://github.com/fastai/fastcore/tree/master/fastcore/test.py#L35'
#|export
def truncstr(s:str, maxlen:int, suf:str='…', space='')->str:
"Truncate `s` to length `maxlen`, adding suffix `suf` if truncated"
return s[:maxlen-len(suf)]+suf if len(s)+len(space)>maxlen else s+space
w = 'abacadabra'
test_eq(truncstr(w, 10), w)
test_eq(truncstr(w, 5), 'abac…')
test_eq(truncstr(w, 5, suf=''), 'abaca')
test_eq(truncstr(w, 11, space='_'), w+"_")
test_eq(truncstr(w, 10, space='_'), w[:-1]+'…')
test_eq(truncstr(w, 5, suf='!!'), 'aba!!')
#|export
spark_chars = '▁▂▃▅▆▇'
#|export
def _ceil(x, lim=None): return x if (not lim or x <= lim) else lim
def _sparkchar(x, mn, mx, incr, empty_zero):
if x is None or (empty_zero and not x): return ' '
if incr == 0: return spark_chars[0]
res = int((_ceil(x,mx)-mn)/incr-0.5)
return spark_chars[res]
#|export
def sparkline(data, mn=None, mx=None, empty_zero=False):
"Sparkline for `data`, with `None`s (and zero, if `empty_zero`) shown as empty column"
valid = [o for o in data if o is not None]
if not valid: return ' '
mn,mx,n = ifnone(mn,min(valid)),ifnone(mx,max(valid)),len(spark_chars)
res = [_sparkchar(x=o, mn=mn, mx=mx, incr=(mx-mn)/n, empty_zero=empty_zero) for o in data]
return ''.join(res)
data = [9,6,None,1,4,0,8,15,10]
print(f'without "empty_zero": {sparkline(data, empty_zero=False)}')
print(f' with "empty_zero": {sparkline(data, empty_zero=True )}')
without "empty_zero": ▅▂ ▁▂▁▃▇▅ with "empty_zero": ▅▂ ▁▂ ▃▇▅
You can set a maximum and minimum for the y-axis of the sparkline with the arguments mn
and mx
respectively:
sparkline([1,2,3,400], mn=0, mx=3)
'▂▅▇▇'
#|export
def modify_exception(
e:Exception, # An exception
msg:str=None, # A custom message
replace:bool=False, # Whether to replace e.args with [msg]
) -> Exception:
"Modifies `e` with a custom message attached"
e.args = [f'{e.args[0]} {msg}'] if not replace and len(e.args) > 0 else [msg]
return e
msg = "This is my custom message!"
test_fail(lambda: (_ for _ in ()).throw(modify_exception(Exception(), None)), contains='')
test_fail(lambda: (_ for _ in ()).throw(modify_exception(Exception(), msg)), contains=msg)
test_fail(lambda: (_ for _ in ()).throw(modify_exception(Exception("The first message"), msg)), contains="The first message This is my custom message!")
test_fail(lambda: (_ for _ in ()).throw(modify_exception(Exception("The first message"), msg, True)), contains="This is my custom message!")
#|export
def round_multiple(x, mult, round_down=False):
"Round `x` to nearest multiple of `mult`"
def _f(x_): return (int if round_down else round)(x_/mult)*mult
res = L(x).map(_f)
return res if is_listy(x) else res[0]
test_eq(round_multiple(63,32), 64)
test_eq(round_multiple(50,32), 64)
test_eq(round_multiple(40,32), 32)
test_eq(round_multiple( 0,32), 0)
test_eq(round_multiple(63,32, round_down=True), 32)
test_eq(round_multiple((63,40),32), (64,32))
#|export
def set_num_threads(nt):
"Get numpy (and others) to use `nt` threads"
try: import mkl; mkl.set_num_threads(nt)
except: pass
try: import torch; torch.set_num_threads(nt)
except: pass
os.environ['IPC_ENABLE']='1'
for o in ['OPENBLAS_NUM_THREADS','NUMEXPR_NUM_THREADS','OMP_NUM_THREADS','MKL_NUM_THREADS']:
os.environ[o] = str(nt)
This sets the number of threads consistently for many tools, by:
nt
: OPENBLAS_NUM_THREADS
,NUMEXPR_NUM_THREADS
,OMP_NUM_THREADS
,MKL_NUM_THREADS
nt
threads for numpy and pytorch.#|export
def join_path_file(file, path, ext=''):
"Return `path/file` if file is a string or a `Path`, file otherwise"
if not isinstance(file, (str, Path)): return file
path.mkdir(parents=True, exist_ok=True)
return path/f'{file}{ext}'
path = Path.cwd()/'_tmp'/'tst'
f = join_path_file('tst.txt', path)
assert path.exists()
test_eq(f, path/'tst.txt')
with open(f, 'w') as f_: assert join_path_file(f_, path) == f_
shutil.rmtree(Path.cwd()/'_tmp')
#|export
def autostart(g):
"Decorator that automatically starts a generator"
@functools.wraps(g)
def f():
r = g()
next(r)
return r
return f
#|export
class EventTimer:
"An event timer with history of `store` items of time `span`"
def __init__(self, store=5, span=60):
import collections
self.hist,self.span,self.last = collections.deque(maxlen=store),span,perf_counter()
self._reset()
def _reset(self): self.start,self.events = self.last,0
def add(self, n=1):
"Record `n` events"
if self.duration>self.span:
self.hist.append(self.freq)
self._reset()
self.events +=n
self.last = perf_counter()
@property
def duration(self): return perf_counter()-self.start
@property
def freq(self): return self.events/self.duration
show_doc(EventTimer, title_level=4)
Add events with add
, and get number of events
and their frequency (freq
).
# Random wait function for testing
def _randwait(): yield from (sleep(random.random()/200) for _ in range(100))
c = EventTimer(store=5, span=0.03)
for o in _randwait(): c.add(1)
print(f'Num Events: {c.events}, Freq/sec: {c.freq:.01f}')
print('Most recent: ', sparkline(c.hist), *L(c.hist).map('{:.01f}'))
Num Events: 3, Freq/sec: 316.2 Most recent: ▇▁▂▃▁ 288.7 227.7 246.5 256.5 217.9
#|export
_fmt = string.Formatter()
#|export
def stringfmt_names(s:str)->list:
"Unique brace-delimited names in `s`"
return uniqueify(o[1] for o in _fmt.parse(s) if o[1])
s = '/pulls/{pull_number}/reviews/{review_id}'
test_eq(stringfmt_names(s), ['pull_number','review_id'])
#|export
class PartialFormatter(string.Formatter):
"A `string.Formatter` that doesn't error on missing fields, and tracks missing fields and unused args"
def __init__(self):
self.missing = set()
super().__init__()
def get_field(self, nm, args, kwargs):
try: return super().get_field(nm, args, kwargs)
except KeyError:
self.missing.add(nm)
return '{'+nm+'}',nm
def check_unused_args(self, used, args, kwargs):
self.xtra = filter_keys(kwargs, lambda o: o not in used)
show_doc(PartialFormatter, title_level=4)
#|export
def partial_format(s:str, **kwargs):
"string format `s`, ignoring missing field errors, returning missing and extra fields"
fmt = PartialFormatter()
res = fmt.format(s, **kwargs)
return res,list(fmt.missing),fmt.xtra
The result is a tuple of (formatted_string,missing_fields,extra_fields)
, e.g:
res,missing,xtra = partial_format(s, pull_number=1, foo=2)
test_eq(res, '/pulls/1/reviews/{review_id}')
test_eq(missing, ['review_id'])
test_eq(xtra, {'foo':2})
#|export
def utc2local(dt:datetime)->datetime:
"Convert `dt` from UTC to local time"
return dt.replace(tzinfo=timezone.utc).astimezone(tz=None)
dt = datetime(2000,1,1,12)
print(f'{dt} UTC is {utc2local(dt)} local time')
2000-01-01 12:00:00 UTC is 2000-01-01 22:00:00+10:00 local time
#|export
def local2utc(dt:datetime)->datetime:
"Convert `dt` from local to UTC time"
return dt.replace(tzinfo=None).astimezone(tz=timezone.utc)
print(f'{dt} local is {local2utc(dt)} UTC time')
2000-01-01 12:00:00 local is 2000-01-01 02:00:00+00:00 UTC time
#|export
def trace(f):
"Add `set_trace` to an existing function `f`"
from pdb import set_trace
if getattr(f, '_traced', False): return f
def _inner(*args,**kwargs):
set_trace()
return f(*args,**kwargs)
_inner._traced = True
return _inner
You can add a breakpoint to an existing function, e.g:
Path.cwd = trace(Path.cwd)
Path.cwd()
Now, when the function is called it will drop you into the debugger. Note, you must issue the s
command when you begin to step into the function that is being traced.
#|export
@contextmanager
def modified_env(*delete, **replace):
"Context manager temporarily modifying `os.environ` by deleting `delete` and replacing `replace`"
prev = dict(os.environ)
try:
os.environ.update(replace)
for k in delete: os.environ.pop(k, None)
yield
finally:
os.environ.clear()
os.environ.update(prev)
# USER isn't in Cloud Linux Environments
env_test = 'USERNAME' if sys.platform == "win32" else 'SHELL'
oldusr = os.environ[env_test]
replace_param = {env_test: 'a'}
with modified_env('PATH', **replace_param):
test_eq(os.environ[env_test], 'a')
assert 'PATH' not in os.environ
assert 'PATH' in os.environ
test_eq(os.environ[env_test], oldusr)
#|export
class ContextManagers(GetAttr):
"Wrapper for `contextlib.ExitStack` which enters a collection of context managers"
def __init__(self, mgrs): self.default,self.stack = L(mgrs),ExitStack()
def __enter__(self): self.default.map(self.stack.enter_context)
def __exit__(self, *args, **kwargs): self.stack.__exit__(*args, **kwargs)
show_doc(ContextManagers, title_level=4)
#|export
def shufflish(x, pct=0.04):
"Randomly relocate items of `x` up to `pct` of `len(x)` from their starting location"
n = len(x)
import random
return L(x[i] for i in sorted(range_of(x), key=lambda o: o+n*(1+random.random()*pct)))
#|export
def console_help(
libname:str): # name of library for console script listing
"Show help for all console scripts from `libname`"
from fastcore.style import S
from pkg_resources import iter_entry_points as ep
for e in ep('console_scripts'):
if e.module_name == libname or e.module_name.startswith(libname+'.'):
nm = S.bold.light_blue(e.name)
print(f'{nm:45}{e.load().__doc__}')
#| export
def hl_md(s, lang='xml', show=True):
"Syntax highlight `s` using `lang`."
md = f'```{lang}\n{s}\n```'
if not show: return md
try:
from IPython import display
return display.Markdown(md)
except ImportError: print(s)
When we display code in a notebook, it's nice to highlight it, so we create a function to simplify that:
hl_md('<test><xml foo="bar">a child</xml></test>')
<test><xml foo="bar">a child</xml></test>
#| export
def type2str(typ:type)->str:
"Stringify `typ`"
if typ is None or typ is NoneType: return 'None'
if hasattr(typ, '__origin__'):
args = ", ".join(type2str(arg) for arg in typ.__args__)
if typ.__origin__ is Union: return f"Union[{args}]"
return f"{typ.__origin__.__name__}[{args}]"
elif isinstance(typ, type): return typ.__name__
return str(typ)
test_eq(type2str(Optional[float]), 'Union[float, None]')
#| export
def dataclass_src(cls):
import dataclasses
src = f"@dataclass\nclass {cls.__name__}:\n"
for f in dataclasses.fields(cls):
d = "" if f.default is dataclasses.MISSING else f" = {f.default!r}"
src += f" {f.name}: {type2str(f.type)}{d}\n"
return src
from dataclasses import make_dataclass, dataclass
DC = make_dataclass('DC', [('x', int), ('y', Optional[float], None), ('z', float, None)])
print(dataclass_src(DC))
@dataclass class DC: x: int y: Union[float, None] = None z: float = None
#| export
def nullable_dc(cls):
"Like `dataclass`, but default of `None` added to fields without defaults"
from dataclasses import dataclass, field
for k,v in get_annotations_ex(cls)[0].items():
if not hasattr(cls,k): setattr(cls, k, field(default=None))
return dataclass(cls)
@nullable_dc
class Person: name: str; age: int; city: str = "Unknown"
Person(name="Bob")
Person(name='Bob', age=None, city='Unknown')
#| export
def make_nullable(clas):
from dataclasses import dataclass, fields, MISSING
if hasattr(clas, '_nullable'): return
clas._nullable = True
original_init = clas.__init__
def __init__(self, *args, **kwargs):
flds = fields(clas)
dargs = {k.name:v for k,v in zip(flds, args)}
for f in flds:
nm = f.name
if nm not in dargs and nm not in kwargs and f.default is None and f.default_factory is MISSING:
kwargs[nm] = None
original_init(self, *args, **kwargs)
clas.__init__ = __init__
for f in fields(clas):
if f.default is MISSING and f.default_factory is MISSING: f.default = None
return clas
@dataclass
class Person: name: str; age: int; city: str = "Unknown"
make_nullable(Person)
Person("Bob", city='NY')
Person(name='Bob', age=None, city='NY')
Person(name="Bob")
Person(name='Bob', age=None, city='Unknown')
Person("Bob", 34)
Person(name='Bob', age=34, city='Unknown')
#| export
def mk_dataclass(cls):
from dataclasses import dataclass, field, is_dataclass, MISSING
if is_dataclass(cls): return make_nullable(cls)
for k,v in get_annotations_ex(cls)[0].items():
if not hasattr(cls,k) or getattr(cls,k) is MISSING:
setattr(cls, k, field(default=None))
dataclass(cls, init=True, repr=True, eq=True, order=False, unsafe_hash=False, frozen=False)
class Person: name: str; age: int; city: str = "Unknown"
mk_dataclass(Person)
Person(name="Bob")
Person(name='Bob', age=None, city='Unknown')
#| export
def flexicache(*funcs, maxsize=128):
"Like `lru_cache`, but customisable with policy `funcs`"
import asyncio
def _f(func):
cache,states = {}, [None]*len(funcs)
def _cache_logic(key, execute_func):
if key in cache:
result,states = cache[key]
if not any(f(state) for f,state in zip(funcs, states)):
cache[key] = cache.pop(key)
return result
del cache[key]
try: newres = execute_func()
except:
if key not in cache: raise
cache[key] = cache.pop(key)
return result
cache[key] = (newres, [f(None) for f in funcs])
if len(cache) > maxsize: cache.popitem()
return newres
@wraps(func)
def wrapper(*args, **kwargs):
return _cache_logic(f"{args} // {kwargs}", lambda: func(*args, **kwargs))
@wraps(func)
async def async_wrapper(*args, **kwargs):
return await _cache_logic(f"{args} // {kwargs}", lambda: asyncio.ensure_future(func(*args, **kwargs)))
return async_wrapper if asyncio.iscoroutinefunction(func) else wrapper
return _f
This is a flexible lru cache function that you can pass a list of functions to. Those functions define the cache eviction policy. For instance, time_policy
is provided for time-based cache eviction, and mtime_policy
evicts based on a file's modified-time changing. The policy functions are passed the last value that function returned was (initially None
), and return a new value to indicate the cache has expired. When the cache expires, all functions are called with None
to force getting new values.
#| export
def time_policy(seconds):
"A `flexicache` policy that expires cached items after `seconds` have passed"
def policy(last_time):
now = time()
return now if last_time is None or now-last_time>seconds else None
return policy
#| export
def mtime_policy(filepath):
"A `flexicache` policy that expires cached items after `filepath` modified-time changes"
def policy(mtime):
current_mtime = getmtime(filepath)
return current_mtime if mtime is None or current_mtime>mtime else None
return policy
@flexicache(time_policy(10), mtime_policy('000_tour.ipynb'))
def cached_func(x, y): return x+y
cached_func(1,2)
3
@flexicache(time_policy(10), mtime_policy('000_tour.ipynb'))
async def cached_func(x, y): return x+y
await cached_func(1,2)
await cached_func(1,2)
3
#| export
def timed_cache(seconds=60, maxsize=128):
"Like `lru_cache`, but also with time-based eviction"
return flexicache(time_policy(seconds), maxsize=maxsize)
This function is a small convenience wrapper for using flexicache
with time_policy
.
@timed_cache(seconds=0.05, maxsize=2)
def cached_func(x): return x * 2, time()
# basic caching
result1, time1 = cached_func(2)
test_eq(result1, 4)
sleep(0.001)
result2, time2 = cached_func(2)
test_eq(result2, 4)
test_eq(time1, time2)
# caching different values
result3, _ = cached_func(3)
test_eq(result3, 6)
# maxsize
_, time4 = cached_func(4)
_, time2_new = cached_func(2)
test_close(time2, time2_new, eps=0.1)
_, time3_new = cached_func(3)
test_ne(time3_new, time())
# time expiration
sleep(0.05)
_, time4_new = cached_func(4)
test_ne(time4_new, time())
#|hide
import nbdev; nbdev.nbdev_export()