diff --git a/src/compiler/nir/nir_instr_set.c b/src/compiler/nir/nir_instr_set.c index 6796fcaad5b..d200412dc9c 100644 --- a/src/compiler/nir/nir_instr_set.c +++ b/src/compiler/nir/nir_instr_set.c @@ -352,12 +352,31 @@ nir_const_value_negative_equal(nir_const_value c1, * This function does not detect the general case when \p alu1 and \p alu2 are * SSA values that are the negations of each other (e.g., \p alu1 represents * (a * b) and \p alu2 represents (-a * b)). + * + * \warning + * It is the responsibility of the caller to ensure that the component counts, + * write masks, and base types of the sources being compared are compatible. */ bool nir_alu_srcs_negative_equal(const nir_alu_instr *alu1, const nir_alu_instr *alu2, unsigned src1, unsigned src2) { +#ifndef NDEBUG + for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) { + assert(nir_alu_instr_channel_used(alu1, src1, i) == + nir_alu_instr_channel_used(alu2, src2, i)); + } + + if (nir_op_infos[alu1->op].input_types[src1] == nir_type_float) { + assert(nir_op_infos[alu1->op].input_types[src1] == + nir_op_infos[alu2->op].input_types[src2]); + } else { + assert(nir_op_infos[alu1->op].input_types[src1] == nir_type_int); + assert(nir_op_infos[alu2->op].input_types[src2] == nir_type_int); + } +#endif + if (alu1->src[src1].abs != alu2->src[src2].abs) return false; @@ -385,12 +404,13 @@ nir_alu_srcs_negative_equal(const nir_alu_instr *alu1, nir_src_bit_size(alu2->src[src2].src)) return false; - /* FINISHME: Apply the swizzle? */ - const unsigned components = nir_ssa_alu_instr_src_components(alu1, src1); const nir_alu_type full_type = nir_op_infos[alu1->op].input_types[src1] | nir_src_bit_size(alu1->src[src1].src); - for (unsigned i = 0; i < components; i++) { - if (!nir_const_value_negative_equal(const1[i], const2[i], full_type)) + for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) { + if (nir_alu_instr_channel_used(alu1, src1, i) && + !nir_const_value_negative_equal(const1[alu1->src[src1].swizzle[i]], + const2[alu2->src[src2].swizzle[i]], + full_type)) return false; } diff --git a/src/compiler/nir/tests/comparison_pre_tests.cpp b/src/compiler/nir/tests/comparison_pre_tests.cpp index fe1cc23fb3b..a48aeca8da4 100644 --- a/src/compiler/nir/tests/comparison_pre_tests.cpp +++ b/src/compiler/nir/tests/comparison_pre_tests.cpp @@ -473,6 +473,56 @@ TEST_F(comparison_pre_test, a_lt_neg_imm_vs_a_plus_imm) EXPECT_TRUE(nir_opt_comparison_pre_impl(bld.impl)); } +TEST_F(comparison_pre_test, swizzle_of_same_immediate_vector) +{ + /* Before: + * + * vec4 32 ssa_0 = load_const (-2.0, -1.0, 1.0, 2.0) + * vec4 32 ssa_1 = load_const ( 2.0, 1.0, -1.0, -2.0) + * vec4 32 ssa_2 = load_const ( 3.0, 4.0, 5.0, 6.0) + * vec4 32 ssa_3 = fadd ssa_0, ssa_2 + * vec1 1 ssa_4 = flt ssa_0.x, ssa_3.x + * + * if ssa_4 { + * vec1 32 ssa_5 = fadd ssa_0.w, ssa_3.x + * } else { + * } + */ + nir_ssa_def *a = nir_fadd(&bld, v1, v3); + + nir_alu_instr *flt = nir_alu_instr_create(bld.shader, nir_op_flt); + + flt->src[0].src = nir_src_for_ssa(v1); + flt->src[1].src = nir_src_for_ssa(a); + + memcpy(&flt->src[0].swizzle, xxxx, sizeof(xxxx)); + memcpy(&flt->src[1].swizzle, xxxx, sizeof(xxxx)); + + nir_builder_alu_instr_finish_and_insert(&bld, flt); + + flt->dest.dest.ssa.num_components = 1; + flt->dest.write_mask = 1; + + nir_if *nif = nir_push_if(&bld, &flt->dest.dest.ssa); + + nir_alu_instr *fadd = nir_alu_instr_create(bld.shader, nir_op_fadd); + + fadd->src[0].src = nir_src_for_ssa(v1); + fadd->src[1].src = nir_src_for_ssa(a); + + memcpy(&fadd->src[0].swizzle, wwww, sizeof(wwww)); + memcpy(&fadd->src[1].swizzle, xxxx, sizeof(xxxx)); + + nir_builder_alu_instr_finish_and_insert(&bld, fadd); + + fadd->dest.dest.ssa.num_components = 1; + fadd->dest.write_mask = 1; + + nir_pop_if(&bld, nif); + + EXPECT_TRUE(nir_opt_comparison_pre_impl(bld.impl)); +} + TEST_F(comparison_pre_test, non_scalar_add_result) { /* The optimization pass should not do anything because the result of the diff --git a/src/compiler/nir/tests/negative_equal_tests.cpp b/src/compiler/nir/tests/negative_equal_tests.cpp index 5e13c8fd28a..9fedb987166 100644 --- a/src/compiler/nir/tests/negative_equal_tests.cpp +++ b/src/compiler/nir/tests/negative_equal_tests.cpp @@ -270,6 +270,42 @@ compare_with_negation(nir_type_uint32) compare_with_negation(nir_type_int64) compare_with_negation(nir_type_uint64) +TEST_F(alu_srcs_negative_equal_test, swizzle_scalar_to_vector) +{ + nir_ssa_def *v = nir_imm_vec2(&bld, 1.0, -1.0); + const uint8_t s0[4] = { 0, 0, 0, 0 }; + const uint8_t s1[4] = { 1, 1, 1, 1 }; + + /* We can't use nir_swizzle here because it inserts an extra MOV. */ + nir_alu_instr *instr = nir_alu_instr_create(bld.shader, nir_op_fadd); + + instr->src[0].src = nir_src_for_ssa(v); + instr->src[1].src = nir_src_for_ssa(v); + + memcpy(&instr->src[0].swizzle, s0, sizeof(s0)); + memcpy(&instr->src[1].swizzle, s1, sizeof(s1)); + + nir_builder_alu_instr_finish_and_insert(&bld, instr); + + EXPECT_TRUE(nir_alu_srcs_negative_equal(instr, instr, 0, 1)); +} + +TEST_F(alu_srcs_negative_equal_test, unused_components_mismatch) +{ + nir_ssa_def *v1 = nir_imm_vec4(&bld, -2.0, 18.0, 43.0, 1.0); + nir_ssa_def *v2 = nir_imm_vec4(&bld, 2.0, 99.0, 76.0, -1.0); + + nir_ssa_def *result = nir_fadd(&bld, v1, v2); + + nir_alu_instr *instr = nir_instr_as_alu(result->parent_instr); + + /* Disable the channels that aren't negations of each other. */ + instr->dest.dest.is_ssa = false; + instr->dest.write_mask = 8 + 1; + + EXPECT_TRUE(nir_alu_srcs_negative_equal(instr, instr, 0, 1)); +} + static void count_sequence(nir_const_value c[NIR_MAX_VEC_COMPONENTS], nir_alu_type full_type, int first)