pydata: Huiming's learning notes

Keep Looking, Don't Settle

DeepSeek Expert Parallelism Load Balancer (EPLB) Code Reading

Introduction

In the previous Introduction to DeepSeek-V3, a crucial component highlighted was the use of DeepSeekMoE. When employing Expert Parallelism, different Experts are assigned to different GPUs. Since the load on different Experts may vary depending on the current workload, maintaining load balance across GPUs is critical. DeepSeek-MoE addresses this by replicating high-load Experts and optimizing load distribution to ensure that the load across GPUs/nodes is as balanced as possible. Replicated Experts are packed into GPUs to ensure load balancing between them. Additionally, as DeepSeek-V3 uses Node-limited routing, Experts from the same group are deployed on the same node whenever possible to minimize inter-node data traffic.

DeepSeek has open-sourced its Expert Parallelism load-balancing algorithm for Expert replication and deployment in eplb.py. This article provides a detailed study of this open-source code.

Algorithm

The load-balancing algorithm includes two strategies tailored to different scenarios.

Hierarchical Load Balancing

When the number of server nodes is divisible by the number of Expert groups, a hierarchical load-balancing strategy is used to control the routing of Expert groups. First, Expert groups are evenly packed into nodes to ensure load balance across nodes. Then, Experts are replicated within each node. Finally, the replicated Experts are packed into GPUs to ensure load balance across GPUs. This strategy is suitable for the prefill stage, where the Expert parallelism scale is smaller.

Global Load Balancing

In other cases, a global load-balancing strategy is used, where Experts are replicated globally regardless of their group and packed into GPUs. This strategy is suitable for the decoding stage, where the Expert parallelism scale is larger.

Terminology

  • Logical Expert: Experts designed in the model. The number of Experts is predefined in the model. In the example above, each layer has 12 logical Experts.
  • Physical Expert: Experts created in GPUs for actual computation. If certain logical Experts have heavy loads, they are replicated. Thus, the number of physical Experts is generally greater than or equal to the number of logical Experts. In the example, each layer has 16 physical Experts involved in computation.
  • Expert Group: A logical grouping of multiple Experts in the model. The purpose of Expert groups is to facilitate coarse-grained load management when allocating Experts to nodes. Load balancing is first achieved at the group level, then within each group.
  • num_groups: The number of Expert groups. Each node has the same number of groups, and each group has the same number of Experts. The goal is to place logical Experts into groups to balance tokens per group as much as possible. group_size = num_logical_experts // num_groups.

Function Explanations

  1. balanced_packing: This function allocates Expert groups to different "packs" (e.g., nodes/clusters), ensuring each pack receives approximately the same number of Expert groups with relatively balanced total load. Grouping Experts enables coarse-grained load balancing at the group level before finer allocation within groups.
  2. replicate_experts: This function dynamically replicates logical Experts with high loads, increasing parallel processing capacity and distributing high loads across multiple physical devices to achieve load balancing.
  3. rebalance_experts_hierarchical: The num_groups parameter represents the number of Expert groups. The code first divides logical Experts into groups (weight.unflatten(-1, (num_groups, group_size))) and performs preliminary load balancing by allocating groups to nodes using balanced_packing(tokens_per_group, num_nodes). Within each node, replicate_experts creates replicas for the logical Experts assigned to that node, utilizing available physical Expert slots. This step addresses local load hotspots by replicating overloaded Experts. After replication, the effective load per physical Expert slot (original load divided by the number of replicas) is calculated, and balanced_packing is used again to allocate these physical Expert slots to GPUs within the node, aiming to balance the effective load across GPUs in the same node.

Example

Consider a load matrix: [[90, 132, 40, 61, 104, 165, 39, 4, 73, 56, 183, 86], [20, 107, 104, 64, 19, 197, 187, 157, 172, 86, 16, 27]]. There are two nodes, each with 4 GPUs. How should these loads be allocated to balance the GPU loads as much as possible? The code replicates 4 Experts per layer, resulting in 16 Experts per layer, placed across 2 nodes, each with 4 GPUs.

  • Model: The model has two layers, each with 12 Experts, each with different token loads.
  • Hardware: There are two nodes, each with 4 GPUs, totaling 8 GPUs.
  • Task: Allocate Experts to GPUs to balance loads as much as possible. For example, one layer's load is [90, 132, 40, 61, 104, 165, 39, 4, 73, 56, 183, 86]. Experts 10, 5, 1, and 4 have higher loads, so replicating these Experts can help distribute their computational pressure. Replication should:
    1. Balance the computational load across GPUs.
    2. Minimize communication between nodes.
    3. Keep physical Experts from the same logical Expert in the same node/group to reduce communication overhead.

