Skip to content

Commit fcadce9

Browse files
[fix] Eagle-2 LLMAPI pybind argument fix. (#3967)
Signed-off-by: Jhao-Ting Chen <[email protected]> Co-authored-by: Haohang Huang <[email protected]>
1 parent 255779a commit fcadce9

File tree

12 files changed

+192
-31
lines changed

12 files changed

+192
-31
lines changed

cpp/tensorrt_llm/pybind/executor/request.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,7 @@ void initRequestBindings(pybind11::module_& m)
449449
{
450450
throw std::runtime_error("Invalid EagleConfig state!");
451451
}
452-
return tle::EagleConfig(state[0].cast<tle::EagleChoices>(), state[1].cast<bool>(),
452+
return tle::EagleConfig(state[0].cast<std::optional<tle::EagleChoices>>(), state[1].cast<bool>(),
453453
state[2].cast<std::optional<float>>(), state[3].cast<bool>(), state[4].cast<std::optional<SizeType32>>());
454454
};
455455
py::class_<tle::EagleConfig>(m, "EagleConfig")
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
### Generate Text Using Eagle2 Decoding
2+
3+
from tensorrt_llm import LLM, SamplingParams
4+
from tensorrt_llm.llmapi import (LLM, EagleDecodingConfig, KvCacheConfig,
5+
SamplingParams)
6+
7+
8+
def main():
9+
# Sample prompts.
10+
prompts = [
11+
"Hello, my name is",
12+
"The president of the United States is",
13+
"The capital of France is",
14+
"The future of AI is",
15+
]
16+
# The end user can customize the sampling configuration with the SamplingParams class
17+
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
18+
19+
# The end user can customize the kv cache configuration with the KVCache class
20+
kv_cache_config = KvCacheConfig(enable_block_reuse=True)
21+
22+
llm_kwargs = {}
23+
24+
model = "lmsys/vicuna-7b-v1.3"
25+
26+
# The end user can customize the eagle decoding configuration by specifying the
27+
# speculative_model, max_draft_len, num_eagle_layers, max_non_leaves_per_layer, eagle_choices
28+
# greedy_sampling,posterior_threshold, use_dynamic_tree and dynamic_tree_max_topK
29+
# with the EagleDecodingConfig class
30+
31+
speculative_config = EagleDecodingConfig(
32+
speculative_model="yuhuili/EAGLE-Vicuna-7B-v1.3",
33+
max_draft_len=63,
34+
num_eagle_layers=4,
35+
max_non_leaves_per_layer=10,
36+
use_dynamic_tree=True,
37+
dynamic_tree_max_topK=10)
38+
39+
llm = LLM(model=model,
40+
kv_cache_config=kv_cache_config,
41+
speculative_config=speculative_config,
42+
max_batch_size=1,
43+
max_seq_len=1024,
44+
**llm_kwargs)
45+
46+
outputs = llm.generate(prompts, sampling_params)
47+
48+
# Print the outputs.
49+
for output in outputs:
50+
prompt = output.prompt
51+
generated_text = output.outputs[0].text
52+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
53+
54+
55+
if __name__ == '__main__':
56+
main()

examples/llm-api/llm_eagle_decoding.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
### Generate Text Using Eagle Decoding
22

33
from tensorrt_llm import LLM, SamplingParams
4-
from tensorrt_llm.llmapi import (LLM, BuildConfig, EagleDecodingConfig,
5-
KvCacheConfig, SamplingParams)
4+
from tensorrt_llm.llmapi import (LLM, EagleDecodingConfig, KvCacheConfig,
5+
SamplingParams)
66

77

88
def main():
@@ -16,9 +16,6 @@ def main():
1616
# The end user can customize the sampling configuration with the SamplingParams class
1717
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
1818

19-
# The end user can customize the build configuration with the BuildConfig class
20-
build_config = BuildConfig(max_batch_size=1, max_seq_len=1024)
21-
2219
# The end user can customize the kv cache configuration with the KVCache class
2320
kv_cache_config = KvCacheConfig(enable_block_reuse=True)
2421

@@ -45,9 +42,10 @@ def main():
4542
)
4643

