Skip to content

vllm.distributed.kv_transfer.kv_connector.v1.offloading_connector

ReqId module-attribute

ReqId = str

logger module-attribute

logger = init_logger(__name__)

OffloadingConnector

Bases: KVConnectorBase_V1

Source code in vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
class OffloadingConnector(KVConnectorBase_V1):
    @property
    def prefer_cross_layer_blocks(self) -> bool:
        return True

    def __init__(
        self,
        vllm_config: VllmConfig,
        role: KVConnectorRole,
        kv_cache_config: KVCacheConfig | None = None,
    ):
        super().__init__(vllm_config, role, kv_cache_config)

        spec = OffloadingSpecFactory.create_spec(vllm_config, kv_cache_config)

        self.connector_scheduler: OffloadingConnectorScheduler | None = None
        self.connector_worker: OffloadingConnectorWorker | None = None
        if role == KVConnectorRole.SCHEDULER:
            self.connector_scheduler = OffloadingConnectorScheduler(spec)
        elif role == KVConnectorRole.WORKER:
            self.connector_worker = OffloadingConnectorWorker(spec)

    def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
        assert self.connector_worker is not None
        self.connector_worker.register_kv_caches(kv_caches)

    def register_cross_layers_kv_cache(
        self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend]
    ):
        assert self.connector_worker is not None
        self.connector_worker.register_cross_layers_kv_cache(kv_cache, attn_backend)

    def handle_preemptions(self, preempted_req_ids: set[str]):
        assert self.connector_worker is not None
        self.connector_worker.handle_preemptions(preempted_req_ids)

    def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
        assert self.connector_worker is not None
        assert isinstance(self._connector_metadata, OffloadingConnectorMetadata)
        self.connector_worker.start_kv_transfers(self._connector_metadata)

    def wait_for_layer_load(self, layer_name: str) -> None:
        pass

    def save_kv_layer(
        self,
        layer_name: str,
        kv_layer: torch.Tensor,
        attn_metadata: "AttentionMetadata",
        **kwargs,
    ) -> None:
        pass

    def wait_for_save(self):
        assert self.connector_worker is not None
        assert isinstance(self._connector_metadata, OffloadingConnectorMetadata)
        self.connector_worker.prepare_store_kv(self._connector_metadata)

    def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
        assert self.connector_worker is not None
        return self.connector_worker.get_finished(finished_req_ids)

    def get_num_new_matched_tokens(
        self, request: "Request", num_computed_tokens: int
    ) -> tuple[int | None, bool]:
        assert self.connector_scheduler is not None
        return self.connector_scheduler.get_num_new_matched_tokens(
            request, num_computed_tokens
        )

    def update_state_after_alloc(
        self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
    ):
        assert self.connector_scheduler is not None
        return self.connector_scheduler.update_state_after_alloc(
            request, blocks, num_external_tokens
        )

    def build_connector_meta(
        self, scheduler_output: SchedulerOutput
    ) -> KVConnectorMetadata:
        assert self.connector_scheduler is not None
        return self.connector_scheduler.build_connector_meta(scheduler_output)

    def update_connector_output(self, connector_output: KVConnectorOutput):
        assert self.connector_scheduler is not None
        self.connector_scheduler.update_connector_output(connector_output)

    def request_finished(
        self,
        request: "Request",
        block_ids: list[int],
    ) -> tuple[bool, dict[str, Any] | None]:
        assert self.connector_scheduler is not None
        return self.connector_scheduler.request_finished(request, block_ids)

    def take_events(self) -> Iterable[KVCacheEvent]:
        assert self.connector_scheduler is not None
        return self.connector_scheduler.take_events()

connector_scheduler instance-attribute

connector_scheduler: OffloadingConnectorScheduler | None = (
    None
)

connector_worker instance-attribute

connector_worker: OffloadingConnectorWorker | None = None

prefer_cross_layer_blocks property

prefer_cross_layer_blocks: bool

__init__

__init__(
    vllm_config: VllmConfig,
    role: KVConnectorRole,
    kv_cache_config: KVCacheConfig | None = None,
)
Source code in vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
def __init__(
    self,
    vllm_config: VllmConfig,
    role: KVConnectorRole,
    kv_cache_config: KVCacheConfig | None = None,
):
    super().__init__(vllm_config, role, kv_cache_config)

    spec = OffloadingSpecFactory.create_spec(vllm_config, kv_cache_config)

    self.connector_scheduler: OffloadingConnectorScheduler | None = None
    self.connector_worker: OffloadingConnectorWorker | None = None
    if role == KVConnectorRole.SCHEDULER:
        self.connector_scheduler = OffloadingConnectorScheduler(spec)
    elif role == KVConnectorRole.WORKER:
        self.connector_worker = OffloadingConnectorWorker(spec)

build_connector_meta

build_connector_meta(
    scheduler_output: SchedulerOutput,
) -> KVConnectorMetadata
Source code in vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
def build_connector_meta(
    self, scheduler_output: SchedulerOutput
) -> KVConnectorMetadata:
    assert self.connector_scheduler is not None
    return self.connector_scheduler.build_connector_meta(scheduler_output)

get_finished

get_finished(
    finished_req_ids: set[str],
) -> tuple[set[str], set[str]]
Source code in vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
    assert self.connector_worker is not None
    return self.connector_worker.get_finished(finished_req_ids)

get_num_new_matched_tokens

get_num_new_matched_tokens(
    request: Request, num_computed_tokens: int
) -> tuple[int | None, bool]
Source code in vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
def get_num_new_matched_tokens(
    self, request: "Request", num_computed_tokens: int
) -> tuple[int | None, bool]:
    assert self.connector_scheduler is not None
    return self.connector_scheduler.get_num_new_matched_tokens(
        request, num_computed_tokens
    )