Step 1: Initial Allocation

Pack the load into groups to obtain group loads: [[262., 330., 116., 325.], [231., 280., 516., 129.]]. Then, use balanced_packing to allocate these grouped loads to packs (nodes) to balance the load across packs as much as possible.

Step 2: Replicate Physical Experts

Within each node, replicate logical Experts with high loads to reduce GPU computational pressure. In each iteration, the function evaluates the average load per logical Expert (weight / logcnt), replicates the Expert with the highest average load (redundant_indices = (weight / logcnt).max(dim=-1).indices), and updates the physical-to-logical Expert mapping (phy2log). This dynamic replication increases parallel processing and distributes high loads across multiple physical devices.

Step 3: Pack Physical Experts to GPUs

Calculate the load per physical Expert (tokens_per_phy). If a logical Expert is replicated into two physical Experts, each physical Expert's load is half the original. Then, use balanced_packing to allocate the physical Experts' loads to GPUs, ensuring balanced allocation per pack.

Verification Code

# Verify that the total load of physical Experts allocated to each GPU is as balanced as possible

tokens_per_phy = \
torch.tensor([[61.0, 52.0, 82.5, 39.0, 4.0, 73.0, 82.5, 52.0],
              [56.0, 91.5, 86.0, 90.0, 66.0, 40.0, 91.5, 66.0],
              [93.5, 157.0, 86.0, 86.0, 16.0, 27.0, 93.5, 86.0],
              [64.0, 19.0, 98.5, 20.0, 53.5, 104.0, 98.5, 53.5]])

pphy2phy = \
torch.tensor([[2, 3, 6, 4, 5, 7, 0, 1],
              [1, 0, 6, 5, 3, 7, 2, 4],
              [1, 4, 0, 7, 6, 5, 2, 3],
              [5, 1, 2, 7, 6, 3, 0, 4]])

# Reorder tokens_per_phy according to pphy2phy to get the load of each physical Expert.

# After reordering, obtain x. Reshape x by the number of layers (2 in this case), then sum the allocations per GPU to get each GPU's load. The loads of the 8 GPUs should be relatively balanced. Each layer has 16 physical Experts, and with 8 GPUs, the load of every 2 physical Experts is summed, then aggregated across all layers to obtain the final load per GPU.

x = tokens_per_phy.gather(-1, pphy2phy)  # Reorder tokens_per_phy by pphy2phy, 4x8
x_by_layer = x.view(weight.shape[0], -1)  # Reshape x by number of layers (2), 2x16
x_by_layer_by_2 = x_by_layer.view(x_by_layer.shape[0], -1, 2)  # Reshape to sum every 2 physical Experts
sum_of_each_2 = x_by_layer_by_2.sum(dim=2)  # Sum the load of every 2 physical Experts
result = sum_of_each_2.sum(dim=0)  # Final load per GPU
print(result)

# tensor([294.5000, 266.0000, 245.5000, 285.0000, 270.5000, 283.5000, 274.5000, 269.5000])

In this example, the loads of the 8 GPUs are [294.5, 266.0, 245.5, 285.0, 270.5, 283.5, 274.5, 269.5]. The loads are relatively balanced, with no GPUs experiencing significantly high or low loads.

from typing import Tuple

import torch
tensor = torch.tensor

