Skip to content

vllm.v1.worker.gpu.mm.mrope_utils

MRopeState

Source code in vllm/v1/worker/gpu/mm/mrope_utils.py
class MRopeState:
    def __init__(
        self,
        max_num_reqs: int,
        max_num_tokens: int,
        max_model_len: int,
        device: torch.device,
    ):
        self.max_num_reqs = max_num_reqs
        self.max_num_tokens = max_num_tokens
        self.max_model_len = max_model_len
        self.device = device

        # NOTE(woosuk): This tensor can be extremely large (e.g., several GBs)
        # wasting a lot of CPU memory.
        self.prefill_mrope_positions = StagedWriteTensor(
            (max_num_reqs * 3, max_model_len),
            dtype=torch.int32,
            device=device,
            uva_instead_of_gpu=True,
        )
        self.prefill_mrope_delta = UvaBackedTensor(max_num_reqs, dtype=torch.int32)

        # NOTE: `mrope_positions` is implemented with one additional dummy
        # position on purpose to make it non-contiguous so that it can work
        # with torch compile.
        # See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923
        # NOTE: When M-RoPE is enabled, position ids are 3D regardless of
        # the modality of inputs. For text-only inputs, each dimension has
        # identical position IDs, making M-RoPE functionally equivalent to
        # 1D-RoPE.
        # See page 5 of https://arxiv.org/abs/2409.12191
        self.mrope_positions = torch.zeros(
            (3, max_num_tokens + 1), dtype=torch.int64, device=device
        )

    def init_prefill_mrope_positions(
        self,
        req_idx: int,
        mrope_model: SupportsMRoPE,
        prefill_token_ids: list[int],
        mm_features: list,
    ) -> None:
        prefill_mrope_positions, prefill_mrope_delta = (
            mrope_model.get_mrope_input_positions(
                prefill_token_ids,
                mm_features,
            )
        )
        for i in range(3):
            pos = prefill_mrope_positions[i].tolist()
            self.prefill_mrope_positions.stage_write(3 * req_idx + i, 0, pos)
        self.prefill_mrope_delta.np[req_idx] = prefill_mrope_delta

    def apply_staged_writes(self) -> None:
        self.prefill_mrope_positions.apply_write()
        self.prefill_mrope_delta.copy_to_uva()

    def prepare_mrope_positions(
        self,
        idx_mapping: torch.Tensor,
        query_start_loc: torch.Tensor,
        prefill_lens: torch.Tensor,
        num_computed_tokens: torch.Tensor,
    ) -> None:
        num_reqs = idx_mapping.shape[0]
        _prepare_mrope_positions_kernel[(num_reqs,)](
            self.mrope_positions,
            self.mrope_positions.stride(0),
            self.prefill_mrope_positions.gpu,
            3 * self.max_model_len,
            self.max_model_len,
            self.prefill_mrope_delta.gpu,
            idx_mapping,
            query_start_loc,
            prefill_lens,
            num_computed_tokens,
            BLOCK_SIZE=1024,
        )

device instance-attribute

device = device

max_model_len instance-attribute

max_model_len = max_model_len

max_num_reqs instance-attribute

max_num_reqs = max_num_reqs

max_num_tokens instance-attribute

max_num_tokens = max_num_tokens

mrope_positions instance-attribute

mrope_positions = zeros(
    (3, max_num_tokens + 1), dtype=int64, device=device
)

prefill_mrope_delta instance-attribute

prefill_mrope_delta = UvaBackedTensor(
    max_num_reqs, dtype=int32
)

prefill_mrope_positions instance-attribute

prefill_mrope_positions = StagedWriteTensor(
    (max_num_reqs * 3, max_model_len),
    dtype=int32,
    device=device,
    uva_instead_of_gpu=True,
)

__init__

__init__(
    max_num_reqs: int,
    max_num_tokens: int,
    max_model_len: int,
    device: device,
)
Source code in vllm/v1/worker/gpu/mm/mrope_utils.py
def __init__(
    self,
    max_num_reqs: int,
    max_num_tokens: int,
    max_model_len: int,
    device: torch.device,
):
    self.max_num_reqs = max_num_reqs
    self.max_num_tokens = max_num_tokens
    self.max_model_len = max_model_len
    self.device = device

    # NOTE(woosuk): This tensor can be extremely large (e.g., several GBs)
    # wasting a lot of CPU memory.
    self.prefill_mrope_positions = StagedWriteTensor(
        (max_num_reqs * 3, max_model_len),
        dtype=torch.int32,
        device=device,
        uva_instead_of_gpu=True,
    )
    self.prefill_mrope_delta = UvaBackedTensor(max_num_reqs, dtype=torch.int32)

    # NOTE: `mrope_positions` is implemented with one additional dummy
    # position on purpose to make it non-contiguous so that it can work
    # with torch compile.
    # See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923
    # NOTE: When M-RoPE is enabled, position ids are 3D regardless of
    # the modality of inputs. For text-only inputs, each dimension has
    # identical position IDs, making M-RoPE functionally equivalent to
    # 1D-RoPE.
    # See page 5 of https://arxiv.org/abs/2409.12191
    self.mrope_positions = torch.zeros(
        (3, max_num_tokens + 1), dtype=torch.int64, device=device
    )