4744
llm = LLM(model=model,
48-
build_config=build_config,
4945
kv_cache_config=kv_cache_config,
5046
speculative_config=speculative_config,
47+
max_batch_size=1,
48+
max_seq_len=1024,
5149
**llm_kwargs)
5250

5351
outputs = llm.generate(prompts, sampling_params)

tests/integration/defs/accuracy/test_cli_flow.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,33 @@ def test_eagle(self, cuda_graph, chunked_context, typical_acceptance,
466466
],
467467
extra_summarize_args=extra_summarize_args)
468468

469+
@skip_post_blackwell
470+
@parametrize_with_ids("cuda_graph,chunked_context", [(False, False),
471+
(True, True),
472+
(True, False)])
473+
def test_eagle_2(self, cuda_graph, chunked_context, mocker):
474+
mocker.patch.object(self.__class__, "EXAMPLE_FOLDER", "eagle")
475+
mocker.patch.object(CnnDailymail, "MAX_BATCH_SIZE", 8)
476+
477+
extra_summarize_args = [
478+
"--eagle_use_dynamic_tree", "--eagle_dynamic_tree_max_top_k=10"
479+
]
480+
if cuda_graph:
481+
extra_summarize_args.append("--cuda_graph_mode")
482+
if chunked_context:
483+
extra_summarize_args.append("--enable_chunked_context")
484+
485+
self.run(spec_dec_algo=EagleDecodingConfig.decoding_type,
486+
extra_convert_args=[
487+
f"--eagle_model_dir={self.EAGLE_MODEL_PATH}",
488+
"--max_draft_len=63", "--num_eagle_layers=4",
489+
"--max_non_leaves_per_layer=10"
490+
],
491+
extra_build_args=[
492+
"--speculative_decoding_mode=eagle", "--max_draft_len=63"
493+
],
494+
extra_summarize_args=extra_summarize_args)
495+
469496

470497
class TestLlama7B(CliFlowAccuracyTestHarness):
471498
MODEL_NAME = "llama-7b-hf"

tests/integration/defs/accuracy/test_llm_api.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515
import pytest
1616

17-
from tensorrt_llm.llmapi import LLM
17+
from tensorrt_llm.llmapi import LLM, EagleDecodingConfig
1818
from tensorrt_llm.models.modeling_utils import QuantConfig
1919
from tensorrt_llm.quantization import QuantAlgo
2020

