变分自编码器图像生成与循环神经网络文本生成
立即解锁
发布时间: 2025-09-05 01:43:17 阅读量: 5 订阅数: 19 AIGC 

### 变分自编码器图像生成与循环神经网络文本生成
#### 变分自编码器(VAE)图像生成
变分自编码器(VAE)在图像生成领域有着独特的优势。通过在潜在空间中操作,VAE 能够生成训练集中未出现过的新图像。
##### 1. VAE 生成新图像
我们可以在潜在空间中随机抽取向量表示,然后将其输入到训练好的 VAE 的解码器中,从而得到新的图像。这些图像与训练集中的任何原始图像都不对应,这是因为 VAE 的潜在空间是连续且可解释的,新的、未见过的编码可以被有意义地解码为与训练集中图像相似但不同的图像。
##### 2. 编码算术操作
VAE 的损失函数中包含正则化项(KL 散度),这使得潜在空间近似于正态分布。这样的正则化确保了潜在变量不仅仅是记忆训练数据,而是捕捉到了数据的潜在分布,从而形成一个结构良好的潜在空间,相似的数据点在这个空间中会被映射到相近的位置,使得空间具有连续性和可解释性。基于此,我们可以对编码进行操作以获得新的结果。
为了演示编码算术在 VAE 中的工作原理,我们按以下步骤操作:
1. **收集不同特征的图像**:手动收集四组图像,每组三张,分别是戴眼镜的男性、不戴眼镜的男性、戴眼镜的女性和不戴眼镜的女性。
```python
torch.manual_seed(0)
glasses=[]
for i in range(25):
img,label=data[i]
glasses.append(img)
plt.subplot(5,5,i+1)
plt.imshow(img.numpy().transpose((1,2,0)))
plt.axis("off")
plt.show()
men_g=[glasses[0],glasses[3],glasses[14]]
women_g=[glasses[9],glasses[15],glasses[21]]
noglasses=[]
for i in range(25):
img,label=data[-i-1]
noglasses.append(img)
plt.subplot(5,5,i+1)
plt.imshow(img.numpy().transpose((1,2,0)))
plt.axis("off")
plt.show()
men_ng=[noglasses[1],noglasses[7],noglasses[22]]
women_ng=[noglasses[4],noglasses[9],noglasses[19]]
```
2. **对每组图像进行编码和解码**:将每组的三张图像分别输入到训练好的 VAE 的编码器中,得到它们在潜在空间中的编码,然后计算每组的平均编码,再将平均编码输入到解码器中得到解码后的图像。
```python
# create a batch of images of men with glasses
men_g_batch = torch.cat((men_g[0].unsqueeze(0),
men_g[1].unsqueeze(0),
men_g[2].unsqueeze(0)), dim=0).to(device)
# Obtain the three encodings
_,_,men_g_encodings=vae.encoder(men_g_batch)
# Average over the three images to obtain the encoding for the group
men_g_encoding=men_g_encodings.mean(dim=0)
# Decode the average encoding to create an image of a man with glasses
men_g_recon=vae.decoder(men_g_encoding.unsqueeze(0))
# Do the same for the other three groups
# group 2, women with glasses
women_g_batch = torch.cat((women_g[0].unsqueeze(0),
women_g[1].unsqueeze(0),
women_g[2].unsqueeze(0)), dim=0).to(device)
# group 3, men without glasses
men_ng_batch = torch.cat((men_ng[0].unsqueeze(0),
men_ng[1].unsqueeze(0),
men_ng[2].unsqueeze(0)), dim=0).to(device)
# group 4, women without glasses
women_ng_batch = torch.cat((women_ng[0].unsqueeze(0),
women_ng[1].unsqueeze(0),
women_ng[2].unsqueeze(0)), dim=0).to(device)
# obtain average encoding for each group
_,_,women_g_encodings=vae.encoder(women_g_batch)
women_g_encoding=women_g_encodings.mean(dim=0)
_,_,men_ng_encodings=vae.encoder(men_ng_batch)
men_ng_encoding=men_ng_encodings.mean(dim=0)
_,_,women_ng_encodings=vae.encoder(women_ng_batch)
women_ng_encoding=women_ng_encodings.mean(dim=0)
# decode for each group
women_g_recon=vae.decoder(women_g_encoding.unsqueeze(0))
men_ng_recon=vae.decoder(men_ng_encoding.unsqueeze(0))
women_ng_recon=vae.decoder(women_ng_enc
```
0
0
复制全文
相关推荐









