Skip to content

Commit df5f297

Browse files
committed
Merge branch 'aiter_integration_final' into aiter_integration_ck_fused_moe
2 parents cdeb54e + c0dd5ad commit df5f297

File tree

14 files changed

+89
-70
lines changed

14 files changed

+89
-70
lines changed

Dockerfile.rocm

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ ENV TOKENIZERS_PARALLELISM=false
116116
ENV HIP_FORCE_DEV_KERNARG=1
117117

118118
# Enable Aiter. Make sure this only exists on the aiter branch.
119-
ENV VLLM_USE_AITER=1
119+
# ENV VLLM_USE_AITER=1
120120

121121
CMD ["/bin/bash"]
122122

Dockerfile.rocm_base

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
1212
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
1313
ARG FA_BRANCH="1a7f4dfa"
1414
ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git"
15-
ARG AITER_BRANCH="485b4b28"
15+
ARG AITER_BRANCH="41297e56"
1616
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
1717

1818
FROM ${BASE_IMAGE} AS base
@@ -118,17 +118,14 @@ RUN mkdir -p /app/install && cp /app/pytorch/dist/*.whl /app/install \
118118
FROM base AS build_aiter
119119
ARG AITER_BRANCH
120120
ARG AITER_REPO
121-
COPY requirements-rocm.txt /app
122-
COPY requirements-common.txt /app
123-
RUN pip install -r requirements-rocm.txt
124121
RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \
125122
pip install /install/*.whl
126123
RUN git clone --recursive ${AITER_REPO}
127124
RUN cd aiter \
128125
&& git checkout ${AITER_BRANCH} \
129126
&& git submodule update --init --recursive \
130-
&& pip install -r requirements.txt \
131-
&& PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py bdist_wheel --dist-dir=dist && ls /app/aiter/dist/*.whl
127+
&& pip install -r requirements.txt
128+
RUN pip install pyyaml && cd aiter && PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py bdist_wheel --dist-dir=dist && ls /app/aiter/dist/*.whl
132129
RUN mkdir -p /app/install && cp /app/aiter/dist/*.whl /app/install
133130

134131
FROM base AS final

csrc/rocm/custom_kernels.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1715,7 +1715,7 @@ void wvSpltKQ_(void* in_a, void* in_b, void* out_c, void* scale_a,
17151715
dim3 block(64, _WvPrGrp); \
17161716
if ((K_in * N_in <= 32 * 1024) && (M_in % _YTILEs == 0)) { \
17171717
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \
1718-
wvSpltKQ_hf_sml_<64, _YTILEs, _WvPrGrp, 16, _UNRLs, _N> \
1718+
wvSpltKQ_hf_sml_<64, _YTILEs, _WvPrGrp, 8, _UNRLs, _N> \
17191719
<<<grid, block, 0, stream>>>(K_in, Kp_in, M_in, af4, bf4, c, s_a, \
17201720
s_b, __wvPrGrp, Otp_in, CuCount); \
17211721
} else { \

vllm/attention/backends/rocm_flash_attn.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@
1212
AttentionMetadata, AttentionType)
1313
from vllm.attention.backends.utils import (CommonAttentionState,
1414
CommonMetadataBuilder)
15+
from vllm.utils import aiter_paged_attn_enabled
1516

16-
if envs.VLLM_USE_AITER_PAGED_ATTN:
17+
if aiter_paged_attn_enabled():
1718
from vllm.attention.ops.paged_attn_aiter import (PagedAttention,
1819
PagedAttentionMetadata)
1920
else:
@@ -616,7 +617,7 @@ def forward(
616617
else:
617618
assert value is None
618619

619-
if (envs.VLLM_USE_AITER_PAGED_ATTN and kv_cache.dtype.itemsize == 1
620+
if (aiter_paged_attn_enabled() and kv_cache.dtype.itemsize == 1
620621
and not self.aiter_kv_scales_initialized
621622
and kv_cache.shape != torch.Size([0])):
622623
num_blocks = kv_cache.shape[1]
@@ -804,7 +805,7 @@ def forward(
804805
use_custom = _use_rocm_custom_paged_attention(
805806
decode_query.dtype, head_size, block_size, gqa_ratio,
806807
decode_meta.max_decode_seq_len)
807-
if envs.VLLM_USE_AITER_PAGED_ATTN:
808+
if aiter_paged_attn_enabled():
808809
out = output[num_prefill_tokens:]
809810
PagedAttention.forward_decode(
810811
decode_query,

vllm/envs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,8 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
304304

305305
# use ater ck fused moe op if ater ops are enabled
306306
"VLLM_USE_AITER_2STAGE_MOE":
307-
lambda: (os.getenv("VLLM_USE_AITER_2STAGE_MOE", "True").lower() in ("true", "1")),
307+
lambda: (os.getenv("VLLM_USE_AITER_2STAGE_MOE", "True").lower() in
308+
("true", "1")),
308309

309310
# use ater paged attn op if ater ops are enabled
310311
"VLLM_USE_AITER_PAGED_ATTN":

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
1616
per_token_group_quant_fp8)
1717
from vllm.platforms import current_platform
18-
from vllm.utils import direct_register_custom_op
18+
from vllm.utils import aiter_moe_enabled, direct_register_custom_op
1919

20-
if envs.VLLM_USE_AITER_MOE:
20+
if aiter_moe_enabled():
2121
import aiter
2222

2323
logger = init_logger(__name__)
@@ -950,7 +950,7 @@ def fused_topk(
950950
dtype=torch.int32,
951951
device=hidden_states.device)
952952

953-
if envs.VLLM_USE_AITER_MOE:
953+
if aiter_moe_enabled():
954954
aiter.topk_softmax(topk_weights, topk_ids, token_expert_indicies,
955955
gating_output.float(), renormalize)
956956
else:

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,16 @@
1111
from vllm.distributed import (get_tensor_model_parallel_rank,
1212
get_tensor_model_parallel_world_size,
1313
tensor_model_parallel_all_reduce)
14-
from vllm.envs import VLLM_USE_AITER_MOE
1514
from vllm.logger import init_logger
1615
from vllm.model_executor.custom_op import CustomOp
1716
from vllm.model_executor.layers.quantization.base_config import (
1817
QuantizationConfig, QuantizeMethodBase)
1918
from vllm.model_executor.utils import set_weight_attrs
2019
from vllm.platforms import current_platform
2120
from vllm.platforms.interface import CpuArchEnum
21+
from vllm.utils import aiter_moe_enabled
2222

23-
if VLLM_USE_AITER_MOE:
23+
if aiter_moe_enabled():
2424
from aiter import ck_moe
2525
from aiter.ops.shuffle import shuffle_weight
2626

@@ -101,7 +101,7 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
101101
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
102102
super().process_weights_after_loading(layer)
103103

104-
if envs.VLLM_USE_AITER_MOE:
104+
if aiter_moe_enabled():
105105
layer.w13_weight = torch.nn.Parameter(shuffle_weight(
106106
layer.w13_weight.data),
107107
requires_grad=False)
@@ -189,7 +189,7 @@ def forward_cuda(
189189
scoring_func=scoring_func,
190190
e_score_correction_bias=e_score_correction_bias)
191191

192-
if VLLM_USE_AITER_MOE:
192+
if aiter_moe_enabled():
193193
return ck_moe(hidden_states=x,
194194
w1=layer.w13_weight,
195195
w2=layer.w2_weight,

vllm/model_executor/layers/layernorm.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
import torch
66
import torch.nn as nn
77

8-
from vllm.envs import VLLM_USE_AITER_NORM
98
from vllm.model_executor.custom_op import CustomOp
9+
from vllm.utils import aiter_norm_enabled
1010

11-
if VLLM_USE_AITER_NORM:
11+
if aiter_norm_enabled():
1212
import aiter
1313

1414

@@ -100,7 +100,7 @@ def forward_cuda(
100100
return out
101101

102102
if residual is not None:
103-
if VLLM_USE_AITER_NORM:
103+
if aiter_norm_enabled():
104104
aiter.rmsnorm2d_fwd_with_add(
105105
x,
106106
x,
@@ -118,7 +118,7 @@ def forward_cuda(
118118
)
119119
return x, residual
120120

121-
if VLLM_USE_AITER_NORM:
121+
if aiter_norm_enabled():
122122
out = aiter.rms_norm(x, self.weight.data, self.variance_epsilon)
123123
else:
124124
out = torch.empty_like(x)

vllm/model_executor/layers/linear.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import torch
88
from torch.nn.parameter import Parameter, UninitializedParameter
99

10-
from vllm import envs
1110
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
1211
get_tensor_model_parallel_world_size,
1312
split_tensor_along_last_dim,
@@ -16,8 +15,9 @@
1615
from vllm.logger import init_logger
1716
from vllm.model_executor.layers.quantization.base_config import (
1817
QuantizationConfig, QuantizeMethodBase)
18+
from vllm.utils import aiter_linear_enabled
1919

20-
if envs.VLLM_USE_AITER_LINEAR:
20+
if aiter_linear_enabled():
2121
from aiter.tuned_gemm import tgemm
2222
else:
2323
from vllm.model_executor.layers.tuned_gemm import tgemm
@@ -256,7 +256,7 @@ def forward(
256256
bias = self.bias if not self.skip_bias_add else None
257257
assert self.quant_method is not None
258258
if type(self.quant_method
259-
) is UnquantizedLinearMethod and envs.VLLM_USE_AITER_LINEAR:
259+
) is UnquantizedLinearMethod and aiter_linear_enabled():
260260
output = tgemm.mm(x, self.weight, bias, self.out_dtype)
261261
else:
262262
output = self.quant_method.apply(self, x, bias)

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 33 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,11 @@
3232
PerTensorScaleParameter)
3333
from vllm.model_executor.utils import set_weight_attrs
3434
from vllm.platforms import current_platform
35-
from vllm.utils import is_navi
35+
from vllm.utils import aiter_2stage_moe_enabled, aiter_moe_enabled, is_navi
3636

37-
if envs.VLLM_USE_AITER_MOE:
37+
if aiter_moe_enabled():
3838
from aiter.fused_moe_bf16_asm import asm_moe
39-
if envs.VLLM_USE_AITER_2STAGE_MOE:
39+
if aiter_2stage_moe_enabled():
4040
from aiter.fused_moe_bf16_asm import ck_moe_2stages
4141
from aiter.ops.shuffle import shuffle_weight
4242

@@ -621,7 +621,7 @@ def process_weights_after_loading(self, layer: Module) -> None:
621621
requires_grad=False)
622622
layer.w2_weight = torch.nn.Parameter(w2_weight,
623623
requires_grad=False)
624-
if envs.VLLM_USE_AITER_MOE:
624+
if aiter_moe_enabled():
625625
w13_scales = layer.w13_weight_scale.data.unsqueeze(
626626
-1).unsqueeze(-1).expand(
627627
(-1, layer.w13_weight.shape[1], -1))
@@ -632,13 +632,13 @@ def process_weights_after_loading(self, layer: Module) -> None:
632632
layer.w13_weight_scale = torch.nn.Parameter(
633633
w13_scales.contiguous(), requires_grad=False)
634634

635-
if envs.VLLM_USE_AITER_2STAGE_MOE:
636-
layer.w13_weight = torch.nn.Parameter(
637-
shuffle_weight(layer.w13_weight, layout=(32, 32)),
638-
requires_grad=False)
639-
layer.w2_weight = torch.nn.Parameter(
640-
shuffle_weight(layer.w2_weight, layout=(32, 32)),
641-
requires_grad=False)
635+
if aiter_2stage_moe_enabled():
636+
layer.w13_weight = torch.nn.Parameter(shuffle_weight(
637+
layer.w13_weight, layout=(32, 32)),
638+
requires_grad=False)
639+
layer.w2_weight = torch.nn.Parameter(shuffle_weight(
640+
layer.w2_weight, layout=(32, 32)),
641+
requires_grad=False)
642642
else:
643643
layer.w13_weight = torch.nn.Parameter(shuffle_weight(
644644
layer.w13_weight),
@@ -715,32 +715,31 @@ def process_weights_after_loading(self, layer: Module) -> None:
715715
dq_weight, max_w13_scales[expert_id])
716716
start += shard_size
717717

718-
if envs.VLLM_USE_AITER_MOE:
719-
if envs.VLLM_USE_AITER_2STAGE_MOE:
718+
if aiter_moe_enabled():
719+
if aiter_2stage_moe_enabled():
720720
max_w13_scales = max_w13_scales.unsqueeze(-1)
721721
w2_scales = layer.w2_weight_scale.data.unsqueeze(-1)
722+
layer.w13_weight = torch.nn.Parameter(shuffle_weight(
723+
layer.w13_weight, layout=(32, 32)),
724+
requires_grad=False)
725+
layer.w2_weight = torch.nn.Parameter(shuffle_weight(
726+
layer.w2_weight, layout=(32, 32)),
727+
requires_grad=False)
722728
else:
723729
max_w13_scales = max_w13_scales.unsqueeze(-1).unsqueeze(
724730
-1).expand((-1, layer.w13_weight.shape[1], -1))
725-
w2_scales = layer.w2_weight_scale.data.unsqueeze(-1).unsqueeze(
726-
-1).expand((-1, layer.w2_weight.shape[1], -1))
727-
728-
layer.w2_weight_scale = torch.nn.Parameter(
729-
w2_scales.contiguous(), requires_grad=False)
730-
if envs.VLLM_USE_AITER_2STAGE_MOE:
731-
layer.w13_weight = torch.nn.Parameter(
732-
shuffle_weight(layer.w13_weight, layout=(32, 32)),
733-
requires_grad=False)
734-
layer.w2_weight = torch.nn.Parameter(
735-
shuffle_weight(layer.w2_weight, layout=(32, 32)),
736-
requires_grad=False)
737-
else:
731+
w2_scales = layer.w2_weight_scale.data.unsqueeze(
732+
-1).unsqueeze(-1).expand(
733+
(-1, layer.w2_weight.shape[1], -1))
738734
layer.w13_weight = torch.nn.Parameter(shuffle_weight(
739735
layer.w13_weight),
740736
requires_grad=False)
741737
layer.w2_weight = torch.nn.Parameter(shuffle_weight(
742738
layer.w2_weight),
743739
requires_grad=False)
740+
741+
layer.w2_weight_scale = torch.nn.Parameter(
742+
w2_scales.contiguous(), requires_grad=False)
744743
layer.w13_weight_scale = torch.nn.Parameter(
745744
max_w13_scales.contiguous(), requires_grad=False)
746745
return
@@ -776,15 +775,15 @@ def apply(
776775
e_score_correction_bias=e_score_correction_bias,
777776
)
778777

779-
if envs.VLLM_USE_AITER_MOE:
780-
if envs.VLLM_USE_AITER_2STAGE_MOE:
778+
if aiter_moe_enabled():
779+
if aiter_2stage_moe_enabled():
781780
return ck_moe_2stages(a1=x,
782-
w1=layer.w13_weight,
783-
w2=layer.w2_weight,
784-
topk_weight=topk_weights,
785-
topk_ids=topk_ids,
786-
fc1_scale=layer.w13_weight_scale,
787-
fc2_scale=layer.w2_weight_scale)
781+
w1=layer.w13_weight,
782+
w2=layer.w2_weight,
783+
topk_weight=topk_weights,
784+
topk_ids=topk_ids,
785+
fc1_scale=layer.w13_weight_scale,
786+
fc2_scale=layer.w2_weight_scale)
788787

789788
return asm_moe(
790789
hidden_states=x,

0 commit comments

Comments
 (0)