RoPE源码实现

def apply_rotary_pos_emb(
    x: torch.Tensor,
    cos: torch.Tensor,
    sin: torch.Tensor,
    position_ids: torch.Tensor,
) -> torch.Tensor:
    """Applies rotary positional embedding to the input tensor.

    Args:
        x(torch.Tensor): input tensor, with shape: [..., s, hd]
        cos(torch.Tensor): cos basis tensor, with shape: [max_seqlen, hd]
        sin(torch.Tensor): sin basis tensor, with shape: [max_seqlen, hd]
        position_ids(torch.Tensor): position id tensor, with shape: [s]

    Returns:
        output(torch.Tensor): embedded output tensor, with shape: [s, hd]
    """
    # with magi_attention
    if position_ids is not None:
        cos = cos[position_ids]
        sin = sin[position_ids]
    else: # without magi_attention
        cos = cos[: x.shape[-2]]
        sin = sin[: x.shape[-2]]

    input_batch_shape_size = len(x.shape[:-2])

    for _ in range(input_batch_shape_size):
        cos = cos.unsqueeze(0)
        sin = sin.unsqueeze(0)

    output = (x * cos) + (rotate_half(x) * sin)
    return output.to(x.dtype)
class LlamaRotaryEmbedding(nn.Module):
    def __init__(self, config: LlamaConfig):
        super().__init__()
        self.head_dim = config.head_dim
        self.rope_theta = config.rope_theta
        self.max_seqlen = config.max_seqlen
        self.emb_dtype = config.hp_params_dtype

        inv_freq = 1.0 / (
            self.rope_theta
            ** (
                torch.arange(0, self.head_dim, 2, dtype=torch.int64).to(
                    dtype=self.emb_dtype
                )
                / self.head_dim
            )
        )
        t = torch.arange(self.max_seqlen, dtype=self.emb_dtype)
        freqs = torch.outer(t, inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        cos_emb, sin_emb = emb.cos(), emb.sin()  # shape: (s, hd)

        self.register_buffer("cos_emb", cos_emb, persistent=False)
        self.register_buffer("sin_emb", sin_emb, persistent=False)

    def forward(
        self, x: torch.Tensor, dist_attn_runtime_key: DistAttnRuntimeKey | None = None
    ) -> torch.Tensor:
        """Forward pass of the RoPE Module.
        Args:
            x (torch.Tensor): Input tensor, with shape: [..., s, hd].
            DistAttnRuntimeKey (DistAttnRuntimeKey, optional):  DistAttnRuntimeKey. Default None.

        Returns:
            torch.Tensor: Output tensor, with shape: [...,s, hd].
        """
        # get the position_ids after dispatching # with magi_attention
        if dist_attn_runtime_key:
            position_ids = get_position_ids(dist_attn_runtime_key)
        else: # without magi_attention
            position_ids = None

        x = apply_rotary_pos_emb(
            x,
            cos=self.cos_emb,
            sin=self.sin_emb,
            position_ids=position_ids,
        )

        return x

所有代码:https://github.com/SandAI-org/MagiAttention/blob/main/example/torch_native/modeling_llama.py

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值