Author: shoyer@google.com
Date: April 16, 2020
The wrapped functions compose well with JAX's autodiff system!
Limitations (both directions):
Limitations for TF in JAX:
jit
. This would need support for wrapping Python functions via XLA's CustomCall.vmap
. Conceivably if we implemented this via a JAX Primitive instead, we could define a batching rule with tf.vectorized_map
.Current rough edges:
custom_vjp
insists on auxiliary outputs being pytrees. So we lie and wrap the closure in tree_util.Partial
.Copyright 2020 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
! pip install -q -U jaxlib jax
%tensorflow_version 2.x
import jax
import tensorflow as tf
import jax.numpy as jnp
import numpy as np
from functools import partial
def as_jax(x):
return jnp.asarray(np.asarray(x))
def as_tf(x):
return tf.convert_to_tensor(np.asarray(x))
def wrap_tf_in_jax(tf_func):
@jax.custom_vjp # requires latest JAX release
def f(x):
return as_jax(tf_func(as_tf(x)))
def f_fwd(x):
with tf.GradientTape() as tape:
x = as_tf(x)
tape.watch(x)
y = tf_func(x)
vjp_func = jax.tree_util.Partial(partial(tape.gradient, y, x))
return as_jax(y), vjp_func
def f_rev(vjp_func, ct_y):
ct_x = vjp_func(as_tf(ct_y))
return (as_jax(ct_x),)
f.defvjp(f_fwd, f_rev)
return f
x = jnp.arange(3.0)
wrapped_sum = wrap_tf_in_jax(tf.reduce_sum)
np.testing.assert_allclose(wrapped_sum(x), jnp.sum(x))
np.testing.assert_allclose(jax.grad(wrapped_sum)(x), jax.grad(jnp.sum)(x))
wrapped_square = wrap_tf_in_jax(tf.square)
def tf_and_jax(x):
return wrapped_square(x).sum()
def jax_only(x):
return (x ** 2).sum()
np.testing.assert_allclose(tf_and_jax(x), jax_only(x))
np.testing.assert_allclose(jax.grad(tf_and_jax)(x), jax.grad(jax_only)(x))
/usr/local/lib/python3.6/dist-packages/jax/lib/xla_bridge.py:123: UserWarning: No GPU/TPU found, falling back to CPU. warnings.warn('No GPU/TPU found, falling back to CPU.')
def wrap_jax_in_tf(jax_func):
@tf.custom_gradient
def f(x):
y, jax_vjp_fn = jax.vjp(jax_func, as_jax(x))
def tf_vjp_fn(ct_y):
ct_x, = as_tf(jax_vjp_fn(as_jax(ct_y)))
return ct_x
return as_tf(y), tf_vjp_fn
return f
def tf_grad(f):
def f2(x):
with tf.GradientTape() as g:
g.watch(x)
y = f(x)
return g.gradient(y, x)
return f2
x = tf.range(3.0)
wrapped_sum = wrap_jax_in_tf(jnp.sum)
np.testing.assert_allclose(wrapped_sum(x), tf.reduce_sum(x))
np.testing.assert_allclose(tf_grad(wrapped_sum)(x), tf_grad(tf.reduce_sum)(x))
wrapped_square = wrap_jax_in_tf(jnp.square)
def tf_and_jax(x):
return tf.reduce_sum(wrapped_square(x))
def tf_only(x):
return tf.reduce_sum(x ** 2)
np.testing.assert_allclose(tf_and_jax(x), tf_only(x))
np.testing.assert_allclose(tf_grad(tf_and_jax)(x), tf_grad(tf_only)(x))