Skip to content

vllm.v1.cudagraph_dispatcher

logger module-attribute

logger = init_logger(__name__)

CudagraphDispatcher

Runtime cudagraph dispatcher to dispatch keys for multiple set of cudagraphs.

The dispatcher stores two sets of dispatch keys, one for PIECEWISE and one for FULL cudagraph runtime mode. The keys are initialized depending on attention support and what cudagraph mode is set in CompilationConfig. The keys stored in dispatcher are the only source of truth for valid cudagraphs that can be dispatched at runtime.

At runtime, the dispatch method generates the runtime cudagraph mode (FULL, PIECEWISE, or NONE for no cudagraph) and the valid key (batch descriptor) based on the input key. After dispatching (communicated via forward context), the cudagraph wrappers will trust the dispatch key to either capture or replay (if the mode matches), or pass through to the underlying runnable without cudagraph (if the mode does not match or mode is NONE).

Source code in vllm/v1/cudagraph_dispatcher.py
class CudagraphDispatcher:
    """
    Runtime cudagraph dispatcher to dispatch keys for multiple set of
    cudagraphs.

    The dispatcher stores two sets of dispatch keys, one for PIECEWISE and one
    for FULL cudagraph runtime mode. The keys are initialized depending on
    attention support and what cudagraph mode is set in CompilationConfig. The
    keys stored in dispatcher are the only source of truth for valid
    cudagraphs that can be dispatched at runtime.

    At runtime, the dispatch method generates the runtime cudagraph mode (FULL,
    PIECEWISE, or NONE for no cudagraph) and the valid key (batch descriptor)
    based on the input key. After dispatching (communicated via forward
    context), the cudagraph wrappers will trust the dispatch key to either
    capture or replay (if the mode matches), or pass through to the underlying
    runnable without cudagraph (if the mode does not match or mode is NONE).
    """

    def __init__(self, vllm_config: VllmConfig):
        self.vllm_config = vllm_config
        self.compilation_config = vllm_config.compilation_config
        self.uniform_decode_query_len = (
            1
            if not self.vllm_config.speculative_config
            else 1 + self.vllm_config.speculative_config.num_speculative_tokens
        )

        # Dict to store valid cudagraph dispatching keys.
        self.cudagraph_keys: dict[CUDAGraphMode, set[BatchDescriptor]] = {
            CUDAGraphMode.PIECEWISE: set(),
            CUDAGraphMode.FULL: set(),
        }

        assert (
            not self.compilation_config.cudagraph_mode.requires_piecewise_compilation()
            or self.compilation_config.is_attention_compiled_piecewise()
        ), (
            "Compilation mode should be CompilationMode.VLLM_COMPILE when "
            "cudagraph_mode piecewise cudagraphs is used, "
            "and attention should be in splitting_ops or "
            "inductor splitting should be used. "
            f"cudagraph_mode={self.compilation_config.cudagraph_mode}, "
            f"compilation_mode={self.compilation_config.mode}, "
            f"splitting_ops={self.compilation_config.splitting_ops}"
        )

        self.keys_initialized = False
        # Default cudagraph_mode to NONE until initialize_cudagraph_keys is called
        self.cudagraph_mode = CUDAGraphMode.NONE

    def _compute_bs_to_padded_graph_size(self) -> None:
        """Pre-compute the mapping from batch size to padded graph size."""
        max_size = self.compilation_config.max_cudagraph_capture_size
        capture_sizes = self.compilation_config.cudagraph_capture_sizes
        self._bs_to_padded_graph_size: list[int] = [0] * (max_size + 1)
        for end, start in zip(
            capture_sizes + [max_size + 1],
            [0] + capture_sizes,
        ):
            for bs in range(start, end):
                if bs == start:
                    self._bs_to_padded_graph_size[bs] = start
                else:
                    self._bs_to_padded_graph_size[bs] = end

        # Validate that compile_sizes won't be changed by padding.
        # Only validate when cudagraphs are actually being used.
        if (
            self.compilation_config.compile_sizes
            and self.cudagraph_mode != CUDAGraphMode.NONE
        ):
            for size in self.compilation_config.compile_sizes:
                if size <= self.compilation_config.max_cudagraph_capture_size:
                    padded = self._bs_to_padded_graph_size[size]
                    if padded != size:
                        raise ValueError(
                            f"compile_sizes contains {size} which would be "
                            f"padded to {padded}. All compile_sizes must be "
                            "values that won't be changed by cudagraph padding. "
                            "Use values from cudagraph_capture_sizes."
                        )

    def _create_padded_batch_descriptor(
        self, num_tokens: int, uniform_decode: bool, has_lora: bool
    ) -> BatchDescriptor:
        max_num_seqs = self.vllm_config.scheduler_config.max_num_seqs
        uniform_decode_query_len = self.uniform_decode_query_len
        num_tokens_padded = self._bs_to_padded_graph_size[num_tokens]

        if uniform_decode and self.cudagraph_mode.has_mode(CUDAGraphMode.FULL):
            num_reqs = num_tokens_padded // uniform_decode_query_len
            assert num_tokens_padded % uniform_decode_query_len == 0
        else:
            uniform_decode = False
            num_reqs = min(num_tokens_padded, max_num_seqs)

        return BatchDescriptor(
            num_tokens=num_tokens_padded,
            num_reqs=num_reqs,
            uniform=uniform_decode,
            has_lora=has_lora,
        )

    def add_cudagraph_key(
        self, runtime_mode: CUDAGraphMode, batch_descriptor: BatchDescriptor
    ):
        assert runtime_mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], (
            f"Invalid cudagraph runtime mode for keys: {runtime_mode}"
        )
        self.cudagraph_keys[runtime_mode].add(batch_descriptor)

    def initialize_cudagraph_keys(
        self, cudagraph_mode: CUDAGraphMode, uniform_decode_query_len: int = 1
    ):
        # This should be called only after attention backend is initialized. So we can
        # get the correct cudagraph mode after backend support is resolved.
        self.cudagraph_mode = cudagraph_mode

        # Early exit if cudagraphs are disabled
        if cudagraph_mode == CUDAGraphMode.NONE:
            self.keys_initialized = True
            return

        self._compute_bs_to_padded_graph_size()

        # LoRA activation cases to specialize the cuda graphs on
        if self.vllm_config.lora_config:
            if self.compilation_config.cudagraph_specialize_lora:
                lora_cases = [True, False]
            else:
                lora_cases = [True]
        else:
            lora_cases = [False]

        # Note: we create all valid keys for cudagraph here but do not
        # guarantee all keys would be used. For example, if we allow lazy
        # capturing in future PR, some keys may never be triggered.
        if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
            for bs, has_lora in product(
                self.compilation_config.cudagraph_capture_sizes, lora_cases
            ):
                self.add_cudagraph_key(
                    cudagraph_mode.mixed_mode(),
                    self._create_padded_batch_descriptor(
                        bs, False, has_lora
                    ).relax_for_mixed_batch_cudagraphs(),
                )

        # if decode cudagraph mode is FULL, and we don't already have mixed
        # mode full cudagraphs then add them here.
        if (
            cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
            and cudagraph_mode.separate_routine()
        ):
            max_num_tokens = (
                uniform_decode_query_len
                * self.vllm_config.scheduler_config.max_num_seqs
            )
            cudagraph_capture_sizes_for_decode = [
                x
                for x in self.compilation_config.cudagraph_capture_sizes
                if x <= max_num_tokens and x >= uniform_decode_query_len
            ]
            for bs, has_lora in product(cudagraph_capture_sizes_for_decode, lora_cases):
                self.add_cudagraph_key(
                    CUDAGraphMode.FULL,
                    self._create_padded_batch_descriptor(bs, True, has_lora),
                )

        self.keys_initialized = True

    def dispatch(
        self,
        num_tokens: int,
        uniform_decode: bool = False,
        has_lora: bool = False,
        disable_full: bool = False,
    ) -> tuple[CUDAGraphMode, BatchDescriptor]:
        """
        Given conditions(e.g.,batch descriptor and if using piecewise only),
        dispatch to a cudagraph runtime mode and the valid batch descriptor.
        A new batch descriptor is returned as we might dispatch a uniform batch
        to a graph that supports a more general batch (uniform to non-uniform).

        Args:
            num_tokens: Number of tokens in the batch.
            uniform_decode: Whether the batch is uniform decode (i.e. uniform and query
                length is uniform_decode_query_len).
            has_lora: Whether LoRA is active.
            disable_full: If True, skip FULL cudagraph checks and
                return PIECEWISE or NONE only. (can be used for features like
                cascade attention that are not supported by full cudagraphs)
        """
        if (
            not self.keys_initialized
            or self.cudagraph_mode == CUDAGraphMode.NONE
            or num_tokens > self.compilation_config.max_cudagraph_capture_size
        ):
            return CUDAGraphMode.NONE, BatchDescriptor(num_tokens)

        batch_desc = self._create_padded_batch_descriptor(
            num_tokens, uniform_decode, has_lora
        )
        relaxed_batch_desc = batch_desc.relax_for_mixed_batch_cudagraphs()

        if not disable_full:
            # check if key exists for full cudagraph
            if batch_desc in self.cudagraph_keys[CUDAGraphMode.FULL]:
                return CUDAGraphMode.FULL, batch_desc

            # otherwise, check if the relaxed key exists
            if relaxed_batch_desc in self.cudagraph_keys[CUDAGraphMode.FULL]:
                return CUDAGraphMode.FULL, relaxed_batch_desc

        # also check if the relaxed key exists for more "general"
        # piecewise cudagraph
        if relaxed_batch_desc in self.cudagraph_keys[CUDAGraphMode.PIECEWISE]:
            return CUDAGraphMode.PIECEWISE, relaxed_batch_desc

        # finally, just return no cudagraphs and a trivial batch descriptor
        return CUDAGraphMode.NONE, BatchDescriptor(num_tokens)

    def get_capture_descs(self) -> list[tuple[CUDAGraphMode, list[BatchDescriptor]]]:
        """
        Returns capture descriptors for cudagraph capturing.

        Returns:
            List of (runtime_mode, batch_descriptors) tuples, ordered PIECEWISE
            first then FULL. Batch descriptors are sorted largest-first for
            memory efficiency.
        """
        if not self.keys_initialized or self.cudagraph_mode == CUDAGraphMode.NONE:
            return []

        result = []
        # Return in order: PIECEWISE first, then FULL
        for mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL]:
            descs = list(self.cudagraph_keys[mode])
            if descs:
                # Sort by num_tokens descending (largest first)
                descs.sort(key=lambda d: d.num_tokens, reverse=True)
                result.append((mode, descs))

        return result

