一 整体架构
该模型是以SD为基础的文生图模型,具体扩散模型原理参考https://zhouyifan.net/2023/07/07/20230330-diffusion-model/,代码地址https://github.com/Tencent/HunyuanDiT,这里介绍 Full-parameter Training
二 输入数据处理
这里主要包括图像和文本数据输入处理
2.1 图像处理
这里代码参考 hydit/data_loader/arrow_load_stream.py,生成1024*1024的图片,对于输入图片进行random_crop,之后包括随机水平翻转,转tensor,以及Normalize(减均值0.5, 除以标准差0.5,为什么是这个,是因为通过PIL Image读图之后转到tensor范围是0-1之间,不是opencv读出来像素值在0-255之间),得到最终image( B ∗ 3 ∗ 1024 ∗ 1024 B*3*1024*1024 B∗3∗1024∗1024)
2.2 文本处理
输入的文本,通过BertTokenizer,进行映射,同时补齐长度到77,不够的补0,同时生成相应的attention_mask;同时还有T5TokenizerFast,对于T5的输入,会随机小于uncond_p_t5(目前给出的设置uncond_p_t5=5),输入为空,否则为文本输入,补齐长度256,同时生成相应的attention_mask
2.3 图像编码
对于输入图像,采用VAE encoder 进行编码,生成隐空间特征latents( B ∗ 4 ∗ 128 ∗ 128 B*4*128*128 B∗4∗128∗128,就是输入8倍下采样,计算过程latents = vae.encode(image).latent_dist.sample().mul_(vae_scaling_factor),具体VAE相关后续补充)
2.4 文本编码
包括两个部分,一个是CLIP的text编码,采用bert layer,生成encoder_hidden_states( B ∗ 77 ∗ 1024 B*77*1024 B∗77∗1024);第二部分是mT5的text编码,生成encoder_hidden_states_t5( B ∗ 256 ∗ 2048 B*256*2048 B∗256∗2048)
2.5 位置编码
这里是采用根据预设的分辨率,提前生成好的位置编码,这里采用ROPE,生成cos_cis_img, sin_cis_img (分别都是 4096 ∗ 88 4096*88 4096∗88)
最终生成图像编码latents,文本编码(encoder_hidden_states以及对应的attention_mask,encoder_hidden_states_t5以及对应的attention_mask),以及位置编码cos_cis_img, sin_cis_img
三 DIT模型
3.1 add noise过程
- 根据上一步的输出latents,作为x_start,随机选取一个time step,根据q_sample,得到增加噪声之后的输出x_t(具体公式参考如下,x0对应x_start,xt对应x_t)
3.2 HunYuanDiT模型训练过程
- 对于输入的文本编码,包括text_states( B ∗ 77 ∗ 1024 B*77*1024 B∗77∗1024),text_states_t5( B ∗ 256 ∗ 2048 B*256*2048 B∗256∗2048)以及相应的attention_mask,对于text_states_t5通过Linear+Silu+Linear,转成 B ∗ 256 ∗ 1024 B*256*1024 B∗256∗1024,然后对着两个进行concat,得到text_states( B ∗ 333 ∗ 1024 B*333*1024 B∗333∗1024),对于attention_mask也concat得到clip_t5_mask( B ∗ 333 B*333 B∗333);这里会生成一个可学习的text_embedding_padding特征( B ∗ 333 ∗ 1024 B*333*1024 B∗333∗1024),对于clip_t5_mask中通过补0得到的特征全部替换成text_embedding_padding特征
- 对于输入time step 先走timestep_embedding(就是sinusoidal编码),然后通过Linear+Silu+Linear得到最终t ( B ∗ 1408 B*1408 B∗1408)
- 对于输入x(就是上一步的x_t),通过PatchEmbed(就是VIT前面对图像进行patch),得到x( B ∗ 4096 ∗ 1408 , 4096 是 64 ∗ 64 B*4096*1408,4096是64*64 B∗4096∗1408,4096是64∗64)
- 对于text_states_t5( B ∗ 256 ∗ 2048 B*256*2048 B∗256∗2048),添加一个AttentionPool模块,就是对于输入在256维度上,进行mean,当成query,然后将输入和query concat一起得到257维,作为key和value,(其中query,key,value都添加位置编码)做multi_head_attention,得到最终输出extra_vec( B ∗ 1024 B*1024 B∗1024)
- 对于extra_vec 通过Linear+Silu+Linear得到( B ∗ 1408 B*1408 B∗1408),然后与通过time step得到的t相加,得到c( B ∗ 1408 B*1408 B∗1408,作为所有extra_vectors)
3.2.1 进入Dit Block
一共40个block,前面0到18个block的生成输入,中间19,20作为middle block,剩余的block会增加一个前面19个block输出的结果作为skip
3.2.1.1 前面0到18共19个block
- 前面一共19个block的过程,输入x( B ∗ 4096 ∗ 1408 B*4096*1408 B∗4096∗1408),c( B ∗ 1408 B*1408 B∗1408),text_states( B ∗ 333 ∗ 1024 B*333*1024 B∗333∗1024),位置编码freqs_cis_img (cos_cis_img, sin_cis_img,分别都是 B ∗ 4096 ∗ 88 B*4096*88 B∗4096∗88)
HunYuanDiTBlock(
(norm1): FP32_Layernorm((1408,), eps=1e-06, elementwise_affine=True)
(attn1): FlashSelfMHAModified(
(Wqkv): Linear(in_features=1408, out_features=4224, bias=True)
(q_norm): LayerNorm((88,), eps=1e-06, elementwise_affine=True)
(k_norm): LayerNorm((88,), eps=1e-06, elementwise_affine=True)
(inner_attn): FlashSelfAttention(
(drop): Dropout(p=0.0, inplace=False)
)
(out_proj): Linear(in_features=1408, out_features=1408, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(norm2): FP32_Layernorm((1408,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=1408, out_features=6144, bias=True)
(act): GELU(approximate='tanh')
(drop1): Dropout(p=0, inplace=False)
(norm): Identity()
(fc2): Linear(in_features=6144, out_features=1408, bias=True)
(drop2): Dropout(p=0, inplace=False)
)
(default_modulation): Sequential(
(0): FP32_SiLU()
(1): Linear(in_features=1408, out_features=1408, bias=True)
)
(attn2): FlashCrossMHAModified(
(q_proj): Linear(in_features=1408, out_features=1408, bias=True)
(kv_proj): Linear(in_features=1024, out_features=2816, bias=True)
(q_norm): LayerNorm((88,), eps=1e-06, elementwise_affine=True)
(k_norm): LayerNorm((88,), eps=1e-06, elementwise_affine=True)
(inner_attn): FlashCrossAttention(
(drop): Dropout(p=0.0, inplace=False)
)
(out_proj): Linear(in_features=1408, out_features=1408, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(norm3): FP32_Layernorm((1408,), eps=1e-06, elementwise_affine=True)
)
- 对于c 通过default_modulation,得到shift_msa( B ∗ 4096 ∗ 1408 B*4096*1408 B∗4096∗1408),与经过norm1之后的x进行相加作为attn1的输入(就是Flash Self Attention)
- 将attn1的输出与原始的x进行残差相加,在经过norm3,与text_states一起作为attn2的输入(就是Flash Cross Attention)
- 在将经过残差相加之后的x与attn2的输出在进行残差相加,作为输入,走FFN,即先经过norm2,在经过mlp,之后与输入残差相加
3.2.1.2 第19和20 middle block
- 中间第19 和 20 两个block作为middle block,方式和上面一样
3.2.1.3 后面21到39共19个block
- 从第21个block开始,增加一个输入,例如第21个block,会将第18个block的输出作为输入
(skip_norm): FP32_Layernorm((2816,), eps=1e-06, elementwise_affine=True)
(skip_linear): Linear(in_features=2816, out_features=1408, bias=True)
- 就是对于新的输入skip,将skip与x进行concat之后,经过skip norm,然后在经过skip linear,得到输出x,剩余步骤与前面一样
3.2.2 最后FInal layer处理
- 输入x和c,x是上面所有dit block的输出,c是上面的extra_vectors;对于c先进行SILU+Linear,得到( B ∗ 2816 B*2816 B∗2816),并拆分成shift 和 scale(分别为 B ∗ 1408 B*1408 B∗1408),最终通过x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1),然后通过Linear,得到最终输出x( B ∗ 4096 ∗ 32 B*4096*32 B∗4096∗32),然后通过转换self.unpatchify得到输出imgs ( B ∗ 8 ∗ 128 ∗ 128 B*8*128*128 B∗8∗128∗128,这里包括输出的噪声,以及预测的var,这里diffusion var type是 LEARNED_RANGE,分别对应的shape B ∗ 4 ∗ 128 ∗ 128 B*4*128*128 B∗4∗128∗128)
3.3 gussian_diffusion过程
该类是diffusion的主类,包括q_sample和loss计算,具体可以参考https://blog.csdn.net/qq_19841133/article/details/126948474
3.3.1 计算变分下界
def _vb_terms_bpd(
self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
):
"""
Get a term for the variational lower-bound.
The resulting units are bits (rather than nats, as one might expect).
This allows for comparison to other papers.
:return: a dict with the following keys:
- 'output': a shape [N] tensor of NLLs or KLs.
- 'pred_xstart': the x_0 predictions.
"""
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
x_start=x_start, x_t=x_t, t=t
)
out = self.p_mean_variance(
model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
)
kl = normal_kl(
true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
)
kl = mean_flat(kl) / np.log(2.0)
decoder_nll = -discretized_gaussian_log_likelihood(
x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
)
assert_shape(decoder_nll, x_start)
decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
# At the first timestep return the decoder NLL,
# otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
output = th.where((t == 0), decoder_nll, kl)
return {"output": output, "pred_xstart": out["pred_xstart"], "extra": out["extra"]}
-
计算后验均值方差 self.q_posterior_mean_variance,也就是根据训练输入的x_start 和 x_t,计算得到true_mean,和true_log_variance_clipped
-
计算预测均值方差 self.p_mean_variance,根据3.2.2模型输出的结果,包括预测的噪声和var;通过预测的噪声,生成pred_xstart;在通过pred_xstart 和 x_t,通过self.q_posterior_mean_variance,得到预测的均值
-
计算kl散度,根据true_mean,true_log_variance_clipped,pred_mean, pred_variance
-
对于time=0时候,计算离散化高斯对数似然,discretized_gaussian_log_likelihood
-
最终对于模型预测的噪声和实际添加的noise做mse loss再加上上面计算的kl loss,得到最终loss