Skip to content

vllm.model_executor.layers.layernorm

Custom normalization layers.

GemmaRMSNorm

Bases: CustomOp

RMS normalization for Gemma.

Two differences from the above RMSNorm
  1. x * (1 + w) instead of x * w.
  2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
Source code in vllm/model_executor/layers/layernorm.py
@CustomOp.register("gemma_rms_norm")
class GemmaRMSNorm(CustomOp):
    """RMS normalization for Gemma.

    Two differences from the above RMSNorm:
        1. x * (1 + w) instead of x * w.
        2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
    """

    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
    ) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = eps

    @staticmethod
    def forward_static(
        weight: torch.Tensor,
        variance_epsilon: float,
        x: torch.Tensor,
        residual: torch.Tensor | None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        """PyTorch-native implementation equivalent to forward()."""
        orig_dtype = x.dtype
        if residual is not None:
            x = (
                x.float() + residual.float()
                if orig_dtype == torch.float16
                else x + residual
            )
            residual = x

        x = x.float()
        variance = x.pow(2).mean(dim=-1, keepdim=True)
        x = x * torch.rsqrt(variance + variance_epsilon)
        # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
        # See https://github.com/huggingface/transformers/pull/29402
        x = x * (1.0 + weight.float())
        x = x.to(orig_dtype)
        return x if residual is None else (x, residual)

    def forward_native(
        self,
        x: torch.Tensor,
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        """PyTorch-native implementation equivalent to forward()."""
        return self.forward_static(self.weight.data, self.variance_epsilon, x, residual)

    def forward_cuda(
        self,
        x: torch.Tensor,
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        if torch.compiler.is_compiling():
            return self.forward_native(x, residual)

        if not getattr(self, "_is_compiled", False):
            self.forward_static = torch.compile(  # type: ignore
                self.forward_static
            )
            self._is_compiled = True
        return self.forward_native(x, residual)

variance_epsilon instance-attribute

variance_epsilon = eps

weight instance-attribute

weight = Parameter(zeros(hidden_size))

__init__

__init__(hidden_size: int, eps: float = 1e-06) -> None
Source code in vllm/model_executor/layers/layernorm.py
def __init__(
    self,
    hidden_size: int,
    eps: float = 1e-6,
) -> None:
    super().__init__()
    self.weight = nn.Parameter(torch.zeros(hidden_size))
    self.variance_epsilon = eps

forward_cuda

forward_cuda(
    x: Tensor, residual: Tensor | None = None
) -> Tensor | tuple[Tensor, Tensor]
Source code in vllm/model_executor/layers/layernorm.py
def forward_cuda(
    self,
    x: torch.Tensor,
    residual: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
    if torch.compiler.is_compiling():
        return self.forward_native(x, residual)

    if not getattr(self, "_is_compiled", False):
        self.forward_static = torch.compile(  # type: ignore
            self.forward_static
        )
        self._is_compiled = True
    return self.forward_native(x, residual)

forward_native

forward_native(
    x: Tensor, residual: Tensor | None = None
) -> Tensor | tuple[Tensor, Tensor]

PyTorch-native implementation equivalent to forward().

Source code in vllm/model_executor/layers/layernorm.py
def forward_native(
    self,
    x: torch.Tensor,
    residual: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
    """PyTorch-native implementation equivalent to forward()."""
    return self.forward_static(self.weight.data, self.variance_epsilon, x, residual)

forward_static staticmethod

forward_static(
    weight: Tensor,
    variance_epsilon: float,
    x: Tensor,
    residual: Tensor | None,
) -> Tensor | tuple[Tensor, Tensor]

PyTorch-native implementation equivalent to forward().

Source code in vllm/model_executor/layers/layernorm.py
@staticmethod
def forward_static(
    weight: torch.Tensor,
    variance_epsilon: float,
    x: torch.Tensor,
    residual: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
    """PyTorch-native implementation equivalent to forward()."""
    orig_dtype = x.dtype
    if residual is not None:
        x = (
            x.float() + residual.float()
            if orig_dtype == torch.float16
            else x + residual
        )
        residual = x

    x = x.float()
    variance = x.pow(2).mean(dim=-1, keepdim=True)
    x = x * torch.rsqrt(variance + variance_epsilon)
    # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
    # See https://github.com/huggingface/transformers/pull/29402
    x = x * (1.0 + weight.float())
    x = x.to(orig_dtype)
    return x if residual is None else (x, residual)

LayerNorm

Bases: Module

Layer Normalization.

Source code in vllm/model_executor/layers/layernorm.py
class LayerNorm(nn.Module):
    """
    Layer Normalization.
    """

    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.dim = dim
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
        self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32))

    def forward(self, x: torch.Tensor):
        return F.layer_norm(
            x.float(), (self.dim,), self.weight, self.bias, self.eps
        ).type_as(x)

bias instance-attribute

bias = Parameter(zeros(dim, dtype=float32))

dim instance-attribute

dim = dim

eps instance-attribute

eps = eps

weight instance-attribute

weight = Parameter(ones(dim, dtype=float32))

__init__

__init__(dim: int, eps: float = 1e-06)
Source code in vllm/model_executor/layers/layernorm.py
def __init__(self, dim: int, eps: float = 1e-6):
    super().__init__()
    self.dim = dim
    self.eps = eps
    self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
    self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32))

forward

forward(x: Tensor)
Source code in vllm/model_executor/layers/layernorm.py
def forward(self, x: torch.Tensor):
    return F.layer_norm(
        x.float(), (self.dim,), self.weight, self.bias, self.eps
    ).type_as(x)

RMSNorm

Bases: CustomOp

Root mean square normalization.

Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight. Refer to https://arxiv.org/abs/1910.07467

Source code in vllm/model_executor/layers/layernorm.py
@CustomOp.register("rms_norm")
class RMSNorm(CustomOp):
    """Root mean square normalization.

    Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
    Refer to https://arxiv.org/abs/1910.07467
    """

    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
        var_hidden_size: int | None = None,
        has_weight: bool = True,
        dtype: torch.dtype | None = None,
    ) -> None:
        super().__init__()

        self.hidden_size = hidden_size
        self.variance_epsilon = eps
        self.variance_size_override = (
            None if var_hidden_size == hidden_size else var_hidden_size
        )
        weight_dtype = dtype or torch.get_default_dtype()
        self.has_weight = has_weight
        self.weight = torch.ones(hidden_size, dtype=weight_dtype)
        if self.has_weight:
            self.weight = nn.Parameter(self.weight)

        if current_platform.is_rocm():
            aiter_rmsnorm_enabled = rocm_aiter_ops.is_rmsnorm_enabled()
            self.rocm_norm_func = dispatch_rocm_rmsnorm_func(
                with_fused_add=False,
                dtype=weight_dtype,
                use_aiter=aiter_rmsnorm_enabled,
            )
            self.rocm_norm_func_with_add = dispatch_rocm_rmsnorm_func(
                with_fused_add=True, dtype=weight_dtype, use_aiter=aiter_rmsnorm_enabled
            )

    @staticmethod
    def forward_static(
        x: torch.Tensor,
        variance_epsilon: float,
        hidden_size: int,
        orig_dtype: torch.dtype,
        weight: torch.Tensor | None = None,
        residual: torch.Tensor | None = None,
        variance_size_override: int | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        """PyTorch-native implementation equivalent to forward()."""
        x = x.to(torch.float32)
        if residual is not None:
            # residual promoted f16->f32 automatically,
            # otherwise Inductor eliminates the casts to and from f16,
            # increasing memory usage (and complicating pattern matching)
            x = x + residual
            residual = x.to(orig_dtype)

        if x.shape[-1] != hidden_size:
            raise ValueError(
                f"Expected hidden_size to be {hidden_size}, but found: {x.shape[-1]}"
            )

        if variance_size_override is None:
            x_var = x
        else:
            if hidden_size < variance_size_override:
                raise ValueError(
                    "Expected hidden_size to be at least "
                    f"{variance_size_override}, but found: {hidden_size}"
                )

            x_var = x[:, :, :variance_size_override]

        variance = x_var.pow(2).mean(dim=-1, keepdim=True)

        x = x * torch.rsqrt(variance + variance_epsilon)
        x = x.to(orig_dtype)
        if weight is not None:
            x = x * weight
        if residual is None:
            return x
        else:
            return x, residual

    def forward_native(
        self,
        x: torch.Tensor,
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        """PyTorch-native implementation equivalent to forward()."""

        return self.forward_static(
            x,
            self.variance_epsilon,
            self.hidden_size,
            x.dtype,
            self.weight.data if self.has_weight else None,
            residual,
            self.variance_size_override,
        )

    def forward_cuda(
        self,
        x: torch.Tensor,
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)

        add_residual = residual is not None
        if add_residual:
            return fused_add_rms_norm(
                x, residual, self.weight.data, self.variance_epsilon
            )
        else:
            return rms_norm(x, self.weight.data, self.variance_epsilon)

    def forward_hip(
        self,
        x: torch.Tensor,
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)

        add_residual = residual is not None
        if add_residual:
            return self.rocm_norm_func_with_add(
                x, residual, self.weight.data, self.variance_epsilon
            )
        else:
            return self.rocm_norm_func(x, self.weight.data, self.variance_epsilon)

    def forward_xpu(
        self,
        x: torch.Tensor,
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)

        from vllm._ipex_ops import ipex_ops as ops

        if residual is not None:
            ops.fused_add_rms_norm(
                x,
                residual,
                self.weight.data,
                self.variance_epsilon,
            )
            return x, residual
        return ops.rms_norm(
            x,
            self.weight.data,
            self.variance_epsilon,
        )

    def extra_repr(self) -> str:
        s = f"hidden_size={self.weight.data.size(0)}"
        s += f", eps={self.variance_epsilon}"
        return s

has_weight instance-attribute

has_weight = has_weight

hidden_size instance-attribute

hidden_size = hidden_size

rocm_norm_func instance-attribute

rocm_norm_func = dispatch_rocm_rmsnorm_func(
    with_fused_add=False,
    dtype=weight_dtype,
    use_aiter=aiter_rmsnorm_enabled,
)

rocm_norm_func_with_add instance-attribute

rocm_norm_func_with_add = dispatch_rocm_rmsnorm_func(
    with_fused_add=True,
    dtype=weight_dtype,
    use_aiter=aiter_rmsnorm_enabled,
)

variance_epsilon instance-attribute

variance_epsilon = eps

variance_size_override instance-attribute

variance_size_override = (
    None
    if var_hidden_size == hidden_size
    else var_hidden_size
)

weight instance-attribute

weight = ones(hidden_size, dtype=weight_dtype)

__init__

__init__(
    hidden_size: int,
    eps: float = 1e-06,
    var_hidden_size: int | None = None,
    has_weight: bool = True,
    dtype: dtype | None = None,
) -> None
Source code in vllm/model_executor/layers/layernorm.py
def __init__(
    self,
    hidden_size: int,
    eps: float = 1e-6,
    var_hidden_size: int | None = None,
    has_weight: bool = True,
    dtype: torch.dtype | None = None,
) -> None:
    super().__init__()

    self.hidden_size = hidden_size
    self.variance_epsilon = eps
    self.variance_size_override = (
        None if var_hidden_size == hidden_size else var_hidden_size
    )
    weight_dtype = dtype or torch.get_default_dtype()
    self.has_weight = has_weight
    self.weight = torch.ones(hidden_size, dtype=weight_dtype)
    if self.has_weight:
        self.weight = nn.Parameter(self.weight)

    if current_platform.is_rocm():
        aiter_rmsnorm_enabled = rocm_aiter_ops.is_rmsnorm_enabled()
        self.rocm_norm_func = dispatch_rocm_rmsnorm_func(
            with_fused_add=False,
            dtype=weight_dtype,
            use_aiter=aiter_rmsnorm_enabled,
        )
        self.rocm_norm_func_with_add = dispatch_rocm_rmsnorm_func(
            with_fused_add=True, dtype=weight_dtype, use_aiter=aiter_rmsnorm_enabled
        )

extra_repr

extra_repr() -> str
Source code in vllm/model_executor/layers/layernorm.py
def extra_repr(self) -> str:
    s = f"hidden_size={self.weight.data.size(0)}"
    s += f", eps={self.variance_epsilon}"
    return s

forward_cuda

forward_cuda(
    x: Tensor, residual: Tensor | None = None
) -> Tensor | tuple[Tensor, Tensor]
Source code in vllm/model_executor/layers/layernorm.py
def forward_cuda(
    self,
    x: torch.Tensor,
    residual: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
    if self.variance_size_override is not None:
        return self.forward_native(x, residual)

    add_residual = residual is not None
    if add_residual:
        return fused_add_rms_norm(
            x, residual, self.weight.data, self.variance_epsilon
        )
    else:
        return rms_norm(x, self.weight.data, self.variance_epsilon)

forward_hip

forward_hip(
    x: Tensor, residual: Tensor | None = None
) -> Tensor | tuple[Tensor, Tensor]
Source code in vllm/model_executor/layers/layernorm.py
def forward_hip(
    self,
    x: torch.Tensor,
    residual: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
    if self.variance_size_override is not None:
        return self.forward_native(x, residual)

    add_residual = residual is not None
    if add_residual:
        return self.rocm_norm_func_with_add(
            x, residual, self.weight.data, self.variance_epsilon
        )
    else:
        return self.rocm_norm_func(x, self.weight.data, self.variance_epsilon)

forward_native

forward_native(
    x: Tensor, residual: Tensor | None = None
) -> Tensor | tuple[Tensor, Tensor]

PyTorch-native implementation equivalent to forward().

Source code in vllm/model_executor/layers/layernorm.py
def forward_native(
    self,
    x: torch.Tensor,
    residual: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
    """PyTorch-native implementation equivalent to forward()."""

    return self.forward_static(
        x,
        self.variance_epsilon,
        self.hidden_size,
        x.dtype,
        self.weight.data if self.has_weight else None,
        residual,
        self.variance_size_override,
    )

forward_static staticmethod

forward_static(
    x: Tensor,
    variance_epsilon: float,
    hidden_size: int,
    orig_dtype: dtype,
    weight: Tensor | None = None,
    residual: Tensor | None = None,
    variance_size_override: int | None = None,
) -> Tensor | tuple[Tensor, Tensor]

PyTorch-native implementation equivalent to forward().

Source code in vllm/model_executor/layers/layernorm.py
@staticmethod
def forward_static(
    x: torch.Tensor,
    variance_epsilon: float,
    hidden_size: int,
    orig_dtype: torch.dtype,
    weight: torch.Tensor | None = None,
    residual: torch.Tensor | None = None,
    variance_size_override: int | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
    """PyTorch-native implementation equivalent to forward()."""
    x = x.to(torch.float32)
    if residual is not None:
        # residual promoted f16->f32 automatically,
        # otherwise Inductor eliminates the casts to and from f16,
        # increasing memory usage (and complicating pattern matching)
        x = x + residual
        residual = x.to(orig_dtype)

    if x.shape[-1] != hidden_size:
        raise ValueError(
            f"Expected hidden_size to be {hidden_size}, but found: {x.shape[-1]}"
        )

    if variance_size_override is None:
        x_var = x
    else:
        if hidden_size < variance_size_override:
            raise ValueError(
                "Expected hidden_size to be at least "
                f"{variance_size_override}, but found: {hidden_size}"
            )

        x_var = x[:, :, :variance_size_override]

    variance = x_var.pow(2).mean(dim=-1, keepdim=True)

    x = x * torch.rsqrt(variance + variance_epsilon)
    x = x.to(orig_dtype)
    if weight is not None:
        x = x * weight
    if residual is None:
        return x
    else:
        return x, residual

forward_xpu

forward_xpu(
    x: Tensor, residual: Tensor | None = None
) -> Tensor | tuple[Tensor, Tensor]
Source code in vllm/model_executor/layers/layernorm.py
def forward_xpu(
    self,
    x: torch.Tensor,
    residual: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
    if self.variance_size_override is not None:
        return self.forward_native(x, residual)

    from vllm._ipex_ops import ipex_ops as ops

    if residual is not None:
        ops.fused_add_rms_norm(
            x,
            residual,
            self.weight.data,
            self.variance_epsilon,
        )
        return x, residual
    return ops.rms_norm(
        x,
        self.weight.data,
        self.variance_epsilon,
    )

RMSNormGated

Bases: CustomOp

RMS Normalization with optional gating.

This is a native PyTorch implementation that supports: - Standard RMS normalization - Group RMS normalization - Optional gating with SiLU activation

Source code in vllm/model_executor/layers/layernorm.py
@CustomOp.register("rms_norm_gated")
class RMSNormGated(CustomOp):
    """RMS Normalization with optional gating.

    This is a native PyTorch implementation that supports:
    - Standard RMS normalization
    - Group RMS normalization
    - Optional gating with SiLU activation
    """

    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-5,
        group_size: int | None = None,
        norm_before_gate: bool = False,
        device: torch.device | None = None,
        dtype: torch.dtype | None = None,
    ):
        """Initialize RMSNormGated.

        Args:
            hidden_size: Size of the hidden dimension
            eps: Epsilon for numerical stability
            group_size: If not None, do GroupNorm with each group
                        having group_size elements.
                        group_size=None is equivalent to group_size=hidden_size
                        (i.e. there's only 1 group).
            norm_before_gate: If True and z is provided: out = norm(x) * silu(z)
                              If False and z is provided: out = norm(x * silu(z))
            device: Device to create parameters on
            dtype: Data type for parameters
        """
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
        self.register_parameter("bias", None)
        self.group_size = group_size
        self.norm_before_gate = norm_before_gate
        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.ones_(self.weight)

    def forward_native(
        self, x: torch.Tensor, z: torch.Tensor | None = None
    ) -> torch.Tensor:
        """
        Native PyTorch implementation of RMS normalization with gating.

        Args:
            x: Input tensor
            z: Optional gating tensor

        Returns:
            Normalized (and optionally gated) tensor

        If z is not None:
            - norm_before_gate=True: out = norm(x) * silu(z)
            - norm_before_gate=False: out = norm(x * silu(z))
        """
        # Apply gating before normalization if needed
        if z is not None and not self.norm_before_gate:
            x = x * F.silu(z)

        # RMS Normalization
        if self.group_size is None:
            # Standard RMS norm across the last dimension
            variance = x.pow(2).mean(dim=-1, keepdim=True)
            x_normed = x * torch.rsqrt(variance + self.eps)
            out = x_normed * self.weight
        else:
            # Group RMS norm
            from einops import rearrange

            x_group = rearrange(x, "... (g d) -> ... g d", d=self.group_size)
            variance = x_group.pow(2).mean(dim=-1, keepdim=True)
            x_normed = x_group * torch.rsqrt(variance + self.eps)
            out = rearrange(x_normed, "... g d -> ... (g d)") * self.weight

        # Apply gating after normalization if needed
        if z is not None and self.norm_before_gate:
            out = out * F.silu(z)

        return out

    def forward_cuda(
        self, x: torch.Tensor, z: torch.Tensor | None = None
    ) -> torch.Tensor:
        from vllm.model_executor.layers.fla.ops.layernorm_guard import rmsnorm_fn

        return rmsnorm_fn(
            x,
            self.weight,
            self.bias,
            z=z,
            eps=self.eps,
            group_size=self.group_size,
            norm_before_gate=self.norm_before_gate,
        )

eps instance-attribute

eps = eps

group_size instance-attribute

group_size = group_size

norm_before_gate instance-attribute

norm_before_gate = norm_before_gate

weight instance-attribute

weight = Parameter(empty(hidden_size, **factory_kwargs))

__init__

__init__(
    hidden_size: int,
    eps: float = 1e-05,
    group_size: int | None = None,
    norm_before_gate: bool = False,
    device: device | None = None,
    dtype: dtype | None = None,
)

Initialize RMSNormGated.

Parameters:

Name Type Description Default
hidden_size int

Size of the hidden dimension

required
eps float

Epsilon for numerical stability

1e-05
group_size int | None

If not None, do GroupNorm with each group having group_size elements. group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).

None
norm_before_gate bool

If True and z is provided: out = norm(x) * silu(z) If False and z is provided: out = norm(x * silu(z))

False
device device | None

Device to create parameters on

None
dtype dtype | None

Data type for parameters

None
Source code in vllm/model_executor/layers/layernorm.py
def __init__(
    self,
    hidden_size: int,
    eps: float = 1e-5,
    group_size: int | None = None,
    norm_before_gate: bool = False,
    device: torch.device | None = None,
    dtype: torch.dtype | None = None,
):
    """Initialize RMSNormGated.

    Args:
        hidden_size: Size of the hidden dimension
        eps: Epsilon for numerical stability
        group_size: If not None, do GroupNorm with each group
                    having group_size elements.
                    group_size=None is equivalent to group_size=hidden_size
                    (i.e. there's only 1 group).
        norm_before_gate: If True and z is provided: out = norm(x) * silu(z)
                          If False and z is provided: out = norm(x * silu(z))
        device: Device to create parameters on
        dtype: Data type for parameters
    """
    factory_kwargs = {"device": device, "dtype": dtype}
    super().__init__()
    self.eps = eps
    self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
    self.register_parameter("bias", None)
    self.group_size = group_size
    self.norm_before_gate = norm_before_gate
    self.reset_parameters()

forward_cuda

forward_cuda(x: Tensor, z: Tensor | None = None) -> Tensor
Source code in vllm/model_executor/layers/layernorm.py
def forward_cuda(
    self, x: torch.Tensor, z: torch.Tensor | None = None
) -> torch.Tensor:
    from vllm.model_executor.layers.fla.ops.layernorm_guard import rmsnorm_fn

    return rmsnorm_fn(
        x,
        self.weight,
        self.bias,
        z=z,
        eps=self.eps,
        group_size=self.group_size,
        norm_before_gate=self.norm_before_gate,
    )

forward_native

forward_native(
    x: Tensor, z: Tensor | None = None
) -> Tensor

Native PyTorch implementation of RMS normalization with gating.

Parameters:

Name Type Description Default
x Tensor

Input tensor

required
z Tensor | None

Optional gating tensor

None

Returns:

Type Description
Tensor

Normalized (and optionally gated) tensor

If z is not None
  • norm_before_gate=True: out = norm(x) * silu(z)
  • norm_before_gate=False: out = norm(x * silu(z))
Source code in vllm/model_executor/layers/layernorm.py
def forward_native(
    self, x: torch.Tensor, z: torch.Tensor | None = None
) -> torch.Tensor:
    """
    Native PyTorch implementation of RMS normalization with gating.

    Args:
        x: Input tensor
        z: Optional gating tensor

    Returns:
        Normalized (and optionally gated) tensor

    If z is not None:
        - norm_before_gate=True: out = norm(x) * silu(z)
        - norm_before_gate=False: out = norm(x * silu(z))
    """
    # Apply gating before normalization if needed
    if z is not None and not self.norm_before_gate:
        x = x * F.silu(z)

    # RMS Normalization
    if self.group_size is None:
        # Standard RMS norm across the last dimension
        variance = x.pow(2).mean(dim=-1, keepdim=True)
        x_normed = x * torch.rsqrt(variance + self.eps)
        out = x_normed * self.weight
    else:
        # Group RMS norm
        from einops import rearrange

        x_group = rearrange(x, "... (g d) -> ... g d", d=self.group_size)
        variance = x_group.pow(2).mean(dim=-1, keepdim=True)
        x_normed = x_group * torch.rsqrt(variance + self.eps)
        out = rearrange(x_normed, "... g d -> ... (g d)") * self.weight

    # Apply gating after normalization if needed
    if z is not None and self.norm_before_gate:
        out = out * F.silu(z)

    return out

reset_parameters

reset_parameters()
Source code in vllm/model_executor/layers/layernorm.py
def reset_parameters(self):
    torch.nn.init.ones_(self.weight)

dispatch_rocm_rmsnorm_func

dispatch_rocm_rmsnorm_func(
    with_fused_add: bool,
    dtype: dtype,
    use_aiter: bool = False,
)
Source code in vllm/model_executor/layers/layernorm.py
def dispatch_rocm_rmsnorm_func(
    with_fused_add: bool, dtype: torch.dtype, use_aiter: bool = False
):
    use_aiter = use_aiter and dtype in [
        torch.float16,
        torch.bfloat16,
    ]

    if use_aiter and with_fused_add:
        return rocm_aiter_ops.rms_norm2d_with_add
    if use_aiter:
        return rocm_aiter_ops.rms_norm

    # fall back to CUDA implementation
    if with_fused_add:
        return fused_add_rms_norm
    return rms_norm

fused_add_rms_norm

fused_add_rms_norm(
    x: Tensor,
    residual: Tensor,
    weight: Tensor,
    variance_epsilon: float,
) -> tuple[Tensor, Tensor]
Source code in vllm/model_executor/layers/layernorm.py
def fused_add_rms_norm(
    x: torch.Tensor,
    residual: torch.Tensor,
    weight: torch.Tensor,
    variance_epsilon: float,
) -> tuple[torch.Tensor, torch.Tensor]:
    from vllm import _custom_ops as ops

    if vllm_is_batch_invariant():
        return rms_norm_batch_invariant(
            x + residual, weight, variance_epsilon
        ), x + residual
    ops.fused_add_rms_norm(
        x,
        residual,
        weight,
        variance_epsilon,
    )
    return x, residual

poly_norm

poly_norm(
    x: Tensor,
    weight: Tensor,
    bias: Tensor,
    variance_epsilon: float,
) -> Tensor
Source code in vllm/model_executor/layers/layernorm.py
def poly_norm(
    x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, variance_epsilon: float
) -> torch.Tensor:
    from vllm import _custom_ops as ops

    out = torch.empty_like(x)
    ops.poly_norm(
        out,
        x,
        weight,
        bias,
        variance_epsilon,
    )
    return out

rms_norm

rms_norm(
    x: Tensor, weight: Tensor, variance_epsilon: float
) -> Tensor
Source code in vllm/model_executor/layers/layernorm.py
def rms_norm(
    x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
) -> torch.Tensor:
    from vllm import _custom_ops as ops

    if vllm_is_batch_invariant():
        return rms_norm_batch_invariant(x, weight, variance_epsilon)
    out = torch.empty_like(x)
    ops.rms_norm(
        out,
        x,
        weight,
        variance_epsilon,
    )
    return out