Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/xgrammar
Submodule xgrammar updated 173 files
2 changes: 1 addition & 1 deletion cpp/tensorrt_llm/batch_manager/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ set(SRCS

file(GLOB_RECURSE XGRAMMAR_SRCS "${3RDPARTY_DIR}/xgrammar/cpp/*.cc")
list(FILTER XGRAMMAR_SRCS EXCLUDE REGEX
"${3RDPARTY_DIR}/xgrammar/cpp/pybind/.*\\.cc")
"${3RDPARTY_DIR}/xgrammar/cpp/nanobind/.*\\.cc")
list(APPEND SRCS ${XGRAMMAR_SRCS})

if(NOT WIN32)
Expand Down
62 changes: 48 additions & 14 deletions cpp/tensorrt_llm/batch_manager/guidedDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
#include "tensorrt_llm/batch_manager/guidedDecoder.h"
#include "tensorrt_llm/batch_manager/decoderBuffers.h"
#include "tensorrt_llm/batch_manager/llmRequest.h"
#include "tensorrt_llm/common/envUtils.h"
#include "tensorrt_llm/kernels/logitsBitmask.h"

#include <nlohmann/json.hpp>
#include <xgrammar/xgrammar.h>

using namespace tensorrt_llm::runtime;
Expand All @@ -41,20 +43,23 @@ GuidedDecoder::GuidedDecoder(executor::GuidedDecodingConfig const& guidedDecodin
if (mGuidedDecodingBackend == executor::GuidedDecodingConfig::GuidedDecodingBackend::kXGRAMMAR)
{
mXGrammarMatchers.resize(mMaxNumSequences);
xgrammar::VocabType vocabType = xgrammar::VocabType::RAW;
bool addPrefixSpace = false;
auto const& tokenizerStr = guidedDecodingConfig.getTokenizerStr();
if (tokenizerStr)
{
auto const& tokenizerInfo = xgrammar::TokenizerInfo::FromHuggingFace(
guidedDecodingConfig.getEncodedVocab().value(), guidedDecodingConfig.getTokenizerStr().value(),
mVocabSizePadded, guidedDecodingConfig.getStopTokenIds());
mXGrammarCompiler = std::make_shared<xgrammar::GrammarCompiler>(tokenizerInfo);
}
else
{
auto const& tokenizerInfo = xgrammar::TokenizerInfo(guidedDecodingConfig.getEncodedVocab().value(),
xgrammar::VocabType::RAW, mVocabSizePadded, guidedDecodingConfig.getStopTokenIds());
mXGrammarCompiler = std::make_shared<xgrammar::GrammarCompiler>(tokenizerInfo);
auto const& metadata = xgrammar::TokenizerInfo::DetectMetadataFromHF(tokenizerStr.value());
auto const& metadataJson = nlohmann::json::parse(metadata);
vocabType = metadataJson.at("vocab_type").template get<xgrammar::VocabType>();
addPrefixSpace = metadataJson.at("add_prefix_space").template get<bool>();
}
auto const& tokenizerInfo = xgrammar::TokenizerInfo(guidedDecodingConfig.getEncodedVocab().value(), vocabType,
mVocabSizePadded, guidedDecodingConfig.getStopTokenIds(), addPrefixSpace);

auto const cacheLimitGb = common::getFloatEnv("XGRAMMAR_CACHE_LIMIT_GB");
mXGrammarCompiler = std::make_shared<xgrammar::GrammarCompiler>(tokenizerInfo, /*max_threads=*/8,
/*cache_enabled=*/true,
/*cache_limit_bytes=*/static_cast<long long>(cacheLimitGb.value_or(1.0f) * 1024 * 1024 * 1024));

auto const logitsPtrDtype = BufferDataType{mLogitsDtype, false, true};
auto constexpr bitmaskDtype = TRTDataType<BitmaskT>::value;
Expand Down Expand Up @@ -89,27 +94,56 @@ void GuidedDecoder::build(ScheduledRequests const& scheduledRequests)
// The request is in the first context forward step (considering kv cache reuse).
auto const& guideType = guidedDecodingParams->getGuideType();
auto const& guide = guidedDecodingParams->getGuide();
if (guideType == executor::GuidedDecodingParams::GuideType::kJSON)
switch (guideType)
{
case executor::GuidedDecodingParams::GuideType::kJSON:
{
mXGrammarMatchers.at(seqSlot) = std::make_shared<xgrammar::GrammarMatcher>(
mXGrammarCompiler->CompileBuiltinJSONGrammar());
break;
}
else if (guideType == executor::GuidedDecodingParams::GuideType::kJSON_SCHEMA)
case executor::GuidedDecodingParams::GuideType::kJSON_SCHEMA:
{
mXGrammarMatchers.at(seqSlot) = std::make_shared<xgrammar::GrammarMatcher>(
mXGrammarCompiler->CompileJSONSchema(guide.value()));
break;
}
else if (guideType == executor::GuidedDecodingParams::GuideType::kREGEX)
case executor::GuidedDecodingParams::GuideType::kREGEX:
{
auto const& grammar = xgrammar::Grammar::FromRegex(guide.value());
mXGrammarMatchers.at(seqSlot)
= std::make_shared<xgrammar::GrammarMatcher>(mXGrammarCompiler->CompileGrammar(grammar));
break;
}
else if (guideType == executor::GuidedDecodingParams::GuideType::kEBNF_GRAMMAR)
case executor::GuidedDecodingParams::GuideType::kEBNF_GRAMMAR:
{
auto const& grammar = xgrammar::Grammar::FromEBNF(guide.value());
mXGrammarMatchers.at(seqSlot)
= std::make_shared<xgrammar::GrammarMatcher>(mXGrammarCompiler->CompileGrammar(grammar));
break;
}
case executor::GuidedDecodingParams::GuideType::kSTRUCTURAL_TAG:
{
auto const& structuralTagParametersJson = nlohmann::json::parse(guide.value());
auto const& structuralTagItemsJson
= structuralTagParametersJson.at("structures").template get<std::vector<nlohmann::json>>();
std::vector<xgrammar::StructuralTagItem> structuralTagItems;
for (auto const& s : structuralTagItemsJson)
{
structuralTagItems.emplace_back(
xgrammar::StructuralTagItem{s.at("begin").template get<std::string>(),
s.at("schema").dump(), s.at("end").template get<std::string>()});
}
auto const& triggers
= structuralTagParametersJson.at("triggers").template get<std::vector<std::string>>();
mXGrammarMatchers.at(seqSlot) = std::make_shared<xgrammar::GrammarMatcher>(
mXGrammarCompiler->CompileStructuralTag(structuralTagItems, triggers));
break;
}
default:
{
TLLM_THROW("Unsupported guide type.");
}
}
}
else if (llmReq->isGenerationInProgressState())
Expand Down
11 changes: 11 additions & 0 deletions cpp/tensorrt_llm/common/envUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,17 @@ std::optional<size_t> getUInt64Env(char const* name)
return {val};
};

