FeatherMC/src/protocol/generate_protocol.py

257 lines
8.5 KiB
Python

import hjson
import sys
text = ''
varint_max_size = 5
indent_count = 0
def indent():
global indent_count
indent_count += 1
def unindent():
global indent_count
indent_count -= 1
def add_text(fmt, *args):
global text
global indent_count
for i in range(indent_count):
text += ' '
text += (fmt + '\n').format(*args)
def newline():
global text
text += '\n'
def extract_array_type(type):
if '[' in type and ']' in type:
return type[:type.rfind('[')]
return type
def extract_array_count(type):
if '[' in type and ']' in type:
return int(type[type.find('[')+1:type.rfind(']')])
return 1
def resolve_type(aliases, type):
type_no_array = type.rstrip('_')
if type_no_array in aliases:
return aliases[type_no_array]
return type
def get_type_size(type):
global varint_max_size
count = extract_array_count(type)
type = extract_array_type(type)
# TODO: encode type sizes in Hjson data
if type == 'varint':
return varint_max_size
elif type == 'int64' or type == 'uint64' or type == 'double' or type == 'position':
return 8
elif type == 'int32' or type == 'uint32' or type == 'float':
return 4
elif type == 'int16' or type == 'uint16':
return 2
elif type == 'int8' or type == 'uint8' or type == 'byte' or type == 'bool':
return 1
elif type == 'string':
return count
elif type == 'uuid':
return 16
else:
print(type)
assert False
def print_enum(name, dict, primitive = 'int32_t'):
add_text('enum class {} : {}', name, primitive)
add_text('{{')
indent()
for key, value in dict.items():
add_text('{} = {},', key, value)
unindent()
add_text('}};')
def get_rw_func(primitiveType, aliasedType, read):
prefix = 'Read' if read else 'Write'
if aliasedType == 'varint':
if primitiveType == 'int32_t':
return '{}VarInt'.format(prefix)
else:
return '{}VarInt<{}>'.format(prefix, primitiveType)
elif aliasedType == 'string':
return '{}String'.format(prefix)
elif aliasedType == 'position':
return '{}Position'.format(prefix)
else:
return '{}<{}>'.format(prefix, primitiveType)
def print_messages(list, global_aliases, primitives):
global text
for state, direction_list in list.items():
add_text('namespace {}', state.capitalize())
add_text('{{')
indent()
for direction, messages in direction_list.items():
serverbound = direction == 'serverbound'
for message_name, message in messages.items():
# global and local aliases
aliases = global_aliases.copy()
# add any local aliases
if 'aliases' in message:
for alias in message['aliases']:
aliases[alias] = message['aliases'][alias]
global varint_max_size
size = varint_max_size # Packet Length
size += varint_max_size # Packet Id
for name, type in message['vars'].items():
size += get_type_size(resolve_type(aliases, type))
struct_name = '{}{}'.format(direction.capitalize(), message_name)
add_text('struct {}', struct_name)
add_text('{{')
indent()
add_text('static constexpr int32_t PacketId = {};', message['id'])
add_text('static constexpr bool Serverbound = {};', 'true' if serverbound else 'false')
add_text('static constexpr ProtocolState PacketState = ProtocolState::{};', state.capitalize())
add_text('static constexpr size_t MaxSize = {};', size)
newline()
if 'enums' in message:
for enum in message['enums']:
if enum in aliases:
# lookup enum primitive type from aliases
prim = resolve_type(primitives, aliases[enum])
print_enum(enum, message['enums'][enum], prim)
else:
print_enum(enum, message['enums'][enum])
newline()
if serverbound:
add_text('{}(PacketReader& reader)', struct_name)
add_text('{{')
indent()
for name, type in message['vars'].items():
add_text('{} = reader.{}();', name, get_rw_func(resolve_type(primitives, extract_array_type(type)), resolve_type(aliases, extract_array_type(type)), True))
unindent()
add_text('}}')
newline()
else:
add_text('operator NetworkMessage() const')
add_text('{{')
indent()
add_text('NetworkMessage msg(MaxSize);')
add_text('msg.WriteVarInt(PacketId);')
for name, type in message['vars'].items():
add_text('msg.{}({});', get_rw_func(resolve_type(primitives, extract_array_type(type)), resolve_type(aliases, extract_array_type(type)), False), name)
add_text('msg.Finalize();')
add_text('return msg;')
unindent()
add_text('}}')
for name, type in message['vars'].items():
resolved_type = resolve_type(primitives, extract_array_type(type))
if not serverbound and resolved_type == 'std::string':
add_text('const {}& {};', resolved_type, name)
else:
add_text('{} {};', resolved_type, name)
unindent()
add_text('}};')
newline()
unindent()
add_text('}}')
newline()
def print_handler(list):
add_text('template <typename HandlerType, typename ClientType>')
add_text('void ProcessPacket(HandlerType& handler, ClientType& client, PacketReader& packet, ProtocolState state)')
add_text('{{')
indent()
add_text('if (packet.IsLegacyPing())')
add_text('{{')
indent()
add_text('handler.HandleLegacyPing(client);')
add_text('return;')
unindent()
add_text('}}')
newline()
add_text('const int32_t packetId = packet.ReadVarInt();')
add_text('switch (state)')
add_text('{{')
indent()
for state, direction_list in list.items():
add_text('case ProtocolState::{}:', state.capitalize())
add_text('{{')
indent()
add_text('switch (packetId)')
add_text('{{')
indent()
for direction, messages in direction_list.items():
serverbound = direction == 'serverbound'
if not serverbound:
continue
for message_name, message in messages.items():
name = '{}::Serverbound{}'.format(state.capitalize(), message_name)
add_text('case {}::PacketId:', name)
indent()
# Here we need the obscure `template` disambiguator for dependent names
add_text('handler.template HandlePacket<{}>(client, {}(packet));', name, name)
add_text('break;')
unindent()
unindent()
add_text('}}')
add_text('break;')
unindent()
add_text('}}')
unindent()
add_text('}}')
unindent()
add_text('}}')
def print_protocol():
with open(sys.argv[1]) as message_file:
message_scheme = hjson.load(message_file)
print_enum('ProtocolState', message_scheme['states'])
newline()
print_messages(
message_scheme['messages'],
message_scheme['types']['aliases'],
message_scheme['types']['primitives'])
newline()
print_handler(message_scheme['messages'])
def main():
global text
if len(sys.argv) != 3:
print('Usage: generate_protocol.py <input: protocol.hjson> <output: ProtocolDefinitions.h>')
return
add_text('#pragma once')
newline()
add_text('#include <cstdint>')
add_text('#include "Types.h"')
add_text('#include "PacketReader.h"')
newline()
add_text('namespace Feather::Protocol')
add_text('{{')
indent()
print_protocol()
unindent()
add_text('}}')
newline()
with open(sys.argv[2], 'w') as out_file:
out_file.write(text)
out_file.close()
main()