! pip install -U jax jaxlib # Copyright 2018 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import collections import functools import itertools import operator import threading import numpy as onp from jax import api from jax import core from jax import dtypes from jax.lax import lax from jax import linear_util as lu from jax.abstract_arrays import ShapedArray, raise_to_shaped from jax.api_util import flatten_fun_nokwargs, apply_flat_fun_nokwargs from jax.interpreters import ad from jax.interpreters import partial_eval as pe from jax.interpreters import xla from jax.interpreters import batching from jax.interpreters import masking from jax.lib import xla_bridge as xb from jax.lib import xla_client from jax.util import (partial, unzip2, safe_map, safe_zip, split_list, split_dict, cache, extend_name_stack) from jax.tree_util import (tree_flatten, tree_unflatten, treedef_is_leaf, treedef_children, treedef_tuple) from jax import ad_util import jax.numpy as jnp import jax.test_util as jtu map = safe_map zip = safe_zip def einsum(*operands): input_string, output_string, operands = _parse_einsum_input(operands) out, = einsum_p.bind(*operands, input_strings=input_string.split(','), output_string=output_string) return out def _einsum_impl(*operands, input_strings, output_string): subscripts = ','.join(input_strings) + '->' + output_string return [jnp.einsum(subscripts, *operands)] def sum_tangents(tangents): return functools.reduce(ad.add_tangents, tangents, ad.zero) def _einsum_jvp(primals, tangents, *, input_strings, output_string): subscripts = ','.join(input_strings) + '->' + output_string this_einsum = functools.partial(einsum, subscripts) operands_list = [] for index, tangent in enumerate(tangents): if tangent is not ad.zero: operands = list(primals) operands[index] = tangent operands_list.append(operands) out_primal = this_einsum(*primals) out_tangent = sum_tangents(this_einsum(*ops) for ops in operands_list) return [out_primal], [out_tangent] def _einsum_transpose_rule(cotangent, *primals, input_strings, output_string): index, = [i for i, p in enumerate(primals) if ad.is_undefined_primal(p)] subscripts = (','.join(input_strings[:index] + input_strings[index+1:]) + ',' + output_string + '->' + input_strings[index]) operands = primals[:index] + primals[index+1:] + tuple(cotangent) out = [None] * len(primals) out[index] = einsum(subscripts, *operands) return out einsum_p = core.Primitive('einsum') einsum_p.multiple_results = True einsum_p.def_impl(_einsum_impl) def generic_abstract_eval(*avals, **params): return pe.abstract_eval_fun(_einsum_impl, *avals, **params) einsum_p.def_abstract_eval(generic_abstract_eval) ad.primitive_jvps[einsum_p] = _einsum_jvp xla.initial_style_translations[einsum_p] = xla.lower_fun_initial_style(_einsum_impl) ad.primitive_transposes[einsum_p] = _einsum_transpose_rule # TODO(shoyer): batching rule (should be pretty easy) # batching.primitive_batchers[einsum_p] = _einsum_batching_rule #@title define `_parse_einsum_input` (from numpy) { display-mode: "form" } # from numpy.core.einsumfunc einsum_symbols = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ' einsum_symbols_set = set(einsum_symbols) # asarray = lambda x: x asarray = jnp.asarray def _parse_einsum_input(operands): """ A reproduction of einsum c side einsum parsing in python. Returns ------- input_strings : str Parsed input strings output_string : str Parsed output string operands : list of array_like The operands to use in the numpy contraction Examples -------- The operand list is simplified to reduce printing: >>> a = np.random.rand(4, 4) >>> b = np.random.rand(4, 4, 4) >>> __parse_einsum_input(('...a,...a->...', a, b)) ('za,xza', 'xz', [a, b]) >>> __parse_einsum_input((a, [Ellipsis, 0], b, [Ellipsis, 0])) ('za,xza', 'xz', [a, b]) """ if len(operands) == 0: raise ValueError("No input operands") if isinstance(operands[0], str): subscripts = operands[0].replace(" ", "") operands = [asarray(v) for v in operands[1:]] # Ensure all characters are valid for s in subscripts: if s in '.,->': continue if s not in einsum_symbols: raise ValueError("Character %s is not a valid symbol." % s) else: tmp_operands = list(operands) operand_list = [] subscript_list = [] for p in range(len(operands) // 2): operand_list.append(tmp_operands.pop(0)) subscript_list.append(tmp_operands.pop(0)) output_list = tmp_operands[-1] if len(tmp_operands) else None operands = [asarray(v) for v in operand_list] subscripts = "" last = len(subscript_list) - 1 for num, sub in enumerate(subscript_list): for s in sub: if s is Ellipsis: subscripts += "..." elif isinstance(s, int): subscripts += einsum_symbols[s] else: raise TypeError("For this input type lists must contain " "either int or Ellipsis") if num != last: subscripts += "," if output_list is not None: subscripts += "->" for s in output_list: if s is Ellipsis: subscripts += "..." elif isinstance(s, int): subscripts += einsum_symbols[s] else: raise TypeError("For this input type lists must contain " "either int or Ellipsis") # Check for proper "->" if ("-" in subscripts) or (">" in subscripts): invalid = (subscripts.count("-") > 1) or (subscripts.count(">") > 1) if invalid or (subscripts.count("->") != 1): raise ValueError("Subscripts can only contain one '->'.") # Parse ellipses if "." in subscripts: used = subscripts.replace(".", "").replace(",", "").replace("->", "") unused = list(einsum_symbols_set - set(used)) ellipse_inds = "".join(unused) longest = 0 if "->" in subscripts: input_tmp, output_sub = subscripts.split("->") split_subscripts = input_tmp.split(",") out_sub = True else: split_subscripts = subscripts.split(',') out_sub = False for num, sub in enumerate(split_subscripts): if "." in sub: if (sub.count(".") != 3) or (sub.count("...") != 1): raise ValueError("Invalid Ellipses.") # Take into account numerical values if operands[num].shape == (): ellipse_count = 0 else: ellipse_count = max(operands[num].ndim, 1) ellipse_count -= (len(sub) - 3) if ellipse_count > longest: longest = ellipse_count if ellipse_count < 0: raise ValueError("Ellipses lengths do not match.") elif ellipse_count == 0: split_subscripts[num] = sub.replace('...', '') else: rep_inds = ellipse_inds[-ellipse_count:] split_subscripts[num] = sub.replace('...', rep_inds) subscripts = ",".join(split_subscripts) if longest == 0: out_ellipse = "" else: out_ellipse = ellipse_inds[-longest:] if out_sub: subscripts += "->" + output_sub.replace("...", out_ellipse) else: # Special care for outputless ellipses output_subscript = "" tmp_subscripts = subscripts.replace(",", "") for s in sorted(set(tmp_subscripts)): if s not in (einsum_symbols): raise ValueError("Character %s is not a valid symbol." % s) if tmp_subscripts.count(s) == 1: output_subscript += s normal_inds = ''.join(sorted(set(output_subscript) - set(out_ellipse))) subscripts += "->" + out_ellipse + normal_inds # Build output string if does not exist if "->" in subscripts: input_subscripts, output_subscript = subscripts.split("->") else: input_subscripts = subscripts # Build output subscripts tmp_subscripts = subscripts.replace(",", "") output_subscript = "" for s in sorted(set(tmp_subscripts)): if s not in einsum_symbols: raise ValueError("Character %s is not a valid symbol." % s) if tmp_subscripts.count(s) == 1: output_subscript += s # Make sure output subscripts are in the input for char in output_subscript: if char not in input_subscripts: raise ValueError("Output character %s did not appear in the input" % char) # Make sure number operands is equivalent to the number of terms if len(input_subscripts.split(',')) != len(operands): raise ValueError("Number of einsum subscripts must be equal to the " "number of operands.") return (input_subscripts, output_subscript, operands) import jax import jax.test_util as jtu jax.make_jaxpr(partial(einsum, 'i,ij->ij'))(jnp.zeros((2,)), jnp.zeros((2, 3))) jax.make_jaxpr(partial(einsum, 'i,ij,jk->ij'))(jnp.zeros((2,)), jnp.zeros((2, 3)), jnp.zeros((3, 4))) def make_einsum_grad(subscripts): @jax.grad def f(*operands): return jnp.sum(einsum(subscripts, *operands) ** 2) return f jax.make_jaxpr(make_einsum_grad('ij,jk->ij'))(jnp.zeros((2, 3)), jnp.zeros((3, 4))) jax.jit(make_einsum_grad('ij,jk->ij'))(jnp.zeros((2, 3)), jnp.zeros((3, 4))) make_einsum_grad('ij,jk->ij')(jnp.zeros((2, 3)), jnp.zeros((3, 4))) from functools import partial import numpy as np rs = np.random.RandomState(0) f = partial(einsum, 'i,ij,j->ij') args = (rs.randn(2), rs.randn(2, 3), rs.randn(3,)) jtu.check_grads(f, args, order=2) from functools import partial import numpy as np rs = np.random.RandomState(0) operands = 'ijk,ij,jk->ij' f = partial(einsum, operands) args = (rs.randn(2, 3, 4), rs.randn(2, 3), rs.randn(3, 4)) jtu.check_grads(f, args, order=2) jax.make_jaxpr(make_einsum_grad(operands))(*args)