From c6ee46a7532291fc8583400e174e77b1833daf23 Mon Sep 17 00:00:00 2001 From: Ian Romanick Date: Tue, 22 May 2018 18:18:07 -0700 Subject: [PATCH] nir: Add nir_alu_srcs_negative_equal v2: Move bug fix in get_neg_instr from the next patch to this patch (where it was intended to be in the first place). Noticed by Caio. Reviewed-by: Kenneth Graunke --- src/compiler/nir/nir.h | 4 + src/compiler/nir/nir_instr_set.c | 104 ++++++++++++++++++ .../nir/tests/negative_equal_tests.cpp | 84 ++++++++++++++ 3 files changed, 192 insertions(+) diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h index e123a59cca8..3ddf97bb12c 100644 --- a/src/compiler/nir/nir.h +++ b/src/compiler/nir/nir.h @@ -997,6 +997,10 @@ bool nir_const_value_negative_equal(const nir_const_value *c1, bool nir_alu_srcs_equal(const nir_alu_instr *alu1, const nir_alu_instr *alu2, unsigned src1, unsigned src2); +bool nir_alu_srcs_negative_equal(const nir_alu_instr *alu1, + const nir_alu_instr *alu2, + unsigned src1, unsigned src2); + typedef enum { nir_deref_type_var, nir_deref_type_array, diff --git a/src/compiler/nir/nir_instr_set.c b/src/compiler/nir/nir_instr_set.c index 1307fe2f3c9..9aa1f3bbe5e 100644 --- a/src/compiler/nir/nir_instr_set.c +++ b/src/compiler/nir/nir_instr_set.c @@ -276,6 +276,20 @@ nir_srcs_equal(nir_src src1, nir_src src2) } } +/** + * If the \p s is an SSA value that was generated by a negation instruction, + * that instruction is returned as a \c nir_alu_instr. Otherwise \c NULL is + * returned. + */ +static const struct nir_alu_instr * +get_neg_instr(const nir_src *s) +{ + const struct nir_alu_instr *const alu = nir_src_as_alu_instr_const(s); + + return alu != NULL && (alu->op == nir_op_fneg || alu->op == nir_op_ineg) + ? alu : NULL; +} + bool nir_const_value_negative_equal(const nir_const_value *c1, const nir_const_value *c2, @@ -377,6 +391,96 @@ nir_const_value_negative_equal(const nir_const_value *c1, return false; } +/** + * Shallow compare of ALU srcs to determine if one is the negation of the other + * + * This function detects cases where \p alu1 is a constant and \p alu2 is a + * constant that is its negation. It will also detect cases where \p alu2 is + * an SSA value that is a \c nir_op_fneg applied to \p alu1 (and vice versa). + * + * This function does not detect the general case when \p alu1 and \p alu2 are + * SSA values that are the negations of each other (e.g., \p alu1 represents + * (a * b) and \p alu2 represents (-a * b)). + */ +bool +nir_alu_srcs_negative_equal(const nir_alu_instr *alu1, + const nir_alu_instr *alu2, + unsigned src1, unsigned src2) +{ + if (alu1->src[src1].abs != alu2->src[src2].abs) + return false; + + bool parity = alu1->src[src1].negate != alu2->src[src2].negate; + + /* Handling load_const instructions is tricky. */ + + const nir_const_value *const const1 = + nir_src_as_const_value(alu1->src[src1].src); + + if (const1 != NULL) { + /* Assume that constant folding will eliminate source mods and unary + * ops. + */ + if (parity) + return false; + + const nir_const_value *const const2 = + nir_src_as_const_value(alu2->src[src2].src); + + if (const2 == NULL) + return false; + + /* FINISHME: Apply the swizzle? */ + return nir_const_value_negative_equal(const1, + const2, + nir_ssa_alu_instr_src_components(alu1, src1), + nir_op_infos[alu1->op].input_types[src1], + alu1->dest.dest.ssa.bit_size); + } + + uint8_t alu1_swizzle[4] = {}; + nir_src alu1_actual_src; + const struct nir_alu_instr *const neg1 = get_neg_instr(&alu1->src[src1].src); + + if (neg1) { + parity = !parity; + alu1_actual_src = neg1->src[0].src; + + for (unsigned i = 0; i < nir_ssa_alu_instr_src_components(neg1, 0); i++) + alu1_swizzle[i] = neg1->src[0].swizzle[i]; + } else { + alu1_actual_src = alu1->src[src1].src; + + for (unsigned i = 0; i < nir_ssa_alu_instr_src_components(alu1, src1); i++) + alu1_swizzle[i] = i; + } + + uint8_t alu2_swizzle[4] = {}; + nir_src alu2_actual_src; + const struct nir_alu_instr *const neg2 = get_neg_instr(&alu2->src[src2].src); + + if (neg2) { + parity = !parity; + alu2_actual_src = neg2->src[0].src; + + for (unsigned i = 0; i < nir_ssa_alu_instr_src_components(neg2, 0); i++) + alu2_swizzle[i] = neg2->src[0].swizzle[i]; + } else { + alu2_actual_src = alu2->src[src2].src; + + for (unsigned i = 0; i < nir_ssa_alu_instr_src_components(alu2, src2); i++) + alu2_swizzle[i] = i; + } + + for (unsigned i = 0; i < nir_ssa_alu_instr_src_components(alu1, src1); i++) { + if (alu1_swizzle[alu1->src[src1].swizzle[i]] != + alu2_swizzle[alu2->src[src2].swizzle[i]]) + return false; + } + + return parity && nir_srcs_equal(alu1_actual_src, alu2_actual_src); +} + bool nir_alu_srcs_equal(const nir_alu_instr *alu1, const nir_alu_instr *alu2, unsigned src1, unsigned src2) diff --git a/src/compiler/nir/tests/negative_equal_tests.cpp b/src/compiler/nir/tests/negative_equal_tests.cpp index e450a8172db..b38a0c10da5 100644 --- a/src/compiler/nir/tests/negative_equal_tests.cpp +++ b/src/compiler/nir/tests/negative_equal_tests.cpp @@ -22,6 +22,7 @@ */ #include #include "nir.h" +#include "nir_builder.h" #include "util/half_float.h" static nir_const_value count_sequence(nir_alu_type base_type, unsigned bits, @@ -47,6 +48,21 @@ protected: nir_const_value c2; }; +class alu_srcs_negative_equal_test : public ::testing::Test { +protected: + alu_srcs_negative_equal_test() + { + static const nir_shader_compiler_options options = { }; + nir_builder_init_simple_shader(&bld, NULL, MESA_SHADER_VERTEX, &options); + } + + ~alu_srcs_negative_equal_test() + { + ralloc_free(bld.shader); + } + + struct nir_builder bld; +}; TEST_F(const_value_negative_equal_test, float32_zero) { @@ -130,6 +146,74 @@ compare_fewer_components(nir_type_uint, 32) compare_fewer_components(nir_type_int, 64) compare_fewer_components(nir_type_uint, 64) +TEST_F(alu_srcs_negative_equal_test, trivial_float) +{ + nir_ssa_def *two = nir_imm_float(&bld, 2.0f); + nir_ssa_def *negative_two = nir_imm_float(&bld, -2.0f); + + nir_ssa_def *result = nir_fadd(&bld, two, negative_two); + nir_alu_instr *instr = nir_instr_as_alu(result->parent_instr); + + ASSERT_NE((void *) 0, instr); + EXPECT_TRUE(nir_alu_srcs_negative_equal(instr, instr, 0, 1)); + EXPECT_FALSE(nir_alu_srcs_negative_equal(instr, instr, 0, 0)); + EXPECT_FALSE(nir_alu_srcs_negative_equal(instr, instr, 1, 1)); +} + +TEST_F(alu_srcs_negative_equal_test, trivial_int) +{ + nir_ssa_def *two = nir_imm_int(&bld, 2); + nir_ssa_def *negative_two = nir_imm_int(&bld, -2); + + nir_ssa_def *result = nir_iadd(&bld, two, negative_two); + nir_alu_instr *instr = nir_instr_as_alu(result->parent_instr); + + ASSERT_NE((void *) 0, instr); + EXPECT_TRUE(nir_alu_srcs_negative_equal(instr, instr, 0, 1)); + EXPECT_FALSE(nir_alu_srcs_negative_equal(instr, instr, 0, 0)); + EXPECT_FALSE(nir_alu_srcs_negative_equal(instr, instr, 1, 1)); +} + +TEST_F(alu_srcs_negative_equal_test, trivial_negation_float) +{ + /* Cannot just do the negation of a nir_load_const_instr because + * nir_alu_srcs_negative_equal expects that constant folding will convert + * fneg(2.0) to just -2.0. + */ + nir_ssa_def *two = nir_imm_float(&bld, 2.0f); + nir_ssa_def *two_plus_two = nir_fadd(&bld, two, two); + nir_ssa_def *negation = nir_fneg(&bld, two_plus_two); + + nir_ssa_def *result = nir_fadd(&bld, two_plus_two, negation); + + nir_alu_instr *instr = nir_instr_as_alu(result->parent_instr); + + ASSERT_NE((void *) 0, instr); + EXPECT_TRUE(nir_alu_srcs_negative_equal(instr, instr, 0, 1)); + EXPECT_FALSE(nir_alu_srcs_negative_equal(instr, instr, 0, 0)); + EXPECT_FALSE(nir_alu_srcs_negative_equal(instr, instr, 1, 1)); +} + +TEST_F(alu_srcs_negative_equal_test, trivial_negation_int) +{ + /* Cannot just do the negation of a nir_load_const_instr because + * nir_alu_srcs_negative_equal expects that constant folding will convert + * ineg(2) to just -2. + */ + nir_ssa_def *two = nir_imm_int(&bld, 2); + nir_ssa_def *two_plus_two = nir_iadd(&bld, two, two); + nir_ssa_def *negation = nir_ineg(&bld, two_plus_two); + + nir_ssa_def *result = nir_iadd(&bld, two_plus_two, negation); + + nir_alu_instr *instr = nir_instr_as_alu(result->parent_instr); + + ASSERT_NE((void *) 0, instr); + EXPECT_TRUE(nir_alu_srcs_negative_equal(instr, instr, 0, 1)); + EXPECT_FALSE(nir_alu_srcs_negative_equal(instr, instr, 0, 0)); + EXPECT_FALSE(nir_alu_srcs_negative_equal(instr, instr, 1, 1)); +} + static nir_const_value count_sequence(nir_alu_type base_type, unsigned bits, int first) {