class ParallelAttention_router(nn.Module):
def __init__(self, config):
super(ParallelAttention_router, self).__init__()
layer_number=0
self.layer_number = max(1, layer_number)
self.flash_attn_drop = 0.01
self.hidden_size = config.hidden_size
self.projection_size = config.moe_config['moe_num_experts']
self.query = nn.Linear(self.hidden_size, self.projection_size, bias=False)
self.key = nn.Linear(self.hidden_size, self.projection_size, bias=False)
self.value = nn.Linear(self.hidden_size, self.projection_size, bias=False)
def forward(self, hidden_states, attention_mask=None, enc_position_ids=None,
encoder_output=None, inference_params=None,
rotary_pos_emb=None):
is_first_step = False
before_hidden_states = None
query_layer = self.query(hidden_states)
key_layer = self.key(hidden_states)
value_layer = self.value(hidden_states)
b = query_layer.size(0)
s = query_layer.size(1) # seq*batch = token_num
z = query_layer.size(2) # expert_num
# use fp32 router
query_layer = query_layer.float().view(b,s,z,1)
key_layer = key_layer.float().view(b,s,z,1)
value_layer = value_layer.float().view(b,s,z,1)
attn_weights = torch.matmul(query_layer, key_layer.transpose(2, 3))
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_output = torch.matmul(attn_weights, value_layer)
router_output = attn_output.view(b*s, z)
return router_output
class YuanExpertMLP(nn.Module):
def __init__(self, config):
super(YuanExpertMLP, self).__init__()
self.gated_linear_unit = config.moe_config['gated_linear_unit']
self.ffn_hidden_size = config.moe_config['ffn_hidden_size']
if self.gated_linear_unit:
self.w1 = nn.Linear(config.hidden_size, self.ffn_hidden_size*2, bias=False)
else:
self.w1 = nn
源大模型Yuan2.0-M32 混合Attention专家-源码解析
最新推荐文章于 2025-07-31 21:17:56 发布