Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include <cstring>
#include <cuda.h>
#include <nvPTXCompiler.h>
#include <nvrtc.h>
#include <sstream>
#include <string>
Expand Down Expand Up @@ -47,10 +48,24 @@
} \
} while (0)

#define CHECK_NVPTX_ERROR(content) \
do \
{ \
nvPTXCompileResult status_ = content; \
if (status_ != NVPTXCOMPILE_SUCCESS) \
{ \
setErrorString("nvPTXCompiler Internal Error"); \
return TLLM_XQA_JIT_INTERNAL_ERROR; \
} \
} while (0)

struct _tllmXqaJitProgram
{
nvrtcProgram program;
tllmXqaJitContext const* context;
// For SM120 two-stage compilation: store cubin data from nvPTXCompiler
std::vector<char> cubin_data;
bool use_stored_cubin = false;
};

namespace
Expand Down Expand Up @@ -78,6 +93,22 @@ std::string getSMFlag(int SM)
return "-arch=sm_" + smStr;
}

std::string getPTXSMFlag(int SM)
{
// For SM120, we use compute_89 for PTX generation
if (SM == 120 || SM == 121)
{
return "-arch=compute_89";
}

std::string smStr = std::to_string(SM);
if (SM == 90)
{
smStr += "a";
}
return "-arch=compute_" + smStr;
}

