Skip to content

Commit cf1a489

Browse files
author
Wei Chen
committed
Address PR comments
1 parent 145b89e commit cf1a489

File tree

1 file changed

+62
-63
lines changed

1 file changed

+62
-63
lines changed

src/pb_tensor.cc

Lines changed: 62 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,62 @@ typedef SSIZE_T ssize_t;
4848
namespace triton { namespace backend { namespace python {
4949

5050
#ifdef TRITON_PB_STUB
51-
py::array deserialize_bytes_tensor_cpp(const uint8_t* data, size_t data_size);
51+
py::array
52+
deserialize_bytes_tensor_cpp(const uint8_t* data, size_t data_size)
53+
{
54+
if (data_size == 0) {
55+
py::module numpy = py::module::import("numpy");
56+
return numpy.attr("empty")(0, py::dtype("object"));
57+
}
58+
59+
// First pass: count the number of strings and calculate total size
60+
size_t offset = 0;
61+
size_t num_strings = 0;
62+
size_t total_string_size = 0;
63+
64+
while (offset < data_size) {
65+
if (offset + 4 > data_size) {
66+
throw PythonBackendException(
67+
"Invalid bytes tensor data: incomplete length field");
68+
}
69+
70+
// Read 4-byte length (little-endian)
71+
uint32_t length = *reinterpret_cast<const uint32_t*>(data + offset);
72+
offset += 4;
73+
74+
if (offset + length > data_size) {
75+
throw PythonBackendException(
76+
"Invalid bytes tensor data: string extends beyond buffer");
77+
}
78+
79+
num_strings++;
80+
total_string_size += length;
81+
offset += length;
82+
}
83+
84+
// Create numpy array of objects using pybind11's numpy module
85+
py::module numpy = py::module::import("numpy");
86+
py::array result = numpy.attr("empty")(num_strings, py::dtype("object"));
87+
auto result_ptr = static_cast<PyObject**>(result.request().ptr);
88+
89+
// Second pass: extract strings
90+
offset = 0;
91+
size_t string_index = 0;
92+
93+
while (offset < data_size) {
94+
uint32_t length = *reinterpret_cast<const uint32_t*>(data + offset);
95+
offset += 4;
96+
97+
// Create Python bytes object using pybind11
98+
py::bytes bytes_obj(reinterpret_cast<const char*>(data + offset), length);
99+
Py_INCREF(bytes_obj.ptr()); // Increment reference count
100+
result_ptr[string_index] = bytes_obj.ptr();
101+
string_index++;
102+
offset += length;
103+
}
104+
105+
return result;
106+
}
52107

53108
PbTensor::PbTensor(const std::string& name, py::array& numpy_array)
54109
: name_(name)
@@ -166,9 +221,9 @@ PbTensor::PbTensor(
166221
py::array(triton_to_pybind_dtype(dtype_), dims_, (void*)memory_ptr_);
167222
numpy_array_ = numpy_array.attr("view")(triton_to_numpy_type(dtype_));
168223
} else {
169-
numpy_array_ = deserialize_bytes_tensor_cpp(
170-
static_cast<const uint8_t*>(memory_ptr_), byte_size)
171-
.attr("reshape")(dims);
224+
py::object numpy_array = deserialize_bytes_tensor_cpp(
225+
static_cast<const uint8_t*>(memory_ptr_), byte_size_);
226+
numpy_array_ = numpy_array.attr("reshape")(dims_);
172227
}
173228
} else {
174229
numpy_array_ = py::none();
@@ -235,62 +290,6 @@ delete_unused_dltensor(PyObject* dlp)
235290
}
236291
}
237292

238-
py::array
239-
deserialize_bytes_tensor_cpp(const uint8_t* data, size_t data_size)
240-
{
241-
if (data_size == 0) {
242-
py::module numpy = py::module::import("numpy");
243-
return numpy.attr("empty")(0, py::dtype("object"));
244-
}
245-
246-
// First pass: count the number of strings and calculate total size
247-
size_t offset = 0;
248-
size_t num_strings = 0;
249-
size_t total_string_size = 0;
250-
251-
while (offset < data_size) {
252-
if (offset + 4 > data_size) {
253-
throw PythonBackendException(
254-
"Invalid bytes tensor data: incomplete length field");
255-
}
256-
257-
// Read 4-byte length (little-endian)
258-
uint32_t length = *reinterpret_cast<const uint32_t*>(data + offset);
259-
offset += 4;
260-
261-
if (offset + length > data_size) {
262-
throw PythonBackendException(
263-
"Invalid bytes tensor data: string extends beyond buffer");
264-
}
265-
266-
num_strings++;
267-
total_string_size += length;
268-
offset += length;
269-
}
270-
271-
// Create numpy array of objects using pybind11's numpy module
272-
py::module numpy = py::module::import("numpy");
273-
py::array result = numpy.attr("empty")(num_strings, py::dtype("object"));
274-
auto result_ptr = static_cast<PyObject**>(result.request().ptr);
275-
276-
// Second pass: extract strings
277-
offset = 0;
278-
size_t string_index = 0;
279-
280-
while (offset < data_size) {
281-
uint32_t length = *reinterpret_cast<const uint32_t*>(data + offset);
282-
offset += 4;
283-
284-
// Create Python bytes object using pybind11
285-
py::bytes bytes_obj(reinterpret_cast<const char*>(data + offset), length);
286-
Py_INCREF(bytes_obj.ptr()); // Increment reference count
287-
result_ptr[string_index] = bytes_obj.ptr();
288-
string_index++;
289-
offset += length;
290-
}
291-
292-
return result;
293-
}
294293

295294
std::shared_ptr<PbTensor>
296295
PbTensor::FromNumpy(const std::string& name, py::array& numpy_array)
@@ -726,9 +725,9 @@ PbTensor::PbTensor(
726725
py::array(triton_to_pybind_dtype(dtype_), dims_, (void*)memory_ptr_);
727726
numpy_array_ = numpy_array.attr("view")(triton_to_numpy_type(dtype_));
728727
} else {
729-
numpy_array_ = deserialize_bytes_tensor_cpp(
730-
static_cast<const uint8_t*>(memory_ptr_), byte_size_)
731-
.attr("reshape")(dims_);
728+
py::object numpy_array = deserialize_bytes_tensor_cpp(
729+
static_cast<const uint8_t*>(memory_ptr_), byte_size_);
730+
numpy_array_ = numpy_array.attr("reshape")(dims_);
732731
}
733732
} else {
734733
numpy_array_ = py::none();

0 commit comments

Comments
 (0)