the variable name from logits to net.
with tf.contrib.tpu.bfloat16_scope():
net = build_network()
net = tf.cast(net, tf.float32)
elif params['precision'] == 'float32':
net = build_network()