Hunuan-DiT代码阅读

一 整体架构

该模型是以SD为基础的文生图模型,具体扩散模型原理参考https://zhouyifan.net/2023/07/07/20230330-diffusion-model/,代码地址https://github.com/Tencent/HunyuanDiT,这里介绍 Full-parameter Training

二 输入数据处理

这里主要包括图像和文本数据输入处理

2.1 图像处理

这里代码参考 hydit/data_loader/arrow_load_stream.py,生成1024*1024的图片,对于输入图片进行random_crop,之后包括随机水平翻转,转tensor,以及Normalize(减均值0.5, 除以标准差0.5,为什么是这个,是因为通过PIL Image读图之后转到tensor范围是0-1之间,不是opencv读出来像素值在0-255之间),得到最终image( B ∗ 3 ∗ 1024 ∗ 1024 B*3*1024*1024 B310241024

2.2 文本处理

输入的文本,通过BertTokenizer,进行映射,同时补齐长度到77,不够的补0,同时生成相应的attention_mask;同时还有T5TokenizerFast,对于T5的输入,会随机小于uncond_p_t5(目前给出的设置uncond_p_t5=5),输入为空,否则为文本输入,补齐长度256,同时生成相应的attention_mask

2.3 图像编码

对于输入图像,采用VAE encoder 进行编码,生成隐空间特征latents( B ∗ 4 ∗ 128 ∗ 128 B*4*128*128 B4128128,就是输入8倍下采样,计算过程latents = vae.encode(image).latent_dist.sample().mul_(vae_scaling_factor),具体VAE相关后续补充)

2.4 文本编码

包括两个部分,一个是CLIP的text编码,采用bert layer,生成encoder_hidden_states( B ∗ 77 ∗ 1024 B*77*1024 B771024);第二部分是mT5的text编码,生成encoder_hidden_states_t5( B ∗ 256 ∗ 2048 B*256*2048 B2562048

2.5 位置编码

这里是采用根据预设的分辨率,提前生成好的位置编码,这里采用ROPE,生成cos_cis_img, sin_cis_img (分别都是 4096 ∗ 88 4096*88 409688)

最终生成图像编码latents,文本编码(encoder_hidden_states以及对应的attention_mask,encoder_hidden_states_t5以及对应的attention_mask),以及位置编码cos_cis_img, sin_cis_img

三 DIT模型

3.1 add noise过程

  • 根据上一步的输出latents,作为x_start,随机选取一个time step,根据q_sample,得到增加噪声之后的输出x_t(具体公式参考如下,x0对应x_start,xt对应x_t)
    在这里插入图片描述

3.2 HunYuanDiT模型训练过程

  • 对于输入的文本编码,包括text_states( B ∗ 77 ∗ 1024 B*77*1024 B771024),text_states_t5( B ∗ 256 ∗ 2048 B*256*2048 B2562048)以及相应的attention_mask,对于text_states_t5通过Linear+Silu+Linear,转成 B ∗ 256 ∗ 1024 B*256*1024 B2561024,然后对着两个进行concat,得到text_states( B ∗ 333 ∗ 1024 B*333*1024 B3331024),对于attention_mask也concat得到clip_t5_mask( B ∗ 333 B*333 B333);这里会生成一个可学习的text_embedding_padding特征( B ∗ 333 ∗ 1024 B*333*1024 B3331024),对于clip_t5_mask中通过补0得到的特征全部替换成text_embedding_padding特征
  • 对于输入time step 先走timestep_embedding(就是sinusoidal编码),然后通过Linear+Silu+Linear得到最终t ( B ∗ 1408 B*1408 B1408)
  • 对于输入x(就是上一步的x_t),通过PatchEmbed(就是VIT前面对图像进行patch),得到x( B ∗ 4096 ∗ 1408 , 4096 是 64 ∗ 64 B*4096*1408,4096是64*64 B4096140840966464
  • 对于text_states_t5( B ∗ 256 ∗ 2048 B*256*2048 B2562048),添加一个AttentionPool模块,就是对于输入在256维度上,进行mean,当成query,然后将输入和query concat一起得到257维,作为key和value,(其中query,key,value都添加位置编码)做multi_head_attention,得到最终输出extra_vec( B ∗ 1024 B*1024 B1024
  • 对于extra_vec 通过Linear+Silu+Linear得到( B ∗ 1408 B*1408 B1408),然后与通过time step得到的t相加,得到c( B ∗ 1408 B*1408 B1408,作为所有extra_vectors)

3.2.1 进入Dit Block

一共40个block,前面0到18个block的生成输入,中间19,20作为middle block,剩余的block会增加一个前面19个block输出的结果作为skip

3.2.1.1 前面0到18共19个block
  • 前面一共19个block的过程,输入x( B ∗ 4096 ∗ 1408 B*4096*1408 B40961408),c( B ∗ 1408 B*1408 B1408),text_states( B ∗ 333 ∗ 1024 B*333*1024 B3331024),位置编码freqs_cis_img (cos_cis_img, sin_cis_img,分别都是 B ∗ 4096 ∗ 88 B*4096*88 B409688
HunYuanDiTBlock(
  (norm1): FP32_Layernorm((1408,), eps=1e-06, elementwise_affine=True)
  (attn1): FlashSelfMHAModified(
    (Wqkv): Linear(in_features=1408, out_features=4224, bias=True)
    (q_norm): LayerNorm((88,), eps=1e-06, elementwise_affine=True)
    (k_norm): LayerNorm((88,), eps=1e-06, elementwise_affine=True)
    (inner_attn): FlashSelfAttention(
      (drop): Dropout(p=0.0, inplace=False)
    )
    (out_proj): Linear(in_features=1408, out_features=1408, bias=True)
    (proj_drop): Dropout(p=0.0, inplace=False)
  )
  (norm2): FP32_Layernorm((1408,), eps=1e-06, elementwise_affine=True)
  (mlp): Mlp(
    (fc1): Linear(in_features=1408, out_features=6144, bias=True)
    (act): GELU(approximate='tanh')
    (drop1): Dropout(p=0, inplace=False)
    (norm): Identity()
    (fc2): Linear(in_features=6144, out_features=1408, bias=True)
    (drop2): Dropout(p=0, inplace=False)
  )
  (default_modulation): Sequential(
    (0): FP32_SiLU()
    (1): Linear(in_features=1408, out_features=1408, bias=True)
  )
  (attn2): FlashCrossMHAModified(
    (q_proj): Linear(in_features=1408, out_features=1408, bias=True)
    (kv_proj): Linear(in_features=1024, out_features=2816, bias=True)
    (q_norm): LayerNorm((88,), eps=1e-06, elementwise_affine=True)
    (k_norm): LayerNorm((88,), eps=1e-06, elementwise_affine=True)
    (inner_attn): FlashCrossAttention(
      (drop): Dropout(p=0.0, inplace=False)
    )
    (out_proj): Linear(in_features=1408, out_features=1408, bias=True)
    (proj_drop): Dropout(p=0.0, inplace=False)
  )
  (norm3): FP32_Layernorm((1408,), eps=1e-06, elementwise_affine=True)
)
  • 对于c 通过default_modulation,得到shift_msa( B ∗ 4096 ∗ 1408 B*4096*1408 B40961408),与经过norm1之后的x进行相加作为attn1的输入(就是Flash Self Attention)
  • 将attn1的输出与原始的x进行残差相加,在经过norm3,与text_states一起作为attn2的输入(就是Flash Cross Attention)
  • 在将经过残差相加之后的x与attn2的输出在进行残差相加,作为输入,走FFN,即先经过norm2,在经过mlp,之后与输入残差相加
3.2.1.2 第19和20 middle block
  • 中间第19 和 20 两个block作为middle block,方式和上面一样
3.2.1.3 后面21到39共19个block
  • 从第21个block开始,增加一个输入,例如第21个block,会将第18个block的输出作为输入
  (skip_norm): FP32_Layernorm((2816,), eps=1e-06, elementwise_affine=True)
  (skip_linear): Linear(in_features=2816, out_features=1408, bias=True)
  • 就是对于新的输入skip,将skip与x进行concat之后,经过skip norm,然后在经过skip linear,得到输出x,剩余步骤与前面一样

3.2.2 最后FInal layer处理

  • 输入x和c,x是上面所有dit block的输出,c是上面的extra_vectors;对于c先进行SILU+Linear,得到( B ∗ 2816 B*2816 B2816),并拆分成shift 和 scale(分别为 B ∗ 1408 B*1408 B1408),最终通过x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1),然后通过Linear,得到最终输出x( B ∗ 4096 ∗ 32 B*4096*32 B409632),然后通过转换self.unpatchify得到输出imgs ( B ∗ 8 ∗ 128 ∗ 128 B*8*128*128 B8128128,这里包括输出的噪声,以及预测的var,这里diffusion var type是 LEARNED_RANGE,分别对应的shape B ∗ 4 ∗ 128 ∗ 128 B*4*128*128 B4128128

3.3 gussian_diffusion过程

该类是diffusion的主类,包括q_sample和loss计算,具体可以参考https://blog.csdn.net/qq_19841133/article/details/126948474

3.3.1 计算变分下界

def _vb_terms_bpd(
            self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
    ):
        """
        Get a term for the variational lower-bound.

        The resulting units are bits (rather than nats, as one might expect).
        This allows for comparison to other papers.

        :return: a dict with the following keys:
                 - 'output': a shape [N] tensor of NLLs or KLs.
                 - 'pred_xstart': the x_0 predictions.
        """
        true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
            x_start=x_start, x_t=x_t, t=t
        )
        out = self.p_mean_variance(
            model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
        )
        kl = normal_kl(
            true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
        )
        kl = mean_flat(kl) / np.log(2.0)

        decoder_nll = -discretized_gaussian_log_likelihood(
            x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
        )
        assert_shape(decoder_nll, x_start)
        decoder_nll = mean_flat(decoder_nll) / np.log(2.0)

        # At the first timestep return the decoder NLL,
        # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
        output = th.where((t == 0), decoder_nll, kl)
        return {"output": output, "pred_xstart": out["pred_xstart"], "extra": out["extra"]}
  • 计算后验均值方差 self.q_posterior_mean_variance,也就是根据训练输入的x_start 和 x_t,计算得到true_mean,和true_log_variance_clipped
    在这里插入图片描述
    在这里插入图片描述

  • 计算预测均值方差 self.p_mean_variance,根据3.2.2模型输出的结果,包括预测的噪声和var;通过预测的噪声,生成pred_xstart;在通过pred_xstart 和 x_t,通过self.q_posterior_mean_variance,得到预测的均值
    在这里插入图片描述

  • 计算kl散度,根据true_mean,true_log_variance_clipped,pred_mean, pred_variance
    在这里插入图片描述

  • 对于time=0时候,计算离散化高斯对数似然,discretized_gaussian_log_likelihood

  • 最终对于模型预测的噪声和实际添加的noise做mse loss再加上上面计算的kl loss,得到最终loss

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值