Bases: FusedMoE
A FusedMoE operation that also computes the results of shared experts. If an all2all communicator is being used the shared expert computation can be interleaved with the fused all2all dispatch communication step.
Source code in vllm/model_executor/layers/fused_moe/shared_fused_moe.py
| class SharedFusedMoE(FusedMoE):
"""
A FusedMoE operation that also computes the results of shared experts.
If an all2all communicator is being used the shared expert computation
can be interleaved with the fused all2all dispatch communication step.
"""
def __init__(
self,
shared_experts: torch.nn.Module | None,
gate: torch.nn.Module | None = None,
use_overlapped: bool = True,
**kwargs,
):
super().__init__(**kwargs)
self._shared_experts = shared_experts
# Disable shared expert overlap if:
# - we are using eplb, because of correctness issues
# - we are using flashinfer with DP, since there nothint to gain
# - we are using marlin kjernels
self.use_overlapped = (
use_overlapped
and not (
# TODO(wentao): find the root cause and remove this condition
self.enable_eplb
or (self.moe_config.use_flashinfer_cutlass_kernels and self.dp_size > 1)
or self.use_marlin_kernels
)
and self._shared_experts is not None
)
self._gate = gate
@property
def shared_experts(self) -> torch.nn.Module | None:
return self._shared_experts if self.use_overlapped else None
@property
def gate(self) -> torch.nn.Module | None:
return self._gate if self.use_overlapped else None
@property
def is_internal_router(self) -> bool:
return self.gate is not None
def forward(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
if not self.use_overlapped:
if self._shared_experts is not None:
shared_out = self._shared_experts(hidden_states)
# Reduce shared expert outputs if necessary, since the MLP
# should have been created with reduce_results=False.
if (
self.reduce_results
and get_tensor_model_parallel_world_size() > 1
and self.must_reduce_shared_expert_outputs()
):
shared_out = tensor_model_parallel_all_reduce(shared_out)
else:
shared_out = None
fused_out = super().forward(
hidden_states=hidden_states,
router_logits=router_logits,
)
else:
shared_out, fused_out = super().forward(
hidden_states=hidden_states,
router_logits=router_logits,
)
# ensure early TP reduction of shared expert outputs when required
if (
shared_out is not None
and self.reduce_results
and get_tensor_model_parallel_world_size() > 1
and self.must_reduce_shared_expert_outputs()
):
shared_out = tensor_model_parallel_all_reduce(shared_out)
return shared_out, fused_out
|
_shared_experts instance-attribute
_shared_experts = shared_experts
is_internal_router property
use_overlapped instance-attribute
use_overlapped = (
use_overlapped
and not (
enable_eplb
or use_flashinfer_cutlass_kernels
and dp_size > 1
or use_marlin_kernels
)
and _shared_experts is not None
)
__init__
__init__(
shared_experts: Module | None,
gate: Module | None = None,
use_overlapped: bool = True,
**kwargs,
)
Source code in vllm/model_executor/layers/fused_moe/shared_fused_moe.py
| def __init__(
self,
shared_experts: torch.nn.Module | None,
gate: torch.nn.Module | None = None,
use_overlapped: bool = True,
**kwargs,
):
super().__init__(**kwargs)
self._shared_experts = shared_experts
# Disable shared expert overlap if:
# - we are using eplb, because of correctness issues
# - we are using flashinfer with DP, since there nothint to gain
# - we are using marlin kjernels
self.use_overlapped = (
use_overlapped
and not (
# TODO(wentao): find the root cause and remove this condition
self.enable_eplb
or (self.moe_config.use_flashinfer_cutlass_kernels and self.dp_size > 1)
or self.use_marlin_kernels
)
and self._shared_experts is not None
)
self._gate = gate
|
forward
Source code in vllm/model_executor/layers/fused_moe/shared_fused_moe.py
| def forward(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
if not self.use_overlapped:
if self._shared_experts is not None:
shared_out = self._shared_experts(hidden_states)
# Reduce shared expert outputs if necessary, since the MLP
# should have been created with reduce_results=False.
if (
self.reduce_results
and get_tensor_model_parallel_world_size() > 1
and self.must_reduce_shared_expert_outputs()
):
shared_out = tensor_model_parallel_all_reduce(shared_out)
else:
shared_out = None
fused_out = super().forward(
hidden_states=hidden_states,
router_logits=router_logits,
)
else:
shared_out, fused_out = super().forward(
hidden_states=hidden_states,
router_logits=router_logits,
)
# ensure early TP reduction of shared expert outputs when required
if (
shared_out is not None
and self.reduce_results
and get_tensor_model_parallel_world_size() > 1
and self.must_reduce_shared_expert_outputs()
):
shared_out = tensor_model_parallel_all_reduce(shared_out)
return shared_out, fused_out
|