class CPUFusedMOE:
def __init__(self, layer: torch.nn.Module) -> None:
use_onednn_mm = ops._supports_onednn and ops.is_onednn_acl_supported()
num_experts = layer.w13_weight.size(0)
has_w13_bias = hasattr(layer, "w13_bias")
has_w2_bias = hasattr(layer, "w2_bias")
layer.gate_up_linear = []
layer.down_linear = []
for i in range(num_experts):
layer_w13_weight = layer.w13_weight[i]
layer_w13_bias = layer.w13_bias[i] if has_w13_bias else None
layer_w2_weight = layer.w2_weight[i]
layer_w2_bias = layer.w2_bias[i] if has_w2_bias else None
if use_onednn_mm:
gate_up_handle = ops.create_onednn_mm(layer_w13_weight.t(), 32)
layer.gate_up_linear.append(
lambda x, handle=gate_up_handle, bias=layer_w13_bias: ops.onednn_mm(
handle, x, bias
)
)
down_handle = ops.create_onednn_mm(layer_w2_weight.t(), 32)
layer.down_linear.append(
lambda x, handle=down_handle, bias=layer_w2_bias: ops.onednn_mm(
handle, x, bias
)
)
else:
layer.gate_up_linear.append(
lambda x, w=layer_w13_weight, b=layer_w13_bias: F.linear(x, w, b)
)
layer.down_linear.append(
lambda x, w=layer_w2_weight, b=layer_w2_bias: F.linear(x, w, b)
)
if use_onednn_mm: # remove weight
layer.w13_weight = torch.nn.Parameter(torch.empty(0), requires_grad=False)
layer.w2_weight = torch.nn.Parameter(torch.empty(0), requires_grad=False)
def __call__(
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
assert activation in {"silu", "swigluoai"}, f"{activation} is not supported."
assert not apply_router_weight_on_input
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
)
# Ref code from https://github.com/sgl-project/sglang/blob/716e682721397df103f347d22da8bd46c6016dab/python/sglang/srt/layers/moe/fused_moe_native.py#L53
len_experts = global_num_experts
cnts = topk_ids.new_zeros((topk_ids.shape[0], len_experts))
cnts.scatter_(1, topk_ids.to(torch.int64), 1)
tokens_per_expert = cnts.sum(dim=0)
idxs = topk_ids.view(-1).argsort()
sorted_tokens = x[idxs // topk_ids.shape[1]]
tokens_per_expert = tokens_per_expert.cpu().numpy()
outputs = []
start_idx = 0
for i, num_tokens in enumerate(tokens_per_expert):
end_idx = start_idx + num_tokens
if num_tokens == 0:
continue
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
gate_up = layer.gate_up_linear[i](tokens_for_this_expert)
if activation == "swigluoai":
gate_up = swigluoai_and_mul(gate_up)
else:
gate_up = silu_and_mul(gate_up)
expert_out = layer.down_linear[i](gate_up)
outputs.append(expert_out)
start_idx = end_idx
outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
new_x = torch.empty_like(outs)
new_x[idxs] = outs
final_out = (
new_x.view(*topk_ids.shape, -1)
.type(topk_weights.dtype)
.mul_(topk_weights.unsqueeze(dim=-1))
.sum(dim=1)
.type(new_x.dtype)
)
return final_out