compilation_config instance-attribute

compilation_config = compilation_config

cudagraph_keys instance-attribute

cudagraph_keys: dict[
    CUDAGraphMode, set[BatchDescriptor]
] = {PIECEWISE: set(), FULL: set()}

cudagraph_mode instance-attribute

cudagraph_mode = NONE

keys_initialized instance-attribute

keys_initialized = False

uniform_decode_query_len instance-attribute

uniform_decode_query_len = (
    1
    if not speculative_config
    else 1 + num_speculative_tokens
)

vllm_config instance-attribute

vllm_config = vllm_config

__init__

__init__(vllm_config: VllmConfig)
Source code in vllm/v1/cudagraph_dispatcher.py
def __init__(self, vllm_config: VllmConfig):
    self.vllm_config = vllm_config
    self.compilation_config = vllm_config.compilation_config
    self.uniform_decode_query_len = (
        1
        if not self.vllm_config.speculative_config
        else 1 + self.vllm_config.speculative_config.num_speculative_tokens
    )

    # Dict to store valid cudagraph dispatching keys.
    self.cudagraph_keys: dict[CUDAGraphMode, set[BatchDescriptor]] = {
        CUDAGraphMode.PIECEWISE: set(),
        CUDAGraphMode.FULL: set(),
    }

    assert (
        not self.compilation_config.cudagraph_mode.requires_piecewise_compilation()
        or self.compilation_config.is_attention_compiled_piecewise()
    ), (
        "Compilation mode should be CompilationMode.VLLM_COMPILE when "
        "cudagraph_mode piecewise cudagraphs is used, "
        "and attention should be in splitting_ops or "
        "inductor splitting should be used. "
        f"cudagraph_mode={self.compilation_config.cudagraph_mode}, "
        f"compilation_mode={self.compilation_config.mode}, "
        f"splitting_ops={self.compilation_config.splitting_ops}"
    )

    self.keys_initialized = False
    # Default cudagraph_mode to NONE until initialize_cudagraph_keys is called
    self.cudagraph_mode = CUDAGraphMode.NONE

