@overload_classmethod
for NumPy Array subclasses¶In this release, experimental support is added for specializing the allocator in NumPy ndarray
subclasses. Two key enhancements were added to enable this:
@overload_classmethod
permits the specializing of classmethod
on specific types; and,Array._allocate
as an overloadable classmethod
on Numba's Array
type.The rest of this notebook demonstrates the use of @overload_classmethod
to override the allocator for a custom NumPy ndarray
subclass.
# All necessary imports
import builtins
import ctypes
from numbers import Number
import numpy as np
# We'll be need to write some LLVM IR
from llvmlite import ir
from numba import njit
from numba.core import types
from numba.extending import (
overload_classmethod,
typeof_impl,
register_model,
intrinsic,
)
from numba.core import cgutils, types, typing
from numba.core.datamodel import models
from numba.np import numpy_support
Make a NumPy ndarray
subclass called MyArray
. It needs to override __array_ufunc__
to specialize how certain ufuncs are handled.
class MyArray(np.ndarray):
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
# This is a "magic" method in NumPy subclasses to override
# the behavior of NumPy’s ufuncs.
if method == "__call__":
N = None
scalars = []
for inp in inputs:
# If scalar?
if isinstance(inp, Number):
scalars.append(inp)
# If array?
elif isinstance(inp, (type(self), np.ndarray)):
if isinstance(inp, type(self)):
scalars.append(np.ndarray(inp.shape, inp.dtype, inp))
else:
scalars.append(inp)
# Guard shape
if N is not None:
if N != inp.shape:
raise TypeError("inconsistent sizes")
else:
N = inp.shape
# If unknown type?
else:
return NotImplemented
print(f"NumPy: {type(self)}.__array_ufunc__ method={method} inputs={inputs}")
ret = ufunc(*scalars, **kwargs)
return self.__class__(ret.shape, ret.dtype, ret)
else:
return NotImplemented
Make a subclass of the Numba Array
type to represent MyArray
as a Numba type. Similar to the NumPy ndarray
subclass, the Numba type also has a __array_ufunc__
method, but the difference is that it operates in the Numba typing domain. Concretely, it receives inputs
that are the argument types, not the argument values, and it returns the type of the returned value, not the return value itself.
class MyArrayType(types.Array):
def __init__(self, dtype, ndim, layout, readonly=False, aligned=True):
name = f"MyArray({ndim}, {dtype}, {layout})"
super().__init__(dtype, ndim, layout, readonly=readonly,
aligned=aligned, name=name)
# Tell Numba typing how to combine MyArrayType with other ndarray types.
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
"""
This is the parallel for NumPy's __array_ufunc__ but operates on Numba types instead.
In NumPy's __array_ufunc__, this performs the calculation, but here we
only produce the return type.
"""
if method == "__call__":
for inp in inputs:
if not isinstance(inp, (types.Array, types.Number)):
return NotImplemented
print(f"Numba: {self}.__array_ufunc__ method={method} inputs={inputs}")
return MyArrayType
else:
return NotImplemented
We need to teach Numba that MyArray
corresponds to the Numba type MyArrayType
. This is done by registering the implementation of typeof
for MyArray
.
@typeof_impl.register(MyArray)
def typeof_ta_ndarray(val, c):
# Determine dtype
try:
dtype = numpy_support.from_dtype(val.dtype)
except NotImplementedError:
raise ValueError("Unsupported array dtype: %s" % (val.dtype,))
# Determine memory layout
layout = numpy_support.map_layout(val)
# Determine writeability
readonly = not val.flags.writeable
return MyArrayType(dtype, val.ndim, layout, readonly=readonly)
We also need to teach Numba how MyArrayType
is represented in memory. For our purpose, it is the same as the basic Array
type. This is done by registering a datamodel
for MyArrayType
.
register_model(MyArrayType)(models.ArrayModel)
numba.core.datamodel.models.ArrayModel
We define a new allocator to use inside Numba for MyArray
. Numba exposes an API for external code to register a new allocator table. The C structure for the allocator table is defined below:
(From: https://github.com/numba/numba/blob/0.54.0/numba/core/runtime/nrt_external.h#L10-L19)
typedef void *(*NRT_external_malloc_func)(size_t size, void *opaque_data);
typedef void *(*NRT_external_realloc_func)(void *ptr, size_t new_size, void *opaque_data);
typedef void (*NRT_external_free_func)(void *ptr, void *opaque_data);
struct ExternalMemAllocator {
NRT_external_malloc_func malloc;
NRT_external_realloc_func realloc;
NRT_external_free_func free;
void *opaque_data;
};
In the following, we use ctypes
to expose Python functions as C-functions (using ctypes.CFUNCTYPE
). These functions will be used as the allocator and deallocator. Then, we put the pointers to these functions into a ctypes.Structure
that matches the ExternalMemAllocator
structure shown above.
As this is not a performance focused implementation, we are using Python functions as the allocator/deallocator so that we can print()
when they are invoked. For production use, users are expected to write the allocator/deallocator in native code.
WARNING: DO NOT rerun the following cells. It will cause a segfault because the deallocator (free_func()
) can be a removed before all the Numba dynamic memory is released.
lib = ctypes.CDLL(None)
lib.malloc.argtypes = [ctypes.c_size_t]
lib.malloc.restype = ctypes.c_size_t
lib.free.argtypes = [ctypes.c_void_p]
@ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.c_size_t, ctypes.c_void_p)
def malloc_func(size, data):
"""
The allocator. Numba takes opaque data as a void* in the second argument.
"""
# Call underlying C malloc
out = lib.malloc(size)
print(f">>> Malloc size={size} data={data} -> {hex(np.uintp(out))}")
return out
@ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_void_p)
def free_func(ptr, data):
"""
The deallocator. Numba takes opaque data as a void* in the second argument.
"""
if lib is None:
# Note: in practice guard against global being removed during interpreter shutdown
return
print(f">>> Free ptr={hex(ptr)} data={data}")
# Call underlying C free()
lib.free(ptr)
return
class ExternalMemAllocator(ctypes.Structure):
"""
This defines a struct for the allocator table.
Its fields must match ExternalMemAllocator defined in `nrt_external.h`
"""
_fields_ = [
("malloc_func", ctypes.c_void_p),
("realloc_func", ctypes.c_void_p),
("free_func", ctypes.c_void_p),
("data", ctypes.c_void_p),
]
# Instantiate the allocator table
allocator_table = ExternalMemAllocator(
malloc_func=ctypes.cast(malloc_func, ctypes.c_void_p),
realloc_func=None, # unused; skipped for demo purpose
free_func=ctypes.cast(free_func, ctypes.c_void_p),
data=None, # no extra data needed
)
# Inspect the address of the table
print("allocator_table:", hex(ctypes.addressof(allocator_table)))
allocator_table: 0x7faad81b2fb0
Now to override the memory allocator for this array subclass...
Note: For demonstration purpose, the allocator references the dynamic runtime address of the allocator-table. This disables several features of Numba, including caching and AOT compilation.
@overload_classmethod(MyArrayType, "_allocate")
def _ol_array_allocate(cls, allocsize, align):
"""Implements a Numba-only classmethod on the array type.
"""
def impl(cls, allocsize, align):
# The bulk of the work in implemented in the intrinsic below.
return allocator_MyArray(allocsize, align)
return impl
@intrinsic
def allocator_MyArray(typingctx, allocsize, align):
def impl(context, builder, sig, args):
context.nrt._require_nrt()
size, align = args
mod = builder.module
u32 = ir.IntType(32)
voidptr = cgutils.voidptr_t
# We will use our custom allocator table here.
# The table is referenced by its dynamic runtime address.
addr = ctypes.addressof(allocator_table)
ext_alloc = context.add_dynamic_addr(builder, addr, info='custom_alloc_table')
# Invoke the allocator routine that uses our custom allocator
fnty = ir.FunctionType(voidptr, [cgutils.intp_t, u32, voidptr])
fn = cgutils.get_or_insert_function(
mod, fnty, name="NRT_MemInfo_alloc_safe_aligned_external"
)
fn.return_value.add_attribute("noalias")
if isinstance(align, builtins.int):
align = context.get_constant(types.uint32, align)
else:
assert align.type == u32, "align must be a uint32"
call = builder.call(fn, [size, align, ext_alloc])
return call
mip = types.MemInfoPointer(types.voidptr) # return untyped pointer
sig = typing.signature(mip, allocsize, align)
return sig, impl
To test, we define a simple functions that computes a * 2 + a
:
def foo(a):
return a * 2 + a
buf = np.arange(4)
a = MyArray(buf.shape, buf.dtype, buf)
a
MyArray([0, 1, 2, 3])
When foo()
, is not Numba-compiled, is executed, we can see that the MyArray.__array_ufunc__
method is used for the *
and +
operations.
foo(a)
NumPy: <class '__main__.MyArray'>.__array_ufunc__ method=__call__ inputs=(MyArray([0, 1, 2, 3]), 2) NumPy: <class '__main__.MyArray'>.__array_ufunc__ method=__call__ inputs=(MyArray([0, 2, 4, 6]), MyArray([0, 1, 2, 3]))
MyArray([0, 3, 6, 9])
Below is the Numba JIT version:
jit_foo = njit(foo)
When jit_foo()
is executed, MyArrayType.__array_ufunc__
method is used to compute the types of the *
and +
operations. Note, type-inference is invoking the __array_ufunc__
method multiple times due to specifics of the algorithm. We can also see a series of prints to stdout
as part of the implementation of the allocator (malloc_func()
) and deallocator (free_func()
). It is showing two allocations for the result of *
and +
, and one deallocation for the intermediate in *
.
jit_foo(a)
Numba: MyArray(1, int64, C).__array_ufunc__ method=__call__ inputs=(MyArray(1, int64, C), int64) Numba: MyArray(1, int64, C).__array_ufunc__ method=__call__ inputs=(MyArray(1, int64, C), MyArray(1, int64, C)) Numba: MyArray(1, int64, C).__array_ufunc__ method=__call__ inputs=(MyArray(1, int64, C), int64) Numba: MyArray(1, int64, C).__array_ufunc__ method=__call__ inputs=(MyArray(1, int64, C), MyArray(1, int64, C)) Numba: MyArray(1, int64, C).__array_ufunc__ method=__call__ inputs=(MyArray(1, int64, C), int64) Numba: MyArray(1, int64, C).__array_ufunc__ method=__call__ inputs=(MyArray(1, int64, C), MyArray(1, int64, C)) >>> Malloc size=144 data=None -> 0x7faa85511aa0 >>> Malloc size=144 data=None -> 0x7faa85561760 >>> Free ptr=0x7faa85511aa0 data=None
array([0, 3, 6, 9])
Lastly, we can observe the use of the MyArray
type in the annotated IR.
jit_foo.inspect_types()
foo (MyArray(1, int64, C),) -------------------------------------------------------------------------------- # File: <ipython-input-8-0d11c4a7f23d> # --- LINE 1 --- def foo(a): # --- LINE 2 --- # label 0 # a = arg(0, name=a) :: MyArray(1, int64, C) # $const4.1 = const(int, 2) :: Literal[int](2) # $6binary_multiply.2 = a * $const4.1 :: MyArray(1, int64, C) # del $const4.1 # $10binary_add.4 = $6binary_multiply.2 + a :: MyArray(1, int64, C) # del a # del $6binary_multiply.2 # $12return_value.5 = cast(value=$10binary_add.4) :: MyArray(1, int64, C) # del $10binary_add.4 # return $12return_value.5 return a * 2 + a ================================================================================