apply_staged_writes

apply_staged_writes() -> None
Source code in vllm/v1/worker/gpu/mm/mrope_utils.py
def apply_staged_writes(self) -> None:
    self.prefill_mrope_positions.apply_write()
    self.prefill_mrope_delta.copy_to_uva()

init_prefill_mrope_positions

init_prefill_mrope_positions(
    req_idx: int,
    mrope_model: SupportsMRoPE,
    prefill_token_ids: list[int],
    mm_features: list,
) -> None
Source code in vllm/v1/worker/gpu/mm/mrope_utils.py
def init_prefill_mrope_positions(
    self,
    req_idx: int,
    mrope_model: SupportsMRoPE,
    prefill_token_ids: list[int],
    mm_features: list,
) -> None:
    prefill_mrope_positions, prefill_mrope_delta = (
        mrope_model.get_mrope_input_positions(
            prefill_token_ids,
            mm_features,
        )
    )
    for i in range(3):
        pos = prefill_mrope_positions[i].tolist()
        self.prefill_mrope_positions.stage_write(3 * req_idx + i, 0, pos)
    self.prefill_mrope_delta.np[req_idx] = prefill_mrope_delta

prepare_mrope_positions

prepare_mrope_positions(
    idx_mapping: Tensor,
    query_start_loc: Tensor,
    prefill_lens: Tensor,
    num_computed_tokens: Tensor,
) -> None
Source code in vllm/v1/worker/gpu/mm/mrope_utils.py
def prepare_mrope_positions(
    self,
    idx_mapping: torch.Tensor,
    query_start_loc: torch.Tensor,
    prefill_lens: torch.Tensor,
    num_computed_tokens: torch.Tensor,
) -> None:
    num_reqs = idx_mapping.shape[0]
    _prepare_mrope_positions_kernel[(num_reqs,)](
        self.mrope_positions,
        self.mrope_positions.stride(0),
        self.prefill_mrope_positions.gpu,
        3 * self.max_model_len,
        self.max_model_len,
        self.prefill_mrope_delta.gpu,
        idx_mapping,
        query_start_loc,
        prefill_lens,
        num_computed_tokens,
        BLOCK_SIZE=1024,
    )

_prepare_mrope_positions_kernel

_prepare_mrope_positions_kernel(
    mrope_positions_ptr,
    mrope_positions_stride,
    prefill_mrope_positions_ptr,
    prefill_mrope_positions_stride0,
    prefill_mrope_positions_stride1,
    prefill_mrope_delta_ptr,
    idx_mapping_ptr,
    query_start_loc_ptr,
    prefill_lens_ptr,
    num_computed_tokens_ptr,
    BLOCK_SIZE: constexpr,
)
Source code in vllm/v1/worker/gpu/mm/mrope_utils.py
@triton.jit
def _prepare_mrope_positions_kernel(
    mrope_positions_ptr,
    mrope_positions_stride,
    prefill_mrope_positions_ptr,
    prefill_mrope_positions_stride0,
    prefill_mrope_positions_stride1,
    prefill_mrope_delta_ptr,
    idx_mapping_ptr,
    query_start_loc_ptr,
    prefill_lens_ptr,
    num_computed_tokens_ptr,
    BLOCK_SIZE: tl.constexpr,
):
    batch_idx = tl.program_id(0)
    req_state_idx = tl.load(idx_mapping_ptr + batch_idx)

    prefill_len = tl.load(prefill_lens_ptr + req_state_idx)
    num_computed = tl.load(num_computed_tokens_ptr + req_state_idx)
    is_prefill = num_computed < prefill_len

    query_start = tl.load(query_start_loc_ptr + batch_idx)
    query_end = tl.load(query_start_loc_ptr + batch_idx + 1)
    query_len = query_end - query_start

    mrope_delta = tl.load(prefill_mrope_delta_ptr + req_state_idx)
    for i in range(0, query_len, BLOCK_SIZE):
        block = i + tl.arange(0, BLOCK_SIZE)
        mask = block < query_len
        orig_pos = num_computed + block

        for j in tl.static_range(3):
            if is_prefill:
                # Read from pre-computed M-RoPE positions.
                pos = tl.load(
                    prefill_mrope_positions_ptr
                    + req_state_idx * prefill_mrope_positions_stride0
                    + j * prefill_mrope_positions_stride1
                    + orig_pos,
                    mask=mask,
                )
            else:
                # Apply M-RoPE delta.
                pos = orig_pos + mrope_delta
            tl.store(
                mrope_positions_ptr + j * mrope_positions_stride + query_start + block,
                pos,
                mask=mask,
            )