diff --git a/src/pb_tensor.cc b/src/pb_tensor.cc index 9fde62fe..26e77586 100644 --- a/src/pb_tensor.cc +++ b/src/pb_tensor.cc @@ -41,9 +41,70 @@ namespace py = pybind11; typedef SSIZE_T ssize_t; #endif +#include +#include +#include + namespace triton { namespace backend { namespace python { #ifdef TRITON_PB_STUB +py::array +deserialize_bytes_tensor_cpp(const uint8_t* data, size_t data_size) +{ + if (data_size == 0) { + py::module numpy = py::module::import("numpy"); + return numpy.attr("empty")(0, py::dtype("object")); + } + + // First pass: count the number of strings and calculate total size + size_t offset = 0; + size_t num_strings = 0; + size_t total_string_size = 0; + + while (offset < data_size) { + if (offset + 4 > data_size) { + throw PythonBackendException( + "Invalid bytes tensor data: incomplete length field"); + } + + // Read 4-byte length (little-endian) + uint32_t length = *reinterpret_cast(data + offset); + offset += 4; + + if (offset + length > data_size) { + throw PythonBackendException( + "Invalid bytes tensor data: string extends beyond buffer"); + } + + num_strings++; + total_string_size += length; + offset += length; + } + + // Create numpy array of objects using pybind11's numpy module + py::module numpy = py::module::import("numpy"); + py::array result = numpy.attr("empty")(num_strings, py::dtype("object")); + auto result_ptr = static_cast(result.request().ptr); + + // Second pass: extract strings + offset = 0; + size_t string_index = 0; + + while (offset < data_size) { + uint32_t length = *reinterpret_cast(data + offset); + offset += 4; + + // Create Python bytes object using pybind11 + py::bytes bytes_obj(reinterpret_cast(data + offset), length); + Py_INCREF(bytes_obj.ptr()); // Increment reference count + result_ptr[string_index] = bytes_obj.ptr(); + string_index++; + offset += length; + } + + return result; +} + PbTensor::PbTensor(const std::string& name, py::array& numpy_array) : name_(name) { @@ -160,14 +221,9 @@ PbTensor::PbTensor( py::array(triton_to_pybind_dtype(dtype_), dims_, (void*)memory_ptr_); numpy_array_ = numpy_array.attr("view")(triton_to_numpy_type(dtype_)); } else { - py::object numpy_array = py::array( - triton_to_pybind_dtype(TRITONSERVER_TYPE_UINT8), {byte_size}, - (void*)memory_ptr_); - py::module triton_pb_utils = - py::module::import("triton_python_backend_utils"); - numpy_array_ = - triton_pb_utils.attr("deserialize_bytes_tensor")(numpy_array) - .attr("reshape")(dims); + py::object numpy_array = deserialize_bytes_tensor_cpp( + static_cast(memory_ptr_), byte_size_); + numpy_array_ = numpy_array.attr("reshape")(dims_); } } else { numpy_array_ = py::none(); @@ -234,6 +290,7 @@ delete_unused_dltensor(PyObject* dlp) } } + std::shared_ptr PbTensor::FromNumpy(const std::string& name, py::array& numpy_array) { @@ -668,14 +725,9 @@ PbTensor::PbTensor( py::array(triton_to_pybind_dtype(dtype_), dims_, (void*)memory_ptr_); numpy_array_ = numpy_array.attr("view")(triton_to_numpy_type(dtype_)); } else { - py::object numpy_array = py::array( - triton_to_pybind_dtype(TRITONSERVER_TYPE_UINT8), {byte_size_}, - (void*)memory_ptr_); - py::module triton_pb_utils = - py::module::import("triton_python_backend_utils"); - numpy_array_ = - triton_pb_utils.attr("deserialize_bytes_tensor")(numpy_array) - .attr("reshape")(dims_); + py::object numpy_array = deserialize_bytes_tensor_cpp( + static_cast(memory_ptr_), byte_size_); + numpy_array_ = numpy_array.attr("reshape")(dims_); } } else { numpy_array_ = py::none();