Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 68 additions & 16 deletions src/pb_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,70 @@ namespace py = pybind11;
typedef SSIZE_T ssize_t;
#endif

#include <cstdint>
#include <cstring>
#include <vector>

namespace triton { namespace backend { namespace python {

#ifdef TRITON_PB_STUB
py::array
deserialize_bytes_tensor_cpp(const uint8_t* data, size_t data_size)
{
if (data_size == 0) {
py::module numpy = py::module::import("numpy");
return numpy.attr("empty")(0, py::dtype("object"));
}

// First pass: count the number of strings and calculate total size
size_t offset = 0;
size_t num_strings = 0;
size_t total_string_size = 0;

while (offset < data_size) {
if (offset + 4 > data_size) {
throw PythonBackendException(
"Invalid bytes tensor data: incomplete length field");
}

// Read 4-byte length (little-endian)
uint32_t length = *reinterpret_cast<const uint32_t*>(data + offset);
offset += 4;

if (offset + length > data_size) {
throw PythonBackendException(
"Invalid bytes tensor data: string extends beyond buffer");
}

num_strings++;
total_string_size += length;
offset += length;
}

// Create numpy array of objects using pybind11's numpy module
py::module numpy = py::module::import("numpy");
py::array result = numpy.attr("empty")(num_strings, py::dtype("object"));
auto result_ptr = static_cast<PyObject**>(result.request().ptr);

// Second pass: extract strings
offset = 0;
size_t string_index = 0;

while (offset < data_size) {
uint32_t length = *reinterpret_cast<const uint32_t*>(data + offset);
offset += 4;

// Create Python bytes object using pybind11
py::bytes bytes_obj(reinterpret_cast<const char*>(data + offset), length);
Py_INCREF(bytes_obj.ptr()); // Increment reference count
result_ptr[string_index] = bytes_obj.ptr();
string_index++;
offset += length;
}

return result;
}

PbTensor::PbTensor(const std::string& name, py::array& numpy_array)
: name_(name)
{
Expand Down Expand Up @@ -160,14 +221,9 @@ PbTensor::PbTensor(
py::array(triton_to_pybind_dtype(dtype_), dims_, (void*)memory_ptr_);
numpy_array_ = numpy_array.attr("view")(triton_to_numpy_type(dtype_));
} else {
py::object numpy_array = py::array(
triton_to_pybind_dtype(TRITONSERVER_TYPE_UINT8), {byte_size},
(void*)memory_ptr_);
py::module triton_pb_utils =
py::module::import("triton_python_backend_utils");
numpy_array_ =
triton_pb_utils.attr("deserialize_bytes_tensor")(numpy_array)
.attr("reshape")(dims);
py::object numpy_array = deserialize_bytes_tensor_cpp(
static_cast<const uint8_t*>(memory_ptr_), byte_size_);
numpy_array_ = numpy_array.attr("reshape")(dims_);
}
} else {
numpy_array_ = py::none();
Expand Down Expand Up @@ -234,6 +290,7 @@ delete_unused_dltensor(PyObject* dlp)
}
}


std::shared_ptr<PbTensor>
PbTensor::FromNumpy(const std::string& name, py::array& numpy_array)
{
Expand Down Expand Up @@ -668,14 +725,9 @@ PbTensor::PbTensor(
py::array(triton_to_pybind_dtype(dtype_), dims_, (void*)memory_ptr_);
numpy_array_ = numpy_array.attr("view")(triton_to_numpy_type(dtype_));
} else {
py::object numpy_array = py::array(
triton_to_pybind_dtype(TRITONSERVER_TYPE_UINT8), {byte_size_},
(void*)memory_ptr_);
py::module triton_pb_utils =
py::module::import("triton_python_backend_utils");
numpy_array_ =
triton_pb_utils.attr("deserialize_bytes_tensor")(numpy_array)
.attr("reshape")(dims_);
py::object numpy_array = deserialize_bytes_tensor_cpp(
static_cast<const uint8_t*>(memory_ptr_), byte_size_);
numpy_array_ = numpy_array.attr("reshape")(dims_);
}
} else {
numpy_array_ = py::none();
Expand Down
Loading