@@ -290,3 +290,50 @@ def test_fp8_kvcache(self):
290290
extra_evaluator_kwargs=self.EXTRA_EVALUATOR_KWARGS)
291291
task = MMLU(self.MODEL_NAME)
292292
task.evaluate(llm)
293+
294+
295+
class TestEagleVicuna_7B_v1_3(LlmapiAccuracyTestHarness):
296+
MODEL_NAME = "lmsys/vicuna-7b-v1.3"
297+
MODEL_PATH = f"{llm_models_root()}/vicuna-7b-v1.3"
298+
299+
speculative_config = EagleDecodingConfig(
300+
max_draft_len=63,
301+
speculative_model=f"{llm_models_root()}/EAGLE-Vicuna-7B-v1.3",
302+
num_eagle_layers=4,
303+
max_non_leaves_per_layer=10,
304+
eagle_choices=[[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3], [0, 3], [4], [0, 4], [2, 0], \
305+
[0, 5], [0, 0, 1], [5], [0, 6], [6], [0, 7], [0, 1, 0], [1, 1], [7], [0, 8], [0, 0, 2], [3, 0], \
306+
[0, 9], [8], [9], [1, 0, 0], [0, 2, 0], [1, 2], [0, 0, 3], [4, 0], [2, 1], [0, 0, 4], [0, 0, 5], \
307+
[0, 0, 0, 0], [0, 1, 1], [0, 0, 6], [0, 3, 0], [5, 0], [1, 3], [0, 0, 7], [0, 0, 8], [0, 0, 9], \
308+
[6, 0], [0, 4, 0], [1, 4], [7, 0], [0, 1, 2], [2, 0, 0], [3, 1], [2, 2], [8, 0], \
309+
[0, 5, 0], [1, 5], [1, 0, 1], [0, 2, 1], [9, 0], [0, 6, 0], [0, 0, 0, 1], [1, 6], [0, 7, 0]]
310+
)
311+
312+
def test_auto_dtype(self):
313+
with LLM(
314+
self.MODEL_PATH,
315+
max_batch_size=8, # Spec-dec use case less than bs=8
316+
speculative_config=self.speculative_config) as llm:
317+
task = CnnDailymail(self.MODEL_NAME)
318+
task.evaluate(llm)
319+
320+
321+
class TestEagle2Vicuna_7B_v1_3(LlmapiAccuracyTestHarness):
322+
MODEL_NAME = "lmsys/vicuna-7b-v1.3"
323+
MODEL_PATH = f"{llm_models_root()}/vicuna-7b-v1.3"
324+
325+
speculative_config = EagleDecodingConfig(
326+
max_draft_len=63,
327+
speculative_model=f"{llm_models_root()}/EAGLE-Vicuna-7B-v1.3",
328+
num_eagle_layers=4,
329+
max_non_leaves_per_layer=10,
330+
use_dynamic_tree=True,
331+
dynamic_tree_max_topK=10)
332+
333+
def test_auto_dtype(self):
334+
with LLM(
335+
self.MODEL_PATH,
336+
max_batch_size=8, # Spec-dec use case less than bs=8
337+
speculative_config=self.speculative_config) as llm:
338+
task = CnnDailymail(self.MODEL_NAME)
339+
task.evaluate(llm)

tests/integration/defs/llmapi/test_llm_examples.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,11 @@ def test_llmapi_example_eagle_decoding(llm_root, engine_dir, llm_venv):
141141
_run_llmapi_example(llm_root, engine_dir, llm_venv, "llm_eagle_decoding.py")
142142

143143

144+
def test_llmapi_example_eagle2_decoding(llm_root, engine_dir, llm_venv):
145+
_run_llmapi_example(llm_root, engine_dir, llm_venv,
146+
"llm_eagle2_decoding.py")
147+
148+
144149
@pytest.mark.skip_less_device(2)
145150
def test_llmapi_example_distributed_tp2(llm_root, engine_dir, llm_venv):
146151
_run_llmapi_example(llm_root, engine_dir, llm_venv,

tests/integration/defs/test_e2e.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1442,9 +1442,6 @@ def test_build_time_benchmark_sanity(llm_root, llm_venv):
14421442
])
14431443

14441444

1445-
# End of HLAPI examples
1446-
1447-
14481445
### Pivot-To-Python examples
14491446
def test_ptp_quickstart(llm_root, llm_venv):
14501447
example_root = Path(os.path.join(llm_root, "examples", "pytorch"))

