Skip to content

Commit 145b89e

Browse files
author
Wei Chen
committed
perf: optimize string tensor deserialization with high performance c++ implementation
1 parent 8b5a055 commit 145b89e

File tree

1 file changed

+69
-16
lines changed

1 file changed

+69
-16
lines changed

src/pb_tensor.cc

Lines changed: 69 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,15 @@ namespace py = pybind11;
4141
typedef SSIZE_T ssize_t;
4242
#endif
4343

44+
#include <cstdint>
45+
#include <cstring>
46+
#include <vector>
47+
4448
namespace triton { namespace backend { namespace python {
4549

4650
#ifdef TRITON_PB_STUB
51+
py::array deserialize_bytes_tensor_cpp(const uint8_t* data, size_t data_size);
52+
4753
PbTensor::PbTensor(const std::string& name, py::array& numpy_array)
4854
: name_(name)
4955
{
@@ -160,14 +166,9 @@ PbTensor::PbTensor(
160166
py::array(triton_to_pybind_dtype(dtype_), dims_, (void*)memory_ptr_);
161167
numpy_array_ = numpy_array.attr("view")(triton_to_numpy_type(dtype_));
162168
} else {
163-
py::object numpy_array = py::array(
164-
triton_to_pybind_dtype(TRITONSERVER_TYPE_UINT8), {byte_size},
165-
(void*)memory_ptr_);
166-
py::module triton_pb_utils =
167-
py::module::import("triton_python_backend_utils");
168-
numpy_array_ =
169-
triton_pb_utils.attr("deserialize_bytes_tensor")(numpy_array)
170-
.attr("reshape")(dims);
169+
numpy_array_ = deserialize_bytes_tensor_cpp(
170+
static_cast<const uint8_t*>(memory_ptr_), byte_size)
171+
.attr("reshape")(dims);
171172
}
172173
} else {
173174
numpy_array_ = py::none();
@@ -234,6 +235,63 @@ delete_unused_dltensor(PyObject* dlp)
234235
}
235236
}
236237

238+
py::array
239+
deserialize_bytes_tensor_cpp(const uint8_t* data, size_t data_size)
240+
{
241+
if (data_size == 0) {
242+
py::module numpy = py::module::import("numpy");
243+
return numpy.attr("empty")(0, py::dtype("object"));
244+
}
245+
246+
// First pass: count the number of strings and calculate total size
247+
size_t offset = 0;
248+
size_t num_strings = 0;
249+
size_t total_string_size = 0;
250+
251+
while (offset < data_size) {
252+
if (offset + 4 > data_size) {
253+
throw PythonBackendException(
254+
"Invalid bytes tensor data: incomplete length field");
255+
}
256+
257+
// Read 4-byte length (little-endian)
258+
uint32_t length = *reinterpret_cast<const uint32_t*>(data + offset);
259+
offset += 4;
260+
261+
if (offset + length > data_size) {
262+
throw PythonBackendException(
263+
"Invalid bytes tensor data: string extends beyond buffer");
264+
}
265+
266+
num_strings++;
267+
total_string_size += length;
268+
offset += length;
269+
}
270+
271+
// Create numpy array of objects using pybind11's numpy module
272+
py::module numpy = py::module::import("numpy");
273+
py::array result = numpy.attr("empty")(num_strings, py::dtype("object"));
274+
auto result_ptr = static_cast<PyObject**>(result.request().ptr);
275+
276+
// Second pass: extract strings
277+
offset = 0;
278+
size_t string_index = 0;
279+
280+
while (offset < data_size) {
281+
uint32_t length = *reinterpret_cast<const uint32_t*>(data + offset);
282+
offset += 4;
283+
284+
// Create Python bytes object using pybind11
285+
py::bytes bytes_obj(reinterpret_cast<const char*>(data + offset), length);
286+
Py_INCREF(bytes_obj.ptr()); // Increment reference count
287+
result_ptr[string_index] = bytes_obj.ptr();
288+
string_index++;
289+
offset += length;
290+
}
291+
292+
return result;
293+
}
294+
237295
std::shared_ptr<PbTensor>
238296
PbTensor::FromNumpy(const std::string& name, py::array& numpy_array)
239297
{
@@ -668,14 +726,9 @@ PbTensor::PbTensor(
668726
py::array(triton_to_pybind_dtype(dtype_), dims_, (void*)memory_ptr_);
669727
numpy_array_ = numpy_array.attr("view")(triton_to_numpy_type(dtype_));
670728
} else {
671-
py::object numpy_array = py::array(
672-
triton_to_pybind_dtype(TRITONSERVER_TYPE_UINT8), {byte_size_},
673-
(void*)memory_ptr_);
674-
py::module triton_pb_utils =
675-
py::module::import("triton_python_backend_utils");
676-
numpy_array_ =
677-
triton_pb_utils.attr("deserialize_bytes_tensor")(numpy_array)
678-
.attr("reshape")(dims_);
729+
numpy_array_ = deserialize_bytes_tensor_cpp(
730+
static_cast<const uint8_t*>(memory_ptr_), byte_size_)
731+
.attr("reshape")(dims_);
679732
}
680733
} else {
681734
numpy_array_ = py::none();

0 commit comments

Comments
 (0)