Skip to content

vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass

CutlassFP8ScaledMMLinearKernel

Bases: FP8ScaledMMLinearKernel

Source code in vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py
class CutlassFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
    @classmethod
    def is_supported(
        cls, compute_capability: int | None = None
    ) -> tuple[bool, str | None]:
        if not current_platform.is_cuda():
            return False, "requires CUDA."
        return True, None

    @classmethod
    def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
        return True, None

    def apply_scaled_mm(
        self,
        *,
        A: torch.Tensor,
        B: torch.Tensor,
        out_dtype: torch.dtype,
        As: torch.Tensor,
        Bs: torch.Tensor,
        bias: torch.Tensor | None,
        output_shape: list,
    ) -> torch.Tensor:
        # Fused GEMM_DQ
        output = ops.cutlass_scaled_mm(
            A, B, out_dtype=out_dtype, scale_a=As, scale_b=Bs, bias=bias
        )
        return output.view(*output_shape)

apply_scaled_mm

apply_scaled_mm(
    *,
    A: Tensor,
    B: Tensor,
    out_dtype: dtype,
    As: Tensor,
    Bs: Tensor,
    bias: Tensor | None,
    output_shape: list,
) -> Tensor
Source code in vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py
def apply_scaled_mm(
    self,
    *,
    A: torch.Tensor,
    B: torch.Tensor,
    out_dtype: torch.dtype,
    As: torch.Tensor,
    Bs: torch.Tensor,
    bias: torch.Tensor | None,
    output_shape: list,
) -> torch.Tensor:
    # Fused GEMM_DQ
    output = ops.cutlass_scaled_mm(
        A, B, out_dtype=out_dtype, scale_a=As, scale_b=Bs, bias=bias
    )
    return output.view(*output_shape)

can_implement classmethod

can_implement(
    c: FP8ScaledMMLinearLayerConfig,
) -> tuple[bool, str | None]
Source code in vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py
@classmethod
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
    return True, None

is_supported classmethod

