Skip to content

vllm.attention

Modules:

Name Description
backends
layer

Attention layer.

layers
ops
selector
utils

__all__ module-attribute

__all__ = [
    "Attention",
    "AttentionBackend",
    "AttentionMetadata",
    "AttentionType",
    "get_attn_backend",
]

Attention

Bases: Module, AttentionLayerBase

Attention layer.

This class takes query, key, and value tensors as input. The input tensors can either contain prompt tokens or generation tokens. The class does the following:

  1. Store the input key and value tensors in the KV cache.
  2. Perform (multi-head/multi-query/grouped-query) attention.
  3. Return the output tensor.
Source code in vllm/attention/layer.py
class Attention(nn.Module, AttentionLayerBase):
    """Attention layer.

    This class takes query, key, and value tensors as input. The input tensors
    can either contain prompt tokens or generation tokens.
    The class does the following:

    1. Store the input key and value tensors in the KV cache.
    2. Perform (multi-head/multi-query/grouped-query) attention.
    3. Return the output tensor.
    """

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int | None = None,
        alibi_slopes: list[float] | None = None,
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
        logits_soft_cap: float | None = None,
        per_layer_sliding_window: int | None = None,
        prefix: str = "",
        attn_type: str = AttentionType.DECODER,
        kv_sharing_target_layer_name: str | None = None,
        attn_backend: type[AttentionBackend] | None = None,
        **extra_impl_args,
    ) -> None:
        """
        The KV cache is stored inside this class and is accessed via
        `self.kv_cache`.
        """
        super().__init__()
        if per_layer_sliding_window is not None:
            # per-layer sliding window
            sliding_window = per_layer_sliding_window
        elif cache_config is not None:
            # model-level sliding window
            sliding_window = cache_config.sliding_window
        else:
            sliding_window = None

        vllm_config = get_current_vllm_config()
        if cache_config is not None:
            kv_cache_dtype = cache_config.cache_dtype
            block_size = cache_config.block_size
            calculate_kv_scales = cache_config.calculate_kv_scales
        else:
            kv_cache_dtype = "auto"
            block_size = 16
            calculate_kv_scales = False
        self.kv_cache_torch_dtype = kv_cache_dtype_str_to_dtype(
            kv_cache_dtype, vllm_config.model_config
        )
        if num_kv_heads is None:
            num_kv_heads = num_heads
        assert num_heads % num_kv_heads == 0, (
            f"num_heads ({num_heads}) is not divisible by num_kv_heads ({num_kv_heads})"
        )

        # Initialize KV cache quantization attributes
        _init_kv_cache_quant(
            self, quant_config, prefix, kv_cache_dtype, calculate_kv_scales
        )

        self.num_heads = num_heads
        self.head_size = head_size
        self.num_kv_heads = num_kv_heads
        self.sliding_window = sliding_window
        self.has_sink = extra_impl_args.get("sinks") is not None

        # During model initialization, the default dtype is set as the model
        # weight and activation dtype.
        dtype = torch.get_default_dtype()
        if attn_backend is None:
            self.attn_backend = get_attn_backend(
                head_size,
                dtype,
                kv_cache_dtype,
                block_size,
                use_mla=False,
                has_sink=self.has_sink,
            )
        else:
            self.attn_backend = attn_backend

        impl_cls = self.attn_backend.get_impl_cls()
        self.impl = impl_cls(
            num_heads,
            head_size,
            scale,
            num_kv_heads,
            alibi_slopes,
            sliding_window,
            kv_cache_dtype,
            logits_soft_cap,
            attn_type,
            kv_sharing_target_layer_name,
            **extra_impl_args,
        )
        self.backend = AttentionBackendEnum[self.attn_backend.get_name()]
        self.dtype = dtype

        # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
        # torch.compile works by registering the attention as one giant
        # opaque custom op. For other platforms, we directly call them
        # and let torch.compile handle them.
        self.use_direct_call = not current_platform.opaque_attention_op()

        self.use_output = self.attn_backend.accept_output_buffer
        compilation_config = vllm_config.compilation_config
        if prefix in compilation_config.static_forward_context:
            raise ValueError(f"Duplicate layer name: {prefix}")
        compilation_config.static_forward_context[prefix] = self
        self.layer_name = prefix
        self.attn_type = attn_type

        if kv_sharing_target_layer_name is not None:
            validate_kv_sharing_target(
                prefix,
                kv_sharing_target_layer_name,
                compilation_config.static_forward_context,
            )
        self.kv_sharing_target_layer_name = kv_sharing_target_layer_name

        # use a placeholder kv cache tensor during init, which will be replaced
        # by bind_kv_cache
        # this variable will not be accessed if use_direct_call is True
        self.kv_cache = [
            torch.tensor([])
            for _ in range(vllm_config.parallel_config.pipeline_parallel_size)
        ]

        # Initialize q/k/v range constants.
        self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
        self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
        self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)

        # for attn backends supporting query quantization
        self.query_quant = None
        if (
            self.kv_cache_dtype.startswith("fp8")
            and self.impl.supports_quant_query_input()
        ):
            self.query_quant = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR)

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        # For some alternate attention backends like MLA the attention output
        # shape does not match the query shape, so we optionally let the model
        # definition specify the output tensor shape.
        output_shape: torch.Size | None = None,
    ) -> torch.Tensor:
        """
        The KV cache is stored inside this class and is accessed via
        `self.kv_cache`.

        Attention metadata (`attn_metadata`) is set using a context manager in
        the model runner's `execute_model` method. It is accessed via forward
        context using
        `vllm.forward_context.get_forward_context().attn_metadata`.
        """
        if self.calculate_kv_scales:
            torch.ops.vllm.maybe_calc_kv_scales(query, key, value, self.layer_name)
        output_dtype = query.dtype
        if self.query_quant is not None:
            # quantizing with a simple torch operation enables
            # torch.compile to fuse this into previous ops
            # which reduces overheads during decoding.
            # Otherwise queries are quantized using custom ops
            # which causes decoding overheads
            assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"}

            # check if query quantization is supported
            if self.impl.supports_quant_query_input():
                query, _ = self.query_quant(query, self._q_scale)

        if self.use_output:
            output_shape = output_shape if output_shape is not None else query.shape
            output = torch.empty(output_shape, dtype=output_dtype, device=query.device)
            hidden_size = output_shape[-1]
            # Reshape the query, key, and value tensors.
            # NOTE(woosuk): We do this outside the custom op to minimize the
            # CPU overheads from the non-CUDA-graph regions.
            query = query.view(-1, self.num_heads, self.head_size)
            output = output.view(-1, self.num_heads, self.head_size)
            if key is not None:
                key = key.view(-1, self.num_kv_heads, self.head_size)
            if value is not None:
                value = value.view(-1, self.num_kv_heads, self.head_size)
            if self.use_direct_call:
                forward_context: ForwardContext = get_forward_context()
                attn_metadata = forward_context.attn_metadata
                if isinstance(attn_metadata, dict):
                    attn_metadata = attn_metadata[self.layer_name]
                self_kv_cache = self.kv_cache[forward_context.virtual_engine]
                self.impl.forward(
                    self, query, key, value, self_kv_cache, attn_metadata, output=output
                )
            else:
                torch.ops.vllm.unified_attention_with_output(
                    query, key, value, output, self.layer_name
                )
            return output.view(-1, hidden_size)
        else:
            if self.use_direct_call:
                forward_context = get_forward_context()
                attn_metadata = forward_context.attn_metadata
                if isinstance(attn_metadata, dict):
                    attn_metadata = attn_metadata[self.layer_name]
                self_kv_cache = self.kv_cache[forward_context.virtual_engine]
                return self.impl.forward(
                    self, query, key, value, self_kv_cache, attn_metadata
                )
            else:
                return torch.ops.vllm.unified_attention(
                    query, key, value, self.layer_name
                )

    def calc_kv_scales(self, query, key, value):
        self._q_scale.copy_(torch.abs(query).max() / self.q_range)
        self._k_scale.copy_(torch.abs(key).max() / self.k_range)
        self._v_scale.copy_(torch.abs(value).max() / self.v_range)
        self._q_scale_float = self._q_scale.item()
        self._k_scale_float = self._k_scale.item()
        self._v_scale_float = self._v_scale.item()
        # We only calculate the scales once
        self.calculate_kv_scales = False

    def extra_repr(self) -> str:
        s = f"head_size={self.impl.head_size}"  # type: ignore
        s += f", num_heads={self.impl.num_heads}"  # type: ignore
        s += f", num_kv_heads={self.impl.num_kv_heads}"  # type: ignore
        s += f", scale={self.impl.scale}"  # type: ignore
        s += f", backend={self.impl.__class__.__name__}"
        return s

    def process_weights_after_loading(self, act_dtype: torch.dtype):
        self.impl.process_weights_after_loading(act_dtype)

    def get_attn_backend(self) -> type[AttentionBackend]:
        return self.attn_backend

    def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
        # Block size may get updated after model loading, refresh it
        block_size = vllm_config.cache_config.block_size
        # Should not be called for enc-dec or encoder-only attention.
        assert self.attn_type == AttentionType.DECODER
        if self.sliding_window is not None:
            assert not vllm_config.model_config.use_mla, (
                "MLA is not supported for slidingwindow"
            )
            return SlidingWindowSpec(
                block_size=block_size,
                num_kv_heads=self.num_kv_heads,
                head_size=self.head_size,
                dtype=self.kv_cache_torch_dtype,
                sliding_window=self.sliding_window,
            )
        else:
            return FullAttentionSpec(
                block_size=block_size,
                num_kv_heads=self.num_kv_heads,
                head_size=self.head_size,
                dtype=self.kv_cache_torch_dtype,
            )