handle_preemptions

handle_preemptions(preempted_req_ids: set[str])
Source code in vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
def handle_preemptions(self, preempted_req_ids: set[str]):
    assert self.connector_worker is not None
    self.connector_worker.handle_preemptions(preempted_req_ids)

register_cross_layers_kv_cache

register_cross_layers_kv_cache(
    kv_cache: Tensor, attn_backend: type[AttentionBackend]
)
Source code in vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
def register_cross_layers_kv_cache(
    self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend]
):
    assert self.connector_worker is not None
    self.connector_worker.register_cross_layers_kv_cache(kv_cache, attn_backend)

register_kv_caches

register_kv_caches(kv_caches: dict[str, Tensor])
Source code in vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
    assert self.connector_worker is not None
    self.connector_worker.register_kv_caches(kv_caches)

request_finished

request_finished(
    request: Request, block_ids: list[int]
) -> tuple[bool, dict[str, Any] | None]
Source code in vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
def request_finished(
    self,
    request: "Request",
    block_ids: list[int],
) -> tuple[bool, dict[str, Any] | None]:
    assert self.connector_scheduler is not None
    return self.connector_scheduler.request_finished(request, block_ids)

save_kv_layer

save_kv_layer(
    layer_name: str,
    kv_layer: Tensor,
    attn_metadata: AttentionMetadata,
    **kwargs,
) -> None
Source code in vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
def save_kv_layer(
    self,
    layer_name: str,
    kv_layer: torch.Tensor,
    attn_metadata: "AttentionMetadata",
    **kwargs,
) -> None:
    pass

start_load_kv

start_load_kv(
    forward_context: ForwardContext, **kwargs
) -> None
Source code in vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
    assert self.connector_worker is not None
    assert isinstance(self._connector_metadata, OffloadingConnectorMetadata)
    self.connector_worker.start_kv_transfers(self._connector_metadata)

take_events

take_events() -> Iterable[KVCacheEvent]
Source code in vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
def take_events(self) -> Iterable[KVCacheEvent]:
    assert self.connector_scheduler is not None
    return self.connector_scheduler.take_events()

update_connector_output

update_connector_output(
    connector_output: KVConnectorOutput,
)
Source code in vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
def update_connector_output(self, connector_output: KVConnectorOutput):
    assert self.connector_scheduler is not None
    self.connector_scheduler.update_connector_output(connector_output)

update_state_after_alloc

update_state_after_alloc(
    request: Request,
    blocks: KVCacheBlocks,
    num_external_tokens: int,
)
Source code in vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
def update_state_after_alloc(
    self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
):
    assert self.connector_scheduler is not None
    return self.connector_scheduler.update_state_after_alloc(
        request, blocks, num_external_tokens
    )

wait_for_layer_load

wait_for_layer_load(layer_name: str) -> None
Source code in vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
def wait_for_layer_load(self, layer_name: str) -> None:
    pass

wait_for_save

wait_for_save()
Source code in vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
def wait_for_save(self):
    assert self.connector_worker is not None
    assert isinstance(self._connector_metadata, OffloadingConnectorMetadata)
    self.connector_worker.prepare_store_kv(self._connector_metadata)

OffloadingConnectorMetadata dataclass

Bases: KVConnectorMetadata

Source code in vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
@dataclass
class OffloadingConnectorMetadata(KVConnectorMetadata):
    reqs_to_load: dict[ReqId, TransferSpec]
    reqs_to_store: dict[ReqId, TransferSpec]

reqs_to_load instance-attribute

reqs_to_load: dict[ReqId, TransferSpec]

reqs_to_store instance-attribute

reqs_to_store: dict[ReqId, TransferSpec]

__init__

__init__(
    reqs_to_load: dict[ReqId, TransferSpec],
    reqs_to_store: dict[ReqId, TransferSpec],
) -> None

OffloadingConnectorScheduler

Implementation of Scheduler side methods

