339 lines
10 KiB
Python
339 lines
10 KiB
Python
import hjson
|
|
import sys
|
|
|
|
text = ''
|
|
varint_max_size = 5
|
|
indent_count = 0
|
|
|
|
# Print to stderr
|
|
def warn(value, *args, sep='', end='\n', flush=False):
|
|
print(value, *args, sep=sep, end=end, file=sys.stderr, flush=flush)
|
|
|
|
def indent():
|
|
global indent_count
|
|
indent_count += 1
|
|
|
|
def unindent():
|
|
global indent_count
|
|
indent_count -= 1
|
|
|
|
def add_text(fmt, *args):
|
|
global text
|
|
global indent_count
|
|
for i in range(indent_count):
|
|
text += ' '
|
|
text += (fmt + '\n').format(*args)
|
|
|
|
def newline():
|
|
global text
|
|
text += '\n'
|
|
|
|
def extract_array_type(type):
|
|
if '[' in type and ']' in type:
|
|
return type[:type.rfind('[')]
|
|
return type
|
|
|
|
def extract_array_count(type):
|
|
if '[' in type and ']' in type:
|
|
return int(type[type.find('[')+1:type.rfind(']')])
|
|
return None
|
|
|
|
def get_type_info(types, typename, aliases={}):
|
|
name = extract_array_type(typename)
|
|
|
|
alias_name = None
|
|
if aliases and name in aliases:
|
|
alias_name = name
|
|
name = aliases[name]
|
|
|
|
if not name in types:
|
|
warn('WARNING: Type name "{}" is not a known type or alias.'.format(name))
|
|
return {
|
|
type: name
|
|
}
|
|
|
|
tp = types[name]
|
|
|
|
if 'alias' in tp:
|
|
alias = tp['alias']
|
|
tp = get_type_info(types, tp['alias']).copy()
|
|
tp['alias'] = alias
|
|
generic = 'generic' in tp and tp['generic'] == True
|
|
# primitive types are always generic
|
|
if generic or not 'method' in tp:
|
|
tp['type'] = name
|
|
|
|
if alias_name:
|
|
tp = tp.copy()
|
|
tp['alias'] = name
|
|
tp['type'] = alias_name
|
|
|
|
# types without a 'type' value use their own name (e.g. double, float)
|
|
if not 'type' in tp:
|
|
tp['type'] = name
|
|
|
|
return tp
|
|
|
|
def resolve_type(types, typename, aliases={}):
|
|
tp = get_type_info(types, typename, aliases)
|
|
if 'type' in tp:
|
|
return tp['type']
|
|
else:
|
|
return typename
|
|
|
|
def get_type_size(types, typename, aliases={}):
|
|
global varint_max_size
|
|
|
|
tp = get_type_info(types, typename, aliases)
|
|
|
|
if not 'size' in tp:
|
|
if 'alias' in tp:
|
|
return get_type_size(types, tp['alias'])
|
|
else:
|
|
warn('ERROR: Non-alias type "{}" does not have "size" field.'.format(typename))
|
|
exit(1)
|
|
|
|
size = tp['size']
|
|
|
|
if size == 'count':
|
|
size = extract_array_count(typename)
|
|
|
|
if not size:
|
|
warn('ERROR: Cannot get size for type "{}"'.format(typename))
|
|
exit(1)
|
|
|
|
|
|
#print("get_type_size: {} -> {}".format(typename, size))
|
|
|
|
return int(size)
|
|
|
|
def get_rw_func(types, typename, read, aliases={}):
|
|
method = 'Read' if read else 'Write'
|
|
|
|
#print("{}: ({}, {})".format(method, typename, read))
|
|
|
|
tp = get_type_info(types, typename, aliases)
|
|
|
|
if 'method' in tp:
|
|
method += tp['method']
|
|
|
|
generic = 'generic' in tp and tp['generic']
|
|
alias = 'alias' in tp
|
|
|
|
if (generic and alias) or not 'method' in tp:
|
|
method += '<' + tp['type'] + '>'
|
|
|
|
#print('{} -> {}'.format(typename, method))
|
|
|
|
return method
|
|
|
|
def get_enum_type(dict, fallback = 'int32'):
|
|
return dict['.type'] if '.type' in dict else fallback
|
|
|
|
def print_enum(name, dict, types={}, fallback = None):
|
|
primType = get_enum_type(dict, fallback)
|
|
|
|
if primType == None:
|
|
# if we have nothing to resolve with and no fallback then use int32_t
|
|
primitive = 'int32_t'
|
|
else:
|
|
primitive = resolve_type(types, primType)
|
|
|
|
add_text('enum class {} : {}', name, primitive)
|
|
add_text('{{')
|
|
indent()
|
|
for key, value in dict.items():
|
|
if key.startswith('.'):
|
|
continue # ignore metadata
|
|
add_text('{} = {},', key, value)
|
|
unindent()
|
|
add_text('}};')
|
|
|
|
def print_messages(list, globalTypes):
|
|
global text
|
|
for state, direction_list in list.items():
|
|
add_text('namespace {}', state.capitalize())
|
|
add_text('{{')
|
|
indent()
|
|
for direction, messages in direction_list.items():
|
|
serverbound = direction == 'serverbound'
|
|
for message_name, message in messages.items():
|
|
|
|
# global and local types
|
|
types = globalTypes.copy()
|
|
|
|
# add any local aliases
|
|
aliases = {}
|
|
if 'aliases' in message:
|
|
for alias in message['aliases']:
|
|
aliases[alias] = message['aliases'][alias]
|
|
|
|
# add enums to local types
|
|
if 'enums' in message:
|
|
for enum in message['enums']:
|
|
if not enum in aliases:
|
|
types[enum] = {'alias': get_enum_type(message['enums'][enum])}
|
|
|
|
global varint_max_size
|
|
size = varint_max_size # Packet Length
|
|
size += varint_max_size # Packet Id
|
|
for name, typename in message['vars'].items():
|
|
size += get_type_size(types, typename, aliases)
|
|
|
|
struct_name = '{}{}'.format(direction.capitalize(), message_name)
|
|
add_text('struct {}', struct_name)
|
|
add_text('{{')
|
|
indent()
|
|
add_text('static constexpr int32_t PacketId = {};', message['id'])
|
|
add_text('static constexpr bool Serverbound = {};', 'true' if serverbound else 'false')
|
|
add_text('static constexpr ProtocolState PacketState = ProtocolState::{};', state.capitalize())
|
|
add_text('static constexpr size_t MaxSize = {};', size)
|
|
newline()
|
|
if 'enums' in message:
|
|
for enum in message['enums']:
|
|
if enum in aliases:
|
|
# lookup enum primitive type from aliases
|
|
prim = resolve_type(types, aliases[enum])
|
|
print_enum(enum, message['enums'][enum], types, prim)
|
|
else:
|
|
print_enum(enum, message['enums'][enum], types)
|
|
newline()
|
|
if serverbound:
|
|
add_text('{}(PacketReader& reader)', struct_name)
|
|
add_text('{{')
|
|
indent()
|
|
for name, typename in message['vars'].items():
|
|
add_text('{} = reader.{}();',
|
|
name,
|
|
get_rw_func(types, typename, True, aliases)
|
|
)
|
|
unindent()
|
|
add_text('}}')
|
|
newline()
|
|
else:
|
|
add_text('operator NetworkMessage() const')
|
|
add_text('{{')
|
|
indent()
|
|
add_text('NetworkMessage msg(MaxSize);')
|
|
add_text('msg.WriteVarInt(PacketId);')
|
|
for name, typename in message['vars'].items():
|
|
add_text('msg.{}({});',
|
|
get_rw_func(types, typename, False, aliases),
|
|
name
|
|
)
|
|
add_text('msg.Finalize();')
|
|
add_text('return msg;')
|
|
unindent()
|
|
add_text('}}')
|
|
|
|
for name, typename in message['vars'].items():
|
|
resolved_type = resolve_type(types, extract_array_type(typename), aliases)
|
|
if not serverbound and resolved_type == 'std::string':
|
|
add_text('const {}& {};', resolved_type, name)
|
|
else:
|
|
add_text('{} {};', resolved_type, name)
|
|
|
|
unindent()
|
|
add_text('}};')
|
|
newline()
|
|
unindent()
|
|
add_text('}}')
|
|
newline()
|
|
|
|
def print_handler(list):
|
|
add_text('template <typename HandlerType, typename ClientType>')
|
|
add_text('void ProcessPacket(HandlerType& handler, ClientType& client, PacketReader& packet, ProtocolState state)')
|
|
add_text('{{')
|
|
indent()
|
|
add_text('if (packet.IsLegacyPing())')
|
|
add_text('{{')
|
|
indent()
|
|
add_text('handler.HandleLegacyPing(client);')
|
|
add_text('return;')
|
|
unindent()
|
|
add_text('}}')
|
|
newline()
|
|
add_text('const int32_t packetId = packet.ReadVarInt();')
|
|
add_text('switch (state)')
|
|
add_text('{{')
|
|
indent()
|
|
for state, direction_list in list.items():
|
|
add_text('case ProtocolState::{}:', state.capitalize())
|
|
add_text('{{')
|
|
indent()
|
|
add_text('switch (packetId)')
|
|
add_text('{{')
|
|
indent()
|
|
for direction, messages in direction_list.items():
|
|
serverbound = direction == 'serverbound'
|
|
if not serverbound:
|
|
continue
|
|
for message_name, message in messages.items():
|
|
name = '{}::Serverbound{}'.format(state.capitalize(), message_name)
|
|
add_text('case {}::PacketId:', name)
|
|
indent()
|
|
# Here we need the obscure `template` disambiguator for dependent names
|
|
add_text('handler.template HandlePacket<{}>(client, {}(packet));', name, name)
|
|
add_text('break;')
|
|
unindent()
|
|
# handle unknown packets
|
|
add_text('default:')
|
|
indent()
|
|
add_text('handler.HandleUnknownPacket(client, packetId, packet);')
|
|
add_text('break;')
|
|
unindent()
|
|
|
|
unindent()
|
|
add_text('}}')
|
|
add_text('break;')
|
|
unindent()
|
|
add_text('}}')
|
|
unindent()
|
|
add_text('}}')
|
|
unindent()
|
|
add_text('}}')
|
|
|
|
|
|
def print_protocol():
|
|
with open(sys.argv[1]) as message_file:
|
|
message_scheme = hjson.load(message_file)
|
|
|
|
print_enum('ProtocolState', message_scheme['states'])
|
|
|
|
newline()
|
|
|
|
print_messages(
|
|
message_scheme['messages'],
|
|
message_scheme['types']
|
|
)
|
|
|
|
newline()
|
|
|
|
print_handler(message_scheme['messages'])
|
|
|
|
def main():
|
|
global text
|
|
|
|
if len(sys.argv) != 3:
|
|
print('Usage: generate_protocol.py <input: protocol.hjson> <output: ProtocolDefinitions.h>')
|
|
return
|
|
|
|
add_text('#pragma once')
|
|
newline()
|
|
add_text('#include <cstdint>')
|
|
add_text('#include "Types.h"')
|
|
add_text('#include "PacketReader.h"')
|
|
newline()
|
|
add_text('namespace Feather::Protocol')
|
|
add_text('{{')
|
|
indent()
|
|
print_protocol()
|
|
unindent()
|
|
add_text('}}')
|
|
newline()
|
|
|
|
with open(sys.argv[2], 'w') as out_file:
|
|
out_file.write(text)
|
|
out_file.close()
|
|
|
|
main() |