attn_backend instance-attribute

attn_backend = get_attn_backend(
    head_size,
    dtype,
    kv_cache_dtype,
    block_size,
    use_mla=False,
    has_sink=has_sink,
)

attn_type instance-attribute

attn_type = attn_type

backend instance-attribute

backend = AttentionBackendEnum[get_name()]

dtype instance-attribute

dtype = dtype

has_sink instance-attribute

has_sink = get('sinks') is not None

head_size instance-attribute

head_size = head_size

impl instance-attribute

impl = impl_cls(
    num_heads,
    head_size,
    scale,
    num_kv_heads,
    alibi_slopes,
    sliding_window,
    kv_cache_dtype,
    logits_soft_cap,
    attn_type,
    kv_sharing_target_layer_name,
    **extra_impl_args,
)

k_range instance-attribute

k_range = tensor(K_SCALE_CONSTANT, dtype=float32)

kv_cache instance-attribute

kv_cache = [
    (tensor([])) for _ in (range(pipeline_parallel_size))
]

kv_cache_torch_dtype instance-attribute

kv_cache_torch_dtype = kv_cache_dtype_str_to_dtype(
    kv_cache_dtype, model_config
)

kv_sharing_target_layer_name instance-attribute

kv_sharing_target_layer_name = kv_sharing_target_layer_name

