From e2444ad6c1120db13b59ddfde0d5b81d369ce80d Mon Sep 17 00:00:00 2001
From: Philipp Zabel
Date: Tue, 23 Apr 2024 15:35:15 +0200
Subject: [PATCH] etnaviv/nn: Extend post-multiplier for v8 architecture
The post-multiplier was extended by 8 bits for improved precision.
The shift offset appears to have changed as well.
Signed-off-by: Philipp Zabel
Reviewed-by: Tomeu Vizoso
Part-of:
---
src/gallium/drivers/etnaviv/etnaviv_ml_nn.c | 23 ++++++++++++++++-----
1 file changed, 18 insertions(+), 5 deletions(-)
diff --git a/src/gallium/drivers/etnaviv/etnaviv_ml_nn.c b/src/gallium/drivers/etnaviv/etnaviv_ml_nn.c
index 73bb3d1349a45..4fe23d505a2db 100644
--- a/src/gallium/drivers/etnaviv/etnaviv_ml_nn.c
+++ b/src/gallium/drivers/etnaviv/etnaviv_ml_nn.c
@@ -149,7 +149,8 @@ struct etna_nn_params {
FIELD(out_zero_point, 8)
FIELD(kernel_direct_stream_from_VIP_sram, 1)
FIELD(depthwise, 1)
- FIELD(unused11, 14)
+ FIELD(post_multiplier_15_to_22, 8)
+ FIELD(unused11, 6)
/* 23, from here they aren't set on */
FIELD(unused12, 32)
@@ -641,6 +642,7 @@ create_nn_config(struct etna_ml_subgraph *subgraph, const struct etna_operation
struct pipe_context *context = subgraph->base.context;
struct etna_context *ctx = etna_context(context);
unsigned nn_core_count = ctx->screen->specs.nn_core_count;
+ unsigned nn_core_version = ctx->screen->specs.nn_core_version;
unsigned oc_sram_size = ctx->screen->specs.on_chip_sram_size;
struct etna_bo *bo = etna_bo_new(ctx->screen->dev,
sizeof(struct etna_nn_params),
@@ -840,16 +842,27 @@ create_nn_config(struct etna_ml_subgraph *subgraph, const struct etna_operation
float conv_scale = (operation->input_scale * operation->weight_scale) / operation->output_scale;
uint32_t scale_bits = fui(conv_scale);
/* Taken from https://github.com/pytorch/QNNPACK/blob/master/src/qnnpack/requantization.h#L130 */
- unsigned shift = 127 + 31 - 32 - (scale_bits >> 23) + 16;
+ unsigned shift = 127 + 31 - 32 - (scale_bits >> 23);
+ if (nn_core_version == 8)
+ shift += 1;
+ else
+ shift += 16;
/* Divides by 2 * (post_shift - 18), rounding to nearest integer. If result doesn't fit in 8 bits, it is clamped to 255. galcore sets to 15 if INT8, to 0 if UINT8. */
map->post_shift = shift & 0x1f;
map->post_shift_bit_5_6 = (shift >> 5) & 0x3;
/* Multiplies by (multiplier * 2^15) */
- map->post_multiplier = (scale_bits >> 8) & 0x1;
- map->post_multiplier_1_to_6 = (scale_bits >> 9) & 0x3f;
- map->post_multiplier_7_to_14 = (scale_bits >> 15) & 0xff;
+ if (nn_core_version == 8) {
+ map->post_multiplier = scale_bits & 0x1;
+ map->post_multiplier_1_to_6 = (scale_bits >> 1) & 0x3f;
+ map->post_multiplier_7_to_14 = (scale_bits >> 7) & 0xff;
+ map->post_multiplier_15_to_22 = (scale_bits >> 15) & 0xff;
+ } else {
+ map->post_multiplier = (scale_bits >> 8) & 0x1;
+ map->post_multiplier_1_to_6 = (scale_bits >> 9) & 0x3f;
+ map->post_multiplier_7_to_14 = (scale_bits >> 15) & 0xff;
+ }
map->per_channel_post_mul = 0x0;