def balanced_packing(weight: torch.Tensor, num_packs: int) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Pack n weighted objects to m packs, such that each bin contains exactly n/m objects and the weights of all packs are as balanced as possible.

    Parameters:
        weight: [X, n], the weight of each item
        num_packs: number of packs

    Returns: 
        pack_index: [X, n], the pack index of each item
        rank_in_pack: [X, n], the rank of the item in the pack
    """

    num_layers, num_groups = weight.shape    # 2x4
    assert num_groups % num_packs == 0
    groups_per_pack = num_groups // num_packs  # 2, 每个node有几个group
    # 两层,每层12个log expert,分成4grp,每个grp应该3个expert。
    # 4个grp分到两个node,每个groups_per_pack为2


    if groups_per_pack == 1:
        pack_index = torch.arange(weight.size(-1), dtype=torch.int64, device=weight.device).expand(weight.shape)
        rank_in_pack = torch.zeros_like(weight, dtype=torch.int64)
        return pack_index, rank_in_pack

    indices = weight.float().sort(-1, descending=True).indices.cpu()  # 把weight按行sort,sort得到的index
    pack_index = torch.full_like(weight, fill_value=-1, dtype=torch.int64, device='cpu')
    rank_in_pack = torch.full_like(pack_index, fill_value=-1)

    pack_weights_array = torch.zeros(weight.shape)
    for i in range(num_layers):  # num_layers = 2
        pack_weights = [0] * num_packs
        pack_items = [0] * num_packs
        for group in indices[i]:  # indices是weight排序的index
            pack = min((i for i in range(num_packs) if pack_items[i] < groups_per_pack), 
                       key=pack_weights.__getitem__)  
            """
            这段代码的含义可以分解如下:​
            (i for i in range(num_packs) if pack_items[i] < groups_per_pack) 是一个generator表达式,​它遍历所有包的索引 i,​并选出那些尚未达到最大组数限制(即 pack_items[i] < groups_per_pack)的包。​
            min(..., key=pack_weights.__getitem__) 函数从上述生成器表达式生成的index中,​选择一个使得 pack_weights[i] 最小的索引 i,​即选择当前重量最轻的包。​
            这段代码的整体作用:​在所有尚未满载的包中,选择当前重量最轻的那个包的index,以便将一个新的组分配给它。​
            """
            print(f"num_packs = {num_packs} ")
            print(f"for layer {i}, for weight index {group}, {pack_items[i]}, {groups_per_pack} \n")
            assert pack_items[pack] < groups_per_pack

            pack_index[i, group] = pack
            print(f"for layer {i}, for weight index {group}, the pack_index is {pack_index} \n")
            rank_in_pack[i, group] = pack_items[pack]
            print(f"for layer {i}, for weight index {group}, the rank_in_pack is {rank_in_pack} \n")
            pack_weights[pack] += weight[i, group]
            print(f"for layer {i}, for weight index {group}, pack_weights for pack {pack} is {weight[i, group]} \n")
            pack_items[pack] += 1
            pack_weights_array[i, group] = pack_weights[pack]
    return pack_index, rank_in_pack


def replicate_experts(weight: torch.Tensor, num_phy: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

    """
    当逻辑Expert的数量小于给定的物理Expert数量的时候,复制load最大的expert,根据复制的数量,调整expert的load,然后继续复制直到物理Expert的数量达到给定值。

    Replicate `num_log` experts to `num_phy` replicas, such that the maximum load of all replicas is minimized.

    Parameters:
        weight: [X, num_log]
        num_phy: total number of experts after replication

    Returns:
        phy2log: [X, num_phy], logical expert id of each physical expert
        rank: [X, num_phy], the replica rank
        logcnt: [X, num_log], number of replicas for each logical expert
    """
    n, num_log = weight.shape
    num_redundant = num_phy - num_log
    assert num_redundant >= 0
    device = weight.device
    phy2log = torch.arange(num_phy, dtype=torch.int64, device=device).repeat(n, 1)
    rank = torch.zeros(n, num_phy, dtype=torch.int64, device=device)
    logcnt = torch.ones(n, num_log, dtype=torch.int64, device=device)
    arangen = torch.arange(n, dtype=torch.int64, device=device)
    for i in range(num_log, num_phy):
        redundant_indices = (weight / logcnt).max(dim=-1).indices
        phy2log[:, i] = redundant_indices
        rank[:, i] = logcnt[arangen, redundant_indices]
        logcnt[arangen, redundant_indices] += 1
    return phy2log, rank, logcnt


def rebalance_experts_hierarchical(weight: torch.Tensor, num_physical_experts: int, 
                      num_groups: int, num_nodes: int, num_gpus: int):
    """
    shm: num_physical_experts = num_replicas, from the code below to call this function
    Parameters:
        weight: [num_moe_layers, num_logical_experts]
        num_physical_experts: number of physical experts after replication
        num_groups: number of expert groups
        num_nodes: number of server nodes, where the intra-node network (e.g, NVLink) is faster
        num_gpus: number of GPUs, must be a multiple of `num_nodes`

    Returns: 
        physical_to_logical_map: [num_moe_layers, num_physical_experts]
        logical_to_physical_map: [num_moe_layers, num_logical_experts, X]
        logical_count: [num_moe_layers, num_logical_experts]
    """

    num_layers, num_logical_experts = weight.shape   # 2, 12
    # expert数量必须能被expert group整除,也就是每个group有相同数量的expert
    assert num_logical_experts % num_groups == 0  
    group_size = num_logical_experts // num_groups   # 12 // 4 = 3, 每个group有3个逻辑Expert
    # group数必须能被nodes整除,也就是每个node能放一样多的group. 这里groups_per_node=2
    assert num_groups % num_nodes == 0               #  4 // 2 = 2,每个node有两个group
    groups_per_node = num_groups // num_nodes
    # gpu的数量能被nodes数整除,也就是每个node有一样多的gpu
    assert num_gpus % num_nodes == 0              
    # 物理expert数量能被gpu数量整除,每个gpu有同样多的physical expert
    assert num_physical_experts % num_gpus == 0   
    phy_experts_per_gpu = num_physical_experts // num_gpus

    def inverse(perm: torch.Tensor) -> torch.Tensor:
        inv = torch.empty_like(perm)
        inv.scatter_(1, perm, torch.arange(perm.size(1), dtype=torch.int64, device=perm.device).expand(perm.shape))
        return inv

    # Step 1: pack groups to nodes
    # tokens_per_group: 把load(weight)按group打包,计算每个group的load
    # 把 weight 矩阵每一行reshape成(num_groups, group_size),然后按行求和
    tokens_per_group = weight.unflatten(-1, (num_groups, group_size)).sum(-1)  # 计算每个group总token/load个数
    # 把已经分组好的load分配到每个pack(node)上面去,使每个pack的load尽可能均衡
    # group_rank_in_pack: tokens_per_group按照group_rank_in_pack相加,就得到 pack_weights_array
    # group_pack_index: tokens_per_group 分给哪个node的index
    # group_rank_in_pack: tokens_per_group 对应的优先级
    # pack_weights_array: 追踪每个node按顺序处理其分配的weight时,所累积的总load是多少。
    group_pack_index, group_rank_in_pack = balanced_packing(tokens_per_group, num_nodes) 
    # log2mlog: 经过分组打包负载均衡以后的逻辑Expert的index. 
    log2mlog = (((group_pack_index * groups_per_node + group_rank_in_pack) * group_size).unsqueeze(-1) + 
                torch.arange(group_size, dtype=torch.int64, device=group_pack_index.device)).flatten(-2)
    # mlog2log: inverse index,weight/load按照这个index排序以后,在各个group会尽可能均衡
    mlog2log = inverse(log2mlog)

    # Step 2: construct redundant experts within nodes 第二部,构造冗余Expert来分解计算压力
    # [num_layers * num_nodes, num_logical_experts // num_nodes]
    # weight.gather(-1, mlog2log): 把weight的元素按照mlog2log作为index从新摆放
    tokens_per_mlog = weight.gather(-1, mlog2log).view(-1, num_logical_experts // num_nodes)
    # replicate Expert,返回Expert的index,哪些Expert会被replicate的indicator,以及会被复制成几个
    phy2mlog, phyrank, mlogcnt = replicate_experts(tokens_per_mlog, num_physical_experts // num_nodes) 

    # Step 3: pack physical_experts to GPUs
    # [num_layers * num_nodes, num_physical_experts // num_nodes]
    # 计算每个physical Expert的Load。如果一个逻辑Expert被复制成两个物理Expert,这里认为每个物理Export的load是原来的一半
    tokens_per_phy = (tokens_per_mlog / mlogcnt).gather(-1, phy2mlog)
    # pack_index是把物理Expert的Load打包分配得到的index,每个pack尽可能分配均衡
    pack_index, rank_in_pack = balanced_packing(tokens_per_phy, num_gpus // num_nodes)
    # 计算物理Expert的index
    phy2pphy = pack_index * phy_experts_per_gpu + rank_in_pack
    # tokens_per_phy按照pphy2phy分配以后,每个gpu上面分配的物理expert的load总和尽可能的均衡,见例4
    pphy2phy = inverse(phy2pphy)

    # 物理Expert和逻辑Expert的mapping
    # 原来的comment可能不对,应该是 [num_layers * num_nodes, num_phy_experts//num_nodes]
    pphy2mlog = phy2mlog.gather(-1, pphy2phy) # [num_layers * num_nodes, num_log_per_nodes??]
    pphy2mlog = (pphy2mlog.view(num_layers, num_nodes, -1) + 
                 torch.arange(0, num_logical_experts, num_logical_experts // num_nodes).view(1, -1, 1)).flatten(-2)        # [num_layers, num_phy_experts]
    pphy2log = mlog2log.gather(-1, pphy2mlog)
    pphyrank = phyrank.gather(-1, pphy2phy).view(num_layers, -1)
    logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog)
    return pphy2log, pphyrank, logcnt


def rebalance_experts(weight: torch.Tensor, num_replicas: int, num_groups: int,
                      num_nodes: int, num_gpus: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Entry point for expert-parallelism load balancer.

    Parameters:
        weight: [layers, num_logical_experts], the load statistics for all logical experts
        num_replicas: number of physical experts, must be a multiple of `num_gpus`
        num_groups: number of expert groups
        num_nodes: number of server nodes, where the intra-node network (e.g, NVLink) is faster
        num_gpus: number of GPUs, must be a multiple of `num_nodes`

    Returns: 
        physical_to_logical_map: [layers, num_replicas], the expert index of each replica
        logical_to_physical_map: [layers, num_logical_experts, X], the replica indices for each expert
        expert_count: [layers, num_logical_experts], number of physical replicas for each logical expert
    """
    num_layers, num_logical_experts = weight.shape
    weight = weight.float().cpu()
    if num_groups % num_nodes == 0:
        # use hierarchical load-balance policy
        phy2log, phyrank, logcnt = rebalance_experts_hierarchical(weight, num_replicas, 
                                                                  num_groups, num_nodes, num_gpus)
    else:
        # use global load-balance policy
        phy2log, phyrank, logcnt = replicate_experts(weight, num_replicas)
    maxlogcnt = logcnt.max().item()
    log2phy: torch.Tensor = torch.full((num_layers, num_logical_experts, maxlogcnt), 
                                       -1, dtype=torch.int64, device=logcnt.device)
    log2phy.view(num_layers, -1).scatter_(-1, phy2log * maxlogcnt + phyrank, 
            torch.arange(num_replicas, dtype=torch.int64, device=log2phy.device).expand(num_layers, -1))
    return phy2log, log2phy, logcnt