_compute_bs_to_padded_graph_size

_compute_bs_to_padded_graph_size() -> None

Pre-compute the mapping from batch size to padded graph size.

Source code in vllm/v1/cudagraph_dispatcher.py
def _compute_bs_to_padded_graph_size(self) -> None:
    """Pre-compute the mapping from batch size to padded graph size."""
    max_size = self.compilation_config.max_cudagraph_capture_size
    capture_sizes = self.compilation_config.cudagraph_capture_sizes
    self._bs_to_padded_graph_size: list[int] = [0] * (max_size + 1)
    for end, start in zip(
        capture_sizes + [max_size + 1],
        [0] + capture_sizes,
    ):
        for bs in range(start, end):
            if bs == start:
                self._bs_to_padded_graph_size[bs] = start
            else:
                self._bs_to_padded_graph_size[bs] = end

    # Validate that compile_sizes won't be changed by padding.
    # Only validate when cudagraphs are actually being used.
    if (
        self.compilation_config.compile_sizes
        and self.cudagraph_mode != CUDAGraphMode.NONE
    ):
        for size in self.compilation_config.compile_sizes:
            if size <= self.compilation_config.max_cudagraph_capture_size:
                padded = self._bs_to_padded_graph_size[size]
                if padded != size:
                    raise ValueError(
                        f"compile_sizes contains {size} which would be "
                        f"padded to {padded}. All compile_sizes must be "
                        "values that won't be changed by cudagraph padding. "
                        "Use values from cudagraph_capture_sizes."
                    )

