diff --git a/src/betterproto/plugin/models.py b/src/betterproto/plugin/models.py index ea819d44d..618a4865e 100644 --- a/src/betterproto/plugin/models.py +++ b/src/betterproto/plugin/models.py @@ -248,13 +248,55 @@ class OutputTemplate: typing_imports: Set[str] = field(default_factory=set) pydantic_imports: Set[str] = field(default_factory=set) builtins_import: bool = False - messages: List["MessageCompiler"] = field(default_factory=list) - enums: List["EnumDefinitionCompiler"] = field(default_factory=list) + messages: Dict[str, "MessageCompiler"] = field(default_factory=dict) + enums: Dict[str, "EnumDefinitionCompiler"] = field(default_factory=dict) services: List["ServiceCompiler"] = field(default_factory=list) imports_type_checking_only: Set[str] = field(default_factory=set) pydantic_dataclasses: bool = False output: bool = True + def structure(self): + def recursive_structure(descriptor_proto): + branch = {} + for msg in descriptor_proto.nested_type: + branch[msg.name] = {"children": recursive_structure(msg), "kind": "msg"} + for enum_ in descriptor_proto.enum_type: + branch[enum_.name] = {"kind": "enum"} + return branch + + tree = {} + for msg in self.package_proto_obj.message_type: + tree[msg.name] = {"children": recursive_structure(msg), "kind": "msg"} + for enum_ in self.package_proto_obj.enum_type: + tree[enum_.name] = {"kind": "enum"} + + return {"root": tree} + + def structure_with_obj(self): + def recursive_structure(descriptor_proto): + branch = {} + for msg in descriptor_proto.nested_type: + branch[msg.name] = { + "children": recursive_structure(msg), + "kind": "msg", + "obj": self.messages[msg.name], + } + for enum_ in descriptor_proto.enum_type: + branch[enum_.name] = {"kind": "enum", "obj": self.enums[enum_.name]} + return branch + + tree = {} + for msg in self.package_proto_obj.message_type: + tree[msg.name] = { + "children": recursive_structure(msg), + "kind": "msg", + "obj": self.messages[msg.name], + } + for enum_ in self.package_proto_obj.enum_type: + tree[enum_.name] = {"kind": "enum", "obj": self.enums[enum_.name]} + + return {"root": tree} + @property def package(self) -> str: """Name of input package. @@ -305,9 +347,9 @@ def __post_init__(self) -> None: # Add message to output file if isinstance(self.parent, OutputTemplate): if isinstance(self, EnumDefinitionCompiler): - self.output_file.enums.append(self) + self.output_file.enums[self.proto_name] = self else: - self.output_file.messages.append(self) + self.output_file.messages[self.proto_name] = self self.deprecated = self.proto_obj.options.deprecated super().__post_init__() diff --git a/src/betterproto/templates/template.py.j2 b/src/betterproto/templates/template.py.j2 index 5b6715605..ef537be9a 100644 --- a/src/betterproto/templates/template.py.j2 +++ b/src/betterproto/templates/template.py.j2 @@ -48,7 +48,23 @@ if TYPE_CHECKING: {% endfor %} {% endif %} -{% if output_file.enums %}{% for enum in output_file.enums %} +{# + + +original markdown is missing until I finish debugging this + + +#} + +{% set tree = output_file.structure() %} +""" +{{ tree|tojson(indent=4) }} +""" + +{% macro render_class(tree_) %} +{% for k, v in tree_.items() %} +{% if v.kind == "enum" %} +{% set enum = v.obj %} class {{ enum.py_name }}(betterproto.Enum): {% if enum.comment %} {{ enum.comment }} @@ -61,17 +77,17 @@ class {{ enum.py_name }}(betterproto.Enum): {% endif %} {% endfor %} - - -{% endfor %} -{% endif %} -{% for message in output_file.messages %} +{% else %} +{% set message = v.obj %} @dataclass(eq=False, repr=False) class {{ message.py_name }}(betterproto.Message): {% if message.comment %} {{ message.comment }} {% endif %} + + {{ render_class(v.children)|indent }} + {% for field in message.fields %} {{ field.get_field_string() }} {% if field.comment %} @@ -83,25 +99,21 @@ class {{ message.py_name }}(betterproto.Message): pass {% endif %} - {% if message.deprecated or message.has_deprecated_fields %} - def __post_init__(self) -> None: - {% if message.deprecated %} - warnings.warn("{{ message.py_name }} is deprecated", DeprecationWarning) - {% endif %} - super().__post_init__() - {% for field in message.deprecated_fields %} - if self.is_set("{{ field }}"): - warnings.warn("{{ message.py_name }}.{{ field }} is deprecated", DeprecationWarning) - {% endfor %} - {% endif %} +{% endif %} +{% endfor %} +{% endmacro %} - {% if output_file.pydantic_dataclasses and message.has_oneof_fields %} - @root_validator() - def check_oneof(cls, values): - return cls._validate_field_groups(values) - {% endif %} +{% set tree_with_objs = output_file.structure_with_obj() %} +{{ render_class(tree_with_objs["root"]) }} + +{# + + +original markdown is missing until I finish debugging this + + +#} -{% endfor %} {% for service in output_file.services %} class {{ service.py_name }}Stub(betterproto.ServiceStub): {% if service.comment %} diff --git a/tests/inputs/nested_enum/nested_enum.proto b/tests/inputs/nested_enum/nested_enum.proto new file mode 100644 index 000000000..570dcd2ae --- /dev/null +++ b/tests/inputs/nested_enum/nested_enum.proto @@ -0,0 +1,45 @@ +syntax = "proto3"; + +package nested_enum; + +message Test { + enum Inner { + NONE = 0; + THIS = 1; + } + Inner status = 1; + + message Doubly { + enum Inner { + NONE = 0; + THIS = 1; + } + Inner status = 1; + } +} + + +message TestInner { + int32 foo = 1; +} + +message TestDoublyInner { + int32 foo = 1; + string bar = 2; +} + +enum Outer { + foo = 0; + bar = 1; +} + +message Content { + message Status { + string code = 1; + } + Status status = 1; +} + +message ContentStatus { + int32 id = 1; +} diff --git a/tests/test_nested_enums.py b/tests/test_nested_enums.py new file mode 100644 index 000000000..b7f3185b0 --- /dev/null +++ b/tests/test_nested_enums.py @@ -0,0 +1,3 @@ +import betterproto +from dataclasses import dataclass +