Python loops are inefficient for numeric operations.
import numpy as np
Here's a function that computes the sum of the log of all non-zero values.
def sum_log_nz(ary):
res = np.zeros(ary.shape[0])
for i in range(ary.shape[0]):
v = ary[i]
if v != 0:
res[i] = np.log(v)
return res.sum()
Test the function
a = np.random.random(5_000_000)
a
sum_log_nz(a)
Time the function
%%time
sum_log_nz(a)
Numba can compile the inefficient pure-Python loop into SIMD-vectorized native loop.
import numba
Try compiling the function with Numba.
Notice the difference between settings of fastmath=<True|False>
.
fast_sum_log_nz = numba.njit(fastmath=True)(sum_log_nz)
fast_sum_log_nz
fast_sum_log_nz(a)
Notice the improved performance
%%time
fast_sum_log_nz(a)
fast_sum_log_nz.inspect_cfg(fast_sum_log_nz.signatures[0]).display()
Numba can auto-parallize the function to leverage multiple threads.
par_sum_log_nz = numba.njit(parallel=True)(sum_log_nz)
par_sum_log_nz(a)
Use the .parallel_diagnostics()
to inspect what the compiler has done to optimize the function.
Note:
par_sum_log_nz.parallel_diagnostics()
Use numba.prange
to mark a loop for parallelization.
@numba.njit(parallel=True, fastmath=True)
def par_sum_log_nz(ary):
res = np.zeros(ary.shape[0])
for i in numba.prange(ary.shape[0]):
v = ary[i]
if v != 0:
res[i] = np.log(v)
return res.sum()
par_sum_log_nz(a)
%%time
par_sum_log_nz(a)
Compare the result of the .parallel_diagnostics()
with the previous version.
Note:
par_sum_log_nz.parallel_diagnostics()