diff --git a/wire-schema/api/wire-schema.api b/wire-schema/api/wire-schema.api index a19b53b261..3061842fba 100644 --- a/wire-schema/api/wire-schema.api +++ b/wire-schema/api/wire-schema.api @@ -484,6 +484,7 @@ public final class com/squareup/wire/schema/MessageType : com/squareup/wire/sche public final fun isDeprecated ()Z public fun linkMembers (Lcom/squareup/wire/schema/Linker;)V public fun linkOptions (Lcom/squareup/wire/schema/Linker;Lcom/squareup/wire/schema/SyntaxRules;Z)V + public final fun oneOf (Ljava/lang/String;)Lcom/squareup/wire/schema/OneOf; public fun retainAll (Lcom/squareup/wire/schema/Schema;Lcom/squareup/wire/schema/MarkSet;)Lcom/squareup/wire/schema/Type; public fun retainLinked (Ljava/util/Set;Ljava/util/Set;)Lcom/squareup/wire/schema/Type; public final fun toElement ()Lcom/squareup/wire/schema/internal/parser/MessageElement; diff --git a/wire-schema/src/commonMain/kotlin/com/squareup/wire/schema/MessageType.kt b/wire-schema/src/commonMain/kotlin/com/squareup/wire/schema/MessageType.kt index f50f5b4ebb..cd7508da7b 100644 --- a/wire-schema/src/commonMain/kotlin/com/squareup/wire/schema/MessageType.kt +++ b/wire-schema/src/commonMain/kotlin/com/squareup/wire/schema/MessageType.kt @@ -86,6 +86,10 @@ data class MessageType( fun extensionField(qualifiedName: String): Field? = extensionFields.firstOrNull { it.qualifiedName == qualifiedName } + /** Returns the oneOf named [name], or null if this type has no such oneOf. */ + fun oneOf(name: String): OneOf? = + oneOfs.firstOrNull { it.name == name } + /** Returns the field tagged [tag], or null if this type has no such field. */ fun field(tag: Int): Field? { for (field in declaredFields) { diff --git a/wire-schema/src/commonMain/kotlin/com/squareup/wire/schema/Pruner.kt b/wire-schema/src/commonMain/kotlin/com/squareup/wire/schema/Pruner.kt index 64d4cd5082..7b993f5834 100644 --- a/wire-schema/src/commonMain/kotlin/com/squareup/wire/schema/Pruner.kt +++ b/wire-schema/src/commonMain/kotlin/com/squareup/wire/schema/Pruner.kt @@ -105,8 +105,12 @@ class Pruner( val member = protoMember.member return when (val type = schema.getType(protoMember.type)) { is MessageType -> { - val field = type.field(member) ?: type.extensionField(member)!! - pruningRules.isFieldRetainedVersion(field.options) + val field = type.field(member) ?: type.extensionField(member) + if (field != null) { + pruningRules.isFieldRetainedVersion(field.options) + } else { + pruningRules.isFieldRetainedVersion(type.oneOf(member)!!.options) + } } is EnumType -> { val enumConstant = type.constant(member)!! @@ -177,9 +181,14 @@ class Pruner( if (type is MessageType) { val field = type.field(member) ?: type.extensionField(member) - checkNotNull(field) { "unexpected member: $member" } - result.add(field.type) - options = field.options + if (field != null) { + result.add(field.type) + options = field.options + } else { + val oneOf = type.oneOf(member) + checkNotNull(oneOf) { "unexpected member: $member" } + options = oneOf.options + } } else if (type is EnumType) { val constant = type.constant(member) ?: throw IllegalStateException("unexpected member: $member") @@ -218,6 +227,7 @@ class Pruner( result.add(get(root, field.qualifiedName)) } for (oneOf in type.oneOfs) { + result.add(get(root, oneOf.name)) for (field in oneOf.fields) { result.add(get(root, field.name)) } diff --git a/wire-schema/src/jvmTest/kotlin/com/squareup/wire/schema/PrunerTest.kt b/wire-schema/src/jvmTest/kotlin/com/squareup/wire/schema/PrunerTest.kt index b754b472b7..ce218703cc 100644 --- a/wire-schema/src/jvmTest/kotlin/com/squareup/wire/schema/PrunerTest.kt +++ b/wire-schema/src/jvmTest/kotlin/com/squareup/wire/schema/PrunerTest.kt @@ -56,6 +56,135 @@ class PrunerTest { assertThat(pruned.getType("MessageB")).isNull() } + @Test + fun oneOfOptionsAreNotArbitrarilyPruned() { + val schema = buildSchema { + add( + "test_event.proto".toPath(), + """ + |syntax = "proto3"; + | + |import "test_event_custom_option.proto"; + | + |package test.oneOf.options.test; + | + |message TestMessage { + | oneof element { + | option (my_custom_oneOf_option) = true; + | string one = 1; + | string two = 2; + | } + |} + """.trimMargin(), + ) + add( + "test_event_custom_option.proto".toPath(), + """ + |syntax = "proto3"; + | + |import "google/protobuf/descriptor.proto"; + | + |package test.oneOf.options; + | + |extend google.protobuf.OneofOptions { + | bool my_custom_oneOf_option = 101400; + |} + """.trimMargin(), + ) + } + val pruned = schema.prune( + PruningRules.Builder() + .addRoot("test.oneOf.options.test.TestMessage") + .build(), + ) + assertThat(pruned.protoFile("test_event.proto")!!.toSchema()) + .isEqualTo( + // spotless:off because spotless will remove the indents (trailing spaces) in the oneof block. + """ + |// Proto schema formatted by Wire, do not edit. + |// Source: test_event.proto + | + |syntax = "proto3"; + | + |package test.oneOf.options.test; + | + |import "test_event_custom_option.proto"; + | + |message TestMessage { + | oneof element { + | option (my_custom_oneOf_option) = true; + | + | string one = 1; + | string two = 2; + | } + |} + |""".trimMargin(), + // spotless:on + ) + } + + @Test + fun oneOfOptionsArePruned() { + val schema = buildSchema { + add( + "test_event.proto".toPath(), + """ + |syntax = "proto3"; + | + |import "test_event_custom_option.proto"; + | + |package test.oneOf.options.test; + | + |message TestMessage { + | oneof element { + | option (my_custom_oneOf_option) = true; + | string one = 1; + | string two = 2; + | } + |} + """.trimMargin(), + ) + add( + "test_event_custom_option.proto".toPath(), + """ + |syntax = "proto3"; + | + |import "google/protobuf/descriptor.proto"; + | + |package test.oneOf.options; + | + |extend google.protobuf.OneofOptions { + | bool my_custom_oneOf_option = 101400; + |} + """.trimMargin(), + ) + } + val pruned = schema.prune( + PruningRules.Builder() + .prune("google.protobuf.OneofOptions#test.oneOf.options.my_custom_oneOf_option") + .build(), + ) + assertThat(pruned.protoFile("test_event.proto")!!.toSchema()) + .isEqualTo( + """ + |// Proto schema formatted by Wire, do not edit. + |// Source: test_event.proto + | + |syntax = "proto3"; + | + |package test.oneOf.options.test; + | + |message TestMessage { + | oneof element { + | string one = 1; + | string two = 2; + | } + |} + | + """.trimMargin(), + ) + } + @Test fun retainMap() { val schema = buildSchema {