is_supported(
    compute_capability: int | None = None,
) -> tuple[bool, str | None]
Source code in vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py
@classmethod
def is_supported(
    cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
    if not current_platform.is_cuda():
        return False, "requires CUDA."
    return True, None

CutlassInt8ScaledMMLinearKernel

Bases: Int8ScaledMMLinearKernel

Source code in vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py
class CutlassInt8ScaledMMLinearKernel(Int8ScaledMMLinearKernel):
    @classmethod
    def is_supported(
        cls, compute_capability: int | None = None
    ) -> tuple[bool, str | None]:
        if not current_platform.is_cuda():
            return False, "requires CUDA."
        return True, None

    @classmethod
    def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
        return True, None

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = self.layer_param_names
        config = self.config
        # WEIGHT
        # Cutlass kernels need transposed weight.
        weight = getattr(layer, w_q_name)
        replace_parameter(
            layer,
            w_q_name,
            torch.nn.Parameter(weight.t().data, requires_grad=False),
        )

        # WEIGHT SCALE
        # Cutlass kernels support only per-tensor and per-channel.
        # If we have a fused module (QKV, MLP) with per tensor scales (thus N
        # scales being passed to the kernel), convert to the per-channel case.
        is_fused_module = len(layer.logical_widths) > 1
        weight_scale = getattr(layer, w_s_name)
        if is_fused_module and not config.is_channelwise:
            weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
        replace_parameter(
            layer,
            w_s_name,
            torch.nn.Parameter(weight_scale.data, requires_grad=False),
        )

        # INPUT SCALE
        if config.is_static_input_scheme:
            input_scale = getattr(layer, i_s_name)

            if config.input_symmetric:
                replace_parameter(
                    layer,
                    i_s_name,
                    torch.nn.Parameter(input_scale.max(), requires_grad=False),
                )
                setattr(layer, i_zp_name, None)
            else:
                input_zero_point = getattr(layer, i_zp_name)

                # reconstruct the ranges
                int8_traits = torch.iinfo(torch.int8)
                azps = input_zero_point.to(dtype=torch.int32)
                range_max = (input_scale * (int8_traits.max - azps)).max()
                range_min = (input_scale * (int8_traits.min - azps)).min()

                scale = (range_max - range_min) / (int8_traits.max - int8_traits.min)
                replace_parameter(
                    layer, i_s_name, torch.nn.Parameter(scale, requires_grad=False)
                )

                # AZP loaded as int8 but used as int32
                azp = (int8_traits.min - range_min / scale).to(dtype=torch.int32)
                replace_parameter(
                    layer, i_zp_name, torch.nn.Parameter(azp, requires_grad=False)
                )

        # azp_adj is the AZP adjustment term, used to account for weights.
        # It does not depend on scales or azp, so it is the same for
        # static and dynamic quantization.
        # For more details, see csrc/quantization/w8a8/cutlass/Epilogues.md
        # https://github.com/vllm-project/vllm/blob/main/csrc/quantization/w8a8/cutlass/Epilogues.md
        if not config.input_symmetric:
            weight = getattr(layer, w_q_name)
            azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.int32)
            if config.is_static_input_scheme:
                # cutlass_w8a8 requires azp to be folded into azp_adj
                # in the per-tensor case
                azp_adj = getattr(layer, i_zp_name) * azp_adj
            setattr(
                layer,
                azp_adj_name,
                torch.nn.Parameter(azp_adj, requires_grad=False),
            )

    def apply_weights(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
        w_q, w_s, i_s, i_zp, azp_adj = self._get_layer_params(layer)

        # ops.scaled_int8_quant supports both dynamic and static quant:
        # * dynamic, i_s is None and x_s computed from x.
        # * static, i_s is scalar and x_s is i_s.
        symmetric = azp_adj is None
        x_q, x_s, x_zp = ops.scaled_int8_quant(
            x.contiguous(), i_s, i_zp, symmetric=symmetric
        )

        if x_zp is not None:
            # Currently, static is always per-tensor and dynamic is per-token
            static = i_zp is not None
            azp = None if static else x_zp
            return ops.cutlass_scaled_mm_azp(
                x_q,
                w_q,
                scale_a=x_s,
                scale_b=w_s,
                out_dtype=x.dtype,
                azp_adj=azp_adj,
                azp=azp,
                bias=bias,
            )
        return ops.cutlass_scaled_mm(
            x_q, w_q, scale_a=x_s, scale_b=w_s, out_dtype=x.dtype, bias=bias
        )

apply_weights

apply_weights(
    layer: Module, x: Tensor, bias: Tensor | None = None
) -> Tensor
Source code in vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py
def apply_weights(
    self,
    layer: torch.nn.Module,
    x: torch.Tensor,
    bias: torch.Tensor | None = None,
) -> torch.Tensor:
    w_q, w_s, i_s, i_zp, azp_adj = self._get_layer_params(layer)

    # ops.scaled_int8_quant supports both dynamic and static quant:
    # * dynamic, i_s is None and x_s computed from x.
    # * static, i_s is scalar and x_s is i_s.
    symmetric = azp_adj is None
    x_q, x_s, x_zp = ops.scaled_int8_quant(
        x.contiguous(), i_s, i_zp, symmetric=symmetric
    )

    if x_zp is not None:
        # Currently, static is always per-tensor and dynamic is per-token
        static = i_zp is not None
        azp = None if static else x_zp
        return ops.cutlass_scaled_mm_azp(
            x_q,
            w_q,
            scale_a=x_s,
            scale_b=w_s,
            out_dtype=x.dtype,
            azp_adj=azp_adj,
            azp=azp,
            bias=bias,
        )
    return ops.cutlass_scaled_mm(
        x_q, w_q, scale_a=x_s, scale_b=w_s, out_dtype=x.dtype, bias=bias
    )

can_implement classmethod

can_implement(
    c: Int8ScaledMMLinearLayerConfig,
) -> tuple[bool, str | None]
Source code in vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py
@classmethod
def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
    return True, None

is_supported classmethod

is_supported(
    compute_capability: int | None = None,
) -> tuple[bool, str | None]
Source code in vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py
@classmethod
def is_supported(
    cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
    if not current_platform.is_cuda():
        return False, "requires CUDA."
    return True, None

process_weights_after_loading

process_weights_after_loading(layer: Module) -> None
Source code in vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
    w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = self.layer_param_names
    config = self.config
    # WEIGHT
    # Cutlass kernels need transposed weight.
    weight = getattr(layer, w_q_name)
    replace_parameter(
        layer,
        w_q_name,
        torch.nn.Parameter(weight.t().data, requires_grad=False),
    )

    # WEIGHT SCALE
    # Cutlass kernels support only per-tensor and per-channel.
    # If we have a fused module (QKV, MLP) with per tensor scales (thus N
    # scales being passed to the kernel), convert to the per-channel case.
    is_fused_module = len(layer.logical_widths) > 1
    weight_scale = getattr(layer, w_s_name)
    if is_fused_module and not config.is_channelwise:
        weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
    replace_parameter(
        layer,
        w_s_name,
        torch.nn.Parameter(weight_scale.data, requires_grad=False),
    )

    # INPUT SCALE
    if config.is_static_input_scheme:
        input_scale = getattr(layer, i_s_name)

        if config.input_symmetric:
            replace_parameter(
                layer,
                i_s_name,
                torch.nn.Parameter(input_scale.max(), requires_grad=False),
            )
            setattr(layer, i_zp_name, None)
        else:
            input_zero_point = getattr(layer, i_zp_name)

            # reconstruct the ranges
            int8_traits = torch.iinfo(torch.int8)
            azps = input_zero_point.to(dtype=torch.int32)
            range_max = (input_scale * (int8_traits.max - azps)).max()
            range_min = (input_scale * (int8_traits.min - azps)).min()

            scale = (range_max - range_min) / (int8_traits.max - int8_traits.min)
            replace_parameter(
                layer, i_s_name, torch.nn.Parameter(scale, requires_grad=False)
            )

            # AZP loaded as int8 but used as int32
            azp = (int8_traits.min - range_min / scale).to(dtype=torch.int32)
            replace_parameter(
                layer, i_zp_name, torch.nn.Parameter(azp, requires_grad=False)
            )

    # azp_adj is the AZP adjustment term, used to account for weights.
    # It does not depend on scales or azp, so it is the same for
    # static and dynamic quantization.
    # For more details, see csrc/quantization/w8a8/cutlass/Epilogues.md
    # https://github.com/vllm-project/vllm/blob/main/csrc/quantization/w8a8/cutlass/Epilogues.md
    if not config.input_symmetric:
        weight = getattr(layer, w_q_name)
        azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.int32)
        if config.is_static_input_scheme:
            # cutlass_w8a8 requires azp to be folded into azp_adj
            # in the per-tensor case
            azp_adj = getattr(layer, i_zp_name) * azp_adj
        setattr(
            layer,
            azp_adj_name,
            torch.nn.Parameter(azp_adj, requires_grad=False),
        )