完整代码下载链接
因为之前用生成对抗网络及众多变体生成诸如心电信号,肌电信号,脑电信号,微震信号,机械振动信号,雷达信号等,但生成的信号在频谱或者时频谱上表现很差,所以暂时先不涉及到这些复杂信号,仅仅以手写数字图像为例进行说明,因为Python相关的资源太多了,我就不凑热闹了,使用的编程环境为MALAB R2021B。
首先看一下对抗自编码器AAE(Adversarial AutoEncoder),关于AAE的大致理解,可以查看如下文章
AAE(Adversarial Autoencoders)浅解 - 嘎嘎小鱼仔的文章 - 知乎 AAE(Adversarial Autoencoders)浅解 - 知乎
AAE根据变分自编码器VAE发展而来,其发展之处就在于加入了对抗的思想。
上半部分就是一个简单典型的自编码器AE结构,包含输入层input layer,编码层encoder layer, 隐层hidden layer, 解码层decoder layer , 输出层output layer。encoder把真实分布x映射为隐层z, decoder 再将z解码还原成x。AAE的特点就在于在隐层hidden layer中引入了对抗的思想来优化隐层的z,判别器discriminator 需要在隐层判断采样后的真实数据和生成器encoder所产生的假数据。因此discriminator的目的就是使得q(z | x) 不断向p(z)靠近。
Adversarial Autoencoders论文链接:https://arxiv.org/abs/1511.0564
下面直接上代码
首先,导入相关的mnist手写数字图
load('mnistAll.mat')
然后对训练、测试图像进行预处理
trainX = preprocess(mnist.train_images);
trainY = mnist.train_labels;%训练标签
testX = preprocess(mnist.test_images);
testY = mnist.test_labels;%测试标签
preprocess为归一化函数,如下
function x = preprocess(x)
x = double(x)/255;
x = (x-.5)/.5;
x = reshape(x,28*28,[]);
end
然后进行参数设置,包括潜变量空间维度,batch_size大小,学习率,最大迭代次数等等
settings.latent_dim = 10;
settings.batch_size = 32; settings.image_size = [28,28,1];
settings.lrD = 0.0002; settings.lrG = 0.0002; settings.beta1 = 0.5;
settings.beta2 = 0.999; settings.maxepochs = 50;
下面进行编码器初始化,代码还是很容易看懂的