In resnet_main.py add the following imports and functions:
import tensorflow_probability as tfp
tfd = tfp.distributions
def ezx_dist(x):
"""Builds the encoder distribution, e(z|x)."""
dist = tfd.MultivariateNormalDiag(loc=x)
return dist
def bzy_dist(y, num_classes=1000, z_dims=2048):
"""Builds the backwards distribution, b(z|y)."""
y_onehot = tf.one_hot(y, num_classes)
mus = tf.layers.dense(y_onehot, z_dims, activation=None)
dist = tfd.MultivariateNormalDiag(loc=mus)
return dist
def cyz_dist(z, num_classes=1000):
"""Builds the classifier distribution, c(y|z)."""