中文

简介

在前面的DeepSeek-V3的介绍当中,一个很重要的部分就是使用了DeepSeekMoE。使用 Expert Parallel 时,不同的 Expert 会被分配到不同的 GPU。由于不同 Expert 的负载可能因当前工作负载而异,因此保持不同 GPU 的负载均衡至关重要。DeepSeek-MoE的做法是复制高负荷的Expert,然后进行load 优化,使得各个GPU/node的负载尽可能俊很。将重复的 Expert 打包到 GPU 中,以确保不同 GPU 之间的负载均衡。此外,由于 DeepSeek-V3 中使用了Node-limited routing,尽可能将同一组的 Expert 部署到同一节点,以减少节点间的数据流量。

DeepSeek在 eplb.py 中开源了 Expert Parallelism 负载均衡算法来进行 Expert 复制和部署。本文是对这个开源代码的更详细的学习记录。

算法

负载均衡算法包含两种用于不同情况的策略。

分层负载均衡

当服务器节点数量除以 Expert 组数量时,使用分层负载均衡策略来控制 Expert 组受限的路由。首先将 Expert 组均匀地打包到Node,确保不同Node的负载均衡。然后在每个Node内复制 Expert。最后,将复制的 Expert 打包到各个 GPU,以确保不同 GPU 之间的负载均衡。分层负载均衡策略可用于预填充阶段,此时 Expert 并行规模较小。

