Copyright 2021 The RecSim Authors.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at

 http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

In [ ]:
# @title Install.
!pip3 install --upgrade -q --no-cache-dir recsim_ng
!pip3 install --upgrade -q --no-cache-dir edward2
In [ ]:
# @title  Imports and defs.

from typing import Any, Callable, Sequence

from IPython.display import HTML
import matplotlib as mpl
from matplotlib import animation
from matplotlib import pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp

import edward2 as ed
from recsim_ng.core import network as network_lib
from recsim_ng.core import variable
from recsim_ng.core import value
from recsim_ng.lib.tensorflow import entity
from recsim_ng.lib.tensorflow import field_spec
from recsim_ng.lib.tensorflow import runtime
from recsim_ng.core import value

from ipykernel.pylab import backend_inline
try:
  get_ipython().events.callbacks["post_execute"].remove(
      backend_inline.flush_figures)
except:

  # discard the exception if the callback has already been removed.
  pass

tfd = tfp.distributions
tfb = tfp.bijectors

mpl.style.use("classic")

Variable = variable.Variable
Value = value.Value
FieldSpec = value.FieldSpec
ValueSpec = value.ValueSpec
Space = field_spec.Space


def animate_game_of_life(cell_state_traj: tf.Tensor, horizon: int) -> Any:
  """Game of life animation in matplotlib."""
  fig, ax = plt.subplots(figsize=(5, 5))
  im = ax.imshow(
      255 * cell_state_traj[0],
      animated=True,
      cmap="gray_r",
      interpolation="none")

  def updatefig(*args):
    im.set_array(255 * cell_state_traj[args[0]])
    return im,

  try:
    ani = animation.FuncAnimation(fig, updatefig, frames=horizon, blit=True)
    return HTML(ani.to_jshtml())
  except:
    return

Introduction

At its core, RecSim NG is a software package for reasoning about recommender ecosystems consisting of agents (for example, users, content providers, advertisers, recommendation algorithms) interacting with each other in the context of a recommendation-driven service. There are two main components of RecSim NG:

  • a composable modeling API for specifying complex (possibly stochastic) agent behaviors (along with a library of reusable building blocks);
  • a runtime library which enables various reasoning tasks such as Monte Carlo simulation, marginal/posterior inference, model learning and others.

In this tutorial we review the basics of these APIs and learn how to set up simple determinstic simulations. We also overview the various things that RecSim NG enables us to do with them.

The RecSim NG Model

Abstractly, a RecSim model (also called a RecSim "simulation") represents a Markovian stochastic process. That is, we assume there is some state space $\cal S$ over which a Markov process gives rise to trajectories, that is, sequences of states in $\cal S$ of length $n$, $(s_i \in {\cal S})_{i=1}^n$, such that $$p(\left(s_i \in {\cal S}\right)_{i=0}^{n-1}) = p_0(s_0)\prod_{i}T(s_{i}, s_{i+1}),$$ where $p_0$ is an initial state density, and $T$ is a transition kernel.

We can think of $\cal S$, for example, as the state of the recommender ecosystem, which might itself simply be the collection of states of all the individual actors in it.

In most cases, $s \in \cal S$ is not just some vector in Euclidean space. Instead, $\cal S$ is a highly structured object consisting of sub-components (e.g., the states of the individual actors) which may evolve autonomously, or through interactions with each other. In RecSim NG, we assume that the state space factors via ${\cal S} = \times_i S^i$, where the $S^i$'s are individual components of the state space. For example, each component might reflect the state of some actor, or some random variable in the recommender environment. Such a factorization of the state space allows us to specify the initial state and transition kernel in a factored way as well. To do so, we use the language of Dynamic Bayesian Networks (DBNs).

Suppose that $\Gamma, \Gamma_{-1}$ are two (possibly different) directed acyclic graphs (DAGs) whose nodes are the components of the state space. We then assume that $$T(s_t, s_{t-1}) = \prod_i T^i(s_t^i | s^{\operatorname{Pa}_{\Gamma}(i)}_t, s_{t-1}^{{\operatorname{Pa}_{\Gamma_{-1}}(i)}}),$$ that is, the state of the component $i$ depends on the state of its parents in $\Gamma$ (its intra-slice dependencies), as well as the preceeding state of its parents in $\Gamma_{-1}$ (its inter-slice dependencies). The components $T^i$ of the transition kernel are termed conditional probability densities (CPDs) or sometimes factors. We assume that the initial state density can be factored similarly, but the graphs may be different.

