Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
10 changes: 9 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -924,7 +924,14 @@ if(VLLM_GPU_LANG STREQUAL "HIP")
set(VLLM_ROCM_EXT_SRC
"csrc/rocm/torch_bindings.cpp"
"csrc/rocm/skinny_gemms.cu"
"csrc/rocm/attention.cu")
"csrc/rocm/attention.cu"
"csrc/rocm/ck_extensions/ck_tile_gemm_bf16/gemm_bf16.cu")

set(CK_TILE_INCLUDE_DIR
"csrc/rocm/ck_extensions/ck_tile_gemm_bf16/include"
"csrc/rocm/ck_extensions/include/"
"opt/rocm/include"
)

define_gpu_extension_target(
_rocm_C
Expand All @@ -933,6 +940,7 @@ if(VLLM_GPU_LANG STREQUAL "HIP")
SOURCES ${VLLM_ROCM_EXT_SRC}
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
ARCHITECTURES ${VLLM_GPU_ARCHES}
INCLUDE_DIRECTORIES ${CK_TILE_INCLUDE_DIR}
USE_SABI 3
WITH_SOABI)
endif()
Expand Down
412 changes: 412 additions & 0 deletions csrc/rocm/ck_extensions/ck_tile_gemm_bf16/gemm_bf16.cu

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#pragma once
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <torch/all.h>
#include <torch/extension.h>
torch::Tensor ck_tile_gemm_bf16(
torch::Tensor &XQ,
torch::Tensor &WQ,
torch::Tensor &bias,
torch::Tensor &Y);
172 changes: 172 additions & 0 deletions csrc/rocm/ck_extensions/ck_tile_gemm_bf16/include/gemm_bias_kernel.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include <iostream>
#include <string>

#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/host/concat.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/host/stream_utils.hpp"
#include "ck_tile/core/utility/env.hpp"
#include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "universal_gemm_bias_kernel.hpp"
namespace ck_tile {

/// @brief The GEMM kernel host arguments.
///
/// @par Overview
/// This structure is passed to @ref GemmKernel "GemmKernel" when creating kernel arguments
/// object. It contain all necessary information required to build proper kernel argument
/// and launch kernel on GPU.
/// This structure defines the GEMM problem configuration by stating all required information
/// like M,N,K sizes and respective strides.
struct GemmHostArgs_bias
{
CK_TILE_HOST GemmHostArgs_bias() = default;
CK_TILE_HOST GemmHostArgs_bias(const void* a_ptr_,
const void* b_ptr_,
void* e_ptr_,
const void* bias_ptr_,
index_t k_batch_,
index_t M_,
index_t N_,
index_t K_,
index_t stride_A_,
index_t stride_B_,
index_t stride_E_)
: a_ptr(a_ptr_),
b_ptr(b_ptr_),
e_ptr(e_ptr_),
bias_ptr(bias_ptr_),
M(M_),
N(N_),
K(K_),
stride_A(stride_A_),
stride_B(stride_B_),
stride_E(stride_E_),
k_batch(k_batch_)
{
}

const void* a_ptr;
const void* b_ptr;
union
{
void* e_ptr;
void* c_ptr;
};
const void* bias_ptr;
index_t M;
index_t N;
index_t K;
index_t stride_A;
index_t stride_B;

union
{
index_t stride_E;
index_t stride_C;
};

index_t k_batch;
};

template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
struct GemmKernel_bias
{
/// @brief Inject the UniversalGemmKernel base class to support execution of all necessary
/// functions.
using UniversalGemmKernel_bias =
UniversalGemmKernel_bias<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>;

using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;

/// @brief Specify the layout configurations for A, B, E and D
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
using ELayout = remove_cvref_t<typename GemmPipeline::CLayout>;

/// @brief Specify the data type configurations for A, B, E and D
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
using EDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;

/// @brief ALayout and ADataType are expected to be scalars, not a tuple.
static_assert(
!is_detected<is_tuple, ALayout>::value && !is_detected<is_tuple, ADataType>::value,
"ALayout and ADataType must be scalars. Multiple parameters are not currently supported.");

/// @brief BLayout and BDataType are expected to be scalars, not a tuple.
static_assert(
!is_detected<is_tuple, BLayout>::value && !is_detected<is_tuple, BDataType>::value,
"BLayout and BDataType must be scalars. Multiple parameters are not currently supported.");

/// @brief C/ELayout and C/EDataType are expected to be scalars, not a tuple.
static_assert(!is_detected<is_tuple, ELayout>::value &&
!is_detected<is_tuple, EDataType>::value,
"C/ELayout and C/EDataType must be scalars.");

static constexpr index_t NumATensor = 1;
static constexpr index_t NumBTensor = 1;

CK_TILE_HOST static auto GetName() -> const std::string
{
return UniversalGemmKernel_bias::GetName();
}

CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch) -> dim3
{
return UniversalGemmKernel_bias::GridSize(M, N, KBatch);
}

CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
{
return UniversalGemmKernel_bias::MaxOccupancyGridSize(s);
}

CK_TILE_HOST static constexpr auto BlockSize() -> dim3
{
return UniversalGemmKernel_bias::BlockSize();
}

CK_TILE_HOST static constexpr auto MakeKernelArgs(const GemmHostArgs_bias& hostArgs) ->
typename UniversalGemmKernel_bias::KernelArgs
{
/// @brief Universal GEMM requires array objects and corresponding stride information for
/// matrices A, B.
return UniversalGemmKernel_bias::MakeKernelArgs(
UniversalGemmHostArgs_bias<NumATensor, NumBTensor /*NumDTensor = 0 */>(
{hostArgs.a_ptr},
{hostArgs.b_ptr},
{/*hostArgs.ds_ptr*/},
hostArgs.e_ptr,
hostArgs.bias_ptr,
hostArgs.k_batch,
hostArgs.M,
hostArgs.N,
hostArgs.K,
{hostArgs.stride_A},
{hostArgs.stride_B},
{/*hostArgs.stride_Ds*/},
hostArgs.stride_E));
}

CK_TILE_HOST static auto
IsSupportedArgument(const typename UniversalGemmKernel_bias::KernelArgs& kargs) -> bool
{
return UniversalGemmKernel_bias::IsSupportedArgument(kargs);
}

CK_TILE_DEVICE auto operator()(typename UniversalGemmKernel_bias::KernelArgs kargs) const -> void
{
UniversalGemmKernel_bias{}.template operator()(kargs);
}
};
} // namespace ck_tile
Loading
Loading