全局负载均衡

在其他情况下,使用全局负载均衡策略,无论 Expert 组如何,都会全局复制 Expert,并将复制的 Expert 打包到各个 GPU。此策略可用于解码阶段,此时 Expert 并行规模较大。

名词解释

逻辑 Expert:模型设计的Expert。Expert的数量在Model里面已经设计好了。在上面的例子里面,就是每层12个logic Experts

物理 Expert:实际参与计算的在GPU内创建的Exerpt。如果某些logical Expert的Load比较重,他们会被复制。所以物理Expert的数量一般大于等于逻辑Expert的数量。在上面的例子里面,每层有16个物理Expert参加计算。

Expert group:通常指的是将模型中的多个“Expert”进行逻辑上的分组。构建Expert group的目的是为了在分配Expert到不同的Node时,更容易进行粗粒度的负载管理。先在group上面尽可能负载均衡,然后再在每个group内去均衡负载。

num_groups: expert groups的数量。每个 Node 有相同数量的group,每个group有相同数量的Expert,它的目的是把logical Experts放到同一个group里面,使得 tokens per group 尽可能的均衡。group_size = 逻辑Expert的数量//expert groups的数量 = num_logical_experts // num_groups

function解释

  1. balanced_packing函数将 Expert 组分配到不同的“packs”(可以理解nodes/clusters),确保每个pack分配到大致相同数量的 Expert 组,并且这些组的总负载也相对均衡。通过将 Expert 分组,可以先进行粗粒度的组级别负载均衡,然后再在组内进行更细致的分配。

  2. replicate_experts通过动态地复制当前负载较高的逻辑 Expert ,增加了系统的并行处理能力,并且通过副本的分配,潜在地将高负载分散到多个物理设备上,从而实现负载均衡。

  3. rebalance_experts_hierarchical 函数中,num_groups 参数就代表了 Expert 组的数量。代码中首先将逻辑 Expert 划分为若干组 (weight.unflatten(-1, (num_groups, group_size))),然后以 Expert 组为单位进行初步的负载均衡和分配到不同的节点 (balanced_packing(tokens_per_group, num_nodes))。在每个节点内,使用 replicate_experts 为分配到该节点的逻辑 Expert 创建副本,利用该节点上可用的物理 Expert 槽位。此步骤专门针对并缓解每个节点内的负载热点,通过复制本地过载的 Expert 。复制后,计算每个节点内每个物理 Expert 槽位的有效负载(原始负载除以副本数量),然后再次使用 balanced_packing,这次在每个节点内,将这些物理 Expert 槽位分配到该节点内的特定 GPU 上。目标是在同一节点内的 GPU 之间平衡有效负载。

