diff --git a/src/betterproto/__init__.py b/src/betterproto/__init__.py index 8eee6f4c4..6719b4917 100644 --- a/src/betterproto/__init__.py +++ b/src/betterproto/__init__.py @@ -1515,24 +1515,30 @@ def to_dict( else: output[cased_name] = b64encode(value).decode("utf8") elif meta.proto_type == TYPE_ENUM: + def name(enum_class, value): + obj = enum_class(value) + if hasattr(obj.__class__, 'full_name') and isinstance(obj.__class__.full_name, property): + return obj.full_name + return obj.name if field_is_repeated: enum_class = field_types[field_name].__args__[0] if isinstance(value, typing.Iterable) and not isinstance( value, str ): - output[cased_name] = [enum_class(el).name for el in value] + output[cased_name] = [ + name(enum_class, el) for el in value] else: # transparently upgrade single value to repeated - output[cased_name] = [enum_class(value).name] + output[cased_name] = [name(enum_class, value)] elif value is None: if include_default_values: output[cased_name] = value elif meta.optional: enum_class = field_types[field_name].__args__[0] - output[cased_name] = enum_class(value).name + output[cased_name] = name(enum_class, value) else: enum_class = field_types[field_name] # noqa - output[cased_name] = enum_class(value).name + output[cased_name] = name(enum_class, value) elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE): if field_is_repeated: output[cased_name] = [_dump_float(n) for n in value] @@ -1592,10 +1598,13 @@ def _from_dict_init(cls, mapping: Mapping[str, Any]) -> Mapping[str, Any]: ) elif meta.proto_type == TYPE_ENUM: enum_cls = cls._betterproto.cls_by_field[field_name] + + def obj(enum_class, value): + return enum_class.from_full_name(value) if hasattr(enum_class, 'from_full_name') else enum_class.from_string(value) if isinstance(value, list): - value = [enum_cls.from_string(e) for e in value] + value = [obj(enum_cls, e) for e in value] elif isinstance(value, str): - value = enum_cls.from_string(value) + value = obj(enum_cls, value) elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE): value = ( [_parse_float(n) for n in value] diff --git a/src/betterproto/plugin/models.py b/src/betterproto/plugin/models.py index c1eb9d7ed..ca14e82a3 100644 --- a/src/betterproto/plugin/models.py +++ b/src/betterproto/plugin/models.py @@ -65,6 +65,7 @@ from betterproto.lib.google.protobuf.compiler import CodeGeneratorRequest from .. import which_one_of +from ..casing import sanitize_name from ..compile.importing import ( get_type_reference, parse_source_type_name, @@ -639,6 +640,7 @@ class EnumEntry: """Representation of an Enum entry.""" name: str + full_name: str value: int comment: str @@ -649,6 +651,7 @@ def __post_init__(self) -> None: name=pythonize_enum_member_name( entry_proto_value.name, self.proto_obj.name ), + full_name=sanitize_name(entry_proto_value.name), value=entry_proto_value.number, comment=get_comment( proto_file=self.source_file, path=self.path + [2, entry_number] diff --git a/src/betterproto/templates/template.py.j2 b/src/betterproto/templates/template.py.j2 index e8ed3d8f4..56ab864fd 100644 --- a/src/betterproto/templates/template.py.j2 +++ b/src/betterproto/templates/template.py.j2 @@ -20,6 +20,26 @@ class {{ enum.py_name }}(betterproto.Enum): return core_schema.int_schema(ge=0) {% endif %} + @property + def full_name(self): + return {{ enum.py_name }}_full_name_map[self.value] + + @classmethod + def from_full_name(cls, full_name): + return cls.from_string({{ enum.py_name }}_full_name_reverse_map[full_name]) + + +{{ enum.py_name }}_full_name_map = { +{% for entry in enum.entries %} + {{ entry.value }}: "{{ entry.full_name }}", +{% endfor %} +} + +{{ enum.py_name }}_full_name_reverse_map = { +{% for entry in enum.entries %} + "{{ entry.full_name }}": "{{ entry.name }}", +{% endfor %} +} {% endfor %} {% endif %} {% for message in output_file.messages %} diff --git a/tests/inputs/enum/enum.proto b/tests/inputs/enum/enum.proto index 5e2e80c1f..073f2e061 100644 --- a/tests/inputs/enum/enum.proto +++ b/tests/inputs/enum/enum.proto @@ -6,6 +6,7 @@ package enum; message Test { Choice choice = 1; repeated Choice choices = 2; + ArithmeticOperator op = 3; } enum Choice { diff --git a/tests/inputs/enum/test_enum.py b/tests/inputs/enum/test_enum.py index 21a5ac3b9..1476d3dc2 100644 --- a/tests/inputs/enum/test_enum.py +++ b/tests/inputs/enum/test_enum.py @@ -1,3 +1,6 @@ +import betterproto +from dataclasses import dataclass + from tests.output_betterproto.enum import ( ArithmeticOperator, Choice, @@ -112,3 +115,50 @@ def test_renamed_enum_members(): "MINUS", "_0_PREFIXED", } + + +def test_enum_full_name(): + assert ArithmeticOperator.NONE.full_name == "ARITHMETIC_OPERATOR_NONE" + assert ArithmeticOperator.PLUS.full_name == "ARITHMETIC_OPERATOR_PLUS" + assert ArithmeticOperator._0_PREFIXED.full_name == "ARITHMETIC_OPERATOR_0_PREFIXED" + + +def test_enum_to_json(): + assert Test(op=ArithmeticOperator.NONE).to_json() == '{}' + assert Test(op=ArithmeticOperator.PLUS).to_json( + ) == '{"op": "ARITHMETIC_OPERATOR_PLUS"}' + assert Test(op=ArithmeticOperator._0_PREFIXED).to_json( + ) == '{"op": "ARITHMETIC_OPERATOR_0_PREFIXED"}' + + +def test_enum_from_json(): + assert Test().from_json('{}').op == ArithmeticOperator.NONE + assert Test().from_json( + '{"op": "ARITHMETIC_OPERATOR_PLUS"}').op == ArithmeticOperator.PLUS + assert Test().from_json( + '{"op": "ARITHMETIC_OPERATOR_0_PREFIXED"}').op == ArithmeticOperator._0_PREFIXED + + +class EnumCompat(betterproto.Enum): + NONE = 0 + PLUS = 1 + MINUS = 2 + + +@dataclass(eq=False, repr=False) +class CompatTest(betterproto.Message): + enum: "EnumCompat" = betterproto.enum_field(1) + + +def test_enum_to_json_backwards_compat(): + assert CompatTest(enum=EnumCompat.NONE).to_json() == '{}' + assert CompatTest(enum=EnumCompat.PLUS).to_json( + ) == '{"enum": "PLUS"}' + assert CompatTest(enum=EnumCompat.MINUS).to_json( + ) == '{"enum": "MINUS"}' + + +def test_enum_from_json_backwards_compat(): + assert CompatTest().from_json('{}').enum == EnumCompat.NONE + assert CompatTest().from_json('{"enum": "PLUS"}').enum == EnumCompat.PLUS + assert CompatTest().from_json('{"enum": "MINUS"}').enum == EnumCompat.MINUS