Source code in vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
class OffloadingConnectorScheduler:
    """Implementation of Scheduler side methods"""

    def __init__(self, spec: OffloadingSpec):
        self.gpu_block_size = spec.gpu_block_size
        self.offloaded_block_size = spec.offloaded_block_size
        self.block_size_factor = self.offloaded_block_size // self.gpu_block_size
        self.manager: OffloadingManager = spec.get_manager()

        self._requests: dict[ReqId, Request] = {}
        # list of GPU block IDs per request
        self._request_block_ids: dict[ReqId, list[int]] = {}
        # requests to load for the current scheduler step
        self._reqs_to_load: dict[ReqId, TransferSpec] = {}
        # request blocks are stored in order
        # index of next block (of size offloaded_block_size) to offload
        self._next_stored_block_idx: dict[ReqId, int] = {}
        # if GPU prefix caching is enabled,
        # track loaded blocks to avoid redundant loads
        self._blocks_being_loaded: set[BlockHash] | None = (
            set() if spec.vllm_config.cache_config.enable_prefix_caching else None
        )

        # request ID -> set(block hashes being stored/load)
        self._reqs_being_stored = defaultdict[ReqId, set[BlockHash]](set)
        self._reqs_being_loaded = defaultdict[ReqId, set[BlockHash]](set)

    def _get_block_hashes(
        self,
        req: Request,
        start_idx: int = 0,
        end_idx: int | None = None,
    ) -> Iterable[BlockHash]:
        return islice(
            req.block_hashes,
            self.block_size_factor * start_idx + self.block_size_factor - 1,
            self.block_size_factor * end_idx if end_idx else None,
            self.block_size_factor,
        )

    def get_num_new_matched_tokens(
        self, request: Request, num_computed_tokens: int
    ) -> tuple[int | None, bool]:
        """
        Get number of new tokens that can be loaded beyond the
        num_computed_tokens.

        Args:
            request (Request): the request object.
            num_computed_tokens (int): the number of locally
                computed tokens for this request

        Returns:
            A tuple with the following elements:
                - The number of tokens that can be loaded beyond what is
                  already computed.
                  If None, it means that the connector needs more time to
                  determine the number of matched tokens, and the scheduler
                  should query for this request again later.
                - `True` if tokens will be loaded asynchronously
                  (between scheduler steps).
        """
        num_blocks = request.num_tokens // self.offloaded_block_size

        assert len(request.block_hashes) // self.block_size_factor == num_blocks
        block_hashes = self._get_block_hashes(request)

        self.manager.touch(block_hashes)

        full_block_tokens = self.offloaded_block_size * num_blocks
        if full_block_tokens - num_computed_tokens < self.offloaded_block_size:
            # we can load less than a block, skip
            return 0, False

        start_block_idx = num_computed_tokens // self.offloaded_block_size
        hits = self.manager.lookup(
            self._get_block_hashes(request, start_idx=start_block_idx)
        )
        if hits is None:
            # indicates a lookup that should be tried later
            return None, False
        if hits == 0:
            return 0, False

        num_hit_tokens = (
            self.offloaded_block_size * (start_block_idx + hits) - num_computed_tokens
        )
        logger.debug(
            "Request %s hit %s offloaded tokens after %s GPU hit tokens",
            request.request_id,
            num_hit_tokens,
            num_computed_tokens,
        )
        if num_hit_tokens < self.offloaded_block_size:
            return 0, False

        if self._blocks_being_loaded:
            block_hashes = self._get_block_hashes(
                request, start_idx=start_block_idx, end_idx=start_block_idx + hits
            )

            if any(
                block_hash in self._blocks_being_loaded for block_hash in block_hashes
            ):
                # hit blocks are being loaded, delay request
                logger.debug(
                    "Delaying request %s since some of its blocks are already"
                    " being loaded",
                    request.request_id,
                )
                return None, False

        return num_hit_tokens, True

    def update_state_after_alloc(
        self, request: Request, blocks: KVCacheBlocks, num_external_tokens: int
    ):
        self._requests[request.request_id] = request
        # the block ids are updated in _get_reqs_to_store
        self._request_block_ids[request.request_id] = []

        if num_external_tokens == 0:
            return

        block_groups = blocks.get_block_ids()
        block_ids = block_groups[0]

        num_computed_gpu_blocks = sum(
            block.block_hash is not None for block in blocks.blocks[0]
        )
        num_computed_tokens = num_computed_gpu_blocks * self.gpu_block_size
        full_block_tokens = num_computed_tokens + num_external_tokens
        assert full_block_tokens % self.offloaded_block_size == 0

        num_pending_gpu_blocks = len(block_ids) - num_computed_gpu_blocks
        assert num_external_tokens == num_pending_gpu_blocks * self.gpu_block_size

        start_block_idx = num_computed_tokens // self.offloaded_block_size
        num_blocks = full_block_tokens // self.offloaded_block_size

        assert len(request.block_hashes) // self.block_size_factor >= num_blocks
        block_hashes = self._get_block_hashes(
            request, start_idx=start_block_idx, end_idx=num_blocks
        )

        src_spec = self.manager.prepare_load(block_hashes)
        dst_spec = GPULoadStoreSpec(block_ids[num_computed_gpu_blocks:])

        block_hashes = self._get_block_hashes(
            request, start_idx=start_block_idx, end_idx=num_blocks
        )

        self._reqs_to_load[request.request_id] = (src_spec, dst_spec)
        req_blocks_being_loaded = self._reqs_being_loaded[request.request_id]
        req_blocks_being_loaded.update(block_hashes)
        self._next_stored_block_idx[request.request_id] = num_blocks

        if self._blocks_being_loaded is not None:
            self._blocks_being_loaded.update(req_blocks_being_loaded)

    def _get_reqs_to_store(self, scheduler_output: SchedulerOutput):
        reqs_to_store: dict[ReqId, TransferSpec] = {}
        # iterate over both new and cached requests
        for req_id, new_block_id_groups, preempted in yield_req_data(scheduler_output):
            if preempted:
                self._request_block_ids[req_id] = []

            if new_block_id_groups:
                new_block_ids = new_block_id_groups[0]
                self._request_block_ids[req_id] += new_block_ids

            block_ids = self._request_block_ids[req_id]

            req = self._requests[req_id]
            new_tokens = scheduler_output.num_scheduled_tokens[req_id]
            total_tokens = req.num_computed_tokens + new_tokens
            num_blocks = total_tokens // self.offloaded_block_size
            start_block_idx = self._next_stored_block_idx.get(req_id, 0)
            num_new_blocks = num_blocks - start_block_idx

            if num_new_blocks <= 0:
                continue

            # NOTE: In async scheduling, placeholders may temporarily make
            # len(req.block_hashes) < num_blocks * self.block_size_factor.

            new_block_hashes = self._get_block_hashes(
                req, start_idx=start_block_idx, end_idx=num_blocks
            )
            store_output = self.manager.prepare_store(new_block_hashes)
            if store_output is None:
                logger.warning(
                    "Request %s: cannot store %s blocks", req_id, num_new_blocks
                )
                continue

            self._next_stored_block_idx[req_id] = num_blocks

            if not store_output.block_hashes_to_store:
                continue
            block_hashes_to_store = set(store_output.block_hashes_to_store)

            block_hashes = self._get_block_hashes(req, end_idx=num_blocks)
            self.manager.touch(block_hashes)

            new_block_hashes = self._get_block_hashes(
                req, start_idx=start_block_idx, end_idx=num_blocks
            )
            dst_spec = store_output.store_spec
            src_block_ids: list[int] = []
            for idx, blk_hash in enumerate(new_block_hashes):
                if blk_hash not in block_hashes_to_store:
                    continue
                offloaded_block_idx = start_block_idx + idx
                gpu_block_idx = offloaded_block_idx * self.block_size_factor
                for i in range(self.block_size_factor):
                    src_block_ids.append(block_ids[gpu_block_idx + i])
            src_spec = GPULoadStoreSpec(src_block_ids)

            reqs_to_store[req_id] = (src_spec, dst_spec)
            self._reqs_being_stored[req_id] |= block_hashes_to_store

            logger.debug(
                "Request %s offloading %s blocks starting from block #%d",
                req_id,
                len(block_hashes_to_store),
                start_block_idx,
            )

        return reqs_to_store

    def build_connector_meta(
        self, scheduler_output: SchedulerOutput
    ) -> KVConnectorMetadata:
        meta = OffloadingConnectorMetadata(
            reqs_to_load=self._reqs_to_load,
            reqs_to_store=self._get_reqs_to_store(scheduler_output),
        )
        self._reqs_to_load = {}

        # NOTE (orozery): we should move this logic to update_connector_output
        # once KVConnectorOutput allows us to report completed transfers
        for req_id in scheduler_output.preempted_req_ids or ():
            block_hashes = self._reqs_being_stored.get(req_id)
            if block_hashes:
                self.manager.complete_store(block_hashes)
                block_hashes.clear()

        return meta

    def update_connector_output(self, connector_output: KVConnectorOutput):
        """
        Update KVConnector state from worker-side connectors output.

        Args:
            connector_output (KVConnectorOutput): the worker-side
                connectors output.
        """
        for req_id in connector_output.finished_sending or []:
            block_hashes = self._reqs_being_stored.pop(req_id, None)
            if block_hashes:
                self.manager.complete_store(block_hashes)

        for req_id in connector_output.finished_recving or []:
            block_hashes = self._reqs_being_loaded.pop(req_id, None)
            if block_hashes:
                if self._blocks_being_loaded:
                    self._blocks_being_loaded.difference_update(block_hashes)
                self.manager.complete_load(block_hashes)

    def request_finished(
        self,
        request: Request,
        block_ids: list[int],
    ) -> tuple[bool, dict[str, Any] | None]:
        """
        Called when a request has finished, before its blocks are freed.

        Returns:
            True if the request is being saved/sent asynchronously and blocks
            should not be freed until the request_id is returned from
            get_finished().
            Optional KVTransferParams to be included in the request outputs
            returned by the engine.
        """
        req_id = request.request_id
        self._requests.pop(req_id, None)
        self._request_block_ids.pop(req_id, None)
        self._next_stored_block_idx.pop(req_id, None)

        request_being_stored = req_id in self._reqs_being_stored
        return request_being_stored, None

    def take_events(self) -> Iterable[KVCacheEvent]:
        """Take the KV cache events from the connector.

        Returns:
            A list of KV cache events.
        """
        for event in self.manager.take_events():
            if event.removed:
                yield BlockRemoved(block_hashes=event.block_hashes, medium=event.medium)
            else:
                yield BlockStored(
                    block_hashes=event.block_hashes,
                    parent_block_hash=None,
                    token_ids=[],
                    lora_id=None,
                    block_size=event.block_size,
                    medium=event.medium,
                    lora_name=None,
                )

