Skip to content

Commit

Permalink
[swift] Provide a default value for sub fields and common types as pe…
Browse files Browse the repository at this point in the history
…r proto spec
  • Loading branch information
dnkoutso committed Oct 11, 2023
1 parent b9f6835 commit 6db6816
Show file tree
Hide file tree
Showing 11 changed files with 114 additions and 63 deletions.
7 changes: 4 additions & 3 deletions wire-runtime-swift/src/test/swift/sample/Dinosaur.swift
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ public struct Dinosaur {
public var picture_urls: [String] = []
public var length_meters: Double?
public var mass_kilograms: Double?
@Defaulted(defaultValue: Period.CRETACEOUS)
public var period: Period?
public var unknownFields: Foundation.Data = .init()

Expand All @@ -40,7 +41,7 @@ extension Dinosaur {
self.picture_urls = picture_urls
self.length_meters = length_meters
self.mass_kilograms = mass_kilograms
self.period = period
_period.wrappedValue = period
}

}
Expand Down Expand Up @@ -95,7 +96,7 @@ extension Dinosaur : Proto2Codable {
self.picture_urls = picture_urls
self.length_meters = length_meters
self.mass_kilograms = mass_kilograms
self.period = period
_period.wrappedValue = period
}

public func encode(to protoWriter: Wire.ProtoWriter) throws {
Expand All @@ -118,7 +119,7 @@ extension Dinosaur : Codable {
self.picture_urls = try container.decodeProtoArray(Swift.String.self, firstOfKeys: "pictureUrls", "picture_urls")
self.length_meters = try container.decodeIfPresent(Swift.Double.self, firstOfKeys: "lengthMeters", "length_meters")
self.mass_kilograms = try container.decodeIfPresent(Swift.Double.self, firstOfKeys: "massKilograms", "mass_kilograms")
self.period = try container.decodeIfPresent(Period.self, forKey: "period")
_period.wrappedValue = try container.decodeIfPresent(Period.self, forKey: "period")
}

public func encode(to encoder: Swift.Encoder) throws {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,39 @@ class SwiftGenerator private constructor(
else -> null
}

private val Field.defaultedValue: CodeBlock?
get() = default?.let {
return defaultFieldInitializer(type!!, it)
} ?: when {
!isOptional -> null
type == ProtoType.ANY -> null
typeName == BOOL -> CodeBlock.of("%L", false)
typeName == INT -> CodeBlock.of("%L", 0)
typeName == INT32 -> CodeBlock.of("%L", 0)
typeName == INT64 -> CodeBlock.of("%L", 0)
typeName == UINT32 -> CodeBlock.of("%L", 0)
typeName == UINT64 -> CodeBlock.of("%L", 0)
typeName == FLOAT -> CodeBlock.of("%L", 0)
typeName == DOUBLE -> CodeBlock.of("%L", 0)
typeName == STRING -> CodeBlock.of("%S", "")
typeName == DATA -> CodeBlock.of(
"%T(base64Encoded: %S)!",
FOUNDATION_DATA,
"".encode(charset = Charsets.ISO_8859_1).base64(),
)

isEnum -> {
val enumType = schema.getType(type!!) as EnumType
CodeBlock.of("%T.%L", typeName.makeNonOptional(), enumType.constants[0].name)
}
isMessage && !isRequiredParameter && !isCollection -> {
val messageType = schema.getType(type!!) as MessageType
if (messageType.fields.any { it.isRequiredParameter }) null else CodeBlock.of("%T()", messageType.typeName)
}

else -> null
}

// see https://protobuf.dev/programming-guides/proto3/#default
private val Field.proto3InitialValue: String
get() = when {
Expand Down Expand Up @@ -651,7 +684,7 @@ class SwiftGenerator private constructor(
}
}
addStatement(
if (field.default != null) "_%N.wrappedValue = %L" else { "self.%N = %L" },
if (!isIndirect(type, field) && field.defaultedValue != null) "_%N.wrappedValue = %L" else { "self.%N = %L" },
field.name,
initializer,
)
Expand Down Expand Up @@ -831,7 +864,7 @@ class SwiftGenerator private constructor(
.map { CodeBlock.of("%S", it) }
.joinToCode()

val prefix = if (field.default != null) { "_%1N.wrappedValue" } else { "self.%1N" }
val prefix = if (!isIndirect(type, field) && field.defaultedValue != null) { "_%1N.wrappedValue" } else { "self.%1N" }
addStatement(
"$prefix = try container.$decode($typeArg%2T.self, $forKeys: $keys)",
field.name,
Expand Down Expand Up @@ -1146,7 +1179,7 @@ class SwiftGenerator private constructor(
.apply {
type.fields.filter { it.isRequiredParameter }.forEach { field ->
addStatement(
if (field.default != null) "_%1N.wrappedValue = %1N" else { "self.%1N = %1N" },
if (!isIndirect(type, field) && field.defaultedValue != null) "_%1N.wrappedValue = %1N" else { "self.%1N = %1N" },
field.name,
)
}
Expand All @@ -1172,7 +1205,7 @@ class SwiftGenerator private constructor(
.apply {
type.fields.forEach { field ->
addStatement(
if (field.default != null) "_%1N.wrappedValue = %1N" else { "self.%1N = %1N" },
if (!isIndirect(type, field) && field.defaultedValue != null) "_%1N.wrappedValue = %1N" else { "self.%1N = %1N" },
field.name,
)
}
Expand Down Expand Up @@ -1224,10 +1257,9 @@ class SwiftGenerator private constructor(
if (isIndirect(type, field)) {
property.addAttribute(AttributeSpec.builder(indirect).build())
}
val default = field.default
if (default != null) {
val defaultValue = defaultFieldInitializer(field.type!!, default)
property.addAttribute(AttributeSpec.builder(defaulted).addArgument("defaultValue: $defaultValue").build())
val defaultedValue = field.defaultedValue
if (!isIndirect(type, field) && defaultedValue != null) {
property.addAttribute(AttributeSpec.builder(defaulted).addArgument("defaultValue: $defaultedValue").build())
}

if (field.isMap) {
Expand Down Expand Up @@ -1278,7 +1310,8 @@ class SwiftGenerator private constructor(
typeName == DOUBLE -> defaultValue.toDoubleFieldInitializer()
typeName == STRING -> CodeBlock.of("%S", stringLiteralWithQuotes2(defaultValue.toString()))
typeName == DATA -> CodeBlock.of(
"Foundation.Data(base64Encoded: %S)!",
"%T(base64Encoded: %S)!",
FOUNDATION_DATA,
defaultValue.toString().encode(charset = Charsets.ISO_8859_1).base64(),
)
protoType.isEnum -> CodeBlock.of("%T.%L", typeName, defaultValue)
Expand Down
7 changes: 4 additions & 3 deletions wire-tests-proto3-swift/src/main/swift/ContainsDuration.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import Wire

public struct ContainsDuration {

@Defaulted(defaultValue: Duration())
public var duration: Duration?
public var unknownFields: Foundation.Data = .init()

Expand All @@ -20,7 +21,7 @@ extension ContainsDuration {
@_disfavoredOverload
@available(*, deprecated)
public init(duration: Duration? = nil) {
self.duration = duration
_duration.wrappedValue = duration
}

}
Expand Down Expand Up @@ -63,7 +64,7 @@ extension ContainsDuration : Proto3Codable {
}
self.unknownFields = try protoReader.endMessage(token: token)

self.duration = duration
_duration.wrappedValue = duration
}

public func encode(to protoWriter: Wire.ProtoWriter) throws {
Expand All @@ -78,7 +79,7 @@ extension ContainsDuration : Codable {

public init(from decoder: Swift.Decoder) throws {
let container = try decoder.container(keyedBy: Wire.StringLiteralCodingKeys.self)
self.duration = try container.decodeIfPresent(Duration.self, forKey: "duration")
_duration.wrappedValue = try container.decodeIfPresent(Duration.self, forKey: "duration")
}

public func encode(to encoder: Swift.Encoder) throws {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import Wire

public struct ContainsTimestamp {

@Defaulted(defaultValue: Timestamp())
public var timestamp: Timestamp?
public var unknownFields: Foundation.Data = .init()

Expand All @@ -20,7 +21,7 @@ extension ContainsTimestamp {
@_disfavoredOverload
@available(*, deprecated)
public init(timestamp: Timestamp? = nil) {
self.timestamp = timestamp
_timestamp.wrappedValue = timestamp
}

}
Expand Down Expand Up @@ -63,7 +64,7 @@ extension ContainsTimestamp : Proto3Codable {
}
self.unknownFields = try protoReader.endMessage(token: token)

self.timestamp = timestamp
_timestamp.wrappedValue = timestamp
}

public func encode(to protoWriter: Wire.ProtoWriter) throws {
Expand All @@ -78,7 +79,7 @@ extension ContainsTimestamp : Codable {

public init(from decoder: Swift.Decoder) throws {
let container = try decoder.container(keyedBy: Wire.StringLiteralCodingKeys.self)
self.timestamp = try container.decodeIfPresent(Timestamp.self, forKey: "timestamp")
_timestamp.wrappedValue = try container.decodeIfPresent(Timestamp.self, forKey: "timestamp")
}

public func encode(to encoder: Swift.Encoder) throws {
Expand Down
28 changes: 16 additions & 12 deletions wire-tests-swift/no-manifest/src/main/swift/AllTypes.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1859,8 +1859,8 @@ extension AllTypes.Storage {
self.opt_double = opt_double
self.opt_string = opt_string
self.opt_bytes = opt_bytes
self.opt_nested_enum = opt_nested_enum
self.opt_nested_message = opt_nested_message
_opt_nested_enum.wrappedValue = opt_nested_enum
_opt_nested_message.wrappedValue = opt_nested_message
self.req_int32 = req_int32
self.req_uint32 = req_uint32
self.req_sint32 = req_sint32
Expand Down Expand Up @@ -1956,8 +1956,8 @@ extension AllTypes.Storage {
self.ext_opt_double = ext_opt_double
self.ext_opt_string = ext_opt_string
self.ext_opt_bytes = ext_opt_bytes
self.ext_opt_nested_enum = ext_opt_nested_enum
self.ext_opt_nested_message = ext_opt_nested_message
_ext_opt_nested_enum.wrappedValue = ext_opt_nested_enum
_ext_opt_nested_message.wrappedValue = ext_opt_nested_message
self.ext_rep_int32 = ext_rep_int32
self.ext_rep_uint32 = ext_rep_uint32
self.ext_rep_sint32 = ext_rep_sint32
Expand Down Expand Up @@ -2016,7 +2016,9 @@ extension AllTypes {
public var opt_double: Swift.Double?
public var opt_string: Swift.String?
public var opt_bytes: Foundation.Data?
@Wire.Defaulted(defaultValue: AllTypes.NestedEnum.UNKNOWN)
public var opt_nested_enum: AllTypes.NestedEnum?
@Wire.Defaulted(defaultValue: AllTypes.NestedMessage())
public var opt_nested_message: AllTypes.NestedMessage?
public var req_int32: Swift.Int32
public var req_uint32: Swift.UInt32
Expand Down Expand Up @@ -2129,7 +2131,9 @@ extension AllTypes {
public var ext_opt_double: Swift.Double?
public var ext_opt_string: Swift.String?
public var ext_opt_bytes: Foundation.Data?
@Wire.Defaulted(defaultValue: AllTypes.NestedEnum.UNKNOWN)
public var ext_opt_nested_enum: AllTypes.NestedEnum?
@Wire.Defaulted(defaultValue: AllTypes.NestedMessage())
public var ext_opt_nested_message: AllTypes.NestedMessage?
public var ext_rep_int32: [Swift.Int32] = []
public var ext_rep_uint32: [Swift.UInt32] = []
Expand Down Expand Up @@ -2548,8 +2552,8 @@ extension AllTypes.Storage : Proto2Codable {
self.opt_double = opt_double
self.opt_string = opt_string
self.opt_bytes = opt_bytes
self.opt_nested_enum = opt_nested_enum
self.opt_nested_message = opt_nested_message
_opt_nested_enum.wrappedValue = opt_nested_enum
_opt_nested_message.wrappedValue = opt_nested_message
self.req_int32 = try AllTypes.checkIfMissing(req_int32, "req_int32")
self.req_uint32 = try AllTypes.checkIfMissing(req_uint32, "req_uint32")
self.req_sint32 = try AllTypes.checkIfMissing(req_sint32, "req_sint32")
Expand Down Expand Up @@ -2645,8 +2649,8 @@ extension AllTypes.Storage : Proto2Codable {
self.ext_opt_double = ext_opt_double
self.ext_opt_string = ext_opt_string
self.ext_opt_bytes = ext_opt_bytes
self.ext_opt_nested_enum = ext_opt_nested_enum
self.ext_opt_nested_message = ext_opt_nested_message
_ext_opt_nested_enum.wrappedValue = ext_opt_nested_enum
_ext_opt_nested_message.wrappedValue = ext_opt_nested_message
self.ext_rep_int32 = ext_rep_int32
self.ext_rep_uint32 = ext_rep_uint32
self.ext_rep_sint32 = ext_rep_sint32
Expand Down Expand Up @@ -2851,8 +2855,8 @@ extension AllTypes.Storage : Codable {
self.opt_double = try container.decodeIfPresent(Swift.Double.self, firstOfKeys: "optDouble", "opt_double")
self.opt_string = try container.decodeIfPresent(Swift.String.self, firstOfKeys: "optString", "opt_string")
self.opt_bytes = try container.decodeIfPresent(stringEncoded: Foundation.Data.self, firstOfKeys: "optBytes", "opt_bytes")
self.opt_nested_enum = try container.decodeIfPresent(AllTypes.NestedEnum.self, firstOfKeys: "optNestedEnum", "opt_nested_enum")
self.opt_nested_message = try container.decodeIfPresent(AllTypes.NestedMessage.self, firstOfKeys: "optNestedMessage", "opt_nested_message")
_opt_nested_enum.wrappedValue = try container.decodeIfPresent(AllTypes.NestedEnum.self, firstOfKeys: "optNestedEnum", "opt_nested_enum")
_opt_nested_message.wrappedValue = try container.decodeIfPresent(AllTypes.NestedMessage.self, firstOfKeys: "optNestedMessage", "opt_nested_message")
self.req_int32 = try container.decode(Swift.Int32.self, firstOfKeys: "reqInt32", "req_int32")
self.req_uint32 = try container.decode(Swift.UInt32.self, firstOfKeys: "reqUint32", "req_uint32")
self.req_sint32 = try container.decode(Swift.Int32.self, firstOfKeys: "reqSint32", "req_sint32")
Expand Down Expand Up @@ -2948,8 +2952,8 @@ extension AllTypes.Storage : Codable {
self.ext_opt_double = try container.decodeIfPresent(Swift.Double.self, firstOfKeys: "extOptDouble", "ext_opt_double")
self.ext_opt_string = try container.decodeIfPresent(Swift.String.self, firstOfKeys: "extOptString", "ext_opt_string")
self.ext_opt_bytes = try container.decodeIfPresent(stringEncoded: Foundation.Data.self, firstOfKeys: "extOptBytes", "ext_opt_bytes")
self.ext_opt_nested_enum = try container.decodeIfPresent(AllTypes.NestedEnum.self, firstOfKeys: "extOptNestedEnum", "ext_opt_nested_enum")
self.ext_opt_nested_message = try container.decodeIfPresent(AllTypes.NestedMessage.self, firstOfKeys: "extOptNestedMessage", "ext_opt_nested_message")
_ext_opt_nested_enum.wrappedValue = try container.decodeIfPresent(AllTypes.NestedEnum.self, firstOfKeys: "extOptNestedEnum", "ext_opt_nested_enum")
_ext_opt_nested_message.wrappedValue = try container.decodeIfPresent(AllTypes.NestedMessage.self, firstOfKeys: "extOptNestedMessage", "ext_opt_nested_message")
self.ext_rep_int32 = try container.decodeProtoArray(Swift.Int32.self, firstOfKeys: "extRepInt32", "ext_rep_int32")
self.ext_rep_uint32 = try container.decodeProtoArray(Swift.UInt32.self, firstOfKeys: "extRepUint32", "ext_rep_uint32")
self.ext_rep_sint32 = try container.decodeProtoArray(Swift.Int32.self, firstOfKeys: "extRepSint32", "ext_rep_sint32")
Expand Down
21 changes: 12 additions & 9 deletions wire-tests-swift/no-manifest/src/main/swift/FooBar.swift
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@ public struct FooBar {

public var foo: Int32?
public var bar: String?
@Defaulted(defaultValue: FooBar.Nested())
public var baz: FooBar.Nested?
public var qux: UInt64?
public var fred: [Float] = []
public var daisy: Double?
public var nested: [FooBar] = []
@Defaulted(defaultValue: FooBar.FooBarBazEnum.FOO)
public var ext: FooBar.FooBarBazEnum?
public var rep: [FooBar.FooBarBazEnum] = []
public var more_string: String?
Expand Down Expand Up @@ -42,12 +44,12 @@ extension FooBar {
) {
self.foo = foo
self.bar = bar
self.baz = baz
_baz.wrappedValue = baz
self.qux = qux
self.fred = fred
self.daisy = daisy
self.nested = nested
self.ext = ext
_ext.wrappedValue = ext
self.rep = rep
self.more_string = more_string
}
Expand Down Expand Up @@ -124,12 +126,12 @@ extension FooBar : Proto2Codable {

self.foo = foo
self.bar = bar
self.baz = baz
_baz.wrappedValue = baz
self.qux = qux
self.fred = fred
self.daisy = daisy
self.nested = nested
self.ext = ext
_ext.wrappedValue = ext
self.rep = rep
self.more_string = more_string
}
Expand Down Expand Up @@ -157,12 +159,12 @@ extension FooBar : Codable {
let container = try decoder.container(keyedBy: Wire.StringLiteralCodingKeys.self)
self.foo = try container.decodeIfPresent(Swift.Int32.self, forKey: "foo")
self.bar = try container.decodeIfPresent(Swift.String.self, forKey: "bar")
self.baz = try container.decodeIfPresent(FooBar.Nested.self, forKey: "baz")
_baz.wrappedValue = try container.decodeIfPresent(FooBar.Nested.self, forKey: "baz")
self.qux = try container.decodeIfPresent(stringEncoded: Swift.UInt64.self, forKey: "qux")
self.fred = try container.decodeProtoArray(Swift.Float.self, forKey: "fred")
self.daisy = try container.decodeIfPresent(Swift.Double.self, forKey: "daisy")
self.nested = try container.decodeProtoArray(FooBar.self, forKey: "nested")
self.ext = try container.decodeIfPresent(FooBar.FooBarBazEnum.self, forKey: "ext")
_ext.wrappedValue = try container.decodeIfPresent(FooBar.FooBarBazEnum.self, forKey: "ext")
self.rep = try container.decodeProtoArray(FooBar.FooBarBazEnum.self, forKey: "rep")
self.more_string = try container.decodeIfPresent(Swift.String.self, firstOfKeys: "moreString", "more_string")
}
Expand Down Expand Up @@ -200,6 +202,7 @@ extension FooBar {

public struct Nested {

@Wire.Defaulted(defaultValue: FooBar.FooBarBazEnum.FOO)
public var value: FooBar.FooBarBazEnum?
public var unknownFields: Foundation.Data = .init()

Expand Down Expand Up @@ -244,7 +247,7 @@ extension FooBar.Nested {
@_disfavoredOverload
@available(*, deprecated)
public init(value: FooBar.FooBarBazEnum? = nil) {
self.value = value
_value.wrappedValue = value
}

}
Expand Down Expand Up @@ -287,7 +290,7 @@ extension FooBar.Nested : Proto2Codable {
}
self.unknownFields = try protoReader.endMessage(token: token)

self.value = value
_value.wrappedValue = value
}

public func encode(to protoWriter: Wire.ProtoWriter) throws {
Expand All @@ -302,7 +305,7 @@ extension FooBar.Nested : Codable {

public init(from decoder: Swift.Decoder) throws {
let container = try decoder.container(keyedBy: Wire.StringLiteralCodingKeys.self)
self.value = try container.decodeIfPresent(FooBar.FooBarBazEnum.self, forKey: "value")
_value.wrappedValue = try container.decodeIfPresent(FooBar.FooBarBazEnum.self, forKey: "value")
}

public func encode(to encoder: Swift.Encoder) throws {
Expand Down
Loading

0 comments on commit 6db6816

Please sign in to comment.