Skip to content

Commit

Permalink
Figuring out nested enum collissions - #212
Browse files Browse the repository at this point in the history
  • Loading branch information
cetanu committed Oct 16, 2023
1 parent 8659c51 commit b9ec59f
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 27 deletions.
50 changes: 46 additions & 4 deletions src/betterproto/plugin/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,13 +248,55 @@ class OutputTemplate:
typing_imports: Set[str] = field(default_factory=set)
pydantic_imports: Set[str] = field(default_factory=set)
builtins_import: bool = False
messages: List["MessageCompiler"] = field(default_factory=list)
enums: List["EnumDefinitionCompiler"] = field(default_factory=list)
messages: Dict[str, "MessageCompiler"] = field(default_factory=dict)
enums: Dict[str, "EnumDefinitionCompiler"] = field(default_factory=dict)
services: List["ServiceCompiler"] = field(default_factory=list)
imports_type_checking_only: Set[str] = field(default_factory=set)
pydantic_dataclasses: bool = False
output: bool = True

def structure(self):
def recursive_structure(descriptor_proto):
branch = {}
for msg in descriptor_proto.nested_type:
branch[msg.name] = {"children": recursive_structure(msg), "kind": "msg"}
for enum_ in descriptor_proto.enum_type:
branch[enum_.name] = {"kind": "enum"}
return branch

tree = {}
for msg in self.package_proto_obj.message_type:
tree[msg.name] = {"children": recursive_structure(msg), "kind": "msg"}
for enum_ in self.package_proto_obj.enum_type:
tree[enum_.name] = {"kind": "enum"}

return {"root": tree}

def structure_with_obj(self):
def recursive_structure(descriptor_proto):
branch = {}
for msg in descriptor_proto.nested_type:
branch[msg.name] = {
"children": recursive_structure(msg),
"kind": "msg",
"obj": self.messages[msg.name],
}
for enum_ in descriptor_proto.enum_type:
branch[enum_.name] = {"kind": "enum", "obj": self.enums[enum_.name]}
return branch

tree = {}
for msg in self.package_proto_obj.message_type:
tree[msg.name] = {
"children": recursive_structure(msg),
"kind": "msg",
"obj": self.messages[msg.name],
}
for enum_ in self.package_proto_obj.enum_type:
tree[enum_.name] = {"kind": "enum", "obj": self.enums[enum_.name]}

return {"root": tree}

@property
def package(self) -> str:
"""Name of input package.
Expand Down Expand Up @@ -305,9 +347,9 @@ def __post_init__(self) -> None:
# Add message to output file
if isinstance(self.parent, OutputTemplate):
if isinstance(self, EnumDefinitionCompiler):
self.output_file.enums.append(self)
self.output_file.enums[self.proto_name] = self
else:
self.output_file.messages.append(self)
self.output_file.messages[self.proto_name] = self
self.deprecated = self.proto_obj.options.deprecated
super().__post_init__()

Expand Down
58 changes: 35 additions & 23 deletions src/betterproto/templates/template.py.j2
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,23 @@ if TYPE_CHECKING:
{% endfor %}
{% endif %}

{% if output_file.enums %}{% for enum in output_file.enums %}
{#
original markdown is missing until I finish debugging this
#}

{% set tree = output_file.structure() %}
"""
{{ tree|tojson(indent=4) }}
"""

{% macro render_class(tree_) %}
{% for k, v in tree_.items() %}
{% if v.kind == "enum" %}
{% set enum = v.obj %}
class {{ enum.py_name }}(betterproto.Enum):
{% if enum.comment %}
{{ enum.comment }}
Expand All @@ -61,17 +77,17 @@ class {{ enum.py_name }}(betterproto.Enum):

{% endif %}
{% endfor %}


{% endfor %}
{% endif %}
{% for message in output_file.messages %}
{% else %}
{% set message = v.obj %}
@dataclass(eq=False, repr=False)
class {{ message.py_name }}(betterproto.Message):
{% if message.comment %}
{{ message.comment }}

{% endif %}

{{ render_class(v.children)|indent }}

{% for field in message.fields %}
{{ field.get_field_string() }}
{% if field.comment %}
Expand All @@ -83,25 +99,21 @@ class {{ message.py_name }}(betterproto.Message):
pass
{% endif %}

{% if message.deprecated or message.has_deprecated_fields %}
def __post_init__(self) -> None:
{% if message.deprecated %}
warnings.warn("{{ message.py_name }} is deprecated", DeprecationWarning)
{% endif %}
super().__post_init__()
{% for field in message.deprecated_fields %}
if self.is_set("{{ field }}"):
warnings.warn("{{ message.py_name }}.{{ field }} is deprecated", DeprecationWarning)
{% endfor %}
{% endif %}
{% endif %}
{% endfor %}
{% endmacro %}

{% if output_file.pydantic_dataclasses and message.has_oneof_fields %}
@root_validator()
def check_oneof(cls, values):
return cls._validate_field_groups(values)
{% endif %}
{% set tree_with_objs = output_file.structure_with_obj() %}
{{ render_class(tree_with_objs["root"]) }}

{#
original markdown is missing until I finish debugging this
#}

{% endfor %}
{% for service in output_file.services %}
class {{ service.py_name }}Stub(betterproto.ServiceStub):
{% if service.comment %}
Expand Down
45 changes: 45 additions & 0 deletions tests/inputs/nested_enum/nested_enum.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
syntax = "proto3";

package nested_enum;

message Test {
enum Inner {
NONE = 0;
THIS = 1;
}
Inner status = 1;

message Doubly {
enum Inner {
NONE = 0;
THIS = 1;
}
Inner status = 1;
}
}


message TestInner {
int32 foo = 1;
}

message TestDoublyInner {
int32 foo = 1;
string bar = 2;
}

enum Outer {
foo = 0;
bar = 1;
}

message Content {
message Status {
string code = 1;
}
Status status = 1;
}

message ContentStatus {
int32 id = 1;
}
3 changes: 3 additions & 0 deletions tests/test_nested_enums.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import betterproto
from dataclasses import dataclass

0 comments on commit b9ec59f

Please sign in to comment.