_blocks_being_loaded instance-attribute

_blocks_being_loaded: set[BlockHash] | None = (
    set() if enable_prefix_caching else None
)

_next_stored_block_idx instance-attribute

_next_stored_block_idx: dict[ReqId, int] = {}

_reqs_being_loaded instance-attribute

_reqs_being_loaded = defaultdict[ReqId, set[BlockHash]](set)

_reqs_being_stored instance-attribute

_reqs_being_stored = defaultdict[ReqId, set[BlockHash]](set)

_reqs_to_load instance-attribute

_reqs_to_load: dict[ReqId, TransferSpec] = {}

_request_block_ids instance-attribute

_request_block_ids: dict[ReqId, list[int]] = {}

_requests instance-attribute

_requests: dict[ReqId, Request] = {}

block_size_factor instance-attribute

block_size_factor = offloaded_block_size // gpu_block_size

gpu_block_size instance-attribute

gpu_block_size = gpu_block_size

manager instance-attribute

manager: OffloadingManager = get_manager()

offloaded_block_size instance-attribute

offloaded_block_size = offloaded_block_size

__init__

__init__(spec: OffloadingSpec)
Source code in vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
def __init__(self, spec: OffloadingSpec):
    self.gpu_block_size = spec.gpu_block_size
    self.offloaded_block_size = spec.offloaded_block_size
    self.block_size_factor = self.offloaded_block_size // self.gpu_block_size
    self.manager: OffloadingManager = spec.get_manager()

    self._requests: dict[ReqId, Request] = {}
    # list of GPU block IDs per request
    self._request_block_ids: dict[ReqId, list[int]] = {}
    # requests to load for the current scheduler step
    self._reqs_to_load: dict[ReqId, TransferSpec] = {}
    # request blocks are stored in order
    # index of next block (of size offloaded_block_size) to offload
    self._next_stored_block_idx: dict[ReqId, int] = {}
    # if GPU prefix caching is enabled,
    # track loaded blocks to avoid redundant loads
    self._blocks_being_loaded: set[BlockHash] | None = (
        set() if spec.vllm_config.cache_config.enable_prefix_caching else None
    )

    # request ID -> set(block hashes being stored/load)
    self._reqs_being_stored = defaultdict[ReqId, set[BlockHash]](set)
    self._reqs_being_loaded = defaultdict[ReqId, set[BlockHash]](set)

_get_block_hashes

