Skip to content

vllm.model_executor.layers.fused_moe.router.custom_routing_router

CustomRoutingRouter

Bases: BaseRouter

Router using a custom user-provided routing function.

Source code in vllm/model_executor/layers/fused_moe/router/custom_routing_router.py
class CustomRoutingRouter(BaseRouter):
    """Router using a custom user-provided routing function."""

    def __init__(
        self,
        top_k: int,
        global_num_experts: int,
        eplb_state: EplbLayerState,
        custom_routing_function: Callable,
        renormalize: bool = True,
        enable_eplb: bool = False,
        indices_type_getter: Callable[[], torch.dtype | None] | None = None,
    ):
        super().__init__(
            top_k=top_k,
            global_num_experts=global_num_experts,
            eplb_state=eplb_state,
            enable_eplb=enable_eplb,
            indices_type_getter=indices_type_getter,
        )
        self.custom_routing_function = custom_routing_function
        self.renormalize = renormalize

    @property
    def routing_method_type(self) -> RoutingMethodType:
        from vllm.model_executor.models.llama4 import Llama4MoE

        # NOTE: FLASHINFER_TRTLLM support the Llama4 router.
        if self.custom_routing_function == Llama4MoE.custom_routing_function:
            return RoutingMethodType.Llama4
        return RoutingMethodType.Custom

    def _compute_routing(
        self,
        hidden_states: torch.Tensor,
        router_logits: torch.Tensor,
        indices_type: torch.dtype | None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Compute routing using the custom routing function."""
        topk_weights, topk_ids = self.custom_routing_function(
            hidden_states=hidden_states,
            gating_output=router_logits,
            topk=self.top_k,
            renormalize=self.renormalize,
        )

        return topk_weights.to(torch.float32), topk_ids.to(
            torch.int32 if indices_type is None else indices_type
        )

custom_routing_function instance-attribute

custom_routing_function = custom_routing_function

renormalize instance-attribute

renormalize = renormalize

routing_method_type property

routing_method_type: RoutingMethodType

__init__

__init__(
    top_k: int,
    global_num_experts: int,
    eplb_state: EplbLayerState,
    custom_routing_function: Callable,
    renormalize: bool = True,
    enable_eplb: bool = False,
    indices_type_getter: Callable[[], dtype | None]
    | None = None,
)
Source code in vllm/model_executor/layers/fused_moe/router/custom_routing_router.py
def __init__(
    self,
    top_k: int,
    global_num_experts: int,
    eplb_state: EplbLayerState,
    custom_routing_function: Callable,
    renormalize: bool = True,
    enable_eplb: bool = False,
    indices_type_getter: Callable[[], torch.dtype | None] | None = None,
):
    super().__init__(
        top_k=top_k,
        global_num_experts=global_num_experts,
        eplb_state=eplb_state,
        enable_eplb=enable_eplb,
        indices_type_getter=indices_type_getter,
    )
    self.custom_routing_function = custom_routing_function
    self.renormalize = renormalize

_compute_routing

_compute_routing(
    hidden_states: Tensor,
    router_logits: Tensor,
    indices_type: dtype | None,
) -> tuple[Tensor, Tensor]

Compute routing using the custom routing function.

Source code in vllm/model_executor/layers/fused_moe/router/custom_routing_router.py
def _compute_routing(
    self,
    hidden_states: torch.Tensor,
    router_logits: torch.Tensor,
    indices_type: torch.dtype | None,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Compute routing using the custom routing function."""
    topk_weights, topk_ids = self.custom_routing_function(
        hidden_states=hidden_states,
        gating_output=router_logits,
        topk=self.top_k,
        renormalize=self.renormalize,
    )

    return topk_weights.to(torch.float32), topk_ids.to(
        torch.int32 if indices_type is None else indices_type
    )