Work in progress.
import conx as cx
import keras.backend as K
Using TensorFlow backend. ConX, version 3.7.3
We need a function to use as the activation function for the Sampler layer:
LENGTH = 5 # latent size
def sampler(inputs):
## inputs is a merged concat
mean, stddev = inputs[:, :LENGTH], inputs[:, LENGTH:]
# we sample from the standard normal a matrix of batch_size * latent_size (taking into account minibatches)
std_norm = K.random_normal(shape=(K.shape(mean)[0], LENGTH), mean=0, stddev=1)
# sampling from Z~N(μ, σ^2) is the same as sampling from μ + σX, X~N(0,1)
return mean + K.exp(stddev) * std_norm
CAPACITY = 5 # size of encoded bank
BETA = 1.5
def bvae_loss(tensor):
print("HERE!", tensor.shape)
LENGTH = tensor.shape[1]
if LENGTH == 10:
mean, stddev = tensor[:,:LENGTH//2], tensor[:,LENGTH//2:]
else:
mean, stddev = tensor, tensor
# kl divergence:
latent_loss = -0.5 * K.mean(1 + stddev
- K.square(mean)
- K.exp(stddev), axis=-1)
# use beta to force less usage of vector space:
# also try to use <capacity> dimensions of the space:
latent_loss = BETA * K.abs(latent_loss - CAPACITY/LENGTH)
return K.sum(latent_loss)
def vae_loss(tensor):
print("HERE!", tensor.shape)
LENGTH = tensor.shape[1]
if LENGTH == 10:
mean, stddev = tensor[:,:LENGTH//2], tensor[:,LENGTH//2:]
else:
mean, stddev = tensor, tensor
# kl divergence:
latent_loss = -0.5 * K.mean(1 + stddev
- K.square(mean)
- K.exp(stddev), axis=-1)
return K.sum(latent_loss)
net = cx.Network("vae")
net.add(cx.Layer("input", 2),
cx.Layer("mean", LENGTH, activation="sigmoid"),
cx.Layer("stddev", LENGTH, activation="sigmoid"),
cx.LambdaLayer("encode", 5, sampler), # function, that takes input layer's output
cx.Layer("output", 1, activation="sigmoid"));
#net.additional_output_banks = ["encode"]
net.connect("input", "mean")
net.connect("input", "stddev")
net.connect("mean", "encode")
net.connect("stddev", "encode")
net.connect("encode", "output")
net.build_model()
net.add_loss("encode", vae_loss)
HERE! (?, 5)
net.compile(loss="mse", optimizer="adam")
To allow an additional error function, we need to declare "encode" (an internal bank) as an output:
And then we can provide a dictionary of error functions by name:
net.picture([1,-1], hspace=200, scale=1.0)
net.dataset.load([
[[0, 0], [0]],
[[0, 1], [1]],
[[1, 0], [1]],
[[1, 1], [0]],
])
net.evaluate(show=True)
======================================================== Testing validation dataset with tolerance 0.1... Total count: 4 correct: 0 incorrect: 4 Total percentage correct: 0.0
net.dataset.summary()
_________________________________________________________________ vae Dataset: Patterns Shape Range ================================================================= inputs (2,) (0.0, 1.0) targets (1,) (0.0, 1.0) ================================================================= Total patterns: 4 Training patterns: 4 Testing patterns: 0 _________________________________________________________________
net.propagate_to("encode", net.dataset.inputs, sequence=True)
array([[-0.5848038 , 0.86490923, -0.04118258, -0.6218531 , -2.0275912 ], [ 1.4013709 , -0.40050203, 2.3741612 , 0.33983442, 1.589536 ], [-1.9229012 , 3.6163 , 1.7998898 , 0.44848615, 1.546372 ], [ 2.0491967 , -2.4284263 , -0.8186321 , 1.8166237 , 0.07840455]], dtype=float32)
net.propagate_to("encode", net.dataset.inputs, sequence=True)
array([[ 0.7198903 , 0.86922455, 2.8350933 , 3.4182205 , 1.9641373 ], [ 0.5212698 , 0.49677068, 2.7393699 , 0.24070835, 1.8382547 ], [ 1.5064052 , -0.4074353 , 0.6014143 , 0.78827447, -2.2850318 ], [-1.8501136 , 0.9037933 , 0.26733384, -0.2333259 , 1.3512714 ]], dtype=float32)
net.propagate(net.dataset.inputs[0])
[0.9892443418502808]
net.propagate(net.dataset._inputs[0], sequence=True)
array([[0.9027333 ], [0.8923163 ], [0.36847237], [0.4038974 ]], dtype=float32)
CAPACITY used to break input down to a set number of basis.
BETA (> 1) used for latent regularizer.
#net.reset()
net.train(epochs=30000, report_rate=100)
======================================================== | Training | Training Epochs | Error | Accuracy ------ | --------- | --------- #45000 | 3.95330 | 0.00000
net.plot_activation_map(to_layer="encode")