The purpose of the RecSim NG modeling API is to allow the intuitive specification of components, their dependencies, and their stochastic behavior.

Hello world: counting to $n$.

We first introduce the RecSim modeling API by mapping these abstract concepts to code in a very simple model. Our first application is a process that counts from $0$ to $n-1$. In the language of DBNs, the state has a single component $S^0$, where ${\cal S} = S^0 = {\mathbb N}$. The dynamics of this process are specified as follows:

  • the initial state $s_0^0$ is always $0$;
  • the state updates as $s_t^0 = s_{t-1}^0 + 1$.

In RecSim NG, a component is represented by an instance of a Variable. A Variable is a component whose state space is a hierarchical dictionary (more on hierarchy later) with fixed keys, the values of which update deterministically or stochastically over time. A RecSim variable is declared as follows.

In [ ]:
count_var = Variable(name="count", spec=ValueSpec(n=FieldSpec()))

This line declares a variable whose name is "count" and whose state space consists of a dictionary having the single key "n". The declaration of the state space keys is accomplished by the ValueSpec object. Declaring a component with two keys might look like:

count_var = Variable(name="count", spec=ValueSpec(key_1=FieldSpec(), key_2=FieldSpec()))

meaning that the value of count_var is a dictionary of the form {"key_1": ..., "key_2":...}. The class FieldSpec can be used to supply additional hints as to what the values of those keys could be. We will ignore this for the time being.

We must now specify the generation of the initial state and transition kernel for count_var.

In [ ]:
def count_init() -> Value:
  return Value(n=0)

count_var.initial_value = variable.value(count_init)

The initial value distribution of a RecSim NG variable is defined by assigning a generating function to the .initial_value property of the variable. In our example, the initial state generating function count_init always returns the dictionary {"n": 0}. Note that methods which interact with variables must always return an instance of Value. The Value class enables compositions of models as well as various input/output validations to take place. We dive into this later on. It is sometimes more convenient to work with a Python dictionary, so we might use the equivalent pattern return Value(**{"n": 0}).

Now, let us specify the component dynamics of the "count" variable.

In [ ]:
def count_next(previous_value: Value) -> Value:
  n_prev = previous_value.get("n")
  return Value(n=n_prev + 1)

count_var.value = variable.value(count_next, (count_var.previous,))

This is very similar to the initial value specification, with the difference that we are now assigning the .value property. Note, moreover, that we have passed an additional argument to variable.value containing the singleton tuple count_var.previous. This declares a $\Gamma_{-1}$ dependency between "count" and itself, meaning that at time $t$ the simulator will pass the $t-1$ value of count_var as the previous_value argument of the update function count_next.

Line 2 in count_next also shows how we access fields of a Value object. This could also be done as n_prev = previous_value.as_dict["n"]. We could also update the previous state using Value.map(), which applies a function to all fields of a Value object. Here is an alternative definition:

In [ ]:
def count_next_map(previous_value: Value) -> Value:
  return previous_value.map(lambda x: x + 1)

The two are identical.

In [ ]:
print(count_next(Value(n=1)))
print(count_next_map(Value(n=1)))
Value[{'n': 2}]
Value[{'n': 2}]

We can now build a network from the individual components which we can then execute.

In [ ]:
count_network = network_lib.Network([count_var])
tf_runtime = runtime.TFRuntime(network=count_network)

A Network object represents a fully formed DBN. The default RecSim NG runtime allows us to "execute" this network in two ways.

In [ ]:
result = tf_runtime.execute(4)
print(result["count"])
Value[{'n': <tf.Tensor: shape=(), dtype=int32, numpy=4>}]

Runtime.execute(n) initializes the variables according to their initial_value definitions and returns a dictionary whose keys are the variable names and whose values are the state of these variables after $n$ state transitions.

In [ ]:
result = tf_runtime.trajectory(5)
print(result["count"])
Value[{'n': <tf.Tensor: shape=(5,), dtype=int32, numpy=array([0, 1, 2, 3, 4], dtype=int32)>}]

Runtime.trajectory(n) generates a trajectory of length $n$. As with .execute, it returns a dictionary of variable names; however, now the values contain the entire trajectory as an additional dimension.

Multiple variables: the Fibonacci numbers

