Skip to content

Commit

Permalink
Treat -0.0 as consistently distinct from 0.0 in pure python
Browse files Browse the repository at this point in the history
The previous treatment was a conformance violation, where implicit present float fields with a non-default value of -0.0 could get dropped.

PiperOrigin-RevId: 705728806
  • Loading branch information
mkruskal-google authored and copybara-github committed Dec 13, 2024
1 parent 5188672 commit bc16fe8
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 20 deletions.
10 changes: 0 additions & 10 deletions conformance/text_format_failure_list_python.txt
Original file line number Diff line number Diff line change
@@ -1,18 +1,8 @@
# This is the list of text format conformance tests that are known to fail right
# now.
# TODO: These should be fixed.
Required.*.TextFormatInput.FloatFieldNegativeZero.ProtobufOutput # Output was not equivalent to reference message: deleted: optional_float: -0
Required.*.TextFormatInput.FloatFieldNegativeZero.TextFormatOutput # Output was not equivalent to reference message: deleted: optional_float: -0
Required.*.TextFormatInput.FloatFieldNegativeZero_F.ProtobufOutput # Output was not equivalent to reference message: deleted: optional_float: -0
Required.*.TextFormatInput.FloatFieldNegativeZero_F.TextFormatOutput # Output was not equivalent to reference message: deleted: optional_float: -0
Required.*.TextFormatInput.FloatFieldNegativeZero_f.ProtobufOutput # Output was not equivalent to reference message: deleted: optional_float: -0
Required.*.TextFormatInput.FloatFieldNegativeZero_f.TextFormatOutput # Output was not equivalent to reference message: deleted: optional_float: -0
Required.*.TextFormatInput.FloatFieldNoNegativeOctal # Should have failed to parse, but didn't.
Required.*.TextFormatInput.FloatFieldNoOctal # Should have failed to parse, but didn't.
Required.*.TextFormatInput.NegDoubleFieldLargeNegativeExponentParsesAsNegZero.ProtobufOutput # Output was not equivalent to reference message: deleted: optional_double: -0
Required.*.TextFormatInput.NegDoubleFieldLargeNegativeExponentParsesAsNegZero.TextFormatOutput # Output was not equivalent to reference message: deleted: optional_double: -0
Required.*.TextFormatInput.NegFloatFieldLargeNegativeExponentParsesAsNegZero.ProtobufOutput # Output was not equivalent to reference message: deleted: optional_float: -0
Required.*.TextFormatInput.NegFloatFieldLargeNegativeExponentParsesAsNegZero.TextFormatOutput # Output was not equivalent to reference message: deleted: optional_float: -0
Required.*.TextFormatInput.StringLiteralBasicEscapesBytes.ProtobufOutput # Output was not equivalent to reference message: modified: optional_bytes: "\007\010\014\n\r\t\013?\\\'\"" -> "\007\010\014\n\r\t
Required.*.TextFormatInput.StringLiteralBasicEscapesBytes.TextFormatOutput # Output was not equivalent to reference message: modified: optional_bytes: "\007\010\014\n\r\t\013?\\\'\"" -> "\007\010\014\n\r\t
Required.*.TextFormatInput.StringLiteralBasicEscapesString.ProtobufOutput # Output was not equivalent to reference message: modified: optional_string: "\007\010\014\n\r\t\013?\\\'\"" -> "\007\010\014\n\r\
Expand Down
30 changes: 26 additions & 4 deletions python/google/protobuf/internal/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
__author__ = '[email protected] (Kenton Varda)'

import math
import numbers
import struct

from google.protobuf import message
Expand All @@ -71,6 +72,27 @@
_DecodeError = message.DecodeError


def IsDefaultScalarValue(value):
"""Returns whether or not a scalar value is the default value of its type.
Specifically, this should be used to determine presence of implicit-presence
fields, where we disallow custom defaults.
Args:
value: A scalar value to check.
Returns:
True if the value is equivalent to a default value, False otherwise.
"""
if isinstance(value, numbers.Number) and math.copysign(1.0, value) < 0:
# Special case for negative zero, where "truthiness" fails to give the right
# answer.
return False

# Normally, we can just use Python's boolean conversion.
return not value


