This notebook contains a demonstration of new features present in the 0.47.0 release of Numba. Whilst release notes are produced as part of the [CHANGE_LOG
](
Demonstrations of new features include:
map
, filter
and reduce
list.sort()
and sorted
with a key
try/except
First, import the necessary from Numba and NumPy...
from numba import jit, njit, config, __version__, errors
from numba.errors import NumbaPendingDeprecationWarning
import warnings
# we're going to ignore a couple of deprecation warnings
warnings.simplefilter('ignore', category=NumbaPendingDeprecationWarning)
from numba.extending import overload
config.SHOW_HELP = 0
import numba
import numpy as np
assert tuple(int(x) for x in __version__.split('.')[:2]) >= (0, 47)
config.FULL_TRACEBACKS = 1
@njit(boundscheck=True)
def OOB_access(x):
sz = len(x)
a = x[0] # fine, first element of x
a += x[sz - 1] # fine, last element of x
a += x[sz] # oops, out of bounds!
try:
OOB_access(np.ones(10))
except IndexError as e:
print(type(e), e)
The setting of config.FULL_TRACEBACKS
(environment variable equivalent) forces the printing of the index, axis and dimension size to the terminal (assuming a terminal was used to invoke python). For example, the terminal that launched this notebook now has:
debug: IndexError: index 10 is out of bounds for axis 0 with size 10
on it. A future release will enhance this feature to include the out of bounds access information in the error message.
The 0.47.0 release adds the following new capability to Numba: dynamic function generation. Essentially functions (closures) defined in a JIT decorated function can now "escape" the function they are defined in and be used as arguments in subsequent function calls. For example:
# takes a function and calls it with argument arg, multiplies the result by 7
@njit
def consumer(function, arg):
return function(arg) * 7
_GLOBAL = 5
@njit
def generator_func():
_FREEVAR = 10
def escapee(x): # closure, 'a' is a local, '_FREEVAR' is a freevar, '_GLOBAL' is global
a = 9
return x * _FREEVAR + a * _GLOBAL
# data argument for the consumer call
x = np.arange(5)
# escapee function is passed to the consumer function along with its argument
return consumer(escapee, x)
generator_func()
map
, filter
, reduce
¶The ability to create dynamic functions lead to being able to write support for map
, filter
and reduce
. This makes it possible to write more "pythonic" code in Numba :-)
import operator
from functools import reduce
from numba.typed import List
@njit
def demo_map_filter_reduce():
# This will be used in map
def mul_n(x, multiplier):
return x * multiplier
# This will be used in filter
V = 20
def greater_than_V(x):
return x > V # captures V from freevars
# this will be used in reduce
reduce_lambda = lambda x, y: (x * 2) + y
a = [x ** 2 for x in range(10)]
n = len(a)
return reduce(reduce_lambda, filter(greater_than_V, map(mul_n, a, range(n))))
demo_map_filter_reduce()
list.sort()
/sorted
with key¶A further extension born from the ability to create dynamic functions was being able to support the key
argument to list.sort
and sorted
, a quick demonstration:
@njit
def demo_sort_sorted(chars):
def key(x):
return x.upper()
x = chars[:]
x.sort()
print("sorted:", ''.join(x))
x = chars[:]
x.sort(reverse=True)
print("sorted backwards:", ''.join(x))
x = chars[:]
x.sort(key=key)
print("sorted key=x.upper():", ''.join(x))
print("sorted(), reversed", ''.join(sorted(x, reverse=True)))
def numba_order(x):
return 'NUMBA🐍numba⚡'.index(x)
x = chars[:]
x.sort(key=numba_order)
print("sorted key=numba_order:", ''.join(x))
# let's sort a list of characters
input_list = ['m','M','a','N','n','u','⚡','🐍','B','b','U','A']
demo_sort_sorted(input_list)
@njit
def demo_try_bare_except(a, b):
try:
c = a / b
return c
except:
print("caught exception")
return -1
print("ok input:", demo_try_bare_except(5., 10.))
print("div by zero input:", demo_try_bare_except(5, 0))
The class Exception
can also be caught, let's mix this with the new bounds checking support:
@njit(boundscheck=True)
def demo_try_except_exception(array, index):
try:
return array[index]
except Exception:
print("caught exception")
return -1
x = np.ones(5)
print("ok input:", demo_try_except_exception(x, 0))
print("OOB access:", demo_try_except_exception(x, 10))
User defined exception classes also work:
class UserDefinedException(Exception):
def __init__(self, some_arg):
self._some_arg = some_arg
@njit(boundscheck=True)
def demo_try_except_ude():
try:
raise UserDefinedException(123)
except Exception:
return "caught UDE!"
print(demo_try_except_ude())
As users of Numba are very aware, Numba has to be able to work out the type of all the variables in a function to be able to compile it (function must be statically typable!). Prior to Numba 0.47.0 tuples of heterogeneous type could not be iterated over as the type of the induction variable in a loop could not be statically computed and further the loop body contents would have a different set of types of each type in the tuple. For example, this doesn't work:
from numba import literal_unroll
@njit
def does_not_work():
tup = (1, 'a', 2j)
for i in tup:
print(i) # Numba cannot work out type of `i`, it changes each loop iteration
print("Typing problem")
try:
does_not_work()
except errors.TypingError as e:
print(e)
In Numba 0.47.0 a new function, numba.literal_unroll
, is introduced. The function itself does nothing much, it's just a token to tell the Numba compiler that the argument needs special treatment for use as an iterable. When this function is applied in situations like in the following, the body of the loop is "versioned" based on the types in the tuple such that Numba can actually statically work out the types for each iteration and compilation will succeed. Here's a working version of the above failing example:
# use special function `numba.literal_unroll`
@njit
def works():
tup = (1, 'a', 2j)
for i in literal_unroll(tup):
print(i) # literal_unroll tells the compiler to version the loop body based on type.
print("Apply literal_unroll():")
works()
A more involved example might be a tuple of locally defined functions (which are all different types by virtue of the Numba type system) that are iterated over:
@njit
def fruit_cookbook():
def get_apples(x):
return ['apple' for _ in range(x * 3)]
def get_oranges(x):
return ['orange' for _ in range(x * 4)]
def get_bananas(x):
return ['banana' for _ in range(x * 2)]
ingredients = (get_apples, get_oranges, get_bananas)
def fruit_salad(scale):
shopping_list = []
for ingredient in literal_unroll(ingredients):
shopping_list.extend(ingredient(scale))
return shopping_list
print(fruit_salad(2))
fruit_cookbook()
Finally, because Numba has string and integer literal support, it's possible to dispatch on these values at compile time and version the loop body with a value based specialisations:
from numba import types
# function stub to overload
def dt(value):
pass
@overload(dt, inline='always')
def ol_dt(li):
# dispatch based on a string literal
if isinstance(li, types.StringLiteral):
value = li.literal_value
if value == "apple":
def impl(li):
return 1
elif value == "orange":
def impl(li):
return 2
elif value == "banana":
def impl(li):
return 3
return impl
# dispatch based on an integer literal
elif isinstance(li, types.IntegerLiteral):
value = li.literal_value
if value == 0xca11ab1e:
def impl(li):
# close over the dispatcher :)
return 0x5ca1ab1e + value
return impl
@njit
def unroll_and_dispatch_on_literal():
acc = 0
for t in literal_unroll(('apple', 'orange', 'banana', 0xca11ab1e)):
acc += dt(t)
return acc
print(unroll_and_dispatch_on_literal())
It's hoped that in a future version of Numba the token function literal_unroll
will not be needed and loop body versioning opportunities will be automatically identified.
This release contains a number of newly supported NumPy functions, all written by contributors from the Numba community:
np.arange
now supports the dtype
keyword argument.
Also now supported are:
np.lcm
np.gcd
A quick demo of the above:
@njit
def demo_numpy():
a = np.arange(5, dtype=np.uint8)
b = np.lcm(a, 2)
c = np.gcd(a, 3)
return a, b, c
demo_numpy()
A large number of unicode string features/enhancements were added in 0.47.0, namely:
str.index()
str.rindex()
start/end
parameters for str.find()
str.rpartition()
str.lower()
and a lot of querying functions:
str.isalnum()
str.isalpha()
str.isascii()
str.isidentifier()
str.islower()
str.isprintable()
str.isspace()
str.istitle()
@njit
def demo_string_enhancements(arg):
print("index:", arg.index("🐍")) # index of snake
print("rindex:", arg.rindex("🐍")) # rindex of snake
print("find:", arg.find("🐍", start=2, end=6)) # find snake with start+end
print("rpartition:", arg.rpartition("🐍")) # rpartition snake
print("lower:", arg.lower()) # lower snake
print("isalnum:", 'abc123'.isalnum(), '🐍'.isalnum())
print("isalpha:", 'abc'.isalpha(), '123'.isalpha())
print("isascii:", 'abc'.isascii(), '🐍'.isascii())
print("isidentifier:", '1'.isidentifier(), 'var'.isidentifier())
print("islower:", 'SHOUT'.islower(), 'whisper'.islower())
print("isprintable:", '\x07'.isprintable(), 'BEL'.isprintable())
print("isspace:", ' '.isspace(), '_'.isspace())
print("istitle:", "Titlestring".istitle(), "notTitlestring".istitle())
arg = "N🐍u🐍M🐍b🐍A⚡"
demo_string_enhancements(arg)