tllmXqaJitStatus getMacroFlags(tllmXqaJitContext const* context, std::vector<std::string>* result)
{
// Macro name -> Macro value.
Expand Down Expand Up @@ -219,6 +250,27 @@ tllmXqaJitStatus getBuildOptions(_tllmXqaJitProgram const* prog, std::vector<std
return TLLM_XQA_JIT_SUCCESS;
}

tllmXqaJitStatus getBuildOptionsPTX(_tllmXqaJitProgram const* prog, std::vector<std::string>* result)
{
// Common flags
result->push_back("-dw");
result->push_back("--use_fast_math");
result->push_back("-default-device");

// Use PTX arch for two-stage compilation
result->push_back(getPTXSMFlag(prog->context->sm));

std::vector<std::string> macros;
CHECK_TLLM_XQA_JIT_ERROR(getMacroFlags(prog->context, &macros));
// Macros
for (auto const& flag : macros)
{
result->push_back(flag);
}

return TLLM_XQA_JIT_SUCCESS;
}

tllmXqaJitStatus createProgram(tllmXqaJitProgram* prog, tllmXqaJitContext const* context)
{
*prog = new _tllmXqaJitProgram;
Expand All @@ -241,28 +293,94 @@ tllmXqaJitStatus createProgram(tllmXqaJitProgram* prog, tllmXqaJitContext const*

tllmXqaJitStatus compileProgram(tllmXqaJitProgram prog)
{
std::vector<std::string> options;
CHECK_TLLM_XQA_JIT_ERROR(getBuildOptions(prog, &options));
std::vector<char const*> options_cstr;
for (auto const& option : options)
bool needsTwoStageCompilation
= (prog->context->sm == 120 || prog->context->sm == 121) && (prog->context->kernel_type == TLLM_XQA_JIT_HMMA);

if (needsTwoStageCompilation)
{
options_cstr.push_back(option.c_str());
}
#ifndef NDEBUG
// Two-stage compilation avoids accuracy regressions and cubin compatibility issues on SM120/SM121
// by using compute_89 for PTX generation then targeting sm_120 for final cubin
printf(
"Using two-stage compilation for SM120/SM121: NVRTC (C++ -> PTX compute_89) + nvPTXCompiler (PTX -> cubin "
"sm_120)\n");
#endif
// Stage 1: Compile C++ to PTX using compute_89
std::vector<std::string> ptx_options;
CHECK_TLLM_XQA_JIT_ERROR(getBuildOptionsPTX(prog, &ptx_options));
std::vector<char const*> ptx_options_cstr;
for (auto const& option : ptx_options)
{
ptx_options_cstr.push_back(option.c_str());
}

#ifdef NDEBUG
CHECK_NVRTC_ERROR(nvrtcCompileProgram(prog->program, options_cstr.size(), options_cstr.data()));
CHECK_NVRTC_ERROR(nvrtcCompileProgram(prog->program, ptx_options_cstr.size(), ptx_options_cstr.data()));
#else
auto const err = nvrtcCompileProgram(prog->program, options_cstr.size(), options_cstr.data());
if (err != NVRTC_SUCCESS)
{
size_t logSize;
CHECK_NVRTC_ERROR(nvrtcGetProgramLogSize(prog->program, &logSize));
std::string log;
log.resize(logSize);
CHECK_NVRTC_ERROR(nvrtcGetProgramLog(prog->program, log.data()));
printf("nvrtc error log:\n%s\n", log.c_str());
CHECK_NVRTC_ERROR(err);
auto const err = nvrtcCompileProgram(prog->program, ptx_options_cstr.size(), ptx_options_cstr.data());
if (err != NVRTC_SUCCESS)
{
size_t logSize;
CHECK_NVRTC_ERROR(nvrtcGetProgramLogSize(prog->program, &logSize));
std::string log;
log.resize(logSize);
CHECK_NVRTC_ERROR(nvrtcGetProgramLog(prog->program, log.data()));
printf("nvrtc PTX compilation error log:\n%s\n", log.c_str());
CHECK_NVRTC_ERROR(err);
}
#endif

size_t ptx_size;
CHECK_NVRTC_ERROR(nvrtcGetPTXSize(prog->program, &ptx_size));
std::vector<char> ptx_data(ptx_size);
CHECK_NVRTC_ERROR(nvrtcGetPTX(prog->program, ptx_data.data()));

// Stage 2: Compile PTX to cubin for sm_120 using nvPTXCompiler
nvPTXCompilerHandle ptx_compiler;
CHECK_NVPTX_ERROR(nvPTXCompilerCreate(&ptx_compiler, ptx_size, ptx_data.data()));

std::vector<char const*> ptx_compile_options = {"--gpu-name=sm_120"};
CHECK_NVPTX_ERROR(nvPTXCompilerCompile(ptx_compiler, ptx_compile_options.size(), ptx_compile_options.data()));

size_t cubin_size;
CHECK_NVPTX_ERROR(nvPTXCompilerGetCompiledProgramSize(ptx_compiler, &cubin_size));

prog->cubin_data.resize(cubin_size);
CHECK_NVPTX_ERROR(nvPTXCompilerGetCompiledProgram(ptx_compiler, prog->cubin_data.data()));
prog->use_stored_cubin = true;

CHECK_NVPTX_ERROR(nvPTXCompilerDestroy(&ptx_compiler));

#ifndef NDEBUG
printf("Two-stage compilation completed: PTX size=%zu, cubin size=%zu\n", ptx_size, cubin_size);
#endif
}
else
{
std::vector<std::string> options;
CHECK_TLLM_XQA_JIT_ERROR(getBuildOptions(prog, &options));
std::vector<char const*> options_cstr;
for (auto const& option : options)
{
options_cstr.push_back(option.c_str());
}
#ifdef NDEBUG
CHECK_NVRTC_ERROR(nvrtcCompileProgram(prog->program, options_cstr.size(), options_cstr.data()));
#else
auto const err = nvrtcCompileProgram(prog->program, options_cstr.size(), options_cstr.data());
if (err != NVRTC_SUCCESS)
{
size_t logSize;
CHECK_NVRTC_ERROR(nvrtcGetProgramLogSize(prog->program, &logSize));
std::string log;
log.resize(logSize);
CHECK_NVRTC_ERROR(nvrtcGetProgramLog(prog->program, log.data()));
printf("nvrtc error log:\n%s\n", log.c_str());
CHECK_NVRTC_ERROR(err);
}
#endif
}

return TLLM_XQA_JIT_SUCCESS;
}

Expand All @@ -277,14 +395,32 @@ tllmXqaJitStatus tllmXqaJitCreateAndCompileProgram(tllmXqaJitProgram* prog, tllm

tllmXqaJitStatus tllmXqaJitGetCUBINSize(tllmXqaJitProgram prog, size_t* cubinSizeRet)
{
CHECK_NVRTC_ERROR(nvrtcGetCUBINSize(prog->program, cubinSizeRet));
return TLLM_XQA_JIT_SUCCESS;
// For SM120 two-stage compilation, return stored cubin size
if (prog->use_stored_cubin)
{
*cubinSizeRet = prog->cubin_data.size();
return TLLM_XQA_JIT_SUCCESS;
}
else
{
CHECK_NVRTC_ERROR(nvrtcGetCUBINSize(prog->program, cubinSizeRet));
return TLLM_XQA_JIT_SUCCESS;
}
}

tllmXqaJitStatus tllmXqaJitGetCUBIN(tllmXqaJitProgram prog, char* cubin)
{
CHECK_NVRTC_ERROR(nvrtcGetCUBIN(prog->program, cubin));
return TLLM_XQA_JIT_SUCCESS;
// For SM120 two-stage compilation, copy stored cubin data
if (prog->use_stored_cubin)
{
std::memcpy(cubin, prog->cubin_data.data(), prog->cubin_data.size());
return TLLM_XQA_JIT_SUCCESS;
}
else
{
CHECK_NVRTC_ERROR(nvrtcGetCUBIN(prog->program, cubin));
return TLLM_XQA_JIT_SUCCESS;
}
}

tllmXqaJitStatus tllmXqaJitDestroyProgram(tllmXqaJitProgram* prog)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,7 @@ DecoderXQAImpl* DecoderXQARunner::getImplFromXQAParams(XQAParams const& xqaParam
}
else
{
// If no env var set, default to precompiled impl for sm120, otherwise default to JIT.
return tensorrt_llm::common::getSMVersion() == 120 ? mPrecompiledImpl.get() : mJITImpl.get();
return mJITImpl.get();
}
}

Expand Down