Skip to content
Merged
11 changes: 10 additions & 1 deletion cpp/tests/resources/scripts/build_redrafter_engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,18 @@
def build_engine(base_model_dir: _pl.Path, drafter_model_dir: _pl.Path,
engine_dir: _pl.Path, *args):

base_ckpt_dir = f'{base_model_dir}-ckpt'
covert_cmd_base = [
_sys.executable, "examples/models/core/llama/convert_checkpoint.py"
] + (['--model_dir', str(base_model_dir)] if base_model_dir else []) + [
'--output_dir', str(base_ckpt_dir), '--dtype=float16'
] + list(args)

run_command(covert_cmd_base)

covert_cmd = [
_sys.executable, "examples/redrafter/convert_checkpoint.py"] + (
['--model_dir', str(base_model_dir)] if base_model_dir else []) + [
['--base_model_checkpoint_dir', str(base_ckpt_dir)] if base_model_dir else []) + [
'--drafter_model_dir', str(drafter_model_dir), \
'--output_dir', str(engine_dir), '--dtype=float16',
'--redrafter_num_beams=5', '--redrafter_draft_len_per_beam=5'
Expand Down
38 changes: 35 additions & 3 deletions examples/redrafter/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,24 +45,56 @@ We use `convert_checkpoint.py` script to convert the model for ReDrafter decodin
You can specify the 3 hyperparameters (described above) during this conversion. The resulting config.json file can be modified to alter these hyperparameters before the engine building process.

```bash
# From the `examples/models/core/llama/` directory, run,
python convert_checkpoint.py --model_dir ./vicuna-7b-v1.3 \
--output_dir ./vicuna-7b-v1.3-ckpt \
--dtype float16

# From this directory, `examples/redrafter/`, run,
python convert_checkpoint.py --base_model_checkpoint_dir ./vicuna-7b-v1.3-ckpt \
--drafter_model_dir ./vicuna-7b-drafter \
--output_dir ./tllm_checkkpoint_1gpu_redrafter \
--output_dir ./tllm_checkpoint_1gpu_redrafter \
--dtype float16 \
--redrafter_num_beams 4 \
--redrafter_draft_len_per_beam 5


trtllm-build --checkpoint_dir ./tllm_checkkpoint_1gpu_redrafter \
trtllm-build --checkpoint_dir ./tllm_checkpoint_1gpu_redrafter \
--output_dir ./tmp/redrafter/7B/trt_engines/fp16/1-gpu/ \
--gemm_plugin float16 \
--speculative_decoding_mode explicit_draft_tokens \
--max_batch_size 4

```

Note that the `speculative_decoding_mode` is set to `explicit_draft_tokens` which is how we categorized ReDrafter.

Similarly we can use an fp8 quantised base model and an bf16 draft head.
```bash
# From the `examples/models/core/qwen/` directory, run the below, to quantize model into FP8 and export trtllm checkpoint
python ../../../quantization/quantize.py --model_dir ./Qwen2.5-7B-Instruct/ \
--dtype bfloat16 \
--qformat fp8 \
--output_dir ./qwen_checkpoint_1gpu_fp8 \
--calib_size 1024

# From this directory, `examples/redrafter/`, run,
python convert_checkpoint.py --base_model_checkpoint_dir ./qwen_checkpoint_1gpu_fp8 \
--drafter_model_dir ./qwen-7b-drafter \
--output_dir ./tllm_checkpoint_1gpu_fp8 \
--dtype bfloat16 \
--redrafter_num_beams 1 \
--redrafter_draft_len_per_beam 3

# Build trtllm engines from the trtllm checkpoint
# Enable fp8 context fmha to get further acceleration by setting `--use_fp8_context_fmha enable`
trtllm-build --checkpoint_dir ./tllm_checkpoint_1gpu_fp8 \
--output_dir ./engine_outputs \
--gemm_plugin fp8 \
--speculative_decoding_mode explicit_draft_tokens \
--max_beam_width 1 \
--max_batch_size 4
```

### Run

Since the hyperparameters are used during engine build process, running a ReDrafter engine is similar to running just the base model.
Expand Down
Loading