notations
- xxx image
- zzz latent
- yyy label (omitted to lighten notation)
- p(x∣z)p(x|z)p(x∣z) decoder
Encoder
- q(z∣x)q(z|x)q(z∣x) encoder
Decoder
- p^(h)\hat{p}(h)p^(h) prior encoder (by variational inference)
PriorEncoder
model structure
class CVAE(nn.Module):
def __init__(self, config):
super(CVAE, self).__init__()
self.encoder = Encoder(...)
self.decoder = Decoder(...)
self.priorEncoder = PriorEncoder(...)
def forward(self, x, y):
x = x.reshape((-1, 784)) # MNIST
mu, sigma = self.encoder(x, y)
prior_mu, prior_sigma = self.priorEncoder(y)
z = torch.randn_like(mu)
z = z * sigma + mu
reconstructed_x = self.decoder(z, y)
reconstructed_x = reconstructed_x.reshape((-1, 28, 28))
return reconstructed_x, mu, sigma, prior_mu, prior_sigma
def infer(self, y):
prior_mu, prior_sigma = self.priorEncoder(y)
z = torch.randn_like(prior_mu)
z = z * prior_sigma + prior_mu
reconstructed_x = self.decoder(z, y)
return reconstructed_x
#
class Loss(nn.Module):
def __init__(self):
super(Loss,self).__init__()
self.loss_fn = nn.MSELoss(reduction='mean')
self.kld_loss_weight = 1e-5
def forward(self, x, reconstructed_x, mu, sigma, prior_mu, prior_sigma):
mse_loss = self.loss_fn(x, reconstructed_x)
kld_loss = torch.log(prior_sigma / sigma) + (sigma**2 + (mu - prior_mu)**2) / (2 * prior_sigma**2) - 0.5
kld_loss