#default_exp xtras
#export
from fastcore.imports import *
from fastcore.foundation import *
from fastcore.basics import *
from functools import wraps
import mimetypes,pickle,random,json,subprocess,shlex,bz2,gzip,zipfile,tarfile
import imghdr,struct,distutils.util,tempfile,time,string,collections
from contextlib import contextmanager,ExitStack
from pdb import set_trace
from datetime import datetime, timezone
from timeit import default_timer
from fastcore.test import *
from nbdev.showdoc import *
from fastcore.nb_imports import *
from time import sleep
Utility functions used in the fastai library
#export
def dict2obj(d):
"Convert (possibly nested) dicts (or lists of dicts) to `AttrDict`"
if isinstance(d, (L,list)): return L(d).map(dict2obj)
if not isinstance(d, dict): return d
return AttrDict(**{k:dict2obj(v) for k,v in d.items()})
This is a convenience to give you "dotted" access to (possibly nested) dictionaries, e.g:
d1 = dict(a=1, b=dict(c=2,d=3))
d2 = dict2obj(d1)
test_eq(d2.b.c, 2)
test_eq(d2.b['c'], 2)
It can also be used on lists of dicts.
_list_of_dicts = [d1, d1]
ds = dict2obj(_list_of_dicts)
test_eq(ds[0].b.c, 2)
#export
def obj2dict(d):
"Convert (possibly nested) AttrDicts (or lists of AttrDicts) to `dict`"
if isinstance(d, (L,list)): return list(L(d).map(obj2dict))
if not isinstance(d, dict): return d
return dict(**{k:obj2dict(v) for k,v in d.items()})
obj2dict
can be used to reverse what is done by dict2obj
:
test_eq(obj2dict(d2), d1)
test_eq(obj2dict(ds), _list_of_dicts)
#export
def _repr_dict(d, lvl):
if isinstance(d,dict):
its = [f"{k}: {_repr_dict(v,lvl+1)}" for k,v in d.items()]
elif isinstance(d,(list,L)): its = [_repr_dict(o,lvl+1) for o in d]
else: return str(d)
return '\n' + '\n'.join([" "*(lvl*2) + "- " + o for o in its])
#export
def repr_dict(d):
"Print nested dicts and lists, such as returned by `dict2obj`"
return _repr_dict(d,0).strip()
print(repr_dict(d2))
- a: 1 - b: - c: 2 - d: 3
repr_dict
is used to display AttrDict
both with repr
and in Jupyter Notebooks:
#export
@patch
def __repr__(self:AttrDict): return repr_dict(self)
AttrDict._repr_markdown_ = AttrDict.__repr__
print(repr(d2))
- a: 1 - b: - c: 2 - d: 3
d2
#export
def is_listy(x):
"`isinstance(x, (tuple,list,L,slice,Generator))`"
return isinstance(x, (tuple,list,L,slice,Generator))
assert is_listy((1,))
assert is_listy([1])
assert is_listy(L([1]))
assert is_listy(slice(2))
assert not is_listy(array([1]))
#export
def shufflish(x, pct=0.04):
"Randomly relocate items of `x` up to `pct` of `len(x)` from their starting location"
n = len(x)
return L(x[i] for i in sorted(range_of(x), key=lambda o: o+n*(1+random.random()*pct)))
#export
def mapped(f, it):
"map `f` over `it`, unless it's not listy, in which case return `f(it)`"
return L(it).map(f) if is_listy(it) else f(it)
def _f(x,a=1): return x-a
test_eq(mapped(_f,1),0)
test_eq(mapped(_f,[1,2]),[0,1])
test_eq(mapped(_f,(1,)),(0,))
#export
#hide
class IterLen:
"Base class to add iteration to anything supporting `__len__` and `__getitem__`"
def __iter__(self): return (self[i] for i in range_of(self))
#export
@docs
class ReindexCollection(GetAttr, IterLen):
"Reindexes collection `coll` with indices `idxs` and optional LRU cache of size `cache`"
_default='coll'
def __init__(self, coll, idxs=None, cache=None, tfm=noop):
if idxs is None: idxs = L.range(coll)
store_attr()
if cache is not None: self._get = functools.lru_cache(maxsize=cache)(self._get)
def _get(self, i): return self.tfm(self.coll[i])
def __getitem__(self, i): return self._get(self.idxs[i])
def __len__(self): return len(self.coll)
def reindex(self, idxs): self.idxs = idxs
def shuffle(self): random.shuffle(self.idxs)
def cache_clear(self): self._get.cache_clear()
def __getstate__(self): return {'coll': self.coll, 'idxs': self.idxs, 'cache': self.cache, 'tfm': self.tfm}
def __setstate__(self, s): self.coll,self.idxs,self.cache,self.tfm = s['coll'],s['idxs'],s['cache'],s['tfm']
_docs = dict(reindex="Replace `self.idxs` with idxs",
shuffle="Randomly shuffle indices",
cache_clear="Clear LRU cache")
show_doc(ReindexCollection, title_level=4)
This is useful when constructing batches or organizing data in a particular manner (i.e. for deep learning). This class is primarly used in organizing data for language models in fastai.
You can supply a custom index upon instantiation with the idxs
argument, or you can call the reindex
method to supply a new index for your collection.
Here is how you can reindex a list such that the elements are reversed:
rc=ReindexCollection(['a', 'b', 'c', 'd', 'e'], idxs=[4,3,2,1,0])
list(rc)
['e', 'd', 'c', 'b', 'a']
Alternatively, you can use the reindex
method:
show_doc(ReindexCollection.reindex, title_level=6)
rc=ReindexCollection(['a', 'b', 'c', 'd', 'e'])
rc.reindex([4,3,2,1,0])
list(rc)
['e', 'd', 'c', 'b', 'a']
You can optionally specify a LRU cache, which uses functools.lru_cache upon instantiation:
sz = 50
t = ReindexCollection(L.range(sz), cache=2)
#trigger a cache hit by indexing into the same element multiple times
t[0], t[0]
t._get.cache_info()
CacheInfo(hits=1, misses=1, maxsize=2, currsize=1)
You can optionally clear the LRU cache by calling the cache_clear
method:
show_doc(ReindexCollection.cache_clear, title_level=5)
sz = 50
t = ReindexCollection(L.range(sz), cache=2)
#trigger a cache hit by indexing into the same element multiple times
t[0], t[0]
t.cache_clear()
t._get.cache_info()
CacheInfo(hits=0, misses=0, maxsize=2, currsize=0)
show_doc(ReindexCollection.shuffle, title_level=5)
Note that an ordered index is automatically constructed for the data structure even if one is not supplied.
rc=ReindexCollection(['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'])
rc.shuffle()
list(rc)
['c', 'f', 'e', 'g', 'h', 'b', 'd', 'a']
sz = 50
t = ReindexCollection(L.range(sz), cache=2)
test_eq(list(t), range(sz))
test_eq(t[sz-1], sz-1)
test_eq(t._get.cache_info().hits, 1)
t.shuffle()
test_eq(t._get.cache_info().hits, 1)
test_ne(list(t), range(sz))
test_eq(set(t), set(range(sz)))
t.cache_clear()
test_eq(t._get.cache_info().hits, 0)
test_eq(t.count(0), 1)
#hide
#Test ReindexCollection pickles
t1 = pickle.loads(pickle.dumps(t))
test_eq(list(t), list(t1))
Utilities (other than extensions to Pathlib.Path) for dealing with IO.
# export
@contextmanager
def maybe_open(f, mode='r', **kwargs):
"Context manager: open `f` if it is a path (and close on exit)"
if isinstance(f, (str,os.PathLike)):
with open(f, mode, **kwargs) as f: yield f
else: yield f
This is useful for functions where you want to accept a path or file. maybe_open
will not close your file handle if you pass one in.
def _f(fn):
with maybe_open(fn) as f: return f.encoding
fname = '00_test.ipynb'
sys_encoding = 'cp1252' if sys.platform == 'win32' else 'UTF-8'
test_eq(_f(fname), sys_encoding)
with open(fname) as fh: test_eq(_f(fh), sys_encoding)
For example, we can use this to reimplement imghdr.what
from the Python standard library, which is written in Python 3.9 as:
def what(file, h=None):
f = None
try:
if h is None:
if isinstance(file, (str,os.PathLike)):
f = open(file, 'rb')
h = f.read(32)
else:
location = file.tell()
h = file.read(32)
file.seek(location)
for tf in imghdr.tests:
res = tf(h, f)
if res: return res
finally:
if f: f.close()
return None
Here's an example of the use of this function:
fname = 'images/puppy.jpg'
what(fname)
'jpeg'
With maybe_open
, Self
, and L.map_first
, we can rewrite this in a much more concise and (in our opinion) clear way:
def what(file, h=None):
if h is None:
with maybe_open(file, 'rb') as f: h = f.peek(32)
return L(imghdr.tests).map_first(Self(h,file))
...and we can check that it still works:
test_eq(what(fname), 'jpeg')
...along with the version passing a file handle:
with open(fname,'rb') as f: test_eq(what(f), 'jpeg')
...along with the h
parameter version:
with open(fname,'rb') as f: test_eq(what(None, h=f.read(32)), 'jpeg')
def _jpg_size(f):
size,ftype = 2,0
while not 0xc0 <= ftype <= 0xcf:
f.seek(size, 1)
byte = f.read(1)
while ord(byte) == 0xff: byte = f.read(1)
ftype = ord(byte)
size = struct.unpack('>H', f.read(2))[0] - 2
f.seek(1, 1) # `precision'
h,w = struct.unpack('>HH', f.read(4))
return w,h
def _gif_size(f): return struct.unpack('<HH', head[6:10])
def _png_size(f):
assert struct.unpack('>i', head[4:8])[0]==0x0d0a1a0a
return struct.unpack('>ii', head[16:24])
#export
def image_size(fn):
"Tuple of (w,h) for png, gif, or jpg; `None` otherwise"
d = dict(png=_png_size, gif=_gif_size, jpeg=_jpg_size)
with maybe_open(fn, 'rb') as f: return d[imghdr.what(f)](f)
test_eq(image_size(fname), (1200,803))
#export
def bunzip(fn):
"bunzip `fn`, raising exception if output already exists"
fn = Path(fn)
assert fn.exists(), f"{fn} doesn't exist"
out_fn = fn.with_suffix('')
assert not out_fn.exists(), f"{out_fn} already exists"
with bz2.BZ2File(fn, 'rb') as src, out_fn.open('wb') as dst:
for d in iter(lambda: src.read(1024*1024), b''): dst.write(d)
f = Path('files/test.txt')
if f.exists(): f.unlink()
bunzip('files/test.txt.bz2')
t = f.open().readlines()
test_eq(len(t),1)
test_eq(t[0], 'test\n')
f.unlink()
#export
def join_path_file(file, path, ext=''):
"Return `path/file` if file is a string or a `Path`, file otherwise"
if not isinstance(file, (str, Path)): return file
path.mkdir(parents=True, exist_ok=True)
return path/f'{file}{ext}'
path = Path.cwd()/'_tmp'/'tst'
f = join_path_file('tst.txt', path)
assert path.exists()
test_eq(f, path/'tst.txt')
with open(f, 'w') as f_: assert join_path_file(f_, path) == f_
shutil.rmtree(Path.cwd()/'_tmp')
#export
def loads(s, cls=None, object_hook=None, parse_float=None,
parse_int=None, parse_constant=None, object_pairs_hook=None, **kw):
"Same as `json.loads`, but handles `None`"
if not s: return {}
return json.loads(s, cls=cls, object_hook=object_hook, parse_float=parse_float,
parse_int=parse_int, parse_constant=parse_constant, object_pairs_hook=object_pairs_hook, **kw)
#export
def loads_multi(s:str):
"Generator of >=0 decoded json dicts, possibly with non-json ignored text at start and end"
_dec = json.JSONDecoder()
while s.find('{')>=0:
s = s[s.find('{'):]
obj,pos = _dec.raw_decode(s)
if not pos: raise ValueError(f'no JSON object found at {pos}')
yield obj
s = s[pos:]
tst = """
# ignored
{ "a":1 }
hello
{
"b":2
}
"""
test_eq(list(loads_multi(tst)), [{'a': 1}, {'b': 2}])
#export
def untar_dir(file, dest):
with tempfile.TemporaryDirectory(dir='.') as d:
d = Path(d)
with tarfile.open(mode='r:gz', fileobj=file) as t: t.extractall(d)
next(d.iterdir()).rename(dest)
#export
def repo_details(url):
"Tuple of `owner,name` from ssh or https git repo `url`"
res = remove_suffix(url.strip(), '.git')
res = res.split(':')[-1]
return res.split('/')[-2:]
test_eq(repo_details('https://github.com/fastai/fastai.git'), ['fastai', 'fastai'])
test_eq(repo_details('git@github.com:fastai/nbdev.git\n'), ['fastai', 'nbdev'])
#export
def run(cmd, *rest, ignore_ex=False, as_bytes=False, stderr=False):
"Pass `cmd` (splitting with `shlex` if string) to `subprocess.run`; return `stdout`; raise `IOError` if fails"
if rest: cmd = (cmd,)+rest
elif isinstance(cmd,str): cmd = shlex.split(cmd)
res = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdout = res.stdout
if stderr and res.stderr: stdout += b' ;; ' + res.stderr
if not as_bytes: stdout = stdout.decode().strip()
if ignore_ex: return (res.returncode, stdout)
if res.returncode: raise IOError(stdout)
return stdout
You can pass a string (which will be split based on standard shell rules), a list, or pass args directly:
if sys.platform == 'win32':
assert 'ipynb' in run('cmd /c dir /p')
assert 'ipynb' in run(['cmd', '/c', 'dir', '/p'])
assert 'ipynb' in run('cmd', '/c', 'dir', '/p')
else:
assert 'ipynb' in run('ls -ls')
assert 'ipynb' in run(['ls', '-l'])
assert 'ipynb' in run('ls', '-l')
Some commands fail in non-error situations, like grep
. Use ignore_ex
in those cases, which will return a tuple of stdout and returncode:
if sys.platform == 'win32':
test_eq(run('cmd /c findstr asdfds 00_test.ipynb', ignore_ex=True)[0], 1)
else:
test_eq(run('grep asdfds 00_test.ipynb', ignore_ex=True)[0], 1)
run
automatically decodes returned bytes to a str
. Use as_bytes
to skip that:
if sys.platform == 'win32':
# why I ingore as_types, becuase every time nbdev_clean_nbs will update \n to \nn
test_eq(run('cmd /c echo hi'), 'hi')
else:
test_eq(run('echo hi', as_bytes=True), b'hi\n')
#export
def open_file(fn, mode='r', **kwargs):
"Open a file, with optional compression if gz or bz2 suffix"
if isinstance(fn, io.IOBase): return fn
fn = Path(fn)
if fn.suffix=='.bz2': return bz2.BZ2File(fn, mode, **kwargs)
elif fn.suffix=='.gz' : return gzip.GzipFile(fn, mode, **kwargs)
elif fn.suffix=='.zip': return zipfile.ZipFile(fn, mode, **kwargs)
else: return open(fn,mode, **kwargs)
#export
def save_pickle(fn, o):
"Save a pickle file, to a file name or opened file"
with open_file(fn, 'wb') as f: pickle.dump(o, f)
#export
def load_pickle(fn):
"Load a pickle file from a file name or opened file"
with open_file(fn, 'rb') as f: return pickle.load(f)
for suf in '.pkl','.bz2','.gz':
# delete=False is added for Windows. https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
with tempfile.NamedTemporaryFile(suffix=suf, delete=False) as f:
fn = Path(f.name)
save_pickle(fn, 't')
t = load_pickle(fn)
f.close()
test_eq(t,'t')
The following methods are added to the standard python libary Pathlib.Path.
#export
@patch
def readlines(self:Path, hint=-1, encoding='utf8'):
"Read the content of `self`"
with self.open(encoding=encoding) as f: return f.readlines(hint)
#export
@patch
def read_json(self:Path, encoding=None, errors=None):
"Same as `read_text` followed by `loads`"
return loads(self.read_text(encoding=encoding, errors=errors))
#export
@patch
def mk_write(self:Path, data, encoding=None, errors=None, mode=511):
"Make all parent dirs of `self`, and write `data`"
self.parent.mkdir(exist_ok=True, parents=True, mode=mode)
self.write_text(data, encoding=encoding, errors=errors)
#export
@patch
def ls(self:Path, n_max=None, file_type=None, file_exts=None):
"Contents of path as a list"
extns=L(file_exts)
if file_type: extns += L(k for k,v in mimetypes.types_map.items() if v.startswith(file_type+'/'))
has_extns = len(extns)==0
res = (o for o in self.iterdir() if has_extns or o.suffix in extns)
if n_max is not None: res = itertools.islice(res, n_max)
return L(res)
We add an ls()
method to pathlib.Path
which is simply defined as list(Path.iterdir())
, mainly for convenience in REPL environments such as notebooks.
path = Path()
t = path.ls()
assert len(t)>0
t1 = path.ls(10)
test_eq(len(t1), 10)
t2 = path.ls(file_exts='.ipynb')
assert len(t)>len(t2)
t[0]
Path('.gitattributes')
You can also pass an optional file_type
MIME prefix and/or a list of file extensions.
lib_path = (path/'../fastcore')
txt_files=lib_path.ls(file_type='text')
assert len(txt_files) > 0 and txt_files[0].suffix=='.py'
ipy_files=path.ls(file_exts=['.ipynb'])
assert len(ipy_files) > 0 and ipy_files[0].suffix=='.ipynb'
txt_files[0],ipy_files[0]
(Path('../fastcore/__init__.py'), Path('01_basics.ipynb'))
#hide
path = Path()
pkl = pickle.dumps(path)
p2 = pickle.loads(pkl)
test_eq(path.ls()[0], p2.ls()[0])
#export
@patch
def __repr__(self:Path):
b = getattr(Path, 'BASE_PATH', None)
if b:
try: self = self.relative_to(b)
except: pass
return f"Path({self.as_posix()!r})"
fastai also updates the repr
of Path
such that, if Path.BASE_PATH
is defined, all paths are printed relative to that path (as long as they are contained in Path.BASE_PATH
:
t = ipy_files[0].absolute()
try:
Path.BASE_PATH = t.parent.parent
test_eq(repr(t), f"Path('nbs/{t.name}')")
finally: Path.BASE_PATH = None
#export
def truncstr(s:str, maxlen:int, suf:str='…', space='')->str:
"Truncate `s` to length `maxlen`, adding suffix `suf` if truncated"
return s[:maxlen-len(suf)]+suf if len(s)+len(space)>maxlen else s+space
w = 'abacadabra'
test_eq(truncstr(w, 10), w)
test_eq(truncstr(w, 5), 'abac…')
test_eq(truncstr(w, 5, suf=''), 'abaca')
test_eq(truncstr(w, 11, space='_'), w+"_")
test_eq(truncstr(w, 10, space='_'), w[:-1]+'…')
test_eq(truncstr(w, 5, suf='!!'), 'aba!!')
#export
spark_chars = '▁▂▃▅▆▇'
#export
def _ceil(x, lim=None): return x if (not lim or x <= lim) else lim
def _sparkchar(x, mn, mx, incr, empty_zero):
if x is None or (empty_zero and not x): return ' '
if incr == 0: return spark_chars[0]
res = int((_ceil(x,mx)-mn)/incr-0.5)
return spark_chars[res]
#export
def sparkline(data, mn=None, mx=None, empty_zero=False):
"Sparkline for `data`, with `None`s (and zero, if `empty_zero`) shown as empty column"
valid = [o for o in data if o is not None]
if not valid: return ' '
mn,mx,n = ifnone(mn,min(valid)),ifnone(mx,max(valid)),len(spark_chars)
res = [_sparkchar(x=o, mn=mn, mx=mx, incr=(mx-mn)/n, empty_zero=empty_zero) for o in data]
return ''.join(res)
data = [9,6,None,1,4,0,8,15,10]
print(f'without "empty_zero": {sparkline(data, empty_zero=False)}')
print(f' with "empty_zero": {sparkline(data, empty_zero=True )}')
without "empty_zero": ▅▂ ▁▂▁▃▇▅ with "empty_zero": ▅▂ ▁▂ ▃▇▅
You can set a maximum and minimum for the y-axis of the sparkline with the arguments mn
and mx
respectively:
sparkline([1,2,3,400], mn=0, mx=3)
'▂▅▇▇'
#export
def autostart(g):
"Decorator that automatically starts a generator"
@functools.wraps(g)
def f():
r = g()
next(r)
return r
return f
#export
class EventTimer:
"An event timer with history of `store` items of time `span`"
def __init__(self, store=5, span=60):
self.hist,self.span,self.last = collections.deque(maxlen=store),span,default_timer()
self._reset()
def _reset(self): self.start,self.events = self.last,0
def add(self, n=1):
"Record `n` events"
if self.duration>self.span:
self.hist.append(self.freq)
self._reset()
self.events +=n
self.last = default_timer()
@property
def duration(self): return default_timer()-self.start
@property
def freq(self): return self.events/self.duration
show_doc(EventTimer, title_level=4)
class
EventTimer
[source]
EventTimer
(store
=5
,span
=60
)
An event timer with history of store
items of time span
Add events with add
, and get number of events
and their frequency (freq
).
# Random wait function for testing
def _randwait(): yield from (sleep(random.random()/200) for _ in range(100))
c = EventTimer(store=5, span=0.03)
for o in _randwait(): c.add(1)
print(f'Num Events: {c.events}, Freq/sec: {c.freq:.01f}')
print('Most recent: ', sparkline(c.hist), *L(c.hist).map('{:.01f}'))
Num Events: 8, Freq/sec: 423.0 Most recent: ▂▂▁▁▇ 318.5 319.0 266.9 275.6 427.7
#export
_fmt = string.Formatter()
#export
def stringfmt_names(s:str)->list:
"Unique brace-delimited names in `s`"
return uniqueify(o[1] for o in _fmt.parse(s) if o[1])
s = '/pulls/{pull_number}/reviews/{review_id}'
test_eq(stringfmt_names(s), ['pull_number','review_id'])
#export
class PartialFormatter(string.Formatter):
"A `string.Formatter` that doesn't error on missing fields, and tracks missing fields and unused args"
def __init__(self):
self.missing = set()
super().__init__()
def get_field(self, nm, args, kwargs):
try: return super().get_field(nm, args, kwargs)
except KeyError:
self.missing.add(nm)
return '{'+nm+'}',nm
def check_unused_args(self, used, args, kwargs):
self.xtra = filter_keys(kwargs, lambda o: o not in used)
show_doc(PartialFormatter, title_level=4)
class
PartialFormatter
[source]
PartialFormatter
() ::Formatter
A string.Formatter
that doesn't error on missing fields, and tracks missing fields and unused args
#export
def partial_format(s:str, **kwargs):
"string format `s`, ignoring missing field errors, returning missing and extra fields"
fmt = PartialFormatter()
res = fmt.format(s, **kwargs)
return res,list(fmt.missing),fmt.xtra
The result is a tuple of (formatted_string,missing_fields,extra_fields)
, e.g:
res,missing,xtra = partial_format(s, pull_number=1, foo=2)
test_eq(res, '/pulls/1/reviews/{review_id}')
test_eq(missing, ['review_id'])
test_eq(xtra, {'foo':2})
#export
def utc2local(dt:datetime)->datetime:
"Convert `dt` from UTC to local time"
return dt.replace(tzinfo=timezone.utc).astimezone(tz=None)
dt = datetime(2000,1,1,12)
print(f'{dt} UTC is {utc2local(dt)} local time')
2000-01-01 12:00:00 UTC is 2000-01-01 12:00:00+00:00 local time
#export
def local2utc(dt:datetime)->datetime:
"Convert `dt` from local to UTC time"
return dt.replace(tzinfo=None).astimezone(tz=timezone.utc)
print(f'{dt} local is {local2utc(dt)} UTC time')
2000-01-01 12:00:00 local is 2000-01-01 12:00:00+00:00 UTC time
#export
def trace(f):
"Add `set_trace` to an existing function `f`"
if getattr(f, '_traced', False): return f
def _inner(*args,**kwargs):
set_trace()
return f(*args,**kwargs)
_inner._traced = True
return _inner
You can add a breakpoint to an existing function, e.g:
Path.cwd = trace(Path.cwd)
Path.cwd()
Now, when the function is called it will drop you into the debugger. Note, you must issue the s
command when you begin to step into the function that is being traced.
#export
def round_multiple(x, mult, round_down=False):
"Round `x` to nearest multiple of `mult`"
def _f(x_): return (int if round_down else round)(x_/mult)*mult
res = L(x).map(_f)
return res if is_listy(x) else res[0]
test_eq(round_multiple(63,32), 64)
test_eq(round_multiple(50,32), 64)
test_eq(round_multiple(40,32), 32)
test_eq(round_multiple( 0,32), 0)
test_eq(round_multiple(63,32, round_down=True), 32)
test_eq(round_multiple((63,40),32), (64,32))
#export
@contextmanager
def modified_env(*delete, **replace):
"Context manager temporarily modifying `os.environ` by deleting `delete` and replacing `replace`"
prev = dict(os.environ)
try:
os.environ.update(replace)
for k in delete: os.environ.pop(k, None)
yield
finally:
os.environ.clear()
os.environ.update(prev)
# USER isn't in Cloud Linux Environments
env_test = 'USERNAME' if sys.platform == "win32" else 'SHELL'
oldusr = os.environ[env_test]
replace_param = {env_test: 'a'}
with modified_env('PATH', **replace_param):
test_eq(os.environ[env_test], 'a')
assert 'PATH' not in os.environ
assert 'PATH' in os.environ
test_eq(os.environ[env_test], oldusr)
#export
class ContextManagers(GetAttr):
"Wrapper for `contextlib.ExitStack` which enters a collection of context managers"
def __init__(self, mgrs): self.default,self.stack = L(mgrs),ExitStack()
def __enter__(self): self.default.map(self.stack.enter_context)
def __exit__(self, *args, **kwargs): self.stack.__exit__(*args, **kwargs)
show_doc(ContextManagers, title_level=4)
#export
def str2bool(s):
"Case-insensitive convert string `s` too a bool (`y`,`yes`,`t`,`true`,`on`,`1`->`True`)"
if not isinstance(s,str): return bool(s)
return bool(distutils.util.strtobool(s)) if s else False
for o in "y YES t True on 1".split(): assert str2bool(o)
for o in "n no FALSE off 0".split(): assert not str2bool(o)
for o in 0,None,'',False: assert not str2bool(o)
for o in 1,True: assert str2bool(o)
#export
def _is_instance(f, gs):
tst = [g if type(g) in [type, 'function'] else g.__class__ for g in gs]
for g in tst:
if isinstance(f, g) or f==g: return True
return False
def _is_first(f, gs):
for o in L(getattr(f, 'run_after', None)):
if _is_instance(o, gs): return False
for g in gs:
if _is_instance(f, L(getattr(g, 'run_before', None))): return False
return True
#export
def sort_by_run(fs):
end = L(fs).attrgot('toward_end')
inp,res = L(fs)[~end] + L(fs)[end], L()
while len(inp):
for i,o in enumerate(inp):
if _is_first(o, inp):
res.append(inp.pop(i))
break
else: raise Exception("Impossible to sort")
return res
#hide
from nbdev.export import notebook2script
notebook2script()
Converted 00_test.ipynb. Converted 01_basics.ipynb. Converted 02_foundation.ipynb. Converted 03_xtras.ipynb. Converted 03a_parallel.ipynb. Converted 03b_net.ipynb. Converted 04_dispatch.ipynb. Converted 05_transform.ipynb. Converted 07_meta.ipynb. Converted 08_script.ipynb. Converted index.ipynb.