etnaviv/nn: Extend post-multiplier for v8 architecture

The post-multiplier was extended by 8 bits for improved precision.
The shift offset appears to have changed as well.

Signed-off-by: Philipp Zabel <p.zabel@pengutronix.de>
Reviewed-by: Tomeu Vizoso <tomeu@tomeuvizoso.net>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/28878>
This commit is contained in:
Philipp Zabel 2024-04-23 15:35:15 +02:00 committed by Marge Bot
parent c2290843df
commit e2444ad6c1
1 changed files with 18 additions and 5 deletions

View File

@ -149,7 +149,8 @@ struct etna_nn_params {
FIELD(out_zero_point, 8)
FIELD(kernel_direct_stream_from_VIP_sram, 1)
FIELD(depthwise, 1)
FIELD(unused11, 14)
FIELD(post_multiplier_15_to_22, 8)
FIELD(unused11, 6)
/* 23, from here they aren't set on */
FIELD(unused12, 32)
@ -641,6 +642,7 @@ create_nn_config(struct etna_ml_subgraph *subgraph, const struct etna_operation
struct pipe_context *context = subgraph->base.context;
struct etna_context *ctx = etna_context(context);
unsigned nn_core_count = ctx->screen->specs.nn_core_count;
unsigned nn_core_version = ctx->screen->specs.nn_core_version;
unsigned oc_sram_size = ctx->screen->specs.on_chip_sram_size;
struct etna_bo *bo = etna_bo_new(ctx->screen->dev,
sizeof(struct etna_nn_params),
@ -840,16 +842,27 @@ create_nn_config(struct etna_ml_subgraph *subgraph, const struct etna_operation
float conv_scale = (operation->input_scale * operation->weight_scale) / operation->output_scale;
uint32_t scale_bits = fui(conv_scale);
/* Taken from https://github.com/pytorch/QNNPACK/blob/master/src/qnnpack/requantization.h#L130 */
unsigned shift = 127 + 31 - 32 - (scale_bits >> 23) + 16;
unsigned shift = 127 + 31 - 32 - (scale_bits >> 23);
if (nn_core_version == 8)
shift += 1;
else
shift += 16;
/* Divides by 2 * (post_shift - 18), rounding to nearest integer. If result doesn't fit in 8 bits, it is clamped to 255. galcore sets to 15 if INT8, to 0 if UINT8. */
map->post_shift = shift & 0x1f;
map->post_shift_bit_5_6 = (shift >> 5) & 0x3;
/* Multiplies by (multiplier * 2^15) */
map->post_multiplier = (scale_bits >> 8) & 0x1;
map->post_multiplier_1_to_6 = (scale_bits >> 9) & 0x3f;
map->post_multiplier_7_to_14 = (scale_bits >> 15) & 0xff;
if (nn_core_version == 8) {
map->post_multiplier = scale_bits & 0x1;
map->post_multiplier_1_to_6 = (scale_bits >> 1) & 0x3f;
map->post_multiplier_7_to_14 = (scale_bits >> 7) & 0xff;
map->post_multiplier_15_to_22 = (scale_bits >> 15) & 0xff;
} else {
map->post_multiplier = (scale_bits >> 8) & 0x1;
map->post_multiplier_1_to_6 = (scale_bits >> 9) & 0x3f;
map->post_multiplier_7_to_14 = (scale_bits >> 15) & 0xff;
}
map->per_channel_post_mul = 0x0;