Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
ad66935
Add CUTLASS as a submodule
WoosukKwon Jan 31, 2024
396e537
Port CUTLASS extensions
WoosukKwon Jan 31, 2024
0cd9436
Port MoE kernels
WoosukKwon Jan 31, 2024
cb4524c
Move moe_kernels
WoosukKwon Jan 31, 2024
c191207
Port MoE GEMM
WoosukKwon Jan 31, 2024
cfa4554
Port CUTLASS kernels
WoosukKwon Jan 31, 2024
90ccdfa
Remove MoE gemm
WoosukKwon Jan 31, 2024
3e90c1a
Merge branch 'main' into cutlass-moe
WoosukKwon Feb 2, 2024
77a5c8d
Remove unused CUTLASS kernels
WoosukKwon Feb 2, 2024
f1583de
Minor
WoosukKwon Feb 2, 2024
de7a749
Add topk_softmax kernels
WoosukKwon Feb 2, 2024
e5c62e8
Remove unnecessary headers
WoosukKwon Feb 2, 2024
e127d9b
Add MoE namespace
WoosukKwon Feb 2, 2024
c3096a0
Minor
WoosukKwon Feb 2, 2024
9a561cc
Add permute_kernels
WoosukKwon Feb 2, 2024
ba07256
Remove unused
WoosukKwon Feb 2, 2024
def2ccd
Move
WoosukKwon Feb 2, 2024
72256cc
Move
WoosukKwon Feb 2, 2024
e86fd06
Remove
WoosukKwon Feb 4, 2024
612f961
Add MoE MLP
WoosukKwon Feb 4, 2024
0bf8fb9
Add cudaUtils
WoosukKwon Feb 4, 2024
c09179d
Fix headers
WoosukKwon Feb 4, 2024
2ab65df
Enable BF16
WoosukKwon Feb 4, 2024
c74fc79
Err msg
WoosukKwon Feb 4, 2024
6320de4
Add unpermute_and_reduce
WoosukKwon Feb 4, 2024
9b57e39
Add renormalize
WoosukKwon Feb 5, 2024
55fae45
Add FusedMoE
WoosukKwon Feb 5, 2024
d355702
Merge branch 'main' into cutlass-moe
WoosukKwon Feb 7, 2024
fb9c524
Minor fix
WoosukKwon Feb 7, 2024
6b20148
Merge branch 'main' into optimized-fused-moe
pcmoritz Feb 18, 2024
bd52cbb
Use autotuned config
pcmoritz Feb 18, 2024
855d98a
fix
pcmoritz Feb 18, 2024
437edcf
add config
pcmoritz Feb 19, 2024
e13bc81
fix
pcmoritz Feb 19, 2024
4b94875
update
pcmoritz Feb 19, 2024
5bd4256
update
pcmoritz Feb 19, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "third_party/cutlass"]
path = third_party/cutlass
url = https://github.com/NVIDIA/cutlass.git
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ include LICENSE
include requirements.txt

recursive-include csrc *
recursive-include third_party *
120 changes: 120 additions & 0 deletions csrc/cutlass_extensions/arch/mma.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Templates exposing architecture support for multiply-add operations
*/

#pragma once
#include "cutlass_extensions/weight_only_quant_op.h"

/////////////////////////////////////////////////////////////////////////////////////////////////

namespace cutlass
{
namespace arch
{

// Tag which triggers MMA which will trigger
struct OpMultiplyAddDequantizeInterleavedBToA;

/*
Below we have extra tags to signal what kind of dequantization we want to do
(per col, scale only fine grained, finegrained with zero). This still lets us
the existing template infrastructure (incl. that in CUTLASS). However, we
split out the template below into OpMultiplyAddDequantizeInterleavedBToA along
with the quantization op before instantiating the GEMM pieces.

Note that this is somewhat of a hack, but it SIGNIFICANTLY reduces the amount of
code we need to duplicate.
*/
struct OpMultiplyAddDequantizeInterleavedBToA_percol_scale;
struct OpMultiplyAddDequantizeInterleavedBToA_fine_scale;
struct OpMultiplyAddDequantizeInterleavedBToA_fine_scalebias;

// The default just forwards the original operator
template <typename MmaOp, WeightOnlyQuantOp QuantOp_>
struct TagOperator
{
using TaggedOperator = MmaOp;
};

// Specializations below attach more information to the operator
template <>
struct TagOperator<OpMultiplyAddDequantizeInterleavedBToA, WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY>
{
using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_percol_scale;
};

template <>
struct TagOperator<OpMultiplyAddDequantizeInterleavedBToA, WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY>
{
using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_fine_scale;
};

template <>
struct TagOperator<OpMultiplyAddDequantizeInterleavedBToA, WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS>
{
using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_fine_scalebias;
};

// Here we instantiate some structs to "detag" the tagged operator. It splits it back to the original
// operator + the extra information. If no extra info was tagged, the dequant op per column scaling
// as a default.
template <typename TaggedMmaOp>
struct DetagOperator
{
using Operator = TaggedMmaOp;
static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY;
};

template <>
struct DetagOperator<OpMultiplyAddDequantizeInterleavedBToA_percol_scale>
{
using Operator = OpMultiplyAddDequantizeInterleavedBToA;
static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY;
};

template <>
struct DetagOperator<OpMultiplyAddDequantizeInterleavedBToA_fine_scale>
{
using Operator = OpMultiplyAddDequantizeInterleavedBToA;
static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY;
};

template <>
struct DetagOperator<OpMultiplyAddDequantizeInterleavedBToA_fine_scalebias>
{
using Operator = OpMultiplyAddDequantizeInterleavedBToA;
static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS;
};

} // namespace arch
} // namespace cutlass
61 changes: 61 additions & 0 deletions csrc/cutlass_extensions/compute_occupancy.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <cuda_runtime_api.h>

