#|default_exp transform
#|export
from fastcore.imports import *
from fastcore.foundation import *
from fastcore.utils import *
from fastcore.dispatch import *
import inspect
from __future__ import annotations
from nbdev.showdoc import *
from fastcore.test import *
from fastcore.nb_imports import *
Definition of
Transform
andPipeline
The classes here provide functionality for creating a composition of partially reversible functions. By "partially reversible" we mean that a transform can be decode
d, creating a form suitable for display. This is not necessarily identical to the original form (e.g. a transform that changes a byte tensor to a float tensor does not recreate a byte tensor when decoded, since that may lose precision, and a float tensor can be displayed already).
Classes are also provided and for composing transforms, and mapping them over collections. Pipeline
is a transform which composes several Transform
, knowing how to decode them or show an encoded item.
#|export
_tfm_methods = 'encodes','decodes','setups'
def _is_tfm_method(n, f): return n in _tfm_methods and callable(f)
class _TfmDict(dict):
def __setitem__(self, k, v):
if not _is_tfm_method(k, v): return super().__setitem__(k,v)
if k not in self: super().__setitem__(k,TypeDispatch())
self[k].add(v)
#|export
class _TfmMeta(type):
def __new__(cls, name, bases, dict):
res = super().__new__(cls, name, bases, dict)
for nm in _tfm_methods:
base_td = [getattr(b,nm,None) for b in bases]
if nm in res.__dict__: getattr(res,nm).bases = base_td
else: setattr(res, nm, TypeDispatch(bases=base_td))
# _TfmMeta.__call__ shadows the signature of inheriting classes, set it back
res.__signature__ = inspect.signature(res.__init__)
return res
def __call__(cls, *args, **kwargs):
f = first(args)
n = getattr(f, '__name__', None)
if _is_tfm_method(n, f):
getattr(cls,n).add(f)
return f
obj = super().__call__(*args, **kwargs)
# _TfmMeta.__new__ replaces cls.__signature__ which breaks the signature of a callable
# instances of cls, fix it
if hasattr(obj, '__call__'): obj.__signature__ = inspect.signature(obj.__call__)
return obj
@classmethod
def __prepare__(cls, name, bases): return _TfmDict()
#|export
def _get_name(o):
if hasattr(o,'__qualname__'): return o.__qualname__
if hasattr(o,'__name__'): return o.__name__
return o.__class__.__name__
#|export
def _is_tuple(o): return isinstance(o, tuple) and not hasattr(o, '_fields')
#|export
class Transform(metaclass=_TfmMeta):
"Delegates (`__call__`,`decode`,`setup`) to (<code>encodes</code>,<code>decodes</code>,<code>setups</code>) if `split_idx` matches"
split_idx,init_enc,order,train_setup = None,None,0,None
def __init__(self, enc=None, dec=None, split_idx=None, order=None):
self.split_idx = ifnone(split_idx, self.split_idx)
if order is not None: self.order=order
self.init_enc = enc or dec
if not self.init_enc: return
self.encodes,self.decodes,self.setups = TypeDispatch(),TypeDispatch(),TypeDispatch()
if enc:
self.encodes.add(enc)
self.order = getattr(enc,'order',self.order)
if len(type_hints(enc)) > 0: self.input_types = union2tuple(first(type_hints(enc).values()))
self._name = _get_name(enc)
if dec: self.decodes.add(dec)
@property
def name(self): return getattr(self, '_name', _get_name(self))
def __call__(self, x, **kwargs): return self._call('encodes', x, **kwargs)
def decode (self, x, **kwargs): return self._call('decodes', x, **kwargs)
def __repr__(self): return f'{self.name}:\nencodes: {self.encodes}decodes: {self.decodes}'
def setup(self, items=None, train_setup=False):
train_setup = train_setup if self.train_setup is None else self.train_setup
return self.setups(getattr(items, 'train', items) if train_setup else items)
def _call(self, fn, x, split_idx=None, **kwargs):
if split_idx!=self.split_idx and self.split_idx is not None: return x
return self._do_call(getattr(self, fn), x, **kwargs)
def _do_call(self, f, x, **kwargs):
if not _is_tuple(x):
if f is None: return x
ret = f.returns(x) if hasattr(f,'returns') else None
return retain_type(f(x, **kwargs), x, ret)
res = tuple(self._do_call(f, x_, **kwargs) for x_ in x)
return retain_type(res, x)
add_docs(Transform, decode="Delegate to <code>decodes</code> to undo transform", setup="Delegate to <code>setups</code> to set up transform")
show_doc(Transform)
Transform (enc=None, dec=None, split_idx=None, order=None)
Delegates (__call__
,decode
,setup
) to (encodes
,decodes
,setups
) if split_idx
matches
A Transform
is the main building block of the fastai data pipelines. In the most general terms a transform can be any function you want to apply to your data, however the Transform
class provides several mechanisms that make the process of building them easy and flexible.
Transform
features:¶L
, as only tuples gets this specific behavior. An alternative is to use ItemTransform
defined below, which will always take the input as a whole.decodes
method. This is mainly used to turn something like a category which is encoded as a number back into a label understandable by humans for showing purposes. Like the regular call method, the decode
method that is used to decode will be applied over each element of a tuple separately.ArrayImage
which is a thin wrapper of pytorch's Tensor
. You can opt out of this behavior by adding ->None
return type annotation.setup
method can be used to perform any one-time calculations to be later used by the transform, for example generating a vocabulary to encode categorical data.split_idx
flag you can make the transform be used only in a specific DataSource
subset like in training, but not validation.order
attribute which the Pipeline
uses when it needs to merge two lists of transforms.Transform
by creating encodes
or decodes
methods for new data types. You can put those new methods outside the original transform definition and decorate them with the class you wish them patched into. This can be used by the fastai library users to add their own behavior, or multiple modules contributing to the same transform.Transform
¶There are a few ways to create a transform with different ratios of simplicity to flexibility.
Transform
class - Use inheritence to implement the methods you want.Transform
class and pass your functions as enc
and dec
arguments.Transform
by just adding a decorator - very straightforward if all you need is a single encodes
implementation.Pipeline
or TfmdDS
you don't even need a decorator. Your function will get converted to a Transform
automatically.A simple way to create a Transform
is to pass a function to the constructor. In the below example, we pass an anonymous function that does integer division by 2:
f = Transform(lambda o:o//2)
If you call this transform, it will apply the transformation:
test_eq_type(f(2), 1)
Another way to define a Transform is to extend the Transform
class:
class A(Transform): pass
However, to enable your transform to do something, you have to define an encodes
method. Note that we can use the class name as a decorator to add this method to the original class.
@A
def encodes(self, x): return x+1
f1 = A()
test_eq(f1(1), 2) # f1(1) is the same as f1.encode(1)
In addition to adding an encodes
method, we can also add a decodes
method. This enables you to call the decode
method (without an s). For more information about the purpose of decodes
, see the discussion about Reversibility in the above section.
Just like with encodes, you can add a decodes
method to the original class by using the class name as a decorator:
class B(A): pass
@B
def decodes(self, x): return x-1
f2 = B()
test_eq(f2.decode(2), 1)
test_eq(f2(1), 2) # uses A's encode method from the parent class
If you do not define an encodes
or decodes
method the original value will be returned:
class _Tst(Transform): pass
f3 = _Tst() # no encodes or decodes method have been defined
test_eq_type(f3.decode(2.0), 2.0)
test_eq_type(f3(2), 2)
Transforms can be created from class methods too:
class A:
@classmethod
def create(cls, x:int): return x+1
test_eq(Transform(A.create)(1), 2)
#|hide
# Test extension of a tfm method defined in the class
class A(Transform):
def encodes(self, x): return 'obj'
@A
def encodes(self, x:int): return 'int'
a = A()
test_eq(a.encodes(0), 'int')
test_eq(a.encodes(0.0), 'obj')
Transform
can be used as a decorator to turn a function into a Transform
.
@Transform
def f(x): return x//2
test_eq_type(f(2), 1)
test_eq_type(f.decode(2.0), 2.0)
@Transform
def f(x): return x*2
test_eq_type(f(2), 4)
test_eq_type(f.decode(2.0), 2.0)
We can also apply different transformations depending on the type of the input passed by using TypedDispatch
. TypedDispatch
automatically works with Transform
when using type hints:
class A(Transform): pass
@A
def encodes(self, x:int): return x//2
@A
def encodes(self, x:float): return x+1
When we pass in an int
, this calls the first encodes method:
f = A()
test_eq_type(f(3), 1)
When we pass in a float
, this calls the second encodes method:
test_eq_type(f(2.), 3.)
When we pass in a type that is not specified in encodes
, the original value is returned:
test_eq(f('a'), 'a')
If the type annotation is a tuple, then any type in the tuple will match:
class MyClass(int): pass
class A(Transform):
def encodes(self, x:MyClass|float): return x/2
def encodes(self, x:str|list): return str(x)+'_1'
f = A()
The below two examples match the first encodes, with a type of MyClass
and float
, respectively:
test_eq(f(MyClass(2)), 1.) # input is of type MyClass
test_eq(f(6.0), 3.0) # input is of type float
The next two examples match the second encodes
method, with a type of str
and list
, respectively:
test_eq(f('a'), 'a_1') # input is of type str
test_eq(f(['a','b','c']), "['a', 'b', 'c']_1") # input is of type list
Without any intervention it is easy for operations to change types in Python. For example, FloatSubclass
(defined below) becomes a float
after performing multiplication:
class FloatSubclass(float): pass
test_eq_type(FloatSubclass(3.0) * 2, 6.0)
This behavior is often not desirable when performing transformations on data. Therefore, Transform
will attempt to cast the output to be of the same type as the input by default. In the below example, the output will be cast to a FloatSubclass
type to match the type of the input:
@Transform
def f(x): return x*2
test_eq_type(f(FloatSubclass(3.0)), FloatSubclass(6.0))
We can optionally turn off casting by annotating the transform function with a return type of None
:
@Transform
def f(x)-> None: return x*2 # Same transform as above, but with a -> None annotation
test_eq_type(f(FloatSubclass(3.0)), 6.0) # Casting is turned off because of -> None annotation
However, Transform
will only cast output back to the input type when the input is a subclass of the output. In the below example, the input is of type FloatSubclass
which is not a subclass of the output which is of type str
. Therefore, the output doesn't get cast back to FloatSubclass
and stays as type str
:
@Transform
def f(x): return str(x)
test_eq_type(f(Float(2.)), '2.0')
Just like encodes
, the decodes
method will cast outputs to match the input type in the same way. In the below example, the output of decodes
remains of type MySubclass
:
class MySubclass(int): pass
def enc(x): return MySubclass(x+1)
def dec(x): return x-1
f = Transform(enc,dec)
t = f(1) # t is of type MySubclass
test_eq_type(f.decode(t), MySubclass(1)) # the output of decode is cast to MySubclass to match the input type.
split_idx
¶You can apply transformations to subsets of data by specifying a split_idx
property. If a transform has a split_idx
then it's only applied if the split_idx
param matches. In the below example, we set split_idx
equal to 1
:
def enc(x): return x+1
def dec(x): return x-1
f = Transform(enc,dec)
f.split_idx = 1
The transformations are applied when a matching split_idx
parameter is passed:
test_eq(f(1, split_idx=1),2)
test_eq(f.decode(2, split_idx=1),1)
On the other hand, transformations are ignored when the split_idx
parameter does not match:
test_eq(f(1, split_idx=0), 1)
test_eq(f.decode(2, split_idx=0), 2)
Transform operates on lists as a whole, not element-wise:
class A(Transform):
def encodes(self, x): return dict(x)
def decodes(self, x): return list(x.items())
f = A()
_inp = [(1,2), (3,4)]
t = f(_inp)
test_eq(t, dict(_inp))
test_eq(f.decodes(t), _inp)
#|hide
f.split_idx = 1
test_eq(f(_inp, split_idx=1), dict(_inp))
test_eq(f(_inp, split_idx=0), _inp)
If you want a transform to operate on a list elementwise, you must implement this appropriately in the encodes
and decodes
methods:
class AL(Transform): pass
@AL
def encodes(self, x): return [x_+1 for x_ in x]
@AL
def decodes(self, x): return [x_-1 for x_ in x]
f = AL()
t = f([1,2])
test_eq(t, [2,3])
test_eq(f.decode(t), [1,2])
Unlike lists, Transform
operates on tuples element-wise.
def neg_int(x): return -x
f = Transform(neg_int)
test_eq(f((1,2,3)), (-1,-2,-3))
Transforms will also apply TypedDispatch
element-wise on tuples when an input type annotation is specified. In the below example, the values 1.0
and 3.0
are ignored because they are of type float
, not int
:
def neg_int(x:int): return -x
f = Transform(neg_int)
test_eq(f((1.0, 2, 3.0)), (1.0, -2, 3.0))
#|hide
test_eq(f((1,)), (-1,))
test_eq(f((1.,)), (1.,))
test_eq(f.decode((1,2)), (1,2))
test_eq(f.input_types, int)
Another example of how Transform
can use TypedDispatch
with tuples is shown below:
class B(Transform): pass
@B
def encodes(self, x:int): return x+1
@B
def encodes(self, x:str): return x+'hello'
@B
def encodes(self, x): return str(x)+'!'
If the input is not an int
or str
, the third encodes
method will apply:
b = B()
test_eq(b([1]), '[1]!')
test_eq(b([1.0]), '[1.0]!')
However, if the input is a tuple, then the appropriate method will apply according to the type of each element in the tuple:
test_eq(b(('1',)), ('1hello',))
test_eq(b((1,2)), (2,3))
test_eq(b(('a',1.0)), ('ahello','1.0!'))
#|hide
@B
def decodes(self, x:int): return x-1
test_eq(b.decode((2,)), (1,))
test_eq(b.decode(('2',)), ('2',))
assert pickle.loads(pickle.dumps(b))
Dispatching over tuples works recursively, by the way:
class B(Transform):
def encodes(self, x:int): return x+1
def encodes(self, x:str): return x+'_hello'
def decodes(self, x:int): return x-1
def decodes(self, x:str): return x.replace('_hello', '')
f = B()
start = (1.,(2,'3'))
t = f(start)
test_eq_type(t, (1.,(3,'3_hello')))
test_eq(f.decode(t), start)
Dispatching also works with typing
module type classes, like numbers.integral
:
@Transform
def f(x:numbers.Integral): return x+1
t = f((1,'1',1))
test_eq(t, (2, '1', 2))
#|export
class InplaceTransform(Transform):
"A `Transform` that modifies in-place and just returns whatever it's passed"
def _call(self, fn, x, split_idx=None, **kwargs):
super()._call(fn,x,split_idx,**kwargs)
return x
#|hide
import pandas as pd
class A(InplaceTransform): pass
@A
def encodes(self, x:pd.Series): x.fillna(10, inplace=True)
f = A()
test_eq_type(f(pd.Series([1,2,None])),pd.Series([1,2,10],dtype=np.float64)) #fillna fills with floats.
#|export
class DisplayedTransform(Transform):
"A transform with a `__repr__` that shows its attrs"
@property
def name(self): return f"{super().name} -- {getattr(self,'__stored_args__',{})}"
Transforms normally are represented by just their class name and a list of encodes and decodes implementations:
class A(Transform): encodes,decodes = noop,noop
f = A()
f
A: encodes: (object,object) -> noop decodes: (object,object) -> noop
A DisplayedTransform
will in addition show the contents of all attributes listed in the comma-delimited string self.store_attrs
:
class A(DisplayedTransform):
encodes = noop
def __init__(self, a, b=2):
super().__init__()
store_attr()
A(a=1,b=2)
A -- {'a': 1, 'b': 2}: encodes: (object,object) -> noop decodes:
#|export
class ItemTransform(Transform):
"A transform that always take tuples as items"
_retain = True
def __call__(self, x, **kwargs): return self._call1(x, '__call__', **kwargs)
def decode(self, x, **kwargs): return self._call1(x, 'decode', **kwargs)
def _call1(self, x, name, **kwargs):
if not _is_tuple(x): return getattr(super(), name)(x, **kwargs)
y = getattr(super(), name)(list(x), **kwargs)
if not self._retain: return y
if is_listy(y) and not isinstance(y, tuple): y = tuple(y)
return retain_type(y, x)
ItemTransform
is the class to use to opt out of the default behavior of Transform
.
class AIT(ItemTransform):
def encodes(self, xy): x,y=xy; return (x+y,y)
def decodes(self, xy): x,y=xy; return (x-y,y)
f = AIT()
test_eq(f((1,2)), (3,2))
test_eq(f.decode((3,2)), (1,2))
If you pass a special tuple subclass, the usual retain type behavior of Transform
will keep it:
class _T(tuple): pass
x = _T((1,2))
test_eq_type(f(x), _T((3,2)))
#|hide
f.split_idx = 0
test_eq_type(f((1,2)), (1,2))
test_eq_type(f((1,2), split_idx=0), (3,2))
test_eq_type(f.decode((1,2)), (1,2))
test_eq_type(f.decode((3,2), split_idx=0), (1,2))
#|hide
class Get(ItemTransform):
_retain = False
def encodes(self, x): return x[0]
g = Get()
test_eq(g([1,2,3]), 1)
test_eq(g(L(1,2,3)), 1)
test_eq(g(np.array([1,2,3])), 1)
test_eq_type(g((['a'], ['b', 'c'])), ['a'])
#|hide
class A(ItemTransform):
def encodes(self, x): return _T((x,x))
def decodes(self, x): return _T(x)
f = A()
test_eq(type(f.decode((1,1))), _T)
#|export
def get_func(t, name, *args, **kwargs):
"Get the `t.name` (potentially partial-ized with `args` and `kwargs`) or `noop` if not defined"
f = nested_callable(t, name)
return f if not (args or kwargs) else partial(f, *args, **kwargs)
This works for any kind of t
supporting getattr
, so a class or a module.
test_eq(get_func(operator, 'neg', 2)(), -2)
test_eq(get_func(operator.neg, '__call__')(2), -2)
test_eq(get_func(list, 'foobar')([2]), [2])
a = [2,1]
get_func(list, 'sort')(a)
test_eq(a, [1,2])
Transforms are built with multiple-dispatch: a given function can have several methods depending on the type of the object received. This is done directly with the TypeDispatch
module and type-annotation in Transform
, but you can also use the following class.
#|export
class Func():
"Basic wrapper around a `name` with `args` and `kwargs` to call on a given type"
def __init__(self, name, *args, **kwargs): self.name,self.args,self.kwargs = name,args,kwargs
def __repr__(self): return f'sig: {self.name}({self.args}, {self.kwargs})'
def _get(self, t): return get_func(t, self.name, *self.args, **self.kwargs)
def __call__(self,t): return mapped(self._get, t)
You can call the Func
object on any module name or type, even a list of types. It will return the corresponding function (with a default to noop
if nothing is found) or list of functions.
test_eq(Func('sqrt')(math), math.sqrt)
#|export
class _Sig():
def __getattr__(self,k):
def _inner(*args, **kwargs): return Func(k, *args, **kwargs)
return _inner
Sig = _Sig()
show_doc(Sig, name="Sig")
Sig
is just sugar-syntax to create a Func
object more easily with the syntax Sig.name(*args, **kwargs)
.
f = Sig.sqrt()
test_eq(f(math), math.sqrt)
#|export
def compose_tfms(x, tfms, is_enc=True, reverse=False, **kwargs):
"Apply all `func_nm` attribute of `tfms` on `x`, maybe in `reverse` order"
if reverse: tfms = reversed(tfms)
for f in tfms:
if not is_enc: f = f.decode
x = f(x, **kwargs)
return x
def to_int (x): return Int(x)
def to_float(x): return Float(x)
def double (x): return x*2
def half(x)->None: return x/2
def test_compose(a, b, *fs): test_eq_type(compose_tfms(a, tfms=map(Transform,fs)), b)
test_compose(1, Int(1), to_int)
test_compose(1, Float(1), to_int,to_float)
test_compose(1, Float(2), to_int,to_float,double)
test_compose(2.0, 2.0, to_int,double,half)
class A(Transform):
def encodes(self, x:float): return Float(x+1)
def decodes(self, x): return x-1
tfms = [A(), Transform(math.sqrt)]
t = compose_tfms(3., tfms=tfms)
test_eq_type(t, Float(2.))
test_eq(compose_tfms(t, tfms=tfms, is_enc=False), 1.)
test_eq(compose_tfms(4., tfms=tfms, reverse=True), 3.)
tfms = [A(), Transform(math.sqrt)]
test_eq(compose_tfms((9,3.), tfms=tfms), (3,2.))
#|export
def mk_transform(f):
"Convert function `f` to `Transform` if it isn't already one"
f = instantiate(f)
return f if isinstance(f,(Transform,Pipeline)) else Transform(f)
#|export
def gather_attrs(o, k, nm):
"Used in __getattr__ to collect all attrs `k` from `self.{nm}`"
if k.startswith('_') or k==nm: raise AttributeError(k)
att = getattr(o,nm)
res = [t for t in att.attrgot(k) if t is not None]
if not res: raise AttributeError(k)
return res[0] if len(res)==1 else L(res)
#|export
def gather_attr_names(o, nm):
"Used in __dir__ to collect all attrs `k` from `self.{nm}`"
return L(getattr(o,nm)).map(dir).concat().unique()
#|export
class Pipeline:
"A pipeline of composed (for encode/decode) transforms, setup with types"
def __init__(self, funcs=None, split_idx=None):
self.split_idx,self.default = split_idx,None
if funcs is None: funcs = []
if isinstance(funcs, Pipeline): self.fs = funcs.fs
else:
if isinstance(funcs, Transform): funcs = [funcs]
self.fs = L(ifnone(funcs,[noop])).map(mk_transform).sorted(key='order')
for f in self.fs:
name = camel2snake(type(f).__name__)
a = getattr(self,name,None)
if a is not None: f = L(a)+f
setattr(self, name, f)
def setup(self, items=None, train_setup=False):
tfms = self.fs[:]
self.fs.clear()
for t in tfms: self.add(t,items, train_setup)
def add(self,ts, items=None, train_setup=False):
if not is_listy(ts): ts=[ts]
for t in ts: t.setup(items, train_setup)
self.fs+=ts
self.fs = self.fs.sorted(key='order')
def __call__(self, o): return compose_tfms(o, tfms=self.fs, split_idx=self.split_idx)
def __repr__(self): return f"Pipeline: {' -> '.join([f.name for f in self.fs if f.name != 'noop'])}"
def __getitem__(self,i): return self.fs[i]
def __setstate__(self,data): self.__dict__.update(data)
def __getattr__(self,k): return gather_attrs(self, k, 'fs')
def __dir__(self): return super().__dir__() + gather_attr_names(self, 'fs')
def decode (self, o, full=True):
if full: return compose_tfms(o, tfms=self.fs, is_enc=False, reverse=True, split_idx=self.split_idx)
#Not full means we decode up to the point the item knows how to show itself.
for f in reversed(self.fs):
if self._is_showable(o): return o
o = f.decode(o, split_idx=self.split_idx)
return o
def show(self, o, ctx=None, **kwargs):
o = self.decode(o, full=False)
o1 = (o,) if not _is_tuple(o) else o
if hasattr(o, 'show'): ctx = o.show(ctx=ctx, **kwargs)
else:
for o_ in o1:
if hasattr(o_, 'show'): ctx = o_.show(ctx=ctx, **kwargs)
return ctx
def _is_showable(self, o):
if hasattr(o, 'show'): return True
if _is_tuple(o): return all(hasattr(o_, 'show') for o_ in o)
return False
add_docs(Pipeline,
__call__="Compose `__call__` of all `fs` on `o`",
decode="Compose `decode` of all `fs` on `o`",
show="Show `o`, a single item from a tuple, decoding as needed",
add="Add transforms `ts`",
setup="Call each tfm's `setup` in order")
Pipeline
is a wrapper for compose_tfms
. You can pass instances of Transform
or regular functions in funcs
, the Pipeline
will wrap them all in Transform
(and instantiate them if needed) during the initialization. It handles the transform setup
by adding them one at a time and calling setup on each, goes through them in order in __call__
or decode
and can show
an object by applying decoding the transforms up until the point it gets an object that knows how to show itself.
# Empty pipeline is noop
pipe = Pipeline()
test_eq(pipe(1), 1)
test_eq(pipe((1,)), (1,))
# Check pickle works
assert pickle.loads(pickle.dumps(pipe))
class IntFloatTfm(Transform):
def encodes(self, x): return Int(x)
def decodes(self, x): return Float(x)
foo=1
int_tfm=IntFloatTfm()
def neg(x): return -x
neg_tfm = Transform(neg, neg)
pipe = Pipeline([neg_tfm, int_tfm])
start = 2.0
t = pipe(start)
test_eq_type(t, Int(-2))
test_eq_type(pipe.decode(t), Float(start))
test_stdout(lambda:pipe.show(t), '-2')
pipe = Pipeline([neg_tfm, int_tfm])
t = pipe(start)
test_stdout(lambda:pipe.show(pipe((1.,2.))), '-1\n-2')
test_eq(pipe.foo, 1)
assert 'foo' in dir(pipe)
assert 'int_float_tfm' in dir(pipe)
You can add a single transform or multiple transforms ts
using Pipeline.add
. Transforms will be ordered by Transform.order
.
pipe = Pipeline([neg_tfm, int_tfm])
class SqrtTfm(Transform):
order=-1
def encodes(self, x):
return x**(.5)
def decodes(self, x): return x**2
pipe.add(SqrtTfm())
test_eq(pipe(4),-2)
test_eq(pipe.decode(-2),4)
pipe.add([SqrtTfm(),SqrtTfm()])
test_eq(pipe(256),-2)
test_eq(pipe.decode(-2),256)
Transforms are available as attributes named with the snake_case version of the names of their types. Attributes in transforms can be directly accessed as attributes of the pipeline.
test_eq(pipe.int_float_tfm, int_tfm)
test_eq(pipe.foo, 1)
pipe = Pipeline([int_tfm, int_tfm])
pipe.int_float_tfm
test_eq(pipe.int_float_tfm[0], int_tfm)
test_eq(pipe.foo, [1,1])
# Check opposite order
pipe = Pipeline([int_tfm,neg_tfm])
t = pipe(start)
test_eq(t, -2)
test_stdout(lambda:pipe.show(t), '-2')
class A(Transform):
def encodes(self, x): return int(x)
def decodes(self, x): return Float(x)
pipe = Pipeline([neg_tfm, A])
t = pipe(start)
test_eq_type(t, -2)
test_eq_type(pipe.decode(t), Float(start))
test_stdout(lambda:pipe.show(t), '-2.0')
s2 = (1,2)
pipe = Pipeline([neg_tfm, A])
t = pipe(s2)
test_eq_type(t, (-1,-2))
test_eq_type(pipe.decode(t), (Float(1.),Float(2.)))
test_stdout(lambda:pipe.show(t), '-1.0\n-2.0')
from PIL import Image
class ArrayImage(ndarray):
_show_args = {'cmap':'viridis'}
def __new__(cls, x, *args, **kwargs):
if isinstance(x,tuple): super().__new__(cls, x, *args, **kwargs)
if args or kwargs: raise RuntimeError('Unknown array init args')
if not isinstance(x,ndarray): x = array(x)
return x.view(cls)
def show(self, ctx=None, figsize=None, **kwargs):
if ctx is None: _,ctx = plt.subplots(figsize=figsize)
ctx.imshow(im, **{**self._show_args, **kwargs})
ctx.axis('off')
return ctx
im = Image.open(TEST_IMAGE)
im_t = ArrayImage(im)
def f1(x:ArrayImage): return -x
def f2(x): return Image.open(x).resize((128,128))
def f3(x:Image.Image): return(ArrayImage(array(x)))
pipe = Pipeline([f2,f3,f1])
t = pipe(TEST_IMAGE)
test_eq(type(t), ArrayImage)
test_eq(t, -array(f3(f2(TEST_IMAGE))))
pipe = Pipeline([f2,f3])
t = pipe(TEST_IMAGE)
ax = pipe.show(t)
#test_fig_exists(ax)
#Check filtering is properly applied
add1 = B()
add1.split_idx = 1
pipe = Pipeline([neg_tfm, A(), add1])
test_eq(pipe(start), -2)
pipe.split_idx=1
test_eq(pipe(start), -1)
pipe.split_idx=0
test_eq(pipe(start), -2)
for t in [None, 0, 1]:
pipe.split_idx=t
test_eq(pipe.decode(pipe(start)), start)
test_stdout(lambda: pipe.show(pipe(start)), "-2.0")
def neg(x): return -x
test_eq(type(mk_transform(neg)), Transform)
test_eq(type(mk_transform(math.sqrt)), Transform)
test_eq(type(mk_transform(lambda a:a*2)), Transform)
test_eq(type(mk_transform(Pipeline([neg]))), Pipeline)
#TODO: method examples
show_doc(Pipeline.__call__)
show_doc(Pipeline.decode)
show_doc(Pipeline.setup)
During the setup, the Pipeline
starts with no transform and adds them one at a time, so that during its setup, each transform gets the items processed up to its point and not after.
#|hide
#Test is with TfmdList
#|hide
#|eval: false
from nbdev import nbdev_export
nbdev_export()