Skip to content

Commit db0a01b

Browse files
committed
use torch custom op instead of pybind
Signed-off-by: qizixi <[email protected]> Signed-off-by: zixi-qi <[email protected]>
1 parent 6610917 commit db0a01b

File tree

6 files changed

+373
-67
lines changed

6 files changed

+373
-67
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -944,7 +944,7 @@ define_gpu_extension_target(
944944
# _suffix_cache_C extension
945945
#
946946
set(VLLM_SUFFIX_CACHE_EXT_SRC
947-
"csrc/suffix_cache/pybind.cc"
947+
"csrc/suffix_cache/torch_bindings.cpp"
948948
"csrc/suffix_cache/suffix_tree.cc")
949949

950950
message(STATUS "Enabling suffix_cache extension.")

csrc/suffix_cache/pybind.cc

Lines changed: 0 additions & 41 deletions
This file was deleted.
Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
// Copyright 2025 Snowflake Inc.
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
// Licensed under the Apache License, Version 2.0 (the "License");
5+
// you may not use this file except in compliance with the License.
6+
// You may obtain a copy of the License at
7+
//
8+
// http://www.apache.org/licenses/LICENSE-2.0
9+
//
10+
// Unless required by applicable law or agreed to in writing, software
11+
// distributed under the License is distributed on an "AS IS" BASIS,
12+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
// See the License for the specific language governing permissions and
14+
// limitations under the License.
15+
16+
#include <torch/library.h>
17+
#include <torch/types.h>
18+
#include <ATen/ATen.h>
19+
#include <ATen/core/jit_type.h>
20+
21+
#include "suffix_tree.h"
22+
#include "core/registration.h"
23+
24+
// Register custom types with PyTorch
25+
namespace {
26+
c10::intrusive_ptr<c10::ivalue::Object> make_candidate(
27+
const std::vector<int64_t>& token_ids,
28+
const std::vector<int64_t>& parents,
29+
const std::vector<double>& probs,
30+
double score,
31+
int64_t match_len) {
32+
33+
auto obj = c10::ivalue::Object::create(
34+
c10::StrongTypePtr(nullptr, c10::ClassType::create(
35+
"_suffix_cache.Candidate", c10::nullopt)));
36+
37+
obj->setAttr(0, token_ids);
38+
obj->setAttr(1, parents);
39+
obj->setAttr(2, probs);
40+
obj->setAttr(3, score);
41+
obj->setAttr(4, match_len);
42+
43+
return obj;
44+
}
45+
46+
// Wrapper functions for SuffixTree operations
47+
class SuffixTreeWrapper {
48+
std::unique_ptr<SuffixTree> tree_;
49+
public:
50+
explicit SuffixTreeWrapper(int64_t max_depth)
51+
: tree_(std::make_unique<SuffixTree>(static_cast<int>(max_depth))) {}
52+
53+
int64_t num_seqs() const {
54+
return static_cast<int64_t>(tree_->num_seqs());
55+
}
56+
57+
void append(int64_t seq_id, int64_t token) {
58+
tree_->append(static_cast<int>(seq_id), static_cast<int>(token));
59+
}
60+
61+
void extend(int64_t seq_id, const std::vector<int64_t>& tokens) {
62+
std::vector<int> int_tokens;
63+
int_tokens.reserve(tokens.size());
64+
for (int64_t token : tokens) {
65+
int_tokens.push_back(static_cast<int>(token));
66+
}
67+
tree_->extend(static_cast<int>(seq_id), int_tokens);
68+
}
69+
70+
void remove(int64_t seq_id) {
71+
tree_->remove(static_cast<int>(seq_id));
72+
}
73+
74+
c10::intrusive_ptr<c10::ivalue::Object> speculate(
75+
const std::vector<int64_t>& pattern,
76+
int64_t max_spec_tokens,
77+
double max_spec_factor,
78+
double max_spec_offset,
79+
double min_token_prob,
80+
bool use_tree_spec) {
81+
82+
std::vector<int> int_pattern;
83+
int_pattern.reserve(pattern.size());
84+
for (int64_t token : pattern) {
85+
int_pattern.push_back(static_cast<int>(token));
86+
}
87+
88+
Candidate result = tree_->speculate(
89+
int_pattern,
90+
static_cast<int>(max_spec_tokens),
91+
static_cast<float>(max_spec_factor),
92+
static_cast<float>(max_spec_offset),
93+
static_cast<float>(min_token_prob),
94+
use_tree_spec);
95+
96+
// Convert Candidate to PyTorch custom type
97+
std::vector<int64_t> token_ids(result.token_ids.begin(), result.token_ids.end());
98+
std::vector<int64_t> parents(result.parents.begin(), result.parents.end());
99+
std::vector<double> probs(result.probs.begin(), result.probs.end());
100+
101+
return make_candidate(token_ids, parents, probs,
102+
static_cast<double>(result.score),
103+
static_cast<int64_t>(result.match_len));
104+
}
105+
106+
std::string check_integrity() {
107+
return tree_->check_integrity();
108+
}
109+
110+
int64_t estimate_memory() const {
111+
return static_cast<int64_t>(tree_->estimate_memory());
112+
}
113+
};
114+
115+
// Shim functions for TORCH_LIBRARY registration
116+
torch::Tensor suffix_tree_create(int64_t max_depth) {
117+
auto wrapper = std::make_unique<SuffixTreeWrapper>(max_depth);
118+
void* ptr = wrapper.release();
119+
120+
// Store the pointer in a tensor (this is a common pattern in vLLM)
121+
// We use a CPU int64 tensor to store the pointer
122+
auto options = torch::TensorOptions().dtype(torch::kInt64).device(torch::kCPU);
123+
auto tensor = torch::empty({1}, options);
124+
tensor.data_ptr<int64_t>()[0] = reinterpret_cast<int64_t>(ptr);
125+
return tensor;
126+
}
127+
128+
void suffix_tree_destroy(torch::Tensor handle) {
129+
int64_t ptr_value = handle.data_ptr<int64_t>()[0];
130+
auto* wrapper = reinterpret_cast<SuffixTreeWrapper*>(ptr_value);
131+
delete wrapper;
132+
}
133+
134+
int64_t suffix_tree_num_seqs(torch::Tensor handle) {
135+
int64_t ptr_value = handle.data_ptr<int64_t>()[0];
136+
auto* wrapper = reinterpret_cast<SuffixTreeWrapper*>(ptr_value);
137+
return wrapper->num_seqs();
138+
}
139+
140+
void suffix_tree_append(torch::Tensor handle, int64_t seq_id, int64_t token) {
141+
int64_t ptr_value = handle.data_ptr<int64_t>()[0];
142+
auto* wrapper = reinterpret_cast<SuffixTreeWrapper*>(ptr_value);
143+
wrapper->append(seq_id, token);
144+
}
145+
146+
void suffix_tree_extend(torch::Tensor handle, int64_t seq_id, torch::Tensor tokens) {
147+
int64_t ptr_value = handle.data_ptr<int64_t>()[0];
148+
auto* wrapper = reinterpret_cast<SuffixTreeWrapper*>(ptr_value);
149+
150+
auto tokens_accessor = tokens.accessor<int64_t, 1>();
151+
std::vector<int64_t> token_vec;
152+
token_vec.reserve(tokens_accessor.size(0));
153+
for (int64_t i = 0; i < tokens_accessor.size(0); ++i) {
154+
token_vec.push_back(tokens_accessor[i]);
155+
}
156+
157+
wrapper->extend(seq_id, token_vec);
158+
}
159+
160+
void suffix_tree_remove(torch::Tensor handle, int64_t seq_id) {
161+
int64_t ptr_value = handle.data_ptr<int64_t>()[0];
162+
auto* wrapper = reinterpret_cast<SuffixTreeWrapper*>(ptr_value);
163+
wrapper->remove(seq_id);
164+
}
165+
166+
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, double, int64_t>
167+
suffix_tree_speculate(torch::Tensor handle,
168+
torch::Tensor pattern,
169+
int64_t max_spec_tokens,
170+
double max_spec_factor,
171+
double max_spec_offset,
172+
double min_token_prob,
173+
bool use_tree_spec) {
174+
int64_t ptr_value = handle.data_ptr<int64_t>()[0];
175+
auto* wrapper = reinterpret_cast<SuffixTreeWrapper*>(ptr_value);
176+
177+
auto pattern_accessor = pattern.accessor<int64_t, 1>();
178+
std::vector<int64_t> pattern_vec;
179+
pattern_vec.reserve(pattern_accessor.size(0));
180+
for (int64_t i = 0; i < pattern_accessor.size(0); ++i) {
181+
pattern_vec.push_back(pattern_accessor[i]);
182+
}
183+
184+
auto result = wrapper->speculate(pattern_vec, max_spec_tokens,
185+
max_spec_factor, max_spec_offset,
186+
min_token_prob, use_tree_spec);
187+
188+
// Extract attributes from the custom object
189+
auto token_ids_list = result->getAttr(0).toIntList();
190+
auto parents_list = result->getAttr(1).toIntList();
191+
auto probs_list = result->getAttr(2).toDoubleList();
192+
double score = result->getAttr(3).toDouble();
193+
int64_t match_len = result->getAttr(4).toInt();
194+
195+
// Convert to tensors
196+
auto options = torch::TensorOptions().dtype(torch::kInt64).device(torch::kCPU);
197+
auto float_options = torch::TensorOptions().dtype(torch::kFloat64).device(torch::kCPU);
198+
199+
auto token_ids_tensor = torch::tensor(token_ids_list, options);
200+
auto parents_tensor = torch::tensor(parents_list, options);
201+
auto probs_tensor = torch::tensor(probs_list, float_options);
202+
203+
return std::make_tuple(token_ids_tensor, parents_tensor, probs_tensor, score, match_len);
204+
}
205+
206+
std::string suffix_tree_check_integrity(torch::Tensor handle) {
207+
int64_t ptr_value = handle.data_ptr<int64_t>()[0];
208+
auto* wrapper = reinterpret_cast<SuffixTreeWrapper*>(ptr_value);
209+
return wrapper->check_integrity();
210+
}
211+
212+
int64_t suffix_tree_estimate_memory(torch::Tensor handle) {
213+
int64_t ptr_value = handle.data_ptr<int64_t>()[0];
214+
auto* wrapper = reinterpret_cast<SuffixTreeWrapper*>(ptr_value);
215+
return wrapper->estimate_memory();
216+
}
217+
218+
} // anonymous namespace
219+
220+
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, suffix_cache) {
221+
// SuffixTree operations
222+
suffix_cache.def("suffix_tree_create(int max_depth) -> Tensor");
223+
suffix_cache.impl("suffix_tree_create", torch::kCPU, &suffix_tree_create);
224+
225+
suffix_cache.def("suffix_tree_destroy(Tensor handle) -> ()");
226+
suffix_cache.impl("suffix_tree_destroy", torch::kCPU, &suffix_tree_destroy);
227+
228+
suffix_cache.def("suffix_tree_num_seqs(Tensor handle) -> int");
229+
suffix_cache.impl("suffix_tree_num_seqs", torch::kCPU, &suffix_tree_num_seqs);
230+
231+
suffix_cache.def("suffix_tree_append(Tensor handle, int seq_id, int token) -> ()");
232+
suffix_cache.impl("suffix_tree_append", torch::kCPU, &suffix_tree_append);
233+
234+
suffix_cache.def("suffix_tree_extend(Tensor handle, int seq_id, Tensor tokens) -> ()");
235+
suffix_cache.impl("suffix_tree_extend", torch::kCPU, &suffix_tree_extend);
236+
237+
suffix_cache.def("suffix_tree_remove(Tensor handle, int seq_id) -> ()");
238+
suffix_cache.impl("suffix_tree_remove", torch::kCPU, &suffix_tree_remove);
239+
240+
suffix_cache.def("suffix_tree_speculate(Tensor handle, Tensor pattern, int max_spec_tokens, float max_spec_factor, float max_spec_offset, float min_token_prob, bool use_tree_spec) -> (Tensor, Tensor, Tensor, float, int)");
241+
suffix_cache.impl("suffix_tree_speculate", torch::kCPU, &suffix_tree_speculate);
242+
243+
suffix_cache.def("suffix_tree_check_integrity(Tensor handle) -> str");
244+
suffix_cache.impl("suffix_tree_check_integrity", torch::kCPU, &suffix_tree_check_integrity);
245+
246+
suffix_cache.def("suffix_tree_estimate_memory(Tensor handle) -> int");
247+
suffix_cache.impl("suffix_tree_estimate_memory", torch::kCPU, &suffix_tree_estimate_memory);
248+
}
249+
250+
REGISTER_EXTENSION(_suffix_cache_C)

tests/v1/spec_decode/test_suffix_tree_cpp.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,9 @@
33
"""Unit tests for C++ suffix tree implementation."""
44

55
import pytest
6+
from vllm.v1.spec_decode.suffix_decode import Candidate, SuffixTree
67

7-
# Try to import C++ implementation
8-
try:
9-
from vllm._suffix_cache_C import Candidate, SuffixTree
10-
cpp_available = True
11-
except ImportError:
12-
cpp_available = False
13-
pytestmark = pytest.mark.skip(reason="C++ suffix tree not available")
148

15-
16-
@pytest.mark.skipif(not cpp_available, reason="C++ suffix tree not available")
179
class TestSuffixTreeCpp:
1810
"""Test suite for C++ SuffixTree implementation."""
1911

vllm/v1/spec_decode/suffix_decode/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,6 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
from .suffix_cache import SuffixCache, SuffixSpecResult
16+
from .suffix_cache import SuffixCache, SuffixSpecResult, Candidate, SuffixTree
1717

18-
__all__ = ["SuffixCache", "SuffixSpecResult"]
18+
__all__ = ["SuffixCache", "SuffixSpecResult", "Candidate", "SuffixTree"]

0 commit comments

Comments
 (0)