def apply_rotary_pos_emb(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
position_ids: torch.Tensor,
) -> torch.Tensor:
"""Applies rotary positional embedding to the input tensor.
Args:
x(torch.Tensor): input tensor, with shape: [..., s, hd]
cos(torch.Tensor): cos basis tensor, with shape: [max_seqlen, hd]
sin(torch.Tensor): sin basis tensor, with shape: [max_seqlen, hd]
position_ids(torch.Tensor): position id tensor, with shape: [s]
Returns:
output(torch.Tensor): embedded output tensor, with shape: [s, hd]
"""
# with magi_attention
if position_ids is not None:
cos = cos[position_ids]
sin = sin[position_ids]
else: # without magi_attention
cos = cos[: x.shape[-2]]
sin = sin[: x.shape[-2]]
input_batch_shape_size = len(x.shape[:-2])
for _ in range(input_batch_shape_size):
cos = cos.unsqueeze(0)
sin = sin.unsqueeze(0)
output = (x * cos) + (rotate_half(x) * sin)
return output.to(x.dtype)
class LlamaRotaryEmbedding(nn.Module):
def __init__(self, config: LlamaConfig):
super().__init__()
self.head_dim = config.head_dim
self.rope_theta = config.rope_theta
self.max_seqlen = config.max_seqlen
self.emb_dtype = config.hp_params_dtype
inv_freq = 1.0 / (
self.rope_theta
** (
torch.arange(0, self.head_dim, 2, dtype=torch.int64).to(
dtype=self.emb_dtype
)
/ self.head_dim
)
)
t = torch.arange(self.max_seqlen, dtype=self.emb_dtype)
freqs = torch.outer(t, inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
cos_emb, sin_emb = emb.cos(), emb.sin() # shape: (s, hd)
self.register_buffer("cos_emb", cos_emb, persistent=False)
self.register_buffer("sin_emb", sin_emb, persistent=False)
def forward(
self, x: torch.Tensor, dist_attn_runtime_key: DistAttnRuntimeKey | None = None
) -> torch.Tensor:
"""Forward pass of the RoPE Module.
Args:
x (torch.Tensor): Input tensor, with shape: [..., s, hd].
DistAttnRuntimeKey (DistAttnRuntimeKey, optional): DistAttnRuntimeKey. Default None.
Returns:
torch.Tensor: Output tensor, with shape: [...,s, hd].
"""
# get the position_ids after dispatching # with magi_attention
if dist_attn_runtime_key:
position_ids = get_position_ids(dist_attn_runtime_key)
else: # without magi_attention
position_ids = None
x = apply_rotary_pos_emb(
x,
cos=self.cos_emb,
sin=self.sin_emb,
position_ids=position_ids,
)
return x
所有代码:https://github.com/SandAI-org/MagiAttention/blob/main/example/torch_native/modeling_llama.py