Skip to content

vllm.model_executor.layers.fused_moe.router.grouped_topk_router

GroupedTopKRouter

Bases: BaseRouter

Router using grouped top-k routing (e.g., DeepSeekV2/V3).

Source code in vllm/model_executor/layers/fused_moe/router/grouped_topk_router.py
class GroupedTopKRouter(BaseRouter):
    """Router using grouped top-k routing (e.g., DeepSeekV2/V3)."""

    def __init__(
        self,
        top_k: int,
        global_num_experts: int,
        eplb_state: EplbLayerState,
        num_expert_group: int,
        topk_group: int,
        renormalize: bool = True,
        scoring_func: str = "softmax",
        routed_scaling_factor: float = 1.0,
        e_score_correction_bias: torch.Tensor | None = None,
        num_fused_shared_experts: int = 0,
        enable_eplb: bool = False,
        indices_type_getter: Callable[[], torch.dtype | None] | None = None,
    ):
        super().__init__(
            top_k=top_k,
            global_num_experts=global_num_experts,
            eplb_state=eplb_state,
            enable_eplb=enable_eplb,
            indices_type_getter=indices_type_getter,
        )
        self.num_expert_group = num_expert_group
        self.topk_group = topk_group
        self.renormalize = renormalize
        self.scoring_func = scoring_func
        self.routed_scaling_factor = routed_scaling_factor
        self.e_score_correction_bias = e_score_correction_bias
        self.num_fused_shared_experts = num_fused_shared_experts

        if scoring_func == "sigmoid":
            self._routing_method_type = RoutingMethodType.DeepSeekV3
        else:
            # NOTE: this prohibits the FLASHINFER_TRTLLM kernels from
            # being selected, since they only support DeepSeek-style.
            self._routing_method_type = RoutingMethodType.Unspecified

    @property
    def routing_method_type(self) -> RoutingMethodType:
        return self._routing_method_type

    def _compute_routing(
        self,
        hidden_states: torch.Tensor,
        router_logits: torch.Tensor,
        indices_type: torch.dtype | None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Compute routing using grouped top-k."""

        def valid_grouping() -> bool:
            # Check if num_experts is greater than num_expert_group
            # and is divisible by num_expert_group
            num_experts = router_logits.shape[-1]
            if num_experts <= self.num_expert_group:
                return False
            return num_experts % self.num_expert_group == 0

        if not valid_grouping():
            if self.e_score_correction_bias is not None:
                topk_weights, topk_ids = fused_topk_bias(
                    hidden_states=hidden_states,
                    gating_output=router_logits,
                    e_score_correction_bias=self.e_score_correction_bias.data,
                    topk=self.top_k,
                    renormalize=self.renormalize,
                )
                if self.routed_scaling_factor != 1.0:
                    topk_weights *= self.routed_scaling_factor
            else:
                topk_weights, topk_ids, token_expert_indices = fused_topk(
                    hidden_states=hidden_states,
                    gating_output=router_logits,
                    topk=self.top_k,
                    renormalize=self.renormalize,
                    indices_type=indices_type,
                )
            return topk_weights, topk_ids

        # Select grouped_topk implementation
        if rocm_aiter_ops.is_fused_moe_enabled():
            if not rocm_aiter_ops.is_fusion_moe_shared_experts_enabled():
                assert self.num_fused_shared_experts == 0
            grouped_topk_impl = partial(
                rocm_aiter_grouped_topk,
                num_fused_shared_experts=self.num_fused_shared_experts,
            )
        else:
            grouped_topk_impl = grouped_topk

        topk_weights, topk_ids = grouped_topk_impl(
            hidden_states=hidden_states,
            gating_output=router_logits,
            topk=self.top_k,
            renormalize=self.renormalize,
            num_expert_group=self.num_expert_group,
            topk_group=self.topk_group,
            scoring_func=self.scoring_func,
            routed_scaling_factor=self.routed_scaling_factor,
            e_score_correction_bias=self.e_score_correction_bias,
        )

        return topk_weights, topk_ids

_routing_method_type instance-attribute

_routing_method_type = DeepSeekV3

e_score_correction_bias instance-attribute

e_score_correction_bias = e_score_correction_bias

num_expert_group instance-attribute

num_expert_group = num_expert_group

num_fused_shared_experts instance-attribute

num_fused_shared_experts = num_fused_shared_experts

renormalize instance-attribute

renormalize = renormalize

routed_scaling_factor instance-attribute

routed_scaling_factor = routed_scaling_factor

routing_method_type property

routing_method_type: RoutingMethodType

scoring_func instance-attribute

scoring_func = scoring_func

topk_group instance-attribute

topk_group = topk_group

__init__

__init__(
    top_k: int,
    global_num_experts: int,
    eplb_state: EplbLayerState,
    num_expert_group: int,
    topk_group: int,
    renormalize: bool = True,
    scoring_func: str = "softmax",
    routed_scaling_factor: float = 1.0,
    e_score_correction_bias: Tensor | None = None,
    num_fused_shared_experts: int = 0,
    enable_eplb: bool = False,
    indices_type_getter: Callable[[], dtype | None]
    | None = None,
)
Source code in vllm/model_executor/layers/fused_moe/router/grouped_topk_router.py
def __init__(
    self,
    top_k: int,
    global_num_experts: int,
    eplb_state: EplbLayerState,
    num_expert_group: int,
    topk_group: int,
    renormalize: bool = True,
    scoring_func: str = "softmax",
    routed_scaling_factor: float = 1.0,
    e_score_correction_bias: torch.Tensor | None = None,
    num_fused_shared_experts: int = 0,
    enable_eplb: bool = False,
    indices_type_getter: Callable[[], torch.dtype | None] | None = None,
):
    super().__init__(
        top_k=top_k,
        global_num_experts=global_num_experts,
        eplb_state=eplb_state,
        enable_eplb=enable_eplb,
        indices_type_getter=indices_type_getter,
    )
    self.num_expert_group = num_expert_group
    self.topk_group = topk_group
    self.renormalize = renormalize
    self.scoring_func = scoring_func
    self.routed_scaling_factor = routed_scaling_factor
    self.e_score_correction_bias = e_score_correction_bias
    self.num_fused_shared_experts = num_fused_shared_experts

    if scoring_func == "sigmoid":
        self._routing_method_type = RoutingMethodType.DeepSeekV3
    else:
        # NOTE: this prohibits the FLASHINFER_TRTLLM kernels from
        # being selected, since they only support DeepSeek-style.
        self._routing_method_type = RoutingMethodType.Unspecified

_compute_routing

_compute_routing(
    hidden_states: Tensor,
    router_logits: Tensor,
    indices_type: dtype | None,
) -> tuple[Tensor, Tensor]

Compute routing using grouped top-k.

Source code in vllm/model_executor/layers/fused_moe/router/grouped_topk_router.py
def _compute_routing(
    self,
    hidden_states: torch.Tensor,
    router_logits: torch.Tensor,
    indices_type: torch.dtype | None,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Compute routing using grouped top-k."""

    def valid_grouping() -> bool:
        # Check if num_experts is greater than num_expert_group
        # and is divisible by num_expert_group
        num_experts = router_logits.shape[-1]
        if num_experts <= self.num_expert_group:
            return False
        return num_experts % self.num_expert_group == 0

    if not valid_grouping():
        if self.e_score_correction_bias is not None:
            topk_weights, topk_ids = fused_topk_bias(
                hidden_states=hidden_states,
                gating_output=router_logits,
                e_score_correction_bias=self.e_score_correction_bias.data,
                topk=self.top_k,
                renormalize=self.renormalize,
            )
            if self.routed_scaling_factor != 1.0:
                topk_weights *= self.routed_scaling_factor
        else:
            topk_weights, topk_ids, token_expert_indices = fused_topk(
                hidden_states=hidden_states,
                gating_output=router_logits,
                topk=self.top_k,
                renormalize=self.renormalize,
                indices_type=indices_type,
            )
        return topk_weights, topk_ids

    # Select grouped_topk implementation
    if rocm_aiter_ops.is_fused_moe_enabled():
        if not rocm_aiter_ops.is_fusion_moe_shared_experts_enabled():
            assert self.num_fused_shared_experts == 0
        grouped_topk_impl = partial(
            rocm_aiter_grouped_topk,
            num_fused_shared_experts=self.num_fused_shared_experts,
        )
    else:
        grouped_topk_impl = grouped_topk

    topk_weights, topk_ids = grouped_topk_impl(
        hidden_states=hidden_states,
        gating_output=router_logits,
        topk=self.top_k,
        renormalize=self.renormalize,
        num_expert_group=self.num_expert_group,
        topk_group=self.topk_group,
        scoring_func=self.scoring_func,
        routed_scaling_factor=self.routed_scaling_factor,
        e_score_correction_bias=self.e_score_correction_bias,
    )

    return topk_weights, topk_ids

GroupedTopk

Bases: CustomOp

GroupedTopk used by the Deepseek-V2 and Deepseek-V3 model.

Source code in vllm/model_executor/layers/fused_moe/router/grouped_topk_router.py
@CustomOp.register("grouped_topk")
class GroupedTopk(CustomOp):
    """GroupedTopk used by the Deepseek-V2 and Deepseek-V3 model."""

    # --8<-- [end:grouped_topk]

    def __init__(
        self,
        topk: int,
        renormalize: bool,
        num_expert_group: int = 0,
        topk_group: int = 0,
        scoring_func: str = "softmax",
        routed_scaling_factor: float = 1.0,
        num_fused_shared_experts: int = 0,
    ) -> None:
        super().__init__()
        self.native_impl = grouped_topk
        self.topk = topk
        self.renormalize = renormalize
        self.num_expert_group = num_expert_group
        self.topk_group = topk_group
        self.scoring_func = scoring_func
        self.routed_scaling_factor = routed_scaling_factor
        self.num_fused_shared_experts = num_fused_shared_experts

    def forward_native(
        self,
        hidden_states: torch.Tensor,
        gating_output: torch.Tensor,
        e_score_correction_bias: torch.Tensor | None = None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        return self.native_impl(
            hidden_states,
            gating_output,
            self.topk,
            self.renormalize,
            self.num_expert_group,
            self.topk_group,
            self.scoring_func,
            self.routed_scaling_factor,
            e_score_correction_bias,
        )

    def forward_cuda(
        self,
        hidden_states: torch.Tensor,
        gating_output: torch.Tensor,
        e_score_correction_bias: torch.Tensor | None = None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        return self.forward_native(
            hidden_states, gating_output, e_score_correction_bias
        )

    def forward_hip(
        self,
        hidden_states: torch.Tensor,
        gating_output: torch.Tensor,
        e_score_correction_bias: torch.Tensor | None = None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        if rocm_aiter_ops.is_fused_moe_enabled():
            if not rocm_aiter_ops.is_fusion_moe_shared_experts_enabled():
                assert self.num_fused_shared_experts == 0
            return rocm_aiter_grouped_topk(
                hidden_states,
                gating_output,
                self.topk,
                self.renormalize,
                self.num_expert_group,
                self.topk_group,
                self.scoring_func,
                self.routed_scaling_factor,
                e_score_correction_bias,
                self.num_fused_shared_experts,
            )
        else:
            return self.forward_native(
                hidden_states, gating_output, e_score_correction_bias
            )

native_impl instance-attribute

native_impl = grouped_topk

num_expert_group instance-attribute

num_expert_group = num_expert_group

num_fused_shared_experts instance-attribute

num_fused_shared_experts = num_fused_shared_experts

renormalize instance-attribute

renormalize = renormalize

routed_scaling_factor instance-attribute

routed_scaling_factor = routed_scaling_factor

scoring_func instance-attribute

scoring_func = scoring_func

topk instance-attribute

topk = topk

topk_group instance-attribute

topk_group = topk_group

__init__

__init__(
    topk: int,
    renormalize: bool,
    num_expert_group: int = 0,
    topk_group: int = 0,
    scoring_func: str = "softmax",
    routed_scaling_factor: float = 1.0,
    num_fused_shared_experts: int = 0,
) -> None
Source code in vllm/model_executor/layers/fused_moe/router/grouped_topk_router.py
def __init__(
    self,
    topk: int,
    renormalize: bool,
    num_expert_group: int = 0,
    topk_group: int = 0,
    scoring_func: str = "softmax",
    routed_scaling_factor: float = 1.0,
    num_fused_shared_experts: int = 0,
) -> None:
    super().__init__()
    self.native_impl = grouped_topk
    self.topk = topk
    self.renormalize = renormalize
    self.num_expert_group = num_expert_group
    self.topk_group = topk_group
    self.scoring_func = scoring_func
    self.routed_scaling_factor = routed_scaling_factor
    self.num_fused_shared_experts = num_fused_shared_experts

forward_cuda

forward_cuda(
    hidden_states: Tensor,
    gating_output: Tensor,
    e_score_correction_bias: Tensor | None = None,
) -> tuple[Tensor, Tensor]
Source code in vllm/model_executor/layers/fused_moe/router/grouped_topk_router.py
def forward_cuda(
    self,
    hidden_states: torch.Tensor,
    gating_output: torch.Tensor,
    e_score_correction_bias: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
    return self.forward_native(
        hidden_states, gating_output, e_score_correction_bias
    )

forward_hip

forward_hip(
    hidden_states: Tensor,
    gating_output: Tensor,
    e_score_correction_bias: Tensor | None = None,
) -> tuple[Tensor, Tensor]
Source code in vllm/model_executor/layers/fused_moe/router/grouped_topk_router.py
def forward_hip(
    self,
    hidden_states: torch.Tensor,
    gating_output: torch.Tensor,
    e_score_correction_bias: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
    if rocm_aiter_ops.is_fused_moe_enabled():
        if not rocm_aiter_ops.is_fusion_moe_shared_experts_enabled():
            assert self.num_fused_shared_experts == 0
        return rocm_aiter_grouped_topk(
            hidden_states,
            gating_output,
            self.topk,
            self.renormalize,
            self.num_expert_group,
            self.topk_group,
            self.scoring_func,
            self.routed_scaling_factor,
            e_score_correction_bias,
            self.num_fused_shared_experts,
        )
    else:
        return self.forward_native(
            hidden_states, gating_output, e_score_correction_bias
        )

forward_native

forward_native(
    hidden_states: Tensor,
    gating_output: Tensor,
    e_score_correction_bias: Tensor | None = None,
) -> tuple[Tensor, Tensor]
Source code in vllm/model_executor/layers/fused_moe/router/grouped_topk_router.py
def forward_native(
    self,
    hidden_states: torch.Tensor,
    gating_output: torch.Tensor,
    e_score_correction_bias: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
    return self.native_impl(
        hidden_states,
        gating_output,
        self.topk,
        self.renormalize,
        self.num_expert_group,
        self.topk_group,
        self.scoring_func,
        self.routed_scaling_factor,
        e_score_correction_bias,
    )

fused_grouped_topk

fused_grouped_topk(
    hidden_states: Tensor,
    gating_output: Tensor,
    topk: int,
    renormalize: bool,
    e_score_correction_bias: Tensor,
    num_expert_group: int = 0,
    topk_group: int = 0,
    scoring_func: str = "softmax",
    routed_scaling_factor: float = 1.0,
) -> tuple[Tensor, Tensor]
Source code in vllm/model_executor/layers/fused_moe/router/grouped_topk_router.py
def fused_grouped_topk(
    hidden_states: torch.Tensor,
    gating_output: torch.Tensor,
    topk: int,
    renormalize: bool,
    e_score_correction_bias: torch.Tensor,
    num_expert_group: int = 0,
    topk_group: int = 0,
    scoring_func: str = "softmax",
    routed_scaling_factor: float = 1.0,
) -> tuple[torch.Tensor, torch.Tensor]:
    assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch"

    if scoring_func == "sigmoid":
        # Fully fused kernel path for sigmoid
        topk_values, topk_indices = ops.grouped_topk(
            gating_output,  # raw logits
            num_expert_group,
            topk_group,
            topk,
            renormalize,
            routed_scaling_factor,
            e_score_correction_bias,
            1,  # scoring_func=1 for sigmoid
        )
    elif scoring_func == "softmax":
        # Apply softmax in Python, then use fused kernel
        # TODO: Add support for softmax in kernel
        scores = torch.softmax(gating_output, dim=-1)
        topk_values, topk_indices = ops.grouped_topk(
            scores,  # pre-computed scores
            num_expert_group,
            topk_group,
            topk,
            renormalize,
            routed_scaling_factor,
            e_score_correction_bias,
            0,  # scoring_func=0 (no activation, scores already computed)
        )
    else:
        raise ValueError(f"Unsupported scoring function: {scoring_func}")

    # Fused kernel outputs float32 values and int32 indices directly
    return topk_values, topk_indices

grouped_topk

grouped_topk(
    hidden_states: Tensor,
    gating_output: Tensor,
    topk: int,
    renormalize: bool,
    num_expert_group: int = 0,
    topk_group: int = 0,
    scoring_func: str = "softmax",
    routed_scaling_factor: float = 1.0,
    e_score_correction_bias: Tensor | None = None,
) -> tuple[Tensor, Tensor]
Source code in vllm/model_executor/layers/fused_moe/router/grouped_topk_router.py
@torch.compile(
    dynamic=True,
    backend=current_platform.simple_compile_backend,
    options=maybe_disable_graph_partition(current_platform.simple_compile_backend),
)
def grouped_topk(
    hidden_states: torch.Tensor,
    gating_output: torch.Tensor,
    topk: int,
    renormalize: bool,
    num_expert_group: int = 0,
    topk_group: int = 0,
    scoring_func: str = "softmax",
    routed_scaling_factor: float = 1.0,
    e_score_correction_bias: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
    if (
        envs.VLLM_USE_FUSED_MOE_GROUPED_TOPK
        and current_platform.is_cuda()
        and num_expert_group <= 32
        and topk <= 32
        and e_score_correction_bias is not None
    ):
        return fused_grouped_topk(
            hidden_states=hidden_states,
            gating_output=gating_output,
            topk=topk,
            renormalize=renormalize,
            e_score_correction_bias=e_score_correction_bias,
            num_expert_group=num_expert_group,
            topk_group=topk_group,
            scoring_func=scoring_func,
            routed_scaling_factor=routed_scaling_factor,
        )

    assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch"

    if scoring_func == "softmax":
        scores = torch.softmax(gating_output, dim=-1)
    elif scoring_func == "sigmoid":
        scores = gating_output.sigmoid()
    else:
        raise ValueError(f"Unsupported scoring function: {scoring_func}")

    num_token = scores.size(0)
    if e_score_correction_bias is not None:
        # Store original scores before applying correction bias. We use biased
        # scores for expert selection but original scores for routing weights
        original_scores = scores
        scores = scores + e_score_correction_bias.unsqueeze(0)
        group_scores = (
            scores.view(num_token, num_expert_group, -1).topk(2, dim=-1)[0].sum(dim=-1)
        )
    else:
        group_scores = (
            scores.view(num_token, num_expert_group, -1).max(dim=-1).values
        )  # [n, n_group]

    # For batch invariance, use sorted=True to ensure deterministic expert selection
    use_sorted = vllm_is_batch_invariant()
    group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=use_sorted)[
        1
    ]  # [n, top_k_group]
    group_mask = torch.zeros_like(group_scores)  # [n, n_group]
    group_mask.scatter_(1, group_idx, 1)  # [n, n_group]
    score_mask = (
        group_mask.unsqueeze(-1)
        .expand(num_token, num_expert_group, scores.size(-1) // num_expert_group)
        .reshape(num_token, -1)
    )  # [n, e]
    tmp_scores = scores.masked_fill(~score_mask.bool(), float("-inf"))  # [n, e]

    if e_score_correction_bias is not None:
        topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=use_sorted)[1]
        # Use original unbiased scores for the routing weights
        topk_weights = original_scores.gather(1, topk_ids)
    else:
        topk_weights, topk_ids = torch.topk(
            tmp_scores, k=topk, dim=-1, sorted=use_sorted
        )

    if renormalize:
        topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)

    if routed_scaling_factor != 1.0:
        topk_weights = topk_weights * routed_scaling_factor
    return topk_weights.to(torch.float32), topk_ids.to(torch.int32)