Skip to content

Commit 27df543

Browse files
Support aoti_torch_cuda__weight_int4pack_mm (#15089)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #15030 by @desertfire ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/desertfire/1/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/desertfire/1/head Merge bot PR base: https://github.com/pytorch/executorch/tree/main Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/desertfire/1/orig Differential Revision: [D84395275](https://our.internmc.facebook.com/intern/diff/D84395275) @diff-train-skip-merge Co-authored-by: Bin Bao <[email protected]>
1 parent 51bd8ed commit 27df543

File tree

12 files changed

+1883
-5
lines changed

12 files changed

+1883
-5
lines changed

backends/aoti/common_shims.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,18 @@ int32_t aoti_torch_dtype_bfloat16() {
172172
return 15; // PyTorch's bfloat16 dtype code
173173
}
174174

175+
int32_t aoti_torch_dtype_int8() {
176+
return 1; // PyTorch's int32 dtype code
177+
}
178+
179+
int32_t aoti_torch_dtype_int16() {
180+
return 2; // PyTorch's int32 dtype code
181+
}
182+
183+
int32_t aoti_torch_dtype_int32() {
184+
return 3; // PyTorch's int32 dtype code
185+
}
186+
175187
int32_t aoti_torch_dtype_int64() {
176188
return 4; // PyTorch's int64 dtype code
177189
}

backends/aoti/common_shims.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ int32_t aoti_torch_device_type_cpu();
5959
int32_t aoti_torch_layout_strided();
6060
int32_t aoti_torch_dtype_float32();
6161
int32_t aoti_torch_dtype_bfloat16();
62+
int32_t aoti_torch_dtype_int8();
63+
int32_t aoti_torch_dtype_int16();
64+
int32_t aoti_torch_dtype_int32();
6265
int32_t aoti_torch_dtype_int64();
6366

6467
// Dtype utility function needed by Metal backend

backends/aoti/utils.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,12 @@ inline executorch::aten::ScalarType dtype_to_scalar_type(int32_t dtype) {
3434
// Convert based on known PyTorch dtype codes (without CUDA-specific
3535
// dependency)
3636
switch (dtype) {
37+
case 1: // PyTorch's int8 dtype code
38+
return executorch::aten::ScalarType::Char;
39+
case 2: // PyTorch's int16 dtype code
40+
return executorch::aten::ScalarType::Short;
41+
case 3: // PyTorch's int32 dtype code
42+
return executorch::aten::ScalarType::Int;
3743
case 4: // PyTorch's int64 dtype code
3844
return executorch::aten::ScalarType::Long;
3945
case 6: // PyTorch's float32 dtype code

backends/cuda/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ find_package_torch()
3838
set(_aoti_cuda_sources
3939
runtime/cuda_backend.cpp runtime/shims/memory.cpp
4040
runtime/shims/tensor_attribute.cpp runtime/guard.cpp
41-
runtime/shims/cuda_guard.cpp
41+
runtime/shims/cuda_guard.cpp runtime/shims/int4mm.cu
4242
)
4343
add_library(aoti_cuda STATIC ${_aoti_cuda_sources})
4444
target_include_directories(

backends/cuda/cuda_backend.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@
3333
}
3434

3535
# exist fallback operators in et namespace;
36-
supported_fallback_kernels: Dict[str, Any] = {}
36+
supported_fallback_kernels: Dict[str, Any] = {
37+
"at::_ops::_weight_int4pack_mm::call": None,
38+
}
3739

3840
# required fallback kernels but not supported
3941
missing_fallback_kernels: Set[str] = set()

backends/cuda/runtime/TARGETS

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
load("//tools/build/buck:nvcc_flags.bzl", "get_nvcc_arch_args")
23

34
oncall("executorch")
45

@@ -7,12 +8,15 @@ runtime.cxx_library(
78
srcs = [
89
"guard.cpp",
910
"shims/cuda_guard.cpp",
11+
"shims/int4mm.cu",
1012
"shims/memory.cpp",
1113
"shims/tensor_attribute.cpp",
1214
],
1315
headers = [
1416
"guard.h",
1517
"shims/cuda_guard.h",
18+
"shims/int4mm.cuh",
19+
"shims/int4mm.h",
1620
"shims/memory.h",
1721
"shims/tensor_attribute.h",
1822
"utils.h",
@@ -30,6 +34,10 @@ runtime.cxx_library(
3034
"//executorch/runtime/core/exec_aten:lib",
3135
"//executorch/runtime/platform:platform",
3236
],
37+
nvcc_flags = get_nvcc_arch_args() + [
38+
"-_NVCC_HOST_COMPILER_FLAG_",
39+
"gcc",
40+
],
3341
external_deps = [
3442
("cuda", None, "cuda-lazy"),
3543
],
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <cuda.h>
10+
#include <cuda_runtime.h>
11+
12+
#include <executorch/backends/aoti/utils.h>
13+
#include <executorch/backends/cuda/runtime/shims/int4mm.h>
14+
#include <executorch/backends/cuda/runtime/shims/int4mm.cuh>
15+
#include <executorch/runtime/platform/log.h>
16+
17+
namespace executorch::backends::cuda {
18+
#ifdef __cplusplus
19+
extern "C" {
20+
#endif
21+
22+
AOTITorchError aoti_torch_cuda__weight_int4pack_mm(
23+
Tensor* self,
24+
Tensor* mat2,
25+
int64_t qGroupSize,
26+
Tensor* qScaleAndZeros,
27+
Tensor** ret0) {
28+
// Validate input parameters first
29+
// Only check for null pointers here, as the actual validation of tensor
30+
// properties is done in _weight_int4pack_mm_cuda
31+
ET_CHECK_OR_RETURN_ERROR(
32+
self != nullptr,
33+
InvalidArgument,
34+
"aoti_torch_cuda__weight_int4pack_mm failed: self tensor is null");
35+
36+
ET_CHECK_OR_RETURN_ERROR(
37+
mat2 != nullptr,
38+
InvalidArgument,
39+
"aoti_torch_cuda__weight_int4pack_mm failed: mat2 tensor is null");
40+
41+
ET_CHECK_OR_RETURN_ERROR(
42+
qScaleAndZeros != nullptr,
43+
InvalidArgument,
44+
"aoti_torch_cuda__weight_int4pack_mm failed: qScaleAndZeros tensor is null");
45+
46+
ET_CHECK_OR_RETURN_ERROR(
47+
ret0 != nullptr,
48+
InvalidArgument,
49+
"aoti_torch_cuda__weight_int4pack_mm failed: ret0 is null");
50+
51+
*ret0 = _weight_int4pack_mm_cuda(*self, *mat2, qGroupSize, *qScaleAndZeros);
52+
ET_CUDA_KERNEL_LAUNCH_CHECK_OR_RETURN_ERROR();
53+
return Error::Ok;
54+
}
55+
56+
#ifdef __cplusplus
57+
}
58+
#endif
59+
} // namespace executorch::backends::cuda

0 commit comments

Comments
 (0)