例子

有一个load矩阵如下[[ 90, 132, 40, 61, 104, 165, 39, 4, 73, 56, 183, 86], [ 20, 107, 104, 64, 19, 197, 187, 157, 172, 86, 16, 27]], 现在有两个Node,每个Node有4个GPU。需要把这些Load怎么分配,使得GPU尽可能均衡负载?Code在每层replicate 4个 Expert,这样每层总共有 16 个Expert放置在 2 个Node上,每个Node包含 4 个 GPU。

  • 模型:模型有两个layer,每个layer有12个Expert,每个Expert有不同的 Tokens Load。
  • 硬件:硬件有两个node,每个node有4个GPU。所以总共有8个GPU
  • 任务:把这些Expert allocate到这些GPU上,使得GPU的Load尽可能均衡。比如有一个layer的load如下,[ 90, 132, 40, 61, 104, 165, 39, 4, 73, 56, 183, 86]。 可以看到,第10,5, 1, 4个expert的load比较大. 所以最好能够复制这些Expert用来分散他们的计算压力。复制的时候,尽可能的保证:
    1. 使得每个GPU的计算量尽可能的均衡
    2. 使得每个 Node 之间的通信尽可能的少
    3. 同一个逻辑Expert经过复制以后的物理Expert尽可能的在同一个Node/Group上,这样可以减少communication的开销

