# Copyright 2021 Google LLC.
# SPDX-License-Identifier: Apache-2.0
from jax import lax, jit
from functools import partial
import jax.numpy as jnp
import jax
import numpy as np
@partial(jit, static_argnames=['unroll'], backend='cpu')
def polyval(p, x, unroll=64):
shape = lax.broadcast_shapes(p.shape[1:], x.shape)
dtype = jnp.result_type(p, x)
y = lax.full_like(x, 0, shape=shape, dtype=dtype)
y, _ = lax.scan(lambda y, p: (y * x + p, None), y, p, unroll=unroll)
return y
x = np.random.rand(100).astype(np.float32)
p = np.random.randn(10000).astype(np.float32)
print("CPU")
for unroll in [1, 2, 4, 8, 16, 32, 64, 128]:
print(f"unroll={unroll}")
%time polyval(p, x, unroll).block_until_ready()
%timeit polyval(p, x, unroll).block_until_ready()
CPU unroll=1 CPU times: user 29.3 ms, sys: 0 ns, total: 29.3 ms Wall time: 29.6 ms 10000 loops, best of 5: 78 µs per loop unroll=2 CPU times: user 35.2 ms, sys: 0 ns, total: 35.2 ms Wall time: 35 ms 10000 loops, best of 5: 45.4 µs per loop unroll=4 CPU times: user 38.3 ms, sys: 0 ns, total: 38.3 ms Wall time: 38.5 ms 10000 loops, best of 5: 35.8 µs per loop unroll=8 CPU times: user 47 ms, sys: 0 ns, total: 47 ms Wall time: 47.1 ms 10000 loops, best of 5: 37.2 µs per loop unroll=16 CPU times: user 61.1 ms, sys: 0 ns, total: 61.1 ms Wall time: 61.4 ms 10000 loops, best of 5: 46.9 µs per loop unroll=32 CPU times: user 135 ms, sys: 0 ns, total: 135 ms Wall time: 135 ms 1000 loops, best of 5: 358 µs per loop unroll=64 CPU times: user 178 ms, sys: 0 ns, total: 178 ms Wall time: 177 ms 10000 loops, best of 5: 98.7 µs per loop unroll=128 CPU times: user 307 ms, sys: 0 ns, total: 307 ms Wall time: 307 ms 10000 loops, best of 5: 130 µs per loop
@partial(jit, static_argnames=['unroll'], backend='gpu')
def polyval(p, x, unroll=64):
shape = lax.broadcast_shapes(p.shape[1:], x.shape)
dtype = jnp.result_type(p, x)
y = lax.full_like(x, 0, shape=shape, dtype=dtype)
y, _ = lax.scan(lambda y, p: (y * x + p, None), y, p, unroll=unroll)
return y
x = jax.device_put(np.random.rand(100))
p = jax.device_put(np.random.randn(10000))
print("GPU")
for unroll in [1, 2, 4, 8, 16, 32, 64, 128]:
print(f"unroll={unroll}")
%time polyval(p, x, unroll).block_until_ready()
%timeit polyval(p, x, unroll).block_until_ready()
GPU unroll=1 CPU times: user 112 ms, sys: 34.1 ms, total: 146 ms Wall time: 730 ms 10 loops, best of 5: 70.6 ms per loop unroll=2 CPU times: user 62.9 ms, sys: 11.2 ms, total: 74.1 ms Wall time: 150 ms 10 loops, best of 5: 35.6 ms per loop unroll=4 CPU times: user 47.7 ms, sys: 13.6 ms, total: 61.3 ms Wall time: 122 ms 100 loops, best of 5: 17.5 ms per loop unroll=8 CPU times: user 42.9 ms, sys: 27.7 ms, total: 70.6 ms Wall time: 129 ms 100 loops, best of 5: 8.86 ms per loop unroll=16 CPU times: user 56 ms, sys: 34.4 ms, total: 90.4 ms Wall time: 144 ms 100 loops, best of 5: 6.54 ms per loop unroll=32 CPU times: user 105 ms, sys: 38.1 ms, total: 143 ms Wall time: 214 ms 100 loops, best of 5: 3.16 ms per loop unroll=64 CPU times: user 162 ms, sys: 32.8 ms, total: 195 ms Wall time: 258 ms 1000 loops, best of 5: 1.83 ms per loop unroll=128 CPU times: user 393 ms, sys: 8.66 ms, total: 402 ms Wall time: 501 ms 1000 loops, best of 5: 861 µs per loop
@partial(jit, static_argnames=['unroll'])
def polyval(p, x, unroll=64):
shape = lax.broadcast_shapes(p.shape[1:], x.shape)
dtype = jnp.result_type(p, x)
y = lax.full_like(x, 0, shape=shape, dtype=dtype)
y, _ = lax.scan(lambda y, p: (y * x + p, None), y, p, unroll=unroll)
return y
x = jax.device_put(np.random.rand(100))
p = jax.device_put(np.random.randn(10000))
print("TPU")
for unroll in [1, 2, 4, 8, 16, 32, 64, 128]:
print(f"unroll={unroll}")
%time polyval(p, x, unroll).block_until_ready()
%timeit polyval(p, x, unroll).block_until_ready()
TPU unroll=1 CPU times: user 34.8 ms, sys: 0 ns, total: 34.8 ms Wall time: 45.2 ms 100 loops, best of 5: 13.2 ms per loop unroll=2 CPU times: user 134 ms, sys: 89 µs, total: 134 ms Wall time: 107 ms 100 loops, best of 5: 13.3 ms per loop unroll=4 CPU times: user 118 ms, sys: 0 ns, total: 118 ms Wall time: 90.7 ms 100 loops, best of 5: 9.31 ms per loop unroll=8 CPU times: user 96.1 ms, sys: 2.07 ms, total: 98.2 ms Wall time: 94.7 ms 100 loops, best of 5: 5.48 ms per loop unroll=16 CPU times: user 118 ms, sys: 0 ns, total: 118 ms Wall time: 107 ms 100 loops, best of 5: 4.2 ms per loop unroll=32 CPU times: user 218 ms, sys: 0 ns, total: 218 ms Wall time: 181 ms 100 loops, best of 5: 3.79 ms per loop unroll=64 CPU times: user 325 ms, sys: 2.08 ms, total: 327 ms Wall time: 276 ms 100 loops, best of 5: 3.64 ms per loop unroll=128 CPU times: user 726 ms, sys: 0 ns, total: 726 ms Wall time: 631 ms 100 loops, best of 5: 3.49 ms per loop