tests/integration/test_lists/test-db/l0_a10.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,9 @@ l0_a10:
107107
- accuracy/test_cli_flow.py::TestVicuna7B::test_eagle[cuda_graph=True-chunked_context=False-typical_acceptance=False] # 5 mins
108108
- accuracy/test_cli_flow.py::TestVicuna7B::test_eagle[cuda_graph=True-chunked_context=True-typical_acceptance=False] # 5 mins
109109
- accuracy/test_cli_flow.py::TestVicuna7B::test_eagle[cuda_graph=True-chunked_context=False-typical_acceptance=True] # 5 mins
110+
- accuracy/test_cli_flow.py::TestVicuna7B::test_eagle_2[cuda_graph=False-chunked_context=False] # 5 mins
111+
- accuracy/test_cli_flow.py::TestVicuna7B::test_eagle_2[cuda_graph=True-chunked_context=False] # 5 mins
112+
- accuracy/test_cli_flow.py::TestVicuna7B::test_eagle_2[cuda_graph=True-chunked_context=True] # 5 mins
110113
- accuracy/test_cli_flow.py::TestLlama2_7B::test_auto_dtype
111114
- examples/test_chatglm.py::test_llm_glm_4_9b_single_gpu_summary[glm-4-9b-disable_weight_only]
112115
- unittest/trt/attention/test_gpt_attention_IFB.py
@@ -165,3 +168,5 @@ l0_a10:
165168
- test_e2e.py::test_build_time_benchmark_sanity
166169
- examples/test_whisper.py::test_llm_whisper_general[large-v3-enable_gemm_plugin-enable_attention_plugin-disable_weight_only-float16-nb:1-use_python_runtime]
167170
- examples/test_whisper.py::test_llm_whisper_general[large-v3-disable_gemm_plugin-enable_attention_plugin-disable_weight_only-float16-nb:1-use_python_runtime] # 4 mins
171+
- accuracy/test_llm_api.py::TestEagleVicuna_7B_v1_3::test_auto_dtype
172+
- accuracy/test_llm_api.py::TestEagle2Vicuna_7B_v1_3::test_auto_dtype

tests/integration/test_lists/test-db/l0_h100.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,9 @@ l0_h100:
268268
- accuracy/test_cli_flow.py::TestVicuna7B::test_eagle[cuda_graph=True-chunked_context=False-typical_acceptance=False] # 5 mins
269269
- accuracy/test_cli_flow.py::TestVicuna7B::test_eagle[cuda_graph=True-chunked_context=True-typical_acceptance=False] # 5 mins
270270
- accuracy/test_cli_flow.py::TestVicuna7B::test_eagle[cuda_graph=True-chunked_context=False-typical_acceptance=True] # 5 mins
271+
- accuracy/test_cli_flow.py::TestVicuna7B::test_eagle_2[cuda_graph=False-chunked_context=False] # 5 mins
272+
- accuracy/test_cli_flow.py::TestVicuna7B::test_eagle_2[cuda_graph=True-chunked_context=False] # 5 mins
273+
- accuracy/test_cli_flow.py::TestVicuna7B::test_eagle_2[cuda_graph=True-chunked_context=True] # 5 mins
271274
- accuracy/test_cli_flow.py::TestPhi2::test_auto_dtype # 2 mins
272275
- accuracy/test_cli_flow.py::TestGpt2Medium::test_fp8
273276
- accuracy/test_cli_flow.py::TestGpt2Medium::test_fp8_lm_head
@@ -289,3 +292,5 @@ l0_h100:
289292
- unittest/trt/model_api/test_model_quantization.py # 20 mins on H100
290293
- unittest/bindings # 8 mins on H100
291294
- test_e2e.py::test_build_time_benchmark_sanity
295+
- accuracy/test_llm_api.py::TestEagleVicuna_7B_v1_3::test_auto_dtype
296+
- accuracy/test_llm_api.py::TestEagle2Vicuna_7B_v1_3::test_auto_dtype

tests/unittest/bindings/test_executor_bindings.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1444,6 +1444,14 @@ def test_eagle_config_pickle():
14441444
assert config.use_dynamic_tree == config_copy.use_dynamic_tree
14451445
assert config.greedy_sampling == config_copy.greedy_sampling
14461446

1447+
config = trtllm.EagleConfig(None, False, 0.5, True, 3)
1448+
config_copy = pickle.loads(pickle.dumps(config))
1449+
assert config.eagle_choices == config_copy.eagle_choices
1450+
assert config.greedy_sampling == config_copy.greedy_sampling
1451+
assert config.posterior_threshold == config_copy.posterior_threshold
1452+
assert config.use_dynamic_tree == config_copy.use_dynamic_tree
1453+
assert config.dynamic_tree_max_topK == config_copy.dynamic_tree_max_topK
1454+
14471455

14481456
def test_decoding_mode():
14491457
mode = trtllm.DecodingMode.Auto()

0 commit comments

Comments
 (0)