#include "cutlass/device_kernel.h"
#include "cudaUtils.h"

namespace tensorrt_llm
{
namespace cutlass_extensions
{

template <typename GemmKernel>
inline int compute_occupancy_for_kernel()
{

int smem_size = int(sizeof(typename GemmKernel::SharedStorage));

if (smem_size > (48 << 10))
{
cudaFuncAttributes attr;
int device = 0;
int max_smem_per_block = 0;
tensorrt_llm::common::check_cuda_error(cudaGetDevice(&device));
tensorrt_llm::common::check_cuda_error(
cudaDeviceGetAttribute(&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device));
tensorrt_llm::common::check_cuda_error(cudaFuncGetAttributes(&attr, cutlass::Kernel<GemmKernel>));
if (smem_size + attr.sharedSizeBytes >= static_cast<size_t>(max_smem_per_block))
{
// This should mean that
// cudaFuncSetAttribute(cutlass::Kernel<GemmKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)
// wouldn't work. In that case, we return an occupancy of 0. This will cause the heuristic to ignore this
// configuration.
return 0;
}
}

int max_active_blocks = -1;
tensorrt_llm::common::check_cuda_error(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks, cutlass::Kernel<GemmKernel>, GemmKernel::kThreadCount, smem_size));

return max_active_blocks;
}

} // namespace cutlass_extensions
} // namespace tensorrt_llm
117 changes: 117 additions & 0 deletions csrc/cutlass_extensions/cudaUtils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <cinttypes>
#include <cublasLt.h>
#include <cublas_v2.h>
#include <cuda_runtime.h>

#include <fstream>
#include <iostream>
#include <memory>
#include <string>
#include <vector>

#include <cerrno>
#include <cstdarg>
#include <cstring>

namespace tensorrt_llm::common
{

/* **************************** debug tools ********************************* */
static const char* _cudaGetErrorEnum(cudaError_t error)
{
return cudaGetErrorString(error);
}

static const char* _cudaGetErrorEnum(cublasStatus_t error)
{
switch (error)
{
case CUBLAS_STATUS_SUCCESS: return "CUBLAS_STATUS_SUCCESS";

case CUBLAS_STATUS_NOT_INITIALIZED: return "CUBLAS_STATUS_NOT_INITIALIZED";

case CUBLAS_STATUS_ALLOC_FAILED: return "CUBLAS_STATUS_ALLOC_FAILED";

case CUBLAS_STATUS_INVALID_VALUE: return "CUBLAS_STATUS_INVALID_VALUE";

case CUBLAS_STATUS_ARCH_MISMATCH: return "CUBLAS_STATUS_ARCH_MISMATCH";

case CUBLAS_STATUS_MAPPING_ERROR: return "CUBLAS_STATUS_MAPPING_ERROR";

case CUBLAS_STATUS_EXECUTION_FAILED: return "CUBLAS_STATUS_EXECUTION_FAILED";

case CUBLAS_STATUS_INTERNAL_ERROR: return "CUBLAS_STATUS_INTERNAL_ERROR";

case CUBLAS_STATUS_NOT_SUPPORTED: return "CUBLAS_STATUS_NOT_SUPPORTED";

case CUBLAS_STATUS_LICENSE_ERROR: return "CUBLAS_STATUS_LICENSE_ERROR";
}
return "<unknown>";
}

static std::string vformat(char const* fmt, va_list args)
{
va_list args0;
va_copy(args0, args);
auto const size = std::vsnprintf(nullptr, 0, fmt, args0);
if (size <= 0)
return "";

std::string stringBuf(size, char{});
auto const size2 = std::vsnprintf(&stringBuf[0], size + 1, fmt, args);
return stringBuf;
}

static std::string fmtstr(char const* format, ...)
{
va_list args;
va_start(args, format);
std::string result = vformat(format, args);
va_end(args);
return result;
};

// FIXME(woosuk)
template <typename T>
void check(T result, char const* const func, const char* const file, int const line)
{
if (result)
{
throw std::runtime_error(
fmtstr("[ERROR] CUDA runtime error in %s: %s %s:%d\n", func, _cudaGetErrorEnum(result), file, line));
}
}

#define check_cuda_error(val) check((val), #val, __FILE__, __LINE__)
#define check_cuda_error_2(val, file, line) check((val), #val, file, line)

inline int getSMVersion()
{
int device{-1};
check_cuda_error(cudaGetDevice(&device));
int sm_major = 0;
int sm_minor = 0;
check_cuda_error(cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device));
check_cuda_error(cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device));
return sm_major * 10 + sm_minor;
}

} // namespace tensorrt_llm::common
Loading