layer_name instance-attribute

layer_name = prefix

num_heads instance-attribute

num_heads = num_heads

num_kv_heads instance-attribute

num_kv_heads = num_kv_heads

q_range instance-attribute

q_range = tensor(Q_SCALE_CONSTANT, dtype=float32)

query_quant instance-attribute

query_quant = None

sliding_window instance-attribute

sliding_window = sliding_window

use_direct_call instance-attribute

use_direct_call = not opaque_attention_op()

use_output instance-attribute

use_output = accept_output_buffer

v_range instance-attribute

v_range = tensor(V_SCALE_CONSTANT, dtype=float32)

__init__

__init__(
    num_heads: int,
    head_size: int,
    scale: float,
    num_kv_heads: int | None = None,
    alibi_slopes: list[float] | None = None,
    cache_config: CacheConfig | None = None,
    quant_config: QuantizationConfig | None = None,
    logits_soft_cap: float | None = None,
    per_layer_sliding_window: int | None = None,
    prefix: str = "",
    attn_type: str = DECODER,
    kv_sharing_target_layer_name: str | None = None,
    attn_backend: type[AttentionBackend] | None = None,
    **extra_impl_args,
) -> None

The KV cache is stored inside this class and is accessed via self.kv_cache.

Source code in vllm/attention/layer.py
def __init__(
    self,
    num_heads: int,
    head_size: int,
    scale: float,
    num_kv_heads: int | None = None,
    alibi_slopes: list[float] | None = None,
    cache_config: CacheConfig | None = None,
    quant_config: QuantizationConfig | None = None,
    logits_soft_cap: float | None = None,
    per_layer_sliding_window: int | None = None,
    prefix: str = "",
    attn_type: str = AttentionType.DECODER,
    kv_sharing_target_layer_name: str | None = None,
    attn_backend: type[AttentionBackend] | None = None,
    **extra_impl_args,
) -> None:
    """
    The KV cache is stored inside this class and is accessed via
    `self.kv_cache`.
    """
    super().__init__()
    if per_layer_sliding_window is not None:
        # per-layer sliding window
        sliding_window = per_layer_sliding_window
    elif cache_config is not None:
        # model-level sliding window
        sliding_window = cache_config.sliding_window
    else:
        sliding_window = None

    vllm_config = get_current_vllm_config()
    if cache_config is not None:
        kv_cache_dtype = cache_config.cache_dtype
        block_size = cache_config.block_size
        calculate_kv_scales = cache_config.calculate_kv_scales
    else:
        kv_cache_dtype = "auto"
        block_size = 16
        calculate_kv_scales = False
    self.kv_cache_torch_dtype = kv_cache_dtype_str_to_dtype(
        kv_cache_dtype, vllm_config.model_config
    )
    if num_kv_heads is None:
        num_kv_heads = num_heads
    assert num_heads % num_kv_heads == 0, (
        f"num_heads ({num_heads}) is not divisible by num_kv_heads ({num_kv_heads})"
    )

    # Initialize KV cache quantization attributes
    _init_kv_cache_quant(
        self, quant_config, prefix, kv_cache_dtype, calculate_kv_scales
    )

    self.num_heads = num_heads
    self.head_size = head_size
    self.num_kv_heads = num_kv_heads
    self.sliding_window = sliding_window
    self.has_sink = extra_impl_args.get("sinks") is not None

    # During model initialization, the default dtype is set as the model
    # weight and activation dtype.
    dtype = torch.get_default_dtype()
    if attn_backend is None:
        self.attn_backend = get_attn_backend(
            head_size,
            dtype,
            kv_cache_dtype,
            block_size,
            use_mla=False,
            has_sink=self.has_sink,
        )
    else:
        self.attn_backend = attn_backend

    impl_cls = self.attn_backend.get_impl_cls()
    self.impl = impl_cls(
        num_heads,
        head_size,
        scale,
        num_kv_heads,
        alibi_slopes,
        sliding_window,
        kv_cache_dtype,
        logits_soft_cap,
        attn_type,
        kv_sharing_target_layer_name,
        **extra_impl_args,
    )
    self.backend = AttentionBackendEnum[self.attn_backend.get_name()]
    self.dtype = dtype

    # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
    # torch.compile works by registering the attention as one giant
    # opaque custom op. For other platforms, we directly call them
    # and let torch.compile handle them.
    self.use_direct_call = not current_platform.opaque_attention_op()

    self.use_output = self.attn_backend.accept_output_buffer
    compilation_config = vllm_config.compilation_config
    if prefix in compilation_config.static_forward_context:
        raise ValueError(f"Duplicate layer name: {prefix}")
    compilation_config.static_forward_context[prefix] = self
    self.layer_name = prefix
    self.attn_type = attn_type

    if kv_sharing_target_layer_name is not None:
        validate_kv_sharing_target(
            prefix,
            kv_sharing_target_layer_name,
            compilation_config.static_forward_context,
        )
    self.kv_sharing_target_layer_name = kv_sharing_target_layer_name

    # use a placeholder kv cache tensor during init, which will be replaced
    # by bind_kv_cache
    # this variable will not be accessed if use_direct_call is True
    self.kv_cache = [
        torch.tensor([])
        for _ in range(vllm_config.parallel_config.pipeline_parallel_size)
    ]

    # Initialize q/k/v range constants.
    self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
    self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
    self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)

    # for attn backends supporting query quantization
    self.query_quant = None
    if (
        self.kv_cache_dtype.startswith("fp8")
        and self.impl.supports_quant_query_input()
    ):
        self.query_quant = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR)

