diff --git a/src/gallium/drivers/etnaviv/etnaviv_ml_nn.c b/src/gallium/drivers/etnaviv/etnaviv_ml_nn.c index 73bb3d1349a45..4fe23d505a2db 100644 --- a/src/gallium/drivers/etnaviv/etnaviv_ml_nn.c +++ b/src/gallium/drivers/etnaviv/etnaviv_ml_nn.c @@ -149,7 +149,8 @@ struct etna_nn_params { FIELD(out_zero_point, 8) FIELD(kernel_direct_stream_from_VIP_sram, 1) FIELD(depthwise, 1) - FIELD(unused11, 14) + FIELD(post_multiplier_15_to_22, 8) + FIELD(unused11, 6) /* 23, from here they aren't set on */ FIELD(unused12, 32) @@ -641,6 +642,7 @@ create_nn_config(struct etna_ml_subgraph *subgraph, const struct etna_operation struct pipe_context *context = subgraph->base.context; struct etna_context *ctx = etna_context(context); unsigned nn_core_count = ctx->screen->specs.nn_core_count; + unsigned nn_core_version = ctx->screen->specs.nn_core_version; unsigned oc_sram_size = ctx->screen->specs.on_chip_sram_size; struct etna_bo *bo = etna_bo_new(ctx->screen->dev, sizeof(struct etna_nn_params), @@ -840,16 +842,27 @@ create_nn_config(struct etna_ml_subgraph *subgraph, const struct etna_operation float conv_scale = (operation->input_scale * operation->weight_scale) / operation->output_scale; uint32_t scale_bits = fui(conv_scale); /* Taken from https://github.com/pytorch/QNNPACK/blob/master/src/qnnpack/requantization.h#L130 */ - unsigned shift = 127 + 31 - 32 - (scale_bits >> 23) + 16; + unsigned shift = 127 + 31 - 32 - (scale_bits >> 23); + if (nn_core_version == 8) + shift += 1; + else + shift += 16; /* Divides by 2 * (post_shift - 18), rounding to nearest integer. If result doesn't fit in 8 bits, it is clamped to 255. galcore sets to 15 if INT8, to 0 if UINT8. */ map->post_shift = shift & 0x1f; map->post_shift_bit_5_6 = (shift >> 5) & 0x3; /* Multiplies by (multiplier * 2^15) */ - map->post_multiplier = (scale_bits >> 8) & 0x1; - map->post_multiplier_1_to_6 = (scale_bits >> 9) & 0x3f; - map->post_multiplier_7_to_14 = (scale_bits >> 15) & 0xff; + if (nn_core_version == 8) { + map->post_multiplier = scale_bits & 0x1; + map->post_multiplier_1_to_6 = (scale_bits >> 1) & 0x3f; + map->post_multiplier_7_to_14 = (scale_bits >> 7) & 0xff; + map->post_multiplier_15_to_22 = (scale_bits >> 15) & 0xff; + } else { + map->post_multiplier = (scale_bits >> 8) & 0x1; + map->post_multiplier_1_to_6 = (scale_bits >> 9) & 0x3f; + map->post_multiplier_7_to_14 = (scale_bits >> 15) & 0xff; + } map->per_channel_post_mul = 0x0;