# Loop over all available experts in the model and perform the computation on each expert for expert_idx in range(self.num_experts): expert_layer = self.experts[expert_idx] # NOTE:取出当前专家负责的seqlen: top_x, 和专家id: idx idx, top_x = torch.where(expert_mask[expert_idx])
if top_x.shape[0] == 0: continue
# in torch it is faster to index using lists than torch tensors top_x_list = top_x.tolist() idx_list = idx.tolist()
# Index the correct hidden states and compute the expert hidden state for # the current expert. We need to make sure to multiply the output hidden # states by `routing_weights` on the corresponding tokens (top-1 and top-2) # NOTE: hidden_states: (batch * sequence_length, hidden_dim) # NOTE: hidden_states[None, top_x_list] -> (1, top_x_list, hidden_dim) current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim) # NOTE: routing_weights.shape = (batch * sequence_length, n_experts) # 经过expert mlp后和专家权重做加权 current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None]
# However `index_add_` only support torch tensors for indexing so we'll use # the `top_x` tensor here. final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) return final_hidden_states, router_logits