class Module: def forward(self, x): return x ** 2 Module().forward(10) class Module: def forward(self, x): self.x = x self.before_forward() self.y = x ** 2 self.after_forward() return self.y def before_forward(self): pass def after_forward(self): pass Module().forward(10) class LoggingMixin: def before_forward(self): print(f'{self.x=}') super().before_forward() def after_forward(self): print(f'{self.y=}') super().after_forward() class MyModule(LoggingMixin, Module): pass MyModule().forward(10) from torch import tensor class TensorMixin(Module): def before_forward(self): self.x = tensor(self.x) super().before_forward() def after_forward(self): self.y = tensor(self.y) super().after_forward() class MyModule(TensorMixin, LoggingMixin, Module): pass MyModule().forward(10) class Module: def __init__(self, cbs): self.cbs = cbs for cb in cbs: cb.mod = self def forward(self, x): self.x = x self.callback('before_forward') self.y = x ** 2 self.callback('after_forward') return self.y def callback(self, nm): for cb in self.cbs: getattr(cb, nm, lambda o: None)() class LoggingCB: def before_forward(self): print(f'{self.mod.x=}') def after_forward(self): print(f'{self.mod.y=}') Module([LoggingCB()]).forward(10) class TensorCB: def before_forward(self): self.mod.x = tensor(self.mod.x) def after_forward(self): self.mod.y = tensor(self.mod.y) Module([TensorCB(), LoggingCB()]).forward(10)