nir/algebraic: add ignore_exact() wrapper
Signed-off-by: Rhys Perry <pendingchaos02@gmail.com> Reviewed-by: Timur Kristóf <timur.kristof@gmail.com> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/13436>
This commit is contained in:
parent
f68797ead7
commit
312a284980
|
@ -98,13 +98,33 @@ class VarSet(object):
|
|||
def lock(self):
|
||||
self.immutable = True
|
||||
|
||||
class SearchExpression(object):
|
||||
def __init__(self, expr):
|
||||
self.opcode = expr[0]
|
||||
self.sources = expr[1:]
|
||||
self.ignore_exact = False
|
||||
|
||||
@staticmethod
|
||||
def create(val):
|
||||
if isinstance(val, tuple):
|
||||
return SearchExpression(val)
|
||||
else:
|
||||
assert(isinstance(val, SearchExpression))
|
||||
return val
|
||||
|
||||
def __repr__(self):
|
||||
l = [self.opcode, *self.sources]
|
||||
if self.ignore_exact:
|
||||
l.append('ignore_exact')
|
||||
return repr((*l,))
|
||||
|
||||
class Value(object):
|
||||
@staticmethod
|
||||
def create(val, name_base, varset, algebraic_pass):
|
||||
if isinstance(val, bytes):
|
||||
val = val.decode('utf-8')
|
||||
|
||||
if isinstance(val, tuple):
|
||||
if isinstance(val, tuple) or isinstance(val, SearchExpression):
|
||||
return Expression(val, name_base, varset, algebraic_pass)
|
||||
elif isinstance(val, Expression):
|
||||
return val
|
||||
|
@ -185,7 +205,9 @@ class Value(object):
|
|||
${val.cond_index},
|
||||
${val.swizzle()},
|
||||
% elif isinstance(val, Expression):
|
||||
${'true' if val.inexact else 'false'}, ${'true' if val.exact else 'false'},
|
||||
${'true' if val.inexact else 'false'},
|
||||
${'true' if val.exact else 'false'},
|
||||
${'true' if val.ignore_exact else 'false'},
|
||||
${val.c_opcode()},
|
||||
${val.comm_expr_idx}, ${val.comm_exprs},
|
||||
{ ${', '.join(src.array_index for src in val.sources)} },
|
||||
|
@ -339,15 +361,17 @@ _opcode_re = re.compile(r"(?P<inexact>~)?(?P<exact>!)?(?P<opcode>\w+)(?:@(?P<bit
|
|||
class Expression(Value):
|
||||
def __init__(self, expr, name_base, varset, algebraic_pass):
|
||||
Value.__init__(self, expr, name_base, "expression")
|
||||
assert isinstance(expr, tuple)
|
||||
|
||||
m = _opcode_re.match(expr[0])
|
||||
expr = SearchExpression.create(expr)
|
||||
|
||||
m = _opcode_re.match(expr.opcode)
|
||||
assert m and m.group('opcode') is not None
|
||||
|
||||
self.opcode = m.group('opcode')
|
||||
self._bit_size = int(m.group('bits')) if m.group('bits') else None
|
||||
self.inexact = m.group('inexact') is not None
|
||||
self.exact = m.group('exact') is not None
|
||||
self.ignore_exact = expr.ignore_exact
|
||||
self.cond = m.group('cond')
|
||||
|
||||
assert not self.inexact or not self.exact, \
|
||||
|
@ -372,7 +396,7 @@ class Expression(Value):
|
|||
self.cond_index = get_cond_index(algebraic_pass.expression_cond, self.cond)
|
||||
|
||||
self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset, algebraic_pass)
|
||||
for (i, src) in enumerate(expr[1:]) ]
|
||||
for (i, src) in enumerate(expr.sources) ]
|
||||
|
||||
# nir_search_expression::srcs is hard-coded to 4
|
||||
assert len(self.sources) <= 4
|
||||
|
@ -1235,3 +1259,9 @@ class AlgebraicPass(object):
|
|||
variable_cond = sorted(self.variable_cond.items(), key=lambda kv: kv[1]),
|
||||
get_c_opcode=get_c_opcode,
|
||||
itertools=itertools)
|
||||
|
||||
# The replacement expression isn't necessarily exact if the search expression is exact.
|
||||
def ignore_exact(*expr):
|
||||
expr = SearchExpression.create(expr)
|
||||
expr.ignore_exact = True
|
||||
return expr
|
||||
|
|
|
@ -41,6 +41,8 @@ e = 'e'
|
|||
signed_zero_inf_nan_preserve_16 = 'nir_is_float_control_signed_zero_inf_nan_preserve(info->float_controls_execution_mode, 16)'
|
||||
signed_zero_inf_nan_preserve_32 = 'nir_is_float_control_signed_zero_inf_nan_preserve(info->float_controls_execution_mode, 32)'
|
||||
|
||||
ignore_exact = nir_algebraic.ignore_exact
|
||||
|
||||
# Written in the form (<search>, <replace>) where <search> is an expression
|
||||
# and <replace> is either an expression or a value. An expression is
|
||||
# defined as a tuple of the form ([~]<op>, <src0>, <src1>, <src2>, <src3>)
|
||||
|
|
|
@ -408,7 +408,7 @@ match_expression(const nir_algebraic_table *table, const nir_search_expression *
|
|||
return false;
|
||||
|
||||
state->inexact_match = expr->inexact || state->inexact_match;
|
||||
state->has_exact_alu = instr->exact || state->has_exact_alu;
|
||||
state->has_exact_alu = (instr->exact && !expr->ignore_exact) || state->has_exact_alu;
|
||||
if (state->inexact_match && state->has_exact_alu)
|
||||
return false;
|
||||
|
||||
|
|
|
@ -142,8 +142,11 @@ typedef struct {
|
|||
/** In a replacement, requests that the instruction be marked exact. */
|
||||
bool exact : 1;
|
||||
|
||||
/** Don't make the replacement exact if the search expression is exact. */
|
||||
bool ignore_exact : 1;
|
||||
|
||||
/* One of nir_op or nir_search_op */
|
||||
uint16_t opcode : 14;
|
||||
uint16_t opcode : 13;
|
||||
|
||||
/* Commutative expression index. This is assigned by opt_algebraic.py when
|
||||
* search structures are constructed and is a unique (to this structure)
|
||||
|
|
Loading…
Reference in New Issue