Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 57 additions & 16 deletions core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -389,23 +389,34 @@ get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
return too_large;
}

uint32_t model_tensor_size = 1;
for (int i = 0; i < (int)tensor->dims->size; ++i)
model_tensor_size *= (uint32_t)tensor->dims->data[i];

if (*output_tensor_size < model_tensor_size) {
NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index);
return too_large;
}

if (tensor->quantization.type == kTfLiteNoQuantization) {
NN_DBG_PRINTF("No quantization information");
float *ot =
tfl_ctx->interpreters[ctx].interpreter->typed_output_tensor<float>(
index);

int size = model_tensor_size * sizeof(float);
bh_memcpy_s(output_tensor, size, ot, size);
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
if (*output_tensor_size < tensor->bytes) {
NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index);
return too_large;
}
#else
/*
* for now, maintain the bug-to-bug compatibility with the old abi,
* where the size here is the number of fp32, not bytes.
*/
if (*output_tensor_size < tensor->bytes / sizeof(float)) {
NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index);
return too_large;
}
#endif
bh_memcpy_s(output_tensor, *output_tensor_size, tensor->data.data,
tensor->bytes);
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
*output_tensor_size = tensor->bytes;
#else
/*
* for now, maintain the bug-to-bug compatibility with the old abi,
* where the size here is the number of fp32, not bytes.
*/
*output_tensor_size = tensor->bytes / sizeof(float);
#endif
}
else { // TODO: Assuming uint8 quantized networks.
TfLiteAffineQuantization *quant_info =
Expand All @@ -414,6 +425,27 @@ get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
NN_ERR_PRINTF("Quantization per channel is not supported");
return runtime_error;
}

uint32_t model_tensor_size = 1;
for (int i = 0; i < (int)tensor->dims->size; ++i)
model_tensor_size *= (uint32_t)tensor->dims->data[i];

#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
if (*output_tensor_size / sizeof(float) < model_tensor_size) {
NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index);
return too_large;
}
#else
/*
* for now, maintain the bug-to-bug compatibility with the old abi,
* where the size here is the number of fp32, not bytes.
*/
if (*output_tensor_size < model_tensor_size) {
NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index);
return too_large;
}
#endif

uint8_t *ot = tfl_ctx->interpreters[ctx]
.interpreter->typed_output_tensor<uint8_t>(index);

Expand All @@ -426,9 +458,18 @@ get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
for (uint32_t i = 0; i < model_tensor_size; ++i) {
output_tensor_f[i] = (ot[i] - zero_point) * scale;
}

#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
*output_tensor_size = model_tensor_size * sizeof(float);
#else
/*
* for now, maintain the bug-to-bug compatibility with the old abi,
* where the size here is the number of fp32, not bytes.
*/
*output_tensor_size = model_tensor_size;
#endif
}

*output_tensor_size = model_tensor_size;
return success;
}

Expand Down
Loading