_create_padded_batch_descriptor

_create_padded_batch_descriptor(
    num_tokens: int, uniform_decode: bool, has_lora: bool
) -> BatchDescriptor
Source code in vllm/v1/cudagraph_dispatcher.py
def _create_padded_batch_descriptor(
    self, num_tokens: int, uniform_decode: bool, has_lora: bool
) -> BatchDescriptor:
    max_num_seqs = self.vllm_config.scheduler_config.max_num_seqs
    uniform_decode_query_len = self.uniform_decode_query_len
    num_tokens_padded = self._bs_to_padded_graph_size[num_tokens]

    if uniform_decode and self.cudagraph_mode.has_mode(CUDAGraphMode.FULL):
        num_reqs = num_tokens_padded // uniform_decode_query_len
        assert num_tokens_padded % uniform_decode_query_len == 0
    else:
        uniform_decode = False
        num_reqs = min(num_tokens_padded, max_num_seqs)

    return BatchDescriptor(
        num_tokens=num_tokens_padded,
        num_reqs=num_reqs,
        uniform=uniform_decode,
        has_lora=has_lora,
    )

add_cudagraph_key

add_cudagraph_key(
    runtime_mode: CUDAGraphMode,
    batch_descriptor: BatchDescriptor,
)
Source code in vllm/v1/cudagraph_dispatcher.py
def add_cudagraph_key(
    self, runtime_mode: CUDAGraphMode, batch_descriptor: BatchDescriptor
):
    assert runtime_mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], (
        f"Invalid cudagraph runtime mode for keys: {runtime_mode}"
    )
    self.cudagraph_keys[runtime_mode].add(batch_descriptor)

dispatch

dispatch(
    num_tokens: int,
    uniform_decode: bool = False,
    has_lora: bool = False,
    disable_full: bool = False,
) -> tuple[CUDAGraphMode, BatchDescriptor]

Given conditions(e.g.,batch descriptor and if using piecewise only), dispatch to a cudagraph runtime mode and the valid batch descriptor. A new batch descriptor is returned as we might dispatch a uniform batch to a graph that supports a more general batch (uniform to non-uniform).

Parameters:

Name Type Description Default
num_tokens int

Number of tokens in the batch.

required
uniform_decode bool

Whether the batch is uniform decode (i.e. uniform and query length is uniform_decode_query_len).

False
has_lora bool

Whether LoRA is active.

False
disable_full bool

If True, skip FULL cudagraph checks and return PIECEWISE or NONE only. (can be used for features like cascade attention that are not supported by full cudagraphs)

False
Source code in vllm/v1/cudagraph_dispatcher.py
def dispatch(
    self,
    num_tokens: int,
    uniform_decode: bool = False,
    has_lora: bool = False,
    disable_full: bool = False,
) -> tuple[CUDAGraphMode, BatchDescriptor]:
    """
    Given conditions(e.g.,batch descriptor and if using piecewise only),
    dispatch to a cudagraph runtime mode and the valid batch descriptor.
    A new batch descriptor is returned as we might dispatch a uniform batch
    to a graph that supports a more general batch (uniform to non-uniform).

    Args:
        num_tokens: Number of tokens in the batch.
        uniform_decode: Whether the batch is uniform decode (i.e. uniform and query
            length is uniform_decode_query_len).
        has_lora: Whether LoRA is active.
        disable_full: If True, skip FULL cudagraph checks and
            return PIECEWISE or NONE only. (can be used for features like
            cascade attention that are not supported by full cudagraphs)
    """
    if (
        not self.keys_initialized
        or self.cudagraph_mode == CUDAGraphMode.NONE
        or num_tokens > self.compilation_config.max_cudagraph_capture_size
    ):
        return CUDAGraphMode.NONE, BatchDescriptor(num_tokens)

    batch_desc = self._create_padded_batch_descriptor(
        num_tokens, uniform_decode, has_lora
    )
    relaxed_batch_desc = batch_desc.relax_for_mixed_batch_cudagraphs()

    if not disable_full:
        # check if key exists for full cudagraph
        if batch_desc in self.cudagraph_keys[CUDAGraphMode.FULL]:
            return CUDAGraphMode.FULL, batch_desc

        # otherwise, check if the relaxed key exists
        if relaxed_batch_desc in self.cudagraph_keys[CUDAGraphMode.FULL]:
            return CUDAGraphMode.FULL, relaxed_batch_desc

    # also check if the relaxed key exists for more "general"
    # piecewise cudagraph
    if relaxed_batch_desc in self.cudagraph_keys[CUDAGraphMode.PIECEWISE]:
        return CUDAGraphMode.PIECEWISE, relaxed_batch_desc

    # finally, just return no cudagraphs and a trivial batch descriptor
    return CUDAGraphMode.NONE, BatchDescriptor(num_tokens)