_get_block_hashes(
    req: Request,
    start_idx: int = 0,
    end_idx: int | None = None,
) -> Iterable[BlockHash]
Source code in vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
def _get_block_hashes(
    self,
    req: Request,
    start_idx: int = 0,
    end_idx: int | None = None,
) -> Iterable[BlockHash]:
    return islice(
        req.block_hashes,
        self.block_size_factor * start_idx + self.block_size_factor - 1,
        self.block_size_factor * end_idx if end_idx else None,
        self.block_size_factor,
    )

_get_reqs_to_store

_get_reqs_to_store(scheduler_output: SchedulerOutput)
Source code in vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
def _get_reqs_to_store(self, scheduler_output: SchedulerOutput):
    reqs_to_store: dict[ReqId, TransferSpec] = {}
    # iterate over both new and cached requests
    for req_id, new_block_id_groups, preempted in yield_req_data(scheduler_output):
        if preempted:
            self._request_block_ids[req_id] = []

        if new_block_id_groups:
            new_block_ids = new_block_id_groups[0]
            self._request_block_ids[req_id] += new_block_ids

        block_ids = self._request_block_ids[req_id]

        req = self._requests[req_id]
        new_tokens = scheduler_output.num_scheduled_tokens[req_id]
        total_tokens = req.num_computed_tokens + new_tokens
        num_blocks = total_tokens // self.offloaded_block_size
        start_block_idx = self._next_stored_block_idx.get(req_id, 0)
        num_new_blocks = num_blocks - start_block_idx

        if num_new_blocks <= 0:
            continue

        # NOTE: In async scheduling, placeholders may temporarily make
        # len(req.block_hashes) < num_blocks * self.block_size_factor.

        new_block_hashes = self._get_block_hashes(
            req, start_idx=start_block_idx, end_idx=num_blocks
        )
        store_output = self.manager.prepare_store(new_block_hashes)
        if store_output is None:
            logger.warning(
                "Request %s: cannot store %s blocks", req_id, num_new_blocks
            )
            continue

        self._next_stored_block_idx[req_id] = num_blocks

        if not store_output.block_hashes_to_store:
            continue
        block_hashes_to_store = set(store_output.block_hashes_to_store)

        block_hashes = self._get_block_hashes(req, end_idx=num_blocks)
        self.manager.touch(block_hashes)

        new_block_hashes = self._get_block_hashes(
            req, start_idx=start_block_idx, end_idx=num_blocks
        )
        dst_spec = store_output.store_spec
        src_block_ids: list[int] = []
        for idx, blk_hash in enumerate(new_block_hashes):
            if blk_hash not in block_hashes_to_store:
                continue
            offloaded_block_idx = start_block_idx + idx
            gpu_block_idx = offloaded_block_idx * self.block_size_factor
            for i in range(self.block_size_factor):
                src_block_ids.append(block_ids[gpu_block_idx + i])
        src_spec = GPULoadStoreSpec(src_block_ids)

        reqs_to_store[req_id] = (src_spec, dst_spec)
        self._reqs_being_stored[req_id] |= block_hashes_to_store

        logger.debug(
            "Request %s offloading %s blocks starting from block #%d",
            req_id,
            len(block_hashes_to_store),
            start_block_idx,
        )

    return reqs_to_store

build_connector_meta

build_connector_meta(
    scheduler_output: SchedulerOutput,
) -> KVConnectorMetadata
Source code in vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
def build_connector_meta(
    self, scheduler_output: SchedulerOutput
) -> KVConnectorMetadata:
    meta = OffloadingConnectorMetadata(
        reqs_to_load=self._reqs_to_load,
        reqs_to_store=self._get_reqs_to_store(scheduler_output),
    )
    self._reqs_to_load = {}

    # NOTE (orozery): we should move this logic to update_connector_output
    # once KVConnectorOutput allows us to report completed transfers
    for req_id in scheduler_output.preempted_req_ids or ():
        block_hashes = self._reqs_being_stored.get(req_id)
        if block_hashes:
            self.manager.complete_store(block_hashes)
            block_hashes.clear()

    return meta

get_num_new_matched_tokens

get_num_new_matched_tokens(
    request: Request, num_computed_tokens: int
) -> tuple[int | None, bool]

Get number of new tokens that can be loaded beyond the num_computed_tokens.

Parameters:

Name Type Description Default
request Request

the request object.

required
num_computed_tokens int

the number of locally computed tokens for this request

required

Returns:

Type Description
tuple[int | None, bool]

A tuple with the following elements: - The number of tokens that can be loaded beyond what is already computed. If None, it means that the connector needs more time to determine the number of matched tokens, and the scheduler should query for this request again later. - True if tokens will be loaded asynchronously (between scheduler steps).