步骤1: 初步分配,把load weight打包分发给Node,使得Node的负载尽量均衡

先把 load 根据group打包, 得到group load [[262., 330., 116., 325.], [231., 280., 516., 129.]]; 然后把已经分组好的load分配到每个pack(node)上面去,使每个pack的load尽可能均衡。具体的是通过balanced_packing将一组 Load 打包到指定数量的“packs”(可以理解为计算节点或一组设备),使得每个 pack 包含大致相同数量的对象,并且各个 pack 的总权重尽可能接近。

步骤2: 在每个Node上构造redundent 物理expert来分解GPU计算压力

在每次迭代中,它会评估当前每个逻辑 Expert 的平均负载 (weight / logcnt),然后对当前平均负载最高的逻辑 Expert 进行复制 (redundant_indices = (weight / logcnt).max(dim=-1).indices),然后更新物理 Expert 到逻辑 Expert 的映射 (phy2log)。通过动态地复制当前负载较高的逻辑 Expert ,增加了系统的并行处理能力,将高负载分散到多个物理设备上,从而实现负载均衡。

步骤3: 将物理 Experts pack到GPU上去

计算每个physical Expert的Load tokens_per_phy。如果一个逻辑 Expert 被复制成两个物理 Expert,这里认为每个物理 Expert 的load是原来的一半。然后再通过balanced_packing把物理 Expert 的 Load 打包分配,每个pack尽可能分配均衡。具体的数字参见下面的验证code。

验证code

# 验证每个gpu上面分配的物理expert的load总和是不是尽可能的均衡

tokens_per_phy = \
torch.tensor([[ 61.0,  52.0,  82.5,  39.0,   4.0,  73.0,  82.5, 52.0],
              [ 56.0,  91.5,  86.0,  90.0,  66.0,  40.0,  91.5, 66.0],
              [ 93.5, 157.0,  86.0,  86.0,  16.0,  27.0,  93.5, 86.0],
              [ 64.0,  19.0,  98.5,  20.0,  53.5, 104.0,  98.5, 53.5]])


pphy2phy = \
torch.tensor([[2, 3, 6, 4, 5, 7, 0, 1],
            [1, 0, 6, 5, 3, 7, 2, 4],
            [1, 4, 0, 7, 6, 5, 2, 3],
            [5, 1, 2, 7, 6, 3, 0, 4]])

# tokens_per_phy按照pphy2phy从新排序,得到每个物理Expert的load。

# tokens_per_phy按照pphy2phy从新排序以后,得到下面的x。把x按Layer的个数(此例中=2)展开,然后按照每个GPU上面的allocation相加,就得到每个gpu的负载。8个gpu的负载应该比较均衡。每个layer有16个物理Expert,总共8个gpu,所以每2个物理Expert的负载加在一起,然后再对所有的layer相加,就得到每个gpu的最终负载

x = tokens_per_phy.gather(-1, pphy2phy)     # tokens_per_phy按照pphy2phy从新排序以后 4x8
x_by_layer = x.view(weight.shape[0], -1)    # 把x按Layer的个数(此例中=2)展开, 2x16
x_by_layer_by_2 = x_by_layer.view(x_by_layer.shape[0], -1, 2)   # reshape,准备每2个物理Expert的负载加在一起
sum_of_each_2 = x_by_layer_by_2.sum(dim = 2)                    # 每2个物理Expert的负载加在一起
result = sum_of_each_2.sum(dim=0)           # 每个gpu的最后负载
print(result)

# tensor([294.5000, 266.0000, 245.5000, 285.0000, 270.5000, 283.5000, 274.5000, 269.5000])

在这个例子上,8个GPU的load分别是[294.5, 266.0, 245.5, 285.0, 270.5, 283.5, 274.5, 269.5],每个GPU的负载都差不多,没有出现某些GPU负载很高或者很低的情形。

Reference

  1. Expert Parallelism Load Balancer (EPLB)
  2. DeepSeekMoE Architecture
  3. 原始笔记
  4. draft notes of deepseek-r1
  5. DeepEP communication library