#|default_exp basics
#|export
from fastcore.imports import *
import builtins,types
import pprint
try: from types import UnionType
except ImportError: UnionType = None
#|hide
from __future__ import annotations
from fastcore.test import *
from nbdev.showdoc import *
from fastcore.nb_imports import *
Basic functionality used in the fastai library
#|export
defaults = SimpleNamespace()
#|export
def ifnone(a, b):
"`b` if `a` is None else `a`"
return b if a is None else a
Since b if a is None else a
is such a common pattern, we wrap it in a function. However, be careful, because python will evaluate both a
and b
when calling ifnone
(which it doesn't do if using the if
version directly).
test_eq(ifnone(None,1), 1)
test_eq(ifnone(2 ,1), 2)
#|export
def maybe_attr(o, attr):
"`getattr(o,attr,o)`"
return getattr(o,attr,o)
Return the attribute attr
for object o
. If the attribute doesn't exist, then return the object o
instead.
class myobj: myattr='foo'
test_eq(maybe_attr(myobj, 'myattr'), 'foo')
test_eq(maybe_attr(myobj, 'another_attr'), myobj)
#|export
def basic_repr(flds=None):
"Minimal `__repr__`"
if isinstance(flds, str): flds = re.split(', *', flds)
flds = list(flds or [])
def _f(self):
res = f'{type(self).__module__}.{type(self).__name__}'
if not flds: return f'<{res}>'
sig = ', '.join(f'{o}={getattr(self,o)!r}' for o in flds)
return f'{res}({sig})'
return _f
In types which provide rich display functionality in Jupyter, their __repr__
is also called in order to provide a fallback text representation. Unfortunately, this includes a memory address which changes on every invocation, making it non-deterministic. This causes diffs to get messy and creates conflicts in git. To fix this, put __repr__=basic_repr()
inside your class.
class SomeClass: __repr__=basic_repr()
repr(SomeClass())
'<__main__.SomeClass>'
If you pass a list of attributes (flds
) of an object, then this will generate a string with the name of each attribute and its corresponding value. The format of this string is key=value
, where key
is the name of the attribute, and value
is the value of the attribute. For each value, attempt to use the __name__
attribute, otherwise fall back to using the value's __repr__
when constructing the string.
class SomeClass:
a=1
b='foo'
__repr__=basic_repr('a,b')
__name__='some-class'
repr(SomeClass())
"__main__.SomeClass(a=1, b='foo')"
class AnotherClass:
c=SomeClass()
d='bar'
__repr__=basic_repr(['c', 'd'])
repr(AnotherClass())
"__main__.AnotherClass(c=__main__.SomeClass(a=1, b='foo'), d='bar')"
#|export
def is_array(x):
"`True` if `x` supports `__array__` or `iloc`"
return hasattr(x,'__array__') or hasattr(x,'iloc')
is_array(np.array(1)),is_array([1])
(True, False)
#|export
def listify(o=None, *rest, use_list=False, match=None):
"Convert `o` to a `list`"
if rest: o = (o,)+rest
if use_list: res = list(o)
elif o is None: res = []
elif isinstance(o, list): res = o
elif isinstance(o, str) or is_array(o): res = [o]
elif is_iter(o): res = list(o)
else: res = [o]
if match is not None:
if is_coll(match): match = len(match)
if len(res)==1: res = res*match
else: assert len(res)==match, 'Match length mismatch'
return res
Conversion is designed to "do what you mean", e.g:
test_eq(listify('hi'), ['hi'])
test_eq(listify(array(1)), [array(1)])
test_eq(listify(1), [1])
test_eq(listify([1,2]), [1,2])
test_eq(listify(range(3)), [0,1,2])
test_eq(listify(None), [])
test_eq(listify(1,2), [1,2])
arr = np.arange(9).reshape(3,3)
listify(arr)
[array([[0, 1, 2], [3, 4, 5], [6, 7, 8]])]
listify(array([1,2]))
[array([1, 2])]
Generators are turned into lists too:
gen = (o for o in range(3))
test_eq(listify(gen), [0,1,2])
Use match
to provide a length to match:
test_eq(listify(1,match=3), [1,1,1])
If match
is a sequence, it's length is used:
test_eq(listify(1,match=range(3)), [1,1,1])
If the listified item is not of length 1
, it must be the same length as match
:
test_eq(listify([1,1,1],match=3), [1,1,1])
test_fail(lambda: listify([1,1],match=3))
#|export
def tuplify(o, use_list=False, match=None):
"Make `o` a tuple"
return tuple(listify(o, use_list=use_list, match=match))
test_eq(tuplify(None),())
test_eq(tuplify([1,2,3]),(1,2,3))
test_eq(tuplify(1,match=[1,2,3]),(1,1,1))
#|export
def true(x):
"Test whether `x` is truthy; collections with >0 elements are considered `True`"
try: return bool(len(x))
except: return bool(x)
[(o,true(o)) for o in
(array(0),array(1),array([0]),array([0,1]),1,0,'',None)]
[(array(0), False), (array(1), True), (array([0]), True), (array([0, 1]), True), (1, True), (0, False), ('', False), (None, False)]
#|export
class NullType:
"An object that is `False` and can be called, chained, and indexed"
def __getattr__(self,*args):return null
def __call__(self,*args, **kwargs):return null
def __getitem__(self, *args):return null
def __bool__(self): return False
null = NullType()
bool(null.hi().there[3])
False
#|export
def tonull(x):
"Convert `None` to `null`"
return null if x is None else x
bool(tonull(None).hi().there[3])
False
#|export
def get_class(nm, *fld_names, sup=None, doc=None, funcs=None, **flds):
"Dynamically create a class, optionally inheriting from `sup`, containing `fld_names`"
attrs = {}
for f in fld_names: attrs[f] = None
for f in listify(funcs): attrs[f.__name__] = f
for k,v in flds.items(): attrs[k] = v
sup = ifnone(sup, ())
if not isinstance(sup, tuple): sup=(sup,)
def _init(self, *args, **kwargs):
for i,v in enumerate(args): setattr(self, list(attrs.keys())[i], v)
for k,v in kwargs.items(): setattr(self,k,v)
all_flds = [*fld_names,*flds.keys()]
def _eq(self,b):
return all([getattr(self,k)==getattr(b,k) for k in all_flds])
if not sup: attrs['__repr__'] = basic_repr(all_flds)
attrs['__init__'] = _init
attrs['__eq__'] = _eq
res = type(nm, sup, attrs)
if doc is not None: res.__doc__ = doc
return res
show_doc(get_class)
_t = get_class('_t', 'a', b=2)
t = _t()
test_eq(t.a, None)
test_eq(t.b, 2)
t = _t(1, b=3)
test_eq(t.a, 1)
test_eq(t.b, 3)
t = _t(1, 3)
test_eq(t.a, 1)
test_eq(t.b, 3)
test_eq(t, pickle.loads(pickle.dumps(t)))
repr(t)
'__main__._t(a=1, b=3)'
Most often you'll want to call mk_class
, since it adds the class to your module. See mk_class
for more details and examples of use (which also apply to get_class
).
#|export
def mk_class(nm, *fld_names, sup=None, doc=None, funcs=None, mod=None, **flds):
"Create a class using `get_class` and add to the caller's module"
if mod is None: mod = sys._getframe(1).f_locals
res = get_class(nm, *fld_names, sup=sup, doc=doc, funcs=funcs, **flds)
mod[nm] = res
Any kwargs
will be added as class attributes, and sup
is an optional (tuple of) base classes.
mk_class('_t', a=1, sup=dict)
t = _t()
test_eq(t.a, 1)
assert(isinstance(t,dict))
A __init__
is provided that sets attrs for any kwargs
, and for any args
(matching by position to fields), along with a __repr__
which prints all attrs. The docstring is set to doc
. You can pass funcs
which will be added as attrs with the function names.
def foo(self): return 1
mk_class('_t', 'a', sup=dict, doc='test doc', funcs=foo)
t = _t(3, b=2)
test_eq(t.a, 3)
test_eq(t.b, 2)
test_eq(t.foo(), 1)
test_eq(t.__doc__, 'test doc')
t
{}
#|export
def wrap_class(nm, *fld_names, sup=None, doc=None, funcs=None, **flds):
"Decorator: makes function a method of a new class `nm` passing parameters to `mk_class`"
def _inner(f):
mk_class(nm, *fld_names, sup=sup, doc=doc, funcs=listify(funcs)+[f], mod=f.__globals__, **flds)
return f
return _inner
@wrap_class('_t', a=2)
def bar(self,x): return x+1
t = _t()
test_eq(t.a, 2)
test_eq(t.bar(3), 4)
#|export
class ignore_exceptions:
"Context manager to ignore exceptions"
def __enter__(self): pass
def __exit__(self, *args): return True
show_doc(ignore_exceptions, title_level=4)
with ignore_exceptions():
# Exception will be ignored
raise Exception
#|export
def exec_local(code, var_name):
"Call `exec` on `code` and return the var `var_name"
loc = {}
exec(code, globals(), loc)
return loc[var_name]
test_eq(exec_local("a=1", "a"), 1)
#|export
def risinstance(types, obj=None):
"Curried `isinstance` but with args reversed"
types = tuplify(types)
if obj is None: return partial(risinstance,types)
if any(isinstance(t,str) for t in types):
return any(t.__name__ in types for t in type(obj).__mro__)
return isinstance(obj, types)
assert risinstance(int, 1)
assert not risinstance(str, 0)
assert risinstance(int)(1)
types
can also be strings:
assert risinstance(('str','int'), 'a')
assert risinstance('str', 'a')
assert not risinstance('int', 'a')
These are used when you need a pass-through function.
show_doc(noop)
noop()
test_eq(noop(1),1)
show_doc(noops)
class _t: foo=noops
test_eq(_t().foo(1),1)
These lists are useful for things like padding an array or adding index column(s) to arrays.
#|export
#|hide
class _InfMeta(type):
@property
def count(self): return itertools.count()
@property
def zeros(self): return itertools.cycle([0])
@property
def ones(self): return itertools.cycle([1])
@property
def nones(self): return itertools.cycle([None])
#|export
class Inf(metaclass=_InfMeta):
"Infinite lists"
pass
show_doc(Inf);
Inf
defines the following properties:
count: itertools.count()
zeros: itertools.cycle([0])
ones : itertools.cycle([1])
nones: itertools.cycle([None])
test_eq([o for i,o in zip(range(5), Inf.count)],
[0, 1, 2, 3, 4])
test_eq([o for i,o in zip(range(5), Inf.zeros)],
[0]*5)
test_eq([o for i,o in zip(range(5), Inf.ones)],
[1]*5)
test_eq([o for i,o in zip(range(5), Inf.nones)],
[None]*5)
#|export
_dumobj = object()
def _oper(op,a,b=_dumobj): return (lambda o:op(o,a)) if b is _dumobj else op(a,b)
def _mk_op(nm, mod):
"Create an operator using `oper` and add to the caller's module"
op = getattr(operator,nm)
def _inner(a, b=_dumobj): return _oper(op, a,b)
_inner.__name__ = _inner.__qualname__ = nm
_inner.__doc__ = f'Same as `operator.{nm}`, or returns partial if 1 arg'
mod[nm] = _inner
#|export
def in_(x, a):
"`True` if `x in a`"
return x in a
operator.in_ = in_
#|export
_all_ = ['lt','gt','le','ge','eq','ne','add','sub','mul','truediv','is_','is_not','in_']
#|export
for op in ['lt','gt','le','ge','eq','ne','add','sub','mul','truediv','is_','is_not','in_']: _mk_op(op, globals())
# test if element is in another
assert in_('c', ('b', 'c', 'a'))
assert in_(4, [2,3,4,5])
assert in_('t', 'fastai')
test_fail(in_('h', 'fastai'))
# use in_ as a partial
assert in_('fastai')('t')
assert in_([2,3,4,5])(4)
test_fail(in_('fastai')('h'))
In addition to in_
, the following functions are provided matching the behavior of the equivalent versions in operator
: lt gt le ge eq ne add sub mul truediv is_ is_not.
lt(3,5),gt(3,5),is_(None,None),in_(0,[1,2])
(True, False, True, False)
Similarly to _in
, they also have additional functionality: if you only pass one param, they return a partial function that passes that param as the second positional parameter.
lt(5)(3),gt(5)(3),is_(None)(None),in_([1,2])(0)
(True, False, True, False)
#|export
def ret_true(*args, **kwargs):
"Predicate: always `True`"
return True
assert ret_true(1,2,3)
assert ret_true(False)
#|export
def ret_false(*args, **kwargs):
"Predicate: always `False`"
return False
#|export
def stop(e=StopIteration):
"Raises exception `e` (by default `StopException`)"
raise e
#|export
def gen(func, seq, cond=ret_true):
"Like `(func(o) for o in seq if cond(func(o)))` but handles `StopIteration`"
return itertools.takewhile(cond, map(func,seq))
test_eq(gen(noop, Inf.count, lt(5)),
range(5))
test_eq(gen(operator.neg, Inf.count, gt(-5)),
[0,-1,-2,-3,-4])
test_eq(gen(lambda o:o if o<5 else stop(), Inf.count),
range(5))
#|export
def chunked(it, chunk_sz=None, drop_last=False, n_chunks=None):
"Return batches from iterator `it` of size `chunk_sz` (or return `n_chunks` total)"
assert bool(chunk_sz) ^ bool(n_chunks)
if n_chunks: chunk_sz = max(math.ceil(len(it)/n_chunks), 1)
if not isinstance(it, Iterator): it = iter(it)
while True:
res = list(itertools.islice(it, chunk_sz))
if res and (len(res)==chunk_sz or not drop_last): yield res
if len(res)<chunk_sz: return
Note that you must pass either chunk_sz
, or n_chunks
, but not both.
t = list(range(10))
test_eq(chunked(t,3), [[0,1,2], [3,4,5], [6,7,8], [9]])
test_eq(chunked(t,3,True), [[0,1,2], [3,4,5], [6,7,8], ])
t = map(lambda o:stop() if o==6 else o, Inf.count)
test_eq(chunked(t,3), [[0, 1, 2], [3, 4, 5]])
t = map(lambda o:stop() if o==7 else o, Inf.count)
test_eq(chunked(t,3), [[0, 1, 2], [3, 4, 5], [6]])
t = np.arange(10)
test_eq(chunked(t,3), [[0,1,2], [3,4,5], [6,7,8], [9]])
test_eq(chunked(t,3,True), [[0,1,2], [3,4,5], [6,7,8], ])
test_eq(chunked([], 3), [])
test_eq(chunked([], n_chunks=3), [])
#|export
def otherwise(x, tst, y):
"`y if tst(x) else x`"
return y if tst(x) else x
test_eq(otherwise(2+1, gt(3), 4), 3)
test_eq(otherwise(2+1, gt(2), 4), 4)
These functions reduce boilerplate when setting or manipulating attributes or properties of objects.
#|export
def custom_dir(c, add):
"Implement custom `__dir__`, adding `add` to `cls`"
return object.__dir__(c) + listify(add)
custom_dir
allows you extract the __dict__
property of a class and appends the list add
to it.
class _T:
def f(): pass
s = custom_dir(_T(), add=['foo', 'bar'])
assert {'foo', 'bar', 'f'}.issubset(s)
#|export
class AttrDict(dict):
"`dict` subclass that also provides access to keys as attrs"
def __getattr__(self,k): return self[k] if k in self else stop(AttributeError(k))
def __setattr__(self, k, v): (self.__setitem__,super().__setattr__)[k[0]=='_'](k,v)
def __dir__(self): return super().__dir__() + list(self.keys())
def _repr_markdown_(self): return f'```json\n{pprint.pformat(self, indent=2)}\n```'
def copy(self): return AttrDict(**self)
d = AttrDict(a=1,b="two")
test_eq(d.a, 1)
test_eq(d['b'], 'two')
test_eq(d.get('c','nope'), 'nope')
d.b = 2
test_eq(d.b, 2)
test_eq(d['b'], 2)
d['b'] = 3
test_eq(d['b'], 3)
test_eq(d.b, 3)
assert 'a' in dir(d)
AttrDict
will pretty print in Jupyter Notebooks:
_test_dict = {'a':1, 'b': {'c':1, 'd':2}, 'c': {'c':1, 'd':2}, 'd': {'c':1, 'd':2},
'e': {'c':1, 'd':2}, 'f': {'c':1, 'd':2, 'e': 4, 'f':[1,2,3,4,5]}}
AttrDict(_test_dict)
{ 'a': 1,
'b': {'c': 1, 'd': 2},
'c': {'c': 1, 'd': 2},
'd': {'c': 1, 'd': 2},
'e': {'c': 1, 'd': 2},
'f': {'c': 1, 'd': 2, 'e': 4, 'f': [1, 2, 3, 4, 5]}}
#|export
def get_annotations_ex(obj, *, globals=None, locals=None):
"Backport of py3.10 `get_annotations` that returns globals/locals"
if isinstance(obj, type):
obj_dict = getattr(obj, '__dict__', None)
if obj_dict and hasattr(obj_dict, 'get'):
ann = obj_dict.get('__annotations__', None)
if isinstance(ann, types.GetSetDescriptorType): ann = None
else: ann = None
obj_globals = None
module_name = getattr(obj, '__module__', None)
if module_name:
module = sys.modules.get(module_name, None)
if module: obj_globals = getattr(module, '__dict__', None)
obj_locals = dict(vars(obj))
unwrap = obj
elif isinstance(obj, types.ModuleType):
ann = getattr(obj, '__annotations__', None)
obj_globals = getattr(obj, '__dict__')
obj_locals,unwrap = None,None
elif callable(obj):
ann = getattr(obj, '__annotations__', None)
obj_globals = getattr(obj, '__globals__', None)
obj_locals,unwrap = None,obj
else: raise TypeError(f"{obj!r} is not a module, class, or callable.")
if ann is None: ann = {}
if not isinstance(ann, dict): raise ValueError(f"{obj!r}.__annotations__ is neither a dict nor None")
if not ann: ann = {}
if unwrap is not None:
while True:
if hasattr(unwrap, '__wrapped__'):
unwrap = unwrap.__wrapped__
continue
if isinstance(unwrap, functools.partial):
unwrap = unwrap.func
continue
break
if hasattr(unwrap, "__globals__"): obj_globals = unwrap.__globals__
if globals is None: globals = obj_globals
if locals is None: locals = obj_locals
return dict(ann), globals, locals
In Python 3.10 inspect.get_annotations
was added. However previous versions of Python are unable to evaluate type annotations correctly if from future import __annotations__
is used. Furthermore, all annotations are evaluated, even if only some subset are needed. get_annotations_ex
provides the same functionality as inspect.get_annotations
, but works on earlier versions of Python, and returns the globals
and locals
needed to evaluate types.
#|export
def eval_type(t, glb, loc):
"`eval` a type or collection of types, if needed, for annotations in py3.10+"
if isinstance(t,str):
if '|' in t: return Union[eval_type(tuple(t.split('|')), glb, loc)]
return eval(t, glb, loc)
if isinstance(t,(tuple,list)): return type(t)([eval_type(c, glb, loc) for c in t])
return t
In py3.10, or if from future import __annotations__
is used, a
is a str
:
class _T2a: pass
def func(a: _T2a): pass
ann,glb,loc = get_annotations_ex(func)
eval_type(ann['a'], glb, loc)
__main__._T2a
|
is supported for defining Union
types when using eval_type
even for python versions prior to 3.9:
class _T2b: pass
def func(a: _T2a|_T2b): pass
ann,glb,loc = get_annotations_ex(func)
eval_type(ann['a'], glb, loc)
typing.Union[__main__._T2a, __main__._T2b]
#|export
def _eval_type(t, glb, loc):
res = eval_type(t, glb, loc)
return NoneType if res is None else res
def type_hints(f):
"Like `typing.get_type_hints` but returns `{}` if not allowed type"
if not isinstance(f, typing._allowed_types): return {}
ann,glb,loc = get_annotations_ex(f)
return {k:_eval_type(v,glb,loc) for k,v in ann.items()}
Below is a list of allowed types for type hints in python:
list(typing._allowed_types)
[function, builtin_function_or_method, method, module, wrapper_descriptor, method-wrapper, method_descriptor]
For example, type func
is allowed so type_hints
returns the same value as typing.get_hints
:
def f(a:int)->bool: ... # a function with type hints (allowed)
exp = {'a':int,'return':bool}
test_eq(type_hints(f), typing.get_type_hints(f))
test_eq(type_hints(f), exp)
However, class
is not an allowed type, so type_hints
returns {}
:
class _T:
def __init__(self, a:int=0)->bool: ...
assert not type_hints(_T)
#|export
def annotations(o):
"Annotations for `o`, or `type(o)`"
res = {}
if not o: return res
res = type_hints(o)
if not res: res = type_hints(getattr(o,'__init__',None))
if not res: res = type_hints(type(o))
return res
This supports a wider range of situations than type_hints
, by checking type()
and __init__
for annotations too:
for o in _T,_T(),_T.__init__,f: test_eq(annotations(o), exp)
assert not annotations(int)
assert not annotations(print)
#|export
def anno_ret(func):
"Get the return annotation of `func`"
return annotations(func).get('return', None) if func else None
def f(x) -> float: return x
test_eq(anno_ret(f), float)
def f(x) -> typing.Tuple[float,float]: return x
test_eq(anno_ret(f), typing.Tuple[float,float])
If your return annotation is None
, anno_ret
will return NoneType
(and not None
):
def f(x) -> None: return x
test_eq(anno_ret(f), NoneType)
assert anno_ret(f) is not None # returns NoneType instead of None
If your function does not have a return type, or if you pass in None
instead of a function, then anno_ret
returns None
:
def f(x): return x
test_eq(anno_ret(f), None)
test_eq(anno_ret(None), None) # instead of passing in a func, pass in None
#|export
def _ispy3_10(): return sys.version_info.major >=3 and sys.version_info.minor >=10
def signature_ex(obj, eval_str:bool=False):
"Backport of `inspect.signature(..., eval_str=True` to <py310"
from inspect import Signature, Parameter, signature
def _eval_param(ann, k, v):
if k not in ann: return v
return Parameter(v.name, v.kind, annotation=ann[k], default=v.default)
if not eval_str: return signature(obj)
if _ispy3_10(): return signature(obj, eval_str=eval_str)
sig = signature(obj)
if sig is None: return None
ann = type_hints(obj)
params = [_eval_param(ann,k,v) for k,v in sig.parameters.items()]
return Signature(params, return_annotation=sig.return_annotation)
#|export
def union2tuple(t):
if (getattr(t, '__origin__', None) is Union
or (UnionType and isinstance(t, UnionType))): return t.__args__
return t
test_eq(union2tuple(Union[int,str]), (int,str))
test_eq(union2tuple(int), int)
test_eq(union2tuple(Tuple[int,str]), Tuple[int,str])
test_eq(union2tuple((int,str)), (int,str))
if UnionType: test_eq(union2tuple(int|str), (int,str))
#|export
def argnames(f, frame=False):
"Names of arguments to function or frame `f`"
code = getattr(f, 'f_code' if frame else '__code__')
return code.co_varnames[:code.co_argcount+code.co_kwonlyargcount]
test_eq(argnames(f), ['x'])
#|export
def with_cast(f):
"Decorator which uses any parameter annotations as preprocessing functions"
anno, out_anno, params = annotations(f), anno_ret(f), argnames(f)
c_out = ifnone(out_anno, noop)
defaults = dict(zip(reversed(params), reversed(f.__defaults__ or {})))
@functools.wraps(f)
def _inner(*args, **kwargs):
args = list(args)
for i,v in enumerate(params):
if v in anno:
c = anno[v]
if v in kwargs: kwargs[v] = c(kwargs[v])
elif i<len(args): args[i] = c(args[i])
elif v in defaults: kwargs[v] = c(defaults[v])
return c_out(f(*args, **kwargs))
return _inner
@with_cast
def _f(a, b:Path, c:str='', d=0): return (a,b,c,d)
test_eq(_f(1, '.', 3), (1,Path('.'),'3',0))
test_eq(_f(1, '.'), (1,Path('.'),'',0))
@with_cast
def _g(a:int=0)->str: return a
test_eq(_g(4.0), '4')
test_eq(_g(4.4), '4')
test_eq(_g(2), '2')
#|export
def _store_attr(self, anno, **attrs):
stored = getattr(self, '__stored_args__', None)
for n,v in attrs.items():
if n in anno: v = anno[n](v)
setattr(self, n, v)
if stored is not None: stored[n] = v
#|export
def store_attr(names=None, self=None, but='', cast=False, store_args=None, **attrs):
"Store params named in comma-separated `names` from calling context into attrs in `self`"
fr = sys._getframe(1)
args = argnames(fr, True)
if self: args = ('self', *args)
else: self = fr.f_locals[args[0]]
if store_args is None: store_args = not hasattr(self,'__slots__')
if store_args and not hasattr(self, '__stored_args__'): self.__stored_args__ = {}
anno = annotations(self) if cast else {}
if names and isinstance(names,str): names = re.split(', *', names)
ns = names if names is not None else getattr(self, '__slots__', args[1:])
added = {n:fr.f_locals[n] for n in ns}
attrs = {**attrs, **added}
if isinstance(but,str): but = re.split(', *', but)
attrs = {k:v for k,v in attrs.items() if k not in but}
return _store_attr(self, anno, **attrs)
In it's most basic form, you can use store_attr
to shorten code like this:
class T:
def __init__(self, a,b,c): self.a,self.b,self.c = a,b,c
...to this:
class T:
def __init__(self, a,b,c): store_attr('a,b,c', self)
This class behaves as if we'd used the first form:
t = T(1,c=2,b=3)
assert t.a==1 and t.b==3 and t.c==2
In addition, it stores the attrs as a dict
in __stored_args__
, which you can use for display, logging, and so forth.
test_eq(t.__stored_args__, {'a':1, 'b':3, 'c':2})
Since you normally want to use the first argument (often called self
) for storing attributes, it's optional:
class T:
def __init__(self, a,b,c:str): store_attr('a,b,c')
t = T(1,c=2,b=3)
assert t.a==1 and t.b==3 and t.c==2
#|hide
class _T:
def __init__(self, a,b):
c = 2
store_attr('a,b,c')
t = _T(1,b=3)
assert t.a==1 and t.b==3 and t.c==2
With cast=True
any parameter annotations will be used as preprocessing functions for the corresponding arguments:
class T:
def __init__(self, a:listify, b, c:str): store_attr('a,b,c', cast=True)
t = T(1,c=2,b=3)
assert t.a==[1] and t.b==3 and t.c=='2'
You can inherit from a class using store_attr
, and just call it again to add in any new attributes added in the derived class:
class T2(T):
def __init__(self, d, **kwargs):
super().__init__(**kwargs)
store_attr('d')
t = T2(d=1,a=2,b=3,c=4)
assert t.a==2 and t.b==3 and t.c==4 and t.d==1
You can skip passing a list of attrs to store. In this case, all arguments passed to the method are stored:
class T:
def __init__(self, a,b,c): store_attr()
t = T(1,c=2,b=3)
assert t.a==1 and t.b==3 and t.c==2
class T4(T):
def __init__(self, d, **kwargs):
super().__init__(**kwargs)
store_attr()
t = T4(4, a=1,c=2,b=3)
assert t.a==1 and t.b==3 and t.c==2 and t.d==4
class T4:
def __init__(self, *, a: int, b: float = 1):
store_attr()
t = T4(a=3)
assert t.a==3 and t.b==1
t = T4(a=3, b=2)
assert t.a==3 and t.b==2
#|hide
# ensure that subclasses work with or without `store_attr`
class T4(T):
def __init__(self, **kwargs):
super().__init__(**kwargs)
store_attr()
t = T4(a=1,c=2,b=3)
assert t.a==1 and t.b==3 and t.c==2
class T4(T): pass
t = T4(a=1,c=2,b=3)
assert t.a==1 and t.b==3 and t.c==2
#|hide
#ensure that kwargs work with names==None
class T:
def __init__(self, a,b,c,**kwargs): store_attr(**kwargs)
t = T(1,c=2,b=3,d=4,e=-1)
assert t.a==1 and t.b==3 and t.c==2 and t.d==4 and t.e==-1
#|hide
#ensure that kwargs work with names==''
class T:
def __init__(self, a, **kwargs):
self.a = a+1
store_attr('', **kwargs)
t = T(a=1, d=4)
test_eq(t.a, 2)
test_eq(t.d, 4)
You can skip some attrs by passing but
:
class T:
def __init__(self, a,b,c): store_attr(but='a')
t = T(1,c=2,b=3)
assert t.b==3 and t.c==2
assert not hasattr(t,'a')
You can also pass keywords to store_attr
, which is identical to setting the attrs directly, but also stores them in __stored_args__
.
class T:
def __init__(self): store_attr(a=1)
t = T()
assert t.a==1
You can also use store_attr inside functions.
def create_T(a, b):
t = SimpleNamespace()
store_attr(self=t)
return t
t = create_T(a=1, b=2)
assert t.a==1 and t.b==2
#|export
def attrdict(o, *ks, default=None):
"Dict from each `k` in `ks` to `getattr(o,k)`"
return {k:getattr(o, k, default) for k in ks}
class T:
def __init__(self, a,b,c): store_attr()
t = T(1,c=2,b=3)
test_eq(attrdict(t,'b','c'), {'b':3, 'c':2})
#|export
def properties(cls, *ps):
"Change attrs in `cls` with names in `ps` to properties"
for p in ps: setattr(cls,p,property(getattr(cls,p)))
class T:
def a(self): return 1
def b(self): return 2
properties(T,'a')
test_eq(T().a,1)
test_eq(T().b(),2)
#|export
_c2w_re = re.compile(r'((?<=[a-z])[A-Z]|(?<!\A)[A-Z](?=[a-z]))')
_camel_re1 = re.compile('(.)([A-Z][a-z]+)')
_camel_re2 = re.compile('([a-z0-9])([A-Z])')
#|export
def camel2words(s, space=' '):
"Convert CamelCase to 'spaced words'"
return re.sub(_c2w_re, rf'{space}\1', s)
test_eq(camel2words('ClassAreCamel'), 'Class Are Camel')
#|export
def camel2snake(name):
"Convert CamelCase to snake_case"
s1 = re.sub(_camel_re1, r'\1_\2', name)
return re.sub(_camel_re2, r'\1_\2', s1).lower()
test_eq(camel2snake('ClassAreCamel'), 'class_are_camel')
test_eq(camel2snake('Already_Snake'), 'already__snake')
#|export
def snake2camel(s):
"Convert snake_case to CamelCase"
return ''.join(s.title().split('_'))
test_eq(snake2camel('a_b_cc'), 'ABCc')
#|export
def class2attr(self, cls_name):
"Return the snake-cased name of the class; strip ending `cls_name` if it exists."
return camel2snake(re.sub(rf'{cls_name}$', '', self.__class__.__name__) or cls_name.lower())
class Parent:
@property
def name(self): return class2attr(self, 'Parent')
class ChildOfParent(Parent): pass
class ParentChildOf(Parent): pass
p = Parent()
cp = ChildOfParent()
cp2 = ParentChildOf()
test_eq(p.name, 'parent')
test_eq(cp.name, 'child_of')
test_eq(cp2.name, 'parent_child_of')
#|export
def getcallable(o, attr):
"Calls `getattr` with a default of `noop`"
return getattr(o, attr, noop)
class Math:
def addition(self,a,b): return a+b
m = Math()
test_eq(getcallable(m, "addition")(a=1,b=2), 3)
test_eq(getcallable(m, "subtraction")(a=1,b=2), None)
#|export
def getattrs(o, *attrs, default=None):
"List of all `attrs` in `o`"
return [getattr(o,attr,default) for attr in attrs]
from fractions import Fraction
getattrs(Fraction(1,2), 'numerator', 'denominator')
[1, 2]
#|export
def hasattrs(o,attrs):
"Test whether `o` contains all `attrs`"
return all(hasattr(o,attr) for attr in attrs)
assert hasattrs(1,('imag','real'))
assert not hasattrs(1,('imag','foo'))
#|export
def setattrs(dest, flds, src):
f = dict.get if isinstance(src, dict) else getattr
flds = re.split(r",\s*", flds)
for fld in flds: setattr(dest, fld, f(src, fld))
d = dict(a=1,bb="2",ignore=3)
o = SimpleNamespace()
setattrs(o, "a,bb", d)
test_eq(o.a, 1)
test_eq(o.bb, "2")
d = SimpleNamespace(a=1,bb="2",ignore=3)
o = SimpleNamespace()
setattrs(o, "a,bb", d)
test_eq(o.a, 1)
test_eq(o.bb, "2")
#|export
def try_attrs(obj, *attrs):
"Return first attr that exists in `obj`"
for att in attrs:
try: return getattr(obj, att)
except: pass
raise AttributeError(attrs)
test_eq(try_attrs(1, 'real'), 1)
test_eq(try_attrs(1, 'foobar', 'real'), 1)
#|export
class GetAttrBase:
"Basic delegation of `__getattr__` and `__dir__`"
_attr=noop
def __getattr__(self,k):
if k[0]=='_' or k==self._attr: return super().__getattr__(k)
return self._getattr(getattr(self, self._attr)[k])
def __dir__(self): return custom_dir(self, getattr(self, self._attr))
#|export
class GetAttr:
"Inherit from this to have all attr accesses in `self._xtra` passed down to `self.default`"
_default='default'
def _component_attr_filter(self,k):
if k.startswith('__') or k in ('_xtra',self._default): return False
xtra = getattr(self,'_xtra',None)
return xtra is None or k in xtra
def _dir(self): return [k for k in dir(getattr(self,self._default)) if self._component_attr_filter(k)]
def __getattr__(self,k):
if self._component_attr_filter(k):
attr = getattr(self,self._default,None)
if attr is not None: return getattr(attr,k)
raise AttributeError(k)
def __dir__(self): return custom_dir(self,self._dir())
# def __getstate__(self): return self.__dict__
def __setstate__(self,data): self.__dict__.update(data)
show_doc(GetAttr, title_level=4)
GetAttr ()
Inherit from this to have all attr accesses in self._xtra
passed down to self.default
Inherit from GetAttr
to have attr access passed down to an instance attribute.
This makes it easy to create composites that don't require callers to know about their components. For a more detailed discussion of how this works as well as relevant context, we suggest reading the delegated composition section of this blog article.
You can customise the behaviour of GetAttr
in subclasses via;
_default
'default'
, so attr access is passed down to self.default
_default
can be set to the name of any instance attribute that does not start with dunder __
_xtra
None
, so all attr access is passed down_xtra
to a list of attribute namesTo illuminate the utility of GetAttr
, suppose we have the following two classes, _WebPage
which is a superclass of _ProductPage
, which we wish to compose like so:
class _WebPage:
def __init__(self, title, author="Jeremy"):
self.title,self.author = title,author
class _ProductPage:
def __init__(self, page, price): self.page,self.price = page,price
page = _WebPage('Soap', author="Sylvain")
p = _ProductPage(page, 15.0)
How do we make it so we can just write p.author
, instead of p.page.author
to access the author
attribute? We can use GetAttr
, of course! First, we subclass GetAttr
when defining _ProductPage
. Next, we set self.default
to the object whose attributes we want to be able to access directly, which in this case is the page
argument passed on initialization:
class _ProductPage(GetAttr):
def __init__(self, page, price): self.default,self.price = page,price #self.default allows you to access page directly.
p = _ProductPage(page, 15.0)
Now, we can access the author
attribute directly from the instance:
test_eq(p.author, 'Sylvain')
If you wish to store the object you are composing in an attribute other than self.default
, you can set the class attribute _data
as shown below. This is useful in the case where you might have a name collision with self.default
:
class _C(GetAttr):
_default = '_data' # use different component name; `self._data` rather than `self.default`
def __init__(self,a): self._data = a
def foo(self): noop
t = _C('Hi')
test_eq(t._data, 'Hi')
test_fail(lambda: t.default) # we no longer have self.default
test_eq(t.lower(), 'hi')
test_eq(t.upper(), 'HI')
assert 'lower' in dir(t)
assert 'upper' in dir(t)
By default, all attributes and methods of the object you are composing are retained. In the below example, we compose a str
object with the class _C
. This allows us to directly call string methods on instances of class _C
, such as str.lower()
or str.upper()
:
class _C(GetAttr):
# allow all attributes and methods to get passed to `self.default` (by leaving _xtra=None)
def __init__(self,a): self.default = a
def foo(self): noop
t = _C('Hi')
test_eq(t.lower(), 'hi')
test_eq(t.upper(), 'HI')
assert 'lower' in dir(t)
assert 'upper' in dir(t)
However, you can choose which attributes or methods to retain by defining a class attribute _xtra
, which is a list of allowed attribute and method names to delegate. In the below example, we only delegate the lower
method from the composed str
object when defining class _C
:
class _C(GetAttr):
_xtra = ['lower'] # specify which attributes get passed to `self.default`
def __init__(self,a): self.default = a
def foo(self): noop
t = _C('Hi')
test_eq(t.default, 'Hi')
test_eq(t.lower(), 'hi')
test_fail(lambda: t.upper()) # upper wasn't in _xtra, so it isn't available to be called
assert 'lower' in dir(t)
assert 'upper' not in dir(t)
You must be careful to properly set an instance attribute in __init__
that corresponds to the class attribute _default
. The below example sets the class attribute _default
to data
, but erroneously fails to define self.data
(and instead defines self.default
).
Failing to properly set instance attributes leads to errors when you try to access methods directly:
class _C(GetAttr):
_default = 'data' # use a bad component name; i.e. self.data does not exist
def __init__(self,a): self.default = a
def foo(self): noop
# TODO: should we raise an error when we create a new instance ...
t = _C('Hi')
test_eq(t.default, 'Hi')
# ... or is it enough for all GetAttr features to raise errors
test_fail(lambda: t.data)
test_fail(lambda: t.lower())
test_fail(lambda: t.upper())
test_fail(lambda: dir(t))
#|hide
# I don't think this test is essential to the docs but it probably makes sense to
# check that everything works when we set both _xtra and _default to non-default values
class _C(GetAttr):
_xtra = ['lower', 'upper']
_default = 'data'
def __init__(self,a): self.data = a
def foo(self): noop
t = _C('Hi')
test_eq(t.data, 'Hi')
test_eq(t.lower(), 'hi')
test_eq(t.upper(), 'HI')
assert 'lower' in dir(t)
assert 'upper' in dir(t)
#|hide
# when consolidating the filter logic, I choose the previous logic from
# __getattr__ k.startswith('__') rather than
# _dir k.startswith('_').
class _C(GetAttr):
def __init__(self): self.default = type('_D', (), {'_under': 1, '__dunder': 2})()
t = _C()
test_eq(t.default._under, 1)
test_eq(t._under, 1) # _ prefix attr access is allowed on component
assert '_under' in dir(t)
test_eq(t.default.__dunder, 2)
test_fail(lambda: t.__dunder) # __ prefix attr access is not allowed on component
assert '__dunder' not in dir(t)
assert t.__dir__ is not None # __ prefix attr access is allowed on composite
assert '__dir__' in dir(t)
#|hide
#Failing test. TODO: make GetAttr pickle-safe
# class B:
# def __init__(self): self.a = A()
# @funcs_kwargs
# class A(GetAttr):
# wif=after_iter= noops
# _methods = 'wif after_iter'.split()
# _default = 'dataset'
# def __init__(self, **kwargs): pass
# a = A()
# b = A(wif=a.wif)
# a = A()
# b = A(wif=a.wif)
# tst = pickle.dumps(b)
# c = pickle.loads(tst)
#|export
def delegate_attr(self, k, to):
"Use in `__getattr__` to delegate to attr `to` without inheriting from `GetAttr`"
if k.startswith('_') or k==to: raise AttributeError(k)
try: return getattr(getattr(self,to), k)
except AttributeError: raise AttributeError(k) from None
delegate_attr
is a functional way to delegate attributes, and is an alternative to GetAttr
. We recommend reading the documentation of GetAttr
for more details around delegation.
You can use achieve delegation when you define __getattr__
by using delegate_attr
:
#|hide
import pandas as pd
class _C:
def __init__(self, o): self.o = o # self.o corresponds to the `to` argument in delegate_attr.
def __getattr__(self, k): return delegate_attr(self, k, to='o')
t = _C('HELLO') # delegates to a string
test_eq(t.lower(), 'hello')
t = _C(np.array([5,4,3])) # delegates to a numpy array
test_eq(t.sum(), 12)
t = _C(pd.DataFrame({'a': [1,2], 'b': [3,4]})) # delegates to a pandas.DataFrame
test_eq(t.b.max(), 4)
ShowPrint
is a base class that defines a show
method, which is used primarily for callbacks in fastai that expect this method to be defined.
#|export
#|hide
class ShowPrint:
"Base class that prints for `show`"
def show(self, *args, **kwargs): print(str(self))
Int
, Float
, and Str
extend int
, float
and str
respectively by adding an additional show
method by inheriting from ShowPrint
.
The code for Int
is shown below:
#|export
#|hide
class Int(int,ShowPrint):
"An extensible `int`"
pass
#|export
#|hide
class Str(str,ShowPrint):
"An extensible `str`"
pass
class Float(float,ShowPrint):
"An extensible `float`"
pass
Examples:
Int(0).show()
Float(2.0).show()
Str('Hello').show()
0 2.0 Hello
Functions that manipulate popular python collections.
#|export
def flatten(o):
"Concatenate all collections and items as a generator"
for item in o:
if isinstance(item, str): yield item; continue
try: yield from flatten(item)
except TypeError: yield item
#|export
def concat(colls)->list:
"Concatenate all collections and items as a list"
return list(flatten(colls))
concat([(o for o in range(2)),[2,3,4], 5])
[0, 1, 2, 3, 4, 5]
concat([["abc", "xyz"], ["foo", "bar"]])
['abc', 'xyz', 'foo', 'bar']
#|export
def strcat(its, sep:str='')->str:
"Concatenate stringified items `its`"
return sep.join(map(str,its))
test_eq(strcat(['a',2]), 'a2')
test_eq(strcat(['a',2], ';'), 'a;2')
#|export
def detuplify(x):
"If `x` is a tuple with one thing, extract it"
return None if len(x)==0 else x[0] if len(x)==1 and getattr(x, 'ndim', 1)==1 else x
test_eq(detuplify(()),None)
test_eq(detuplify([1]),1)
test_eq(detuplify([1,2]), [1,2])
test_eq(detuplify(np.array([[1,2]])), np.array([[1,2]]))
#|export
def replicate(item,match):
"Create tuple of `item` copied `len(match)` times"
return (item,)*len(match)
t = [1,1]
test_eq(replicate([1,2], t),([1,2],[1,2]))
test_eq(replicate(1, t),(1,1))
#|export
def setify(o):
"Turn any list like-object into a set."
return o if isinstance(o,set) else set(listify(o))
# test
test_eq(setify(None),set())
test_eq(setify('abc'),{'abc'})
test_eq(setify([1,2,2]),{1,2})
test_eq(setify(range(0,3)),{0,1,2})
test_eq(setify({1,2}),{1,2})
#|export
def merge(*ds):
"Merge all dictionaries in `ds`"
return {k:v for d in ds if d is not None for k,v in d.items()}
test_eq(merge(), {})
test_eq(merge(dict(a=1,b=2)), dict(a=1,b=2))
test_eq(merge(dict(a=1,b=2), dict(b=3,c=4), None), dict(a=1, b=3, c=4))
#|export
def range_of(x):
"All indices of collection `x` (i.e. `list(range(len(x)))`)"
return list(range(len(x)))
test_eq(range_of([1,1,1,1]), [0,1,2,3])
#|export
def groupby(x, key, val=noop):
"Like `itertools.groupby` but doesn't need to be sorted, and isn't lazy, plus some extensions"
if isinstance(key,int): key = itemgetter(key)
elif isinstance(key,str): key = attrgetter(key)
if isinstance(val,int): val = itemgetter(val)
elif isinstance(val,str): val = attrgetter(val)
res = {}
for o in x: res.setdefault(key(o), []).append(val(o))
return res
test_eq(groupby('aa ab bb'.split(), itemgetter(0)), {'a':['aa','ab'], 'b':['bb']})
Here's an example of how to invert a grouping, using an int
as key
(which uses itemgetter
; passing a str
will use attrgetter
), and using a val
function:
d = {0: [1, 3, 7], 2: [3], 3: [5], 4: [8], 5: [4], 7: [5]}
groupby(((o,k) for k,v in d.items() for o in v), 0, 1)
{1: [0], 3: [0, 2], 7: [0], 5: [3, 7], 8: [4], 4: [5]}
#|export
def last_index(x, o):
"Finds the last index of occurence of `x` in `o` (returns -1 if no occurence)"
try: return next(i for i in reversed(range(len(o))) if o[i] == x)
except StopIteration: return -1
test_eq(last_index(9, [1, 2, 9, 3, 4, 9, 10]), 5)
test_eq(last_index(6, [1, 2, 9, 3, 4, 9, 10]), -1)
#|export
def filter_dict(d, func):
"Filter a `dict` using `func`, applied to keys and values"
return {k:v for k,v in d.items() if func(k,v)}
letters = {o:chr(o) for o in range(65,73)}
letters
{65: 'A', 66: 'B', 67: 'C', 68: 'D', 69: 'E', 70: 'F', 71: 'G', 72: 'H'}
filter_dict(letters, lambda k,v: k<67 or v in 'FG')
{65: 'A', 66: 'B', 70: 'F', 71: 'G'}
#|export
def filter_keys(d, func):
"Filter a `dict` using `func`, applied to keys"
return {k:v for k,v in d.items() if func(k)}
filter_keys(letters, lt(67))
{65: 'A', 66: 'B'}
#|export
def filter_values(d, func):
"Filter a `dict` using `func`, applied to values"
return {k:v for k,v in d.items() if func(v)}
filter_values(letters, in_('FG'))
{70: 'F', 71: 'G'}
#|export
def cycle(o):
"Like `itertools.cycle` except creates list of `None`s if `o` is empty"
o = listify(o)
return itertools.cycle(o) if o is not None and len(o) > 0 else itertools.cycle([None])
test_eq(itertools.islice(cycle([1,2,3]),5), [1,2,3,1,2])
test_eq(itertools.islice(cycle([]),3), [None]*3)
test_eq(itertools.islice(cycle(None),3), [None]*3)
test_eq(itertools.islice(cycle(1),3), [1,1,1])
#|export
def zip_cycle(x, *args):
"Like `itertools.zip_longest` but `cycle`s through elements of all but first argument"
return zip(x, *map(cycle,args))
test_eq(zip_cycle([1,2,3,4],list('abc')), [(1, 'a'), (2, 'b'), (3, 'c'), (4, 'a')])
#|export
def sorted_ex(iterable, key=None, reverse=False):
"Like `sorted`, but if key is str use `attrgetter`; if int use `itemgetter`"
if isinstance(key,str): k=lambda o:getattr(o,key,0)
elif isinstance(key,int): k=itemgetter(key)
else: k=key
return sorted(iterable, key=k, reverse=reverse)
#|export
def not_(f):
"Create new function that negates result of `f`"
def _f(*args, **kwargs): return not f(*args, **kwargs)
return _f
def f(a): return a>0
test_eq(f(1),True)
test_eq(not_(f)(1),False)
test_eq(not_(f)(a=-1),True)
#|export
def argwhere(iterable, f, negate=False, **kwargs):
"Like `filter_ex`, but return indices for matching items"
if kwargs: f = partial(f,**kwargs)
if negate: f = not_(f)
return [i for i,o in enumerate(iterable) if f(o)]
#|export
def filter_ex(iterable, f=noop, negate=False, gen=False, **kwargs):
"Like `filter`, but passing `kwargs` to `f`, defaulting `f` to `noop`, and adding `negate` and `gen`"
if f is None: f = lambda _: True
if kwargs: f = partial(f,**kwargs)
if negate: f = not_(f)
res = filter(f, iterable)
if gen: return res
return list(res)
#|export
def range_of(a, b=None, step=None):
"All indices of collection `a`, if `a` is a collection, otherwise `range`"
if is_coll(a): a = len(a)
return list(range(a,b,step) if step is not None else range(a,b) if b is not None else range(a))
test_eq(range_of([1,1,1,1]), [0,1,2,3])
test_eq(range_of(4), [0,1,2,3])
#|export
def renumerate(iterable, start=0):
"Same as `enumerate`, but returns index as 2nd element instead of 1st"
return ((o,i) for i,o in enumerate(iterable, start=start))
test_eq(renumerate('abc'), (('a',0),('b',1),('c',2)))
#|export
def first(x, f=None, negate=False, **kwargs):
"First element of `x`, optionally filtered by `f`, or None if missing"
x = iter(x)
if f: x = filter_ex(x, f=f, negate=negate, gen=True, **kwargs)
return next(x, None)
test_eq(first(['a', 'b', 'c', 'd', 'e']), 'a')
test_eq(first([False]), False)
test_eq(first([False], noop), None)
#|export
def only(o):
"Return the only item of `o`, raise if `o` doesn't have exactly one item"
it = iter(o)
try: res = next(it)
except StopIteration: raise ValueError('iterable has 0 items') from None
try: next(it)
except StopIteration: return res
raise ValueError(f'iterable has more than 1 item')
#|hide
test_fail(lambda: only([]), contains='iterable has 0 items')
test_eq(only([0]), 0)
test_fail(lambda: only([0,1]), contains='iterable has more than 1 item')
#|export
def nested_attr(o, attr, default=None):
"Same as `getattr`, but if `attr` includes a `.`, then looks inside nested objects"
try:
for a in attr.split("."): o = getattr(o, a)
except AttributeError: return default
return o
a = SimpleNamespace(b=(SimpleNamespace(c=1)))
test_eq(nested_attr(a, 'b.c'), getattr(getattr(a, 'b'), 'c'))
test_eq(nested_attr(a, 'b.d'), None)
#|export
def nested_setdefault(o, attr, default):
"Same as `setdefault`, but if `attr` includes a `.`, then looks inside nested objects"
attrs = attr.split('.')
for a in attrs[:-1]: o = o.setdefault(a, type(o)())
return o.setdefault(attrs[-1], default)
#|hide
o = {'e':'f'}
test_eq(nested_setdefault(o, 'a.b.c', 'd'), 'd')
test_eq(o, {'a':{'b':{'c':'d'}},'e':'f'})
#|hide
o = {'a':'b'}
test_eq(nested_setdefault(o, 'a', 'c'), 'b')
test_eq(o, {'a':'b'})
#|hide
o = {'a':{'b':'c'}}
test_eq(nested_setdefault(o, 'a.b', 'd'), 'c')
test_eq(o,{'a':{'b':'c'}})
#|export
def nested_callable(o, attr):
"Same as `nested_attr` but if not found will return `noop`"
return nested_attr(o, attr, noop)
a = SimpleNamespace(b=(SimpleNamespace(c=1)))
test_eq(nested_callable(a, 'b.c'), getattr(getattr(a, 'b'), 'c'))
test_eq(nested_callable(a, 'b.d'), noop)
#|export
def _access(coll, idx): return coll.get(idx, None) if hasattr(coll, 'get') else coll[idx] if idx<len(coll) else None
def _nested_idx(coll, *idxs):
*idxs,last_idx = idxs
for idx in idxs:
if isinstance(coll,str) or not isinstance(coll, typing.Collection): return None,None
coll = coll.get(idx, None) if hasattr(coll, 'get') else coll[idx] if idx<len(coll) else None
return coll,last_idx
#|export
def nested_idx(coll, *idxs):
"Index into nested collections, dicts, etc, with `idxs`"
if not coll or not idxs: return coll
coll,idx = _nested_idx(coll, *idxs)
if not coll or not idxs: return coll
return _access(coll, idx)
a = {'b':[1,{'c':2}]}
test_eq(nested_idx(a, 'nope'), None)
test_eq(nested_idx(a, 'nope', 'nup'), None)
test_eq(nested_idx(a, 'b', 3), None)
test_eq(nested_idx(a), a)
test_eq(nested_idx(a, 'b'), [1,{'c':2}])
test_eq(nested_idx(a, 'b', 1), {'c':2})
test_eq(nested_idx(a, 'b', 1, 'c'), 2)
#|export
def set_nested_idx(coll, value, *idxs):
"Set value indexed like `nested_idx"
coll,idx = _nested_idx(coll, *idxs)
coll[idx] = value
set_nested_idx(a, 3, 'b', 0)
test_eq(nested_idx(a, 'b', 0), 3)
#|export
def val2idx(x):
"Dict from value to index"
return {v:k for k,v in enumerate(x)}
test_eq(val2idx([1,2,3]), {3:2,1:0,2:1})
#|export
def uniqueify(x, sort=False, bidir=False, start=None):
"Unique elements in `x`, optional `sort`, optional return reverse correspondence, optional prepend with elements."
res = list(dict.fromkeys(x))
if start is not None: res = listify(start)+res
if sort: res.sort()
return (res,val2idx(res)) if bidir else res
t = [1,1,0,5,0,3]
test_eq(uniqueify(t),[1,0,5,3])
test_eq(uniqueify(t, sort=True),[0,1,3,5])
test_eq(uniqueify(t, start=[7,8,6]), [7,8,6,1,0,5,3])
v,o = uniqueify(t, bidir=True)
test_eq(v,[1,0,5,3])
test_eq(o,{1:0, 0: 1, 5: 2, 3: 3})
v,o = uniqueify(t, sort=True, bidir=True)
test_eq(v,[0,1,3,5])
test_eq(o,{0:0, 1: 1, 3: 2, 5: 3})
#|export
# looping functions from https://github.com/willmcgugan/rich/blob/master/rich/_loop.py
def loop_first_last(values):
"Iterate and generate a tuple with a flag for first and last value."
iter_values = iter(values)
try: previous_value = next(iter_values)
except StopIteration: return
first = True
for value in iter_values:
yield first,False,previous_value
first,previous_value = False,value
yield first,True,previous_value
test_eq(loop_first_last(range(3)), [(True,False,0), (False,False,1), (False,True,2)])
#|export
def loop_first(values):
"Iterate and generate a tuple with a flag for first value."
return ((b,o) for b,_,o in loop_first_last(values))
test_eq(loop_first(range(3)), [(True,0), (False,1), (False,2)])
#|export
def loop_last(values):
"Iterate and generate a tuple with a flag for last value."
return ((b,o) for _,b,o in loop_first_last(values))
test_eq(loop_last(range(3)), [(False,0), (False,1), (True,2)])
A tuple with extended functionality.
#|export
num_methods = """
__add__ __sub__ __mul__ __matmul__ __truediv__ __floordiv__ __mod__ __divmod__ __pow__
__lshift__ __rshift__ __and__ __xor__ __or__ __neg__ __pos__ __abs__
""".split()
rnum_methods = """
__radd__ __rsub__ __rmul__ __rmatmul__ __rtruediv__ __rfloordiv__ __rmod__ __rdivmod__
__rpow__ __rlshift__ __rrshift__ __rand__ __rxor__ __ror__
""".split()
inum_methods = """
__iadd__ __isub__ __imul__ __imatmul__ __itruediv__
__ifloordiv__ __imod__ __ipow__ __ilshift__ __irshift__ __iand__ __ixor__ __ior__
""".split()
#|export
class fastuple(tuple):
"A `tuple` with elementwise ops and more friendly __init__ behavior"
def __new__(cls, x=None, *rest):
if x is None: x = ()
if not isinstance(x,tuple):
if len(rest): x = (x,)
else:
try: x = tuple(iter(x))
except TypeError: x = (x,)
return super().__new__(cls, x+rest if rest else x)
def _op(self,op,*args):
if not isinstance(self,fastuple): self = fastuple(self)
return type(self)(map(op,self,*map(cycle, args)))
def mul(self,*args):
"`*` is already defined in `tuple` for replicating, so use `mul` instead"
return fastuple._op(self, operator.mul,*args)
def add(self,*args):
"`+` is already defined in `tuple` for concat, so use `add` instead"
return fastuple._op(self, operator.add,*args)
def _get_op(op):
if isinstance(op,str): op = getattr(operator,op)
def _f(self,*args): return self._op(op,*args)
return _f
for n in num_methods:
if not hasattr(fastuple, n) and hasattr(operator,n): setattr(fastuple,n,_get_op(n))
for n in 'eq ne lt le gt ge'.split(): setattr(fastuple,n,_get_op(n))
setattr(fastuple,'__invert__',_get_op('__not__'))
setattr(fastuple,'max',_get_op(max))
setattr(fastuple,'min',_get_op(min))
show_doc(fastuple, title_level=4)
Common failure modes when trying to initialize a tuple in python:
tuple(3)
> TypeError: 'int' object is not iterable
or
tuple(3, 4)
> TypeError: tuple expected at most 1 arguments, got 2
However, fastuple
allows you to define tuples like this and in the usual way:
test_eq(fastuple(3), (3,))
test_eq(fastuple(3,4), (3, 4))
test_eq(fastuple((3,4)), (3, 4))
show_doc(fastuple.add, title_level=5)
test_eq(fastuple.add((1,1),(2,2)), (3,3))
test_eq_type(fastuple(1,1).add(2), fastuple(3,3))
test_eq(fastuple('1','2').add('2'), fastuple('12','22'))
show_doc(fastuple.mul, title_level=5)
fastuple.mul (*args)
*
is already defined in tuple
for replicating, so use mul
instead
test_eq_type(fastuple(1,1).mul(2), fastuple(2,2))
Additionally, the following elementwise operations are available:
le
: less than or equaleq
: equalgt
: greater thanmin
: minimum oftest_eq(fastuple(3,1).le(1), (False, True))
test_eq(fastuple(3,1).eq(1), (False, True))
test_eq(fastuple(3,1).gt(1), (True, False))
test_eq(fastuple(3,1).min(2), (2,1))
You can also do other elementwise operations like negate a fastuple
, or subtract two fastuple
s:
test_eq(-fastuple(1,2), (-1,-2))
test_eq(~fastuple(1,0,1), (False,True,False))
test_eq(fastuple(1,1)-fastuple(2,2), (-1,-1))
test_eq(type(fastuple(1)), fastuple)
test_eq_type(fastuple(1,2), fastuple(1,2))
test_ne(fastuple(1,2), fastuple(1,3))
test_eq(fastuple(), ())
Utilities for functional programming or for defining, modifying, or debugging functions.
#|export
class _Arg:
def __init__(self,i): self.i = i
arg0 = _Arg(0)
arg1 = _Arg(1)
arg2 = _Arg(2)
arg3 = _Arg(3)
arg4 = _Arg(4)
#|export
class bind:
"Same as `partial`, except you can use `arg0` `arg1` etc param placeholders"
def __init__(self, func, *pargs, **pkwargs):
self.func,self.pargs,self.pkwargs = func,pargs,pkwargs
self.maxi = max((x.i for x in pargs if isinstance(x, _Arg)), default=-1)
def __call__(self, *args, **kwargs):
args = list(args)
kwargs = {**self.pkwargs,**kwargs}
for k,v in kwargs.items():
if isinstance(v,_Arg): kwargs[k] = args.pop(v.i)
fargs = [args[x.i] if isinstance(x, _Arg) else x for x in self.pargs] + args[self.maxi+1:]
return self.func(*fargs, **kwargs)
show_doc(bind, title_level=3)
bind (func, *pargs, **pkwargs)
Same as partial
, except you can use arg0
arg1
etc param placeholders
bind
is the same as partial
, but also allows you to reorder positional arguments using variable name(s) arg{i}
where i refers to the zero-indexed positional argument. bind
as implemented currently only supports reordering of up to the first 5 positional arguments.
Consider the function myfunc
below, which has 3 positional arguments. These arguments can be referenced as arg0
, arg1
, and arg1
, respectively.
def myfn(a,b,c,d=1,e=2): return(a,b,c,d,e)
In the below example we bind the positional arguments of myfn
as follows:
14
, referenced by arg1
, is substituted for the first positional argument.17
for the second positional argument.19
, referenced by arg0
, is subsituted for the third positional argument.test_eq(bind(myfn, arg1, 17, arg0, e=3)(19,14), (14,17,19,1,3))
In this next example:
17
for the first positional argument.19
refrenced by arg0
, becomes the second positional argument.14
becomes the third positional argument.e
to 3
.test_eq(bind(myfn, 17, arg0, e=3)(19,14), (17,19,14,1,3))
This is an example of using bind
like partial
and do not reorder any arguments:
test_eq(bind(myfn)(17,19,14), (17,19,14,1,2))
bind
can also be used to change default values. In the below example, we use the first input 3
to override the default value of the named argument e
, and supply default values for the first three positional arguments:
test_eq(bind(myfn, 17,19,14,e=arg0)(3), (17,19,14,1,3))
#|export
def mapt(func, *iterables):
"Tuplified `map`"
return tuple(map(func, *iterables))
t = [0,1,2,3]
test_eq(mapt(operator.neg, t), (0,-1,-2,-3))
#|export
def map_ex(iterable, f, *args, gen=False, **kwargs):
"Like `map`, but use `bind`, and supports `str` and indexing"
g = (bind(f,*args,**kwargs) if callable(f)
else f.format if isinstance(f,str)
else f.__getitem__)
res = map(g, iterable)
if gen: return res
return list(res)
test_eq(map_ex(t,operator.neg), [0,-1,-2,-3])
If f
is a string then it is treated as a format string to create the mapping:
test_eq(map_ex(t, '#{}#'), ['#0#','#1#','#2#','#3#'])
If f
is a dictionary (or anything supporting __getitem__
) then it is indexed to create the mapping:
test_eq(map_ex(t, list('abcd')), list('abcd'))
You can also pass the same arg
params that bind
accepts:
def f(a=None,b=None): return b
test_eq(map_ex(t, f, b=arg0), range(4))
#|export
def compose(*funcs, order=None):
"Create a function that composes all functions in `funcs`, passing along remaining `*args` and `**kwargs` to all"
funcs = listify(funcs)
if len(funcs)==0: return noop
if len(funcs)==1: return funcs[0]
if order is not None: funcs = sorted_ex(funcs, key=order)
def _inner(x, *args, **kwargs):
for f in funcs: x = f(x, *args, **kwargs)
return x
return _inner
f1 = lambda o,p=0: (o*2)+p
f2 = lambda o,p=1: (o+1)/p
test_eq(f2(f1(3)), compose(f1,f2)(3))
test_eq(f2(f1(3,p=3),p=3), compose(f1,f2)(3,p=3))
test_eq(f2(f1(3, 3), 3), compose(f1,f2)(3, 3))
f1.order = 1
test_eq(f1(f2(3)), compose(f1,f2, order="order")(3))
#|export
def maps(*args, retain=noop):
"Like `map`, except funcs are composed first"
f = compose(*args[:-1])
def _f(b): return retain(f(b), b)
return map(_f, args[-1])
test_eq(maps([1]), [1])
test_eq(maps(operator.neg, [1,2]), [-1,-2])
test_eq(maps(operator.neg, operator.neg, [1,2]), [1,2])
#|export
def partialler(f, *args, order=None, **kwargs):
"Like `functools.partial` but also copies over docstring"
fnew = partial(f,*args,**kwargs)
fnew.__doc__ = f.__doc__
if order is not None: fnew.order=order
elif hasattr(f,'order'): fnew.order=f.order
return fnew
def _f(x,a=1):
"test func"
return x-a
_f.order=1
f = partialler(_f, 2)
test_eq(f.order, 1)
test_eq(f(3), -1)
f = partialler(_f, a=2, order=3)
test_eq(f.__doc__, "test func")
test_eq(f.order, 3)
test_eq(f(3), _f(3,2))
class partial0:
"Like `partialler`, but args passed to callable are inserted at started, instead of at end"
def __init__(self, f, *args, order=None, **kwargs):
self.f,self.args,self.kwargs = f,args,kwargs
self.order = ifnone(order, getattr(f,'order',None))
self.__doc__ = f.__doc__
def __call__(self, *args, **kwargs): return self.f(*args, *self.args, **kwargs, **self.kwargs)
f = partial0(_f, 2)
test_eq(f.order, 1)
test_eq(f(3), 1) # NB: different to `partialler` example
#|export
def instantiate(t):
"Instantiate `t` if it's a type, otherwise do nothing"
return t() if isinstance(t, type) else t
test_eq_type(instantiate(int), 0)
test_eq_type(instantiate(1), 1)
#|export
def _using_attr(f, attr, x): return f(getattr(x,attr))
#|export
def using_attr(f, attr):
"Construct a function which applies `f` to the argument's attribute `attr`"
return partial(_using_attr, f, attr)
t = Path('/a/b.txt')
f = using_attr(str.upper, 'name')
test_eq(f(t), 'B.TXT')
A Concise Way To Create Lambdas
#|export
class _Self:
"An alternative to `lambda` for calling methods on passed object."
def __init__(self): self.nms,self.args,self.kwargs,self.ready = [],[],[],True
def __repr__(self): return f'self: {self.nms}({self.args}, {self.kwargs})'
def __call__(self, *args, **kwargs):
if self.ready:
x = args[0]
for n,a,k in zip(self.nms,self.args,self.kwargs):
x = getattr(x,n)
if callable(x) and a is not None: x = x(*a, **k)
return x
else:
self.args.append(args)
self.kwargs.append(kwargs)
self.ready = True
return self
def __getattr__(self,k):
if not self.ready:
self.args.append(None)
self.kwargs.append(None)
self.nms.append(k)
self.ready = False
return self
def _call(self, *args, **kwargs):
self.args,self.kwargs,self.nms = [args],[kwargs],['__call__']
self.ready = True
return self
#|export
class _SelfCls:
def __getattr__(self,k): return getattr(_Self(),k)
def __getitem__(self,i): return self.__getattr__('__getitem__')(i)
def __call__(self,*args,**kwargs): return self.__getattr__('_call')(*args,**kwargs)
Self = _SelfCls()
#|export
_all_ = ['Self']
This is a concise way to create lambdas that are calling methods on an object (note the capitalization!)
Self.sum()
, for instance, is a shortcut for lambda o: o.sum()
.
f = Self.sum()
x = np.array([3.,1])
test_eq(f(x), 4.)
# This is equivalent to above
f = lambda o: o.sum()
x = np.array([3.,1])
test_eq(f(x), 4.)
f = Self.argmin()
arr = np.array([1,2,3,4,5])
test_eq(f(arr), arr.argmin())
f = Self.sum().is_integer()
x = np.array([3.,1])
test_eq(f(x), True)
f = Self.sum().real.is_integer()
x = np.array([3.,1])
test_eq(f(x), True)
f = Self.imag()
test_eq(f(3), 0)
f = Self[1]
test_eq(f(x), 1)
Self
is also callable, which creates a function which calls any function passed to it, using the arguments passed to Self
:
def f(a, b=3): return a+b+2
def g(a, b=3): return a*b
fg = Self(1,b=2)
list(map(fg, [f,g]))
[5, 2]
#|export
def copy_func(f):
"Copy a non-builtin function (NB `copy.copy` does not work for this)"
if not isinstance(f,FunctionType): return copy(f)
fn = FunctionType(f.__code__, f.__globals__, f.__name__, f.__defaults__, f.__closure__)
fn.__kwdefaults__ = f.__kwdefaults__
fn.__dict__.update(f.__dict__)
fn.__annotations__.update(f.__annotations__)
fn.__qualname__ = f.__qualname__
return fn
Sometimes it may be desirable to make a copy of a function that doesn't point to the original object. When you use Python's built in copy.copy
or copy.deepcopy
to copy a function, you get a reference to the original object:
import copy as cp
def foo(): pass
a = cp.copy(foo)
b = cp.deepcopy(foo)
a.someattr = 'hello' # since a and b point at the same object, updating a will update b
test_eq(b.someattr, 'hello')
assert a is foo and b is foo
However, with copy_func
, you can retrieve a copy of a function without a reference to the original object:
c = copy_func(foo) # c is an indpendent object
assert c is not foo
def g(x, *, y=3): return x+y
test_eq(copy_func(g)(4), 7)
#|export
def patch_to(cls, as_prop=False, cls_method=False):
"Decorator: add `f` to `cls`"
if not isinstance(cls, (tuple,list)): cls=(cls,)
def _inner(f):
for c_ in cls:
nf = copy_func(f)
nm = f.__name__
# `functools.update_wrapper` when passing patched function to `Pipeline`, so we do it manually
for o in functools.WRAPPER_ASSIGNMENTS: setattr(nf, o, getattr(f,o))
nf.__qualname__ = f"{c_.__name__}.{nm}"
if cls_method:
setattr(c_, nm, MethodType(nf, c_))
else:
setattr(c_, nm, property(nf) if as_prop else nf)
# Avoid clobbering existing functions
return globals().get(nm, builtins.__dict__.get(nm, None))
return _inner
The @patch_to
decorator allows you to monkey patch a function into a class as a method:
class _T3(int): pass
@patch_to(_T3)
def func1(self, a): return self+a
t = _T3(1) # we initilized `t` to a type int = 1
test_eq(t.func1(2), 3) # we add 2 to `t`, so 2 + 1 = 3
You can access instance properties in the usual way via self
:
class _T4():
def __init__(self, g): self.g = g
@patch_to(_T4)
def greet(self, x): return self.g + x
t = _T4('hello ') # this sets self.g = 'helllo '
test_eq(t.greet('world'), 'hello world') #t.greet('world') will append 'world' to 'hello '
You can instead specify that the method should be a class method by setting cls_method=True
:
class _T5(int): attr = 3 # attr is a class attribute we will access in a later method
@patch_to(_T5, cls_method=True)
def func(cls, x): return cls.attr + x # you can access class attributes in the normal way
test_eq(_T5.func(4), 7)
Additionally you can specify that the function you want to patch should be a class attribute with as_prop
= False
@patch_to(_T5, as_prop=True)
def add_ten(self): return self + 10
t = _T5(4)
test_eq(t.add_ten, 14)
Instead of passing one class to the @patch_to
decorator, you can pass multiple classes in a tuple to simulteanously patch more than one class with the same method:
class _T6(int): pass
class _T7(int): pass
@patch_to((_T6,_T7))
def func_mult(self, a): return self*a
t = _T6(2)
test_eq(t.func_mult(4), 8)
t = _T7(2)
test_eq(t.func_mult(4), 8)
#|export
def patch(f=None, *, as_prop=False, cls_method=False):
"Decorator: add `f` to the first parameter's class (based on f's type annotations)"
if f is None: return partial(patch, as_prop=as_prop, cls_method=cls_method)
ann,glb,loc = get_annotations_ex(f)
cls = union2tuple(eval_type(ann.pop('cls') if cls_method else next(iter(ann.values())), glb, loc))
return patch_to(cls, as_prop=as_prop, cls_method=cls_method)(f)
@patch
is an alternative to @patch_to
that allows you similarly monkey patch class(es) by using type annotations:
class _T8(int): pass
@patch
def func(self:_T8, a): return self+a
t = _T8(1) # we initilized `t` to a type int = 1
test_eq(t.func(3), 4) # we add 3 to `t`, so 3 + 1 = 4
test_eq(t.func.__qualname__, '_T8.func')
Similarly to patch_to
, you can supply a union of classes instead of a single class in your type annotations to patch multiple classes:
class _T9(int): pass
@patch
def func2(x:_T8|_T9, a): return x*a # will patch both _T8 and _T9
t = _T8(2)
test_eq(t.func2(4), 8)
test_eq(t.func2.__qualname__, '_T8.func2')
t = _T9(2)
test_eq(t.func2(4), 8)
test_eq(t.func2.__qualname__, '_T9.func2')
Just like patch_to
decorator you can use as_prop
and cls_method
parameters with patch
decorator:
@patch(as_prop=True)
def add_ten(self:_T5): return self + 10
t = _T5(4)
test_eq(t.add_ten, 14)
class _T5(int): attr = 3 # attr is a class attribute we will access in a later method
@patch(cls_method=True)
def func(cls:_T5, x): return cls.attr + x # you can access class attributes in the normal way
test_eq(_T5.func(4), 7)
#|export
def patch_property(f):
"Deprecated; use `patch(as_prop=True)` instead"
warnings.warn("`patch_property` is deprecated and will be removed; use `patch(as_prop=True)` instead")
cls = next(iter(f.__annotations__.values()))
return patch_to(cls, as_prop=True)(f)
#|export
def compile_re(pat):
"Compile `pat` if it's not None"
return None if pat is None else re.compile(pat)
assert compile_re(None) is None
assert compile_re('a').match('ab')
#|export
class ImportEnum(enum.Enum):
"An `Enum` that can have its values imported"
@classmethod
def imports(cls):
g = sys._getframe(1).f_locals
for o in cls: g[o.name]=o
show_doc(ImportEnum, title_level=4)
ImportEnum (value, names=None, module=None, qualname=None, type=None, start=1)
An Enum
that can have its values imported
_T = ImportEnum('_T', {'foobar':1, 'goobar':2})
_T.imports()
test_eq(foobar, _T.foobar)
test_eq(goobar, _T.goobar)
#|export
class StrEnum(str,ImportEnum):
"An `ImportEnum` that behaves like a `str`"
def __str__(self): return self.name
show_doc(StrEnum, title_level=4)
StrEnum (value, names=None, module=None, qualname=None, type=None, start=1)
An ImportEnum
that behaves like a str
#|export
def str_enum(name, *vals):
"Simplified creation of `StrEnum` types"
return StrEnum(name, {o:o for o in vals})
_T = str_enum('_T', 'a', 'b')
test_eq(f'{_T.a}', 'a')
test_eq(_T.a, 'a')
test_eq(list(_T.__members__), ['a','b'])
print(_T.a, _T.a.upper())
a A
#|export
class Stateful:
"A base class/mixin for objects that should not serialize all their state"
_stateattrs=()
def __init__(self,*args,**kwargs):
self._init_state()
super().__init__(*args,**kwargs) # required for mixin usage
def __getstate__(self):
return {k:v for k,v in self.__dict__.items()
if k not in self._stateattrs+('_state',)}
def __setstate__(self, state):
self.__dict__.update(state)
self._init_state()
def _init_state(self):
"Override for custom init and deserialization logic"
self._state = {}
show_doc(Stateful, title_level=4)
Stateful (*args, **kwargs)
A base class/mixin for objects that should not serialize all their state
class _T(Stateful):
def __init__(self):
super().__init__()
self.a=1
self._state['test']=2
t = _T()
t2 = pickle.loads(pickle.dumps(t))
test_eq(t.a,1)
test_eq(t._state['test'],2)
test_eq(t2.a,1)
test_eq(t2._state,{})
Override _init_state
to do any necessary setup steps that are required during __init__
or during deserialization (e.g. pickle.load
). Here's an example of how Stateful
simplifies the official Python example for Handling Stateful Objects.
class TextReader(Stateful):
"""Print and number lines in a text file."""
_stateattrs=('file',)
def __init__(self, filename):
self.filename,self.lineno = filename,0
super().__init__()
def readline(self):
self.lineno += 1
line = self.file.readline()
if line: return f"{self.lineno}: {line.strip()}"
def _init_state(self):
self.file = open(self.filename)
for _ in range(self.lineno): self.file.readline()
reader = TextReader("00_test.ipynb")
print(reader.readline())
print(reader.readline())
new_reader = pickle.loads(pickle.dumps(reader))
print(reader.readline())
1: { 2: "cells": [ 3: {
#|export
class PrettyString(str):
"Little hack to get strings to show properly in Jupyter."
def __repr__(self): return self
show_doc(PrettyString, title_level=4)
Allow strings with special characters to render properly in Jupyter. Without calling print()
strings with special characters are displayed like so:
with_special_chars='a string\nwith\nnew\nlines and\ttabs'
with_special_chars
'a string\nwith\nnew\nlines and\ttabs'
We can correct this with PrettyString
:
PrettyString(with_special_chars)
a string with new lines and tabs
#|export
def even_mults(start, stop, n):
"Build log-stepped array from `start` to `stop` in `n` steps."
if n==1: return stop
mult = stop/start
step = mult**(1/(n-1))
return [start*(step**i) for i in range(n)]
test_eq(even_mults(2,8,3), [2,4,8])
test_eq(even_mults(2,32,5), [2,4,8,16,32])
test_eq(even_mults(2,8,1), 8)
#|export
def num_cpus():
"Get number of cpus"
try: return len(os.sched_getaffinity(0))
except AttributeError: return os.cpu_count()
defaults.cpus = num_cpus()
num_cpus()
4
#|export
def add_props(f, g=None, n=2):
"Create properties passing each of `range(n)` to f"
if g is None: return (property(partial(f,i)) for i in range(n))
return (property(partial(f,i), partial(g,i)) for i in range(n))
class _T(): a,b = add_props(lambda i,x:i*2)
t = _T()
test_eq(t.a,0)
test_eq(t.b,2)
class _T():
def __init__(self, v): self.v=v
def _set(i, self, v): self.v[i] = v
a,b = add_props(lambda i,x: x.v[i], _set)
t = _T([0,2])
test_eq(t.a,0)
test_eq(t.b,2)
t.a = t.a+1
t.b = 3
test_eq(t.a,1)
test_eq(t.b,3)
#|export
def _typeerr(arg, val, typ): return TypeError(f"{arg}=={val} not {typ}")
#|export
def typed(f):
"Decorator to check param and return types at runtime"
names = f.__code__.co_varnames
anno = annotations(f)
ret = anno.pop('return',None)
def _f(*args,**kwargs):
kw = {**kwargs}
if len(anno) > 0:
for i,arg in enumerate(args): kw[names[i]] = arg
for k,v in kw.items():
if k in anno and not isinstance(v,anno[k]): raise _typeerr(k, v, anno[k])
res = f(*args,**kwargs)
if ret is not None and not isinstance(res,ret): raise _typeerr("return", res, ret)
return res
return functools.update_wrapper(_f, f)
typed
validates argument types at runtime. This is in contrast to MyPy which only offers static type checking.
For example, a TypeError
will be raised if we try to pass an integer into the first argument of the below function:
@typed
def discount(price:int, pct:float):
return (1-pct) * price
with ExceptionExpected(TypeError): discount(100.0, .1)
We can also optionally allow multiple types by enumarating the types in a tuple as illustrated below:
def discount(price:int|float, pct:float):
return (1-pct) * price
assert 90.0 == discount(100.0, .1)
@typed
def foo(a:int, b:str='a'): return a
test_eq(foo(1, '2'), 1)
with ExceptionExpected(TypeError): foo(1,2)
@typed
def foo()->str: return 1
with ExceptionExpected(TypeError): foo()
@typed
def foo()->str: return '1'
assert foo()
typed
works with classes, too:
class Foo:
@typed
def __init__(self, a:int, b: int, c:str): pass
@typed
def test(cls, d:str): return d
with ExceptionExpected(TypeError): Foo(1, 2, 3)
with ExceptionExpected(TypeError): Foo(1,2, 'a string').test(10)
#|export
def exec_new(code):
"Execute `code` in a new environment and return it"
pkg = None if __name__=='__main__' else Path().cwd().name
g = {'__name__': __name__, '__package__': pkg}
exec(code, g)
return g
g = exec_new('a=1')
test_eq(g['a'], 1)
#|export
def exec_import(mod, sym):
"Import `sym` from `mod` in a new environment"
# pref = '' if __name__=='__main__' or mod[0]=='.' else '.'
return exec_new(f'from {mod} import {sym}')
#|export
def str2bool(s):
"Case-insensitive convert string `s` too a bool (`y`,`yes`,`t`,`true`,`on`,`1`->`True`)"
if not isinstance(s,str): return bool(s)
if not s: return False
s = s.lower()
if s in ('y', 'yes', 't', 'true', 'on', '1'): return 1
elif s in ('n', 'no', 'f', 'false', 'off', '0'): return 0
else: raise ValueError()
True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values are 'n', 'no', 'f', 'false', 'off', and '0'. Raises ValueError
if 'val' is anything else.
for o in "y YES t True on 1".split(): assert str2bool(o)
for o in "n no FALSE off 0".split(): assert not str2bool(o)
for o in 0,None,'',False: assert not str2bool(o)
for o in 1,True: assert str2bool(o)
show_doc(ipython_shell)
show_doc(in_ipython)
show_doc(in_colab)
show_doc(in_jupyter)
show_doc(in_notebook)
These variables are available as booleans in fastcore.basics
as IN_IPYTHON
, IN_JUPYTER
, IN_COLAB
and IN_NOTEBOOK
.
IN_IPYTHON, IN_JUPYTER, IN_COLAB, IN_NOTEBOOK
(True, True, False, True)
#|hide
import nbdev; nbdev.nbdev_export()