Cheat Sheet
MLP: down(act(gate(x)) * up(x)) # 两个path矩阵乘,再过最后一个线性层
每个token都取topk个专家和权重
[B, L, D] -> [B, L, topk, D]
每个专家会处理x个token
sorted_tokens, 让属于同一个专家的token在一起
for 选专家, 选token处理, 依次append -> [total_num_tokens, D]
[exper_token_num, D] -> [exper_token_num, D]
索引还原: new_x = [B*L, topk, D], 每个token的topk个专家的处理结果
按权重加和每个专家的结果: new_x * topk_weights = [BL, topk, D], sum(dim=1) = [B L, D]
Mixtral MoE源码笔记
transformers/src/transformers/models/mixtral/modeling_mixtral.py
注意是mixtral不是mistral
和llama基本相同, 主要区别只在与MLP: 混合专家中的MLP有num_experts个mlp, 而llama只有一个mlp。核心代码在于MixtralSparseMoeBlock。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 class MixtralDecoderLayer (nn.Module) : def __init__ (self, config: MixtralConfig, layer_idx: int) : super().__init__() self.hidden_size = config.hidden_size self.self_attn = MIXTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) self.block_sparse_moe = MixtralSparseMoeBlock(config) self.input_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward ( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, output_router_logits: Optional[bool] = False, use_cache: Optional[bool] = False, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, ) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states, router_logits = self.block_sparse_moe(hidden_states) hidden_states = residual + hidden_states
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 class LlamaDecoderLayer (nn.Module) : def __init__ (self, config: LlamaConfig, layer_idx: int) : super().__init__() self.hidden_size = config.hidden_size self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) self.mlp = LlamaMLP(config) self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward ( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, **kwargs, ) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states
MixtralSparseMoeBlock MixtralSparseMoeBlock根据attention计算的结果hidden_state去选取topk个专家(mlp)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 router_logits = self.gate(hidden_states) routing_weights = F.softmax(router_logits, dim=1 , dtype=torch.float) routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1 ) expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2 , 1 , 0 ) for expert_idx in range(self.num_experts): expert_layer = self.experts[expert_idx] idx, top_x = torch.where(expert_mask[expert_idx]) if top_x.shape[0 ] == 0 : continue top_x_list = top_x.tolist() idx_list = idx.tolist() current_state = hidden_states[None , top_x_list].reshape(-1 , hidden_dim) current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None ] final_hidden_states.index_add_(0 , top_x, current_hidden_states.to(hidden_states.dtype)) final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) return final_hidden_states, router_logits
因为batch * seqlen可能很大, 所以处于计算效率的考虑, 对selected_expert做permute(2, 1, 0)使得其形状变为: (num_experts_one_hot, topk_experts, batch * sequence_length)
之后, 处理流程为:
遍历每个专家
选出当前专家负责的bs * seqlen的索引信息: idx, top_x = torch.where(expert_mask[expert_idx])
idx为当前专家的index
top_x为当前专家负责的bs * seqlen
选出当前专家负责的数据(bs * seqlen中选取)进行处理, 并根据选出的专家的权重进行加权
选出数据current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
hidden_states[None, top_x_list] -> (1, top_x_list, hidden_dim) -> current_state[top_x_list, hidden_dim]
专家处理expert_layer(current_state)
topk专家加权: current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None]
routing_weights.shape = (batch * sequence_length, n_experts)
不断累加中间结果直到遍历完所有专家
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
final_hidden_states.shape = (bs * seqlen, hidden_dim)
把当前专家计算得到的bs * seqlen维度的数据累加到最终结果的bs * seqlen维度的对应位置
核心:
多个mlp层加权
计算效率: 在bs * seqlen维度上并行
遍历所有专家
每个专家处理其负责的bs * seqlen维度上的数据