calc_kv_scales

calc_kv_scales(query, key, value)
Source code in vllm/attention/layer.py
def calc_kv_scales(self, query, key, value):
    self._q_scale.copy_(torch.abs(query).max() / self.q_range)
    self._k_scale.copy_(torch.abs(key).max() / self.k_range)
    self._v_scale.copy_(torch.abs(value).max() / self.v_range)
    self._q_scale_float = self._q_scale.item()
    self._k_scale_float = self._k_scale.item()
    self._v_scale_float = self._v_scale.item()
    # We only calculate the scales once
    self.calculate_kv_scales = False

extra_repr

extra_repr() -> str
Source code in vllm/attention/layer.py
def extra_repr(self) -> str:
    s = f"head_size={self.impl.head_size}"  # type: ignore
    s += f", num_heads={self.impl.num_heads}"  # type: ignore
    s += f", num_kv_heads={self.impl.num_kv_heads}"  # type: ignore
    s += f", scale={self.impl.scale}"  # type: ignore
    s += f", backend={self.impl.__class__.__name__}"
    return s

forward

forward(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    output_shape: Size | None = None,
) -> Tensor

The KV cache is stored inside this class and is accessed via self.kv_cache.

Attention metadata (attn_metadata) is set using a context manager in the model runner's execute_model method. It is accessed via forward context using vllm.forward_context.get_forward_context().attn_metadata.

