Files
2024-12-09 08:01:09 -07:00

389 lines
12 KiB
Python

#!/usr/bin/env python3
###############################################################################
## Copyright (C) Photon Vision.
###############################################################################
## This program is free software: you can redistribute it and/or modify
## it under the terms of the GNU General Public License as published by
## the Free Software Foundation, either version 3 of the License, or
## (at your option) any later version.
##
## This program is distributed in the hope that it will be useful,
## but WITHOUT ANY WARRANTY; without even the implied warranty of
## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
## GNU General Public License for more details.
##
## You should have received a copy of the GNU General Public License
## along with this program. If not, see <https://www.gnu.org/licenses/>.
###############################################################################
import argparse
import copy
import hashlib
import os
import sys
from pathlib import Path
from typing import List, TypedDict, cast
import yaml
from jinja2 import Environment, FileSystemLoader
class SerdeField(TypedDict):
name: str
type: str
# optional extra args
optional: bool
vla: bool
class MessageType(TypedDict):
name: str
fields: List[SerdeField]
# will be 'shim' if shimmed, and the shims will be set
shimmed: bool
java_decode_shim: str
java_encode_shim: str
# C++ helpers
cpp_include: str
# python shim types
python_encode_shim: str
python_decode_shim: str
# Java import name
java_import: str
# Remember our message hash. Recalculated by us. All intrinsic types are unhashed so this is fine to live here
message_hash: str
schema_str: str
def yaml_to_dict(path: str):
script_dir = os.path.dirname(os.path.abspath(__file__))
yaml_file_path = os.path.join(script_dir, path)
with open(yaml_file_path, "r") as file:
file_dict: dict = yaml.safe_load(file)
return file_dict
data_types = yaml_to_dict("message_data_types.yaml")
# Helper to check if we need to use our own decoder
def is_intrinsic_type(type_str: str):
ret = type_str in data_types.keys()
return ret
# Deal with shimmed types
def get_shimmed_filter(message_db):
def is_shimmed(message_name: str):
# We don't (yet) support shimming intrinsic types
if is_intrinsic_type(message_name):
return False
message = get_message_by_name(message_db, message_name)
return "shimmed" in message and message["shimmed"] == True
return is_shimmed
def get_qualified_cpp_name(
message_db: List[MessageType], data_types, field: SerdeField
):
"""
Get the full name of the type encoded. Eg:
std::optional<photon::TargetCorner>
std::array<frc::Transform3d>
"""
if get_shimmed_filter(message_db)(field["type"]):
base_type = get_message_by_name(message_db, field["type"])["cpp_type"]
else:
base_type = data_types[field["type"]]["cpp_type"]
if "optional" in field and field["optional"] == True:
typestr = f"std::optional<{base_type}>"
elif "vla" in field and field["vla"] == True:
typestr = f"std::vector<{base_type}>"
else:
typestr = base_type
return typestr
def get_message_by_name(message_db: List[MessageType], message_name: str):
try:
return next(
message for message in message_db if message["name"] == message_name
)
except StopIteration as e:
raise Exception("Could not find " + message_name) from e
def get_field_by_name(message: MessageType, field_name: str):
return next(f for f in message["fields"] if f["name"] == field_name)
def get_message_hash(message_db: List[MessageType], message: MessageType) -> str:
"""
Calculate a unique message hash via MD5 sum. This is a very similar approach to rosmsg, documented:
http://wiki.ros.org/ROS/Technical%20Overview#Message_serialization_and_msg_MD5_sums
For non-intrinsic (user-defined) types, replace its type-string with the md5sum of the submessage definition
"""
# replace the non-intrinsic typename with its hash
modified_message = copy.deepcopy(message)
fields_to_hash = [
field
for field in modified_message["fields"]
if not is_intrinsic_type(field["type"])
]
for field in fields_to_hash:
sub_message = get_message_by_name(message_db, field["type"])
get_message_hash(message_db, sub_message)
schema = get_struct_schema_str(message, message_db)
message_hash = hashlib.md5(schema.encode("ascii")).hexdigest()
# and remember the hash
message["message_hash"] = message_hash
message["schema_str"] = schema
return message_hash
def get_includes(db, message: MessageType) -> str:
includes = []
for field in message["fields"]:
if not is_intrinsic_type(field["type"]):
field_msg = get_message_by_name(db, field["type"])
if "shimmed" in field_msg and field_msg["shimmed"] == True:
includes.append(field_msg["cpp_include"])
else:
# must be a photon type.
includes.append(f"\"photon/targeting/{field_msg['name']}.h\"")
if "optional" in field and field["optional"] == True:
includes.append("<optional>")
if "vla" in field and field["vla"] == True:
includes.append("<vector>")
# stdint types
includes.append("<stdint.h>")
return sorted(set(includes))
def parse_yaml() -> List[MessageType]:
config = yaml_to_dict("messages.yaml")
return config
INTRINSIC_TYPE_ALIASES = {
"float": "float32",
"double": "float64",
}
def get_fully_defined_field_name(field: SerdeField, message_db: List[MessageType]):
"""
Get the fully-defined, globally unique type name for a field. Returns something like
Transform3d:b290703ff9e54f9ec2c733b90d7fc30b for user-defined types, or just
something like int64 for built-in types. Also normalizes float/double to float32/float64
Args:
field: The field we want the name of
message_db: All other loaded messages
"""
typestr = field["type"]
if not is_intrinsic_type(field["type"]):
msg = get_message_by_name(message_db, field["type"])
is_shimmed = get_shimmed_filter(message_db)(field["type"])
if not is_shimmed:
typestr = field["type"] + ":" + msg["message_hash"]
else:
# handle replacing float/doubles
typestr = field["type"]
typestr = INTRINSIC_TYPE_ALIASES.get(typestr, typestr)
return typestr
def get_struct_schema_str(message: MessageType, message_db: List[MessageType]):
ret = ""
for field in message["fields"]:
if (
"optional" in field
and field["optional"] == True
and "vla" in field
and field["vla"] == True
):
raise Exception(f"Field {field} must be optional OR vla!")
typestr = get_fully_defined_field_name(field, message_db)
array_modifier = ""
if "optional" in field and field["optional"] == True:
typestr = "optional " + typestr
if "vla" in field and field["vla"] == True:
array_modifier = "[?]"
ret += f"{typestr} {field['name']}{array_modifier};"
return ret
def generate_photon_messages(cpp_java_root, py_root, template_root):
messages = parse_yaml()
for message in messages:
message["message_hash"] = get_message_hash(messages, message)
env = Environment(
loader=FileSystemLoader(str(template_root)),
# autoescape=False,
# keep_trailing_newline=False,
)
env.filters["is_intrinsic"] = is_intrinsic_type
env.filters["is_shimmed"] = get_shimmed_filter(messages)
# add our custom types
extended_data_types = data_types.copy()
for message in messages:
name = message["name"]
extended_data_types[name] = {
"len": -1,
"java_type": name,
"cpp_type": "photon::" + name,
}
java_output_dir = Path(cpp_java_root) / "main/java/org/photonvision/struct"
java_output_dir.mkdir(parents=True, exist_ok=True)
cpp_serde_header_dir = Path(cpp_java_root) / "main/native/include/photon/serde/"
cpp_serde_header_dir.mkdir(parents=True, exist_ok=True)
cpp_serde_source_dir = Path(cpp_java_root) / "main/native/cpp/photon/serde/"
cpp_serde_source_dir.mkdir(parents=True, exist_ok=True)
cpp_struct_header_dir = Path(cpp_java_root) / "main/native/include/photon/struct/"
cpp_struct_header_dir.mkdir(parents=True, exist_ok=True)
py_serde_source_dir = Path(py_root)
py_serde_source_dir.mkdir(parents=True, exist_ok=True)
env.filters["get_qualified_name"] = lambda field: get_qualified_cpp_name(
messages, extended_data_types, field
)
for message in messages:
# don't generate shimmed types
if get_shimmed_filter(messages)(message["name"]):
continue
message = cast(MessageType, message)
java_name = f"{message['name']}Serde.java"
cpp_serde_header_name = f"{message['name']}Serde.h"
cpp_serde_source_name = f"{message['name']}Serde.cpp"
cpp_struct_header_name = f"{message['name']}Struct.h"
py_name = f"{message['name']}Serde.py"
java_template = env.get_template("Message.java.jinja")
cpp_serde_header_template = env.get_template("ThingSerde.h.jinja")
cpp_serde_source_template = env.get_template("ThingSerde.cpp.jinja")
cpp_struct_header_template = env.get_template("ThingStruct.h.jinja")
py_template = env.get_template("ThingSerde.py.jinja")
message_hash = get_message_hash(messages, message)
for output_name, template, output_folder in [
[java_name, java_template, java_output_dir],
[cpp_serde_header_name, cpp_serde_header_template, cpp_serde_header_dir],
[cpp_serde_source_name, cpp_serde_source_template, cpp_serde_source_dir],
[cpp_struct_header_name, cpp_struct_header_template, cpp_struct_header_dir],
[py_name, py_template, py_serde_source_dir],
]:
# Hack in our message getter
template.globals["get_message_by_name"] = lambda name: get_message_by_name(
messages, name
)
nested_photon_types = set(
[
field["type"]
for field in message["fields"]
if (
not is_intrinsic_type(field["type"])
and not get_shimmed_filter(messages)(field["type"])
)
]
)
nested_wpilib_types = set(
[
field["type"]
for field in message["fields"]
if (
not is_intrinsic_type(field["type"])
and get_shimmed_filter(messages)(field["type"])
)
]
)
output_file = output_folder / output_name
output_file.write_text(
template.render(
message,
type_map=extended_data_types,
message_fmt=get_struct_schema_str(message, messages),
message_hash=message_hash,
cpp_includes=get_includes(messages, message),
nested_photon_types=nested_photon_types,
nested_wpilib_types=nested_wpilib_types,
),
encoding="utf-8",
)
def main(argv):
script_path = Path(__file__).resolve()
dirname = script_path.parent
parser = argparse.ArgumentParser()
parser.add_argument(
"--cpp_java_output_dir",
help="Optional. If set, will output the generated files to this directory, otherwise it will use a path relative to the script",
default=dirname.parent / "photon-targeting/src/generated",
type=Path,
)
parser.add_argument(
"--py_output_dir",
help="Optional. If set, will spit Python serde files here",
default=dirname.parent / "photon-lib/py/photonlibpy/generated",
type=Path,
)
parser.add_argument(
"--template_root",
help="Optional. If set, will use this directory as the root for the jinja templates",
default=dirname / "templates",
type=Path,
)
args = parser.parse_args(argv)
generate_photon_messages(
args.cpp_java_output_dir, args.py_output_dir, args.template_root
)
if __name__ == "__main__":
main(sys.argv[1:])