#default_exp dispatch
#export
from fastcore.imports import *
from fastcore.foundation import *
from fastcore.utils import *
from collections import defaultdict
from nbdev.showdoc import *
from fastcore.test import *
from fastcore.nb_imports import *
Basic single and dual parameter dispatch
#export
def lenient_issubclass(cls, types):
"If possible return whether `cls` is a subclass of `types`, otherwise return False."
if cls is object and types is not object: return False # treat `object` as highest level
try: return isinstance(cls, types) or issubclass(cls, types)
except: return False
assert not lenient_issubclass(typing.Collection, list)
assert lenient_issubclass(list, typing.Collection)
assert lenient_issubclass(typing.Collection, object)
assert lenient_issubclass(typing.List, typing.Collection)
assert not lenient_issubclass(typing.Collection, typing.List)
assert not lenient_issubclass(object, typing.Callable)
#export
def sorted_topologically(iterable, *, cmp=operator.lt, reverse=False):
"Return a new list containing all items from the iterable sorted topologically"
l,res = L(list(iterable)),[]
for _ in range(len(l)):
t = l.reduce(lambda x,y: y if cmp(y,x) else x)
res.append(t), l.remove(t)
return res[::-1] if reverse else res
td = [3, 1, 2, 5]
test_eq(sorted_topologically(td), [1, 2, 3, 5])
test_eq(sorted_topologically(td, reverse=True), [5, 3, 2, 1])
td = {int:1, numbers.Number:2, numbers.Integral:3}
test_eq(sorted_topologically(td, cmp=lenient_issubclass), [int, numbers.Integral, numbers.Number])
td = [numbers.Integral, tuple, list, int, dict]
td = sorted_topologically(td, cmp=lenient_issubclass)
assert td.index(int) < td.index(numbers.Integral)
#export
def _chk_defaults(f, ann):
pass
# Implementation removed until we can figure out how to do this without `inspect` module
# try: # Some callables don't have signatures, so ignore those errors
# params = list(inspect.signature(f).parameters.values())[:min(len(ann),2)]
# if any(p.default!=inspect.Parameter.empty for p in params):
# warn(f"{f.__name__} has default params. These will be ignored.")
# except ValueError: pass
#export
def _p2_anno(f):
"Get the 1st 2 annotations of `f`, defaulting to `object`"
hints = type_hints(f)
ann = [o for n,o in hints.items() if n!='return']
if callable(f): _chk_defaults(f, ann)
while len(ann)<2: ann.append(object)
return ann[:2]
#hide
def _f(a): pass
test_eq(_p2_anno(_f), (object,object))
def _f(a, b): pass
test_eq(_p2_anno(_f), (object,object))
def _f(a:None, b)->str: pass
test_eq(_p2_anno(_f), (NoneType,object))
def _f(a:str, b)->float: pass
test_eq(_p2_anno(_f), (str,object))
def _f(a:None, b:str)->float: pass
test_eq(_p2_anno(_f), (NoneType,str))
def _f(a:int, b:int)->float: pass
test_eq(_p2_anno(_f), (int,int))
def _f(self, a:int, b:int): pass
test_eq(_p2_anno(_f), (int,int))
def _f(a:int, b:str)->float: pass
test_eq(_p2_anno(_f), (int,str))
test_eq(_p2_anno(attrgetter('foo')), (object,object))
#hide
# Disabled until _chk_defaults fixed
# def _f(x:int,y:int=10): pass
# test_warns(lambda: _p2_anno(_f))
def _f(x:int,y=10): pass
_p2_anno(None),_p2_anno(_f)
([object, object], [int, object])
Type dispatch, or Multiple dispatch, allows you to change the way a function behaves based upon the input types it recevies. This is a prominent feature in some programming languages like Julia. For example, this is a conceptual example of how multiple dispatch works in Julia, returning different values depending on the input types of x and y:
collide_with(x::Asteroid, y::Asteroid) = ...
# deal with asteroid hitting asteroid
collide_with(x::Asteroid, y::Spaceship) = ...
# deal with asteroid hitting spaceship
collide_with(x::Spaceship, y::Asteroid) = ...
# deal with spaceship hitting asteroid
collide_with(x::Spaceship, y::Spaceship) = ...
# deal with spaceship hitting spaceship
Type dispatch can be especially useful in data science, where you might allow different input types (i.e. numpy arrays and pandas dataframes) to function that processes data. Type dispatch allows you to have a common API for functions that do similar tasks.
The TypeDispatch
class allows us to achieve type dispatch in Python. It contains a dictionary that maps types from type annotations to functions, which ensures that the proper function is called when passed inputs.
#export
class _TypeDict:
def __init__(self): self.d,self.cache = {},{}
def _reset(self):
self.d = {k:self.d[k] for k in sorted_topologically(self.d, cmp=lenient_issubclass)}
self.cache = {}
def add(self, t, f):
"Add type `t` and function `f`"
if not isinstance(t,tuple): t=tuple(L(t))
for t_ in t: self.d[t_] = f
self._reset()
def all_matches(self, k):
"Find first matching type that is a super-class of `k`"
if k not in self.cache:
types = [f for f in self.d if lenient_issubclass(k,f)]
self.cache[k] = [self.d[o] for o in types]
return self.cache[k]
def __getitem__(self, k):
"Find first matching type that is a super-class of `k`"
res = self.all_matches(k)
return res[0] if len(res) else None
def __repr__(self): return self.d.__repr__()
def first(self): return first(self.d.values())
#export
class TypeDispatch:
"Dictionary-like object; `__getitem__` matches keys of types using `issubclass`"
def __init__(self, funcs=(), bases=()):
self.funcs,self.bases = _TypeDict(),L(bases).filter(is_not(None))
for o in L(funcs): self.add(o)
self.inst = None
self.owner = None
def add(self, f):
"Add type `t` and function `f`"
if isinstance(f, staticmethod): a0,a1 = _p2_anno(f.__func__)
else: a0,a1 = _p2_anno(f)
t = self.funcs.d.get(a0)
if t is None:
t = _TypeDict()
self.funcs.add(a0, t)
t.add(a1, f)
def first(self):
"Get first function in ordered dict of type:func."
return self.funcs.first().first()
def returns(self, x):
"Get the return type of annotation of `x`."
return anno_ret(self[type(x)])
def _attname(self,k): return getattr(k,'__name__',str(k))
def __repr__(self):
r = [f'({self._attname(k)},{self._attname(l)}) -> {getattr(v, "__name__", type(v).__name__)}'
for k in self.funcs.d for l,v in self.funcs[k].d.items()]
r = r + [o.__repr__() for o in self.bases]
return '\n'.join(r)
def __call__(self, *args, **kwargs):
ts = L(args).map(type)[:2]
f = self[tuple(ts)]
if not f: return args[0]
if isinstance(f, staticmethod): f = f.__func__
elif self.inst is not None: f = MethodType(f, self.inst)
elif self.owner is not None: f = MethodType(f, self.owner)
return f(*args, **kwargs)
def __get__(self, inst, owner):
self.inst = inst
self.owner = owner
return self
def __getitem__(self, k):
"Find first matching type that is a super-class of `k`"
k = L(k)
while len(k)<2: k.append(object)
r = self.funcs.all_matches(k[0])
for t in r:
o = t[k[1]]
if o is not None: return o
for base in self.bases:
res = base[k]
if res is not None: return res
return None
To demonstrate how TypeDispatch
works, we define a set of functions that accept a variety of input types, specified with different type annotations:
def f2(x:int, y:float): return x+y #int and float for 2nd arg
def f_nin(x:numbers.Integral)->int: return x+1 #integral numeric
def f_ni2(x:int): return x #integer
def f_bll(x:(bool,list)): return x #bool or list
def f_num(x:numbers.Number): return x #Number (root of numerics)
We can optionally initialize TypeDispatch
with a list of functions we want to search. Printing an instance of TypeDispatch
will display convenient mapping of types -> functions:
t = TypeDispatch([f_nin,f_ni2,f_num,f_bll,None])
t
(bool,object) -> f_bll (int,object) -> f_ni2 (Integral,object) -> f_nin (Number,object) -> f_num (list,object) -> f_bll (object,object) -> NoneType
Note that only the first two arguments are used for TypeDispatch
. If your function only contains one argument, the second parameter will be shown as object
. If you pass None
into TypeDispatch
, then this will be displayed as (object, object) -> NoneType
.
TypeDispatch
is a dictionary-like object, which means that you can retrieve a function by the associated type annotation. For example, the statement:
t[float]
Will return f_num
because that is the matching function that has a type annotation that is a super-class of of float
- numbers.Number
:
assert issubclass(float, numbers.Number)
test_eq(t[float], f_num)
The same is true for other types as well:
test_eq(t[np.int32], f_nin)
test_eq(t[bool], f_bll)
test_eq(t[list], f_bll)
test_eq(t[np.int32], f_nin)
If you try to get a type that doesn't match, TypeDispatch
will return None
:
test_eq(t[str], None)
show_doc(TypeDispatch.add)
This method allows you to add an additional function to an existing TypeDispatch
instance :
def f_col(x:typing.Collection): return x
t.add(f_col)
test_eq(t[str], f_col)
t
(bool,object) -> f_bll (int,object) -> f_ni2 (Integral,object) -> f_nin (Number,object) -> f_num (list,object) -> f_bll (typing.Collection,object) -> f_col (object,object) -> NoneType
If you accidentally add the same function more than once things will still work as expected:
t.add(f_ni2)
test_eq(t[int], f_ni2)
However, if you add a function that has a type collision that raises an ambiguity, this will automatically resolve to the latest function added:
def f_ni3(z:int): return z # collides with f_ni2 with same type annotations
t.add(f_ni3)
test_eq(t[int], f_ni3)
bases
:¶The argument bases
can optionally accept a single instance of TypeDispatch
or a collection (i.e. a tuple or list) of TypeDispatch
objects. This can provide functionality similar to multiple inheritance.
These are searched for matching functions if no match in your list of functions:
def f_str(x:str): return x+'1'
t = TypeDispatch([f_nin,f_ni2,f_num,f_bll,None])
t2 = TypeDispatch(f_str, bases=t) # you can optionally supply a list of TypeDispatch objects for `bases`.
t2
(str,object) -> f_str (bool,object) -> f_bll (int,object) -> f_ni2 (Integral,object) -> f_nin (Number,object) -> f_num (list,object) -> f_bll (object,object) -> NoneType
test_eq(t2[int], f_ni2) # searches `t` b/c not found in `t2`
test_eq(t2[np.int32], f_nin) # searches `t` b/c not found in `t2`
test_eq(t2[float], f_num) # searches `t` b/c not found in `t2`
test_eq(t2[bool], f_bll) # searches `t` b/c not found in `t2`
test_eq(t2[str], f_str) # found in `t`!
test_eq(t2('a'), 'a1') # found in `t`!, and uses __call__
o = np.int32(1)
test_eq(t2(o), 2) # found in `t2` and uses __call__
TypeDispatch
supports up to two arguments when searching for the appropriate function. The following functions f1
and f2
both have two parameters:
def f1(x:numbers.Integral, y): return x+1 #Integral is a numeric type
def f2(x:int, y:float): return x+y
t = TypeDispatch([f1,f2])
t
(int,float) -> f2 (Integral,object) -> f1
You can lookup functions from a TypeDispatch
instance with two parameters like this:
test_eq(t[np.int32], f1)
test_eq(t[int,float], f2)
Keep in mind that anything beyond the first two parameters are ignored, and any collisions will be resolved in favor of the most recent function added. In the below example, f1
is ignored in favor of f2
because the first two parameters have identical type hints:
def f1(a:str, b:int, c:list): return a
def f2(a: str, b:int): return b
t = TypeDispatch([f1,f2])
test_eq(t[str, int], f2)
t
(str,int) -> f2
Type Dispatch
matches types with functions according to whether the supplied class is a subclass or the same class of the type annotation(s) of associated functions.
Let's consider an example where we try to retrieve the function corresponding to types of [np.int32, float]
.
In this scenario, f2
will not be matched. This is because the first type annotation of f2
, int
, is not a superclass (or the same class) of np.int32
:
def f1(x:numbers.Integral, y): return x+1
def f2(x:int, y:float): return x+y
t = TypeDispatch([f1,f2])
assert not issubclass(np.int32, int)
Instead, f1
is a valid match, as its first argument is annoted with the type numbers.Integeral
, which np.int32
is a subclass of:
assert issubclass(np.int32, numbers.Integral)
test_eq(t[np.int32,float], f1)
In f1
, the 2nd parameter y
is not annotated, which means TypeDispatch
will match anything where the first argument matches int
that is not matched with anything else:
assert issubclass(int, numbers.Integral) # int is a subclass of numbers.Integral
test_eq(t[int], f1)
test_eq(t[int,int], f1)
If no match is possible, None
is returned:
test_eq(t[float,float], None)
show_doc(TypeDispatch.__call__)
TypeDispatch
is also callable. When you call an instance of TypeDispatch
, it will execute the relevant function:
def f_arr(x:np.ndarray): return x.sum()
def f_int(x:np.int32): return x+1
t = TypeDispatch([f_arr, f_int])
arr = np.array([5,4,3,2,1])
test_eq(t(arr), 15) # dispatches to f_arr
o = np.int32(1)
test_eq(t(o), 2) # dispatches to f_int
assert t.first() is not None
You can also call an instance of of TypeDispatch
when there are two parameters:
def f1(x:numbers.Integral, y): return x+1
def f2(x:int, y:float): return x+y
t = TypeDispatch([f1,f2])
test_eq(t(3,2.0), 5)
test_eq(t(3,2), 4)
When no match is found, a TypeDispatch
instance becomes an identity function. This default behavior is leveraged by fasatai for data transformations to provide a sensible default when a matching function cannot be found.
test_eq(t('a'), 'a')
show_doc(TypeDispatch.returns)
You can optionally pass an object to TypeDispatch.returns
and get the return type annotation back:
def f1(x:int) -> np.ndarray: return np.array(x)
def f2(x:str) -> float: return List
def f3(x:float): return List # f3 has no return type annotation
t = TypeDispatch([f1, f2, f3])
test_eq(t.returns(1), np.ndarray) # dispatched to f1
test_eq(t.returns('Hello'), float) # dispatched to f2
test_eq(t.returns(1.0), None) # dispatched to f3
class _Test: pass
_test = _Test()
test_eq(t.returns(_test), None) # type `_Test` not found, so None returned
You can use TypeDispatch
when defining methods as well:
def m_nin(self, x:(str,numbers.Integral)): return str(x)+'1'
def m_bll(self, x:bool): self.foo='a'
def m_num(self, x:numbers.Number): return x*2
t = TypeDispatch([m_nin,m_num,m_bll])
class A: f = t # set class attribute `f` equal to a TypeDispatch instance
a = A()
test_eq(a.f(1), '11') #dispatch to m_nin
test_eq(a.f(1.), 2.) #dispatch to m_num
test_is(a.f.inst, a)
a.f(False) # this triggers t.m_bll to run, which sets self.foo to 'a'
test_eq(a.foo, 'a')
As discussed in TypeDispatch.__call__
, when there is not a match, TypeDispatch.__call__
becomes an identity function. In the below example, a tuple does not match any type annotations so a tuple is returned:
test_eq(a.f(()), ())
We extend the previous example by using bases
to add an additional method that supports tuples:
def m_tup(self, x:tuple): return x+(1,)
t2 = TypeDispatch(m_tup, bases=t)
class A2: f = t2
a2 = A2()
test_eq(a2.f(1), '11')
test_eq(a2.f(1.), 2.)
test_is(a2.f.inst, a2)
a2.f(False)
test_eq(a2.foo, 'a')
test_eq(a2.f(()), (1,))
You can use TypeDispatch
when defining class methods too:
def m_nin(cls, x:(str,numbers.Integral)): return str(x)+'1'
def m_bll(cls, x:bool): cls.foo='a'
def m_num(cls, x:numbers.Number): return x*2
t = TypeDispatch([m_nin,m_num,m_bll])
class A: f = t # set class attribute `f` equal to a TypeDispatch
test_eq(A.f(1), '11') #dispatch to m_nin
test_eq(A.f(1.), 2.) #dispatch to m_num
test_is(A.f.owner, A)
A.f(False) # this triggers t.m_bll to run, which sets A.foo to 'a'
test_eq(A.foo, 'a')
#export
class DispatchReg:
"A global registry for `TypeDispatch` objects keyed by function name"
def __init__(self): self.d = defaultdict(TypeDispatch)
def __call__(self, f):
if isinstance(f, (classmethod, staticmethod)): nm = f'{f.__func__.__qualname__}'
else: nm = f'{f.__qualname__}'
if isinstance(f, classmethod): f=f.__func__
self.d[nm].add(f)
return self.d[nm]
typedispatch = DispatchReg()
@typedispatch
def f_td_test(x, y): return f'{x}{y}'
@typedispatch
def f_td_test(x:numbers.Integral, y): return x+1
@typedispatch
def f_td_test(x:int, y:float): return x+y
@typedispatch
def f_td_test(x:int, y:int): return x*y
test_eq(f_td_test(3,2.0), 5)
assert issubclass(int, numbers.Integral)
test_eq(f_td_test(3,2), 6)
test_eq(f_td_test('a','b'), 'ab')
You can use typedispatch
with classmethod
and staticmethod
decorator
class A:
@typedispatch
def f_td_test(self, x:numbers.Integral, y): return x+1
@typedispatch
@classmethod
def f_td_test(cls, x:int, y:float): return x+y
@typedispatch
@staticmethod
def f_td_test(x:int, y:int): return x*y
test_eq(A.f_td_test(3,2), 6)
test_eq(A.f_td_test(3,2.0), 5)
test_eq(A().f_td_test(3,'2.0'), 4)
Now that we can dispatch on types, let's make it easier to cast objects to a different type.
#export
_all_=['cast']
#export
def retain_meta(x, res, as_copy=False):
"Call `res.set_meta(x)`, if it exists"
if hasattr(res,'set_meta'): res.set_meta(x, as_copy=as_copy)
return res
#export
def default_set_meta(self, x, as_copy=False):
"Copy over `_meta` from `x` to `res`, if it's missing"
if hasattr(x, '_meta') and not hasattr(self, '_meta'):
meta = x._meta
if as_copy: meta = copy(meta)
self._meta = meta
return self
#export
@typedispatch
def cast(x, typ):
"cast `x` to type `typ` (may also change `x` inplace)"
res = typ._before_cast(x) if hasattr(typ, '_before_cast') else x
if risinstance('ndarray', res): res = res.view(typ)
elif hasattr(res, 'as_subclass'): res = res.as_subclass(typ)
else:
try: res.__class__ = typ
except: res = typ(res)
return retain_meta(x, res)
This works both for plain python classes:...
mk_class('_T1', 'a') # mk_class is a fastai utility that constructs a class.
class _T2(_T1): pass
t = _T1(a=1)
t2 = cast(t, _T2)
assert t2 is t # t2 refers to the same object as t
assert isinstance(t, _T2) # t also changed in-place
assert isinstance(t2, _T2)
test_eq_type(_T2(a=1), t2)
...as well as for arrays and tensors.
class _T1(ndarray): pass
t = array([1])
t2 = cast(t, _T1)
test_eq(array([1]), t2)
test_eq(_T1, type(t2))
To customize casting for other types, define a separate cast
function with typedispatch
for your type.
#export
def retain_type(new, old=None, typ=None, as_copy=False):
"Cast `new` to type of `old` or `typ` if it's a superclass"
# e.g. old is TensorImage, new is Tensor - if not subclass then do nothing
if new is None: return
assert old is not None or typ is not None
if typ is None:
if not isinstance(old, type(new)): return new
typ = old if isinstance(old,type) else type(old)
# Do nothing the new type is already an instance of requested type (i.e. same type)
if typ==NoneType or isinstance(new, typ): return new
return retain_meta(old, cast(new, typ), as_copy=as_copy)
class _T(tuple): pass
a = _T((1,2))
b = tuple((1,2))
c = retain_type(b, typ=_T)
test_eq_type(c, a)
If old
has a _meta
attribute, its content is passed when casting new
to the type of old
. In the below example, only the attribute a
, but not other_attr
is kept, because other_attr
is not in _meta
:
class _A():
set_meta = default_set_meta
def __init__(self, t): self.t=t
class _B1(_A):
def __init__(self, t, a=1):
super().__init__(t)
self._meta = {'a':a}
self.other_attr = 'Hello' # will not be kept after casting.
x = _B1(1, a=2)
b = _A(1)
c = retain_type(b, old=x)
test_eq(c._meta, {'a': 2})
assert not getattr(c, 'other_attr', None)
#export
def retain_types(new, old=None, typs=None):
"Cast each item of `new` to type of matching item in `old` if it's a superclass"
if not is_listy(new): return retain_type(new, old, typs)
if typs is not None:
if isinstance(typs, dict):
t = first(typs.keys())
typs = typs[t]
else: t,typs = typs,None
else: t = type(old) if old is not None and isinstance(old,type(new)) else type(new)
return t(L(new, old, typs).map_zip(retain_types, cycled=True))
class T(tuple): pass
t1,t2 = retain_types((1,(1,(1,1))), (2,T((2,T((3,4))))))
test_eq_type(t1, 1)
test_eq_type(t2, T((1,T((1,1)))))
t1,t2 = retain_types((1,(1,(1,1))), typs = {tuple: [int, {T: [int, {T: [int,int]}]}]})
test_eq_type(t1, 1)
test_eq_type(t2, T((1,T((1,1)))))
#export
def explode_types(o):
"Return the type of `o`, potentially in nested dictionaries for thing that are listy"
if not is_listy(o): return type(o)
return {type(o): [explode_types(o_) for o_ in o]}
test_eq(explode_types((2,T((2,T((3,4)))))), {tuple: [int, {T: [int, {T: [int,int]}]}]})
#hide
from nbdev.export import notebook2script
notebook2script()
Converted 00_test.ipynb. Converted 01_basics.ipynb. Converted 02_foundation.ipynb. Converted 03_xtras.ipynb. Converted 03a_parallel.ipynb. Converted 03b_net.ipynb. Converted 04_dispatch.ipynb. Converted 05_transform.ipynb. Converted 07_meta.ipynb. Converted 08_script.ipynb. Converted index.ipynb.