Source code in vllm/attention/layer.py
def forward(
    self,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    # For some alternate attention backends like MLA the attention output
    # shape does not match the query shape, so we optionally let the model
    # definition specify the output tensor shape.
    output_shape: torch.Size | None = None,
) -> torch.Tensor:
    """
    The KV cache is stored inside this class and is accessed via
    `self.kv_cache`.

    Attention metadata (`attn_metadata`) is set using a context manager in
    the model runner's `execute_model` method. It is accessed via forward
    context using
    `vllm.forward_context.get_forward_context().attn_metadata`.
    """
    if self.calculate_kv_scales:
        torch.ops.vllm.maybe_calc_kv_scales(query, key, value, self.layer_name)
    output_dtype = query.dtype
    if self.query_quant is not None:
        # quantizing with a simple torch operation enables
        # torch.compile to fuse this into previous ops
        # which reduces overheads during decoding.
        # Otherwise queries are quantized using custom ops
        # which causes decoding overheads
        assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"}

        # check if query quantization is supported
        if self.impl.supports_quant_query_input():
            query, _ = self.query_quant(query, self._q_scale)

    if self.use_output:
        output_shape = output_shape if output_shape is not None else query.shape
        output = torch.empty(output_shape, dtype=output_dtype, device=query.device)
        hidden_size = output_shape[-1]
        # Reshape the query, key, and value tensors.
        # NOTE(woosuk): We do this outside the custom op to minimize the
        # CPU overheads from the non-CUDA-graph regions.
        query = query.view(-1, self.num_heads, self.head_size)
        output = output.view(-1, self.num_heads, self.head_size)
        if key is not None:
            key = key.view(-1, self.num_kv_heads, self.head_size)
        if value is not None:
            value = value.view(-1, self.num_kv_heads, self.head_size)
        if self.use_direct_call:
            forward_context: ForwardContext = get_forward_context()
            attn_metadata = forward_context.attn_metadata
            if isinstance(attn_metadata, dict):
                attn_metadata = attn_metadata[self.layer_name]
            self_kv_cache = self.kv_cache[forward_context.virtual_engine]
            self.impl.forward(
                self, query, key, value, self_kv_cache, attn_metadata, output=output
            )
        else:
            torch.ops.vllm.unified_attention_with_output(
                query, key, value, output, self.layer_name
            )
        return output.view(-1, hidden_size)
    else:
        if self.use_direct_call:
            forward_context = get_forward_context()
            attn_metadata = forward_context.attn_metadata
            if isinstance(attn_metadata, dict):
                attn_metadata = attn_metadata[self.layer_name]
            self_kv_cache = self.kv_cache[forward_context.virtual_engine]
            return self.impl.forward(
                self, query, key, value, self_kv_cache, attn_metadata
            )
        else:
            return torch.ops.vllm.unified_attention(
                query, key, value, self.layer_name
            )

get_attn_backend

get_attn_backend() -> type[AttentionBackend]
Source code in vllm/attention/layer.py
def get_attn_backend(self) -> type[AttentionBackend]:
    return self.attn_backend

get_kv_cache_spec

get_kv_cache_spec(vllm_config: VllmConfig) -> KVCacheSpec
Source code in vllm/attention/layer.py
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
    # Block size may get updated after model loading, refresh it
    block_size = vllm_config.cache_config.block_size
    # Should not be called for enc-dec or encoder-only attention.
    assert self.attn_type == AttentionType.DECODER
    if self.sliding_window is not None:
        assert not vllm_config.model_config.use_mla, (
            "MLA is not supported for slidingwindow"
        )
        return SlidingWindowSpec(
            block_size=block_size,
            num_kv_heads=self.num_kv_heads,
            head_size=self.head_size,
            dtype=self.kv_cache_torch_dtype,
            sliding_window=self.sliding_window,
        )
    else:
        return FullAttentionSpec(
            block_size=block_size,
            num_kv_heads=self.num_kv_heads,
            head_size=self.head_size,
            dtype=self.kv_cache_torch_dtype,
        )

process_weights_after_loading

process_weights_after_loading(act_dtype: dtype)
Source code in vllm/attention/layer.py
def process_weights_after_loading(self, act_dtype: torch.dtype):
    self.impl.process_weights_after_loading(act_dtype)

AttentionBackend

Bases: ABC

Abstract class for attention backends.

