2014-11-15 01:47:56 +00:00
|
|
|
#
|
|
|
|
# Copyright (C) 2014 Intel Corporation
|
|
|
|
#
|
|
|
|
# Permission is hereby granted, free of charge, to any person obtaining a
|
|
|
|
# copy of this software and associated documentation files (the "Software"),
|
|
|
|
# to deal in the Software without restriction, including without limitation
|
|
|
|
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
|
|
|
# and/or sell copies of the Software, and to permit persons to whom the
|
|
|
|
# Software is furnished to do so, subject to the following conditions:
|
|
|
|
#
|
|
|
|
# The above copyright notice and this permission notice (including the next
|
|
|
|
# paragraph) shall be included in all copies or substantial portions of the
|
|
|
|
# Software.
|
|
|
|
#
|
|
|
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
|
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
|
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
|
|
|
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
|
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
|
|
|
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
|
|
|
|
# IN THE SOFTWARE.
|
|
|
|
#
|
|
|
|
# Authors:
|
|
|
|
# Jason Ekstrand (jason@jlekstrand.net)
|
|
|
|
|
2016-04-25 19:36:08 +01:00
|
|
|
from __future__ import print_function
|
2016-04-25 20:23:38 +01:00
|
|
|
import ast
|
2018-06-27 11:37:38 +01:00
|
|
|
from collections import OrderedDict
|
2014-11-15 01:47:56 +00:00
|
|
|
import itertools
|
|
|
|
import struct
|
|
|
|
import sys
|
|
|
|
import mako.template
|
2015-01-29 00:42:20 +00:00
|
|
|
import re
|
2016-04-25 19:36:08 +01:00
|
|
|
import traceback
|
2014-11-15 01:47:56 +00:00
|
|
|
|
2016-04-26 04:58:47 +01:00
|
|
|
from nir_opcodes import opcodes
|
|
|
|
|
2018-08-09 09:27:21 +01:00
|
|
|
if sys.version_info < (3, 0):
|
2018-08-09 09:27:22 +01:00
|
|
|
integer_types = (int, long)
|
2018-08-09 09:27:21 +01:00
|
|
|
string_type = unicode
|
|
|
|
|
|
|
|
else:
|
2018-08-09 09:27:22 +01:00
|
|
|
integer_types = (int, )
|
2018-08-09 09:27:21 +01:00
|
|
|
string_type = str
|
|
|
|
|
2016-04-26 04:58:47 +01:00
|
|
|
_type_re = re.compile(r"(?P<type>int|uint|bool|float)?(?P<bits>\d+)?")
|
|
|
|
|
|
|
|
def type_bits(type_str):
|
|
|
|
m = _type_re.match(type_str)
|
|
|
|
assert m.group('type')
|
|
|
|
|
|
|
|
if m.group('bits') is None:
|
|
|
|
return 0
|
|
|
|
else:
|
|
|
|
return int(m.group('bits'))
|
|
|
|
|
2014-11-15 01:47:56 +00:00
|
|
|
# Represents a set of variables, each with a unique id
|
|
|
|
class VarSet(object):
|
|
|
|
def __init__(self):
|
|
|
|
self.names = {}
|
|
|
|
self.ids = itertools.count()
|
2015-01-29 19:45:31 +00:00
|
|
|
self.immutable = False;
|
2014-11-15 01:47:56 +00:00
|
|
|
|
|
|
|
def __getitem__(self, name):
|
|
|
|
if name not in self.names:
|
2015-01-29 19:45:31 +00:00
|
|
|
assert not self.immutable, "Unknown replacement variable: " + name
|
2018-07-05 14:17:39 +01:00
|
|
|
self.names[name] = next(self.ids)
|
2014-11-15 01:47:56 +00:00
|
|
|
|
|
|
|
return self.names[name]
|
|
|
|
|
2015-01-29 19:45:31 +00:00
|
|
|
def lock(self):
|
|
|
|
self.immutable = True
|
|
|
|
|
2014-11-15 01:47:56 +00:00
|
|
|
class Value(object):
|
|
|
|
@staticmethod
|
|
|
|
def create(val, name_base, varset):
|
2018-08-09 09:27:21 +01:00
|
|
|
if isinstance(val, bytes):
|
|
|
|
val = val.decode('utf-8')
|
|
|
|
|
2014-11-15 01:47:56 +00:00
|
|
|
if isinstance(val, tuple):
|
|
|
|
return Expression(val, name_base, varset)
|
|
|
|
elif isinstance(val, Expression):
|
|
|
|
return val
|
2018-08-09 09:27:21 +01:00
|
|
|
elif isinstance(val, string_type):
|
2014-11-15 01:47:56 +00:00
|
|
|
return Variable(val, name_base, varset)
|
2018-08-09 09:27:22 +01:00
|
|
|
elif isinstance(val, (bool, float) + integer_types):
|
2014-11-15 01:47:56 +00:00
|
|
|
return Constant(val, name_base)
|
|
|
|
|
|
|
|
__template = mako.template.Template("""
|
|
|
|
static const ${val.c_type} ${val.name} = {
|
2016-04-25 20:23:38 +01:00
|
|
|
{ ${val.type_enum}, ${val.bit_size} },
|
2014-11-15 01:47:56 +00:00
|
|
|
% if isinstance(val, Constant):
|
python: Don't abuse hex()
The hex() builtin returns a string containing the hexa-decimal
representation of an integer.
When the argument is not an integer, then the function calls that
object's __hex__() method, if one is defined. That method is supposed to
return a string.
While that's not explicitly documented, that string is supposed to be a
valid hexa-decimal representation for a number. Python 2 doesn't enforce
this though, which is why we got away with returning things like
'NIR_TRUE' which are not numbers.
In Python 3, the hex() builtin instead calls an object's __index__()
method, which itself must return an integer. That integer is then
automatically converted to a string with its hexa-decimal representation
by the rest of the hex() function.
As a result, we really can't make this compatible with Python 3 as it
is.
The solution is to stop using the hex() builtin, and instead use a hex()
object method, which can return whatever we want, in Python 2 and 3.
Signed-off-by: Mathieu Bridon <bochecha@daitauha.fr>
Reviewed-by: Eric Engestrom <eric.engestrom@intel.com>
Reviewed-by: Dylan Baker <dylan@pnwbakers.com>
2018-06-17 16:53:16 +01:00
|
|
|
${val.type()}, { ${val.hex()} /* ${val.value} */ },
|
2014-11-15 01:47:56 +00:00
|
|
|
% elif isinstance(val, Variable):
|
|
|
|
${val.index}, /* ${val.var_name} */
|
2015-01-29 00:42:20 +00:00
|
|
|
${'true' if val.is_constant else 'false'},
|
2015-08-14 19:45:30 +01:00
|
|
|
${val.type() or 'nir_type_invalid' },
|
2016-05-07 18:01:24 +01:00
|
|
|
${val.cond if val.cond else 'NULL'},
|
2014-11-15 01:47:56 +00:00
|
|
|
% elif isinstance(val, Expression):
|
2016-03-17 18:04:49 +00:00
|
|
|
${'true' if val.inexact else 'false'},
|
2014-11-15 01:47:56 +00:00
|
|
|
nir_op_${val.opcode},
|
|
|
|
{ ${', '.join(src.c_ptr for src in val.sources)} },
|
2017-01-10 04:47:31 +00:00
|
|
|
${val.cond if val.cond else 'NULL'},
|
2014-11-15 01:47:56 +00:00
|
|
|
% endif
|
|
|
|
};""")
|
|
|
|
|
2018-10-19 20:01:31 +01:00
|
|
|
def __init__(self, val, name, type_str):
|
|
|
|
self.in_val = str(val)
|
2014-11-15 01:47:56 +00:00
|
|
|
self.name = name
|
|
|
|
self.type_str = type_str
|
|
|
|
|
2018-10-19 20:01:31 +01:00
|
|
|
def __str__(self):
|
|
|
|
return self.in_val
|
|
|
|
|
2014-11-15 01:47:56 +00:00
|
|
|
@property
|
|
|
|
def type_enum(self):
|
|
|
|
return "nir_search_value_" + self.type_str
|
|
|
|
|
|
|
|
@property
|
|
|
|
def c_type(self):
|
|
|
|
return "nir_search_" + self.type_str
|
|
|
|
|
|
|
|
@property
|
|
|
|
def c_ptr(self):
|
|
|
|
return "&{0}.value".format(self.name)
|
|
|
|
|
|
|
|
def render(self):
|
|
|
|
return self.__template.render(val=self,
|
|
|
|
Constant=Constant,
|
|
|
|
Variable=Variable,
|
|
|
|
Expression=Expression)
|
|
|
|
|
2016-05-07 18:01:24 +01:00
|
|
|
_constant_re = re.compile(r"(?P<value>[^@\(]+)(?:@(?P<bits>\d+))?")
|
2016-04-25 20:23:38 +01:00
|
|
|
|
2014-11-15 01:47:56 +00:00
|
|
|
class Constant(Value):
|
|
|
|
def __init__(self, val, name):
|
2018-10-19 20:01:31 +01:00
|
|
|
Value.__init__(self, val, name, "constant")
|
2016-04-25 20:23:38 +01:00
|
|
|
|
2018-10-19 20:01:31 +01:00
|
|
|
self.in_val = str(val)
|
2016-04-25 20:23:38 +01:00
|
|
|
if isinstance(val, (str)):
|
|
|
|
m = _constant_re.match(val)
|
|
|
|
self.value = ast.literal_eval(m.group('value'))
|
|
|
|
self.bit_size = int(m.group('bits')) if m.group('bits') else 0
|
|
|
|
else:
|
|
|
|
self.value = val
|
|
|
|
self.bit_size = 0
|
|
|
|
|
|
|
|
if isinstance(self.value, bool):
|
|
|
|
assert self.bit_size == 0 or self.bit_size == 32
|
|
|
|
self.bit_size = 32
|
2014-11-15 01:47:56 +00:00
|
|
|
|
python: Don't abuse hex()
The hex() builtin returns a string containing the hexa-decimal
representation of an integer.
When the argument is not an integer, then the function calls that
object's __hex__() method, if one is defined. That method is supposed to
return a string.
While that's not explicitly documented, that string is supposed to be a
valid hexa-decimal representation for a number. Python 2 doesn't enforce
this though, which is why we got away with returning things like
'NIR_TRUE' which are not numbers.
In Python 3, the hex() builtin instead calls an object's __index__()
method, which itself must return an integer. That integer is then
automatically converted to a string with its hexa-decimal representation
by the rest of the hex() function.
As a result, we really can't make this compatible with Python 3 as it
is.
The solution is to stop using the hex() builtin, and instead use a hex()
object method, which can return whatever we want, in Python 2 and 3.
Signed-off-by: Mathieu Bridon <bochecha@daitauha.fr>
Reviewed-by: Eric Engestrom <eric.engestrom@intel.com>
Reviewed-by: Dylan Baker <dylan@pnwbakers.com>
2018-06-17 16:53:16 +01:00
|
|
|
def hex(self):
|
2014-11-15 01:47:56 +00:00
|
|
|
if isinstance(self.value, (bool)):
|
|
|
|
return 'NIR_TRUE' if self.value else 'NIR_FALSE'
|
2018-08-09 09:27:22 +01:00
|
|
|
if isinstance(self.value, integer_types):
|
2016-02-02 00:35:41 +00:00
|
|
|
return hex(self.value)
|
2014-11-15 01:47:56 +00:00
|
|
|
elif isinstance(self.value, float):
|
2018-06-25 17:31:01 +01:00
|
|
|
i = struct.unpack('Q', struct.pack('d', self.value))[0]
|
|
|
|
h = hex(i)
|
|
|
|
|
|
|
|
# On Python 2 this 'L' suffix is automatically added, but not on Python 3
|
|
|
|
# Adding it explicitly makes the generated file identical, regardless
|
|
|
|
# of the Python version running this script.
|
|
|
|
if h[-1] != 'L' and i > sys.maxsize:
|
|
|
|
h += 'L'
|
|
|
|
|
|
|
|
return h
|
2014-11-15 01:47:56 +00:00
|
|
|
else:
|
|
|
|
assert False
|
|
|
|
|
2015-08-14 19:45:30 +01:00
|
|
|
def type(self):
|
|
|
|
if isinstance(self.value, (bool)):
|
2018-10-19 04:31:08 +01:00
|
|
|
return "nir_type_bool"
|
2018-08-09 09:27:22 +01:00
|
|
|
elif isinstance(self.value, integer_types):
|
2015-08-14 19:45:30 +01:00
|
|
|
return "nir_type_int"
|
|
|
|
elif isinstance(self.value, float):
|
|
|
|
return "nir_type_float"
|
|
|
|
|
2016-04-25 20:23:38 +01:00
|
|
|
_var_name_re = re.compile(r"(?P<const>#)?(?P<name>\w+)"
|
2016-05-07 18:01:24 +01:00
|
|
|
r"(?:@(?P<type>int|uint|bool|float)?(?P<bits>\d+)?)?"
|
|
|
|
r"(?P<cond>\([^\)]+\))?")
|
2015-01-29 00:42:20 +00:00
|
|
|
|
2014-11-15 01:47:56 +00:00
|
|
|
class Variable(Value):
|
|
|
|
def __init__(self, val, name, varset):
|
2018-10-19 20:01:31 +01:00
|
|
|
Value.__init__(self, val, name, "variable")
|
2015-01-29 00:42:20 +00:00
|
|
|
|
|
|
|
m = _var_name_re.match(val)
|
|
|
|
assert m and m.group('name') is not None
|
|
|
|
|
|
|
|
self.var_name = m.group('name')
|
|
|
|
self.is_constant = m.group('const') is not None
|
2016-05-07 18:01:24 +01:00
|
|
|
self.cond = m.group('cond')
|
2015-01-29 00:42:20 +00:00
|
|
|
self.required_type = m.group('type')
|
2016-04-25 20:23:38 +01:00
|
|
|
self.bit_size = int(m.group('bits')) if m.group('bits') else 0
|
|
|
|
|
|
|
|
if self.required_type == 'bool':
|
|
|
|
assert self.bit_size == 0 or self.bit_size == 32
|
|
|
|
self.bit_size = 32
|
2015-01-29 00:42:20 +00:00
|
|
|
|
|
|
|
if self.required_type is not None:
|
2016-04-25 20:00:12 +01:00
|
|
|
assert self.required_type in ('float', 'bool', 'int', 'uint')
|
2015-01-29 00:42:20 +00:00
|
|
|
|
|
|
|
self.index = varset[self.var_name]
|
2014-11-15 01:47:56 +00:00
|
|
|
|
2018-10-19 20:01:31 +01:00
|
|
|
def __str__(self):
|
|
|
|
return self.in_val
|
|
|
|
|
2015-08-14 19:45:30 +01:00
|
|
|
def type(self):
|
|
|
|
if self.required_type == 'bool':
|
2018-10-19 04:31:08 +01:00
|
|
|
return "nir_type_bool"
|
2016-04-25 20:00:12 +01:00
|
|
|
elif self.required_type in ('int', 'uint'):
|
2015-08-14 19:45:30 +01:00
|
|
|
return "nir_type_int"
|
|
|
|
elif self.required_type == 'float':
|
|
|
|
return "nir_type_float"
|
|
|
|
|
2017-01-10 04:47:31 +00:00
|
|
|
_opcode_re = re.compile(r"(?P<inexact>~)?(?P<opcode>\w+)(?:@(?P<bits>\d+))?"
|
|
|
|
r"(?P<cond>\([^\)]+\))?")
|
2016-03-17 18:04:49 +00:00
|
|
|
|
2014-11-15 01:47:56 +00:00
|
|
|
class Expression(Value):
|
|
|
|
def __init__(self, expr, name_base, varset):
|
2018-10-19 20:01:31 +01:00
|
|
|
Value.__init__(self, expr, name_base, "expression")
|
2014-11-15 01:47:56 +00:00
|
|
|
assert isinstance(expr, tuple)
|
|
|
|
|
2016-03-17 18:04:49 +00:00
|
|
|
m = _opcode_re.match(expr[0])
|
|
|
|
assert m and m.group('opcode') is not None
|
|
|
|
|
|
|
|
self.opcode = m.group('opcode')
|
2016-04-25 20:23:38 +01:00
|
|
|
self.bit_size = int(m.group('bits')) if m.group('bits') else 0
|
2016-03-17 18:04:49 +00:00
|
|
|
self.inexact = m.group('inexact') is not None
|
2017-01-10 04:47:31 +00:00
|
|
|
self.cond = m.group('cond')
|
2014-11-15 01:47:56 +00:00
|
|
|
self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset)
|
|
|
|
for (i, src) in enumerate(expr[1:]) ]
|
|
|
|
|
|
|
|
def render(self):
|
|
|
|
srcs = "\n".join(src.render() for src in self.sources)
|
|
|
|
return srcs + super(Expression, self).render()
|
|
|
|
|
2016-04-26 04:58:47 +01:00
|
|
|
class IntEquivalenceRelation(object):
|
|
|
|
"""A class representing an equivalence relation on integers.
|
|
|
|
|
|
|
|
Each integer has a canonical form which is the maximum integer to which it
|
|
|
|
is equivalent. Two integers are equivalent precisely when they have the
|
|
|
|
same canonical form.
|
|
|
|
|
|
|
|
The convention of maximum is explicitly chosen to make using it in
|
|
|
|
BitSizeValidator easier because it means that an actual bit_size (if any)
|
|
|
|
will always be the canonical form.
|
|
|
|
"""
|
|
|
|
def __init__(self):
|
|
|
|
self._remap = {}
|
|
|
|
|
|
|
|
def get_canonical(self, x):
|
|
|
|
"""Get the canonical integer corresponding to x."""
|
|
|
|
if x in self._remap:
|
|
|
|
return self.get_canonical(self._remap[x])
|
|
|
|
else:
|
|
|
|
return x
|
|
|
|
|
|
|
|
def add_equiv(self, a, b):
|
|
|
|
"""Add an equivalence and return the canonical form."""
|
|
|
|
c = max(self.get_canonical(a), self.get_canonical(b))
|
|
|
|
if a != c:
|
|
|
|
assert a < c
|
|
|
|
self._remap[a] = c
|
|
|
|
|
|
|
|
if b != c:
|
|
|
|
assert b < c
|
|
|
|
self._remap[b] = c
|
|
|
|
|
|
|
|
return c
|
|
|
|
|
|
|
|
class BitSizeValidator(object):
|
|
|
|
"""A class for validating bit sizes of expressions.
|
|
|
|
|
|
|
|
NIR supports multiple bit-sizes on expressions in order to handle things
|
|
|
|
such as fp64. The source and destination of every ALU operation is
|
|
|
|
assigned a type and that type may or may not specify a bit size. Sources
|
|
|
|
and destinations whose type does not specify a bit size are considered
|
|
|
|
"unsized" and automatically take on the bit size of the corresponding
|
|
|
|
register or SSA value. NIR has two simple rules for bit sizes that are
|
|
|
|
validated by nir_validator:
|
|
|
|
|
|
|
|
1) A given SSA def or register has a single bit size that is respected by
|
|
|
|
everything that reads from it or writes to it.
|
|
|
|
|
|
|
|
2) The bit sizes of all unsized inputs/outputs on any given ALU
|
|
|
|
instruction must match. They need not match the sized inputs or
|
|
|
|
outputs but they must match each other.
|
|
|
|
|
|
|
|
In order to keep nir_algebraic relatively simple and easy-to-use,
|
|
|
|
nir_search supports a type of bit-size inference based on the two rules
|
|
|
|
above. This is similar to type inference in many common programming
|
|
|
|
languages. If, for instance, you are constructing an add operation and you
|
|
|
|
know the second source is 16-bit, then you know that the other source and
|
|
|
|
the destination must also be 16-bit. There are, however, cases where this
|
|
|
|
inference can be ambiguous or contradictory. Consider, for instance, the
|
|
|
|
following transformation:
|
|
|
|
|
|
|
|
(('usub_borrow', a, b), ('b2i', ('ult', a, b)))
|
|
|
|
|
|
|
|
This transformation can potentially cause a problem because usub_borrow is
|
|
|
|
well-defined for any bit-size of integer. However, b2i always generates a
|
|
|
|
32-bit result so it could end up replacing a 64-bit expression with one
|
|
|
|
that takes two 64-bit values and produces a 32-bit value. As another
|
|
|
|
example, consider this expression:
|
|
|
|
|
|
|
|
(('bcsel', a, b, 0), ('iand', a, b))
|
|
|
|
|
|
|
|
In this case, in the search expression a must be 32-bit but b can
|
|
|
|
potentially have any bit size. If we had a 64-bit b value, we would end up
|
|
|
|
trying to and a 32-bit value with a 64-bit value which would be invalid
|
|
|
|
|
|
|
|
This class solves that problem by providing a validation layer that proves
|
|
|
|
that a given search-and-replace operation is 100% well-defined before we
|
|
|
|
generate any code. This ensures that bugs are caught at compile time
|
|
|
|
rather than at run time.
|
|
|
|
|
|
|
|
The basic operation of the validator is very similar to the bitsize_tree in
|
|
|
|
nir_search only a little more subtle. Instead of simply tracking bit
|
|
|
|
sizes, it tracks "bit classes" where each class is represented by an
|
|
|
|
integer. A value of 0 means we don't know anything yet, positive values
|
|
|
|
are actual bit-sizes, and negative values are used to track equivalence
|
|
|
|
classes of sizes that must be the same but have yet to receive an actual
|
|
|
|
size. The first stage uses the bitsize_tree algorithm to assign bit
|
|
|
|
classes to each variable. If it ever comes across an inconsistency, it
|
|
|
|
assert-fails. Then the second stage uses that information to prove that
|
|
|
|
the resulting expression can always validly be constructed.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, varset):
|
|
|
|
self._num_classes = 0
|
|
|
|
self._var_classes = [0] * len(varset.names)
|
|
|
|
self._class_relation = IntEquivalenceRelation()
|
|
|
|
|
|
|
|
def validate(self, search, replace):
|
2018-10-19 20:03:24 +01:00
|
|
|
search_dst_class = self._propagate_bit_size_up(search)
|
|
|
|
if search_dst_class == 0:
|
|
|
|
search_dst_class = self._new_class()
|
|
|
|
self._propagate_bit_class_down(search, search_dst_class)
|
2016-04-26 04:58:47 +01:00
|
|
|
|
2018-10-19 20:03:24 +01:00
|
|
|
replace_dst_class = self._validate_bit_class_up(replace)
|
2018-10-19 20:33:36 +01:00
|
|
|
if replace_dst_class != 0:
|
|
|
|
assert search_dst_class != 0, \
|
|
|
|
'Search expression matches any bit size but replace ' \
|
|
|
|
'expression can only generate {0}-bit values' \
|
|
|
|
.format(replace_dst_class)
|
|
|
|
|
|
|
|
assert search_dst_class == replace_dst_class, \
|
|
|
|
'Search expression matches any {0}-bit values but replace ' \
|
|
|
|
'expression can only generates {1}-bit values' \
|
|
|
|
.format(search_dst_class, replace_dst_class)
|
|
|
|
|
2018-10-19 20:03:24 +01:00
|
|
|
self._validate_bit_class_down(replace, search_dst_class)
|
2016-04-26 04:58:47 +01:00
|
|
|
|
|
|
|
def _new_class(self):
|
|
|
|
self._num_classes += 1
|
|
|
|
return -self._num_classes
|
|
|
|
|
2018-10-19 20:03:24 +01:00
|
|
|
def _set_var_bit_class(self, var, bit_class):
|
2016-04-26 04:58:47 +01:00
|
|
|
assert bit_class != 0
|
2018-10-19 20:03:24 +01:00
|
|
|
var_class = self._var_classes[var.index]
|
2016-04-26 04:58:47 +01:00
|
|
|
if var_class == 0:
|
2018-10-19 20:03:24 +01:00
|
|
|
self._var_classes[var.index] = bit_class
|
2016-04-26 04:58:47 +01:00
|
|
|
else:
|
2018-10-23 00:29:52 +01:00
|
|
|
canon_var_class = self._class_relation.get_canonical(var_class)
|
|
|
|
canon_bit_class = self._class_relation.get_canonical(bit_class)
|
2018-10-19 20:31:19 +01:00
|
|
|
assert canon_var_class < 0 or canon_bit_class < 0 or \
|
|
|
|
canon_var_class == canon_bit_class, \
|
|
|
|
'Variable {0} cannot be both {1}-bit and {2}-bit' \
|
|
|
|
.format(str(var), bit_class, var_class)
|
2016-04-26 04:58:47 +01:00
|
|
|
var_class = self._class_relation.add_equiv(var_class, bit_class)
|
2018-10-19 20:03:24 +01:00
|
|
|
self._var_classes[var.index] = var_class
|
2016-04-26 04:58:47 +01:00
|
|
|
|
2018-10-19 20:03:24 +01:00
|
|
|
def _get_var_bit_class(self, var):
|
|
|
|
return self._class_relation.get_canonical(self._var_classes[var.index])
|
2016-04-26 04:58:47 +01:00
|
|
|
|
|
|
|
def _propagate_bit_size_up(self, val):
|
|
|
|
if isinstance(val, (Constant, Variable)):
|
|
|
|
return val.bit_size
|
|
|
|
|
|
|
|
elif isinstance(val, Expression):
|
|
|
|
nir_op = opcodes[val.opcode]
|
|
|
|
val.common_size = 0
|
|
|
|
for i in range(nir_op.num_inputs):
|
|
|
|
src_bits = self._propagate_bit_size_up(val.sources[i])
|
|
|
|
if src_bits == 0:
|
|
|
|
continue
|
|
|
|
|
|
|
|
src_type_bits = type_bits(nir_op.input_types[i])
|
|
|
|
if src_type_bits != 0:
|
2018-10-19 20:33:36 +01:00
|
|
|
assert src_bits == src_type_bits, \
|
|
|
|
'Source {0} of nir_op_{1} must be a {2}-bit value but ' \
|
|
|
|
'the only possible matched values are {3}-bit: {4}' \
|
|
|
|
.format(i, val.opcode, src_type_bits, src_bits, str(val))
|
2016-04-26 04:58:47 +01:00
|
|
|
else:
|
2018-10-19 20:33:36 +01:00
|
|
|
assert val.common_size == 0 or src_bits == val.common_size, \
|
|
|
|
'Expression cannot have both {0}-bit and {1}-bit ' \
|
|
|
|
'variable-width sources: {2}' \
|
|
|
|
.format(src_bits, val.common_size, str(val))
|
2016-04-26 04:58:47 +01:00
|
|
|
val.common_size = src_bits
|
|
|
|
|
|
|
|
dst_type_bits = type_bits(nir_op.output_type)
|
|
|
|
if dst_type_bits != 0:
|
2018-10-19 20:33:36 +01:00
|
|
|
assert val.bit_size == 0 or val.bit_size == dst_type_bits, \
|
|
|
|
'nir_op_{0} produces a {1}-bit result but a {2}-bit ' \
|
|
|
|
'result was requested' \
|
|
|
|
.format(val.opcode, dst_type_bits, val.bit_size)
|
2016-04-26 04:58:47 +01:00
|
|
|
return dst_type_bits
|
|
|
|
else:
|
|
|
|
if val.common_size != 0:
|
2018-10-19 20:33:36 +01:00
|
|
|
assert val.bit_size == 0 or val.bit_size == val.common_size, \
|
|
|
|
'Variable width expression musr be {0}-bit based on ' \
|
|
|
|
'the sources but a {1}-bit result was requested: {2}' \
|
|
|
|
.format(val.common_size, val.bit_size, str(val))
|
2016-04-26 04:58:47 +01:00
|
|
|
else:
|
|
|
|
val.common_size = val.bit_size
|
|
|
|
return val.common_size
|
|
|
|
|
|
|
|
def _propagate_bit_class_down(self, val, bit_class):
|
|
|
|
if isinstance(val, Constant):
|
2018-10-19 20:33:36 +01:00
|
|
|
assert val.bit_size == 0 or val.bit_size == bit_class, \
|
|
|
|
'Constant is {0}-bit but a {1}-bit value is required: {2}' \
|
|
|
|
.format(val.bit_size, bit_class, str(val))
|
2016-04-26 04:58:47 +01:00
|
|
|
|
|
|
|
elif isinstance(val, Variable):
|
2018-10-19 20:33:36 +01:00
|
|
|
assert val.bit_size == 0 or val.bit_size == bit_class, \
|
|
|
|
'Variable is {0}-bit but a {1}-bit value is required: {2}' \
|
|
|
|
.format(val.bit_size, bit_class, str(val))
|
2018-10-19 20:03:24 +01:00
|
|
|
self._set_var_bit_class(val, bit_class)
|
2016-04-26 04:58:47 +01:00
|
|
|
|
|
|
|
elif isinstance(val, Expression):
|
|
|
|
nir_op = opcodes[val.opcode]
|
|
|
|
dst_type_bits = type_bits(nir_op.output_type)
|
|
|
|
if dst_type_bits != 0:
|
2018-10-19 20:33:36 +01:00
|
|
|
assert bit_class == 0 or bit_class == dst_type_bits, \
|
|
|
|
'nir_op_{0} produces a {1}-bit result but the parent ' \
|
|
|
|
'expression wants a {2}-bit value' \
|
|
|
|
.format(val.opcode, dst_type_bits, bit_class)
|
2016-04-26 04:58:47 +01:00
|
|
|
else:
|
2018-10-19 20:33:36 +01:00
|
|
|
assert val.common_size == 0 or val.common_size == bit_class, \
|
|
|
|
'Variable-width expression produces a {0}-bit result ' \
|
|
|
|
'based on the source widths but the parent expression ' \
|
|
|
|
'wants a {1}-bit value: {2}' \
|
|
|
|
.format(val.common_size, bit_class, str(val))
|
2016-04-26 04:58:47 +01:00
|
|
|
val.common_size = bit_class
|
|
|
|
|
|
|
|
if val.common_size:
|
|
|
|
common_class = val.common_size
|
|
|
|
elif nir_op.num_inputs:
|
|
|
|
# If we got here then we have no idea what the actual size is.
|
|
|
|
# Instead, we use a generic class
|
|
|
|
common_class = self._new_class()
|
|
|
|
|
|
|
|
for i in range(nir_op.num_inputs):
|
|
|
|
src_type_bits = type_bits(nir_op.input_types[i])
|
|
|
|
if src_type_bits != 0:
|
|
|
|
self._propagate_bit_class_down(val.sources[i], src_type_bits)
|
|
|
|
else:
|
|
|
|
self._propagate_bit_class_down(val.sources[i], common_class)
|
|
|
|
|
|
|
|
def _validate_bit_class_up(self, val):
|
|
|
|
if isinstance(val, Constant):
|
|
|
|
return val.bit_size
|
|
|
|
|
|
|
|
elif isinstance(val, Variable):
|
2018-10-19 20:03:24 +01:00
|
|
|
var_class = self._get_var_bit_class(val)
|
2016-04-26 04:58:47 +01:00
|
|
|
# By the time we get to validation, every variable should have a class
|
|
|
|
assert var_class != 0
|
|
|
|
|
|
|
|
# If we have an explicit size provided by the user, the variable
|
|
|
|
# *must* exactly match the search. It cannot be implicitly sized
|
|
|
|
# because otherwise we could end up with a conflict at runtime.
|
|
|
|
assert val.bit_size == 0 or val.bit_size == var_class
|
|
|
|
|
|
|
|
return var_class
|
|
|
|
|
|
|
|
elif isinstance(val, Expression):
|
|
|
|
nir_op = opcodes[val.opcode]
|
|
|
|
val.common_class = 0
|
|
|
|
for i in range(nir_op.num_inputs):
|
|
|
|
src_class = self._validate_bit_class_up(val.sources[i])
|
|
|
|
if src_class == 0:
|
|
|
|
continue
|
|
|
|
|
|
|
|
src_type_bits = type_bits(nir_op.input_types[i])
|
|
|
|
if src_type_bits != 0:
|
|
|
|
assert src_class == src_type_bits
|
|
|
|
else:
|
|
|
|
assert val.common_class == 0 or src_class == val.common_class
|
|
|
|
val.common_class = src_class
|
|
|
|
|
|
|
|
dst_type_bits = type_bits(nir_op.output_type)
|
|
|
|
if dst_type_bits != 0:
|
|
|
|
assert val.bit_size == 0 or val.bit_size == dst_type_bits
|
|
|
|
return dst_type_bits
|
|
|
|
else:
|
|
|
|
if val.common_class != 0:
|
|
|
|
assert val.bit_size == 0 or val.bit_size == val.common_class
|
|
|
|
else:
|
|
|
|
val.common_class = val.bit_size
|
|
|
|
return val.common_class
|
|
|
|
|
|
|
|
def _validate_bit_class_down(self, val, bit_class):
|
|
|
|
# At this point, everything *must* have a bit class. Otherwise, we have
|
|
|
|
# a value we don't know how to define.
|
|
|
|
assert bit_class != 0
|
|
|
|
|
|
|
|
if isinstance(val, Constant):
|
|
|
|
assert val.bit_size == 0 or val.bit_size == bit_class
|
|
|
|
|
|
|
|
elif isinstance(val, Variable):
|
|
|
|
assert val.bit_size == 0 or val.bit_size == bit_class
|
|
|
|
|
|
|
|
elif isinstance(val, Expression):
|
|
|
|
nir_op = opcodes[val.opcode]
|
|
|
|
dst_type_bits = type_bits(nir_op.output_type)
|
|
|
|
if dst_type_bits != 0:
|
|
|
|
assert bit_class == dst_type_bits
|
|
|
|
else:
|
|
|
|
assert val.common_class == 0 or val.common_class == bit_class
|
|
|
|
val.common_class = bit_class
|
|
|
|
|
|
|
|
for i in range(nir_op.num_inputs):
|
|
|
|
src_type_bits = type_bits(nir_op.input_types[i])
|
|
|
|
if src_type_bits != 0:
|
|
|
|
self._validate_bit_class_down(val.sources[i], src_type_bits)
|
|
|
|
else:
|
|
|
|
self._validate_bit_class_down(val.sources[i], val.common_class)
|
|
|
|
|
2014-11-15 01:47:56 +00:00
|
|
|
_optimization_ids = itertools.count()
|
|
|
|
|
2015-02-03 00:20:06 +00:00
|
|
|
condition_list = ['true']
|
|
|
|
|
2014-11-15 01:47:56 +00:00
|
|
|
class SearchAndReplace(object):
|
2015-02-03 00:20:06 +00:00
|
|
|
def __init__(self, transform):
|
2018-07-05 14:17:39 +01:00
|
|
|
self.id = next(_optimization_ids)
|
2014-11-15 01:47:56 +00:00
|
|
|
|
2015-02-03 00:20:06 +00:00
|
|
|
search = transform[0]
|
|
|
|
replace = transform[1]
|
|
|
|
if len(transform) > 2:
|
|
|
|
self.condition = transform[2]
|
|
|
|
else:
|
|
|
|
self.condition = 'true'
|
|
|
|
|
|
|
|
if self.condition not in condition_list:
|
|
|
|
condition_list.append(self.condition)
|
|
|
|
self.condition_index = condition_list.index(self.condition)
|
|
|
|
|
2014-11-15 01:47:56 +00:00
|
|
|
varset = VarSet()
|
|
|
|
if isinstance(search, Expression):
|
|
|
|
self.search = search
|
|
|
|
else:
|
|
|
|
self.search = Expression(search, "search{0}".format(self.id), varset)
|
|
|
|
|
2015-01-29 19:45:31 +00:00
|
|
|
varset.lock()
|
|
|
|
|
2014-11-15 01:47:56 +00:00
|
|
|
if isinstance(replace, Value):
|
|
|
|
self.replace = replace
|
|
|
|
else:
|
|
|
|
self.replace = Value.create(replace, "replace{0}".format(self.id), varset)
|
|
|
|
|
2016-04-26 04:58:47 +01:00
|
|
|
BitSizeValidator(varset).validate(self.search, self.replace)
|
|
|
|
|
2014-11-15 01:47:56 +00:00
|
|
|
_algebraic_pass_template = mako.template.Template("""
|
|
|
|
#include "nir.h"
|
|
|
|
#include "nir_search.h"
|
2017-01-18 17:21:07 +00:00
|
|
|
#include "nir_search_helpers.h"
|
2014-11-15 01:47:56 +00:00
|
|
|
|
2015-03-24 00:22:44 +00:00
|
|
|
#ifndef NIR_OPT_ALGEBRAIC_STRUCT_DEFS
|
|
|
|
#define NIR_OPT_ALGEBRAIC_STRUCT_DEFS
|
|
|
|
|
2014-11-15 01:47:56 +00:00
|
|
|
struct transform {
|
|
|
|
const nir_search_expression *search;
|
|
|
|
const nir_search_value *replace;
|
2015-02-03 00:20:06 +00:00
|
|
|
unsigned condition_offset;
|
2014-11-15 01:47:56 +00:00
|
|
|
};
|
|
|
|
|
2015-03-24 00:22:44 +00:00
|
|
|
#endif
|
|
|
|
|
2018-07-06 11:20:26 +01:00
|
|
|
% for (opcode, xform_list) in xform_dict.items():
|
2014-11-15 01:47:56 +00:00
|
|
|
% for xform in xform_list:
|
|
|
|
${xform.search.render()}
|
|
|
|
${xform.replace.render()}
|
|
|
|
% endfor
|
|
|
|
|
2015-01-28 00:42:38 +00:00
|
|
|
static const struct transform ${pass_name}_${opcode}_xforms[] = {
|
2014-11-15 01:47:56 +00:00
|
|
|
% for xform in xform_list:
|
2015-02-03 00:20:06 +00:00
|
|
|
{ &${xform.search.name}, ${xform.replace.c_ptr}, ${xform.condition_index} },
|
2014-11-15 01:47:56 +00:00
|
|
|
% endfor
|
|
|
|
};
|
|
|
|
% endfor
|
|
|
|
|
|
|
|
static bool
|
2016-04-12 20:30:22 +01:00
|
|
|
${pass_name}_block(nir_block *block, const bool *condition_flags,
|
|
|
|
void *mem_ctx)
|
2014-11-15 01:47:56 +00:00
|
|
|
{
|
2016-04-12 20:30:22 +01:00
|
|
|
bool progress = false;
|
2014-11-15 01:47:56 +00:00
|
|
|
|
2016-04-27 02:34:19 +01:00
|
|
|
nir_foreach_instr_reverse_safe(instr, block) {
|
2014-11-15 01:47:56 +00:00
|
|
|
if (instr->type != nir_instr_type_alu)
|
|
|
|
continue;
|
|
|
|
|
|
|
|
nir_alu_instr *alu = nir_instr_as_alu(instr);
|
|
|
|
if (!alu->dest.dest.is_ssa)
|
|
|
|
continue;
|
|
|
|
|
|
|
|
switch (alu->op) {
|
|
|
|
% for opcode in xform_dict.keys():
|
|
|
|
case nir_op_${opcode}:
|
|
|
|
for (unsigned i = 0; i < ARRAY_SIZE(${pass_name}_${opcode}_xforms); i++) {
|
2015-01-28 00:42:38 +00:00
|
|
|
const struct transform *xform = &${pass_name}_${opcode}_xforms[i];
|
2016-04-12 20:30:22 +01:00
|
|
|
if (condition_flags[xform->condition_offset] &&
|
2015-02-03 00:20:06 +00:00
|
|
|
nir_replace_instr(alu, xform->search, xform->replace,
|
2016-04-12 20:30:22 +01:00
|
|
|
mem_ctx)) {
|
|
|
|
progress = true;
|
2015-01-15 03:08:32 +00:00
|
|
|
break;
|
|
|
|
}
|
2014-11-15 01:47:56 +00:00
|
|
|
}
|
|
|
|
break;
|
|
|
|
% endfor
|
|
|
|
default:
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2016-04-12 20:30:22 +01:00
|
|
|
return progress;
|
2014-11-15 01:47:56 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
static bool
|
2015-02-03 00:20:06 +00:00
|
|
|
${pass_name}_impl(nir_function_impl *impl, const bool *condition_flags)
|
2014-11-15 01:47:56 +00:00
|
|
|
{
|
2016-04-12 20:30:22 +01:00
|
|
|
void *mem_ctx = ralloc_parent(impl);
|
|
|
|
bool progress = false;
|
2014-11-15 01:47:56 +00:00
|
|
|
|
2016-04-12 20:30:22 +01:00
|
|
|
nir_foreach_block_reverse(block, impl) {
|
|
|
|
progress |= ${pass_name}_block(block, condition_flags, mem_ctx);
|
|
|
|
}
|
2014-11-15 01:47:56 +00:00
|
|
|
|
2016-04-12 20:30:22 +01:00
|
|
|
if (progress)
|
2014-12-13 00:22:46 +00:00
|
|
|
nir_metadata_preserve(impl, nir_metadata_block_index |
|
|
|
|
nir_metadata_dominance);
|
2014-11-15 01:47:56 +00:00
|
|
|
|
2016-04-12 20:30:22 +01:00
|
|
|
return progress;
|
2014-11-15 01:47:56 +00:00
|
|
|
}
|
|
|
|
|
2015-02-03 00:20:06 +00:00
|
|
|
|
2014-11-15 01:47:56 +00:00
|
|
|
bool
|
|
|
|
${pass_name}(nir_shader *shader)
|
|
|
|
{
|
|
|
|
bool progress = false;
|
2015-02-03 00:20:06 +00:00
|
|
|
bool condition_flags[${len(condition_list)}];
|
|
|
|
const nir_shader_compiler_options *options = shader->options;
|
2016-04-07 23:03:39 +01:00
|
|
|
(void) options;
|
2015-02-03 00:20:06 +00:00
|
|
|
|
|
|
|
% for index, condition in enumerate(condition_list):
|
|
|
|
condition_flags[${index}] = ${condition};
|
|
|
|
% endfor
|
2014-11-15 01:47:56 +00:00
|
|
|
|
2016-04-27 04:26:42 +01:00
|
|
|
nir_foreach_function(function, shader) {
|
2015-12-26 18:00:47 +00:00
|
|
|
if (function->impl)
|
|
|
|
progress |= ${pass_name}_impl(function->impl, condition_flags);
|
2014-11-15 01:47:56 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
return progress;
|
|
|
|
}
|
|
|
|
""")
|
|
|
|
|
|
|
|
class AlgebraicPass(object):
|
|
|
|
def __init__(self, pass_name, transforms):
|
2018-06-27 11:37:38 +01:00
|
|
|
self.xform_dict = OrderedDict()
|
2014-11-15 01:47:56 +00:00
|
|
|
self.pass_name = pass_name
|
|
|
|
|
2016-04-25 19:36:08 +01:00
|
|
|
error = False
|
|
|
|
|
2014-11-15 01:47:56 +00:00
|
|
|
for xform in transforms:
|
|
|
|
if not isinstance(xform, SearchAndReplace):
|
2016-04-25 19:36:08 +01:00
|
|
|
try:
|
|
|
|
xform = SearchAndReplace(xform)
|
|
|
|
except:
|
|
|
|
print("Failed to parse transformation:", file=sys.stderr)
|
|
|
|
print(" " + str(xform), file=sys.stderr)
|
|
|
|
traceback.print_exc(file=sys.stderr)
|
|
|
|
print('', file=sys.stderr)
|
|
|
|
error = True
|
|
|
|
continue
|
2014-11-15 01:47:56 +00:00
|
|
|
|
|
|
|
if xform.search.opcode not in self.xform_dict:
|
|
|
|
self.xform_dict[xform.search.opcode] = []
|
|
|
|
|
|
|
|
self.xform_dict[xform.search.opcode].append(xform)
|
|
|
|
|
2016-04-25 19:36:08 +01:00
|
|
|
if error:
|
|
|
|
sys.exit(1)
|
|
|
|
|
2014-11-15 01:47:56 +00:00
|
|
|
def render(self):
|
|
|
|
return _algebraic_pass_template.render(pass_name=self.pass_name,
|
2015-02-03 00:20:06 +00:00
|
|
|
xform_dict=self.xform_dict,
|
|
|
|
condition_list=condition_list)
|