FeatherMC/src/protocol/generate_protocol.py

339 lines
10 KiB
Python

import hjson
import sys
text = ''
varint_max_size = 5
indent_count = 0
# Print to stderr
def warn(value, *args, sep='', end='\n', flush=False):
print(value, *args, sep=sep, end=end, file=sys.stderr, flush=flush)
def indent():
global indent_count
indent_count += 1
def unindent():
global indent_count
indent_count -= 1
def add_text(fmt, *args):
global text
global indent_count
for i in range(indent_count):
text += ' '
text += (fmt + '\n').format(*args)
def newline():
global text
text += '\n'
def extract_array_type(type):
if '[' in type and ']' in type:
return type[:type.rfind('[')]
return type
def extract_array_count(type):
if '[' in type and ']' in type:
return int(type[type.find('[')+1:type.rfind(']')])
return None
def get_type_info(types, typename, aliases={}):
name = extract_array_type(typename)
alias_name = None
if aliases and name in aliases:
alias_name = name
name = aliases[name]
if not name in types:
warn('WARNING: Type name "{}" is not a known type or alias.'.format(name))
return {
type: name
}
tp = types[name]
if 'alias' in tp:
alias = tp['alias']
tp = get_type_info(types, tp['alias']).copy()
tp['alias'] = alias
generic = 'generic' in tp and tp['generic'] == True
# primitive types are always generic
if generic or not 'method' in tp:
tp['type'] = name
if alias_name:
tp = tp.copy()
tp['alias'] = name
tp['type'] = alias_name
# types without a 'type' value use their own name (e.g. double, float)
if not 'type' in tp:
tp['type'] = name
return tp
def resolve_type(types, typename, aliases={}):
tp = get_type_info(types, typename, aliases)
if 'type' in tp:
return tp['type']
else:
return typename
def get_type_size(types, typename, aliases={}):
global varint_max_size
tp = get_type_info(types, typename, aliases)
if not 'size' in tp:
if 'alias' in tp:
return get_type_size(types, tp['alias'])
else:
warn('ERROR: Non-alias type "{}" does not have "size" field.'.format(typename))
exit(1)
size = tp['size']
if size == 'count':
size = extract_array_count(typename)
if not size:
warn('ERROR: Cannot get size for type "{}"'.format(typename))
exit(1)
#print("get_type_size: {} -> {}".format(typename, size))
return int(size)
def get_rw_func(types, typename, read, aliases={}):
method = 'Read' if read else 'Write'
#print("{}: ({}, {})".format(method, typename, read))
tp = get_type_info(types, typename, aliases)
if 'method' in tp:
method += tp['method']
generic = 'generic' in tp and tp['generic']
alias = 'alias' in tp
if (generic and alias) or not 'method' in tp:
method += '<' + tp['type'] + '>'
#print('{} -> {}'.format(typename, method))
return method
def get_enum_type(dict, fallback = 'int32'):
return dict['.type'] if '.type' in dict else fallback
def print_enum(name, dict, types={}, fallback = None):
primType = get_enum_type(dict, fallback)
if primType == None:
# if we have nothing to resolve with and no fallback then use int32_t
primitive = 'int32_t'
else:
primitive = resolve_type(types, primType)
add_text('enum class {} : {}', name, primitive)
add_text('{{')
indent()
for key, value in dict.items():
if key.startswith('.'):
continue # ignore metadata
add_text('{} = {},', key, value)
unindent()
add_text('}};')
def print_messages(list, globalTypes):
global text
for state, direction_list in list.items():
add_text('namespace {}', state.capitalize())
add_text('{{')
indent()
for direction, messages in direction_list.items():
serverbound = direction == 'serverbound'
for message_name, message in messages.items():
# global and local types
types = globalTypes.copy()
# add any local aliases
aliases = {}
if 'aliases' in message:
for alias in message['aliases']:
aliases[alias] = message['aliases'][alias]
# add enums to local types
if 'enums' in message:
for enum in message['enums']:
if not enum in aliases:
types[enum] = {'alias': get_enum_type(message['enums'][enum])}
global varint_max_size
size = varint_max_size # Packet Length
size += varint_max_size # Packet Id
for name, typename in message['vars'].items():
size += get_type_size(types, typename, aliases)
struct_name = '{}{}'.format(direction.capitalize(), message_name)
add_text('struct {}', struct_name)
add_text('{{')
indent()
add_text('static constexpr int32_t PacketId = {};', message['id'])
add_text('static constexpr bool Serverbound = {};', 'true' if serverbound else 'false')
add_text('static constexpr ProtocolState PacketState = ProtocolState::{};', state.capitalize())
add_text('static constexpr size_t MaxSize = {};', size)
newline()
if 'enums' in message:
for enum in message['enums']:
if enum in aliases:
# lookup enum primitive type from aliases
prim = resolve_type(types, aliases[enum])
print_enum(enum, message['enums'][enum], types, prim)
else:
print_enum(enum, message['enums'][enum], types)
newline()
if serverbound:
add_text('{}(PacketReader& reader)', struct_name)
add_text('{{')
indent()
for name, typename in message['vars'].items():
add_text('{} = reader.{}();',
name,
get_rw_func(types, typename, True, aliases)
)
unindent()
add_text('}}')
newline()
else:
add_text('operator NetworkMessage() const')
add_text('{{')
indent()
add_text('NetworkMessage msg(MaxSize);')
add_text('msg.WriteVarInt(PacketId);')
for name, typename in message['vars'].items():
add_text('msg.{}({});',
get_rw_func(types, typename, False, aliases),
name
)
add_text('msg.Finalize();')
add_text('return msg;')
unindent()
add_text('}}')
for name, typename in message['vars'].items():
resolved_type = resolve_type(types, extract_array_type(typename), aliases)
if not serverbound and resolved_type == 'std::string':
add_text('const {}& {};', resolved_type, name)
else:
add_text('{} {};', resolved_type, name)
unindent()
add_text('}};')
newline()
unindent()
add_text('}}')
newline()
def print_handler(list):
add_text('template <typename HandlerType, typename ClientType>')
add_text('void ProcessPacket(HandlerType& handler, ClientType& client, PacketReader& packet, ProtocolState state)')
add_text('{{')
indent()
add_text('if (packet.IsLegacyPing())')
add_text('{{')
indent()
add_text('handler.HandleLegacyPing(client);')
add_text('return;')
unindent()
add_text('}}')
newline()
add_text('const int32_t packetId = packet.ReadVarInt();')
add_text('switch (state)')
add_text('{{')
indent()
for state, direction_list in list.items():
add_text('case ProtocolState::{}:', state.capitalize())
add_text('{{')
indent()
add_text('switch (packetId)')
add_text('{{')
indent()
for direction, messages in direction_list.items():
serverbound = direction == 'serverbound'
if not serverbound:
continue
for message_name, message in messages.items():
name = '{}::Serverbound{}'.format(state.capitalize(), message_name)
add_text('case {}::PacketId:', name)
indent()
# Here we need the obscure `template` disambiguator for dependent names
add_text('handler.template HandlePacket<{}>(client, {}(packet));', name, name)
add_text('break;')
unindent()
# handle unknown packets
add_text('default:')
indent()
add_text('handler.HandleUnknownPacket(client, packetId, packet);')
add_text('break;')
unindent()
unindent()
add_text('}}')
add_text('break;')
unindent()
add_text('}}')
unindent()
add_text('}}')
unindent()
add_text('}}')
def print_protocol():
with open(sys.argv[1]) as message_file:
message_scheme = hjson.load(message_file)
print_enum('ProtocolState', message_scheme['states'])
newline()
print_messages(
message_scheme['messages'],
message_scheme['types']
)
newline()
print_handler(message_scheme['messages'])
def main():
global text
if len(sys.argv) != 3:
print('Usage: generate_protocol.py <input: protocol.hjson> <output: ProtocolDefinitions.h>')
return
add_text('#pragma once')
newline()
add_text('#include <cstdint>')
add_text('#include "Types.h"')
add_text('#include "PacketReader.h"')
newline()
add_text('namespace Feather::Protocol')
add_text('{{')
indent()
print_protocol()
unindent()
add_text('}}')
newline()
with open(sys.argv[2], 'w') as out_file:
out_file.write(text)
out_file.close()
main()