#|default_exp dispatch #|export from __future__ import annotations from fastcore.imports import * from fastcore.foundation import * from fastcore.utils import * from collections import defaultdict from nbdev.showdoc import * from fastcore.test import * from fastcore.nb_imports import * #|export def lenient_issubclass(cls, types): "If possible return whether `cls` is a subclass of `types`, otherwise return False." if cls is object and types is not object: return False # treat `object` as highest level try: return isinstance(cls, types) or issubclass(cls, types) except: return False assert not lenient_issubclass(typing.Collection, list) assert lenient_issubclass(list, typing.Collection) assert lenient_issubclass(typing.Collection, object) assert lenient_issubclass(typing.List, typing.Collection) assert not lenient_issubclass(typing.Collection, typing.List) assert not lenient_issubclass(object, typing.Callable) #|export def sorted_topologically(iterable, *, cmp=operator.lt, reverse=False): "Return a new list containing all items from the iterable sorted topologically" l,res = L(list(iterable)),[] for _ in range(len(l)): t = l.reduce(lambda x,y: y if cmp(y,x) else x) res.append(t), l.remove(t) return res[::-1] if reverse else res td = [3, 1, 2, 5] test_eq(sorted_topologically(td), [1, 2, 3, 5]) test_eq(sorted_topologically(td, reverse=True), [5, 3, 2, 1]) td = {int:1, numbers.Number:2, numbers.Integral:3} test_eq(sorted_topologically(td, cmp=lenient_issubclass), [int, numbers.Integral, numbers.Number]) td = [numbers.Integral, tuple, list, int, dict] td = sorted_topologically(td, cmp=lenient_issubclass) assert td.index(int) < td.index(numbers.Integral) #|export def _chk_defaults(f, ann): pass # Implementation removed until we can figure out how to do this without `inspect` module # try: # Some callables don't have signatures, so ignore those errors # params = list(inspect.signature(f).parameters.values())[:min(len(ann),2)] # if any(p.default!=inspect.Parameter.empty for p in params): # warn(f"{f.__name__} has default params. These will be ignored.") # except ValueError: pass #|export def _p2_anno(f): "Get the 1st 2 annotations of `f`, defaulting to `object`" hints = type_hints(f) ann = [o for n,o in hints.items() if n!='return'] if callable(f): _chk_defaults(f, ann) while len(ann)<2: ann.append(object) return ann[:2] #|hide def _f(a): pass test_eq(_p2_anno(_f), (object,object)) def _f(a, b): pass test_eq(_p2_anno(_f), (object,object)) def _f(a:None, b)->str: pass test_eq(_p2_anno(_f), (NoneType,object)) def _f(a:str, b)->float: pass test_eq(_p2_anno(_f), (str,object)) def _f(a:None, b:str)->float: pass test_eq(_p2_anno(_f), (NoneType,str)) def _f(a:int, b:int)->float: pass test_eq(_p2_anno(_f), (int,int)) def _f(self, a:int, b:int): pass test_eq(_p2_anno(_f), (int,int)) def _f(a:int, b:str)->float: pass test_eq(_p2_anno(_f), (int,str)) test_eq(_p2_anno(attrgetter('foo')), (object,object)) #|hide # Disabled until _chk_defaults fixed # def _f(x:int,y:int=10): pass # test_warns(lambda: _p2_anno(_f)) def _f(x:int,y=10): pass _p2_anno(None),_p2_anno(_f) #|export class _TypeDict: def __init__(self): self.d,self.cache = {},{} def _reset(self): self.d = {k:self.d[k] for k in sorted_topologically(self.d, cmp=lenient_issubclass)} self.cache = {} def add(self, t, f): "Add type `t` and function `f`" if not isinstance(t, tuple): t = tuple(L(union2tuple(t))) for t_ in t: self.d[t_] = f self._reset() def all_matches(self, k): "Find first matching type that is a super-class of `k`" if k not in self.cache: types = [f for f in self.d if lenient_issubclass(k,f)] self.cache[k] = [self.d[o] for o in types] return self.cache[k] def __getitem__(self, k): "Find first matching type that is a super-class of `k`" res = self.all_matches(k) return res[0] if len(res) else None def __repr__(self): return self.d.__repr__() def first(self): return first(self.d.values()) #|export class TypeDispatch: "Dictionary-like object; `__getitem__` matches keys of types using `issubclass`" def __init__(self, funcs=(), bases=()): self.funcs,self.bases = _TypeDict(),L(bases).filter(is_not(None)) for o in L(funcs): self.add(o) self.inst = None self.owner = None def add(self, f): "Add type `t` and function `f`" if isinstance(f, staticmethod): a0,a1 = _p2_anno(f.__func__) else: a0,a1 = _p2_anno(f) t = self.funcs.d.get(a0) if t is None: t = _TypeDict() self.funcs.add(a0, t) t.add(a1, f) def first(self): "Get first function in ordered dict of type:func." return self.funcs.first().first() def returns(self, x): "Get the return type of annotation of `x`." return anno_ret(self[type(x)]) def _attname(self,k): return getattr(k,'__name__',str(k)) def __repr__(self): r = [f'({self._attname(k)},{self._attname(l)}) -> {getattr(v, "__name__", type(v).__name__)}' for k in self.funcs.d for l,v in self.funcs[k].d.items()] r = r + [o.__repr__() for o in self.bases] return '\n'.join(r) def __call__(self, *args, **kwargs): ts = L(args).map(type)[:2] f = self[tuple(ts)] if not f: return args[0] if isinstance(f, staticmethod): f = f.__func__ elif self.inst is not None: f = MethodType(f, self.inst) elif self.owner is not None: f = MethodType(f, self.owner) return f(*args, **kwargs) def __get__(self, inst, owner): self.inst = inst self.owner = owner return self def __getitem__(self, k): "Find first matching type that is a super-class of `k`" k = L(k) while len(k)<2: k.append(object) r = self.funcs.all_matches(k[0]) for t in r: o = t[k[1]] if o is not None: return o for base in self.bases: res = base[k] if res is not None: return res return None def f2(x:int, y:float): return x+y #int and float for 2nd arg def f_nin(x:numbers.Integral)->int: return x+1 #integral numeric def f_ni2(x:int): return x #integer def f_bll(x:bool|list): return x #bool or list def f_num(x:numbers.Number): return x #Number (root of numerics) t = TypeDispatch([f_nin,f_ni2,f_num,f_bll,None]) t assert issubclass(float, numbers.Number) test_eq(t[float], f_num) test_eq(t[np.int32], f_nin) test_eq(t[bool], f_bll) test_eq(t[list], f_bll) test_eq(t[np.int32], f_nin) test_eq(t[str], None) show_doc(TypeDispatch.add) def f_col(x:typing.Collection): return x t.add(f_col) test_eq(t[str], f_col) t t.add(f_ni2) test_eq(t[int], f_ni2) def f_ni3(z:int): return z # collides with f_ni2 with same type annotations t.add(f_ni3) test_eq(t[int], f_ni3) def f_str(x:str): return x+'1' t = TypeDispatch([f_nin,f_ni2,f_num,f_bll,None]) t2 = TypeDispatch(f_str, bases=t) # you can optionally supply a list of TypeDispatch objects for `bases`. t2 test_eq(t2[int], f_ni2) # searches `t` b/c not found in `t2` test_eq(t2[np.int32], f_nin) # searches `t` b/c not found in `t2` test_eq(t2[float], f_num) # searches `t` b/c not found in `t2` test_eq(t2[bool], f_bll) # searches `t` b/c not found in `t2` test_eq(t2[str], f_str) # found in `t`! test_eq(t2('a'), 'a1') # found in `t`!, and uses __call__ o = np.int32(1) test_eq(t2(o), 2) # found in `t2` and uses __call__ def f1(x:numbers.Integral, y): return x+1 #Integral is a numeric type def f2(x:int, y:float): return x+y t = TypeDispatch([f1,f2]) t test_eq(t[np.int32], f1) test_eq(t[int,float], f2) def f1(a:str, b:int, c:list): return a def f2(a: str, b:int): return b t = TypeDispatch([f1,f2]) test_eq(t[str, int], f2) t def f1(x:numbers.Integral, y): return x+1 def f2(x:int, y:float): return x+y t = TypeDispatch([f1,f2]) assert not issubclass(np.int32, int) assert issubclass(np.int32, numbers.Integral) test_eq(t[np.int32,float], f1) assert issubclass(int, numbers.Integral) # int is a subclass of numbers.Integral test_eq(t[int], f1) test_eq(t[int,int], f1) test_eq(t[float,float], None) show_doc(TypeDispatch.__call__) def f_arr(x:np.ndarray): return x.sum() def f_int(x:np.int32): return x+1 t = TypeDispatch([f_arr, f_int]) arr = np.array([5,4,3,2,1]) test_eq(t(arr), 15) # dispatches to f_arr o = np.int32(1) test_eq(t(o), 2) # dispatches to f_int assert t.first() is not None def f1(x:numbers.Integral, y): return x+1 def f2(x:int, y:float): return x+y t = TypeDispatch([f1,f2]) test_eq(t(3,2.0), 5) test_eq(t(3,2), 4) test_eq(t('a'), 'a') show_doc(TypeDispatch.returns) def f1(x:int) -> np.ndarray: return np.array(x) def f2(x:str) -> float: return List def f3(x:float): return List # f3 has no return type annotation t = TypeDispatch([f1, f2, f3]) test_eq(t.returns(1), np.ndarray) # dispatched to f1 test_eq(t.returns('Hello'), float) # dispatched to f2 test_eq(t.returns(1.0), None) # dispatched to f3 class _Test: pass _test = _Test() test_eq(t.returns(_test), None) # type `_Test` not found, so None returned def m_nin(self, x:str|numbers.Integral): return str(x)+'1' def m_bll(self, x:bool): self.foo='a' def m_num(self, x:numbers.Number): return x*2 t = TypeDispatch([m_nin,m_num,m_bll]) class A: f = t # set class attribute `f` equal to a TypeDispatch instance a = A() test_eq(a.f(1), '11') #dispatch to m_nin test_eq(a.f(1.), 2.) #dispatch to m_num test_is(a.f.inst, a) a.f(False) # this triggers t.m_bll to run, which sets self.foo to 'a' test_eq(a.foo, 'a') test_eq(a.f(()), ()) def m_tup(self, x:tuple): return x+(1,) t2 = TypeDispatch(m_tup, bases=t) class A2: f = t2 a2 = A2() test_eq(a2.f(1), '11') test_eq(a2.f(1.), 2.) test_is(a2.f.inst, a2) a2.f(False) test_eq(a2.foo, 'a') test_eq(a2.f(()), (1,)) def m_nin(cls, x:str|numbers.Integral): return str(x)+'1' def m_bll(cls, x:bool): cls.foo='a' def m_num(cls, x:numbers.Number): return x*2 t = TypeDispatch([m_nin,m_num,m_bll]) class A: f = t # set class attribute `f` equal to a TypeDispatch test_eq(A.f(1), '11') #dispatch to m_nin test_eq(A.f(1.), 2.) #dispatch to m_num test_is(A.f.owner, A) A.f(False) # this triggers t.m_bll to run, which sets A.foo to 'a' test_eq(A.foo, 'a') #|export class DispatchReg: "A global registry for `TypeDispatch` objects keyed by function name" def __init__(self): self.d = defaultdict(TypeDispatch) def __call__(self, f): if isinstance(f, (classmethod, staticmethod)): nm = f'{f.__func__.__qualname__}' else: nm = f'{f.__qualname__}' if isinstance(f, classmethod): f=f.__func__ self.d[nm].add(f) return self.d[nm] typedispatch = DispatchReg() @typedispatch def f_td_test(x, y): return f'{x}{y}' @typedispatch def f_td_test(x:numbers.Integral|int, y): return x+1 @typedispatch def f_td_test(x:int, y:float): return x+y @typedispatch def f_td_test(x:int, y:int): return x*y test_eq(f_td_test(3,2.0), 5) assert issubclass(int, numbers.Integral) test_eq(f_td_test(3,2), 6) test_eq(f_td_test('a','b'), 'ab') class A: @typedispatch def f_td_test(self, x:numbers.Integral, y): return x+1 @typedispatch @classmethod def f_td_test(cls, x:int, y:float): return x+y @typedispatch @staticmethod def f_td_test(x:int, y:int): return x*y test_eq(A.f_td_test(3,2), 6) test_eq(A.f_td_test(3,2.0), 5) test_eq(A().f_td_test(3,'2.0'), 4) #|export _all_=['cast'] #|export def retain_meta(x, res, as_copy=False): "Call `res.set_meta(x)`, if it exists" if hasattr(res,'set_meta'): res.set_meta(x, as_copy=as_copy) return res #|export def default_set_meta(self, x, as_copy=False): "Copy over `_meta` from `x` to `res`, if it's missing" if hasattr(x, '_meta') and not hasattr(self, '_meta'): meta = x._meta if as_copy: meta = copy(meta) self._meta = meta return self #|export @typedispatch def cast(x, typ): "cast `x` to type `typ` (may also change `x` inplace)" res = typ._before_cast(x) if hasattr(typ, '_before_cast') else x if risinstance('ndarray', res): res = res.view(typ) elif hasattr(res, 'as_subclass'): res = res.as_subclass(typ) else: try: res.__class__ = typ except: res = typ(res) return retain_meta(x, res) mk_class('_T1', 'a') # mk_class is a fastai utility that constructs a class. class _T2(_T1): pass t = _T1(a=1) t2 = cast(t, _T2) assert t2 is t # t2 refers to the same object as t assert isinstance(t, _T2) # t also changed in-place assert isinstance(t2, _T2) test_eq_type(_T2(a=1), t2) class _T1(ndarray): pass t = array([1]) t2 = cast(t, _T1) test_eq(array([1]), t2) test_eq(_T1, type(t2)) #|export def retain_type(new, old=None, typ=None, as_copy=False): "Cast `new` to type of `old` or `typ` if it's a superclass" # e.g. old is TensorImage, new is Tensor - if not subclass then do nothing if new is None: return assert old is not None or typ is not None if typ is None: if not isinstance(old, type(new)): return new typ = old if isinstance(old,type) else type(old) # Do nothing the new type is already an instance of requested type (i.e. same type) if typ==NoneType or isinstance(new, typ): return new return retain_meta(old, cast(new, typ), as_copy=as_copy) class _T(tuple): pass a = _T((1,2)) b = tuple((1,2)) c = retain_type(b, typ=_T) test_eq_type(c, a) class _A(): set_meta = default_set_meta def __init__(self, t): self.t=t class _B1(_A): def __init__(self, t, a=1): super().__init__(t) self._meta = {'a':a} self.other_attr = 'Hello' # will not be kept after casting. x = _B1(1, a=2) b = _A(1) c = retain_type(b, old=x) test_eq(c._meta, {'a': 2}) assert not getattr(c, 'other_attr', None) #|export def retain_types(new, old=None, typs=None): "Cast each item of `new` to type of matching item in `old` if it's a superclass" if not is_listy(new): return retain_type(new, old, typs) if typs is not None: if isinstance(typs, dict): t = first(typs.keys()) typs = typs[t] else: t,typs = typs,None else: t = type(old) if old is not None and isinstance(old,type(new)) else type(new) return t(L(new, old, typs).map_zip(retain_types, cycled=True)) class T(tuple): pass t1,t2 = retain_types((1,(1,(1,1))), (2,T((2,T((3,4)))))) test_eq_type(t1, 1) test_eq_type(t2, T((1,T((1,1))))) t1,t2 = retain_types((1,(1,(1,1))), typs = {tuple: [int, {T: [int, {T: [int,int]}]}]}) test_eq_type(t1, 1) test_eq_type(t2, T((1,T((1,1))))) #|export def explode_types(o): "Return the type of `o`, potentially in nested dictionaries for thing that are listy" if not is_listy(o): return type(o) return {type(o): [explode_types(o_) for o_ in o]} test_eq(explode_types((2,T((2,T((3,4)))))), {tuple: [int, {T: [int, {T: [int,int]}]}]}) #|hide #|eval: false from nbdev import nbdev_export nbdev_export()