diff --git a/src/compiler/nir/nir_constant_expressions.py b/src/compiler/nir/nir_constant_expressions.py index ad841e3d311..0ff8dfbf6c7 100644 --- a/src/compiler/nir/nir_constant_expressions.py +++ b/src/compiler/nir/nir_constant_expressions.py @@ -22,13 +22,18 @@ def type_add_size(type_, size): return type_ + str(size) def op_bit_sizes(op): - sizes = set([8, 16, 32, 64]) + sizes = None if not type_has_size(op.output_type): - sizes = sizes.intersection(set(type_sizes(op.output_type))) + sizes = set(type_sizes(op.output_type)) + for input_type in op.input_types: if not type_has_size(input_type): - sizes = sizes.intersection(set(type_sizes(input_type))) - return sorted(list(sizes)) + if sizes is None: + sizes = set(type_sizes(input_type)) + else: + sizes = sizes.intersection(set(type_sizes(input_type))) + + return sorted(list(sizes)) if sizes is not None else None def get_const_field(type_): if type_ == "bool32": @@ -375,17 +380,21 @@ evaluate_${name}(MAYBE_UNUSED unsigned num_components, unsigned bit_size, { nir_const_value _dst_val = { {0, } }; - switch (bit_size) { - % for bit_size in op_bit_sizes(op): - case ${bit_size}: { - ${evaluate_op(op, bit_size)} - break; - } - % endfor + % if op_bit_sizes(op) is not None: + switch (bit_size) { + % for bit_size in op_bit_sizes(op): + case ${bit_size}: { + ${evaluate_op(op, bit_size)} + break; + } + % endfor - default: - unreachable("unknown bit width"); - } + default: + unreachable("unknown bit width"); + } + % else: + ${evaluate_op(op, 0)} + % endif return _dst_val; }