#!/usr/bin/env python # coding: utf-8 # # Subclassing Bijectors # # > In this post, we are going to make customized transformation with our own bijectors for fexibility. import tensorflow as tf
import tensorflow_probability as tfp
import numpy as np
import matplotlib.pyplot as plt

tfd = tfp.distributions
tfpl = tfp.layers
tfb = tfp.bijectors

plt.rcParams['figure.figsize'] = (10, 6)


# In[2]:


print("Tensorflow Version: ", tf.__version__)
print("Tensorflow Probability Version: ", tfp.__version__)


# ## Overview
# 
# We can create our bijector by using bijector subclassing.
# 
# ### Sample bijector

# In[3]:


class MySigmoid(tfb.Bijector):
    def __init__(self, validate_args=False, name='sigmoid'):
        super(MySigmoid, self).__init__(validate_args=validate_args,
                                        forward_min_event_ndims=0,
                                        name=name)
    
    def _forward(self, x):
        return tf.math.sigmoid(x)
    
    def _inverse(self, y):
        return tf.math.log(y) - tf.math.log(1 - y)
    
    def _inverse_log_det_jacobian(self, y):
        return -tf.math.log(y) - tf.math.log(1 - y)
    
    def _forward_log_det_jacobian(self, x):
        return -self.inverse_log_det_jacobian(self._forward(x)) Note that, while implementing our bijector, the name of method should contain underscore in front of the name itself. # ### Same bijector (implementing forward log det jacobian) # In[4]: class MySigmoid(tfb.Bijector): def __init__(self, validate_args=False, name='sigmoid'): super(MySigmoid, self).__init__(validate_args=validate_args, forward_min_event_ndims=0, name=name) def _forward(self, x): return tf.math.sigmoid(x) def _inverse(self, y): return tf.math.log(y) - tf.math.log(1 - y) def _inverse_log_det_jacobian(self, y): return -_forward_log_det_jacobian(self._inverse(y)) def _forward_log_det_jacobian(self, x): return -tf.math.softplus(-x) - tf.math.softplus(x) # ### Simple shift bijector # In[5]: class MyShift(tfb.Bijector): def __init__(self, shift, validate_args=False, name='shift'): self.shift = shift super(MyShift, self).__init__(validate_args=validate_args, forward_min_event_ndims=0, name=name, is_constant_jacobian=True) def _forward(self, x): return x + self.shift def _inverse(self, y): return y - self.shift def _forward_log_det_jacobian(self, x): return tf.constant(0., x.dtype) # ### Simple Tanh bijector # In[7]: class MyTanh(tfb.Bijector): def __init__(self, validate_args=False, name='cube'): super(MyTanh, self).__init__(validate_args=validate_args, forward_min_event_ndims=0, name=name) def _forward(self, x): return tf.math.tanh(x) def _inverse(self, y): return tf.math.atanh(y) def _forward_log_det_jacobian(self, x): return tf.math.log1p(-tf.square(tf.tanh(x))) # ## Tutorial # In[12]: # Define a new bijector: Cubic class Cubic(tfb.Bijector): def __init__(self, a, b, validate_args=False, name='Cubic'): self.a = tf.cast(a, tf.float32) self.b = tf.cast(b, tf.float32) if validate_args: assert tf.reduce_mean(tf.cast(tf.math.greater_equal(tf.abs(self.a), 1e-5), tf.float32)) == 1.0 assert tf.reduce_mean(tf.cast(tf.math.greater_equal(tf.abs(self.b), 1e-5), tf.float32)) == 1.0 super(Cubic, self).__init__(validate_args=validate_args, forward_min_event_ndims=0, name=name) def _forward(self, x): x = tf.cast(x, tf.float32) return tf.squeeze(tf.pow(self.a * x + self.b, 3)) def _inverse(self, y): y = tf.cast(y, tf.float32) return (tf.math.sign(y) * tf.pow(tf.abs(y), 1/3) - self.b) / self.a def _forward_log_det_jacobian(self, x): x = tf.cast(x, tf.float32) return tf.math.log(3. * tf.abs(self.a)) + 2. * tf.math.log(tf.abs(self.a * x + self.b)) # In[13]: # Cubic bijector cubic = Cubic([1.0, -2.0], [-1.0, 0.4], validate_args=True) # In[14]: # Apply forward transformation x = tf.constant([[1, 2], [3, 4]]) y = cubic.forward(x) y # In[16]: # Check inverse np.linalg.norm(x - cubic.inverse(y)) # ### Function plots # In[17]: # Plot the forward transformation x = np.linspace(-10, 10, 500).reshape(-1, 1) plt.plot(x, cubic.forward(x)) plt.show() # In[18]: # Display shape cubic.forward(x).shape # In[19]: # Plot the inverse plt.plot(x, cubic.inverse(x)) plt.show() # In[20]: # Plot the forward log det jacobian determinant plt.plot(x, cubic.forward_log_det_jacobian(x, event_ndims=0)) plt.show() # In[21]: # Plot the inverse log Jacobian determinant plt.plot(x, cubic.inverse_log_det_jacobian(x, event_ndims=0)) plt.show() # ### TransformedDistribution and plots # In[22]: # Create a transformed distribution with cubic normal = tfd.Normal(loc=0., scale=1.) cubed_normal = tfd.TransformedDistribution(tfd.Sample(normal, sample_shape=[2]), cubic) # In[23]: # Sample cubed_normal n = 1000 g = cubed_normal.sample(n) g.shape # In[24]: # Plot histogram plt.subplot(1, 2, 1) plt.hist(g[..., 0].numpy(), bins=50, density=True) plt.subplot(1, 2, 2) plt.hist(g[..., 1].numpy(), bins=50, density=True) plt.show() # In[25]: # Make contour plot xx = np.linspace(-0.5, 0.5, 100) yy = np.linspace(-0.5, 0.5, 100) X, y = np.meshgrid(xx, yy) fig, ax = plt.subplots(1, 1) Z = cubed_normal.prob(np.dstack((X, y))) cp = ax.contourf(X, y, Z) fig.colorbar(cp) ax.set_title('Filled Contours Plot') ax.set_xlabel('X') ax.set_ylabel('y') plt.show() # In[27]: # Create a transformed distribution with the inverse of Cube inverse_cubic = tfb.Invert(cubic) inv_cubed_normal = inverse_cubic(tfd.Sample(normal, sample_shape=[2])) # In[28]: # Sample inv_cubed_normal g = inv_cubed_normal.sample(n) g.shape # In[30]: # Make contour plot xx = np.linspace(-3.0, 3.0, 100) yy = np.linspace(-2.0, 2.0, 100) X, y = np.meshgrid(xx, yy) fig, ax = plt.subplots(1, 1) Z = inv_cubed_normal.prob(np.dstack((X, y))) cp = ax.contourf(X, y, Z) fig.colorbar(cp) ax.set_title('Filled Contours Plot') ax.set_xlabel('X') ax.set_ylabel('y') plt.show() # In[31]: # Plot histogram plt.subplot(1, 2, 1) plt.hist(g[..., 0].numpy(), bins=50, density=True) plt.subplot(1, 2, 2) plt.hist(g[..., 1].numpy(), bins=50, density=True) plt.show() # ### Training the bijector # In[32]: # Create a mixture of four Gaussians probs = [0.45, 0.55] mix_gauss = tfd.Mixture( cat=tfd.Categorical(probs=probs), components=[ tfd.Normal(loc=2.3, scale=0.4), tfd.Normal(loc=-0.8, scale=0.4) ] ) # In[33]: # Create the dataset X_train = mix_gauss.sample(10000) X_train = tf.data.Dataset.from_tensor_slices(X_train).batch(128) X_valid = mix_gauss.sample(10000) X_valid = tf.data.Dataset.from_tensor_slices(X_valid).batch(128) # In[34]: print(X_train.element_spec) print(X_valid.element_spec) # In[35]: # Plot the data distribution x = np.linspace(-5.0, 5.0, 100) plt.plot(x, mix_gauss.prob(x)) plt.title('Data Distribution') plt.show() # In[36]: # Make a trainable bijector trainable_inv_cubic = tfb.Invert(Cubic(tf.Variable(0.25), tf.Variable(-0.1))) trainable_inv_cubic.trainable_variables # In[37]: # Make a trainable transformed distribution trainable_dist = tfd.TransformedDistribution(normal, trainable_inv_cubic) # In[38]: # Plot the data and learned distributions x = np.linspace(-5.0, 5.0, 100) plt.plot(x, mix_gauss.prob(x), label='data') plt.plot(x, trainable_dist.prob(x), label='trainable') plt.title('Data and trainable distribution') plt.legend(loc='best') plt.show() # In[40]: # Train the bijector num_epochs = 10 opt = tf.keras.optimizers.Adam() train_losses = [] valid_losses = [] for epoch in range(num_epochs): print("Epoch {}...".format(epoch)) train_loss = tf.keras.metrics.Mean() val_loss = tf.keras.metrics.Mean() for train_batch in X_train: with tf.GradientTape() as tape: tape.watch(trainable_inv_cubic.trainable_variables) loss = -trainable_dist.log_prob(train_batch) train_loss(loss) grads = tape.gradient(loss, trainable_inv_cubic.trainable_variables) opt.apply_gradients(zip(grads, trainable_inv_cubic.trainable_variables)) train_losses.append(train_loss.result().numpy()) # Validation for valid_batch in X_valid: loss = -trainable_dist.log_prob(valid_batch) val_loss(loss) valid_losses.append(val_loss.result().numpy()) # In[41]: # Plot the learning curves plt.plot(train_losses, label='train') plt.plot(valid_losses, label='valid') plt.legend() plt.xlabel("Epochs") plt.ylabel("Negative log likelihood") plt.title("Training and validation loss curves") plt.show() # In[42]: # Plot the data and learned distributions x = np.linspace(-5.0, 5.0, 100) plt.plot(x, mix_gauss.prob(x), label='data') plt.plot(x, trainable_dist.prob(x), label='learned') plt.title('Data and learned distribution') plt.legend(loc='best') plt.show() # In[43]: # Display trainable variable trainable_inv_cubic.trainable_variables