Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nest message/enum names #597

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 46 additions & 4 deletions src/betterproto/plugin/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,13 +248,55 @@ class OutputTemplate:
typing_imports: Set[str] = field(default_factory=set)
pydantic_imports: Set[str] = field(default_factory=set)
builtins_import: bool = False
messages: List["MessageCompiler"] = field(default_factory=list)
enums: List["EnumDefinitionCompiler"] = field(default_factory=list)
messages: Dict[str, "MessageCompiler"] = field(default_factory=dict)
enums: Dict[str, "EnumDefinitionCompiler"] = field(default_factory=dict)
services: List["ServiceCompiler"] = field(default_factory=list)
imports_type_checking_only: Set[str] = field(default_factory=set)
pydantic_dataclasses: bool = False
output: bool = True

def structure(self):
def recursive_structure(descriptor_proto):
branch = {}
for msg in descriptor_proto.nested_type:
branch[msg.name] = {"children": recursive_structure(msg), "kind": "msg"}
for enum_ in descriptor_proto.enum_type:
branch[enum_.name] = {"kind": "enum"}
return branch

tree = {}
for msg in self.package_proto_obj.message_type:
tree[msg.name] = {"children": recursive_structure(msg), "kind": "msg"}
for enum_ in self.package_proto_obj.enum_type:
tree[enum_.name] = {"kind": "enum"}

return {"root": tree}

def structure_with_obj(self):
def recursive_structure(descriptor_proto):
branch = {}
for msg in descriptor_proto.nested_type:
branch[msg.name] = {
"children": recursive_structure(msg),
"kind": "msg",
"obj": self.messages[msg.name],
}
for enum_ in descriptor_proto.enum_type:
branch[enum_.name] = {"kind": "enum", "obj": self.enums[enum_.name]}
return branch

tree = {}
for msg in self.package_proto_obj.message_type:
tree[msg.name] = {
"children": recursive_structure(msg),
"kind": "msg",
"obj": self.messages[msg.name],
}
for enum_ in self.package_proto_obj.enum_type:
tree[enum_.name] = {"kind": "enum", "obj": self.enums[enum_.name]}

return {"root": tree}

