@@ -41,9 +41,15 @@ namespace py = pybind11;
41
41
typedef SSIZE_T ssize_t ;
42
42
#endif
43
43
44
+ #include < cstdint>
45
+ #include < cstring>
46
+ #include < vector>
47
+
44
48
namespace triton { namespace backend { namespace python {
45
49
46
50
#ifdef TRITON_PB_STUB
51
+ py::array deserialize_bytes_tensor_cpp (const uint8_t * data, size_t data_size);
52
+
47
53
PbTensor::PbTensor (const std::string& name, py::array& numpy_array)
48
54
: name_(name)
49
55
{
@@ -160,14 +166,9 @@ PbTensor::PbTensor(
160
166
py::array (triton_to_pybind_dtype (dtype_), dims_, (void *)memory_ptr_);
161
167
numpy_array_ = numpy_array.attr (" view" )(triton_to_numpy_type (dtype_));
162
168
} else {
163
- py::object numpy_array = py::array (
164
- triton_to_pybind_dtype (TRITONSERVER_TYPE_UINT8), {byte_size},
165
- (void *)memory_ptr_);
166
- py::module triton_pb_utils =
167
- py::module::import (" triton_python_backend_utils" );
168
- numpy_array_ =
169
- triton_pb_utils.attr (" deserialize_bytes_tensor" )(numpy_array)
170
- .attr (" reshape" )(dims);
169
+ numpy_array_ = deserialize_bytes_tensor_cpp (
170
+ static_cast <const uint8_t *>(memory_ptr_), byte_size)
171
+ .attr (" reshape" )(dims);
171
172
}
172
173
} else {
173
174
numpy_array_ = py::none ();
@@ -234,6 +235,63 @@ delete_unused_dltensor(PyObject* dlp)
234
235
}
235
236
}
236
237
238
+ py::array
239
+ deserialize_bytes_tensor_cpp (const uint8_t * data, size_t data_size)
240
+ {
241
+ if (data_size == 0 ) {
242
+ py::module numpy = py::module::import (" numpy" );
243
+ return numpy.attr (" empty" )(0 , py::dtype (" object" ));
244
+ }
245
+
246
+ // First pass: count the number of strings and calculate total size
247
+ size_t offset = 0 ;
248
+ size_t num_strings = 0 ;
249
+ size_t total_string_size = 0 ;
250
+
251
+ while (offset < data_size) {
252
+ if (offset + 4 > data_size) {
253
+ throw PythonBackendException (
254
+ " Invalid bytes tensor data: incomplete length field" );
255
+ }
256
+
257
+ // Read 4-byte length (little-endian)
258
+ uint32_t length = *reinterpret_cast <const uint32_t *>(data + offset);
259
+ offset += 4 ;
260
+
261
+ if (offset + length > data_size) {
262
+ throw PythonBackendException (
263
+ " Invalid bytes tensor data: string extends beyond buffer" );
264
+ }
265
+
266
+ num_strings++;
267
+ total_string_size += length;
268
+ offset += length;
269
+ }
270
+
271
+ // Create numpy array of objects using pybind11's numpy module
272
+ py::module numpy = py::module::import (" numpy" );
273
+ py::array result = numpy.attr (" empty" )(num_strings, py::dtype (" object" ));
274
+ auto result_ptr = static_cast <PyObject**>(result.request ().ptr );
275
+
276
+ // Second pass: extract strings
277
+ offset = 0 ;
278
+ size_t string_index = 0 ;
279
+
280
+ while (offset < data_size) {
281
+ uint32_t length = *reinterpret_cast <const uint32_t *>(data + offset);
282
+ offset += 4 ;
283
+
284
+ // Create Python bytes object using pybind11
285
+ py::bytes bytes_obj (reinterpret_cast <const char *>(data + offset), length);
286
+ Py_INCREF (bytes_obj.ptr ()); // Increment reference count
287
+ result_ptr[string_index] = bytes_obj.ptr ();
288
+ string_index++;
289
+ offset += length;
290
+ }
291
+
292
+ return result;
293
+ }
294
+
237
295
std::shared_ptr<PbTensor>
238
296
PbTensor::FromNumpy (const std::string& name, py::array& numpy_array)
239
297
{
@@ -668,14 +726,9 @@ PbTensor::PbTensor(
668
726
py::array (triton_to_pybind_dtype (dtype_), dims_, (void *)memory_ptr_);
669
727
numpy_array_ = numpy_array.attr (" view" )(triton_to_numpy_type (dtype_));
670
728
} else {
671
- py::object numpy_array = py::array (
672
- triton_to_pybind_dtype (TRITONSERVER_TYPE_UINT8), {byte_size_},
673
- (void *)memory_ptr_);
674
- py::module triton_pb_utils =
675
- py::module::import (" triton_python_backend_utils" );
676
- numpy_array_ =
677
- triton_pb_utils.attr (" deserialize_bytes_tensor" )(numpy_array)
678
- .attr (" reshape" )(dims_);
729
+ numpy_array_ = deserialize_bytes_tensor_cpp (
730
+ static_cast <const uint8_t *>(memory_ptr_), byte_size_)
731
+ .attr (" reshape" )(dims_);
679
732
}
680
733
} else {
681
734
numpy_array_ = py::none ();
0 commit comments