mirror of https://gitlab.freedesktop.org/mesa/mesa
etnaviv/nn: Extend post-multiplier for v8 architecture
The post-multiplier was extended by 8 bits for improved precision. The shift offset appears to have changed as well. Signed-off-by: Philipp Zabel <p.zabel@pengutronix.de> Reviewed-by: Tomeu Vizoso <tomeu@tomeuvizoso.net> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/28878>
This commit is contained in:
parent
c2290843df
commit
e2444ad6c1
|
@ -149,7 +149,8 @@ struct etna_nn_params {
|
|||
FIELD(out_zero_point, 8)
|
||||
FIELD(kernel_direct_stream_from_VIP_sram, 1)
|
||||
FIELD(depthwise, 1)
|
||||
FIELD(unused11, 14)
|
||||
FIELD(post_multiplier_15_to_22, 8)
|
||||
FIELD(unused11, 6)
|
||||
|
||||
/* 23, from here they aren't set on */
|
||||
FIELD(unused12, 32)
|
||||
|
@ -641,6 +642,7 @@ create_nn_config(struct etna_ml_subgraph *subgraph, const struct etna_operation
|
|||
struct pipe_context *context = subgraph->base.context;
|
||||
struct etna_context *ctx = etna_context(context);
|
||||
unsigned nn_core_count = ctx->screen->specs.nn_core_count;
|
||||
unsigned nn_core_version = ctx->screen->specs.nn_core_version;
|
||||
unsigned oc_sram_size = ctx->screen->specs.on_chip_sram_size;
|
||||
struct etna_bo *bo = etna_bo_new(ctx->screen->dev,
|
||||
sizeof(struct etna_nn_params),
|
||||
|
@ -840,16 +842,27 @@ create_nn_config(struct etna_ml_subgraph *subgraph, const struct etna_operation
|
|||
float conv_scale = (operation->input_scale * operation->weight_scale) / operation->output_scale;
|
||||
uint32_t scale_bits = fui(conv_scale);
|
||||
/* Taken from https://github.com/pytorch/QNNPACK/blob/master/src/qnnpack/requantization.h#L130 */
|
||||
unsigned shift = 127 + 31 - 32 - (scale_bits >> 23) + 16;
|
||||
unsigned shift = 127 + 31 - 32 - (scale_bits >> 23);
|
||||
if (nn_core_version == 8)
|
||||
shift += 1;
|
||||
else
|
||||
shift += 16;
|
||||
|
||||
/* Divides by 2 * (post_shift - 18), rounding to nearest integer. If result doesn't fit in 8 bits, it is clamped to 255. galcore sets to 15 if INT8, to 0 if UINT8. */
|
||||
map->post_shift = shift & 0x1f;
|
||||
map->post_shift_bit_5_6 = (shift >> 5) & 0x3;
|
||||
|
||||
/* Multiplies by (multiplier * 2^15) */
|
||||
map->post_multiplier = (scale_bits >> 8) & 0x1;
|
||||
map->post_multiplier_1_to_6 = (scale_bits >> 9) & 0x3f;
|
||||
map->post_multiplier_7_to_14 = (scale_bits >> 15) & 0xff;
|
||||
if (nn_core_version == 8) {
|
||||
map->post_multiplier = scale_bits & 0x1;
|
||||
map->post_multiplier_1_to_6 = (scale_bits >> 1) & 0x3f;
|
||||
map->post_multiplier_7_to_14 = (scale_bits >> 7) & 0xff;
|
||||
map->post_multiplier_15_to_22 = (scale_bits >> 15) & 0xff;
|
||||
} else {
|
||||
map->post_multiplier = (scale_bits >> 8) & 0x1;
|
||||
map->post_multiplier_1_to_6 = (scale_bits >> 9) & 0x3f;
|
||||
map->post_multiplier_7_to_14 = (scale_bits >> 15) & 0xff;
|
||||
}
|
||||
|
||||
map->per_channel_post_mul = 0x0;
|
||||
|
||||
|
|
Loading…
Reference in New Issue