Source code in vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
def get_num_new_matched_tokens(
    self, request: Request, num_computed_tokens: int
) -> tuple[int | None, bool]:
    """
    Get number of new tokens that can be loaded beyond the
    num_computed_tokens.

    Args:
        request (Request): the request object.
        num_computed_tokens (int): the number of locally
            computed tokens for this request

    Returns:
        A tuple with the following elements:
            - The number of tokens that can be loaded beyond what is
              already computed.
              If None, it means that the connector needs more time to
              determine the number of matched tokens, and the scheduler
              should query for this request again later.
            - `True` if tokens will be loaded asynchronously
              (between scheduler steps).
    """
    num_blocks = request.num_tokens // self.offloaded_block_size

    assert len(request.block_hashes) // self.block_size_factor == num_blocks
    block_hashes = self._get_block_hashes(request)

    self.manager.touch(block_hashes)

    full_block_tokens = self.offloaded_block_size * num_blocks
    if full_block_tokens - num_computed_tokens < self.offloaded_block_size:
        # we can load less than a block, skip
        return 0, False

    start_block_idx = num_computed_tokens // self.offloaded_block_size
    hits = self.manager.lookup(
        self._get_block_hashes(request, start_idx=start_block_idx)
    )
    if hits is None:
        # indicates a lookup that should be tried later
        return None, False
    if hits == 0:
        return 0, False

    num_hit_tokens = (
        self.offloaded_block_size * (start_block_idx + hits) - num_computed_tokens
    )
    logger.debug(
        "Request %s hit %s offloaded tokens after %s GPU hit tokens",
        request.request_id,
        num_hit_tokens,
        num_computed_tokens,
    )
    if num_hit_tokens < self.offloaded_block_size:
        return 0, False

    if self._blocks_being_loaded:
        block_hashes = self._get_block_hashes(
            request, start_idx=start_block_idx, end_idx=start_block_idx + hits
        )

        if any(
            block_hash in self._blocks_being_loaded for block_hash in block_hashes
        ):
            # hit blocks are being loaded, delay request
            logger.debug(
                "Delaying request %s since some of its blocks are already"
                " being loaded",
                request.request_id,
            )
            return None, False

    return num_hit_tokens, True

request_finished

request_finished(
    request: Request, block_ids: list[int]
) -> tuple[bool, dict[str, Any] | None]

Called when a request has finished, before its blocks are freed.

Returns:

Type Description
bool

True if the request is being saved/sent asynchronously and blocks

dict[str, Any] | None

should not be freed until the request_id is returned from

tuple[bool, dict[str, Any] | None]

get_finished().

tuple[bool, dict[str, Any] | None]

Optional KVTransferParams to be included in the request outputs

tuple[bool, dict[str, Any] | None]

returned by the engine.

Source code in vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
def request_finished(
    self,
    request: Request,
    block_ids: list[int],
) -> tuple[bool, dict[str, Any] | None]:
    """
    Called when a request has finished, before its blocks are freed.

    Returns:
        True if the request is being saved/sent asynchronously and blocks
        should not be freed until the request_id is returned from
        get_finished().
        Optional KVTransferParams to be included in the request outputs
        returned by the engine.
    """
    req_id = request.request_id
    self._requests.pop(req_id, None)
    self._request_block_ids.pop(req_id, None)
    self._next_stored_block_idx.pop(req_id, None)

    request_being_stored = req_id in self._reqs_being_stored
    return request_being_stored, None

take_events

take_events() -> Iterable[KVCacheEvent]

Take the KV cache events from the connector.

Returns:

Type Description
Iterable[KVCacheEvent]

A list of KV cache events.

Source code in vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
def take_events(self) -> Iterable[KVCacheEvent]:
    """Take the KV cache events from the connector.

    Returns:
        A list of KV cache events.
    """
    for event in self.manager.take_events():
        if event.removed:
            yield BlockRemoved(block_hashes=event.block_hashes, medium=event.medium)
        else:
            yield BlockStored(
                block_hashes=event.block_hashes,
                parent_block_hash=None,
                token_ids=[],
                lora_id=None,
                block_size=event.block_size,
                medium=event.medium,
                lora_name=None,
            )

update_connector_output

update_connector_output(
    connector_output: KVConnectorOutput,
)

Update KVConnector state from worker-side connectors output.

Parameters:

Name Type Description Default
connector_output KVConnectorOutput

the worker-side connectors output.

required
Source code in vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
def update_connector_output(self, connector_output: KVConnectorOutput):
    """
    Update KVConnector state from worker-side connectors output.

    Args:
        connector_output (KVConnectorOutput): the worker-side
            connectors output.
    """
    for req_id in connector_output.finished_sending or []:
        block_hashes = self._reqs_being_stored.pop(req_id, None)
        if block_hashes:
            self.manager.complete_store(block_hashes)

    for req_id in connector_output.finished_recving or []:
        block_hashes = self._reqs_being_loaded.pop(req_id, None)
        if block_hashes:
            if self._blocks_being_loaded:
                self._blocks_being_loaded.difference_update(block_hashes)
            self.manager.complete_load(block_hashes)

update_state_after_alloc

update_state_after_alloc(
    request: Request,
    blocks: KVCacheBlocks,
    num_external_tokens: int,
)
Source code in vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
def update_state_after_alloc(
    self, request: Request, blocks: KVCacheBlocks, num_external_tokens: int
):
    self._requests[request.request_id] = request
    # the block ids are updated in _get_reqs_to_store
    self._request_block_ids[request.request_id] = []

    if num_external_tokens == 0:
        return

    block_groups = blocks.get_block_ids()
    block_ids = block_groups[0]

    num_computed_gpu_blocks = sum(
        block.block_hash is not None for block in blocks.blocks[0]
    )
    num_computed_tokens = num_computed_gpu_blocks * self.gpu_block_size
    full_block_tokens = num_computed_tokens + num_external_tokens
    assert full_block_tokens % self.offloaded_block_size == 0

    num_pending_gpu_blocks = len(block_ids) - num_computed_gpu_blocks
    assert num_external_tokens == num_pending_gpu_blocks * self.gpu_block_size

    start_block_idx = num_computed_tokens // self.offloaded_block_size
    num_blocks = full_block_tokens // self.offloaded_block_size

    assert len(request.block_hashes) // self.block_size_factor >= num_blocks
    block_hashes = self._get_block_hashes(
        request, start_idx=start_block_idx, end_idx=num_blocks
    )

    src_spec = self.manager.prepare_load(block_hashes)
    dst_spec = GPULoadStoreSpec(block_ids[num_computed_gpu_blocks:])

    block_hashes = self._get_block_hashes(
        request, start_idx=start_block_idx, end_idx=num_blocks
    )

    self._reqs_to_load[request.request_id] = (src_spec, dst_spec)
    req_blocks_being_loaded = self._reqs_being_loaded[request.request_id]
    req_blocks_being_loaded.update(block_hashes)
    self._next_stored_block_idx[request.request_id] = num_blocks

    if self._blocks_being_loaded is not None:
        self._blocks_being_loaded.update(req_blocks_being_loaded)