Let's expand the above example to incorporate multiple interacting variables. The Fibonacci sequence is commonly defined with the following recursion: $f_0 = 0, f_1 = 1, f_n = f_{n-1} + f_{n-2}.$ We can implement this in RecSim NG using two variables as follows:

In [ ]:
fib0 = Variable(name="fib0", spec=ValueSpec(f0=FieldSpec()))
fib1 = Variable(name="fib1", spec=ValueSpec(f1=FieldSpec()))

def f0_init() -> Value:
  return Value(f0=0)

def f0_next(f1_previous: Value) -> Value:
  return Value(f0=f1_previous.get("f1"))

fib0.initial_value = variable.value(f0_init)
fib0.value = variable.value(f0_next, (fib1.previous,))

def f1_init() -> Value:
  return Value(f1=1)

def f1_next(f0_previous: Value, f1_previous: Value) -> Value:
  return Value(f1=f0_previous.get("f0") + f1_previous.get("f1"))

fib1.initial_value = variable.value(f1_init)
fib1.value = variable.value(f1_next, (fib0.previous, fib1.previous))

tf_runtime = runtime.TFRuntime(network=network_lib.Network([fib0, fib1]))
trajectory = tf_runtime.trajectory(5)
print(trajectory["fib0"])
print(trajectory["fib1"])
Value[{'f0': <tf.Tensor: shape=(5,), dtype=int32, numpy=array([0, 1, 1, 2, 3], dtype=int32)>}]
Value[{'f1': <tf.Tensor: shape=(5,), dtype=int32, numpy=array([1, 1, 2, 3, 5], dtype=int32)>}]

In this example, the fib0 variable acts as memory for the value of $f_{n-2}$. Specifically, it's update just needs to remember the previous value of fib1. The fib1 value updates to be the sum of the previous values of fib0 and fib1. Note that we could have just implemented this recursion with a single variable having two keys. The following code is functionally equivalent.

In [ ]:
def fib_init():
  return Value(f0=0, f1=1)


def fib_next(previous_value):
  return Value(
      f0=previous_value.get("f1"),
      f1=previous_value.get("f0") + previous_value.get("f1"))

fibonacci = Variable(name="fib", spec=ValueSpec(f0=FieldSpec(), f1=FieldSpec()))
fibonacci.initial_value = variable.value(fib_init)
fibonacci.value = variable.value(fib_next, (fibonacci.previous,))

tf_runtime = runtime.TFRuntime(
    network=network_lib.Network(variables=[fibonacci]))
print(tf_runtime.trajectory(5)["fib"])
Value[{'f0': <tf.Tensor: shape=(5,), dtype=int32, numpy=array([0, 1, 1, 2, 3], dtype=int32)>, 'f1': <tf.Tensor: shape=(5,), dtype=int32, numpy=array([1, 1, 2, 3, 5], dtype=int32)>}]

Is one of these preferred over the other? No. We can use either pattern depending on the application. Typically, this choice is determined by the semantics of the variables. For example, if we are modeling some complex behavior, such as a user reacting to a recommendation by generating a three reactions---say, a click, a rating, and a comment---then it might make sense to generate the these jointly as fields of a single variable (say, something like user_response). This choice also largely determines what behaviors can be modularized in the application. We discuss this further in the next section.

Entities, Behaviors, and Stories: the RecSim Game of Life

We have so far covered how to specify DBN components as RecSim variables and their dependencies, and how to run the resulting simulation in a RecSim runtime. RecSim variables are the "assembly language" of a RecSim model: they reflect the bare minimum information necessary to define a simulation. As such, the variable API leaves a lot of useful structure on the table.

In an agent-based simulation, there are two common sources of structure that can be harnessed to make complex models easier to implement:

  1. Parameter sharing/encapsulation: An agent may have multiple behaviors such as "choose," "observe," "update state", etc. All of these behaviors might be influenced by a common set of parameters, such as the agent's preferences, its risk sensitivity, its overall satisfaction, etc. This fits the object-oriented paradigm in which the various behaviors are expressed as methods and the shared parameters are encapsulated properties of the object. (In fact, agent-based modeling is often used to justify object-oriented programming.)

  2. Behavioral generalization: While every agent in a population is unique to some degree, not everything they do is idiosyncratic. In fact, most behavioral models tend to be expressed in terms of a population "template" that abstracts away the individuality into some small set of parameters. For example, a multinomial logit model for user choice (i.e., a model which specifies the probabilty with which a user selects a specific item from a slate of recommended items) posits that an agent will choose an item with probability proportional to:

    $$\exp(\operatorname{affinity model}(\mathrm{agent\ features}, \mathrm{item\ features})).$$ In this example, the agent features are personlized each agent. However, the affinity model and the choice distribution family (i.e., softmax) are common to the entire population. Templatized models of this from offer two critical beneifits: (a) they allow us to define large populations of agents very concisely; and (b) they make heavy use of the accelerated computation present on modern hardware (through, for example, highly parallel hardware architectures).

