! pip install -U jax jaxlib
from jax.experimental import loops
import jax
import jax.numpy as jnp
import numpy as np
import numba
@jax.jit
def smooth_image_jax(x, n):
with loops.Scope() as s:
s.x = x
s.y = jnp.zeros_like(x)
s.k = 0
for _ in s.while_range(lambda: s.k < n):
s.k += 1
k, m = x.shape
for i in s.range(k):
for j in s.range(m):
new_value = 0.25 * (s.x[i - 1, j]
+ s.x[(i + 1) % k, j]
+ s.x[i, j - 1]
+ s.x[i, (j + 1) % m])
s.y = s.y.at[i, j].set(new_value)
s.x = s.y
return s.y
@numba.jit
def _smooth_image_numba(x, n):
y = np.zeros_like(x)
for _ in range(n):
k, m = x.shape
for i in range(k):
for j in range(m):
y[i, j] = 0.25 * (x[i - 1, j]
+ x[(i + 1) % k, j]
+ x[i, j - 1]
+ x[i, (j + 1) % m])
x[:] = y
return y
def smooth_image_numba(x, n):
return _smooth_image_numba(x.copy(), n)
x = np.arange(25.0).reshape(5, 5)
x
array([[ 0., 1., 2., 3., 4.], [ 5., 6., 7., 8., 9.], [10., 11., 12., 13., 14.], [15., 16., 17., 18., 19.], [20., 21., 22., 23., 24.]])
smooth_image_numba(x, 2)
array([[ 9.375 , 9.125 , 9.8125, 10.5 , 10.25 ], [ 8.125 , 7.875 , 8.5625, 9.25 , 9. ], [11.5625, 11.3125, 12. , 12.6875, 12.4375], [15. , 14.75 , 15.4375, 16.125 , 15.875 ], [13.75 , 13.5 , 14.1875, 14.875 , 14.625 ]])
smooth_image_numba(x, 2)
array([[ 9.375 , 9.125 , 9.8125, 10.5 , 10.25 ], [ 8.125 , 7.875 , 8.5625, 9.25 , 9. ], [11.5625, 11.3125, 12. , 12.6875, 12.4375], [15. , 14.75 , 15.4375, 16.125 , 15.875 ], [13.75 , 13.5 , 14.1875, 14.875 , 14.625 ]])
smooth_image_jax(x, 1)
DeviceArray([[ 7.5 , 7.25, 8.25, 9.25, 9. ], [ 6.25, 6. , 7. , 8. , 7.75], [11.25, 11. , 12. , 13. , 12.75], [16.25, 16. , 17. , 18. , 17.75], [15. , 14.75, 15.75, 16.75, 16.5 ]], dtype=float32)
smooth_image_jax(x, 1)
DeviceArray([[ 7.5 , 7.25, 8.25, 9.25, 10. ], [ 6.25, 6. , 7. , 8. , 8.75], [11.25, 11. , 12. , 13. , 13.75], [16.25, 16. , 17. , 18. , 18.75], [20. , 19.75, 20.75, 21.75, 22.5 ]], dtype=float32)
x = np.arange(256.0 ** 2).reshape(256, 256)
%time print(smooth_image_numba(x, 64))
%timeit smooth_image_numba(x, 64)
[[30580.57911198 30563.8374853 30547.90229619 ... 30630.27476735 30614.33957825 30597.59795157] [26294.72268224 26277.98105556 26262.04586645 ... 26344.41833761 26328.48314851 26311.74152183] [22215.31427203 22198.57264535 22182.63745625 ... 22265.00992741 22249.0747383 22232.33311163] ... [43302.66688837 43285.9252617 43269.99007259 ... 43352.36254375 43336.42735465 43319.68572797] [39223.25847817 39206.51685149 39190.58166239 ... 39272.95413355 39257.01894444 39240.27731776] [34937.40204843 34920.66042175 34904.72523265 ... 34987.09770381 34971.1625147 34954.42088802]] CPU times: user 442 ms, sys: 10.7 ms, total: 453 ms Wall time: 454 ms 10 loops, best of 3: 31 ms per loop
%time print(smooth_image_jax(x, 64).block_until_ready())
%timeit smooth_image_jax(x, 64).block_until_ready()
/usr/local/lib/python3.6/dist-packages/jax/lib/xla_bridge.py:127: UserWarning: No GPU/TPU found, falling back to CPU. warnings.warn('No GPU/TPU found, falling back to CPU.')
[[30580.578 30563.84 30547.902 ... 30630.277 30614.34 30597.602] [26294.727 26277.98 26262.047 ... 26344.418 26328.488 26311.74 ] [22215.312 22198.574 22182.637 ... 22265.014 22249.072 22232.338] ... [43302.67 43285.926 43269.992 ... 43352.363 43336.43 43319.688] [39223.258 39206.52 39190.582 ... 39272.957 39257.023 39240.28 ] [34937.406 34920.66 34904.727 ... 34987.098 34971.164 34954.42 ]] CPU times: user 814 ms, sys: 37.1 ms, total: 851 ms Wall time: 844 ms 10 loops, best of 3: 59.3 ms per loop