OffloadingConnectorWorker

Implementation of Worker side methods

Source code in vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
class OffloadingConnectorWorker:
    """Implementation of Worker side methods"""

    def __init__(self, spec: OffloadingSpec):
        self.spec = spec
        self.worker = OffloadingWorker()

        self._job_counter = 0

        # job_id -> (req_id, store)
        self._jobs: dict[int, tuple[ReqId, bool]] = {}
        # req_id -> active job IDs
        self._load_job: dict[ReqId, int] = {}
        # req_id -> set(active job IDs)
        self._store_jobs = defaultdict[ReqId, set[int]](set)
        # list of store jobs pending submission (job_id, transfer_spec)
        self._unsubmitted_store_jobs: list[tuple[int, TransferSpec]] = []

        self._finished_reqs_waiting_for_store: set[ReqId] = set()

    def _generate_job_id(self) -> int:
        job_id = self._job_counter
        self._job_counter = job_id + 1
        return job_id

    def _register_handlers(
        self,
        kv_caches: dict[str, torch.Tensor],
        attn_backends: dict[str, type[AttentionBackend]],
    ):
        for src_cls, dst_cls, handler in self.spec.get_handlers(
            kv_caches, attn_backends
        ):
            self.worker.register_handler(src_cls, dst_cls, handler)

    def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
        layer_names = list(kv_caches.keys())
        layers = get_layers_from_vllm_config(
            self.spec.vllm_config, Attention, layer_names
        )
        attn_backends = {
            layer_name: layers[layer_name].get_attn_backend()
            for layer_name in layer_names
        }
        self._register_handlers(kv_caches, attn_backends)

    def register_cross_layers_kv_cache(
        self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend]
    ):
        cross_layer_name = "ALL_LAYERS"
        kv_caches = {cross_layer_name: kv_cache}
        attn_backends = {cross_layer_name: attn_backend}
        self._register_handlers(kv_caches, attn_backends)

    def handle_preemptions(self, preempted_req_ids: set[str]):
        for job_id, transfer_spec in self._unsubmitted_store_jobs:
            success = self.worker.transfer_async(job_id, transfer_spec)
            assert success
        self._unsubmitted_store_jobs.clear()

        for req_id in preempted_req_ids:
            job_ids = self._store_jobs.get(req_id)
            if job_ids:
                self.worker.wait(job_ids)

    def start_kv_transfers(self, metadata: OffloadingConnectorMetadata):
        for job_id, transfer_spec in self._unsubmitted_store_jobs:
            success = self.worker.transfer_async(job_id, transfer_spec)
            assert success
        self._unsubmitted_store_jobs.clear()

        for req_id, transfer_spec in metadata.reqs_to_load.items():
            job_id = self._generate_job_id()
            self._jobs[job_id] = (req_id, False)
            assert req_id not in self._load_job
            self._load_job[req_id] = job_id
            success = self.worker.transfer_async(job_id, transfer_spec)
            assert success

    def prepare_store_kv(self, metadata: OffloadingConnectorMetadata):
        for req_id, transfer_spec in metadata.reqs_to_store.items():
            job_id = self._generate_job_id()
            self._jobs[job_id] = (req_id, True)
            self._store_jobs[req_id].add(job_id)
            # NOTE(orozery): defer the store to the beginning of the next engine step,
            # so that offloading starts AFTER transfers related to token sampling,
            # thereby avoiding delays to token generation due to offloading.
            self._unsubmitted_store_jobs.append((job_id, transfer_spec))

    def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
        """
        Notifies worker-side connector ids of requests that have
        finished generating tokens.
        Returns a list of request IDs that finished loading or storing.

        Returns:
            ids of requests that have finished asynchronous transfer
            tuple of (sending/saving ids, recving/loading ids).
        """
        finished_sending = set()
        finished_recving = set()
        for job_id, success in self.worker.get_finished():
            # we currently do not support job failures
            assert success
            req_id, store = self._jobs.pop(job_id)
            if store:
                req_jobs = self._store_jobs[req_id]
                req_jobs.remove(job_id)
                if req_jobs:
                    continue

                if req_id in self._finished_reqs_waiting_for_store:
                    self._finished_reqs_waiting_for_store.remove(req_id)
                    finished_sending.add(req_id)
                    del self._store_jobs[req_id]
            else:
                req_job = self._load_job[req_id]
                assert job_id == req_job
                del self._load_job[req_id]
                finished_recving.add(req_id)

        for req_id in finished_req_ids:
            pending_req_jobs = self._store_jobs.get(req_id)
            if pending_req_jobs:
                self._finished_reqs_waiting_for_store.add(req_id)
            elif pending_req_jobs is not None:
                finished_sending.add(req_id)
                del self._store_jobs[req_id]

        return finished_sending, finished_recving

_finished_reqs_waiting_for_store instance-attribute

_finished_reqs_waiting_for_store: set[ReqId] = set()

_job_counter instance-attribute

_job_counter = 0

_jobs instance-attribute

_jobs: dict[int, tuple[ReqId, bool]] = {}

_load_job instance-attribute

_load_job: dict[ReqId, int] = {}

_store_jobs instance-attribute

_store_jobs = defaultdict[ReqId, set[int]](set)

_unsubmitted_store_jobs instance-attribute

_unsubmitted_store_jobs: list[tuple[int, TransferSpec]] = []

spec instance-attribute

spec = spec

worker instance-attribute

worker = OffloadingWorker()

__init__