In RecSim, these two types of structure are harnessed through the Entity pattern. A RecSim Entity is a class that models the parameters and behaviors of an entire population of agents, making use of batched execution as much as possible. Let's illustrate this by implementing a RecSim version of Conway's Game of Life. The purpose of an Entity is to provide methods that update a RecSim variable(s). We call these behaviors.

The Game of Life is defined as follows:

  • The game consists of some number of agents (cells), typically arranged in a rectangular grid. Each cell has one bit of state reflecting whether it's currently alive.
  • At any point in time $t$, a cell counts its living neighbours (i.e., the occupants of the 8 adjacent grid points). If the number of living neighbors is less than 2 or more than 3, then the cell dies (if it was alive) due to under/overpopulation (if it was dead, it remains dead). If the number of living neighbors is exactly 3, the cell is reborn if it was dead (and remains alive if it was living).

The Game of Life embodies the principle of behavioral generalization in a very strict way: the rules are identical for every cell, and the outcome of the state transition is determined by a small set of parameters---the state of a cell and the states of its neighbors. In fact, we can use this to our advantage. By representing the state of the entire cell population as a binary rank-2 tensor, we can implement the neighbor-counting operation as a 2-dimensional convolution with the following kernel:

$$\begin{array} .1 & 1 & 1\\ 1 & 0 & 1\\ 1 & 1 & 1.\\ \end{array}$$

In RecSim, our cell population Entity has two behaviors,initializing a cell's state and updating the state. For this example, we pass the initial state as a parameter at construction time. The implementation of a cell Entity using a 2D convolution is given below.

In [ ]:
class Cell(entity.Entity):
  """Entity representing the dynamics of a population of cells."""

  def __init__(self, initial_configuration: tf.Tensor) -> None:
    """Stores the initial configuration and sets up the convolution layer."""
    super().__init__(name="Cell")
    self._initial_configuration = initial_configuration
    grid_dim = tf.shape(initial_configuration)[0]

    # Initialize the convolution layer.
    def kernel_init(shape: int, dtype=None) -> tf.Tensor:
      del shape, dtype
      return tf.constant([[1.0, 1.0, 1.0], [1.0, 0.0, 1.0],
                          [1.0, 1.0, 1.0]])[..., tf.newaxis, tf.newaxis]

    self._conv = tf.keras.layers.Conv2D(
        1, 3, use_bias=False, padding="same",
        input_shape=(grid_dim, grid_dim, 1),
        kernel_initializer=kernel_init)

  def initial_state(self) -> Value:
    # Returns the stored initial configuration.
    return Value(cell_state=self._initial_configuration)

  def next_state(self, old_state: Value) -> Value:
    # Updates the cell population state using the GoL rules.
    old_cell_state = old_state.get("cell_state")
    # Expand and then squeeze to accommodate the input shape expectations of the
    # convolutional layer.
    neighbor_count = self._conv(old_cell_state[tf.newaxis, ..., tf.newaxis])
    neighbor_count = tf.squeeze(neighbor_count)
    # Underpopulated or overpopulated cells die off.
    cell_state = tf.where(
        tf.logical_or(neighbor_count > 3.0, neighbor_count < 2.0), 0.0,
        old_cell_state)
    # New cells are born.
    cell_state = tf.where(
        tf.cast(tf.round(neighbor_count), tf.int32) == 3, 1.0, cell_state)
    return Value(cell_state=cell_state)

The Cell entity provides two behaviors: initial_state and next_state, which update a simulation variable.

In general there are no strict commitments as to what behaviors need to be implemented within an entity---this depends on the simulation. However, entities can be used to define interfaces (via abstract methods), which, together with stories, create modular templatized applications. We dive deeper into building applications in another tutorial.