get_capture_descs

get_capture_descs() -> list[
    tuple[CUDAGraphMode, list[BatchDescriptor]]
]

Returns capture descriptors for cudagraph capturing.

Returns:

Type Description
list[tuple[CUDAGraphMode, list[BatchDescriptor]]]

List of (runtime_mode, batch_descriptors) tuples, ordered PIECEWISE

list[tuple[CUDAGraphMode, list[BatchDescriptor]]]

first then FULL. Batch descriptors are sorted largest-first for

list[tuple[CUDAGraphMode, list[BatchDescriptor]]]

memory efficiency.

Source code in vllm/v1/cudagraph_dispatcher.py
def get_capture_descs(self) -> list[tuple[CUDAGraphMode, list[BatchDescriptor]]]:
    """
    Returns capture descriptors for cudagraph capturing.

    Returns:
        List of (runtime_mode, batch_descriptors) tuples, ordered PIECEWISE
        first then FULL. Batch descriptors are sorted largest-first for
        memory efficiency.
    """
    if not self.keys_initialized or self.cudagraph_mode == CUDAGraphMode.NONE:
        return []

    result = []
    # Return in order: PIECEWISE first, then FULL
    for mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL]:
        descs = list(self.cudagraph_keys[mode])
        if descs:
            # Sort by num_tokens descending (largest first)
            descs.sort(key=lambda d: d.num_tokens, reverse=True)
            result.append((mode, descs))

    return result

initialize_cudagraph_keys

initialize_cudagraph_keys(
    cudagraph_mode: CUDAGraphMode,
    uniform_decode_query_len: int = 1,
)
Source code in vllm/v1/cudagraph_dispatcher.py
def initialize_cudagraph_keys(
    self, cudagraph_mode: CUDAGraphMode, uniform_decode_query_len: int = 1
):
    # This should be called only after attention backend is initialized. So we can
    # get the correct cudagraph mode after backend support is resolved.
    self.cudagraph_mode = cudagraph_mode

    # Early exit if cudagraphs are disabled
    if cudagraph_mode == CUDAGraphMode.NONE:
        self.keys_initialized = True
        return

    self._compute_bs_to_padded_graph_size()

    # LoRA activation cases to specialize the cuda graphs on
    if self.vllm_config.lora_config:
        if self.compilation_config.cudagraph_specialize_lora:
            lora_cases = [True, False]
        else:
            lora_cases = [True]
    else:
        lora_cases = [False]

    # Note: we create all valid keys for cudagraph here but do not
    # guarantee all keys would be used. For example, if we allow lazy
    # capturing in future PR, some keys may never be triggered.
    if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
        for bs, has_lora in product(
            self.compilation_config.cudagraph_capture_sizes, lora_cases
        ):
            self.add_cudagraph_key(
                cudagraph_mode.mixed_mode(),
                self._create_padded_batch_descriptor(
                    bs, False, has_lora
                ).relax_for_mixed_batch_cudagraphs(),
            )

    # if decode cudagraph mode is FULL, and we don't already have mixed
    # mode full cudagraphs then add them here.
    if (
        cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
        and cudagraph_mode.separate_routine()
    ):
        max_num_tokens = (
            uniform_decode_query_len
            * self.vllm_config.scheduler_config.max_num_seqs
        )
        cudagraph_capture_sizes_for_decode = [
            x
            for x in self.compilation_config.cudagraph_capture_sizes
            if x <= max_num_tokens and x >= uniform_decode_query_len
        ]
        for bs, has_lora in product(cudagraph_capture_sizes_for_decode, lora_cases):
            self.add_cudagraph_key(
                CUDAGraphMode.FULL,
                self._create_padded_batch_descriptor(bs, True, has_lora),
            )

    self.keys_initialized = True