Skip to content

Commit

Permalink
[swift] Generate a default value for sub fields with only optional fi…
Browse files Browse the repository at this point in the history
…elds.
  • Loading branch information
dnkoutso committed Oct 2, 2023
1 parent 2d58614 commit c535f94
Show file tree
Hide file tree
Showing 8 changed files with 47 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,14 @@ class SwiftGenerator private constructor(
else -> null
}

private val Field.defaultedValue: CodeBlock?
get() = default?.let {
return defaultFieldInitializer(type!!, it)
} ?: if (isMessage && !isRequiredParameter && !isCollection) {
val subType = schema.getType(type!!) as MessageType
if (subType!!.fields.all { !it.isRequiredParameter }) CodeBlock.of("%T()", subType.typeName) else null
} else null

// see https://protobuf.dev/programming-guides/proto3/#default
private val Field.proto3InitialValue: String
get() = when {
Expand Down Expand Up @@ -651,7 +659,7 @@ class SwiftGenerator private constructor(
}
}
addStatement(
if (field.default != null) "_%N.wrappedValue = %L" else { "self.%N = %L" },
if (field.defaultedValue != null) "_%N.wrappedValue = %L" else { "self.%N = %L" },
field.name,
initializer,
)
Expand Down Expand Up @@ -831,7 +839,7 @@ class SwiftGenerator private constructor(
.map { CodeBlock.of("%S", it) }
.joinToCode()

val prefix = if (field.default != null) { "_%1N.wrappedValue" } else { "self.%1N" }
val prefix = if (field.defaultedValue != null) { "_%1N.wrappedValue" } else { "self.%1N" }
addStatement(
"$prefix = try container.$decode($typeArg%2T.self, $forKeys: $keys)",
field.name,
Expand Down Expand Up @@ -1146,7 +1154,7 @@ class SwiftGenerator private constructor(
.apply {
type.fields.filter { it.isRequiredParameter }.forEach { field ->
addStatement(
if (field.default != null) "_%1N.wrappedValue = %1N" else { "self.%1N = %1N" },
if (field.defaultedValue != null) "_%1N.wrappedValue = %1N" else { "self.%1N = %1N" },
field.name,
)
}
Expand All @@ -1172,7 +1180,7 @@ class SwiftGenerator private constructor(
.apply {
type.fields.forEach { field ->
addStatement(
if (field.default != null) "_%1N.wrappedValue = %1N" else { "self.%1N = %1N" },
if (field.defaultedValue != null) "_%1N.wrappedValue = %1N" else { "self.%1N = %1N" },
field.name,
)
}
Expand Down Expand Up @@ -1217,10 +1225,9 @@ class SwiftGenerator private constructor(
if (isIndirect(type, field)) {
property.addAttribute(AttributeSpec.builder(indirect).build())
}
val default = field.default
if (default != null) {
val defaultValue = defaultFieldInitializer(field.type!!, default)
property.addAttribute(AttributeSpec.builder(defaulted).addArgument("defaultValue: $defaultValue").build())
val defaultedValue = field.defaultedValue
if (defaultedValue != null) {
property.addAttribute(AttributeSpec.builder(defaulted).addArgument("defaultValue: $defaultedValue").build())
}

if (field.isMap) {
Expand Down
7 changes: 4 additions & 3 deletions wire-tests-proto3-swift/src/main/swift/ContainsDuration.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import Wire

public struct ContainsDuration {

@Defaulted(defaultValue: Duration())
public var duration: Duration?
public var unknownFields: Foundation.Data = .init()

Expand All @@ -20,7 +21,7 @@ extension ContainsDuration {
@_disfavoredOverload
@available(*, deprecated)
public init(duration: Duration? = nil) {
self.duration = duration
_duration.wrappedValue = duration
}

}
Expand Down Expand Up @@ -63,7 +64,7 @@ extension ContainsDuration : Proto3Codable {
}
self.unknownFields = try protoReader.endMessage(token: token)

self.duration = duration
_duration.wrappedValue = duration
}

public func encode(to protoWriter: Wire.ProtoWriter) throws {
Expand All @@ -78,7 +79,7 @@ extension ContainsDuration : Codable {

public init(from decoder: Swift.Decoder) throws {
let container = try decoder.container(keyedBy: Wire.StringLiteralCodingKeys.self)
self.duration = try container.decodeIfPresent(Duration.self, forKey: "duration")
_duration.wrappedValue = try container.decodeIfPresent(Duration.self, forKey: "duration")
}

public func encode(to encoder: Swift.Encoder) throws {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import Wire

public struct ContainsTimestamp {

@Defaulted(defaultValue: Timestamp())
public var timestamp: Timestamp?
public var unknownFields: Foundation.Data = .init()

Expand All @@ -20,7 +21,7 @@ extension ContainsTimestamp {
@_disfavoredOverload
@available(*, deprecated)
public init(timestamp: Timestamp? = nil) {
self.timestamp = timestamp
_timestamp.wrappedValue = timestamp
}

}
Expand Down Expand Up @@ -63,7 +64,7 @@ extension ContainsTimestamp : Proto3Codable {
}
self.unknownFields = try protoReader.endMessage(token: token)

self.timestamp = timestamp
_timestamp.wrappedValue = timestamp
}

public func encode(to protoWriter: Wire.ProtoWriter) throws {
Expand All @@ -78,7 +79,7 @@ extension ContainsTimestamp : Codable {

public init(from decoder: Swift.Decoder) throws {
let container = try decoder.container(keyedBy: Wire.StringLiteralCodingKeys.self)
self.timestamp = try container.decodeIfPresent(Timestamp.self, forKey: "timestamp")
_timestamp.wrappedValue = try container.decodeIfPresent(Timestamp.self, forKey: "timestamp")
}

public func encode(to encoder: Swift.Encoder) throws {
Expand Down
14 changes: 8 additions & 6 deletions wire-tests-swift/src/main/swift/AllTypes.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1860,7 +1860,7 @@ extension AllTypes.Storage {
self.opt_string = opt_string
self.opt_bytes = opt_bytes
self.opt_nested_enum = opt_nested_enum
self.opt_nested_message = opt_nested_message
_opt_nested_message.wrappedValue = opt_nested_message
self.req_int32 = req_int32
self.req_uint32 = req_uint32
self.req_sint32 = req_sint32
Expand Down Expand Up @@ -1957,7 +1957,7 @@ extension AllTypes.Storage {
self.ext_opt_string = ext_opt_string
self.ext_opt_bytes = ext_opt_bytes
self.ext_opt_nested_enum = ext_opt_nested_enum
self.ext_opt_nested_message = ext_opt_nested_message
_ext_opt_nested_message.wrappedValue = ext_opt_nested_message
self.ext_rep_int32 = ext_rep_int32
self.ext_rep_uint32 = ext_rep_uint32
self.ext_rep_sint32 = ext_rep_sint32
Expand Down Expand Up @@ -2017,6 +2017,7 @@ extension AllTypes {
public var opt_string: Swift.String?
public var opt_bytes: Foundation.Data?
public var opt_nested_enum: AllTypes.NestedEnum?
@Wire.Defaulted(defaultValue: AllTypes.NestedMessage())
public var opt_nested_message: AllTypes.NestedMessage?
public var req_int32: Swift.Int32
public var req_uint32: Swift.UInt32
Expand Down Expand Up @@ -2130,6 +2131,7 @@ extension AllTypes {
public var ext_opt_string: Swift.String?
public var ext_opt_bytes: Foundation.Data?
public var ext_opt_nested_enum: AllTypes.NestedEnum?
@Wire.Defaulted(defaultValue: AllTypes.NestedMessage())
public var ext_opt_nested_message: AllTypes.NestedMessage?
public var ext_rep_int32: [Swift.Int32] = []
public var ext_rep_uint32: [Swift.UInt32] = []
Expand Down Expand Up @@ -2549,7 +2551,7 @@ extension AllTypes.Storage : Proto2Codable {
self.opt_string = opt_string
self.opt_bytes = opt_bytes
self.opt_nested_enum = opt_nested_enum
self.opt_nested_message = opt_nested_message
_opt_nested_message.wrappedValue = opt_nested_message
self.req_int32 = try AllTypes.checkIfMissing(req_int32, "req_int32")
self.req_uint32 = try AllTypes.checkIfMissing(req_uint32, "req_uint32")
self.req_sint32 = try AllTypes.checkIfMissing(req_sint32, "req_sint32")
Expand Down Expand Up @@ -2646,7 +2648,7 @@ extension AllTypes.Storage : Proto2Codable {
self.ext_opt_string = ext_opt_string
self.ext_opt_bytes = ext_opt_bytes
self.ext_opt_nested_enum = ext_opt_nested_enum
self.ext_opt_nested_message = ext_opt_nested_message
_ext_opt_nested_message.wrappedValue = ext_opt_nested_message
self.ext_rep_int32 = ext_rep_int32
self.ext_rep_uint32 = ext_rep_uint32
self.ext_rep_sint32 = ext_rep_sint32
Expand Down Expand Up @@ -2852,7 +2854,7 @@ extension AllTypes.Storage : Codable {
self.opt_string = try container.decodeIfPresent(Swift.String.self, firstOfKeys: "optString", "opt_string")
self.opt_bytes = try container.decodeIfPresent(stringEncoded: Foundation.Data.self, firstOfKeys: "optBytes", "opt_bytes")
self.opt_nested_enum = try container.decodeIfPresent(AllTypes.NestedEnum.self, firstOfKeys: "optNestedEnum", "opt_nested_enum")
self.opt_nested_message = try container.decodeIfPresent(AllTypes.NestedMessage.self, firstOfKeys: "optNestedMessage", "opt_nested_message")
_opt_nested_message.wrappedValue = try container.decodeIfPresent(AllTypes.NestedMessage.self, firstOfKeys: "optNestedMessage", "opt_nested_message")
self.req_int32 = try container.decode(Swift.Int32.self, firstOfKeys: "reqInt32", "req_int32")
self.req_uint32 = try container.decode(Swift.UInt32.self, firstOfKeys: "reqUint32", "req_uint32")
self.req_sint32 = try container.decode(Swift.Int32.self, firstOfKeys: "reqSint32", "req_sint32")
Expand Down Expand Up @@ -2949,7 +2951,7 @@ extension AllTypes.Storage : Codable {
self.ext_opt_string = try container.decodeIfPresent(Swift.String.self, firstOfKeys: "extOptString", "ext_opt_string")
self.ext_opt_bytes = try container.decodeIfPresent(stringEncoded: Foundation.Data.self, firstOfKeys: "extOptBytes", "ext_opt_bytes")
self.ext_opt_nested_enum = try container.decodeIfPresent(AllTypes.NestedEnum.self, firstOfKeys: "extOptNestedEnum", "ext_opt_nested_enum")
self.ext_opt_nested_message = try container.decodeIfPresent(AllTypes.NestedMessage.self, firstOfKeys: "extOptNestedMessage", "ext_opt_nested_message")
_ext_opt_nested_message.wrappedValue = try container.decodeIfPresent(AllTypes.NestedMessage.self, firstOfKeys: "extOptNestedMessage", "ext_opt_nested_message")
self.ext_rep_int32 = try container.decodeProtoArray(Swift.Int32.self, firstOfKeys: "extRepInt32", "ext_rep_int32")
self.ext_rep_uint32 = try container.decodeProtoArray(Swift.UInt32.self, firstOfKeys: "extRepUint32", "ext_rep_uint32")
self.ext_rep_sint32 = try container.decodeProtoArray(Swift.Int32.self, firstOfKeys: "extRepSint32", "ext_rep_sint32")
Expand Down
7 changes: 4 additions & 3 deletions wire-tests-swift/src/main/swift/FooBar.swift
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ public struct FooBar {

public var foo: Int32?
public var bar: String?
@Defaulted(defaultValue: FooBar.Nested())
public var baz: FooBar.Nested?
public var qux: UInt64?
public var fred: [Float] = []
Expand Down Expand Up @@ -42,7 +43,7 @@ extension FooBar {
) {
self.foo = foo
self.bar = bar
self.baz = baz
_baz.wrappedValue = baz
self.qux = qux
self.fred = fred
self.daisy = daisy
Expand Down Expand Up @@ -124,7 +125,7 @@ extension FooBar : Proto2Codable {

self.foo = foo
self.bar = bar
self.baz = baz
_baz.wrappedValue = baz
self.qux = qux
self.fred = fred
self.daisy = daisy
Expand Down Expand Up @@ -157,7 +158,7 @@ extension FooBar : Codable {
let container = try decoder.container(keyedBy: Wire.StringLiteralCodingKeys.self)
self.foo = try container.decodeIfPresent(Swift.Int32.self, forKey: "foo")
self.bar = try container.decodeIfPresent(Swift.String.self, forKey: "bar")
self.baz = try container.decodeIfPresent(FooBar.Nested.self, forKey: "baz")
_baz.wrappedValue = try container.decodeIfPresent(FooBar.Nested.self, forKey: "baz")
self.qux = try container.decodeIfPresent(stringEncoded: Swift.UInt64.self, forKey: "qux")
self.fred = try container.decodeProtoArray(Swift.Float.self, forKey: "fred")
self.daisy = try container.decodeIfPresent(Swift.Double.self, forKey: "daisy")
Expand Down
7 changes: 4 additions & 3 deletions wire-tests-swift/src/main/swift/OuterMessage.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import Wire
public struct OuterMessage {

public var outer_number_before: Int32?
@Defaulted(defaultValue: EmbeddedMessage())
public var embedded_message: EmbeddedMessage?
public var unknownFields: Foundation.Data = .init()

Expand All @@ -22,7 +23,7 @@ extension OuterMessage {
@available(*, deprecated)
public init(outer_number_before: Swift.Int32? = nil, embedded_message: EmbeddedMessage? = nil) {
self.outer_number_before = outer_number_before
self.embedded_message = embedded_message
_embedded_message.wrappedValue = embedded_message
}

}
Expand Down Expand Up @@ -68,7 +69,7 @@ extension OuterMessage : Proto2Codable {
self.unknownFields = try protoReader.endMessage(token: token)

self.outer_number_before = outer_number_before
self.embedded_message = embedded_message
_embedded_message.wrappedValue = embedded_message
}

public func encode(to protoWriter: Wire.ProtoWriter) throws {
Expand All @@ -85,7 +86,7 @@ extension OuterMessage : Codable {
public init(from decoder: Swift.Decoder) throws {
let container = try decoder.container(keyedBy: Wire.StringLiteralCodingKeys.self)
self.outer_number_before = try container.decodeIfPresent(Swift.Int32.self, firstOfKeys: "outerNumberBefore", "outer_number_before")
self.embedded_message = try container.decodeIfPresent(EmbeddedMessage.self, firstOfKeys: "embeddedMessage", "embedded_message")
_embedded_message.wrappedValue = try container.decodeIfPresent(EmbeddedMessage.self, firstOfKeys: "embeddedMessage", "embedded_message")
}

public func encode(to encoder: Swift.Encoder) throws {
Expand Down
7 changes: 4 additions & 3 deletions wire-tests-swift/src/main/swift/VersionOne.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import Wire
public struct VersionOne {

public var i: Int32?
@Defaulted(defaultValue: NestedVersionOne())
public var obj: NestedVersionOne?
public var en: EnumVersionOne?
public var unknownFields: Foundation.Data = .init()
Expand All @@ -27,7 +28,7 @@ extension VersionOne {
en: EnumVersionOne? = nil
) {
self.i = i
self.obj = obj
_obj.wrappedValue = obj
self.en = en
}

Expand Down Expand Up @@ -76,7 +77,7 @@ extension VersionOne : Proto2Codable {
self.unknownFields = try protoReader.endMessage(token: token)

self.i = i
self.obj = obj
_obj.wrappedValue = obj
self.en = en
}

Expand All @@ -95,7 +96,7 @@ extension VersionOne : Codable {
public init(from decoder: Swift.Decoder) throws {
let container = try decoder.container(keyedBy: Wire.StringLiteralCodingKeys.self)
self.i = try container.decodeIfPresent(Swift.Int32.self, forKey: "i")
self.obj = try container.decodeIfPresent(NestedVersionOne.self, forKey: "obj")
_obj.wrappedValue = try container.decodeIfPresent(NestedVersionOne.self, forKey: "obj")
self.en = try container.decodeIfPresent(EnumVersionOne.self, forKey: "en")
}

Expand Down
7 changes: 4 additions & 3 deletions wire-tests-swift/src/main/swift/VersionTwo.swift
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ public struct VersionTwo {
public var v2_f32: UInt32?
public var v2_f64: UInt64?
public var v2_rs: [String] = []
@Defaulted(defaultValue: NestedVersionTwo())
public var obj: NestedVersionTwo?
public var en: EnumVersionTwo?
public var unknownFields: Foundation.Data = .init()
Expand Down Expand Up @@ -42,7 +43,7 @@ extension VersionTwo {
self.v2_f32 = v2_f32
self.v2_f64 = v2_f64
self.v2_rs = v2_rs
self.obj = obj
_obj.wrappedValue = obj
self.en = en
}

Expand Down Expand Up @@ -106,7 +107,7 @@ extension VersionTwo : Proto2Codable {
self.v2_f32 = v2_f32
self.v2_f64 = v2_f64
self.v2_rs = v2_rs
self.obj = obj
_obj.wrappedValue = obj
self.en = en
}

Expand Down Expand Up @@ -135,7 +136,7 @@ extension VersionTwo : Codable {
self.v2_f32 = try container.decodeIfPresent(Swift.UInt32.self, firstOfKeys: "v2F32", "v2_f32")
self.v2_f64 = try container.decodeIfPresent(stringEncoded: Swift.UInt64.self, firstOfKeys: "v2F64", "v2_f64")
self.v2_rs = try container.decodeProtoArray(Swift.String.self, firstOfKeys: "v2Rs", "v2_rs")
self.obj = try container.decodeIfPresent(NestedVersionTwo.self, forKey: "obj")
_obj.wrappedValue = try container.decodeIfPresent(NestedVersionTwo.self, forKey: "obj")
self.en = try container.decodeIfPresent(EnumVersionTwo.self, forKey: "en")
}

Expand Down

0 comments on commit c535f94

Please sign in to comment.