Source code in vllm/attention/backends/abstract.py
class AttentionBackend(ABC):
    """Abstract class for attention backends."""

    # For some attention backends, we allocate an output tensor before
    # calling the custom op. When piecewise cudagraph is enabled, this
    # makes sure the output tensor is allocated inside the cudagraph.
    accept_output_buffer: bool = False
    supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
    supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(1)]
    supported_kv_cache_dtypes: ClassVar[list["CacheDType"]] = ["auto"]

    @staticmethod
    @abstractmethod
    def get_name() -> str:
        raise NotImplementedError

    @staticmethod
    @abstractmethod
    def get_impl_cls() -> type["AttentionImpl"]:
        raise NotImplementedError

    @staticmethod
    @abstractmethod
    def get_builder_cls():  # -> Type["AttentionMetadataBuilder"]:
        raise NotImplementedError

    @staticmethod
    @abstractmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
        cache_dtype_str: str = "auto",
    ) -> tuple[int, ...]:
        raise NotImplementedError

    @staticmethod
    def get_kv_cache_stride_order() -> tuple[int, ...]:
        raise NotImplementedError

    @classmethod
    def full_cls_name(cls) -> tuple[str, str]:
        return (cls.__module__, cls.__qualname__)

    @classmethod
    def get_supported_head_sizes(cls) -> list[int]:
        return []

    @classmethod
    def supports_head_size(cls, head_size: int) -> bool:
        supported_head_sizes = cls.get_supported_head_sizes()
        return (not supported_head_sizes) or head_size in supported_head_sizes

    @classmethod
    def supports_dtype(cls, dtype: torch.dtype) -> bool:
        return dtype in cls.supported_dtypes

    @classmethod
    def supports_kv_cache_dtype(cls, kv_cache_dtype: "CacheDType | None") -> bool:
        if kv_cache_dtype is None:
            return True
        return (not cls.supported_kv_cache_dtypes) or (
            kv_cache_dtype in cls.supported_kv_cache_dtypes
        )

    @classmethod
    def supports_block_size(cls, block_size: int | None) -> bool:
        from vllm.config.cache import BlockSize

        if block_size is None:
            return True

        valid_sizes = get_args(BlockSize)
        if block_size not in valid_sizes:
            return False

        if not cls.supported_kernel_block_sizes:
            return True

        for supported_size in cls.supported_kernel_block_sizes:
            is_multiple_of = (
                isinstance(supported_size, MultipleOf)
                and block_size % supported_size.base == 0
            )
            is_int_equal = (
                isinstance(supported_size, int) and block_size == supported_size
            )
            if is_multiple_of or is_int_equal:
                return True
        return False

    @classmethod
    def is_mla(cls) -> bool:
        return False

    @classmethod
    def supports_sink(cls) -> bool:
        return False

    @classmethod
    def is_sparse(cls) -> bool:
        return False

    @classmethod
    def supports_compute_capability(cls, capability: "DeviceCapability") -> bool:
        return True

    @classmethod
    def supports_combination(
        cls,
        head_size: int,
        dtype: torch.dtype,
        kv_cache_dtype: "CacheDType | None",
        block_size: int | None,
        use_mla: bool,
        has_sink: bool,
        use_sparse: bool,
        device_capability: "DeviceCapability",
    ) -> str | None:
        return None

    @classmethod
    def validate_configuration(
        cls,
        head_size: int,
        dtype: torch.dtype,
        kv_cache_dtype: "CacheDType | None",
        block_size: int | None,
        use_mla: bool,
        has_sink: bool,
        use_sparse: bool,
        device_capability: "DeviceCapability",
    ) -> list[str]:
        invalid_reasons = []
        if not cls.supports_head_size(head_size):
            invalid_reasons.append("head_size not supported")
        if not cls.supports_dtype(dtype):
            invalid_reasons.append("dtype not supported")
        if not cls.supports_kv_cache_dtype(kv_cache_dtype):
            invalid_reasons.append("kv_cache_dtype not supported")
        if not cls.supports_block_size(block_size):
            invalid_reasons.append("block_size not supported")
        if use_mla != cls.is_mla():
            if use_mla:
                invalid_reasons.append("MLA not supported")
            else:
                invalid_reasons.append("non-MLA not supported")
        if has_sink and not cls.supports_sink():
            invalid_reasons.append("sink setting not supported")
        if use_sparse != cls.is_sparse():
            if use_sparse:
                invalid_reasons.append("sparse not supported")
            else:
                invalid_reasons.append("non-sparse not supported")
        if not cls.supports_compute_capability(device_capability):
            invalid_reasons.append("compute capability not supported")
        combination_reason = cls.supports_combination(
            head_size,
            dtype,
            kv_cache_dtype,
            block_size,
            use_mla,
            has_sink,
            use_sparse,
            device_capability,
        )
        if combination_reason is not None:
            invalid_reasons.append(combination_reason)
        return invalid_reasons

    @classmethod
    def get_required_kv_cache_layout(cls) -> "KVCacheLayoutType | None":
        return None

accept_output_buffer class-attribute instance-attribute

accept_output_buffer: bool = False

supported_dtypes class-attribute

supported_dtypes: list[dtype] = [float16, bfloat16]

supported_kernel_block_sizes class-attribute

supported_kernel_block_sizes: list[int | MultipleOf] = [
    MultipleOf(1)
]

supported_kv_cache_dtypes class-attribute

supported_kv_cache_dtypes: list[CacheDType] = ['auto']

full_cls_name classmethod

full_cls_name() -> tuple[str, str]
Source code in vllm/attention/backends/abstract.py
@classmethod
def full_cls_name(cls) -> tuple[str, str]:
    return (cls.__module__, cls.__qualname__)

get_builder_cls abstractmethod staticmethod

get_builder_cls()
Source code in vllm/attention/backends/abstract.py
@staticmethod
@abstractmethod
def get_builder_cls():  # -> Type["AttentionMetadataBuilder"]:
    raise NotImplementedError