std::optional<float> getFloatEnv(char const* name)
{
char const* const env = std::getenv(name);
if (env == nullptr)
{
return std::nullopt;
}
float const val = std::stof(env);
return {val};
}

std::optional<std::string> getStrEnv(char const* name)
{
char const* const env = std::getenv(name);
Expand Down
2 changes: 2 additions & 0 deletions cpp/tensorrt_llm/common/envUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ std::optional<int32_t> getIntEnv(char const* name);

std::optional<size_t> getUInt64Env(char const* name);

std::optional<float> getFloatEnv(char const* name);

bool getBoolEnv(char const* name);

// XQA kernels (optimized kernels for generation phase).
Expand Down
3 changes: 0 additions & 3 deletions cpp/tensorrt_llm/executor/executorImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1621,9 +1621,6 @@ std::tuple<Executor::Impl::RequestList, double> Executor::Impl::fetchNewRequests
TLLM_CHECK_WITH_INFO(mModel->hasGuidedDecoder(),
"Request is specified with GuidedDecodingParams, but GuidedDecoder is not setup. Please "
"provide a valid GuidedDecodingConfig to setup GuidedDecoder.");
TLLM_CHECK_WITH_INFO(newReq->getGuidedDecodingParams()->getGuideType()
!= executor::GuidedDecodingParams::GuideType::kSTRUCTURAL_TAG,
"Structural tag is not supported for guided decoding in C++ Executor.");
}

