! pip install -q -U jaxlib jax %tensorflow_version 2.x import jax import tensorflow as tf import jax.numpy as jnp import numpy as np from functools import partial def as_jax(x): return jnp.asarray(np.asarray(x)) def as_tf(x): return tf.convert_to_tensor(np.asarray(x)) def wrap_tf_in_jax(tf_func): @jax.custom_vjp # requires latest JAX release def f(x): return as_jax(tf_func(as_tf(x))) def f_fwd(x): with tf.GradientTape() as tape: x = as_tf(x) tape.watch(x) y = tf_func(x) vjp_func = jax.tree_util.Partial(partial(tape.gradient, y, x)) return as_jax(y), vjp_func def f_rev(vjp_func, ct_y): ct_x = vjp_func(as_tf(ct_y)) return (as_jax(ct_x),) f.defvjp(f_fwd, f_rev) return f x = jnp.arange(3.0) wrapped_sum = wrap_tf_in_jax(tf.reduce_sum) np.testing.assert_allclose(wrapped_sum(x), jnp.sum(x)) np.testing.assert_allclose(jax.grad(wrapped_sum)(x), jax.grad(jnp.sum)(x)) wrapped_square = wrap_tf_in_jax(tf.square) def tf_and_jax(x): return wrapped_square(x).sum() def jax_only(x): return (x ** 2).sum() np.testing.assert_allclose(tf_and_jax(x), jax_only(x)) np.testing.assert_allclose(jax.grad(tf_and_jax)(x), jax.grad(jax_only)(x)) def wrap_jax_in_tf(jax_func): @tf.custom_gradient def f(x): y, jax_vjp_fn = jax.vjp(jax_func, as_jax(x)) def tf_vjp_fn(ct_y): ct_x, = as_tf(jax_vjp_fn(as_jax(ct_y))) return ct_x return as_tf(y), tf_vjp_fn return f def tf_grad(f): def f2(x): with tf.GradientTape() as g: g.watch(x) y = f(x) return g.gradient(y, x) return f2 x = tf.range(3.0) wrapped_sum = wrap_jax_in_tf(jnp.sum) np.testing.assert_allclose(wrapped_sum(x), tf.reduce_sum(x)) np.testing.assert_allclose(tf_grad(wrapped_sum)(x), tf_grad(tf.reduce_sum)(x)) wrapped_square = wrap_jax_in_tf(jnp.square) def tf_and_jax(x): return tf.reduce_sum(wrapped_square(x)) def tf_only(x): return tf.reduce_sum(x ** 2) np.testing.assert_allclose(tf_and_jax(x), tf_only(x)) np.testing.assert_allclose(tf_grad(tf_and_jax)(x), tf_grad(tf_only)(x))