nir/opcodes: Pull in the type helpers from constant_expressions
While we're at it, we rework them a bit to all use regular expressions and assert more. Reviewed-by: Connor Abbott <cwabbott0@gmail.com>
This commit is contained in:
parent
a0ae12ca91
commit
03571a7a6c
|
@ -1,23 +1,8 @@
|
|||
from __future__ import print_function
|
||||
|
||||
import re
|
||||
|
||||
type_split_re = re.compile(r'(?P<type>[a-z]+)(?P<bits>\d+)')
|
||||
|
||||
def type_has_size(type_):
|
||||
return type_[-1:].isdigit()
|
||||
|
||||
def type_size(type_):
|
||||
assert type_has_size(type_)
|
||||
return int(type_split_re.match(type_).group('bits'))
|
||||
|
||||
def type_sizes(type_):
|
||||
if type_has_size(type_):
|
||||
return [type_size(type_)]
|
||||
elif type_ == 'float':
|
||||
return [16, 32, 64]
|
||||
else:
|
||||
return [8, 16, 32, 64]
|
||||
from nir_opcodes import opcodes
|
||||
from nir_opcodes import type_has_size, type_size, type_sizes, type_base_type
|
||||
|
||||
def type_add_size(type_, size):
|
||||
if type_has_size(type_):
|
||||
|
@ -44,10 +29,7 @@ def get_const_field(type_):
|
|||
elif type_ == "float16":
|
||||
return "u16"
|
||||
else:
|
||||
m = type_split_re.match(type_)
|
||||
if not m:
|
||||
raise Exception(str(type_))
|
||||
return m.group('type')[0] + m.group('bits')
|
||||
return type_base_type(type_)[0] + str(type_size(type_))
|
||||
|
||||
template = """\
|
||||
/*
|
||||
|
@ -429,7 +411,6 @@ nir_eval_const_opcode(nir_op op, unsigned num_components,
|
|||
}
|
||||
}"""
|
||||
|
||||
from nir_opcodes import opcodes
|
||||
from mako.template import Template
|
||||
|
||||
print(Template(template).render(opcodes=opcodes, type_sizes=type_sizes,
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
# Authors:
|
||||
# Connor Abbott (cwabbott0@gmail.com)
|
||||
|
||||
import re
|
||||
|
||||
# Class that represents all the information we have about the opcode
|
||||
# NOTE: this must be kept in sync with nir_op_info
|
||||
|
@ -99,6 +100,33 @@ tint64 = "int64"
|
|||
tuint64 = "uint64"
|
||||
tfloat64 = "float64"
|
||||
|
||||
_TYPE_SPLIT_RE = re.compile(r'(?P<type>int|uint|float|bool)(?P<bits>\d+)?')
|
||||
|
||||
def type_has_size(type_):
|
||||
m = _TYPE_SPLIT_RE.match(type_)
|
||||
assert m is not None, 'Invalid NIR type string: "{}"'.format(type_)
|
||||
return m.group('bits') is not None
|
||||
|
||||
def type_size(type_):
|
||||
m = _TYPE_SPLIT_RE.match(type_)
|
||||
assert m is not None, 'Invalid NIR type string: "{}"'.format(type_)
|
||||
assert m.group('bits') is not None, \
|
||||
'NIR type string has no bit size: "{}"'.format(type_)
|
||||
return int(m.group('bits'))
|
||||
|
||||
def type_sizes(type_):
|
||||
if type_has_size(type_):
|
||||
return [type_size(type_)]
|
||||
elif type_ == 'float':
|
||||
return [16, 32, 64]
|
||||
else:
|
||||
return [8, 16, 32, 64]
|
||||
|
||||
def type_base_type(type_):
|
||||
m = _TYPE_SPLIT_RE.match(type_)
|
||||
assert m is not None, 'Invalid NIR type string: "{}"'.format(type_)
|
||||
return m.group('type')
|
||||
|
||||
commutative = "commutative "
|
||||
associative = "associative "
|
||||
|
||||
|
@ -175,11 +203,7 @@ for src_t in [tint, tuint, tfloat]:
|
|||
dst_types = [tint, tuint, tfloat]
|
||||
|
||||
for dst_t in dst_types:
|
||||
if dst_t == tfloat:
|
||||
bit_sizes = [16, 32, 64]
|
||||
else:
|
||||
bit_sizes = [8, 16, 32, 64]
|
||||
for bit_size in bit_sizes:
|
||||
for bit_size in type_sizes(dst_t):
|
||||
if bit_size == 16 and dst_t == tfloat and src_t == tfloat:
|
||||
rnd_modes = ['_rtne', '_rtz', '']
|
||||
for rnd_mode in rnd_modes:
|
||||
|
|
|
@ -25,7 +25,7 @@
|
|||
|
||||
from __future__ import print_function
|
||||
|
||||
from nir_opcodes import opcodes
|
||||
from nir_opcodes import opcodes, type_sizes
|
||||
from mako.template import Template
|
||||
|
||||
template = Template("""
|
||||
|
@ -64,12 +64,7 @@ nir_type_conversion_op(nir_alu_type src, nir_alu_type dst, nir_rounding_mode rnd
|
|||
% endif
|
||||
% endif
|
||||
switch (dst_bit_size) {
|
||||
% if dst_t == 'float':
|
||||
<% bit_sizes = [16, 32, 64] %>
|
||||
% else:
|
||||
<% bit_sizes = [8, 16, 32, 64] %>
|
||||
% endif
|
||||
% for dst_bits in bit_sizes:
|
||||
% for dst_bits in type_sizes(dst_t):
|
||||
case ${dst_bits}:
|
||||
% if src_t == 'float' and dst_t == 'float' and dst_bits == 16:
|
||||
switch(rnd) {
|
||||
|
@ -137,4 +132,4 @@ const nir_op_info nir_op_infos[nir_num_opcodes] = {
|
|||
};
|
||||
""")
|
||||
|
||||
print(template.render(opcodes=opcodes))
|
||||
print(template.render(opcodes=opcodes, type_sizes=type_sizes))
|
||||
|
|
Loading…
Reference in New Issue