Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion vllm_ascend/ops/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,16 @@
import torch
from vllm.model_executor.layers.activation import QuickGELU, SiluAndMul

from vllm_ascend.utils import is_310p

Check failure on line 21 in vllm_ascend/ops/activation.py

View workflow job for this annotation

GitHub Actions / lint (3.10)

Module "vllm_ascend.utils" has no attribute "is_310p" [attr-defined]


def silu_and_mul_forward_oot(self, x: torch.Tensor) -> torch.Tensor:
import torch_npu

out = torch_npu.npu_swiglu(x)
if is_310p():
out = torch_npu.npu_swiglu(x.to(torch.float32)).to(torch.float16)
else:
out = torch_npu.npu_swiglu(x)
return out


Expand Down
17 changes: 16 additions & 1 deletion vllm_ascend/ops/common_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
from vllm.model_executor.layers.fused_moe.layer import \
UnquantizedFusedMoEMethod

from vllm_ascend.ops.fused_moe import fused_experts, select_experts
from vllm_ascend.ops.fused_moe import (fused_experts, fused_experts_310p,
select_experts)
from vllm_ascend.utils import is_310p

Check failure on line 26 in vllm_ascend/ops/common_fused_moe.py

View workflow job for this annotation

GitHub Actions / lint (3.10)

Module "vllm_ascend.utils" has no attribute "is_310p" [attr-defined]


def forward_oot(
Expand Down Expand Up @@ -55,6 +57,19 @@
e_score_correction_bias=e_score_correction_bias,
)

if is_310p():
assert global_num_experts is not None
return fused_experts_310p(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input)

return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
Expand Down
89 changes: 89 additions & 0 deletions vllm_ascend/ops/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,95 @@ def fused_experts_with_all2all_buffer(
return final_hidden_states


# Currently, fused_experts on 310p only supports PanguProMoE.
def fused_experts_310p(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
top_k: int,
global_num_experts: int,
expert_map: torch.Tensor = None,
apply_router_weight_on_input: bool = False,
) -> torch.Tensor:
"""

Args:
hidden_states: Hidden states of shape (num_tokens, hidden_size).
w1: Expert weights1 of shape (num_experts, intermediate_size * 2, hidden_size).
w2: Expert weights2 of shape (num_experts, hidden_size, intermediate_size).
topk_weights: Routing weights of shape (num_tokens, top_k).
topk_ids: Selected expert IDs of shape (num_tokens, top_k).
top_k: Number of experts to select.
expert_map: Expert mapping of shape (num_experts,).

Returns:
hidden_states: Hidden states after routing.
"""
ep_size = get_ep_group().world_size
local_num_experts = global_num_experts // ep_size
local_num_group = top_k // ep_size

if apply_router_weight_on_input:
assert (topk_weights.dim() == 2
), "`topk_weights` should be in shape (num_tokens, topk)"
_, topk = topk_weights.shape
assert (
topk == 1
), "Only support topk=1 when `apply_router_weight_on_input` is True"
hidden_states = hidden_states * topk_weights.to(hidden_states.dtype)

bsz, _ = hidden_states.shape
flatten_topk_ids = topk_ids.view(-1)
sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
sorted_topk_ids = sorted_topk_ids.to(torch.int32)
sorted_hidden_states = hidden_states.index_select(
0, sorted_topk_ids // local_num_group)

experts_id = torch.arange(0,
local_num_experts,
dtype=topk_ids.dtype,
device=topk_ids.device)
num_tokens_per_expert = (flatten_topk_ids.unsqueeze(-1) == experts_id).to(
torch.float32).sum(0)
topk_scales = topk_weights.view(-1).index_select(
0, sorted_topk_ids).unsqueeze(-1)
group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64)

w1 = w1.transpose(1, 2)
gate_up_out = torch_npu.npu_grouped_matmul(
x=[sorted_hidden_states],
weight=[w1],
split_item=2,
group_list_type=0,
group_type=0,
group_list=group_list,
)[0]

gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to(
torch.float16)
gate_up_out *= topk_scales

w2 = w2.transpose(1, 2)
down_out_list = torch_npu.npu_grouped_matmul(
x=[gate_up_out],
weight=[w2],
split_item=2,
group_list_type=0,
group_type=0,
group_list=group_list,
)[0]

unsorted_topk_ids = torch.argsort(sorted_topk_ids.float()).to(
torch.int32) + torch.Tensor([0]).to(torch.int32).npu()
unsorted_hidden_states = down_out_list.index_select(0, unsorted_topk_ids)
final_hidden_states = unsorted_hidden_states.reshape(
bsz, top_k // ep_size, -1).sum(1)

return final_hidden_states


def fused_experts(
hidden_states: torch.Tensor,
w1: torch.Tensor,
Expand Down
Loading