import hjson import sys text = '' varint_max_size = 5 indent_count = 0 # Print to stderr def warn(value, *args, sep='', end='\n', flush=False): print(value, *args, sep=sep, end=end, file=sys.stderr, flush=flush) def indent(): global indent_count indent_count += 1 def unindent(): global indent_count indent_count -= 1 def add_text(fmt, *args): global text global indent_count for i in range(indent_count): text += ' ' text += (fmt + '\n').format(*args) def newline(): global text text += '\n' def extract_array_type(type): if '[' in type and ']' in type: return type[:type.rfind('[')] return type def extract_array_count(type): if '[' in type and ']' in type: return int(type[type.find('[')+1:type.rfind(']')]) return None def get_type_info(types, typename, aliases={}): name = extract_array_type(typename) alias_name = None if aliases and name in aliases: alias_name = name name = aliases[name] if not name in types: warn('WARNING: Type name "{}" is not a known type or alias.'.format(name)) return { type: name } tp = types[name] if 'alias' in tp: alias = tp['alias'] tp = get_type_info(types, tp['alias']).copy() tp['alias'] = alias if tp['generic'] == True: tp['type'] = name if alias_name: tp = tp.copy() tp['alias'] = name tp['type'] = alias_name # types without a 'type' value use their own name (e.g. double, float) if not 'type' in tp: tp['type'] = name return tp def resolve_type(types, typename, aliases={}): tp = get_type_info(types, typename, aliases) if 'type' in tp: return tp['type'] else: return typename def get_type_size(types, typename, aliases={}): global varint_max_size tp = get_type_info(types, typename, aliases) if not 'size' in tp: if 'alias' in tp: return get_type_size(types, tp['alias']) else: warn('ERROR: Non-alias type "{}" does not have "size" field.'.format(typename)) exit(1) size = tp['size'] if size == 'count': size = extract_array_count(typename) if not size: warn('ERROR: Cannot get size for type "{}"'.format(typename)) exit(1) #print("get_type_size: {} -> {}".format(typename, size)) return int(size) def print_enum(name, dict, primitive = 'int32_t'): add_text('enum class {} : {}', name, primitive) add_text('{{') indent() for key, value in dict.items(): add_text('{} = {},', key, value) unindent() add_text('}};') def get_rw_func(types, typename, read, aliases={}): method = 'Read' if read else 'Write' #print("{}: ({}, {})".format(method, typename, read)) tp = get_type_info(types, typename, aliases) if 'method' in tp: method += tp['method'] generic = 'generic' in tp and tp['generic'] alias = 'alias' in tp if (generic and alias) or not 'method' in tp: method += '<' + tp['type'] + '>' #print('{} -> {}'.format(typename, method)) return method def print_messages(list, types): global text for state, direction_list in list.items(): add_text('namespace {}', state.capitalize()) add_text('{{') indent() for direction, messages in direction_list.items(): serverbound = direction == 'serverbound' for message_name, message in messages.items(): # add any local aliases aliases = {} if 'aliases' in message: for alias in message['aliases']: aliases[alias] = message['aliases'][alias] global varint_max_size size = varint_max_size # Packet Length size += varint_max_size # Packet Id for name, typename in message['vars'].items(): size += get_type_size(types, typename, aliases) struct_name = '{}{}'.format(direction.capitalize(), message_name) add_text('struct {}', struct_name) add_text('{{') indent() add_text('static constexpr int32_t PacketId = {};', message['id']) add_text('static constexpr bool Serverbound = {};', 'true' if serverbound else 'false') add_text('static constexpr ProtocolState PacketState = ProtocolState::{};', state.capitalize()) add_text('static constexpr size_t MaxSize = {};', size) newline() if 'enums' in message: for enum in message['enums']: if enum in aliases: # lookup enum primitive type from aliases prim = resolve_type(types, aliases[enum]) print_enum(enum, message['enums'][enum], prim) else: print_enum(enum, message['enums'][enum]) newline() if serverbound: add_text('{}(PacketReader& reader)', struct_name) add_text('{{') indent() for name, typename in message['vars'].items(): add_text('{} = reader.{}();', name, get_rw_func(types, typename, True, aliases) ) unindent() add_text('}}') newline() else: add_text('operator NetworkMessage() const') add_text('{{') indent() add_text('NetworkMessage msg(MaxSize);') add_text('msg.WriteVarInt(PacketId);') for name, typename in message['vars'].items(): add_text('msg.{}({});', get_rw_func(types, typename, False, aliases), name ) add_text('msg.Finalize();') add_text('return msg;') unindent() add_text('}}') for name, typename in message['vars'].items(): resolved_type = resolve_type(types, extract_array_type(typename), aliases) if not serverbound and resolved_type == 'std::string': add_text('const {}& {};', resolved_type, name) else: add_text('{} {};', resolved_type, name) unindent() add_text('}};') newline() unindent() add_text('}}') newline() def print_handler(list): add_text('template ') add_text('void ProcessPacket(HandlerType& handler, ClientType& client, PacketReader& packet, ProtocolState state)') add_text('{{') indent() add_text('if (packet.IsLegacyPing())') add_text('{{') indent() add_text('handler.HandleLegacyPing(client);') add_text('return;') unindent() add_text('}}') newline() add_text('const int32_t packetId = packet.ReadVarInt();') add_text('switch (state)') add_text('{{') indent() for state, direction_list in list.items(): add_text('case ProtocolState::{}:', state.capitalize()) add_text('{{') indent() add_text('switch (packetId)') add_text('{{') indent() for direction, messages in direction_list.items(): serverbound = direction == 'serverbound' if not serverbound: continue for message_name, message in messages.items(): name = '{}::Serverbound{}'.format(state.capitalize(), message_name) add_text('case {}::PacketId:', name) indent() # Here we need the obscure `template` disambiguator for dependent names add_text('handler.template HandlePacket<{}>(client, {}(packet));', name, name) add_text('break;') unindent() # handle unknown packets add_text('default:') indent() add_text('handler.HandleUnknownPacket(client, packetId, packet);') add_text('break;') unindent() unindent() add_text('}}') add_text('break;') unindent() add_text('}}') unindent() add_text('}}') unindent() add_text('}}') def print_protocol(): with open(sys.argv[1]) as message_file: message_scheme = hjson.load(message_file) print_enum('ProtocolState', message_scheme['states']) newline() print_messages( message_scheme['messages'], message_scheme['types'] ) newline() print_handler(message_scheme['messages']) def main(): global text if len(sys.argv) != 3: print('Usage: generate_protocol.py ') return add_text('#pragma once') newline() add_text('#include ') add_text('#include "Types.h"') add_text('#include "PacketReader.h"') newline() add_text('namespace Feather::Protocol') add_text('{{') indent() print_protocol() unindent() add_text('}}') newline() with open(sys.argv[2], 'w') as out_file: out_file.write(text) out_file.close() main()