nir/algebraic: Add lowering for dot_4x8 instructions

v2: Fix copy-and-paste bugs in lowering patterns.

v3: Add has_sudot_4x8 flag.  Requested by Rhys.

v4: Since the names of the opcodes changed from dp4 to dot_4x8, also
change the names of the lowering helpers.  Suggested by Jason.

Reviewed-by: Jason Ekstrand <jason@jlekstrand.net>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/12142>
This commit is contained in:
Ian Romanick 2021-06-09 14:53:49 -07:00 committed by Marge Bot
parent 0f809dbf40
commit 839495efc6
2 changed files with 57 additions and 0 deletions

View File

@ -223,6 +223,42 @@ optimizations = [
(('sudot_4x8_iadd_sat', '#a', '#b', 'c(is_not_const)'), ('iadd_sat', ('sudot_4x8_iadd', a, b, 0), c), '!options->lower_add_sat'),
]
# Shorthand for the expansion of just the dot product part of the [iu]dp4a
# instructions.
sdot_4x8_a_b = ('iadd', ('iadd', ('imul', ('extract_i8', a, 0), ('extract_i8', b, 0)),
('imul', ('extract_i8', a, 1), ('extract_i8', b, 1))),
('iadd', ('imul', ('extract_i8', a, 2), ('extract_i8', b, 2)),
('imul', ('extract_i8', a, 3), ('extract_i8', b, 3))))
udot_4x8_a_b = ('iadd', ('iadd', ('imul', ('extract_u8', a, 0), ('extract_u8', b, 0)),
('imul', ('extract_u8', a, 1), ('extract_u8', b, 1))),
('iadd', ('imul', ('extract_u8', a, 2), ('extract_u8', b, 2)),
('imul', ('extract_u8', a, 3), ('extract_u8', b, 3))))
sudot_4x8_a_b = ('iadd', ('iadd', ('imul', ('extract_i8', a, 0), ('extract_u8', b, 0)),
('imul', ('extract_i8', a, 1), ('extract_u8', b, 1))),
('iadd', ('imul', ('extract_i8', a, 2), ('extract_u8', b, 2)),
('imul', ('extract_i8', a, 3), ('extract_u8', b, 3))))
optimizations.extend([
(('sdot_4x8_iadd', a, b, c), ('iadd', sdot_4x8_a_b, c), '!options->has_dot_4x8'),
(('udot_4x8_uadd', a, b, c), ('iadd', udot_4x8_a_b, c), '!options->has_dot_4x8'),
(('sudot_4x8_iadd', a, b, c), ('iadd', sudot_4x8_a_b, c), '!options->has_sudot_4x8'),
# For the unsigned dot-product, the largest possible value 4*(255*255) =
# 0x3f804, so we don't have to worry about that intermediate result
# overflowing. 0x100000000 - 0x3f804 = 0xfffc07fc. If c is a constant
# that is less than 0xfffc07fc, then the result cannot overflow ever.
(('udot_4x8_uadd_sat', a, b, '#c(is_ult_0xfffc07fc)'), ('udot_4x8_uadd', a, b, c)),
(('udot_4x8_uadd_sat', a, b, c), ('uadd_sat', udot_4x8_a_b, c), '!options->has_dot_4x8'),
# For the signed dot-product, the largest positive value is 4*(-128*-128) =
# 0x10000, and the largest negative value is 4*(-128*127) = -0xfe00. We
# don't have to worry about that intermediate result overflowing or
# underflowing.
(('sdot_4x8_iadd_sat', a, b, c), ('iadd_sat', sdot_4x8_a_b, c), '!options->has_dot_4x8'),
(('sudot_4x8_iadd_sat', a, b, c), ('iadd_sat', sudot_4x8_a_b, c), '!options->has_sudot_4x8'),
])
# Float sizes
for s in [16, 32, 64]:
optimizations.extend([

View File

@ -205,6 +205,27 @@ is_not_const_zero(UNUSED struct hash_table *ht, const nir_alu_instr *instr,
return true;
}
/** Is value unsigned less than 0xfffc07fc? */
static inline bool
is_ult_0xfffc07fc(UNUSED struct hash_table *ht, const nir_alu_instr *instr,
unsigned src, unsigned num_components,
const uint8_t *swizzle)
{
/* only constant srcs: */
if (!nir_src_is_const(instr->src[src].src))
return false;
for (unsigned i = 0; i < num_components; i++) {
const unsigned val =
nir_src_comp_as_uint(instr->src[src].src, swizzle[i]);
if (val >= 0xfffc07fcU)
return false;
}
return true;
}
static inline bool
is_not_const(UNUSED struct hash_table *ht, const nir_alu_instr *instr,
unsigned src, UNUSED unsigned num_components,