if (mModel->getWorldConfig().isLastPipelineParallelRank() && newReq->hasAdditionalOutputs())
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ peft
einops
flashinfer-python==0.2.5
opencv-python-headless
xgrammar==0.1.19
xgrammar==0.1.21
backoff
nvtx
matplotlib # FIXME: this is added to make nvtx happy
Expand Down
24 changes: 12 additions & 12 deletions tensorrt_llm/_torch/pyexecutor/grammar_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,21 +49,21 @@ class XGrammarMatcherFactory(GrammarMatcherFactory):
def __init__(self, guided_decoding_config: GuidedDecodingConfig,
vocab_size_padded: int):
super().__init__()
vocab_type = xgrammar.VocabType.RAW
add_prefix_space = False
if guided_decoding_config.tokenizer_str is not None:
metadata = xgrammar.TokenizerInfo._detect_metadata_from_hf(
guided_decoding_config.tokenizer_str)
tokenizer_info = xgrammar.TokenizerInfo(
guided_decoding_config.encoded_vocab,
vocab_type=metadata["vocab_type"],
vocab_size=vocab_size_padded,
stop_token_ids=guided_decoding_config.stop_token_ids,
add_prefix_space=metadata["add_prefix_space"])
else:
tokenizer_info = xgrammar.TokenizerInfo(
guided_decoding_config.encoded_vocab,
xgrammar.VocabType.RAW,
vocab_size=vocab_size_padded,
stop_token_ids=guided_decoding_config.stop_token_ids)
vocab_type = metadata["vocab_type"]
add_prefix_space = metadata["add_prefix_space"]

tokenizer_info = xgrammar.TokenizerInfo(
guided_decoding_config.encoded_vocab,
vocab_type=vocab_type,
vocab_size=vocab_size_padded,
stop_token_ids=guided_decoding_config.stop_token_ids,
add_prefix_space=add_prefix_space)

# Default cache limit is 1GB.
cache_limit_gb = float(os.getenv("XGRAMMAR_CACHE_LIMIT_GB", "1"))
cache_limit_bytes = int(cache_limit_gb * 1024 * 1024 * 1024)
Expand Down
3 changes: 1 addition & 2 deletions tensorrt_llm/llmapi/llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@

# yapf: disable
# isort: off
from ..bindings.executor import (
BatchingType as _BatchingType,
from ..bindings.executor import (BatchingType as _BatchingType,
CacheTransceiverBackendType as _CacheTransceiverBackendType,
CacheTransceiverConfig as _CacheTransceiverConfig,
CapacitySchedulerPolicy as _CapacitySchedulerPolicy,
Expand Down
2 changes: 1 addition & 1 deletion tensorrt_llm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class GuidedDecodingParams:
regex (str, optional): The generated text is amenable to the user-specified regular expression. Defaults to None.
grammar (str, optional): The generated text is amenable to the user-specified extended Backus-Naur form (EBNF) grammar. Defaults to None.
json_object (bool): If True, the generated text is amenable to json format. Defaults to False.
structural_tag (str, optional): The generated text is amenable to the user-specified structural tag. Structural tag is supported by xgrammar in PyTorch backend only. Defaults to None.
structural_tag (str, optional): The generated text is amenable to the user-specified structural tag. Structural tag is supported by xgrammar backend only. Defaults to None.
""" # noqa: E501

json: Optional[Union[str, BaseModel, dict]] = None
Expand Down