diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h index 0ed13a3f931..82837e78b0f 100644 --- a/src/compiler/nir/nir.h +++ b/src/compiler/nir/nir.h @@ -3183,6 +3183,13 @@ typedef enum { nir_lower_divmod64 = (1 << 2), /** Lower all 64-bit umul_high and imul_high opcodes */ nir_lower_imul_high64 = (1 << 3), + nir_lower_mov64 = (1 << 4), + nir_lower_icmp64 = (1 << 5), + nir_lower_iadd64 = (1 << 6), + nir_lower_iabs64 = (1 << 7), + nir_lower_ineg64 = (1 << 8), + nir_lower_logic64 = (1 << 9), + nir_lower_minmax64 = (1 << 10), } nir_lower_int64_options; bool nir_lower_int64(nir_shader *shader, nir_lower_int64_options options); diff --git a/src/compiler/nir/nir_lower_int64.c b/src/compiler/nir/nir_lower_int64.c index 575aa839581..75cd47feeea 100644 --- a/src/compiler/nir/nir_lower_int64.c +++ b/src/compiler/nir/nir_lower_int64.c @@ -505,6 +505,38 @@ opcode_to_options_mask(nir_op opcode) case nir_op_imod: case nir_op_irem: return nir_lower_divmod64; + case nir_op_b2i64: + case nir_op_i2b1: + case nir_op_i2i32: + case nir_op_i2i64: + case nir_op_u2u32: + case nir_op_u2u64: + case nir_op_bcsel: + return nir_lower_mov64; + case nir_op_ieq: + case nir_op_ine: + case nir_op_ult: + case nir_op_ilt: + case nir_op_uge: + case nir_op_ige: + return nir_lower_icmp64; + case nir_op_iadd: + case nir_op_isub: + return nir_lower_iadd64; + case nir_op_imin: + case nir_op_imax: + case nir_op_umin: + case nir_op_umax: + return nir_lower_minmax64; + case nir_op_iabs: + return nir_lower_iabs64; + case nir_op_ineg: + return nir_lower_ineg64; + case nir_op_iand: + case nir_op_ior: + case nir_op_ixor: + case nir_op_inot: + return nir_lower_logic64; default: return 0; } @@ -536,6 +568,59 @@ lower_int64_alu_instr(nir_builder *b, nir_alu_instr *alu) return lower_imod64(b, src[0], src[1]); case nir_op_irem: return lower_irem64(b, src[0], src[1]); + case nir_op_b2i64: + return lower_b2i64(b, src[0]); + case nir_op_i2b1: + return lower_i2b(b, src[0]); + case nir_op_i2i8: + return lower_i2i8(b, src[0]); + case nir_op_i2i16: + return lower_i2i16(b, src[0]); + case nir_op_i2i32: + return lower_i2i32(b, src[0]); + case nir_op_i2i64: + return lower_i2i64(b, src[0]); + case nir_op_u2u8: + return lower_u2u8(b, src[0]); + case nir_op_u2u16: + return lower_u2u16(b, src[0]); + case nir_op_u2u32: + return lower_u2u32(b, src[0]); + case nir_op_u2u64: + return lower_u2u64(b, src[0]); + case nir_op_bcsel: + return lower_bcsel64(b, src[0], src[1], src[2]); + case nir_op_ieq: + case nir_op_ine: + case nir_op_ult: + case nir_op_ilt: + case nir_op_uge: + case nir_op_ige: + return lower_int64_compare(b, alu->op, src[0], src[1]); + case nir_op_iadd: + return lower_iadd64(b, src[0], src[1]); + case nir_op_isub: + return lower_isub64(b, src[0], src[1]); + case nir_op_imin: + return lower_imin64(b, src[0], src[1]); + case nir_op_imax: + return lower_imax64(b, src[0], src[1]); + case nir_op_umin: + return lower_umin64(b, src[0], src[1]); + case nir_op_umax: + return lower_umax64(b, src[0], src[1]); + case nir_op_iabs: + return lower_iabs64(b, src[0]); + case nir_op_ineg: + return lower_ineg64(b, src[0]); + case nir_op_iand: + return lower_iand64(b, src[0], src[1]); + case nir_op_ior: + return lower_ior64(b, src[0], src[1]); + case nir_op_ixor: + return lower_ixor64(b, src[0], src[1]); + case nir_op_inot: + return lower_inot64(b, src[0]); default: unreachable("Invalid ALU opcode to lower"); } @@ -554,9 +639,41 @@ lower_int64_impl(nir_function_impl *impl, nir_lower_int64_options options) continue; nir_alu_instr *alu = nir_instr_as_alu(instr); - assert(alu->dest.dest.is_ssa); - if (alu->dest.dest.ssa.bit_size != 64) - continue; + switch (alu->op) { + case nir_op_i2b1: + case nir_op_i2i32: + case nir_op_u2u32: + assert(alu->src[0].src.is_ssa); + if (alu->src[0].src.ssa->bit_size != 64) + continue; + break; + case nir_op_bcsel: + assert(alu->src[1].src.is_ssa); + assert(alu->src[2].src.is_ssa); + assert(alu->src[1].src.ssa->bit_size == + alu->src[2].src.ssa->bit_size); + if (alu->src[1].src.ssa->bit_size != 64) + continue; + break; + case nir_op_ieq: + case nir_op_ine: + case nir_op_ult: + case nir_op_ilt: + case nir_op_uge: + case nir_op_ige: + assert(alu->src[0].src.is_ssa); + assert(alu->src[1].src.is_ssa); + assert(alu->src[0].src.ssa->bit_size == + alu->src[1].src.ssa->bit_size); + if (alu->src[0].src.ssa->bit_size != 64) + continue; + break; + default: + assert(alu->dest.dest.is_ssa); + if (alu->dest.dest.ssa.bit_size != 64) + continue; + break; + } if (!(options & opcode_to_options_mask(alu->op))) continue;