# Copyright 2020 Google LLC
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# https://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import jax
from jax import random
def _with_key(func):
def method(self, *args, **kwargs):
return func(self.key, *args, **kwargs)
return method
class RNG:
def __init__(self, key):
self.key = key
def __repr__(self):
return f'{type(self).__name__}({self.key!r})'
def split(self, num=2):
return [RNG(k) for k in random.split(self.key, num)]
uniform = _with_key(random.uniform)
normal = _with_key(random.normal)
jax.tree_util.register_pytree_node(
RNG,
lambda rng: ([rng.key], None),
lambda aux, values: RNG(values[0]),
)
def rng(seed):
return RNG(random.PRNGKey(seed))
rng(10).uniform()
DeviceArray(0.08938682, dtype=float32)
keys = jax.vmap(random.PRNGKey)(jax.numpy.arange(3))
samples = jax.vmap(random.uniform)(keys)
samples
DeviceArray([0.41845703, 0.11815023, 0.4240216 ], dtype=float32)
jax.vmap(RNG.uniform)(jax.vmap(rng)(jax.numpy.arange(3)))
DeviceArray([0.41845703, 0.11815023, 0.4240216 ], dtype=float32)
@jax.jit
def split_and_sample(rng):
rng, sub_rng = rng.split()
val = sub_rng.normal(shape=(3,))
return rng, val
split_and_sample(rng(10))
(RNG(DeviceArray([3912842007, 31661381], dtype=uint32)), DeviceArray([0.47754696, 0.2578578 , 2.4254863 ], dtype=float32))