Skip to content

Commit 9fa1ae5

Browse files
committed
wasi_nn_tensorflowlite.cpp: fix get_output return size
it should be byte size, not the number of (fp32) values. i'm ambivalent about how to deal with the compatibility for the legacy wamr-specific "wasi_nn". for now, i avoided changing it. (so that existing tests using the legacy abi, namely test_tensorflow.c and test_tensorflow_quantized.c, passes as they are.) if we have any users who still want to use the legacy abi, i suppose they consider the compatibility is more important than the consistency with other backends. cf. #4376
1 parent a29f394 commit 9fa1ae5

File tree

1 file changed

+46
-16
lines changed

1 file changed

+46
-16
lines changed

core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp

Lines changed: 46 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -389,23 +389,23 @@ get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
389389
return too_large;
390390
}
391391

392-
uint32_t model_tensor_size = 1;
393-
for (int i = 0; i < (int)tensor->dims->size; ++i)
394-
model_tensor_size *= (uint32_t)tensor->dims->data[i];
395-
396-
if (*output_tensor_size < model_tensor_size) {
397-
NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index);
398-
return too_large;
399-
}
400-
401392
if (tensor->quantization.type == kTfLiteNoQuantization) {
402393
NN_DBG_PRINTF("No quantization information");
403-
float *ot =
404-
tfl_ctx->interpreters[ctx].interpreter->typed_output_tensor<float>(
405-
index);
406-
407-
int size = model_tensor_size * sizeof(float);
408-
bh_memcpy_s(output_tensor, size, ot, size);
394+
if (*output_tensor_size < tensor->bytes) {
395+
NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index);
396+
return too_large;
397+
}
398+
bh_memcpy_s(output_tensor, *output_tensor_size, tensor->data.data,
399+
tensor->bytes);
400+
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
401+
*output_tensor_size = tensor->bytes;
402+
#else
403+
/*
404+
* for now, maintain the bug-to-bug compatibility with the old abi,
405+
* where the size here is the number of fp32, not bytes.
406+
*/
407+
*output_tensor_size = tensor->bytes / sizeof(float);
408+
#endif
409409
}
410410
else { // TODO: Assuming uint8 quantized networks.
411411
TfLiteAffineQuantization *quant_info =
@@ -414,6 +414,27 @@ get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
414414
NN_ERR_PRINTF("Quantization per channel is not supported");
415415
return runtime_error;
416416
}
417+
418+
uint32_t model_tensor_size = 1;
419+
for (int i = 0; i < (int)tensor->dims->size; ++i)
420+
model_tensor_size *= (uint32_t)tensor->dims->data[i];
421+
422+
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
423+
if (*output_tensor_size / sizeof(float) < model_tensor_size) {
424+
NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index);
425+
return too_large;
426+
}
427+
#else
428+
/*
429+
* for now, maintain the bug-to-bug compatibility with the old abi,
430+
* where the size here is the number of fp32, not bytes.
431+
*/
432+
if (*output_tensor_size < model_tensor_size) {
433+
NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index);
434+
return too_large;
435+
}
436+
#endif
437+
417438
uint8_t *ot = tfl_ctx->interpreters[ctx]
418439
.interpreter->typed_output_tensor<uint8_t>(index);
419440

@@ -426,9 +447,18 @@ get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
426447
for (uint32_t i = 0; i < model_tensor_size; ++i) {
427448
output_tensor_f[i] = (ot[i] - zero_point) * scale;
428449
}
450+
451+
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
452+
*output_tensor_size = model_tensor_size * sizeof(float);
453+
#else
454+
/*
455+
* for now, maintain the bug-to-bug compatibility with the old abi,
456+
* where the size here is the number of fp32, not bytes.
457+
*/
458+
*output_tensor_size = model_tensor_size;
459+
#endif
429460
}
430461

431-
*output_tensor_size = model_tensor_size;
432462
return success;
433463
}
434464

0 commit comments

Comments
 (0)