This Notebook contains demonstrations of new features in the 0.46 release of Numba that are intended for use by library developers/compiler engineers.
Features demonstrated in this notebook include:
Other new features present but not demonstrated here include:
C
modules and associated helper functions. Documentation here.jit
decorator. Documentation here.First, import the necessary...
from numba import jit, njit, config, __version__, errors
from numba.extending import overload
import numpy as np
assert tuple(int(x) for x in __version__.split('.')[:2]) >= (0, 46)
config.SHOW_HELP = False # switch off help messages
Numba gains a lot from LLVM itself being able to inline functions, and Numba's internals are geared towards making it easy. However, numerous use cases have arisen where it would be useful to be able to inline a function at the Numba IR level. Numba 0.46 adds support for doing this via the keyword argument inline
that can be supplied to the numba.jit
family of decorators and also numba.extending.overload
, documentation is here.
A motivating use case, the following function obviously can be compiled without issue:
from numba.typed import List
@njit
def foo():
l = List()
for i in range(10):
l.append(i * 123.45)
return l
foo()
This minor variation on the above cannot be compiled, the type of the List()
in bar
cannot be inferred as type inference cannot "see" across the function call into baz
where it becomes apparent the type must be a ListType[float64]
.
@njit
def baz(l):
for i in range(10):
l.append(i * 123.45)
@njit
def bar():
l = List()
baz(l)
return l
try:
bar()
except errors.TypingError as e:
print(e)
Something similar to the above use case was the exact reason the ability to perform inlining was explored. The following demonstrates how to resolve the above situation, supplying the kwarg inline='always'
to the called function will force it's body to be inlined at the call site in the caller, hence there's now no type inference issue.
@njit(inline='always')
def baz(l):
for i in range(10):
l.append(i * 123.45)
@njit
def bar():
l = List()
baz(l)
return l
bar() # works fine
# baz got inlined, bar was effectively seen as:
# def bar():
# l = List()
# for i in range(10):
# l.append(i * 123.45)
# return l
#
# which is the same as foo above
To make the inlining capability as flexible as possible three options were added for the kwarg:
'never'
- never inline (default)'always'
- always inlineAn example using all of the above follows (it also uses the new environment variable/config option DEBUG_PRINT_AFTER
to show the IR, docs are here):
from numba import njit, ir
import numba
# enable printing of the IR post legalization, i.e. just before it is lowered
numba.config.DEBUG_PRINT_AFTER="ir_legalization"
@njit(inline='never')
def never_inline():
return 100
@njit(inline='always')
def always_inline():
return 200
def sentinel_cost_model(expr, caller_info, callee_info):
# this cost model will return True (i.e. do inlining) if either:
# a) the callee IR contains an `ir.Const(37)`
# b) the caller IR contains an `ir.Const(13)` logically prior to the call
# site
# check the callee
for blk in callee_info.blocks.values():
for stmt in blk.body:
if isinstance(stmt, ir.Assign):
if isinstance(stmt.value, ir.Const):
if stmt.value.value == 37:
return True
# check the caller
before_expr = True
for blk in caller_info.blocks.values():
for stmt in blk.body:
if isinstance(stmt, ir.Assign):
if isinstance(stmt.value, ir.Expr):
if stmt.value == expr:
before_expr = False
if isinstance(stmt.value, ir.Const):
if stmt.value.value == 13:
return True & before_expr
return False
@njit(inline=sentinel_cost_model)
def maybe_inline1():
# Will not inline based on the callee IR with the declared cost model
# The following is ir.Const(300).
return 300
@njit(inline=sentinel_cost_model)
def maybe_inline2():
# Will inline based on the callee IR with the declared cost model
# The following is ir.Const(37).
return 37
@njit
def foo():
a = never_inline() # will never inline
b = always_inline() # will always inline
# will not inline as the function does not contain a magic constant known to
# the cost model, and the IR up to the call site does not contain a magic
# constant either
d = maybe_inline1()
# declare this magic constant to trigger inlining of maybe_inline1 in a
# subsequent call
magic_const = 13
# will inline due to above constant declaration
e = maybe_inline1()
# will inline as the maybe_inline2 function contains a magic constant known
# to the cost model
c = maybe_inline2()
return a + b + c + d + e + magic_const
foo()
Note in the above IR, as dead code elimination is not performed by default, there are superfluous statements present.
Further, the same inline
kwarg is implemented for the numba.extending.overload
decorator, documentation and examples are here.
numba.config.DEBUG_PRINT_AFTER="" # disable debug print again
In Numba 0.46 the main compiler pipeline was significantly reworked to make it more easily extendable and to permit users to essentially build their own custom compiler frontends. This change is based on a design similar to that found in LLVM. Full documentation is here.
For a large number of releases the Numba @jit
family of decorators have permitted the definition of a custom compiler pipeline via the kwarg pipeline_class
, this has not changed, however the type of the class passed as the value has. Numba 0.46 now requires an instance of a numba.compiler.CompilerBase
class to be passed as the value, this is a much more flexible class than the before mentioned pipeline.
The default compiler used by Numba is the numba.compiler.Compiler
class and it itself makes use of pre-canned pipelines defined in numba.compiler.DefaultPassBuilder
by the methods:
.define_nopython_pipeline()
for the nopython mode pipeline.define_objectmode_pipeline()
for the object-mode pipeline.define_interpreted_pipeline()
for the interpreted pipelineCreating a new custom compiler requires extending from the numba.compiler.CompilerBase
class and overriding the .define_pipelines()
method. e.g.
from numba.compiler import CompilerBase, DefaultPassBuilder
class CustomCompiler(CompilerBase): # custom compiler extends from CompilerBase
def define_pipelines(self):
# define a new set of pipelines (just one in this case) and for demonstration purposes
# reuse an existing pipeline from the DefaultPassBuilder, namely the "nopython" pipeline
pm = DefaultPassBuilder.define_nopython_pipeline(self.state)
# return as an iterable, any number of pipelines may be defined!
return [pm]
Using the custom compiler is just a question of supplying it via the aforementioned pipeline_class
kwarg, for example:
@jit(pipeline_class=CustomCompiler)
def foo(x):
return x + 1
foo(10)
The next example won't work with the CustomCompiler
because there's only the nopython
mode pipeline available in the CustomCompiler
and this function contains a Python object.
@jit(pipeline_class=CustomCompiler)
def foo(x):
return x + 1, object()
from numba import errors
try:
foo(10)
except errors.TypingError as e:
print(str(e))
Numba has a large number of pre-defined passes for use, they are categorised as being:
untyped
, i.e. do not require type information, these are found in numba.untyped_passes
typed
, i.e. require type information, these are found in numba.typed_passes
object mode
, i.e. require object mode, these are found in numba.object_mode_passes
For reference, these are the ones in the code base for 0.46.
for x in numba.compiler_machinery._pass_registry._registry.keys():
print(x)
Let's implement a new pipeline that:
and use it in a new custom compiler. The pipeline management code is found in numba.compiler_machinery
from numba.compiler_machinery import PassManager
from numba.untyped_passes import (TranslateByteCode, FixupArgs, IRProcessing, DeadBranchPrune,
RewriteSemanticConstants)
from numba.typed_passes import (NopythonTypeInference, DeadCodeElimination, IRLegalization,
NoPythonBackend)
def gen_pipeline():
""" pipeline generation function, it need not be a function, pipelines are often
defined directly in `ClassExtendingCompilerBase.define_pipelines` but it'll be used
in a later example for another purpose.
"""
# create a new PassManager to handle the passes for the pipeline
pm = PassManager("custom_pipeline")
# untyped
pm.add_pass(TranslateByteCode, "analyzing bytecode")
pm.add_pass(IRProcessing, "processing IR")
pm.add_pass(RewriteSemanticConstants, "rewrite semantic constants")
pm.add_pass(DeadBranchPrune, "dead branch pruning")
# typed
pm.add_pass(NopythonTypeInference, "nopython frontend")
pm.add_pass(DeadCodeElimination, "DCE")
# legalise
pm.add_pass(IRLegalization, "ensure IR is legal prior to lowering")
# lower
pm.add_pass(NoPythonBackend, "nopython mode backend")
# finalise the contents
pm.finalize()
return pm
class NewPipelineCompiler(CompilerBase):
def define_pipelines(self):
return [gen_pipeline()]
Now use the NewPipelineCompiler
in a deliberately contrived example to demonstrate the effect of certain passes.
numba.config.DEBUG_PRINT_AFTER="ir_processing,rewrite_semantic_constants,dead_branch_prune,dead_code_elimination"
@jit(pipeline_class=NewPipelineCompiler)
def foo(arr):
if arr.ndim == 1:
return 100
else:
return 200
x = np.arange(10) # 1d array input, x.ndim = 1
foo(x)
In the output above, the following can be seen:
ir_processing
pass produces the inital IR.rewrite_semantic_constants
pass replaces the expression:$0.2 = getattr(value=arr, attr=ndim)
with $0.2 = const(int, 1)
dead_branch_prune
pass spotted that the block with label 14
is dead and removed it because:$0.2 = const(int, 1) ['$0.2']
del arr []
$const0.3 = const(int, 1) ['$const0.3']
$0.4 = $0.2 == $const0.3
evaluates to $0.4
always being True
and as a result, it's use as the predicate in branch $0.4, 10, 14
means the 10
branch will always be taken, 14
is dead.dead_code_elimination
pass removed all the statements which were dead (had no effect).In the final output there are now two blocks, labels 0
and 10
. Block 0
has only one statement, an unconditional jump to 10
. In the next section a new pass is going to be written to simplify the control flow graph in such situations, as it's clear that the blocks can be fused.
Implementing a new compiler pass involves writing a class that inherits from numba.compiler_machinery.CompilerPass
. It must be registered with the pass registry before use and through the process of registration declare some information about what it will do in certain scenarios. Documentation for this feature is here.
Continuing with the above example, Numba has a function numba.ir_utils.simplify_CFG
which does the control flow graph simplification alluded to in the final paragraph above. In the following this function is wrapped in a compiler pass and then used in a new pipeline.
from numba.ir_utils import simplify_CFG
from numba.compiler_machinery import register_pass, FunctionPass
# Register this pass with the compiler framework, declare that it can mutate the control
# flow graph and that it is not an analysis_only pass (it potentially mutates the IR).
@register_pass(mutates_CFG=True, analysis_only=False)
# Inherit from FunctionPass, the base class for passes operating on functions
class SimplifyCFG(FunctionPass):
_name = "simplify_cfg" # the common name for the pass
def __init__(self):
FunctionPass.__init__(self)
# implement the method to do the work, "state" is the internal compiler
# state from the CompilerBase instance.
def run_pass(self, state):
# get the IR blocks
blks = state.func_ir.blocks
# run the simplification
new_blks = simplify_CFG(blks)
# update the reference to the block state
state.func_ir.blocks = new_blks
# return whether the IR was mutated (here, CFG change implies IR change)
mutated = blks != new_blks
return mutated
# define a new compiler
class NewPipelineWSimplifyCFGCompiler(CompilerBase):
def define_pipelines(self):
# generate the same pipeline as in the previous example
pm = gen_pipeline()
# add the new pass after DeadCodeElimination
pm.add_pass_after(SimplifyCFG, DeadCodeElimination)
# re-finalize the pipeline since the above mutated it
pm.finalize()
return [pm]
Now re-run the foo
function again with the updated custom compiler including the new pass in its pipeline. Also, print the IR after dead code elimination (the end of output from the last example) and now after the new SimplifyCFG
pass.
numba.config.DEBUG_PRINT_AFTER="dead_code_elimination,simplify_cfg"
@jit(pipeline_class=NewPipelineWSimplifyCFGCompiler)
def foo(arr):
if arr.ndim == 1:
return 100
else:
return 200
x = np.arange(10) # 1d array input, x.ndim = 1
foo(x)
It can be seen in the above that the CFG has been simplified after the new simplify_cfg
pass has run, the IR is now a single block.