@@ -48,7 +48,62 @@ typedef SSIZE_T ssize_t;
48
48
namespace triton { namespace backend { namespace python {
49
49
50
50
#ifdef TRITON_PB_STUB
51
- py::array deserialize_bytes_tensor_cpp (const uint8_t * data, size_t data_size);
51
+ py::array
52
+ deserialize_bytes_tensor_cpp (const uint8_t * data, size_t data_size)
53
+ {
54
+ if (data_size == 0 ) {
55
+ py::module numpy = py::module::import (" numpy" );
56
+ return numpy.attr (" empty" )(0 , py::dtype (" object" ));
57
+ }
58
+
59
+ // First pass: count the number of strings and calculate total size
60
+ size_t offset = 0 ;
61
+ size_t num_strings = 0 ;
62
+ size_t total_string_size = 0 ;
63
+
64
+ while (offset < data_size) {
65
+ if (offset + 4 > data_size) {
66
+ throw PythonBackendException (
67
+ " Invalid bytes tensor data: incomplete length field" );
68
+ }
69
+
70
+ // Read 4-byte length (little-endian)
71
+ uint32_t length = *reinterpret_cast <const uint32_t *>(data + offset);
72
+ offset += 4 ;
73
+
74
+ if (offset + length > data_size) {
75
+ throw PythonBackendException (
76
+ " Invalid bytes tensor data: string extends beyond buffer" );
77
+ }
78
+
79
+ num_strings++;
80
+ total_string_size += length;
81
+ offset += length;
82
+ }
83
+
84
+ // Create numpy array of objects using pybind11's numpy module
85
+ py::module numpy = py::module::import (" numpy" );
86
+ py::array result = numpy.attr (" empty" )(num_strings, py::dtype (" object" ));
87
+ auto result_ptr = static_cast <PyObject**>(result.request ().ptr );
88
+
89
+ // Second pass: extract strings
90
+ offset = 0 ;
91
+ size_t string_index = 0 ;
92
+
93
+ while (offset < data_size) {
94
+ uint32_t length = *reinterpret_cast <const uint32_t *>(data + offset);
95
+ offset += 4 ;
96
+
97
+ // Create Python bytes object using pybind11
98
+ py::bytes bytes_obj (reinterpret_cast <const char *>(data + offset), length);
99
+ Py_INCREF (bytes_obj.ptr ()); // Increment reference count
100
+ result_ptr[string_index] = bytes_obj.ptr ();
101
+ string_index++;
102
+ offset += length;
103
+ }
104
+
105
+ return result;
106
+ }
52
107
53
108
PbTensor::PbTensor (const std::string& name, py::array& numpy_array)
54
109
: name_(name)
@@ -166,9 +221,9 @@ PbTensor::PbTensor(
166
221
py::array (triton_to_pybind_dtype (dtype_), dims_, (void *)memory_ptr_);
167
222
numpy_array_ = numpy_array.attr (" view" )(triton_to_numpy_type (dtype_));
168
223
} else {
169
- numpy_array_ = deserialize_bytes_tensor_cpp (
170
- static_cast <const uint8_t *>(memory_ptr_), byte_size)
171
- .attr (" reshape" )(dims );
224
+ py::object numpy_array = deserialize_bytes_tensor_cpp (
225
+ static_cast <const uint8_t *>(memory_ptr_), byte_size_);
226
+ numpy_array_ = numpy_array .attr (" reshape" )(dims_ );
172
227
}
173
228
} else {
174
229
numpy_array_ = py::none ();
@@ -235,62 +290,6 @@ delete_unused_dltensor(PyObject* dlp)
235
290
}
236
291
}
237
292
238
- py::array
239
- deserialize_bytes_tensor_cpp (const uint8_t * data, size_t data_size)
240
- {
241
- if (data_size == 0 ) {
242
- py::module numpy = py::module::import (" numpy" );
243
- return numpy.attr (" empty" )(0 , py::dtype (" object" ));
244
- }
245
-
246
- // First pass: count the number of strings and calculate total size
247
- size_t offset = 0 ;
248
- size_t num_strings = 0 ;
249
- size_t total_string_size = 0 ;
250
-
251
- while (offset < data_size) {
252
- if (offset + 4 > data_size) {
253
- throw PythonBackendException (
254
- " Invalid bytes tensor data: incomplete length field" );
255
- }
256
-
257
- // Read 4-byte length (little-endian)
258
- uint32_t length = *reinterpret_cast <const uint32_t *>(data + offset);
259
- offset += 4 ;
260
-
261
- if (offset + length > data_size) {
262
- throw PythonBackendException (
263
- " Invalid bytes tensor data: string extends beyond buffer" );
264
- }
265
-
266
- num_strings++;
267
- total_string_size += length;
268
- offset += length;
269
- }
270
-
271
- // Create numpy array of objects using pybind11's numpy module
272
- py::module numpy = py::module::import (" numpy" );
273
- py::array result = numpy.attr (" empty" )(num_strings, py::dtype (" object" ));
274
- auto result_ptr = static_cast <PyObject**>(result.request ().ptr );
275
-
276
- // Second pass: extract strings
277
- offset = 0 ;
278
- size_t string_index = 0 ;
279
-
280
- while (offset < data_size) {
281
- uint32_t length = *reinterpret_cast <const uint32_t *>(data + offset);
282
- offset += 4 ;
283
-
284
- // Create Python bytes object using pybind11
285
- py::bytes bytes_obj (reinterpret_cast <const char *>(data + offset), length);
286
- Py_INCREF (bytes_obj.ptr ()); // Increment reference count
287
- result_ptr[string_index] = bytes_obj.ptr ();
288
- string_index++;
289
- offset += length;
290
- }
291
-
292
- return result;
293
- }
294
293
295
294
std::shared_ptr<PbTensor>
296
295
PbTensor::FromNumpy (const std::string& name, py::array& numpy_array)
@@ -726,9 +725,9 @@ PbTensor::PbTensor(
726
725
py::array (triton_to_pybind_dtype (dtype_), dims_, (void *)memory_ptr_);
727
726
numpy_array_ = numpy_array.attr (" view" )(triton_to_numpy_type (dtype_));
728
727
} else {
729
- numpy_array_ = deserialize_bytes_tensor_cpp (
730
- static_cast <const uint8_t *>(memory_ptr_), byte_size_)
731
- .attr (" reshape" )(dims_);
728
+ py::object numpy_array = deserialize_bytes_tensor_cpp (
729
+ static_cast <const uint8_t *>(memory_ptr_), byte_size_);
730
+ numpy_array_ = numpy_array .attr (" reshape" )(dims_);
732
731
}
733
732
} else {
734
733
numpy_array_ = py::none ();
0 commit comments