235 lines
7.4 KiB
Python
235 lines
7.4 KiB
Python
import hjson
|
|
import sys
|
|
|
|
text = ''
|
|
varint_max_size = 5
|
|
indent_count = 0
|
|
|
|
def indent():
|
|
global indent_count
|
|
indent_count += 1
|
|
|
|
def unindent():
|
|
global indent_count
|
|
indent_count -= 1
|
|
|
|
def add_text(fmt, *args):
|
|
global text
|
|
global indent_count
|
|
for i in range(indent_count):
|
|
text += ' '
|
|
text += (fmt + '\n').format(*args)
|
|
|
|
def newline():
|
|
global text
|
|
text += '\n'
|
|
|
|
def extract_array_type(type):
|
|
if '[' in type and ']' in type:
|
|
return type[:type.rfind('[')]
|
|
return type
|
|
|
|
def extract_array_count(type):
|
|
if '[' in type and ']' in type:
|
|
return int(type[type.find('[')+1:type.rfind(']')])
|
|
return 1
|
|
|
|
def resolve_type(aliases, type):
|
|
type_no_array = type.rstrip('_')
|
|
if type_no_array in aliases:
|
|
return aliases[type_no_array]
|
|
return type
|
|
|
|
def get_type_size(type):
|
|
global varint_max_size
|
|
count = extract_array_count(type)
|
|
type = extract_array_type(type)
|
|
|
|
if type == 'varint':
|
|
return varint_max_size
|
|
elif type == 'int64' or type == 'uint64' or type == 'double':
|
|
return 8
|
|
elif type == 'int32' or type == 'uint32' or type == 'float':
|
|
return 4
|
|
elif type == 'int16' or type == 'uint16':
|
|
return 2
|
|
elif type == 'int8' or type == 'uint8' or type == 'bool':
|
|
return 1
|
|
elif type == 'string':
|
|
return count
|
|
elif type == 'uuid':
|
|
return 16
|
|
else:
|
|
print(type)
|
|
assert False
|
|
|
|
def print_states(states):
|
|
add_text('enum class ProtocolState : int32_t')
|
|
add_text('{{')
|
|
indent()
|
|
for state, value in states.items():
|
|
add_text('{} = {},', state, value)
|
|
unindent()
|
|
add_text('}};')
|
|
|
|
def get_rw_func(primitiveType, aliasedType, read):
|
|
prefix = 'Read' if read else 'Write'
|
|
if aliasedType == 'varint':
|
|
if primitiveType == 'int32_t':
|
|
return '{}VarInt'.format(prefix)
|
|
else:
|
|
return '{}VarInt<{}>'.format(prefix, primitiveType)
|
|
elif aliasedType == 'string':
|
|
return '{}String'.format(prefix)
|
|
else:
|
|
return '{}<{}>'.format(prefix, primitiveType)
|
|
|
|
def print_messages(list, aliases, primitives):
|
|
global text
|
|
for state, direction_list in list.items():
|
|
add_text('namespace {}', state.capitalize())
|
|
add_text('{{')
|
|
indent()
|
|
for direction, messages in direction_list.items():
|
|
serverbound = direction == 'serverbound'
|
|
for message_name, message in messages.items():
|
|
global varint_max_size
|
|
size = varint_max_size # Packet Length
|
|
size += varint_max_size # Packet Id
|
|
for name, type in message['vars'].items():
|
|
size += get_type_size(resolve_type(aliases, type))
|
|
|
|
struct_name = '{}{}'.format(direction.capitalize(), message_name)
|
|
add_text('struct {}', struct_name)
|
|
add_text('{{')
|
|
indent()
|
|
add_text('static constexpr int32_t PacketId = {};', message['id'])
|
|
add_text('static constexpr bool Serverbound = {};', 'true' if serverbound else 'false')
|
|
add_text('static constexpr ProtocolState PacketState = ProtocolState::{};', state.capitalize())
|
|
add_text('static constexpr size_t MaxSize = {};', size)
|
|
newline()
|
|
if serverbound:
|
|
add_text('{}(PacketReader& reader)', struct_name)
|
|
add_text('{{')
|
|
indent()
|
|
for name, type in message['vars'].items():
|
|
add_text('{} = reader.{}();', name, get_rw_func(resolve_type(primitives, extract_array_type(type)), resolve_type(aliases, extract_array_type(type)), True))
|
|
unindent()
|
|
add_text('}}')
|
|
newline()
|
|
else:
|
|
add_text('operator NetworkMessage() const')
|
|
add_text('{{')
|
|
indent()
|
|
add_text('NetworkMessage msg(MaxSize);')
|
|
add_text('msg.WriteVarInt(PacketId);')
|
|
for name, type in message['vars'].items():
|
|
add_text('msg.{}({});', get_rw_func(resolve_type(primitives, extract_array_type(type)), resolve_type(aliases, extract_array_type(type)), False), name)
|
|
add_text('msg.Finalize();')
|
|
add_text('return msg;')
|
|
unindent()
|
|
add_text('}}')
|
|
|
|
for name, type in message['vars'].items():
|
|
resolved_type = resolve_type(primitives, extract_array_type(type))
|
|
if not serverbound and resolved_type == 'std::string':
|
|
add_text('const {}& {};', resolved_type, name)
|
|
else:
|
|
add_text('{} {};', resolved_type, name)
|
|
|
|
unindent()
|
|
add_text('}};')
|
|
newline()
|
|
unindent()
|
|
add_text('}}')
|
|
newline()
|
|
|
|
def print_handler(list):
|
|
add_text('template <typename HandlerType, typename ClientType>')
|
|
add_text('void ProcessPacket(HandlerType& handler, ClientType& client, PacketReader& packet, ProtocolState state)')
|
|
add_text('{{')
|
|
indent()
|
|
add_text('if (packet.IsLegacyPing())')
|
|
add_text('{{')
|
|
indent()
|
|
add_text('handler.HandleLegacyPing(client);')
|
|
add_text('return;')
|
|
unindent()
|
|
add_text('}}')
|
|
newline()
|
|
add_text('const int32_t packetId = packet.ReadVarInt();')
|
|
add_text('switch (state)')
|
|
add_text('{{')
|
|
indent()
|
|
for state, direction_list in list.items():
|
|
add_text('case ProtocolState::{}:', state.capitalize())
|
|
add_text('{{')
|
|
indent()
|
|
add_text('switch (packetId)')
|
|
add_text('{{')
|
|
indent()
|
|
for direction, messages in direction_list.items():
|
|
serverbound = direction == 'serverbound'
|
|
if not serverbound:
|
|
continue
|
|
for message_name, message in messages.items():
|
|
name = '{}::Serverbound{}'.format(state.capitalize(), message_name)
|
|
add_text('case {}::PacketId:', name)
|
|
indent()
|
|
add_text('handler.HandlePacket<{}>(client, {}(packet));', name, name)
|
|
add_text('break;')
|
|
unindent()
|
|
unindent()
|
|
add_text('}}')
|
|
add_text('break;')
|
|
unindent()
|
|
add_text('}}')
|
|
unindent()
|
|
add_text('}}')
|
|
unindent()
|
|
add_text('}}')
|
|
|
|
|
|
def print_protocol():
|
|
with open(sys.argv[1]) as message_file:
|
|
message_scheme = hjson.load(message_file)
|
|
|
|
print_states(message_scheme['states'])
|
|
|
|
newline()
|
|
|
|
print_messages(
|
|
message_scheme['messages'],
|
|
message_scheme['types']['aliases'],
|
|
message_scheme['types']['primitives'])
|
|
|
|
newline()
|
|
|
|
print_handler(message_scheme['messages'])
|
|
|
|
def main():
|
|
global text
|
|
|
|
if len(sys.argv) != 3:
|
|
print('Usage: generate_protocol.py <input: protocol.hjson> <output: ProtocolDefinitions.h>')
|
|
return
|
|
|
|
add_text('#pragma once')
|
|
newline()
|
|
add_text('#include <cstdint>')
|
|
add_text('#include "Types.h"')
|
|
add_text('#include "PacketReader.h"')
|
|
newline()
|
|
add_text('namespace Feather::Protocol')
|
|
add_text('{{')
|
|
indent()
|
|
print_protocol()
|
|
unindent()
|
|
add_text('}}')
|
|
newline()
|
|
|
|
with open(sys.argv[2], 'w') as out_file:
|
|
out_file.write(text)
|
|
out_file.close()
|
|
|
|
main() |