The main requirement when implementing behaviors is that the method must return a Value object, as well as expect a Value for all of its non-optional arguments. This enables a behavior to receive inputs from other variables and have its output bound to some variable.

Finally, a RecSim story is just a function that encapsulates the creation of simulation variables, entities and the bindings between them. A story is expected to receive as inputs whatever parameters we choose to define our simulation, and output a list of well-formed simulation variables. In our example, we choose to pass in the cell entity constructor and the initial state as parameters.

Note also that we have passed a name string to the entity constructor with super().__init__(name="Cell"). The purpose of this will become clear in the next section.

In [ ]:
def game_of_life_story(cell_ctor: Callable[[tf.Tensor], entity.Entity],
                       initial_configuration: tf.Tensor) -> Sequence[Variable]:
  # Create simulation variable.
  cell_state = Variable(
      "cell_state_var", spec=ValueSpec(cell_state=FieldSpec()))
  # Create entity.
  cell = cell_ctor(initial_configuration)
  # Bind the simulation variable to the entity's behaviors.
  cell_state.initial_value = variable.value(cell.initial_state, ())
  cell_state.value = variable.value(cell.next_state, (cell_state.previous,))
  # Return the variable.
  return [cell_state]

Note that passing entity constructors to a story is a very useful pattern when building simulations: it allows us to swap out agent types on the fly (e.g., when comparing different user models, recommenders, etc.). At this point we can run our Game of Life simulation and look at some fascinating patterns, such as the "101" pattern.

