Reading:
class Module:
def forward(self, x): return x ** 2
Module().forward(10)
100
class Module:
def forward(self, x):
self.x = x
self.before_forward()
self.y = x ** 2
self.after_forward()
return self.y
def before_forward(self): pass
def after_forward(self): pass
Module().forward(10)
100
class LoggingMixin:
def before_forward(self):
print(f'{self.x=}')
super().before_forward()
def after_forward(self):
print(f'{self.y=}')
super().after_forward()
class MyModule(LoggingMixin, Module): pass
MyModule().forward(10)
self.x=10 self.y=100
100
from torch import tensor
class TensorMixin(Module):
def before_forward(self):
self.x = tensor(self.x)
super().before_forward()
def after_forward(self):
self.y = tensor(self.y)
super().after_forward()
class MyModule(TensorMixin, LoggingMixin, Module): pass
MyModule().forward(10)
self.x=tensor(10) self.y=tensor(100)
tensor(100)
class Module:
def __init__(self, cbs):
self.cbs = cbs
for cb in cbs: cb.mod = self
def forward(self, x):
self.x = x
self.callback('before_forward')
self.y = x ** 2
self.callback('after_forward')
return self.y
def callback(self, nm):
for cb in self.cbs: getattr(cb, nm, lambda o: None)()
class LoggingCB:
def before_forward(self): print(f'{self.mod.x=}')
def after_forward(self): print(f'{self.mod.y=}')
Module([LoggingCB()]).forward(10)
self.mod.x=10 self.mod.y=100
100
class TensorCB:
def before_forward(self): self.mod.x = tensor(self.mod.x)
def after_forward(self): self.mod.y = tensor(self.mod.y)
Module([TensorCB(), LoggingCB()]).forward(10)
self.mod.x=tensor(10) self.mod.y=tensor(100)
tensor(100)