data_parallel_size: int = world_size // (tensor_model_parallel_size * pipeline_model_parallel_size)
num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size
# Build the data-parallel groups. global _DATA_PARALLEL_GROUP global _DATA_PARALLEL_GLOBAL_RANKS assert _DATA_PARALLEL_GROUP isNone, 'data parallel group is already initialized' all_data_parallel_group_ranks = [] for i in range(pipeline_model_parallel_size): start_rank = i * num_pipeline_model_parallel_groups end_rank = (i + 1) * num_pipeline_model_parallel_groups for j in range(tensor_model_parallel_size): ranks = range(start_rank + j, end_rank, tensor_model_parallel_size) all_data_parallel_group_ranks.append(list(ranks)) group = torch.distributed.new_group(ranks) if rank in ranks: _DATA_PARALLEL_GROUP = group _DATA_PARALLEL_GLOBAL_RANKS = ranks
# Build the model-parallel groups. global _MODEL_PARALLEL_GROUP assert _MODEL_PARALLEL_GROUP isNone, 'model parallel group is already initialized' for i in range(data_parallel_size): ranks = [data_parallel_group_ranks[i] for data_parallel_group_ranks in all_data_parallel_group_ranks] group = torch.distributed.new_group(ranks) if rank in ranks: _MODEL_PARALLEL_GROUP = group
# Build the tensor model-parallel groups. global _TENSOR_MODEL_PARALLEL_GROUP assert _TENSOR_MODEL_PARALLEL_GROUP isNone, \ 'tensor model parallel group is already initialized' for i in range(num_tensor_model_parallel_groups): ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size) group = torch.distributed.new_group(ranks) if rank in ranks: _TENSOR_MODEL_PARALLEL_GROUP = group
# Build the pipeline model-parallel groups and embedding groups # (和Build the model-parallel groups等价) # (first and last rank in each pipeline model-parallel group). for i in range(num_pipeline_model_parallel_groups): ranks = range(i, world_size, num_pipeline_model_parallel_groups) group = torch.distributed.new_group(ranks) if rank in ranks: _PIPELINE_MODEL_PARALLEL_GROUP = group _PIPELINE_GLOBAL_RANKS = ranks # Setup embedding group (to exchange gradients between # first and last stages). if len(ranks) > 1: embedding_ranks = [ranks[0], ranks[-1]] position_embedding_ranks = [ranks[0]] if pipeline_model_parallel_split_rank isnotNone: if ranks[pipeline_model_parallel_split_rank] notin embedding_ranks: embedding_ranks = [ranks[0], ranks[pipeline_model_parallel_split_rank], ranks[-1]] if ranks[pipeline_model_parallel_split_rank] notin position_embedding_ranks: position_embedding_ranks = [ranks[0], ranks[pipeline_model_parallel_split_rank]] else: embedding_ranks = ranks position_embedding_ranks = ranks
group = torch.distributed.new_group(embedding_ranks)