! pip install -U jax jaxlib
Requirement already up-to-date: jax in /usr/local/lib/python3.6/dist-packages (0.1.69) Requirement already up-to-date: jaxlib in /usr/local/lib/python3.6/dist-packages (0.1.47) Requirement already satisfied, skipping upgrade: absl-py in /usr/local/lib/python3.6/dist-packages (from jax) (0.9.0) Requirement already satisfied, skipping upgrade: numpy>=1.12 in /usr/local/lib/python3.6/dist-packages (from jax) (1.18.5) Requirement already satisfied, skipping upgrade: opt-einsum in /usr/local/lib/python3.6/dist-packages (from jax) (3.2.1) Requirement already satisfied, skipping upgrade: scipy in /usr/local/lib/python3.6/dist-packages (from jaxlib) (1.4.1) Requirement already satisfied, skipping upgrade: six in /usr/local/lib/python3.6/dist-packages (from absl-py->jax) (1.12.0)
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import functools
import itertools
import operator
import threading
import numpy as onp
from jax import api
from jax import core
from jax import dtypes
from jax.lax import lax
from jax import linear_util as lu
from jax.abstract_arrays import ShapedArray, raise_to_shaped
from jax.api_util import flatten_fun_nokwargs, apply_flat_fun_nokwargs
from jax.interpreters import ad
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax.interpreters import batching
from jax.interpreters import masking
from jax.lib import xla_bridge as xb
from jax.lib import xla_client
from jax.util import (partial, unzip2, safe_map, safe_zip, split_list,
split_dict, cache, extend_name_stack)
from jax.tree_util import (tree_flatten, tree_unflatten, treedef_is_leaf,
treedef_children, treedef_tuple)
from jax import ad_util
import jax.numpy as jnp
import jax.test_util as jtu
map = safe_map
zip = safe_zip
def einsum(*operands):
input_string, output_string, operands = _parse_einsum_input(operands)
out, = einsum_p.bind(*operands, input_strings=input_string.split(','),
output_string=output_string)
return out
def _einsum_impl(*operands, input_strings, output_string):
subscripts = ','.join(input_strings) + '->' + output_string
return [jnp.einsum(subscripts, *operands)]
def sum_tangents(tangents):
return functools.reduce(ad.add_tangents, tangents)
def _einsum_jvp(primals, tangents, *, input_strings, output_string):
subscripts = ','.join(input_strings) + '->' + output_string
this_einsum = functools.partial(einsum, subscripts)
operands_list = []
for index, tangent in enumerate(tangents):
if type(tangent) is not ad.Zero:
operands = list(primals)
operands[index] = tangent
operands_list.append(operands)
out_primal = this_einsum(*primals)
out_tangent = sum_tangents(this_einsum(*ops) for ops in operands_list)
return [out_primal], [out_tangent]
def _einsum_transpose_rule(cotangent, *primals, input_strings, output_string):
index, = [i for i, p in enumerate(primals) if ad.is_undefined_primal(p)]
subscripts = (','.join(input_strings[:index] + input_strings[index+1:])
+ ',' + output_string
+ '->' + input_strings[index])
operands = primals[:index] + primals[index+1:] + tuple(cotangent)
out = [None] * len(primals)
out[index] = einsum(subscripts, *operands)
return out
einsum_p = core.Primitive('einsum')
einsum_p.multiple_results = True
einsum_p.def_impl(_einsum_impl)
def generic_abstract_eval(*avals, **params):
return pe.abstract_eval_fun(_einsum_impl, *avals, **params)
einsum_p.def_abstract_eval(generic_abstract_eval)
ad.primitive_jvps[einsum_p] = _einsum_jvp
xla.initial_style_translations[einsum_p] = xla.lower_fun_initial_style(_einsum_impl)
ad.primitive_transposes[einsum_p] = _einsum_transpose_rule
# TODO(shoyer): batching rule (should be pretty easy)
# batching.primitive_batchers[einsum_p] = _einsum_batching_rule
#@title define `_parse_einsum_input` (from numpy) { display-mode: "form" }
# from numpy.core.einsumfunc
einsum_symbols = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
einsum_symbols_set = set(einsum_symbols)
# asarray = lambda x: x
asarray = jnp.asarray
def _parse_einsum_input(operands):
"""
A reproduction of einsum c side einsum parsing in python.
Returns
-------
input_strings : str
Parsed input strings
output_string : str
Parsed output string
operands : list of array_like
The operands to use in the numpy contraction
Examples
--------
The operand list is simplified to reduce printing:
>>> a = np.random.rand(4, 4)
>>> b = np.random.rand(4, 4, 4)
>>> __parse_einsum_input(('...a,...a->...', a, b))
('za,xza', 'xz', [a, b])
>>> __parse_einsum_input((a, [Ellipsis, 0], b, [Ellipsis, 0]))
('za,xza', 'xz', [a, b])
"""
if len(operands) == 0:
raise ValueError("No input operands")
if isinstance(operands[0], str):
subscripts = operands[0].replace(" ", "")
operands = [asarray(v) for v in operands[1:]]
# Ensure all characters are valid
for s in subscripts:
if s in '.,->':
continue
if s not in einsum_symbols:
raise ValueError("Character %s is not a valid symbol." % s)
else:
tmp_operands = list(operands)
operand_list = []
subscript_list = []
for p in range(len(operands) // 2):
operand_list.append(tmp_operands.pop(0))
subscript_list.append(tmp_operands.pop(0))
output_list = tmp_operands[-1] if len(tmp_operands) else None
operands = [asarray(v) for v in operand_list]
subscripts = ""
last = len(subscript_list) - 1
for num, sub in enumerate(subscript_list):
for s in sub:
if s is Ellipsis:
subscripts += "..."
elif isinstance(s, int):
subscripts += einsum_symbols[s]
else:
raise TypeError("For this input type lists must contain "
"either int or Ellipsis")
if num != last:
subscripts += ","
if output_list is not None:
subscripts += "->"
for s in output_list:
if s is Ellipsis:
subscripts += "..."
elif isinstance(s, int):
subscripts += einsum_symbols[s]
else:
raise TypeError("For this input type lists must contain "
"either int or Ellipsis")
# Check for proper "->"
if ("-" in subscripts) or (">" in subscripts):
invalid = (subscripts.count("-") > 1) or (subscripts.count(">") > 1)
if invalid or (subscripts.count("->") != 1):
raise ValueError("Subscripts can only contain one '->'.")
# Parse ellipses
if "." in subscripts:
used = subscripts.replace(".", "").replace(",", "").replace("->", "")
unused = list(einsum_symbols_set - set(used))
ellipse_inds = "".join(unused)
longest = 0
if "->" in subscripts:
input_tmp, output_sub = subscripts.split("->")
split_subscripts = input_tmp.split(",")
out_sub = True
else:
split_subscripts = subscripts.split(',')
out_sub = False
for num, sub in enumerate(split_subscripts):
if "." in sub:
if (sub.count(".") != 3) or (sub.count("...") != 1):
raise ValueError("Invalid Ellipses.")
# Take into account numerical values
if operands[num].shape == ():
ellipse_count = 0
else:
ellipse_count = max(operands[num].ndim, 1)
ellipse_count -= (len(sub) - 3)
if ellipse_count > longest:
longest = ellipse_count
if ellipse_count < 0:
raise ValueError("Ellipses lengths do not match.")
elif ellipse_count == 0:
split_subscripts[num] = sub.replace('...', '')
else:
rep_inds = ellipse_inds[-ellipse_count:]
split_subscripts[num] = sub.replace('...', rep_inds)
subscripts = ",".join(split_subscripts)
if longest == 0:
out_ellipse = ""
else:
out_ellipse = ellipse_inds[-longest:]
if out_sub:
subscripts += "->" + output_sub.replace("...", out_ellipse)
else:
# Special care for outputless ellipses
output_subscript = ""
tmp_subscripts = subscripts.replace(",", "")
for s in sorted(set(tmp_subscripts)):
if s not in (einsum_symbols):
raise ValueError("Character %s is not a valid symbol." % s)
if tmp_subscripts.count(s) == 1:
output_subscript += s
normal_inds = ''.join(sorted(set(output_subscript) -
set(out_ellipse)))
subscripts += "->" + out_ellipse + normal_inds
# Build output string if does not exist
if "->" in subscripts:
input_subscripts, output_subscript = subscripts.split("->")
else:
input_subscripts = subscripts
# Build output subscripts
tmp_subscripts = subscripts.replace(",", "")
output_subscript = ""
for s in sorted(set(tmp_subscripts)):
if s not in einsum_symbols:
raise ValueError("Character %s is not a valid symbol." % s)
if tmp_subscripts.count(s) == 1:
output_subscript += s
# Make sure output subscripts are in the input
for char in output_subscript:
if char not in input_subscripts:
raise ValueError("Output character %s did not appear in the input"
% char)
# Make sure number operands is equivalent to the number of terms
if len(input_subscripts.split(',')) != len(operands):
raise ValueError("Number of einsum subscripts must be equal to the "
"number of operands.")
return (input_subscripts, output_subscript, operands)
import jax
import jax.test_util as jtu
jax.make_jaxpr(partial(einsum, 'i,ij->ij'))(jnp.zeros((2,)), jnp.zeros((2, 3)))
/usr/local/lib/python3.6/dist-packages/jax/lib/xla_bridge.py:127: UserWarning: No GPU/TPU found, falling back to CPU. warnings.warn('No GPU/TPU found, falling back to CPU.')
{ lambda ; a b. let c = einsum[ input_strings=['i', 'ij'] output_string=ij ] a b in (c,) }
jax.make_jaxpr(partial(einsum, 'i,ij,jk->ij'))(jnp.zeros((2,)), jnp.zeros((2, 3)), jnp.zeros((3, 4)))
{ lambda ; a b c. let d = einsum[ input_strings=['i', 'ij', 'jk'] output_string=ij ] a b c in (d,) }
def make_einsum_grad(subscripts, einsum_fun=einsum, argnums=0):
@partial(jax.grad, argnums=argnums)
def f(*operands):
return jnp.sum(einsum_fun(subscripts, *operands) ** 2)
return f
jax.make_jaxpr(make_einsum_grad('ij,jk->ij'))(jnp.zeros((2, 3)), jnp.zeros((3, 4)))
{ lambda c ; a b. let d = einsum[ input_strings=['ij', 'jk'] output_string=ij ] a b e = mul 2.0 d f = mul c e g = einsum[ input_strings=['jk', 'ij'] output_string=ij ] b f in (g,) }
import opt_einsum
opt_einsum
import collections
operands = 'abc,ad,be,cf,def,dg,eh,fi->ghi'
sizes = collections.defaultdict(lambda: 100)
arrays = [jnp.zeros(tuple(sizes[k] for k in op)) for op in operands.split('->')[0].split(',')]
jax.make_jaxpr(make_einsum_grad(operands))(*arrays)
{ lambda i ; a b c d e f g h. let j = einsum[ input_strings=['abc', 'ad', 'be', 'cf', 'def', 'dg', 'eh', 'fi'] output_string=ghi ] a b c d e f g h k = mul 2.0 j l = mul i k m = einsum[ input_strings=['ad', 'be', 'cf', 'def', 'dg', 'eh', 'fi', 'ghi'] output_string=abc ] b c d e f g h l in (m,) }
operands = 'ad,be,cf,def,dg,eh,fi,ghi->abc'
arrays = [jnp.zeros(tuple(sizes[k] for k in op)) for op in operands.split('->')[0].split(',')]
jax.make_jaxpr(partial(jnp.einsum, operands))(*arrays)
{ lambda ; a b c d e f g h. let i = xla_call[ backend=None call_jaxpr={ lambda ; a b c d e f g h. let i = dot_general[ dimension_numbers=(((0,), (1,)), ((), ())) precision=None ] h e j = dot_general[ dimension_numbers=(((0,), (1,)), ((), ())) precision=None ] i f k = dot_general[ dimension_numbers=(((0,), (1,)), ((), ())) precision=None ] j g l = dot_general[ dimension_numbers=(((), ()), ((0, 1, 2), (0, 1, 2))) precision=None ] k d m = dot_general[ dimension_numbers=(((0,), (1,)), ((), ())) precision=None ] l a n = dot_general[ dimension_numbers=(((0,), (1,)), ((), ())) precision=None ] m b o = dot_general[ dimension_numbers=(((0,), (1,)), ((), ())) precision=None ] n c in (o,) } device=None donated_invars=(False, False, False, False, False, False, False, False) name=_einsum ] a b c d e f g h in (i,) }
operands = 'abc,ad,be,cf,def,dg,eh,fi->ghi'
sizes = collections.defaultdict(lambda: 100)
arrays = [jnp.zeros(tuple(sizes[k] for k in op)) for op in operands.split('->')[0].split(',')]
jax.make_jaxpr(make_einsum_grad(operands, einsum=jnp.einsum))(*arrays)
{ lambda i s ; a b c d e f g h. let j k l m n o p q r = xla_call[ backend=None call_jaxpr={ lambda ; a b c d e f g h i. let j = dot_general[ dimension_numbers=(((0,), (0,)), ((), ())) precision=None ] b a k = dot_general[ dimension_numbers=(((1,), (0,)), ((), ())) precision=None ] j c l = dot_general[ dimension_numbers=(((1,), (0,)), ((), ())) precision=None ] k d m = dot_general[ dimension_numbers=(((), ()), ((0, 1, 2), (0, 1, 2))) precision=None ] l e n = dot_general[ dimension_numbers=(((0,), (0,)), ((), ())) precision=None ] m f o = dot_general[ dimension_numbers=(((0,), (0,)), ((), ())) precision=None ] n g p = dot_general[ dimension_numbers=(((0,), (0,)), ((), ())) precision=None ] o h in (p, *, b, c, d, e, f, g, h) } device=None donated_invars=(False, False, False, False, False, False, False, False, False) name=jvp(_einsum) ] a b c d e f g h i t = mul 2.0 j u = mul s t v = xla_call[ backend=None call_jaxpr={ lambda ; a b c d e f g h i j k l m n o. let p = dot_general[ dimension_numbers=(((2,), (1,)), ((), ())) precision=None ] o g q = transpose[ permutation=(2, 0, 1) ] p r = dot_general[ dimension_numbers=(((2,), (1,)), ((), ())) precision=None ] q f s = transpose[ permutation=(2, 0, 1) ] r t = dot_general[ dimension_numbers=(((2,), (1,)), ((), ())) precision=None ] s e u = transpose[ permutation=(2, 0, 1) ] t v = dot_general[ dimension_numbers=(((), ()), ((0, 1, 2), (0, 1, 2))) precision=None ] u d w = dot_general[ dimension_numbers=(((2,), (1,)), ((), ())) precision=None ] v c x = transpose[ permutation=(0, 2, 1) ] w y = dot_general[ dimension_numbers=(((2,), (1,)), ((), ())) precision=None ] x b z = transpose[ permutation=(0, 2, 1) ] y ba = dot_general[ dimension_numbers=(((0,), (1,)), ((), ())) precision=None ] z a bb = transpose[ permutation=(2, 0, 1) ] ba in (bb,) } device=None donated_invars=(False, False, False, False, False, False, False, False, False, False, False, False, False, False, False) name=transpose(jvp(_einsum)) ] l m n o p q r i i i i i i i u in (v,) }
operands = 'bdik,acaj,ikab,ajac,ikbd->'
sizes = collections.defaultdict(lambda: 10)
arrays = [jnp.zeros(tuple(sizes[k] for k in op)) for op in operands.split('->')[0].split(',')]
jax.make_jaxpr(make_einsum_grad(operands))(*arrays)
{ lambda ; a b c d e. let f = einsum[ input_strings=['bdik', 'acaj', 'ikab', 'ajac', 'ikbd'] output_string= ] a b c d e g = mul 2.0 f h = mul 1.0 g i = einsum[ input_strings=['acaj', 'ikab', 'ajac', 'ikbd', ''] output_string=bdik ] b c d e h in (i,) }
block_until_ready = partial(jax.tree_map, lambda x: x.block_until_ready())
def make_einsum_grad2(subscripts, einsum_fun=einsum, argnums=0):
@partial(jax.grad, argnums=argnums)
def f(*operands):
return einsum_fun(subscripts, *operands)
return f
operands = 'abcde,abfg,cdhi,ghjk,ielm,fjno,klpq,nopqm->'
dim_size = 8
print(f"expression: {operands}")
print(f"dim_size: {dim_size}")
sizes = collections.defaultdict(lambda: dim_size)
arrays = [jnp.zeros(tuple(sizes[k] for k in op)) for op in operands.split('->')[0].split(',')]
argnums = (1, 2, 3, 4, 5, 6, 7)
print(f"gradient argnums: {argnums}")
print()
print("einsum primitive")
f = jax.jit(make_einsum_grad(operands, einsum_fun=einsum, argnums=argnums))
# print(jax.make_jaxpr(f)(*arrays))
block_until_ready(f(*arrays)) # compile
%timeit block_until_ready(f(*arrays))
print()
print("dot_general primitive")
f = jax.jit(make_einsum_grad(operands, einsum_fun=jnp.einsum, argnums=argnums))
# print(jax.make_jaxpr(f)(*arrays))
block_until_ready(f(*arrays)) # compile
%timeit block_until_ready(f(*arrays))
expression: abcde,abfg,cdhi,ghjk,ielm,fjno,klpq,nopqm-> dim_size: 8 gradient argnums: (1, 2, 3, 4, 5, 6, 7) einsum primitive 100 loops, best of 3: 4.88 ms per loop dot_general primitive 100 loops, best of 3: 3.93 ms per loop
operands = 'abcde,abfg,cdhi,ghjk,ielm,fjno,klpq->nopqm'
sizes = collections.defaultdict(lambda: 16)
arrays = [jnp.zeros(tuple(sizes[k] for k in op)) for op in operands.split('->')[0].split(',')]
f = jax.jit(make_einsum_grad(operands, einsum=jnp.einsum, argnums=(0,)))
print(jax.make_jaxpr(f)(*arrays))
block_until_ready(f(*arrays)) # compile
%timeit block_until_ready(f(*arrays))
{ lambda h i ; a b c d e f g. let j = xla_call[ backend=None call_jaxpr={ lambda ; h q a b c d e f g. let i j k l m n o p = xla_call[ backend=None call_jaxpr={ lambda ; a b c d e f g h. let i = dot_general[ dimension_numbers=(((0, 1), (0, 1)), ((), ())) precision=None ] b a j = dot_general[ dimension_numbers=(((2, 3), (0, 1)), ((), ())) precision=None ] i c k = dot_general[ dimension_numbers=(((1, 3), (0, 1)), ((), ())) precision=None ] j d l = dot_general[ dimension_numbers=(((1, 2), (1, 0)), ((), ())) precision=None ] k e m = dot_general[ dimension_numbers=(((0, 1), (0, 1)), ((), ())) precision=None ] l f n = dot_general[ dimension_numbers=(((0, 1), (0, 1)), ((), ())) precision=None ] m g o = transpose[ permutation=(1, 2, 3, 4, 0) ] n in (o, *, b, c, d, e, f, g) } device=None donated_invars=(False, False, False, False, False, False, False, False) name=jvp(_einsum) ] a b c d e f g h r = mul 2.0 i s = mul q r t = xla_call[ backend=None call_jaxpr={ lambda ; a b c d e f g h i j k l m. let n = transpose[ permutation=(4, 0, 1, 2, 3) ] m o = dot_general[ dimension_numbers=(((3, 4), (2, 3)), ((), ())) precision=None ] n f p = transpose[ permutation=(3, 4, 0, 1, 2) ] o q = dot_general[ dimension_numbers=(((3, 4), (2, 3)), ((), ())) precision=None ] p e r = transpose[ permutation=(3, 4, 0, 1, 2) ] q s = dot_general[ dimension_numbers=(((3, 4), (2, 3)), ((), ())) precision=None ] r d t = transpose[ permutation=(0, 4, 3, 1, 2) ] s u = dot_general[ dimension_numbers=(((3, 4), (2, 3)), ((), ())) precision=None ] t c v = transpose[ permutation=(0, 3, 1, 4, 2) ] u w = dot_general[ dimension_numbers=(((3, 4), (2, 3)), ((), ())) precision=None ] v b x = transpose[ permutation=(0, 1, 3, 4, 2) ] w y = dot_general[ dimension_numbers=(((0, 1), (2, 3)), ((), ())) precision=None ] x a z = transpose[ permutation=(3, 4, 0, 1, 2) ] y in (z,) } device=None donated_invars=(False, False, False, False, False, False, False, False, False, False, False, False, False) name=transpose(jvp(_einsum)) ] k l m n o p h h h h h h s in (t,) } device=None donated_invars=(False, False, False, False, False, False, False, False, False) name=f ] h i a b c d e f g in (j,) } 10 loops, best of 3: 130 ms per loop
operands = 'acaj,ikab,ajac,ikbd,->bdik'
arrays = [jnp.zeros(tuple(sizes[k] for k in op)) for op in operands.split('->')[0].split(',')]
jax.make_jaxpr(partial(jnp.einsum, operands))(*arrays)
{ lambda f g ; a b c d e. let h = xla_call[ backend=None call_jaxpr={ lambda ; f j a b c d e. let g = mul c f h = reduce_sum[ axes=(0,) ] g i = transpose[ permutation=(1, 0, 2) ] h k = mul a j l = reduce_sum[ axes=(0,) ] k m = transpose[ permutation=(1, 0, 2) ] l n = dot_general[ dimension_numbers=(((2, 1), (1, 2)), ((0,), (0,))) precision=None ] i m o = dot_general[ dimension_numbers=(((0,), (2,)), ((), ())) precision=None ] n b p = reshape[ dimensions=None new_sizes=() ] e q = dot_general[ dimension_numbers=(((), ()), ((), ())) precision=None ] o p r = dot_general[ dimension_numbers=(((), ()), ((0, 1, 2), (0, 1, 2))) precision=None ] q d s = transpose[ permutation=(2, 3, 0, 1) ] r in (s,) } device=None donated_invars=(False, False, False, False, False, False, False) name=_einsum ] f g a b c d e in (h,) }
#
operands = 'bdik,acaj,ikab,ajac,ikbd->'
sizes = collections.defaultdict(lambda: 10)
arrays = [jnp.zeros(tuple(sizes[k] for k in op)) for op in operands.split('->')[0].split(',')]
jax.make_jaxpr(make_einsum_grad(operands, einsum=jnp.einsum))(*arrays)
{ lambda f g h ; a b c d e. let i j k l m = xla_call[ backend=None call_jaxpr={ lambda ; l p a b c d e f. let g = transpose[ permutation=(2, 0, 1, 3) ] e h = transpose[ permutation=(0, 2, 3, 1) ] a i = dot_general[ dimension_numbers=(((3,), (3,)), ((0, 1, 2), (0, 1, 2))) precision=None ] g h j = transpose[ permutation=(1, 2, 0) ] i k = dot_general[ dimension_numbers=(((2, 0, 1), (3, 0, 1)), ((), ())) precision=None ] j c m = mul d l n = reduce_sum[ axes=(0,) ] m o = transpose[ permutation=(1, 0, 2) ] n q = mul b p r = reduce_sum[ axes=(0,) ] q s = transpose[ permutation=(1, 0, 2) ] r t = dot_general[ dimension_numbers=(((2, 1), (1, 2)), ((0,), (0,))) precision=None ] o s u = dot_general[ dimension_numbers=(((0,), (0,)), ((), ())) precision=None ] k t in (u, *, g, c, t) } device=None donated_invars=(False, False, False, False, False, False, False, False) name=jvp(_einsum) ] f g a b c d e h n = mul 2.0 i o = mul 1.0 n p = xla_call[ backend=None call_jaxpr={ lambda ; a b c d e f g h. let i = dot_general[ dimension_numbers=(((), ()), ((), ())) precision=None ] h c j = dot_general[ dimension_numbers=(((0,), (2,)), ((), ())) precision=None ] i b k = transpose[ permutation=(2, 0, 1) ] j l = dot_general[ dimension_numbers=(((), ()), ((0, 1, 2), (0, 1, 2))) precision=None ] k a m = transpose[ permutation=(0, 3, 1, 2) ] l in (m,) } device=None donated_invars=(False, False, False, False, False, False, False, False) name=transpose(jvp(_einsum)) ] k l m h h h h o in (p,) }
jax.jit(make_einsum_grad('ij,jk->ij'))(jnp.zeros((2, 3)), jnp.zeros((3, 4)))
DeviceArray([[0., 0., 0.], [0., 0., 0.]], dtype=float32)
make_einsum_grad('ij,jk->ij')(jnp.zeros((2, 3)), jnp.zeros((3, 4)))
DeviceArray([[0., 0., 0.], [0., 0., 0.]], dtype=float32)
from functools import partial
import numpy as np
rs = np.random.RandomState(0)
f = partial(einsum, 'i,ij,j->ij')
args = (rs.randn(2), rs.randn(2, 3), rs.randn(3,))
jtu.check_grads(f, args, order=2)
/usr/local/lib/python3.6/dist-packages/jax/lib/xla_bridge.py:116: UserWarning: No GPU/TPU found, falling back to CPU. warnings.warn('No GPU/TPU found, falling back to CPU.')
from functools import partial
import numpy as np
rs = np.random.RandomState(0)
operands = 'ijk,ij,jk->ij'
f = partial(einsum, operands)
args = (rs.randn(2, 3, 4), rs.randn(2, 3), rs.randn(3, 4))
jtu.check_grads(f, args, order=2)
/usr/local/lib/python3.6/dist-packages/jax/lib/xla_bridge.py:127: UserWarning: No GPU/TPU found, falling back to CPU. warnings.warn('No GPU/TPU found, falling back to CPU.')
jax.make_jaxpr(make_einsum_grad(operands))(*args)
{ lambda d ; a b c. let e = einsum[ input_strings=['ijk', 'ij', 'jk'] output_string=ij ] a b c f = mul 2.0 e g = mul d f h = einsum[ input_strings=['ij', 'jk', 'ij'] output_string=ijk ] b c g in (h,) }
print(jax.xla_computation(make_einsum_grad(operands, einsum=jnp.einsum))(*args).as_hlo_text())
HloModule xla_computation_f__3.44 jit_pe_jvp__einsum__.8 { parameter.12 = pred[] parameter(3) parameter.11 = f32[3,4]{1,0} parameter(2) parameter.9 = f32[2,3,4]{2,1,0} parameter(0) transpose.14 = f32[3,2,4]{2,0,1} transpose(parameter.9), dimensions={1,0,2} dot.15 = f32[3,2]{1,0} dot(parameter.11, transpose.14), lhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={2} parameter.10 = f32[2,3]{1,0} parameter(1) transpose.16 = f32[3,2]{0,1} transpose(parameter.10), dimensions={1,0} dot.17 = f32[3,2]{1,0} dot(dot.15, transpose.16), lhs_batch_dims={0,1}, lhs_contracting_dims={}, rhs_batch_dims={0,1}, rhs_contracting_dims={} transpose.18 = f32[2,3]{0,1} transpose(dot.17), dimensions={1,0} constant.13 = pred[] constant(false) ROOT tuple.19 = (f32[2,3]{0,1}, pred[], f32[3,4]{1,0}, f32[3,2]{0,1}) tuple(transpose.18, constant.13, parameter.11, transpose.16) } jit_transpose_pe_jvp__einsum___.29 { parameter.32 = pred[] parameter(2) parameter.33 = pred[] parameter(3) constant.35 = pred[] constant(false) parameter.34 = f32[2,3]{1,0} parameter(4) transpose.36 = f32[3,2]{0,1} transpose(parameter.34), dimensions={1,0} parameter.31 = f32[3,2]{0,1} parameter(1) dot.37 = f32[3,2]{1,0} dot(transpose.36, parameter.31), lhs_batch_dims={0,1}, lhs_contracting_dims={}, rhs_batch_dims={0,1}, rhs_contracting_dims={} parameter.30 = f32[3,4]{1,0} parameter(0) dot.38 = f32[3,2,4]{2,1,0} dot(dot.37, parameter.30), lhs_batch_dims={0}, lhs_contracting_dims={}, rhs_batch_dims={0}, rhs_contracting_dims={} transpose.39 = f32[2,3,4]{2,0,1} transpose(dot.38), dimensions={1,0,2} ROOT tuple.40 = (f32[2,3,4]{2,0,1}) tuple(transpose.39) } ENTRY xla_computation_f__3.44 { constant.7 = pred[] constant(false) parameter.4 = f32[2,3,4]{2,1,0} parameter(0) parameter.5 = f32[2,3]{1,0} parameter(1) parameter.6 = f32[3,4]{1,0} parameter(2) constant.1 = pred[] constant(false) call.20 = (f32[2,3]{0,1}, pred[], f32[3,4]{1,0}, f32[3,2]{0,1}) call(parameter.4, parameter.5, parameter.6, constant.1), to_apply=jit_pe_jvp__einsum__.8 get-tuple-element.22 = pred[] get-tuple-element(call.20), index=1 get-tuple-element.23 = f32[3,4]{1,0} get-tuple-element(call.20), index=2 get-tuple-element.24 = f32[3,2]{0,1} get-tuple-element(call.20), index=3 constant.2 = f32[] constant(1) broadcast.3 = f32[2,3]{1,0} broadcast(constant.2), dimensions={} constant.25 = f32[] constant(2) broadcast.26 = f32[2,3]{1,0} broadcast(constant.25), dimensions={} get-tuple-element.21 = f32[2,3]{0,1} get-tuple-element(call.20), index=0 multiply.27 = f32[2,3]{1,0} multiply(broadcast.26, get-tuple-element.21) multiply.28 = f32[2,3]{1,0} multiply(broadcast.3, multiply.27) call.41 = (f32[2,3,4]{2,0,1}) call(get-tuple-element.23, get-tuple-element.24, constant.1, constant.1, multiply.28), to_apply=jit_transpose_pe_jvp__einsum___.29 get-tuple-element.42 = f32[2,3,4]{2,0,1} get-tuple-element(call.41), index=0 ROOT tuple.43 = (f32[2,3,4]{2,0,1}) tuple(get-tuple-element.42) }
print(jax.xla_computation(make_einsum_grad(operands))(*args).as_hlo_text())
HloModule xla_computation_f__2.43 jit__einsum__260.8 { constant.12 = pred[] constant(false) parameter.11 = f32[3,4]{1,0} parameter(2) parameter.9 = f32[2,3,4]{2,1,0} parameter(0) transpose.13 = f32[3,2,4]{2,0,1} transpose(parameter.9), dimensions={1,0,2} dot.14 = f32[3,2]{1,0} dot(parameter.11, transpose.13), lhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={2} parameter.10 = f32[2,3]{1,0} parameter(1) transpose.15 = f32[3,2]{0,1} transpose(parameter.10), dimensions={1,0} dot.16 = f32[3,2]{1,0} dot(dot.14, transpose.15), lhs_batch_dims={0,1}, lhs_contracting_dims={}, rhs_batch_dims={0,1}, rhs_contracting_dims={} transpose.17 = f32[2,3]{0,1} transpose(dot.16), dimensions={1,0} ROOT tuple.18 = (f32[2,3]{0,1}) tuple(transpose.17) } jit__einsum__261.28 { constant.32 = pred[] constant(false) parameter.31 = f32[2,3]{1,0} parameter(2) parameter.29 = f32[2,3]{1,0} parameter(0) dot.33 = f32[2,3]{1,0} dot(parameter.31, parameter.29), lhs_batch_dims={0,1}, lhs_contracting_dims={}, rhs_batch_dims={0,1}, rhs_contracting_dims={} transpose.34 = f32[3,2]{0,1} transpose(dot.33), dimensions={1,0} parameter.30 = f32[3,4]{1,0} parameter(1) dot.35 = f32[3,2,4]{2,1,0} dot(transpose.34, parameter.30), lhs_batch_dims={0}, lhs_contracting_dims={}, rhs_batch_dims={0}, rhs_contracting_dims={} transpose.36 = f32[2,3,4]{2,0,1} transpose(dot.35), dimensions={1,0,2} ROOT tuple.37 = (f32[2,3,4]{2,0,1}) tuple(transpose.36) } ENTRY xla_computation_f__2.43 { constant.6 = pred[] constant(false) constant.7 = pred[] constant(false) constant.27 = pred[] constant(false) parameter.4 = f32[2,3]{1,0} parameter(1) parameter.5 = f32[3,4]{1,0} parameter(2) constant.1 = f32[] constant(1) broadcast.2 = f32[2,3]{1,0} broadcast(constant.1), dimensions={} constant.23 = f32[] constant(2) broadcast.24 = f32[2,3]{1,0} broadcast(constant.23), dimensions={} parameter.3 = f32[2,3,4]{2,1,0} parameter(0) call.19 = (f32[2,3]{0,1}) call(parameter.3, parameter.4, parameter.5), to_apply=jit__einsum__260.8 get-tuple-element.20 = f32[2,3]{0,1} get-tuple-element(call.19), index=0 tuple.21 = (f32[2,3]{0,1}) tuple(get-tuple-element.20) get-tuple-element.22 = f32[2,3]{0,1} get-tuple-element(tuple.21), index=0 multiply.25 = f32[2,3]{1,0} multiply(broadcast.24, get-tuple-element.22) multiply.26 = f32[2,3]{1,0} multiply(broadcast.2, multiply.25) call.38 = (f32[2,3,4]{2,0,1}) call(parameter.4, parameter.5, multiply.26), to_apply=jit__einsum__261.28 get-tuple-element.39 = f32[2,3,4]{2,0,1} get-tuple-element(call.38), index=0 tuple.40 = (f32[2,3,4]{2,0,1}) tuple(get-tuple-element.39) get-tuple-element.41 = f32[2,3,4]{2,0,1} get-tuple-element(tuple.40), index=0 ROOT tuple.42 = (f32[2,3,4]{2,0,1}) tuple(get-tuple-element.41) }