nir/algebraic: Move relocations for expression conds to a table.

This helps concentrate the dirty pages from the relocations, reduces how
many relocations there are, and reduces the size of each expression
(assuming expressions mostly don't have conditions or the conditions are
mostly reused).  Reduces libvulkan_intel.so size by 8.7kb.

Reviewed-by: Adam Jackson <ajax@redhat.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/13987>
This commit is contained in:
Emma Anholt 2021-11-30 14:23:39 -08:00 committed by Marge Bot
parent 7635379dc7
commit 8485a78977
4 changed files with 57 additions and 20 deletions

View File

@ -53,6 +53,17 @@ conv_opcode_types = {
'f2b' : 'bool',
}
def get_cond_index(conds, cond):
if cond:
if cond in conds:
return conds[cond]
else:
cond_index = len(conds)
conds[cond] = cond_index
return cond_index
else:
return -1
def get_c_opcode(op):
if op in conv_opcode_types:
return 'nir_search_op_' + op
@ -89,12 +100,12 @@ class VarSet(object):
class Value(object):
@staticmethod
def create(val, name_base, varset):
def create(val, name_base, varset, algebraic_pass):
if isinstance(val, bytes):
val = val.decode('utf-8')
if isinstance(val, tuple):
return Expression(val, name_base, varset)
return Expression(val, name_base, varset, algebraic_pass)
elif isinstance(val, Expression):
return val
elif isinstance(val, str):
@ -178,7 +189,7 @@ class Value(object):
${val.comm_expr_idx}, ${val.comm_exprs},
${val.c_opcode()},
{ ${', '.join(src.array_index for src in val.sources)} },
${val.cond if val.cond else 'NULL'},
${val.cond_index},
% endif
} },
""")
@ -326,7 +337,7 @@ _opcode_re = re.compile(r"(?P<inexact>~)?(?P<exact>!)?(?P<opcode>\w+)(?:@(?P<bit
r"(?P<cond>\([^\)]+\))?")
class Expression(Value):
def __init__(self, expr, name_base, varset):
def __init__(self, expr, name_base, varset, algebraic_pass):
Value.__init__(self, expr, name_base, "expression")
assert isinstance(expr, tuple)
@ -356,7 +367,11 @@ class Expression(Value):
self.cond = c[0] if c else None
self.many_commutative_expressions = True
self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset)
# Deduplicate references to the condition functions for the expressions
# and save the index for the order they were added.
self.cond_index = get_cond_index(algebraic_pass.expression_cond, self.cond)
self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset, algebraic_pass)
for (i, src) in enumerate(expr[1:]) ]
# nir_search_expression::srcs is hard-coded to 4
@ -730,7 +745,7 @@ _optimization_ids = itertools.count()
condition_list = ['true']
class SearchAndReplace(object):
def __init__(self, transform):
def __init__(self, transform, algebraic_pass):
self.id = next(_optimization_ids)
search = transform[0]
@ -748,14 +763,14 @@ class SearchAndReplace(object):
if isinstance(search, Expression):
self.search = search
else:
self.search = Expression(search, "search{0}".format(self.id), varset)
self.search = Expression(search, "search{0}".format(self.id), varset, algebraic_pass)
varset.lock()
if isinstance(replace, Value):
self.replace = replace
else:
self.replace = Value.create(replace, "replace{0}".format(self.id), varset)
self.replace = Value.create(replace, "replace{0}".format(self.id), varset, algebraic_pass)
BitSizeValidator(varset).validate(self.search, self.replace)
@ -1041,6 +1056,14 @@ ${xform.replace.render(cache)}
% endfor
};
% if expression_cond:
static const nir_search_expression_cond ${pass_name}_expression_cond[] = {
% for cond in expression_cond:
${cond[0]},
% endfor
};
% endif
% for state_id, state_xforms in enumerate(automaton.state_patterns):
% if state_xforms: # avoid emitting a 0-length array for MSVC
static const struct transform ${pass_name}_state${state_id}_xforms[] = {
@ -1100,6 +1123,7 @@ static const nir_algebraic_table ${pass_name}_table = {
.transform_counts = ${pass_name}_transform_counts,
.pass_op_table = ${pass_name}_pass_op_table,
.values = ${pass_name}_values,
.expression_cond = ${ pass_name + "_expression_cond" if expression_cond else "NULL" },
};
bool
@ -1134,13 +1158,14 @@ class AlgebraicPass(object):
self.xforms = []
self.opcode_xforms = defaultdict(lambda : [])
self.pass_name = pass_name
self.expression_cond = {}
error = False
for xform in transforms:
if not isinstance(xform, SearchAndReplace):
try:
xform = SearchAndReplace(xform)
xform = SearchAndReplace(xform, self)
except:
print("Failed to parse transformation:", file=sys.stderr)
print(" " + str(xform), file=sys.stderr)
@ -1196,5 +1221,6 @@ class AlgebraicPass(object):
opcode_xforms=self.opcode_xforms,
condition_list=condition_list,
automaton=self.automaton,
expression_cond = sorted(self.expression_cond.items(), key=lambda kv: kv[1]),
get_c_opcode=get_c_opcode,
itertools=itertools)

View File

@ -50,7 +50,7 @@ struct match_state {
};
static bool
match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
match_expression(const nir_algebraic_table *table, const nir_search_expression *expr, nir_alu_instr *instr,
unsigned num_components, const uint8_t *swizzle,
struct match_state *state);
static bool
@ -253,7 +253,8 @@ nir_op_for_search_op(uint16_t sop, unsigned bit_size)
}
static bool
match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src,
match_value(const nir_algebraic_table *table,
const nir_search_value *value, nir_alu_instr *instr, unsigned src,
unsigned num_components, const uint8_t *swizzle,
struct match_state *state)
{
@ -289,7 +290,7 @@ match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src,
if (instr->src[src].src.ssa->parent_instr->type != nir_instr_type_alu)
return false;
return match_expression(nir_search_value_as_expression(value),
return match_expression(table, nir_search_value_as_expression(value),
nir_instr_as_alu(instr->src[src].src.ssa->parent_instr),
num_components, new_swizzle, state);
@ -390,11 +391,11 @@ match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src,
}
static bool
match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
match_expression(const nir_algebraic_table *table, const nir_search_expression *expr, nir_alu_instr *instr,
unsigned num_components, const uint8_t *swizzle,
struct match_state *state)
{
if (expr->cond && !expr->cond(instr))
if (expr->cond_index != -1 && !table->expression_cond[expr->cond_index](instr))
return false;
if (!nir_op_matches_search_op(instr->op, expr->opcode))
@ -441,7 +442,7 @@ match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
/* 2src_commutative instructions that have 3 sources are only commutative
* in the first two sources. Source 2 is always source 2.
*/
if (!match_value(&state->table->values[expr->srcs[i]].value, instr,
if (!match_value(table, &state->table->values[expr->srcs[i]].value, instr,
i < 2 ? i ^ comm_op_flip : i,
num_components, swizzle, state)) {
matched = false;
@ -720,7 +721,7 @@ nir_replace_instr(nir_builder *build, nir_alu_instr *instr,
state.comm_op_direction = comb;
state.variables_seen = 0;
if (match_expression(search, instr,
if (match_expression(table, search, instr,
instr->dest.dest.ssa.num_components,
swizzle, &state)) {
found = true;

View File

@ -160,13 +160,13 @@ typedef struct {
/* Index in table->values[] for the expression operands */
uint16_t srcs[4];
/** Optional condition fxn ptr
/** Optional table->expression_cond[] fxn ptr index
*
* This allows additional constraints on expression matching, it is
* typically used to match an expressions uses such as the number of times
* the expression is used, and whether its used by an if.
*/
bool (*cond)(nir_alu_instr *instr);
int16_t cond_index;
} nir_search_expression;
struct per_op_table {
@ -189,12 +189,20 @@ typedef union {
nir_search_expression expression;
} nir_search_value_union;
typedef bool (*nir_search_expression_cond)(nir_alu_instr *instr);
/* Generated data table for an algebraic optimization pass. */
typedef struct {
const struct transform **transforms;
const uint16_t *transform_counts;
const struct per_op_table *pass_op_table;
const nir_search_value_union *values;
/**
* Array of condition functions for expressions, referenced by
* nir_search_expression->cond.
*/
const nir_search_expression_cond *expression_cond;
} nir_algebraic_table;
/* Note: these must match the start states created in

View File

@ -26,7 +26,7 @@ import sys
import os
sys.path.insert(1, os.path.join(sys.path[0], '..'))
from nir_algebraic import SearchAndReplace
from nir_algebraic import SearchAndReplace, AlgebraicPass
# These tests check that the bitsize validator correctly rejects various
# different kinds of malformed expressions, and documents what the error
@ -40,9 +40,11 @@ class ValidatorTests(unittest.TestCase):
pattern = ()
message = ''
algebraic_pass = AlgebraicPass("test", [])
def common(self, pattern, message):
with self.assertRaises(AssertionError) as context:
SearchAndReplace(pattern)
SearchAndReplace(pattern, self.algebraic_pass)
self.assertEqual(message, str(context.exception))