get_impl_cls abstractmethod staticmethod

get_impl_cls() -> type[AttentionImpl]
Source code in vllm/attention/backends/abstract.py
@staticmethod
@abstractmethod
def get_impl_cls() -> type["AttentionImpl"]:
    raise NotImplementedError

get_kv_cache_shape abstractmethod staticmethod

get_kv_cache_shape(
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,
    head_size: int,
    cache_dtype_str: str = "auto",
) -> tuple[int, ...]
Source code in vllm/attention/backends/abstract.py
@staticmethod
@abstractmethod
def get_kv_cache_shape(
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,
    head_size: int,
    cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
    raise NotImplementedError

get_kv_cache_stride_order staticmethod

get_kv_cache_stride_order() -> tuple[int, ...]
Source code in vllm/attention/backends/abstract.py
@staticmethod
def get_kv_cache_stride_order() -> tuple[int, ...]:
    raise NotImplementedError

get_name abstractmethod staticmethod

get_name() -> str
Source code in vllm/attention/backends/abstract.py
@staticmethod
@abstractmethod
def get_name() -> str:
    raise NotImplementedError

get_required_kv_cache_layout classmethod

get_required_kv_cache_layout() -> KVCacheLayoutType | None
Source code in vllm/attention/backends/abstract.py
@classmethod
def get_required_kv_cache_layout(cls) -> "KVCacheLayoutType | None":
    return None

get_supported_head_sizes classmethod

get_supported_head_sizes() -> list[int]
Source code in vllm/attention/backends/abstract.py
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
    return []

is_mla classmethod

is_mla() -> bool
Source code in vllm/attention/backends/abstract.py
@classmethod
def is_mla(cls) -> bool:
    return False

is_sparse classmethod

is_sparse() -> bool
Source code in vllm/attention/backends/abstract.py
@classmethod
def is_sparse(cls) -> bool:
    return False

supports_block_size classmethod

supports_block_size(block_size: int | None) -> bool
Source code in vllm/attention/backends/abstract.py
@classmethod
def supports_block_size(cls, block_size: int | None) -> bool:
    from vllm.config.cache import BlockSize

    if block_size is None:
        return True

    valid_sizes = get_args(BlockSize)
    if block_size not in valid_sizes:
        return False

    if not cls.supported_kernel_block_sizes:
        return True

    for supported_size in cls.supported_kernel_block_sizes:
        is_multiple_of = (
            isinstance(supported_size, MultipleOf)
            and block_size % supported_size.base == 0
        )
        is_int_equal = (
            isinstance(supported_size, int) and block_size == supported_size
        )
        if is_multiple_of or is_int_equal:
            return True
    return False

supports_combination classmethod

supports_combination(
    head_size: int,
    dtype: dtype,
    kv_cache_dtype: CacheDType | None,
    block_size: int | None,
    use_mla: bool,
    has_sink: bool,
    use_sparse: bool,
    device_capability: DeviceCapability,
) -> str | None
Source code in vllm/attention/backends/abstract.py
@classmethod
def supports_combination(
    cls,
    head_size: int,
    dtype: torch.dtype,
    kv_cache_dtype: "CacheDType | None",
    block_size: int | None,
    use_mla: bool,
    has_sink: bool,
    use_sparse: bool,
    device_capability: "DeviceCapability",
) -> str | None:
    return None

supports_compute_capability classmethod

supports_compute_capability(
    capability: DeviceCapability,
) -> bool
Source code in vllm/attention/backends/abstract.py
@classmethod
def supports_compute_capability(cls, capability: "DeviceCapability") -> bool:
    return True

supports_dtype classmethod

supports_dtype(dtype: dtype) -> bool
Source code in vllm/attention/backends/abstract.py
@classmethod
def supports_dtype(cls, dtype: torch.dtype) -> bool:
    return dtype in cls.supported_dtypes

supports_head_size classmethod

supports_head_size(head_size: int) -> bool
Source code in vllm/attention/backends/abstract.py
@classmethod
def supports_head_size(cls, head_size: int) -> bool:
    supported_head_sizes = cls.get_supported_head_sizes()
    return (not supported_head_sizes) or head_size in supported_head_sizes

supports_kv_cache_dtype classmethod

supports_kv_cache_dtype(
    kv_cache_dtype: CacheDType | None,
) -> bool
Source code in vllm/attention/backends/abstract.py
@classmethod
def supports_kv_cache_dtype(cls, kv_cache_dtype: "CacheDType | None") -> bool:
    if kv_cache_dtype is None:
        return True
    return (not cls.supported_kv_cache_dtypes) or (
        kv_cache_dtype in cls.supported_kv_cache_dtypes
    )

supports_sink classmethod

supports_sink() -> bool
Source code in vllm/attention/backends/abstract.py
@classmethod
def supports_sink(cls) -> bool:
    return False

validate_configuration classmethod

validate_configuration(
    head_size: int,
    dtype: dtype,
    kv_cache_dtype: CacheDType | None,
    block_size: int | None,
    use_mla: bool,
    has_sink: bool,
    use_sparse: bool,
    device_capability: DeviceCapability,
) -> list[str]
Source code in vllm/attention/backends/abstract.py
@classmethod
def validate_configuration(
    cls,
    head_size: int,
    dtype: torch.dtype,
    kv_cache_dtype: "CacheDType | None",
    block_size: int | None,
    use_mla: bool,
    has_sink: bool,
    use_sparse: bool,
    device_capability: "DeviceCapability",
) -> list[str]:
    invalid_reasons = []
    if not cls.supports_head_size(head_size):
        invalid_reasons.append("head_size not supported")
    if not cls.supports_dtype(dtype):
        invalid_reasons.append("dtype not supported")
    if not cls.supports_kv_cache_dtype(kv_cache_dtype):
        invalid_reasons.append("kv_cache_dtype not supported")
    if not cls.supports_block_size(block_size):
        invalid_reasons.append("block_size not supported")
    if use_mla != cls.is_mla():
        if use_mla:
            invalid_reasons.append("MLA not supported")
        else:
            invalid_reasons.append("non-MLA not supported")
    if has_sink and not cls.supports_sink():
        invalid_reasons.append("sink setting not supported")
    if use_sparse != cls.is_sparse():
        if use_sparse:
            invalid_reasons.append("sparse not supported")
        else:
            invalid_reasons.append("non-sparse not supported")
    if not cls.supports_compute_capability(device_capability):
        invalid_reasons.append("compute capability not supported")
    combination_reason = cls.supports_combination(
        head_size,
        dtype,
        kv_cache_dtype,
        block_size,
        use_mla,
        has_sink,
        use_sparse,
        device_capability,
    )
    if combination_reason is not None:
        invalid_reasons.append(combination_reason)
    return invalid_reasons

AttentionMetadata

Source code in vllm/attention/backends/abstract.py
class AttentionMetadata:
    pass

AttentionType

Attention type. Use string to be compatible with torch.compile.

Source code in vllm/attention/backends/abstract.py
class AttentionType:
    """
    Attention type.
    Use string to be compatible with `torch.compile`.
    """

    DECODER = "decoder"
    """Decoder attention between previous layer Q/K/V."""
    ENCODER = "encoder"
    """Encoder attention between previous layer Q/K/V for encoder-decoder."""
    ENCODER_ONLY = "encoder_only"
    """Encoder attention between previous layer Q/K/V."""
    ENCODER_DECODER = "encoder_decoder"
    """Attention between dec. Q and enc. K/V for encoder-decoder."""

DECODER class-attribute instance-attribute

DECODER = 'decoder'

Decoder attention between previous layer Q/K/V.

ENCODER class-attribute instance-attribute

ENCODER = 'encoder'

Encoder attention between previous layer Q/K/V for encoder-decoder.

ENCODER_DECODER class-attribute instance-attribute

ENCODER_DECODER = 'encoder_decoder'

Attention between dec. Q and enc. K/V for encoder-decoder.

ENCODER_ONLY class-attribute instance-attribute

ENCODER_ONLY = 'encoder_only'

Encoder attention between previous layer Q/K/V.

get_attn_backend

get_attn_backend(
    head_size: int,
    dtype: dtype,
    kv_cache_dtype: str | None,
    block_size: int | None,
    use_mla: bool = False,
    has_sink: bool = False,
    use_sparse: bool = False,
) -> type[AttentionBackend]

Selects which attention backend to use and lazily imports it.

Source code in vllm/attention/selector.py
def get_attn_backend(
    head_size: int,
    dtype: torch.dtype,
    kv_cache_dtype: str | None,
    block_size: int | None,
    use_mla: bool = False,
    has_sink: bool = False,
    use_sparse: bool = False,
) -> type[AttentionBackend]:
    """Selects which attention backend to use and lazily imports it."""

    if kv_cache_dtype is not None:
        valid_cache_dtypes = get_args(CacheDType)
        assert kv_cache_dtype in valid_cache_dtypes, (
            f"Invalid kv_cache_dtype: {kv_cache_dtype}. "
            f"Valid values are: {valid_cache_dtypes}"
        )

    return _cached_get_attn_backend(
        head_size=head_size,
        dtype=dtype,
        kv_cache_dtype=cast(CacheDType | None, kv_cache_dtype),
        block_size=block_size,
        use_mla=use_mla,
        has_sink=has_sink,
        use_sparse=use_sparse,
    )