__init__(spec: OffloadingSpec)
Source code in vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
def __init__(self, spec: OffloadingSpec):
    self.spec = spec
    self.worker = OffloadingWorker()

    self._job_counter = 0

    # job_id -> (req_id, store)
    self._jobs: dict[int, tuple[ReqId, bool]] = {}
    # req_id -> active job IDs
    self._load_job: dict[ReqId, int] = {}
    # req_id -> set(active job IDs)
    self._store_jobs = defaultdict[ReqId, set[int]](set)
    # list of store jobs pending submission (job_id, transfer_spec)
    self._unsubmitted_store_jobs: list[tuple[int, TransferSpec]] = []

    self._finished_reqs_waiting_for_store: set[ReqId] = set()

_generate_job_id

_generate_job_id() -> int
Source code in vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
def _generate_job_id(self) -> int:
    job_id = self._job_counter
    self._job_counter = job_id + 1
    return job_id

_register_handlers

_register_handlers(
    kv_caches: dict[str, Tensor],
    attn_backends: dict[str, type[AttentionBackend]],
)
Source code in vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
def _register_handlers(
    self,
    kv_caches: dict[str, torch.Tensor],
    attn_backends: dict[str, type[AttentionBackend]],
):
    for src_cls, dst_cls, handler in self.spec.get_handlers(
        kv_caches, attn_backends
    ):
        self.worker.register_handler(src_cls, dst_cls, handler)

get_finished

get_finished(
    finished_req_ids: set[str],
) -> tuple[set[str], set[str]]

Notifies worker-side connector ids of requests that have finished generating tokens. Returns a list of request IDs that finished loading or storing.

Returns:

Type Description
set[str]

ids of requests that have finished asynchronous transfer

set[str]

tuple of (sending/saving ids, recving/loading ids).

Source code in vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
    """
    Notifies worker-side connector ids of requests that have
    finished generating tokens.
    Returns a list of request IDs that finished loading or storing.

    Returns:
        ids of requests that have finished asynchronous transfer
        tuple of (sending/saving ids, recving/loading ids).
    """
    finished_sending = set()
    finished_recving = set()
    for job_id, success in self.worker.get_finished():
        # we currently do not support job failures
        assert success
        req_id, store = self._jobs.pop(job_id)
        if store:
            req_jobs = self._store_jobs[req_id]
            req_jobs.remove(job_id)
            if req_jobs:
                continue

            if req_id in self._finished_reqs_waiting_for_store:
                self._finished_reqs_waiting_for_store.remove(req_id)
                finished_sending.add(req_id)
                del self._store_jobs[req_id]
        else:
            req_job = self._load_job[req_id]
            assert job_id == req_job
            del self._load_job[req_id]
            finished_recving.add(req_id)

    for req_id in finished_req_ids:
        pending_req_jobs = self._store_jobs.get(req_id)
        if pending_req_jobs:
            self._finished_reqs_waiting_for_store.add(req_id)
        elif pending_req_jobs is not None:
            finished_sending.add(req_id)
            del self._store_jobs[req_id]

    return finished_sending, finished_recving

handle_preemptions

handle_preemptions(preempted_req_ids: set[str])
Source code in vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
def handle_preemptions(self, preempted_req_ids: set[str]):
    for job_id, transfer_spec in self._unsubmitted_store_jobs:
        success = self.worker.transfer_async(job_id, transfer_spec)
        assert success
    self._unsubmitted_store_jobs.clear()

    for req_id in preempted_req_ids:
        job_ids = self._store_jobs.get(req_id)
        if job_ids:
            self.worker.wait(job_ids)

prepare_store_kv

prepare_store_kv(metadata: OffloadingConnectorMetadata)
Source code in vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
def prepare_store_kv(self, metadata: OffloadingConnectorMetadata):
    for req_id, transfer_spec in metadata.reqs_to_store.items():
        job_id = self._generate_job_id()
        self._jobs[job_id] = (req_id, True)
        self._store_jobs[req_id].add(job_id)
        # NOTE(orozery): defer the store to the beginning of the next engine step,
        # so that offloading starts AFTER transfers related to token sampling,
        # thereby avoiding delays to token generation due to offloading.
        self._unsubmitted_store_jobs.append((job_id, transfer_spec))

register_cross_layers_kv_cache

register_cross_layers_kv_cache(
    kv_cache: Tensor, attn_backend: type[AttentionBackend]
)
Source code in vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
def register_cross_layers_kv_cache(
    self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend]
):
    cross_layer_name = "ALL_LAYERS"
    kv_caches = {cross_layer_name: kv_cache}
    attn_backends = {cross_layer_name: attn_backend}
    self._register_handlers(kv_caches, attn_backends)

register_kv_caches

register_kv_caches(kv_caches: dict[str, Tensor])
Source code in vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
    layer_names = list(kv_caches.keys())
    layers = get_layers_from_vllm_config(
        self.spec.vllm_config, Attention, layer_names
    )
    attn_backends = {
        layer_name: layers[layer_name].get_attn_backend()
        for layer_name in layer_names
    }
    self._register_handlers(kv_caches, attn_backends)

start_kv_transfers

start_kv_transfers(metadata: OffloadingConnectorMetadata)
Source code in vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
def start_kv_transfers(self, metadata: OffloadingConnectorMetadata):
    for job_id, transfer_spec in self._unsubmitted_store_jobs:
        success = self.worker.transfer_async(job_id, transfer_spec)
        assert success
    self._unsubmitted_store_jobs.clear()

    for req_id, transfer_spec in metadata.reqs_to_load.items():
        job_id = self._generate_job_id()
        self._jobs[job_id] = (req_id, False)
        assert req_id not in self._load_job
        self._load_job[req_id] = job_id
        success = self.worker.transfer_async(job_id, transfer_spec)
        assert success