def _VarintDecoder(mask, result_type):
"""Return an encoder for a basic varint value (does not include tag).
Expand Down Expand Up @@ -237,7 +259,7 @@ def DecodeField(buffer, pos, end, message, field_dict):
(new_value, pos) = decode_value(buffer, pos)
if pos > end:
raise _DecodeError('Truncated message.')
if clear_if_default and not new_value:
if clear_if_default and IsDefaultScalarValue(new_value):
field_dict.pop(key, None)
else:
field_dict[key] = new_value
Expand Down Expand Up @@ -478,7 +500,7 @@ def DecodeField(buffer, pos, end, message, field_dict):
(enum_value, pos) = _DecodeSignedVarint32(buffer, pos)
if pos > end:
raise _DecodeError('Truncated message.')
if clear_if_default and not enum_value:
if clear_if_default and IsDefaultScalarValue(enum_value):
field_dict.pop(key, None)
return pos
# pylint: disable=protected-access
Expand Down Expand Up @@ -573,7 +595,7 @@ def DecodeField(buffer, pos, end, message, field_dict):
new_pos = pos + size
if new_pos > end:
raise _DecodeError('Truncated string.')
if clear_if_default and not size:
if clear_if_default and IsDefaultScalarValue(size):
field_dict.pop(key, None)
else:
field_dict[key] = _ConvertToUnicode(buffer[pos:new_pos])
Expand Down Expand Up @@ -614,7 +636,7 @@ def DecodeField(buffer, pos, end, message, field_dict):
new_pos = pos + size
if new_pos > end:
raise _DecodeError('Truncated string.')
if clear_if_default and not size:
if clear_if_default and IsDefaultScalarValue(size):
field_dict.pop(key, None)
else:
field_dict[key] = buffer[pos:new_pos].tobytes()
Expand Down
12 changes: 11 additions & 1 deletion python/google/protobuf/internal/decoder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@
from google.protobuf.internal import testing_refleaks
from google.protobuf.internal import wire_format

from google.protobuf.internal import _parameterized


_INPUT_BYTES = b'\x84r\x12'
_EXPECTED = (14596, 18)


@testing_refleaks.TestCase
class DecoderTest(unittest.TestCase):
class DecoderTest(_parameterized.TestCase):

def test_decode_varint_bytes(self):
(size, pos) = decoder._DecodeVarint(_INPUT_BYTES, 0)
Expand Down Expand Up @@ -128,6 +130,14 @@ def test_unknown_message_set_decoder_mismatched_end_group(self):
memoryview(b'\054\014'),
)

@_parameterized.parameters(int(0), float(0.0), False, '')
def test_default_scalar(self, value):
self.assertTrue(decoder.IsDefaultScalarValue(value))

@_parameterized.parameters(int(1), float(-0.0), float(1.0), True, 'a')
def test_not_default_scalar(self, value):
self.assertFalse(decoder.IsDefaultScalarValue(value))


if __name__ == '__main__':
unittest.main()
18 changes: 17 additions & 1 deletion python/google/protobuf/internal/message_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1784,13 +1784,29 @@ def testProto3ParserDropDefaultScalar(self):
message_proto2 = unittest_pb2.TestAllTypes()
message_proto2.optional_int32 = 0
message_proto2.optional_string = ''
message_proto2.optional_float = 0.0
message_proto2.optional_bytes = b''
self.assertEqual(len(message_proto2.ListFields()), 3)
self.assertEqual(len(message_proto2.ListFields()), 4)

message_proto3 = unittest_proto3_arena_pb2.TestAllTypes()
message_proto3.ParseFromString(message_proto2.SerializeToString())
self.assertEqual(len(message_proto3.ListFields()), 0)

def testProto3ParserKeepsNonDefaultScalar(self):
message_proto2 = unittest_pb2.TestAllTypes()
message_proto2.optional_int32 = 1
message_proto2.optional_string = '\0'
message_proto2.optional_float = -0.0
message_proto2.optional_double = -0.0
message_proto2.optional_bytes = b'\0'
self.assertEqual(len(message_proto2.ListFields()), 5)
serialized = message_proto2.SerializeToString()

message_proto3 = unittest_proto3_arena_pb2.TestAllTypes()
message_proto3.ParseFromString(serialized)
self.assertEqual(len(message_proto3.ListFields()), 5)
self.assertEqual(message_proto3.SerializeToString(), serialized)

def testProto3Optional(self):
msg = test_proto3_optional_pb2.TestProto3Optional()
self.assertFalse(msg.HasField('optional_int32'))
Expand Down
7 changes: 4 additions & 3 deletions python/google/protobuf/internal/python_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

import datetime
from io import BytesIO
import math
import struct
import sys
import warnings
Expand Down Expand Up @@ -716,14 +717,14 @@ def getter(self):

def field_setter(self, new_value):
# pylint: disable=protected-access
# Testing the value for truthiness captures all of the proto3 defaults
# (0, 0.0, enum 0, and False).
# Testing the value for truthiness captures all of the implicit presence
# defaults (0, 0.0, enum 0, and False), except for -0.0.
try:
new_value = type_checker.CheckValue(new_value)
except TypeError as e:
raise TypeError(
'Cannot set %s to %.1024r: %s' % (field.full_name, new_value, e))
if not field.has_presence and not new_value:
if not field.has_presence and decoder.IsDefaultScalarValue(new_value):
self._fields.pop(field, None)
else:
self._fields[field] = new_value
Expand Down
14 changes: 14 additions & 0 deletions python/google/protobuf/internal/text_format_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -979,6 +979,20 @@ def testParseDuplicateScalars(self, message_module):
r'have multiple "optional_int32" fields.'), text_format.Parse, text,
message)

def testParseDuplicateNegativeZero(self, message_module):
message = message_module.TestAllTypes()
text = 'optional_double: -0.0 optional_double: 3'
self.assertRaisesRegex(
text_format.ParseError,
(
r'1:40 : Message type "\w+.TestAllTypes" should not '
r'have multiple "optional_double" fields.'
),
text_format.Parse,
text,
message,
)

def testParseExistingScalarInMessage(self, message_module):
message = message_module.TestAllTypes(optional_int32=42)
text = 'optional_int32: 67'
Expand Down
4 changes: 3 additions & 1 deletion python/google/protobuf/text_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -1165,7 +1165,9 @@ def _MergeScalarField(self, tokenizer, message, field):
else:
# For field that doesn't represent presence, try best effort to
# check multiple scalars by compare to default values.
duplicate_error = bool(getattr(message, field.name))
duplicate_error = not decoder.IsDefaultScalarValue(
getattr(message, field.name)
)

if duplicate_error:
raise tokenizer.ParseErrorPreviousToken(
Expand Down

0 comments on commit bc16fe8

Please sign in to comment.