In [ ]:
pattern_101 = tf.constant(
    [[0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0],
     [0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0],
     [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
     [1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1],
     [1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1],
     [0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0],
     [0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0],
     [1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1],
     [1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1],
     [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
     [0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0],
     [0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0]],
    dtype=tf.float32)
variables = game_of_life_story(Cell, pattern_101)
p101_runtime = runtime.TFRuntime(
    network=network_lib.Network(variables))
In [ ]:
trajectory = p101_runtime.trajectory(100)
animate_game_of_life(trajectory["cell_state_var"].get("cell_state"), 10)
Out[ ]:

We can also explore random initial configurations for larger populations.

In [ ]:
init_conf = ed.Bernoulli(0.1 * tf.ones((100, 100)), dtype=tf.float32)
variables = game_of_life_story(Cell, init_conf)
tf_runtime = runtime.TFRuntime(network=network_lib.Network(variables))
trajectory = tf_runtime.trajectory(100)
animate_game_of_life(trajectory["cell_state_var"].get("cell_state"), 20)
WARNING:tensorflow:5 out of the last 5 calls to <function TFRuntime.__init__.<locals>.trajectory_fn at 0x7fae14060280> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
Out[ ]:

Learning About Life: Differentiable Simulation

We have so far looked at creating and running RecSim NG simulations, as well as adding structure using entities and stories. Now that we're able to run simulations, let's illustrate a few possible use cases.

Being able to generate trajectories from a model alone is rarely the end of the story for a simulation project. Most of the time, we use simulation to do sampling-based estimation of various objectives that present themeselves as expectations over random variables and adapt model parameters based on these objectives, or perform other inference tasks.

Model learning

In the following we illustrate a simple example of learning simulation models in RecSim. In particular, we will attempt to learn the dynamics of the Game of Life dynamics using a neural net by feeding it example trajectories from the "ground truth" Game of Life simulation. This is similar to the task of fitting a learned dynamics model to a data set of trajectories generated in some real-world environment (though in this case, the ``real world'' is in our control).

We begin by implementing a new entity to represent the trainable cell population dynamics model. Like the ground truth Game of Life dynamics, to compute the next cell population state, the trainable entity applies a 2-dimensional convolution operator to the previous state; however, the results of the convolution, together with the previous state, the go through a couple of perceptron layers with ReLU activations. Our goal in this example will be to learn the weights of the perceptron layers together with the convolution kernel to imitate the rules of the Game of Life as well as possible.

The implementation of the trainable cell entity is given below.

In [ ]:
class NeuralCell(entity.Entity):
  """Cell population dynamics implemented with a neural network."""

  def __init__(self, initial_configuration: tf.Tensor) -> None:
    """Stores the initial configuration and sets up the convolution layer."""
    super().__init__(name="NeuralCell")
    self._initial_configuration = initial_configuration
    grid_dim = tf.shape(initial_configuration)[0]
    self._conv = tf.keras.layers.Conv2D(
        1, 3, use_bias=False, padding="same",
        input_shape=(grid_dim, grid_dim, 1))
    self._sequential = tf.keras.Sequential([
        tf.keras.layers.Dense(15, activation="relu"),
        tf.keras.layers.Dense(15, activation="relu"),
        tf.keras.layers.Dense(15, activation="relu"),
        tf.keras.layers.Dense(1, activation="relu")
    ])
    # Initialize all Keras weights.
    self.next_state(self.initial_state())

  def initial_state(self) -> Value:
    # Returns the stored initial configuration.
    return Value(cell_state=self._initial_configuration)

  def next_state(self, old_state: Value) -> Value:
    # GoL dynamics using a neural net.
    old_cell_state = old_state.get("cell_state")
    neighbor_conv = self._conv(old_cell_state[tf.newaxis, ..., tf.newaxis])
    neighbor_conv = tf.squeeze(neighbor_conv)
    sequential_inputs = tf.stack((neighbor_conv, old_cell_state), axis=-1)
    return Value(cell_state=tf.squeeze(self._sequential(sequential_inputs)))

Let's run the trainable cell and examine its outputs.

In [ ]:
neural_cell = NeuralCell(
    ed.Bernoulli(0.2 * tf.ones((10, 10), dtype=tf.float32), dtype=tf.float32))
next_state = neural_cell.next_state(neural_cell.initial_state())
print(next_state.get("cell_state"))
tf.Tensor(
[[0.         0.         0.         0.         0.         0.
  0.         0.         0.04322901 0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.        ]
 [0.00690586 0.         0.         0.         0.         0.
  0.         0.         0.         0.00602371]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.        ]
 [0.01382275 0.         0.         0.         0.         0.
  0.         0.         0.         0.        ]
 [0.00546031 0.         0.         0.         0.         0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.00116605 0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.04467456
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.        ]], shape=(10, 10), dtype=float32)

As expected, the trainable cell population's state outputs look nothing like Game of Life states---they're not even integer!

To be able to train the cell, we need access to it's trainable variables. Such access can be acquired in two different ways (among others). One (somewhat inconvenient way) is to create the weights of the various layers a priori and pass them as arguments to the cell at construction. This, however, could get very unwieldly in simulations that accept a wide range of hot-swappable entities.

To circumvent this, RecSim can capture all trainables encountered within a story using the entity.story_with_trainable_variables call.

In [ ]:
recsim_vars, trainable_vars = entity.story_with_trainable_variables(
    lambda: game_of_life_story(NeuralCell, init_conf))

The outputs are as follows:

  • recsim_vars is a list of simulation variables (identical to simply running game_of_life_story(NeuralCell, init_conf).
  • trainable_vars is a dictionary whose keys are the entity names (as specified in super().__init__(name=...) during the entity construction) and values are the results of entity_object.trainable_variables. The RecSim entity class inherits from tf.Module so these are found using reflection.
In [ ]:
tf.print(trainable_vars)
{'NeuralCell': ([[[[-0.0775674582]]

  [[0.334556699]]

  [[0.355787158]]]


 [[[0.0490371]]

  [[-0.435905695]]

  [[-0.435599029]]]


 [[[0.125528395]]

  [[0.411838531]]

  [[0.276391864]]]], [[0.0527278185 -0.508410335 -0.111527145 ... -0.207783192 0.380178511 0.131137908]
 [-0.439756274 -0.202046573 -0.48110348 ... -0.0908879638 -0.00762683153 -0.515277088]], [0 0 0 ... 0 0 0], [[-0.36731571 -0.031799227 0.277867556 ... 0.196124136 -0.282098055 -0.0157352686]
 [0.00674086809 0.275848329 0.0402102768 ... 0.134226561 0.322926462 0.109724045]
 [0.0659481287 -0.352079809 -0.386200339 ... 0.0993827581 0.424138248 -0.435369372]
 ...
 [-0.393255 0.273308516 -0.317875326 ... -0.385964155 0.115151584 0.0644667149]
 [0.141053438 -0.00288140774 -0.123884112 ... -0.00430825353 -0.298331916 0.174455881]
 [-0.414522439 -0.343310505 0.384101868 ... 0.357141316 0.413959146 0.102271855]], [0 0 0 ... 0 0 0], [[-0.432652056 -0.121035665 -0.418340981 ... 0.159153759 -0.302733243 0.382452905]
 [-0.248741165 -0.235086039 0.0331261754 ... -0.144766331 -0.174206495 -0.320373654]
 [-0.125471115 -0.426644534 -0.227025479 ... -0.138436586 0.326952219 0.353766561]
 ...
 [0.201197505 -0.118076831 -0.279283 ... -0.394584835 0.230459392 -0.121304244]
 [0.140716434 0.17597872 0.381311178 ... 0.331919134 0.0701503754 0.377870142]
 [-0.37088263 0.195860326 0.382169187 ... 0.0805368423 -0.113157094 0.259937227]], [0 0 0 ... 0 0 0], [[-0.0329420567]
 [-0.10004586]
 [-0.0730355382]
 ...
 [0.15977484]
 [0.121239066]
 [-0.347004294]], [0])}

We can now implement a basic stochastic training loop as follows. We use two RecSim runtimes: one with the ground truth Game of Life simulation and another with the trainable one. At each iteration, we draw a random initial state (cell configuration), generate a trajectory from both runtimes, and then attempt to reduce the squared difference between the two as a function of the trainable parameters.

This also demonstrates another runtime feature. We often need to run multiple simulation trajectories using the same initial state, so RecSim allows one to provide an initial state using the starting_value argument. Its datatype is NetworkValue, i.e., a dictionary of the form {"variable name": Value}.

In [ ]:
recsim_vars, trainable_vars = entity.story_with_trainable_variables(
    lambda: game_of_life_story(NeuralCell, init_conf))
trainable_tf_runtime = runtime.TFRuntime(
    network=network_lib.Network(recsim_vars))
loss = lambda label, pred: tf.reduce_sum(
    tf.math.squared_difference(label, pred))
optimizer = tf.keras.optimizers.Adam(learning_rate=0.01)


@tf.function
def training_step():
  with tf.GradientTape() as tape:
    random_initial_state = ed.Bernoulli(
        0.1 * tf.ones((100, 100)), dtype=tf.float32)
    initial_network_value = {
        "cell_state_var": Value(cell_state=random_initial_state)
    }
    gt_trajectory = tf_runtime.trajectory(
        length=5, starting_value=initial_network_value)
    predicted_trajectory = trainable_tf_runtime.trajectory(
        length=5, starting_value=initial_network_value)
    objective = loss(gt_trajectory["cell_state_var"].get("cell_state"),
                     predicted_trajectory["cell_state_var"].get("cell_state"))
  grads = tape.gradient(objective, trainable_vars["NeuralCell"])
  optimizer.apply_gradients(zip(grads, trainable_vars["NeuralCell"]))
  return objective


for i in range(1001):
  obj = training_step()
  if not i % 100:
    print(f"Iteration {i}, loss {obj.numpy()}")
Iteration 0, loss 9298.2216796875
Iteration 100, loss 4150.26904296875
Iteration 200, loss 328.1241455078125
Iteration 300, loss 0.07827343046665192
Iteration 400, loss 0.004944309126585722
Iteration 500, loss 0.0003180106286890805
Iteration 600, loss 1.3583032341557555e-05
Iteration 700, loss 3.0732940103916917e-07
Iteration 800, loss 6.824411258321561e-09
Iteration 900, loss 5.6892730526669766e-09
Iteration 1000, loss 5.625295784739137e-09

(In case this loop doesn't coverge to a loss on the order of $10^{-2}$, it might have to be rerun a few times.)

We can now visualize the results.

In [ ]:
predicted_trajectory = trainable_tf_runtime.trajectory(length=50)
animate_game_of_life(predicted_trajectory["cell_state_var"].get("cell_state"),
                     20)
WARNING:tensorflow:6 out of the last 6 calls to <function TFRuntime.__init__.<locals>.trajectory_fn at 0x7fadfc336a60> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
Out[ ]:

It looks like the network has learned to reproduce the Game of Life dynamics rather convincingly. It's not perfect---indeed, some anomalous artifacts can be seen in latter iterations, as we do not enforce integrality of the cell state in any way. We could easily fix this by rounding.

Inference

Having a differentiable Game of Life model offers some interesting benefits. By construction, the model outputs are differentiable with respect to the model parameters. Perhaps more interestingly, the model outputs are differentiable with respect to the initial state as well. This opens the door to some interesting use cases. For example, suppose we are given a Game of Life trajectory $(s_1, \ldots, s_k)$ that is missing its initial state $s_0$. We can try to reconstruct the initial state as: $$s^*_0 = \arg\min_{\hat{s}_0} \operatorname{loss}((s_1, \ldots, s_k), (\hat{s}_1, \ldots, \hat{s}_k)),$$ where $(\hat{s}_1, \ldots, \hat{s}_k)$ is the sequence of states generated by starting the Game of Life at $\hat{s}_0$. Since our learned model is differentiable with respect to the previous state, we can attempt to solve this problem by gradient descent on the loss function. (For simplicity, we make no attempt to ensure integrality of the state).

Let's try this out on the 101 pattern. Since we want to reuse our runtimes, which are built to work with (100, 100) grids, we pad the pattern with zeroes.

In [ ]:
p101_expanded = tf.concat((pattern_101, tf.zeros((12, 82))), axis=1)
p101_expanded = tf.concat((p101_expanded, tf.zeros((88, 100))), axis=0)
animate_game_of_life(p101_expanded[tf.newaxis, ...], 1)
Out[ ]:

Now let's set up the learning loop. We generate a "target" trajectory by starting the ground truth runtime with the above configuration. We generate two time steps. We then create a trainable variable to represent the unknown initial state and generate two state transitions using our differentiable cell model starting from that unknown initial state. We then differentiate the difference between the target and trainable trajectories as a function of the trainable initial state.

In [ ]:
loss = lambda label, pred: tf.reduce_sum(
    tf.math.squared_difference(label, pred))
optimizer = tf.keras.optimizers.Adam(learning_rate=0.05)
# Initialize the differentiable initial state variable.
initial_loc = np.zeros((100, 100), dtype=np.float32)
initial_loc[:12, :18] = 0.5
initial_scale = np.zeros((100, 100), dtype=np.float32)
initial_scale[:12, :18] = 0.5
trainable_initial_state = tf.Variable(
    ed.Normal(loc=initial_loc, scale=initial_scale),
    constraint=lambda x: tf.clip_by_value(x, 0.0, 1.0))
# Generate the target trajectory.
gt_network_value = {"cell_state_var": Value(cell_state=p101_expanded)}
target_trajectory = tf_runtime.trajectory(
    length=3, starting_value=gt_network_value)


@tf.function
def training_step_end_state():
  with tf.GradientTape() as tape:
    trainable_network_value = {
        "cell_state_var": Value(cell_state=trainable_initial_state)
    }
    predicted_trajectory = trainable_tf_runtime.trajectory(
        length=3, starting_value=trainable_network_value)
    # Filter out the first step to avoid revealing the ground truth and
    # constrain the loss to the top 12 x 18 rectangle.
    objective = loss(
        target_trajectory["cell_state_var"].get("cell_state")[1:, :12, :18],
        predicted_trajectory["cell_state_var"].get("cell_state")[1:, :12, :18])
  grads = tape.gradient(objective, [trainable_initial_state])
  optimizer.apply_gradients(zip(grads, [trainable_initial_state]))
  return objective


for i in range(4000):
  obj = training_step_end_state()
  if not i % 500:
    print(f"Iteration {i}, loss {obj.numpy()}")
WARNING:tensorflow:7 out of the last 7 calls to <function TFRuntime.__init__.<locals>.trajectory_fn at 0x7fae14060280> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
Iteration 0, loss 161.31173706054688
Iteration 500, loss 63.2030029296875
Iteration 1000, loss 62.88899612426758
Iteration 1500, loss 53.366485595703125
Iteration 2000, loss 39.3611946105957
Iteration 2500, loss 38.90966796875
Iteration 3000, loss 38.79045104980469
Iteration 3500, loss 38.708499908447266

We can now visualize the result.

In [ ]:
trainable_network_value = {
    "cell_state_var": Value(cell_state=trainable_initial_state)
}
predicted_trajectory = trainable_tf_runtime.trajectory(
    length=10, starting_value=trainable_network_value)
animate_game_of_life(predicted_trajectory["cell_state_var"].get("cell_state"),
                     10)
WARNING:tensorflow:8 out of the last 8 calls to <function TFRuntime.__init__.<locals>.trajectory_fn at 0x7fadfc336a60> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
Out[ ]: