nir/algebraic: Move relocations for expression conds to a table.
This helps concentrate the dirty pages from the relocations, reduces how many relocations there are, and reduces the size of each expression (assuming expressions mostly don't have conditions or the conditions are mostly reused). Reduces libvulkan_intel.so size by 8.7kb. Reviewed-by: Adam Jackson <ajax@redhat.com> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/13987>
This commit is contained in:
parent
7635379dc7
commit
8485a78977
|
@ -53,6 +53,17 @@ conv_opcode_types = {
|
|||
'f2b' : 'bool',
|
||||
}
|
||||
|
||||
def get_cond_index(conds, cond):
|
||||
if cond:
|
||||
if cond in conds:
|
||||
return conds[cond]
|
||||
else:
|
||||
cond_index = len(conds)
|
||||
conds[cond] = cond_index
|
||||
return cond_index
|
||||
else:
|
||||
return -1
|
||||
|
||||
def get_c_opcode(op):
|
||||
if op in conv_opcode_types:
|
||||
return 'nir_search_op_' + op
|
||||
|
@ -89,12 +100,12 @@ class VarSet(object):
|
|||
|
||||
class Value(object):
|
||||
@staticmethod
|
||||
def create(val, name_base, varset):
|
||||
def create(val, name_base, varset, algebraic_pass):
|
||||
if isinstance(val, bytes):
|
||||
val = val.decode('utf-8')
|
||||
|
||||
if isinstance(val, tuple):
|
||||
return Expression(val, name_base, varset)
|
||||
return Expression(val, name_base, varset, algebraic_pass)
|
||||
elif isinstance(val, Expression):
|
||||
return val
|
||||
elif isinstance(val, str):
|
||||
|
@ -178,7 +189,7 @@ class Value(object):
|
|||
${val.comm_expr_idx}, ${val.comm_exprs},
|
||||
${val.c_opcode()},
|
||||
{ ${', '.join(src.array_index for src in val.sources)} },
|
||||
${val.cond if val.cond else 'NULL'},
|
||||
${val.cond_index},
|
||||
% endif
|
||||
} },
|
||||
""")
|
||||
|
@ -326,7 +337,7 @@ _opcode_re = re.compile(r"(?P<inexact>~)?(?P<exact>!)?(?P<opcode>\w+)(?:@(?P<bit
|
|||
r"(?P<cond>\([^\)]+\))?")
|
||||
|
||||
class Expression(Value):
|
||||
def __init__(self, expr, name_base, varset):
|
||||
def __init__(self, expr, name_base, varset, algebraic_pass):
|
||||
Value.__init__(self, expr, name_base, "expression")
|
||||
assert isinstance(expr, tuple)
|
||||
|
||||
|
@ -356,7 +367,11 @@ class Expression(Value):
|
|||
self.cond = c[0] if c else None
|
||||
self.many_commutative_expressions = True
|
||||
|
||||
self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset)
|
||||
# Deduplicate references to the condition functions for the expressions
|
||||
# and save the index for the order they were added.
|
||||
self.cond_index = get_cond_index(algebraic_pass.expression_cond, self.cond)
|
||||
|
||||
self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset, algebraic_pass)
|
||||
for (i, src) in enumerate(expr[1:]) ]
|
||||
|
||||
# nir_search_expression::srcs is hard-coded to 4
|
||||
|
@ -730,7 +745,7 @@ _optimization_ids = itertools.count()
|
|||
condition_list = ['true']
|
||||
|
||||
class SearchAndReplace(object):
|
||||
def __init__(self, transform):
|
||||
def __init__(self, transform, algebraic_pass):
|
||||
self.id = next(_optimization_ids)
|
||||
|
||||
search = transform[0]
|
||||
|
@ -748,14 +763,14 @@ class SearchAndReplace(object):
|
|||
if isinstance(search, Expression):
|
||||
self.search = search
|
||||
else:
|
||||
self.search = Expression(search, "search{0}".format(self.id), varset)
|
||||
self.search = Expression(search, "search{0}".format(self.id), varset, algebraic_pass)
|
||||
|
||||
varset.lock()
|
||||
|
||||
if isinstance(replace, Value):
|
||||
self.replace = replace
|
||||
else:
|
||||
self.replace = Value.create(replace, "replace{0}".format(self.id), varset)
|
||||
self.replace = Value.create(replace, "replace{0}".format(self.id), varset, algebraic_pass)
|
||||
|
||||
BitSizeValidator(varset).validate(self.search, self.replace)
|
||||
|
||||
|
@ -1041,6 +1056,14 @@ ${xform.replace.render(cache)}
|
|||
% endfor
|
||||
};
|
||||
|
||||
% if expression_cond:
|
||||
static const nir_search_expression_cond ${pass_name}_expression_cond[] = {
|
||||
% for cond in expression_cond:
|
||||
${cond[0]},
|
||||
% endfor
|
||||
};
|
||||
% endif
|
||||
|
||||
% for state_id, state_xforms in enumerate(automaton.state_patterns):
|
||||
% if state_xforms: # avoid emitting a 0-length array for MSVC
|
||||
static const struct transform ${pass_name}_state${state_id}_xforms[] = {
|
||||
|
@ -1100,6 +1123,7 @@ static const nir_algebraic_table ${pass_name}_table = {
|
|||
.transform_counts = ${pass_name}_transform_counts,
|
||||
.pass_op_table = ${pass_name}_pass_op_table,
|
||||
.values = ${pass_name}_values,
|
||||
.expression_cond = ${ pass_name + "_expression_cond" if expression_cond else "NULL" },
|
||||
};
|
||||
|
||||
bool
|
||||
|
@ -1134,13 +1158,14 @@ class AlgebraicPass(object):
|
|||
self.xforms = []
|
||||
self.opcode_xforms = defaultdict(lambda : [])
|
||||
self.pass_name = pass_name
|
||||
self.expression_cond = {}
|
||||
|
||||
error = False
|
||||
|
||||
for xform in transforms:
|
||||
if not isinstance(xform, SearchAndReplace):
|
||||
try:
|
||||
xform = SearchAndReplace(xform)
|
||||
xform = SearchAndReplace(xform, self)
|
||||
except:
|
||||
print("Failed to parse transformation:", file=sys.stderr)
|
||||
print(" " + str(xform), file=sys.stderr)
|
||||
|
@ -1196,5 +1221,6 @@ class AlgebraicPass(object):
|
|||
opcode_xforms=self.opcode_xforms,
|
||||
condition_list=condition_list,
|
||||
automaton=self.automaton,
|
||||
expression_cond = sorted(self.expression_cond.items(), key=lambda kv: kv[1]),
|
||||
get_c_opcode=get_c_opcode,
|
||||
itertools=itertools)
|
||||
|
|
|
@ -50,7 +50,7 @@ struct match_state {
|
|||
};
|
||||
|
||||
static bool
|
||||
match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
|
||||
match_expression(const nir_algebraic_table *table, const nir_search_expression *expr, nir_alu_instr *instr,
|
||||
unsigned num_components, const uint8_t *swizzle,
|
||||
struct match_state *state);
|
||||
static bool
|
||||
|
@ -253,7 +253,8 @@ nir_op_for_search_op(uint16_t sop, unsigned bit_size)
|
|||
}
|
||||
|
||||
static bool
|
||||
match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src,
|
||||
match_value(const nir_algebraic_table *table,
|
||||
const nir_search_value *value, nir_alu_instr *instr, unsigned src,
|
||||
unsigned num_components, const uint8_t *swizzle,
|
||||
struct match_state *state)
|
||||
{
|
||||
|
@ -289,7 +290,7 @@ match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src,
|
|||
if (instr->src[src].src.ssa->parent_instr->type != nir_instr_type_alu)
|
||||
return false;
|
||||
|
||||
return match_expression(nir_search_value_as_expression(value),
|
||||
return match_expression(table, nir_search_value_as_expression(value),
|
||||
nir_instr_as_alu(instr->src[src].src.ssa->parent_instr),
|
||||
num_components, new_swizzle, state);
|
||||
|
||||
|
@ -390,11 +391,11 @@ match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src,
|
|||
}
|
||||
|
||||
static bool
|
||||
match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
|
||||
match_expression(const nir_algebraic_table *table, const nir_search_expression *expr, nir_alu_instr *instr,
|
||||
unsigned num_components, const uint8_t *swizzle,
|
||||
struct match_state *state)
|
||||
{
|
||||
if (expr->cond && !expr->cond(instr))
|
||||
if (expr->cond_index != -1 && !table->expression_cond[expr->cond_index](instr))
|
||||
return false;
|
||||
|
||||
if (!nir_op_matches_search_op(instr->op, expr->opcode))
|
||||
|
@ -441,7 +442,7 @@ match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
|
|||
/* 2src_commutative instructions that have 3 sources are only commutative
|
||||
* in the first two sources. Source 2 is always source 2.
|
||||
*/
|
||||
if (!match_value(&state->table->values[expr->srcs[i]].value, instr,
|
||||
if (!match_value(table, &state->table->values[expr->srcs[i]].value, instr,
|
||||
i < 2 ? i ^ comm_op_flip : i,
|
||||
num_components, swizzle, state)) {
|
||||
matched = false;
|
||||
|
@ -720,7 +721,7 @@ nir_replace_instr(nir_builder *build, nir_alu_instr *instr,
|
|||
state.comm_op_direction = comb;
|
||||
state.variables_seen = 0;
|
||||
|
||||
if (match_expression(search, instr,
|
||||
if (match_expression(table, search, instr,
|
||||
instr->dest.dest.ssa.num_components,
|
||||
swizzle, &state)) {
|
||||
found = true;
|
||||
|
|
|
@ -160,13 +160,13 @@ typedef struct {
|
|||
/* Index in table->values[] for the expression operands */
|
||||
uint16_t srcs[4];
|
||||
|
||||
/** Optional condition fxn ptr
|
||||
/** Optional table->expression_cond[] fxn ptr index
|
||||
*
|
||||
* This allows additional constraints on expression matching, it is
|
||||
* typically used to match an expressions uses such as the number of times
|
||||
* the expression is used, and whether its used by an if.
|
||||
*/
|
||||
bool (*cond)(nir_alu_instr *instr);
|
||||
int16_t cond_index;
|
||||
} nir_search_expression;
|
||||
|
||||
struct per_op_table {
|
||||
|
@ -189,12 +189,20 @@ typedef union {
|
|||
nir_search_expression expression;
|
||||
} nir_search_value_union;
|
||||
|
||||
typedef bool (*nir_search_expression_cond)(nir_alu_instr *instr);
|
||||
|
||||
/* Generated data table for an algebraic optimization pass. */
|
||||
typedef struct {
|
||||
const struct transform **transforms;
|
||||
const uint16_t *transform_counts;
|
||||
const struct per_op_table *pass_op_table;
|
||||
const nir_search_value_union *values;
|
||||
|
||||
/**
|
||||
* Array of condition functions for expressions, referenced by
|
||||
* nir_search_expression->cond.
|
||||
*/
|
||||
const nir_search_expression_cond *expression_cond;
|
||||
} nir_algebraic_table;
|
||||
|
||||
/* Note: these must match the start states created in
|
||||
|
|
|
@ -26,7 +26,7 @@ import sys
|
|||
import os
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
from nir_algebraic import SearchAndReplace
|
||||
from nir_algebraic import SearchAndReplace, AlgebraicPass
|
||||
|
||||
# These tests check that the bitsize validator correctly rejects various
|
||||
# different kinds of malformed expressions, and documents what the error
|
||||
|
@ -40,9 +40,11 @@ class ValidatorTests(unittest.TestCase):
|
|||
pattern = ()
|
||||
message = ''
|
||||
|
||||
algebraic_pass = AlgebraicPass("test", [])
|
||||
|
||||
def common(self, pattern, message):
|
||||
with self.assertRaises(AssertionError) as context:
|
||||
SearchAndReplace(pattern)
|
||||
SearchAndReplace(pattern, self.algebraic_pass)
|
||||
|
||||
self.assertEqual(message, str(context.exception))
|
||||
|
||||
|
|
Loading…
Reference in New Issue