# @title Install. !pip3 install --upgrade -q --no-cache-dir recsim_ng !pip3 install --upgrade -q --no-cache-dir edward2 # @title Imports and defs. from typing import Any, Callable, Mapping, Sequence, Text import functools import matplotlib as mpl from matplotlib import animation from matplotlib import pyplot as plt import numpy as np import tensorflow as tf import tensorflow_probability as tfp import edward2 as ed from recsim_ng.core import network as network_lib from recsim_ng.core import variable from recsim_ng.core import value from recsim_ng.lib import data from recsim_ng.lib.tensorflow import entity from recsim_ng.lib.tensorflow import log_probability from recsim_ng.lib.tensorflow import runtime from recsim_ng.core import value tfd = tfp.distributions tfb = tfp.bijectors mpl.style.use("classic") Variable = variable.Variable Value = value.Value FieldSpec = value.FieldSpec ValueSpec = value.ValueSpec normal_scale = tf.Variable(1.0) def rw_next(previous_state: Value) -> Value: return Value( state=ed.Normal(loc=previous_state.get("state"), scale=normal_scale)) random_walk_var = Variable(name="rw_var", spec=ValueSpec(state=FieldSpec())) random_walk_var.initial_value = variable.value( lambda: Value(state=ed.Normal(loc=0.0, scale=1.0))) random_walk_var.value = variable.value(rw_next, (random_walk_var.previous,)) tf_runtime = runtime.TFRuntime(network=network_lib.Network([random_walk_var])) horizon = 200 trajectories = [tf_runtime.trajectory(horizon) for _ in range(5)] _, ax = plt.subplots(figsize=(8, 3)) for trajectory in trajectories: ax.plot(range(horizon), trajectory["rw_var"].get("state")) log_probs = [ log_probability.log_probability_from_value_trajectory( [random_walk_var], traj, horizon - 1).numpy() for traj in trajectories ] print(log_probs) with tf.GradientTape() as tape: log_probs = [ log_probability.log_probability_from_value_trajectory([random_walk_var], traj, 100) for traj in trajectories ] negative_likelihood = -tf.reduce_sum(log_probs) grad = tape.gradient(negative_likelihood, normal_scale) print(grad) # Get input data trajectory. trajectory = trajectories[0] rw_trajectory = trajectory["rw_var"] # This is just a horizon x 1 tensor of values wrapped in a Value object. # Inputs like these could come from real-world logs. print("Trajectory shape:", rw_trajectory.get("state").shape) # Now we create the replay variable. replay_var = data.data_variable( name=random_walk_var.name + "_replay", spec=random_walk_var.spec, data_sequence=data.SlicedValue(value=trajectory[random_walk_var.name])) # replay_var now mimics random_walk_var, except that it replays values from # the input trajectory. For example: print("replay_var's initial state:", replay_var.initial_value.fn().get("state"), "\n is identical to the logged trajectory at index 0:", trajectory["rw_var"].get("state")[0]) # Now we create the corresponding transformed log prob variable: log_prob_var = log_probability.log_prob_variables_from_observation( [random_walk_var], [replay_var]) # We can have an aggregator var that sums the log probabilities over time and # keys. aggregators = log_probability.log_prob_accumulator_variables(log_prob_var) # We can also have a global accumulator that adds the log probabilities of all # variables together, but this is redundant since we have only one. # aggregators.append( # log_probability.total_log_prob_accumulator_variable(log_prob_vars)) # We can now run this network tf_runtime = runtime.TFRuntime( network=network_lib.Network(variables=[replay_var] + log_prob_var + aggregators)) # We can either get the aggregated outputs at the last step, which is equivalent # to the previous example: print("Log probability of last step: ", tf_runtime.execute(horizon - 1)["rw_var_log_prob_accum"]) # or get the entire trajectory of log probs: log_prob_trajectory = tf_runtime.trajectory(horizon) print("Disaggregated log probabilities: ", log_prob_trajectory["rw_var_log_prob"]) # For reference, the trajectory contains the following keys: print("Trajectory keys:", log_prob_trajectory.keys()) # corresponding to the replay var, the log probability var and the accumulator. def rw_next(previous_state: Value) -> Value: return Value(state=ed.Normal(loc=previous_state.get("state"), scale=1.0)) random_walk_var = Variable(name="rw_var", spec=ValueSpec(state=FieldSpec())) random_walk_var.initial_value = variable.value( lambda: Value(state=ed.Normal(loc=0.0, scale=1.0))) random_walk_var.value = variable.value(rw_next, (random_walk_var.previous,)) rw_log_prob = log_probability.log_prob_variables_from_direct_output( [random_walk_var])[0] tf_runtime = runtime.TFRuntime( network=network_lib.Network([random_walk_var, rw_log_prob])) trajectory = tf_runtime.trajectory(2) print("Simulated random walk: ", trajectory["rw_var"]) print("Log probabilities: ", trajectory["rw_var_log_prob"]) def rw_next(previous_state: Value) -> Value: return Value( state=tf.random.normal([1], previous_state.get("state"), 1.0, tf.float32)) random_walk_var = Variable(name="rw_var", spec=ValueSpec(state=FieldSpec())) random_walk_var.initial_value = variable.value( lambda: Value(state=tf.random.normal([1], 0.0, 1.0, tf.float32))) random_walk_var.value = variable.value(rw_next, (random_walk_var.previous,)) tf_runtime = runtime.TFRuntime(network=network_lib.Network([random_walk_var])) trajectories = [tf_runtime.trajectory(horizon) for _ in range(5)] _, ax = plt.subplots(figsize=(8, 3)) for trajectory in trajectories: ax.plot(range(horizon), trajectory["rw_var"].get("state")) log_probs = [ log_probability.log_probability_from_value_trajectory( [random_walk_var], traj, horizon - 1).numpy() for traj in trajectories ] print(log_probs) def rw_next(previous_state: Value) -> Value: return Value( state=previous_state.get("state") + ed.Normal(loc=0.0, scale=1.0)) random_walk_var = Variable(name="rw_var", spec=ValueSpec(state=FieldSpec())) random_walk_var.initial_value = variable.value( lambda: Value(state=0.0 + ed.Normal(loc=0.0, scale=1.0))) random_walk_var.value = variable.value(rw_next, (random_walk_var.previous,)) tf_runtime = runtime.TFRuntime(network=network_lib.Network([random_walk_var])) trajectories = [tf_runtime.trajectory(horizon) for _ in range(5)] try: log_probs = [ log_probability.log_probability_from_value_trajectory( [random_walk_var], traj, horizon - 1).numpy() for traj in trajectories ] except Exception as e: print(e) sample = tf.random.normal(shape=[2], mean=0.0, stddev=1.0) indep_sample_rv = ed.Normal(loc=0.0, scale=1.0, sample_shape=[2]) print("Log prob with sample shape [2]:", indep_sample_rv.distribution.log_prob(sample)) batch_sample_rv = ed.Normal(loc=[0.0, 0.0], scale=[1.0, 1.0]) print("Log prob with batch shape [2]:", batch_sample_rv.distribution.log_prob(sample)) event_sample_rv = ed.MultivariateNormalDiag( loc=[0.0, 0.0], scale_identity_multiplier=1.0) print("Log prob with event shape [2]:", event_sample_rv.distribution.log_prob(sample)) def rw_next(previous_state: Value) -> Value: # Next state samples from a batch of two normals. return Value(state=ed.Normal(loc=previous_state.get("state"), scale=1.0)) random_walk_var = Variable(name="rw_var", spec=ValueSpec(state=FieldSpec())) # Initial state samples from a multivariate normal. random_walk_var.initial_value = variable.value(lambda: Value( state=ed.MultivariateNormalDiag( loc=[0.0, 0.0], scale_identity_multiplier=1.0))) random_walk_var.value = variable.value(rw_next, (random_walk_var.previous,)) tf_runtime = runtime.TFRuntime(network=network_lib.Network([random_walk_var])) trajectories = [tf_runtime.trajectory(horizon) for _ in range(5)] try: log_probs = [ log_probability.log_probability_from_value_trajectory( [random_walk_var], traj, horizon - 1).numpy() for traj in trajectories ] except Exception as e: print(e) # @title RecSimNG modeling imports. from recsim_ng.stories import recommendation_simulation as simulation from recsim_ng.applications.latent_variable_model_learning import recommender from recsim_ng.entities.recommendation import user from recsim_ng.entities.state_models import static from recsim_ng.entities.choice_models import selectors from recsim_ng.entities.choice_models import affinities class ModelLearningDemoUser(user.User): def __init__(self, config: Mapping[Text, Any], satisfaction_sensitivity: tf.Tensor, name: Text = 'ModelLearningDemoUser') -> None: super().__init__(config, name) self._slate_size = config.get('slate_size') # Hardcoded parameter values. self._user_intent_variance = 0.1 self._initial_satisfication = 5.0 # Unknown satisfaction sensitivity. self._sat_sensitivity = satisfaction_sensitivity # The intent model as a GMM state model from the RecSim NG state # model library. batch_intent_means = tf.eye( 2, num_columns=2, batch_shape=(self._num_users,)) lop_ctor = lambda params: tf.linalg.LinearOperatorScaledIdentity( num_rows=2, multiplier=params) self._intent_model = static.GMMVector( batch_ndims=1, mixture_logits=tf.zeros((self._num_users, 2)), component_means=batch_intent_means, component_scales=tf.sqrt(self._user_intent_variance), linear_operator_ctor=lop_ctor) # The choice model is a multinomial logit choice model from the RecSim NG # choice model library. self._choice_model = selectors.MultinomialLogitChoiceModel( batch_shape=(self._num_users,), nochoice_logits=tf.ones(self._num_users)) # The affinity model is a target point similarity model, which by default # computes the negative Euclidean distance between the target point and the # item embedding. self._affinity_model = affinities.TargetPointSimilarity( batch_shape=(self._num_users,), slate_size=self._slate_size) def initial_state(self) -> Value: """Initial state value.""" return Value( satisfaction=ed.Deterministic(self._initial_satisfication * tf.ones(self._num_users)), intent=self._intent_model.initial_state().get('state'), max_slate_utility=tf.zeros(self._num_users)) def next_state(self, previous_state: Value, _, slate_docs: Value) -> Value: """State transition kernel.""" # Compute the improvement of slate scores. slate_doc_features = slate_docs.get('features') slate_doc_affinities = self._affinity_model.affinities( previous_state.get('intent'), slate_doc_features).get('affinities') max_slate_utility = tf.reduce_max(slate_doc_affinities, axis=-1) improvement = max_slate_utility - previous_state.get('max_slate_utility') next_satisfaction = self._sat_sensitivity * previous_state.get( 'satisfaction') + improvement return Value( satisfaction=ed.Normal(loc=next_satisfaction, scale=0.01), intent=self._intent_model.next_state( Value(state=previous_state.get('intent'))).get('state'), max_slate_utility=max_slate_utility) def next_response(self, previous_state: Value, slate_docs: Value) -> Value: """The response value after the initial value.""" slate_doc_features = slate_docs.get('features') slate_doc_scores = self._affinity_model.affinities( previous_state.get('intent'), slate_doc_features).get('affinities') # Adding the user's satisfaction to the item scores. adjusted_scores = ( slate_doc_scores + previous_state.get('satisfaction')[..., tf.newaxis]) return self._choice_model.choice(adjusted_scores) def observation(self): pass def specs(self) -> ValueSpec: response_spec = self._choice_model.specs() state_spec = ValueSpec( intent=self._intent_model.specs().get('state'), satisfaction=FieldSpec(), max_slate_utility=FieldSpec()) return state_spec.prefixed_with('state').union( response_spec.prefixed_with('response')) # Initialize simulation parameters gt_satisfaction_sensitivity = 0.8 * tf.ones(5) num_users = 5 num_topics = 2 horizon = 6 config = { "slate_size": 2, "num_users": num_users, "num_topics": num_topics, "num_docs": 0 } # Set up ground truth runtime. gt_user_ctor = functools.partial( ModelLearningDemoUser, satisfaction_sensitivity=gt_satisfaction_sensitivity) gt_variables = simulation.simplified_recs_story(config, gt_user_ctor, recommender.SimpleNormalRecommender) gt_network = network_lib.Network(variables=gt_variables) gt_runtime = runtime.TFRuntime(network=gt_network) traj = dict(gt_runtime.trajectory(length=horizon)) print('===============GROUND TRUTH LIKELIHOOD================') print( log_probability.log_probability_from_value_trajectory( variables=gt_variables, value_trajectory=traj, num_steps=horizon - 1)) print('======================================================') trainable_sat_sensitivity = tf.Variable( tf.math.sigmoid(ed.Normal(loc=tf.zeros(5), scale=1.0))) trainable_user_ctor = functools.partial( ModelLearningDemoUser, satisfaction_sensitivity=trainable_sat_sensitivity) t_variables = simulation.simplified_recs_story( config, trainable_user_ctor, recommender.SimpleNormalRecommender) trainable_network = network_lib.Network(variables=t_variables) trainable_runtime = runtime.TFRuntime(network=trainable_network) print('===============UNTRAINED LIKELIHOOD================') print( log_probability.log_probability_from_value_trajectory( variables=t_variables, value_trajectory=traj, num_steps=horizon - 1)) print('======================================================') optimizer = tf.keras.optimizers.Adam(learning_rate=0.01) @tf.function def training_step(): with tf.GradientTape() as tape: gt_trajectory = gt_runtime.trajectory(length=horizon) neg_likelihood = -log_probability.log_probability_from_value_trajectory( variables=t_variables, value_trajectory=traj, num_steps=horizon - 1) grads = tape.gradient(neg_likelihood, [trainable_sat_sensitivity]) optimizer.apply_gradients(zip(grads, [trainable_sat_sensitivity])) return neg_likelihood likelihood_history = [] print(f"Parameters before training: {trainable_sat_sensitivity.numpy()}") for i in range(101): obj = training_step() likelihood_history.append(-obj) if not i % 20: print(f"Iteration {i}, negative likelihood {obj.numpy()}") print(f"Parameter values: {trainable_sat_sensitivity.numpy()}") _, ax = plt.subplots(figsize=(8, 3)) ax.plot(range(len(likelihood_history)), likelihood_history) plt.xlabel('training iteration') plt.ylabel('trainable model likelihood') plt.show() # Reinitialize trainable variables. trainable_sat_sensitivity.assign( tf.math.sigmoid(ed.Normal(loc=tf.zeros(5), scale=1.0))) @tf.function def unnormalized_log_prob_train(intent: tf.Tensor) -> tf.Tensor: # Expand initial intent to complete intent trajectory. intent_traj = tf.expand_dims( intent, axis=0) + tf.zeros((horizon, num_users, num_topics)) # Combine intent trajectory with the observed data. user_state_dict = dict(traj['user state'].as_dict) user_state_dict['intent'] = intent_traj traj['user state'] = Value(**user_state_dict) # Return the log probability of the imputed intent + observations. return log_probability.log_probability_from_value_trajectory( variables=t_variables, value_trajectory=traj, num_steps=horizon - 1) num_results = int(2e3) num_burnin_steps = int(5e2) adaptive_hmc = tfp.mcmc.SimpleStepSizeAdaptation( tfp.mcmc.HamiltonianMonteCarlo( target_log_prob_fn=unnormalized_log_prob_train, num_leapfrog_steps=5, step_size=.00008), num_adaptation_steps=int(num_burnin_steps * 0.8)) # Run the chain (with burn-in). @tf.function def run_chain(): samples, is_accepted = tfp.mcmc.sample_chain( num_results=num_results, num_burnin_steps=num_burnin_steps, current_state=tfd.Normal(loc=tf.ones((5, 2)) / 5, scale=0.5).sample(), kernel=adaptive_hmc, trace_fn=lambda _, pkr: pkr.inner_results.is_accepted) sample_mean = tf.reduce_mean(samples) sample_stddev = tf.math.reduce_std(samples) is_accepted = tf.reduce_mean(tf.cast(is_accepted, dtype=tf.float32)) return samples, sample_mean, sample_stddev, is_accepted #@test {"skip": true} # Initialize the HMC transition kernel. optimizer = tf.keras.optimizers.Adam(learning_rate=0.01) elbo_history = [] for i in range(101): posterior_samples, sample_mean, sample_stddev, is_accepted = run_chain() log_probs = [] with tf.GradientTape() as tape: log_probs = tf.vectorized_map(unnormalized_log_prob_train, posterior_samples[num_burnin_steps:,]) neg_elbo = -tf.reduce_mean(log_probs) grads = tape.gradient(neg_elbo, [trainable_sat_sensitivity]) optimizer.apply_gradients(zip(grads, [trainable_sat_sensitivity])) elbo_history.append(-neg_elbo) if not i % 5: print(f"Iteration {i}, unnormalized negative ELBO {neg_elbo.numpy()}") print(f"Parameter values: {trainable_sat_sensitivity.numpy()}") print(f"Acceptance rate: {is_accepted.numpy()}") #@test {"skip": true} _, ax = plt.subplots(figsize=(8, 3)) ax.plot(range(len(elbo_history)), elbo_history) plt.xlabel('training iteration') plt.ylabel('ELBO')