# Copyright 2019 Google LLC.
# SPDX-License-Identifier: Apache-2.0
import jax
import jax.numpy as jnp
import numpy as np
@jax.jit
def gradient_conv(x):
y = jnp.array([0.5, 0.0, -0.5])
return jax.lax.conv(
x.reshape(1, 1, -1),
y.reshape(1, 1, -1),
window_strides=(1,),
padding='VALID',
).squeeze(axis=(0, 1))
@jax.jit
def gradient_slicing(x):
return 0.5 * (x[:-2] - x[2:])
x = jnp.array([1, -3, 0, 4, 2, 5., 0.])
np.testing.assert_allclose(gradient_conv(x), gradient_slicing(x))
# CPU
x = jax.device_put(jnp.ones((100, 100, 100)))
%timeit jax.device_get(gradient_slicing(x))
%timeit jax.device_get(gradient_conv(x))
100 loops, best of 3: 3 ms per loop 100 loops, best of 3: 4.5 ms per loop
# GPU
x = jax.device_put(jnp.ones((100, 100, 100)))
%timeit jax.device_get(gradient_slicing(x))
%timeit jax.device_get(gradient_conv(x))
1000 loops, best of 3: 1.59 ms per loop 100 loops, best of 3: 1.87 ms per loop