FeatherMC/src/protocol/generate_protocol.py

578 lines
18 KiB
Python

import hjson
import sys
text = ''
varint_max_size = 5
indent_count = 0
# this is cleared after each message is parsed
enum_vars = {}
# Print to stderr
def warn(value, *args, sep='', end='\n', flush=False):
print(value, *args, sep=sep, end=end, file=sys.stderr, flush=flush)
def indent():
global indent_count
indent_count += 1
def unindent():
global indent_count
indent_count -= 1
def add_text(fmt, *args):
global text
global indent_count
for i in range(indent_count):
text += ' '
text += (fmt + '\n').format(*args)
def newline():
global text
text += '\n'
# returns None if this is not a .switch(cond) {} block
def get_switch_condition(name):
idx = name.find('.switch(')
if idx == 0:
idx2 = name.find(')')
if name.find(')') > idx:
return name[idx+len('.switch('):idx2]
return None
def extract_array_type(type):
if '[' in type and ']' in type:
return type[:type.rfind('[')]
return type
def extract_array_count(value, name=''):
if isinstance(value, dict):
if name == '' or name == None:
raise Exception('extract_array_count called with dict value but name not provided')
idx = name.find('(')
if idx > 0:
idxend = name.find(')')
if idxend > idx:
attribute = name[idx+1:idxend]
name = name[:idx]
# array of objects key(#count): {...}
if attribute.startswith('#'):
count = attribute[1:]
if count.isdigit():
return int(count)
else:
return count
else:
warn('WARNING: Var name "{}" has unterminated () brackets', name)
return None
else:
return None
if '[' in value and ']' in value:
count = value[value.find('[')+1:value.rfind(']')]
if count.isdigit():
return int(count)
else:
return count
return None
def get_type_info(types, typename):
name = extract_array_type(typename)
if not name in types:
warn('WARNING: Type name "{}" is not a known type or alias.'.format(name))
return {
type: name
}
tp = types[name]
if 'alias' in tp:
alias = tp['alias']
enum = tp['enum'] if 'enum' in tp else False
tp = tp.copy()
realtype = get_type_info(types, tp['alias'])
# copy certain properties from real type
if 'generic' in realtype: tp['generic'] = realtype['generic']
if 'method' in realtype: tp['method'] = realtype['method']
# use real type's size if we have none
if 'size' in realtype and not 'size' in tp:
tp['size'] = realtype['size']
# types without a 'type' value use their own name (e.g. double, float)
if not 'type' in tp:
tp['type'] = name
return tp
def resolve_name(types, name, value, decl=False):
fullname = name
idx = name.find('(')
if idx > 0:
name = name[:idx]
elif idx == 0:
warn('ERROR: Invalid variable name "{}"'.format(name))
exit(1)
# standard array syntax key: type[count]
if decl:
if is_type_array(types, fullname, value):
count = extract_array_count(value, fullname)
if type(count) is int:
# array of fixed size
name += '[' + str(count) + ']'
else:
# array of variable size
name = '*' + name
return name
def resolve_type(types, value):
if isinstance(value, dict):
warn('WARNING: Attempted to call resolve_type on object type')
return ""
typename = value
tp = get_type_info(types, typename)
if 'type' in tp:
return tp['type']
else:
return typename
def is_type_array(types, name, value):
if isinstance(value, dict):
return extract_array_count(value, name) != None
typename = value
tp = get_type_info(types, typename)
if 'size' in tp and tp['size'] == 'count':
return False
else:
return extract_array_count(typename) != None
def get_type_size(types, value, name=''):
global varint_max_size
if isinstance(value, dict):
if name == '' or name == None:
raise Exception('get_type_size called on dict value with no name')
size = 0
for k, v in value.items():
size += get_type_size(types, v, k)
if is_type_array(types, name, value):
count = extract_array_count(value, name)
if type(count) is int:
return size * count
else:
# TODO: handling size for variable arrays...
return size
return size
typename = value
tp = get_type_info(types, typename)
if not 'size' in tp:
if 'alias' in tp:
return get_type_size(types, tp['alias'])
else:
warn('ERROR: Non-alias type "{}" does not have "size" field.'.format(typename))
exit(1)
size = tp['size']
if size == 'count':
size = extract_array_count(typename)
if not type(size) is int:
# TODO: how do we handle type size for variably sized arrays...
return 1
if not size:
warn('ERROR: Cannot get size for type "{}"'.format(typename))
exit(1)
#print("get_type_size: {} -> {}".format(typename, size))
return int(size)
def get_rw_func(types, typename, read):
method = 'Read' if read else 'Write'
#print("{}: ({}, {})".format(method, typename, read))
tp = get_type_info(types, typename)
if 'method' in tp:
method += tp['method']
generic = 'generic' in tp and tp['generic']
alias = 'alias' in tp
if (generic and alias) or not 'method' in tp:
method += '<' + tp['type'] + '>'
#print('{} -> {}'.format(typename, method))
return method
def print_rw_logic(types, value, name, read, prefix=''):
global enum_vars
resolved_name = prefix + resolve_name(types, name, value)
if isinstance(value, dict):
is_array = is_type_array(types, name, value)
if is_array:
count = extract_array_count(value, name)
add_text(f'for (int i = 0; i < {prefix}{count}; i++)')
switch = get_switch_condition(name)
is_switch = switch != None
case_prefix = ''
if is_switch:
add_text(f'switch ({switch})')
if switch in enum_vars:
case_prefix = enum_vars[switch] + '::'
# invalid: .switch(#value) {}
if is_array and is_switch:
warn(f'ERROR: Invalid syntax, ambiguity between .switch statement and object(#count) array: "{name}"')
exit(1)
add_text('{{')
indent()
for k, v in value.items():
if not is_switch:
scoped_name = resolved_name
if is_array:
scoped_name += '[i]'
scoped_name += '.'
print_rw_logic(types, v, k, read, prefix=scoped_name)
else:
# switch cases
add_text(f'case ({case_prefix}{k}):')
add_text('{{')
indent()
if isinstance(v, dict): # print_definition() already warns if this is false
if len(v) > 0:
for k1, v1 in v.items():
prefix_case = f'{prefix}{k}.'
print_rw_logic(types, v1, k1, read, prefix=prefix_case)
else:
add_text('// no other fields')
#print_rw_logic(types, v, name, read, prefix=prefix)
add_text('break;')
unindent()
add_text('}}')
unindent()
add_text('}}')
return
typename = value
tp = get_type_info(types, typename)
if 'enum' in tp and tp['enum']:
# record any variables that are enums for this message.
# this is used to resolve scope in switch-case blocks
enum_vars[name] = typename
start = name.find('(')
local_indent = 0
if start != -1:
end = name.rfind(')')
if end == -1 or end < start:
warn('ERROR: Invalid variable name syntax "{}"'.format(name))
exit(1)
# conditional
if name[end - 1] == '?':
cond = name[start+1:end-1]
add_text(f'if ({prefix}{cond})')
indent()
local_indent += 1
rw_func = get_rw_func(types, typename, read)
if is_type_array(types, name, typename):
# arrays
count = extract_array_count(typename)
arrtype = extract_array_type(typename)
decl_name = resolve_name(types, name, typename, decl=True)
# if the array is variable size then we currently have to allocate
if read and decl_name.startswith('*'):
# TODO: avoid dynamic allocation here?
# TODO: delete[] this on deconstruct
add_text('{} = new {}[{}];',
resolved_name,
arrtype,
count
)
add_text(f'for (int i = 0; i < {prefix}{count}; i++)')
indent()
if read:
add_text('{}[i] = reader.{}();',
resolved_name,
rw_func
)
else:
# TODO: bounds checks
add_text('msg.{}({}[i]);',
rw_func,
resolved_name
)
unindent()
else:
if read:
add_text('{} = reader.{}();',
resolved_name,
rw_func
)
else:
add_text('msg.{}({});',
rw_func,
resolved_name
)
for i in range(local_indent):
unindent()
def print_definition(types, name, value, serverbound):
resolved_name = resolve_name(types, name, value, decl=True)
if isinstance(value, dict):
switch = get_switch_condition(name)
if switch != None:
# .switch(cond) {} block
newline()
add_text(f'union {{{{ // switch ({switch})')
indent()
for k, v in value.items():
if not isinstance(v, dict):
warn(f'WARNING: Switch case statement "{k}" (in switch block "{switch}") of non-object type.')
continue
# switch case block
newline()
add_text('struct {{')
indent()
if len(v) == 0:
add_text('// no other fields')
for k1, v1 in v.items():
print_definition(types, k1, v1, serverbound)
unindent()
add_text(f'}}}} {k};')
unindent()
add_text('}};')
else:
# object definition
add_text('struct {{')
indent()
for k, v in value.items():
print_definition(types, k, v, serverbound)
unindent()
add_text('}} {};', resolved_name)
else:
resolved_type = resolve_type(types, value)
add_text('{} {};', resolved_type, resolved_name)
def get_enum_type(dict, fallback = 'int32'):
return dict['.type'] if '.type' in dict else fallback
def print_enum(name, dict, types={}, fallback = None):
primType = get_enum_type(dict, fallback)
if primType == None:
# if we have nothing to resolve with and no fallback then use int32_t
primitive = 'int32_t'
else:
primitive = resolve_type(types, primType)
add_text('enum class {} : {}', name, primitive)
add_text('{{')
indent()
for key, value in dict.items():
if key.startswith('.'):
continue # ignore metadata
add_text('{} = {},', key, value)
unindent()
add_text('}};')
def print_messages(list, globalTypes):
global text
global enum_vars
for state, direction_list in list.items():
add_text('namespace {}', state.capitalize())
add_text('{{')
indent()
for direction, messages in direction_list.items():
serverbound = direction == 'serverbound'
for message_name, message in messages.items():
# clear out enum vars from previous message
enum_vars = {}
# global and local types
types = globalTypes.copy()
# add enums to local types
if 'enums' in message:
for enum in message['enums']:
types[enum] = {'alias': get_enum_type(message['enums'][enum]), 'enum': True}
global varint_max_size
size = varint_max_size # Packet Length
size += varint_max_size # Packet Id
for name, typename in message['vars'].items():
size += get_type_size(types, typename, name)
struct_name = '{}{}'.format(direction.capitalize(), message_name)
add_text('struct {}', struct_name)
add_text('{{')
indent()
add_text('static constexpr int32_t PacketId = {};', message['id'])
add_text('static constexpr bool Serverbound = {};', 'true' if serverbound else 'false')
add_text('static constexpr ProtocolState PacketState = ProtocolState::{};', state.capitalize())
add_text('static constexpr size_t MaxSize = {};', size)
newline()
if 'enums' in message:
for enum in message['enums']:
print_enum(enum, message['enums'][enum], types)
newline()
if serverbound:
add_text('{}(PacketReader& reader)', struct_name)
add_text('{{')
indent()
for name, typename in message['vars'].items():
print_rw_logic(types, typename, name, True)
unindent()
add_text('}}')
newline()
else:
add_text('operator NetworkMessage() const')
add_text('{{')
indent()
add_text('NetworkMessage msg(MaxSize);')
add_text('msg.WriteVarInt(PacketId);')
for name, typename in message['vars'].items():
print_rw_logic(types, typename, name, False)
add_text('msg.Finalize();')
add_text('return msg;')
unindent()
add_text('}}')
for name, value in message['vars'].items():
print_definition(types, name, value, serverbound)
unindent()
add_text('}};')
newline()
unindent()
add_text('}}')
newline()
def print_handler(list):
add_text('template <typename HandlerType, typename ClientType>')
add_text('void ProcessPacket(HandlerType& handler, ClientType& client, PacketReader& packet, ProtocolState state)')
add_text('{{')
indent()
add_text('if (packet.IsLegacyPing())')
add_text('{{')
indent()
add_text('handler.HandleLegacyPing(client);')
add_text('return;')
unindent()
add_text('}}')
newline()
add_text('const int32_t packetId = packet.ReadVarInt();')
add_text('switch (state)')
add_text('{{')
indent()
for state, direction_list in list.items():
add_text('case ProtocolState::{}:', state.capitalize())
add_text('{{')
indent()
add_text('switch (packetId)')
add_text('{{')
indent()
for direction, messages in direction_list.items():
serverbound = direction == 'serverbound'
if not serverbound:
continue
for message_name, message in messages.items():
name = '{}::Serverbound{}'.format(state.capitalize(), message_name)
add_text('case {}::PacketId:', name)
indent()
# Here we need the obscure `template` disambiguator for dependent names
add_text('handler.template HandlePacket<{}>(client, {}(packet));', name, name)
add_text('break;')
unindent()
# handle unknown packets
add_text('default:')
indent()
add_text('handler.HandleUnknownPacket(client, packetId, packet);')
add_text('break;')
unindent()
unindent()
add_text('}}')
add_text('break;')
unindent()
add_text('}}')
unindent()
add_text('}}')
unindent()
add_text('}}')
def print_protocol():
with open(sys.argv[1]) as message_file:
message_scheme = hjson.load(message_file)
print_enum('ProtocolState', message_scheme['states'])
newline()
print_messages(
message_scheme['messages'],
message_scheme['types']
)
newline()
print_handler(message_scheme['messages'])
def main():
global text
if len(sys.argv) != 3:
print('Usage: generate_protocol.py <input: protocol.hjson> <output: ProtocolDefinitions.h>')
return
add_text('#pragma once')
newline()
add_text('#include <cstdint>')
add_text('#include "Types.h"')
add_text('#include "PacketReader.h"')
add_text('#include "NetworkMessage.h"')
add_text('#include <string_view>')
newline()
add_text('namespace Feather::Protocol')
add_text('{{')
indent()
print_protocol()
unindent()
add_text('}}')
newline()
with open(sys.argv[2], 'w') as out_file:
out_file.write(text)
out_file.close()
main()