@property
def package(self) -> str:
"""Name of input package.
Expand Down Expand Up @@ -305,9 +347,9 @@ def __post_init__(self) -> None:
# Add message to output file
if isinstance(self.parent, OutputTemplate):
if isinstance(self, EnumDefinitionCompiler):
self.output_file.enums.append(self)
self.output_file.enums[self.proto_name] = self
else:
self.output_file.messages.append(self)
self.output_file.messages[self.proto_name] = self
self.deprecated = self.proto_obj.options.deprecated
super().__post_init__()

Expand Down
2 changes: 1 addition & 1 deletion src/betterproto/plugin/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,10 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
output_path = pathlib.Path(*output_package_name.split("."), "__init__.py")
output_paths.add(output_path)

# Render and then format the output file
response.file.append(
CodeGeneratorResponseFile(
name=str(output_path),
# Render and then format the output file
content=outputfile_compiler(output_file=output_package),
)
)
Expand Down
70 changes: 47 additions & 23 deletions src/betterproto/templates/template.py.j2
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,30 @@ if TYPE_CHECKING:
{% endfor %}
{% endif %}

{% if output_file.enums %}{% for enum in output_file.enums %}
class {{ enum.py_name }}(betterproto.Enum):
{#


original markdown is missing until I finish debugging this


#}

{% set tree = output_file.structure() %}
"""
{{ tree|tojson(indent=4) }}
"""

{% macro render_class(tree_, parent=none) %}
{% for k, v in tree_.items() %}
{% if v.kind == "enum" %}
{% set enum = v.obj %}

{% set enum_name = enum.py_name %}
{% if parent is not none and enum.py_name.startswith(parent) %}
{% set enum_name = enum.py_name[parent|length:] %}
{% endif %}

class {{ enum_name }}(betterproto.Enum):
{% if enum.comment %}
{{ enum.comment }}

Expand All @@ -61,17 +83,23 @@ class {{ enum.py_name }}(betterproto.Enum):

{% endif %}
{% endfor %}
{% else %}
{% set message = v.obj %}


{% endfor %}
{% set msg_name = message.py_name %}
{% if parent is not none and message.py_name.startswith(parent) %}
{% set msg_name = message.py_name[parent|length:] %}
{% endif %}
{% for message in output_file.messages %}

@dataclass(eq=False, repr=False)
class {{ message.py_name }}(betterproto.Message):
class {{ msg_name }}(betterproto.Message):
{% if message.comment %}
{{ message.comment }}

{% endif %}

{{ render_class(v.children, parent=message.py_name)|indent }}

{% for field in message.fields %}
{{ field.get_field_string() }}
{% if field.comment %}
Expand All @@ -83,25 +111,21 @@ class {{ message.py_name }}(betterproto.Message):
pass
{% endif %}

{% if message.deprecated or message.has_deprecated_fields %}
def __post_init__(self) -> None:
{% if message.deprecated %}
warnings.warn("{{ message.py_name }} is deprecated", DeprecationWarning)
{% endif %}
super().__post_init__()
{% for field in message.deprecated_fields %}
if self.is_set("{{ field }}"):
warnings.warn("{{ message.py_name }}.{{ field }} is deprecated", DeprecationWarning)
{% endfor %}
{% endif %}
{% endif %}
{% endfor %}
{% endmacro %}

{% if output_file.pydantic_dataclasses and message.has_oneof_fields %}
@root_validator()
def check_oneof(cls, values):
return cls._validate_field_groups(values)
{% endif %}
{% set tree_with_objs = output_file.structure_with_obj() %}
{{ render_class(tree_with_objs["root"]) }}

{#


original markdown is missing until I finish debugging this


#}

{% endfor %}
{% for service in output_file.services %}
class {{ service.py_name }}Stub(betterproto.ServiceStub):
{% if service.comment %}
Expand Down
45 changes: 45 additions & 0 deletions tests/inputs/nested_enum/nested_enum.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
syntax = "proto3";

package nested_enum;

message Test {
enum Inner {
NONE = 0;
THIS = 1;
}
Inner status = 1;

message Doubly {
enum Inner {
NONE = 0;
THIS = 1;
}
Inner status = 1;
}
}


message TestInner {
int32 foo = 1;
}

message TestDoublyInner {
int32 foo = 1;
string bar = 2;
}

enum Outer {
foo = 0;
bar = 1;
}

message Content {
message Status {
string code = 1;
}
Status status = 1;
}

message ContentStatus {
int32 id = 1;
}
3 changes: 3 additions & 0 deletions tests/test_nested_enums.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import betterproto
from dataclasses import dataclass

9 changes: 9 additions & 0 deletions tests/test_pickling.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ class NestedData(betterproto.Message):
)



class Color(betterproto.Enum):
BLUE = 0
GREEN = 1


@dataclass(eq=False, repr=False)
class Complex(betterproto.Message):
foo_str: str = betterproto.string_field(1)
Expand All @@ -55,6 +61,7 @@ class Complex(betterproto.Message):
mapping: Dict[str, "google.Any"] = betterproto.map_field(
7, betterproto.TYPE_STRING, betterproto.TYPE_MESSAGE
)
color: "Color" = betterproto.enum_field(8)


def complex_msg():
Expand All @@ -81,6 +88,7 @@ def complex_msg():
"message": google.Any(value=bytes(Fi(abc="hi"))),
"string": google.Any(value=b"howdy"),
},
color=Color.BLUE,
)


Expand All @@ -89,6 +97,7 @@ def test_pickling_complex_message():
deser = unpickled(msg)
assert msg == deser
assert msg.fe.abc == "1"
assert msg.color == Color.BLUE
assert msg.is_set("fi") is not True
assert msg.mapping["message"] == google.Any(value=bytes(Fi(abc="hi")))
assert msg.mapping["string"].value.decode() == "howdy"
Expand Down