From e63c7f59d5b9b156667f3f569c79b82055a3d523 Mon Sep 17 00:00:00 2001 From: Scott Hart Date: Thu, 18 Jul 2024 15:05:25 -0400 Subject: [PATCH] impl(generator): support protobuf wrapper types as query params --- generator/internal/descriptor_utils_test.cc | 4 +- generator/internal/http_option_utils.cc | 90 ++++++++++---- generator/internal/http_option_utils.h | 18 +++ generator/internal/http_option_utils_test.cc | 120 ++++++++++++++++++- 4 files changed, 208 insertions(+), 24 deletions(-) diff --git a/generator/internal/descriptor_utils_test.cc b/generator/internal/descriptor_utils_test.cc index 8cf9147d7f748..ecd2b6fc4f070 100644 --- a/generator/internal/descriptor_utils_test.cc +++ b/generator/internal/descriptor_utils_test.cc @@ -1281,7 +1281,7 @@ INSTANTIATE_TEST_SUITE_P( MethodVarsTestValues("my.service.v1.Service.Method1", "method_http_query_parameters", R"""(, rest_internal::TrimEmptyQueryParameters({std::make_pair("number", std::to_string(request.number())), - std::make_pair("toggle", request.toggle() ? "1" : "0"), + std::make_pair("toggle", (request.toggle() ? "1" : "0")), std::make_pair("title", request.title()), std::make_pair("parent", request.parent())}))"""), // Method2 @@ -1310,7 +1310,7 @@ INSTANTIATE_TEST_SUITE_P( "method_http_query_parameters", R"""(, rest_internal::TrimEmptyQueryParameters({std::make_pair("number", std::to_string(request.number())), std::make_pair("name", request.name()), - std::make_pair("toggle", request.toggle() ? "1" : "0"), + std::make_pair("toggle", (request.toggle() ? "1" : "0")), std::make_pair("title", request.title())}))"""), // Method3 MethodVarsTestValues("my.service.v1.Service.Method3", diff --git a/generator/internal/http_option_utils.cc b/generator/internal/http_option_utils.cc index 38acc073b9163..fce38e0c8b675 100644 --- a/generator/internal/http_option_utils.cc +++ b/generator/internal/http_option_utils.cc @@ -195,7 +195,52 @@ void SetHttpDerivedMethodVars( absl::visit(HttpInfoVisitor(method, method_vars), parsed_http_info); } -// Request fields not appering in the path may not wind up as part of the json +absl::optional DetermineQueryParameterInfo( + google::protobuf::FieldDescriptor const& field) { + static auto* const kSupportedWellKnownValueTypes = [] { + auto foo = std::make_unique< + std::unordered_map>(); + foo->emplace("google.protobuf.BoolValue", + protobuf::FieldDescriptor::CPPTYPE_BOOL); + foo->emplace("google.protobuf.DoubleValue", + protobuf::FieldDescriptor::CPPTYPE_DOUBLE); + foo->emplace("google.protobuf.FloatValue", + protobuf::FieldDescriptor::CPPTYPE_FLOAT); + foo->emplace("google.protobuf.Int32Value", + protobuf::FieldDescriptor::CPPTYPE_INT32); + foo->emplace("google.protobuf.Int64Value", + protobuf::FieldDescriptor::CPPTYPE_INT64); + foo->emplace("google.protobuf.StringValue", + protobuf::FieldDescriptor::CPPTYPE_STRING); + foo->emplace("google.protobuf.UInt32Value", + protobuf::FieldDescriptor::CPPTYPE_UINT32); + foo->emplace("google.protobuf.UInt64Value", + protobuf::FieldDescriptor::CPPTYPE_UINT64); + return foo.release(); + }(); + + absl::optional param_info; + // Only attempt to make non-repeated, simple fields query parameters. + if (!field.is_repeated() && !field.options().deprecated()) { + if (field.cpp_type() != protobuf::FieldDescriptor::CPPTYPE_MESSAGE) { + param_info = QueryParameterInfo{ + field.cpp_type(), absl::StrCat("request.", field.name(), "()"), + false}; + } else { + // But also consider protobuf well known types that wrap simple types. + auto iter = kSupportedWellKnownValueTypes->find( + field.message_type()->full_name()); + if (iter != kSupportedWellKnownValueTypes->end()) { + param_info = QueryParameterInfo{ + iter->second, absl::StrCat("request.", field.name(), "().value()"), + true}; + } + } + } + return param_info; +} + +// Request fields not appearing in the path may not wind up as part of the json // request body, so per https://cloud.google.com/apis/design/standard_methods, // for HTTP transcoding we need to turn the request fields into query // parameters. @@ -211,36 +256,41 @@ void SetHttpQueryParameters( : method(method), method_vars(method_vars) {} void FormatQueryParameterCode( std::vector const& param_field_names) { - std::vector> + std::vector> remaining_request_fields; auto const* request = method.input_type(); for (int i = 0; i < request->field_count(); ++i) { auto const* field = request->field(i); - // Only attempt to make non-repeated, simple fields query parameters. - if (!field->is_repeated() && !field->options().deprecated() && - field->cpp_type() != protobuf::FieldDescriptor::CPPTYPE_MESSAGE) { - if (!internal::Contains(param_field_names, field->name())) { - remaining_request_fields.emplace_back(field->name(), - field->cpp_type()); - } + auto param_info = DetermineQueryParameterInfo(*field); + if (param_info && + !internal::Contains(param_field_names, field->name())) { + remaining_request_fields.emplace_back(field->name(), *param_info); } } + auto format = [](auto* out, auto const& i) { - if (i.second == protobuf::FieldDescriptor::CPPTYPE_STRING) { - out->append(absl::StrFormat("std::make_pair(\"%s\", request.%s())", - i.first, i.first)); - return; + std::string field_access; + if (i.second.cpp_type == protobuf::FieldDescriptor::CPPTYPE_STRING) { + field_access = i.second.request_field_accessor; + } else if (i.second.cpp_type == + protobuf::FieldDescriptor::CPPTYPE_BOOL) { + field_access = absl::StrCat("(", i.second.request_field_accessor, + R"""( ? "1" : "0"))"""); + } else { + field_access = absl::StrCat("std::to_string(", + i.second.request_field_accessor, ")"); } - if (i.second == protobuf::FieldDescriptor::CPPTYPE_BOOL) { + + if (i.second.check_presence) { out->append(absl::StrFormat( - R"""(std::make_pair("%s", request.%s() ? "1" : "0"))""", i.first, - i.first)); - return; + R"""(std::make_pair("%s", (request.has_%s() ? %s : "")))""", + i.first, i.first, field_access)); + } else { + out->append(absl::StrFormat(R"""(std::make_pair("%s", %s))""", + i.first, field_access)); } - out->append(absl::StrFormat( - "std::make_pair(\"%s\", std::to_string(request.%s()))", i.first, - i.first)); }; + if (remaining_request_fields.empty()) { method_vars["method_http_query_parameters"] = ""; } else { diff --git a/generator/internal/http_option_utils.h b/generator/internal/http_option_utils.h index 7b5474fefb8cb..40972cbe2541c 100644 --- a/generator/internal/http_option_utils.h +++ b/generator/internal/http_option_utils.h @@ -69,6 +69,24 @@ void SetHttpDerivedMethodVars( google::protobuf::MethodDescriptor const& method, VarsDictionary& method_vars); +struct QueryParameterInfo { + protobuf::FieldDescriptor::CppType cpp_type; + // A code fragment the generator emits to access the value of the field. + std::string request_field_accessor; + // Check presence for MESSAGE types as their default values may result in + // undesired behavior. + bool check_presence; +}; + +/** + * Determine if a field is a query parameter candidate, such that it's a + * non-repeated field that is also not an aggregate type. This includes numeric, + * bool, and string native protobuf data types, as well as, protobuf "Well Known + * Types" that wrap those data types. + */ +absl::optional DetermineQueryParameterInfo( + google::protobuf::FieldDescriptor const& field); + /** * Sets the "method_http_query_parameters" value in method_vars based on the * parsed_http_info. diff --git a/generator/internal/http_option_utils_test.cc b/generator/internal/http_option_utils_test.cc index 2f7f408d58a5c..0e0ffdc7a9c6b 100644 --- a/generator/internal/http_option_utils_test.cc +++ b/generator/internal/http_option_utils_test.cc @@ -106,6 +106,42 @@ syntax = "proto3"; package google.protobuf; // Leading comments about message Empty. message Empty {} +message DoubleValue { + // The double value. + double value = 1; +} +message FloatValue { + // The float value. + float value = 1; +} +message Int64Value { + // The int64 value. + int64 value = 1; +} +message UInt64Value { + // The uint64 value. + uint64 value = 1; +} +message Int32Value { + // The int32 value. + int32 value = 1; +} +message UInt32Value { + // The uint32 value. + uint32 value = 1; +} +message BoolValue { + // The bool value. + bool value = 1; +} +message StringValue { + // The string value. + string value = 1; +} +message BytesValue { + // The bytes value. + bytes value = 1; +} )"""; char const* const kServiceProto = @@ -275,6 +311,65 @@ char const* const kServiceProtoWithoutVersion = " }\n" "}\n"; +char const* const kBigQueryServiceProto = R"""( +syntax = "proto3"; +package my.package.v1; +import "google/api/annotations.proto"; +import "google/api/client.proto"; +import "google/api/http.proto"; +import "google/protobuf/well_known.proto"; + +service BigQueryLikeService { + // RPC to get the results of a query job. + rpc GetQueryResults(GetQueryResultsRequest) + returns (GetQueryResultsResponse) { + option (google.api.http) = { + get: "/bigquery/v2/projects/{project_id=*}/queries/{job_id=*}" + }; + } +} + +// Request object of GetQueryResults. +message GetQueryResultsRequest { + // Required. Project ID of the query job. + string project_id = 1; + + // Required. Job ID of the query job. + string job_id = 2; + + // Zero-based index of the starting row. + google.protobuf.UInt64Value start_index = 3; + + // Page token, returned by a previous call, to request the next page of + // results. + google.protobuf.StringValue page_token = 4; + + // Maximum number of results to read. + google.protobuf.UInt32Value max_results = 5; + + // The geographic location of the job. + google.protobuf.BoolValue include_location = 7; + + // Double field. + google.protobuf.DoubleValue double_value = 8; + + // Float field. + google.protobuf.FloatValue float_value = 9; + + // Int32 field. + google.protobuf.Int32Value int32_value = 10; + + // Int64 field. + google.protobuf.Int64Value int64_value = 11; + + // Non supported message type that is not a query param. + google.protobuf.Empty non_supported_type = 12; +} + +// Response object of GetQueryResults. +message GetQueryResultsResponse {} +)"""; + struct MethodVarsTestValues { MethodVarsTestValues(std::string m, std::string k, std::string v) : method(std::move(m)), @@ -298,7 +393,9 @@ class HttpOptionUtilsTest {std::string("google/protobuf/well_known.proto"), kWellKnownProto}, {std::string("google/foo/v1/service.proto"), kServiceProto}, {std::string("google/foo/v1/service_without_version.proto"), - kServiceProtoWithoutVersion}}), + kServiceProtoWithoutVersion}, + {std::string("google/foo/v1/big_query_service.proto"), + kBigQueryServiceProto}}), source_tree_db_(&source_tree_), merged_db_(&simple_db_, &source_tree_db_), pool_(&merged_db_, &collector_) { @@ -550,7 +647,26 @@ TEST_F(HttpOptionUtilsTest, SetHttpGetQueryParametersGetPaginated) { rest_internal::TrimEmptyQueryParameters({std::make_pair("page_size", std::to_string(request.page_size())), std::make_pair("page_token", request.page_token()), std::make_pair("name", request.name()), - std::make_pair("include_foo", request.include_foo() ? "1" : "0")}))""")); + std::make_pair("include_foo", (request.include_foo() ? "1" : "0"))}))""")); +} + +TEST_F(HttpOptionUtilsTest, + SetHttpGetQueryParametersGetWellKnownTypesPaginated) { + FileDescriptor const* service_file_descriptor = + pool_.FindFileByName("google/foo/v1/big_query_service.proto"); + MethodDescriptor const* method = + service_file_descriptor->service(0)->method(0); + VarsDictionary vars; + SetHttpQueryParameters(ParseHttpExtension(*method), *method, vars); + EXPECT_THAT(vars.at("method_http_query_parameters"), Eq(R"""(, + rest_internal::TrimEmptyQueryParameters({std::make_pair("start_index", (request.has_start_index() ? std::to_string(request.start_index().value()) : "")), + std::make_pair("page_token", (request.has_page_token() ? request.page_token().value() : "")), + std::make_pair("max_results", (request.has_max_results() ? std::to_string(request.max_results().value()) : "")), + std::make_pair("include_location", (request.has_include_location() ? (request.include_location().value() ? "1" : "0") : "")), + std::make_pair("double_value", (request.has_double_value() ? std::to_string(request.double_value().value()) : "")), + std::make_pair("float_value", (request.has_float_value() ? std::to_string(request.float_value().value()) : "")), + std::make_pair("int32_value", (request.has_int32_value() ? std::to_string(request.int32_value().value()) : "")), + std::make_pair("int64_value", (request.has_int64_value() ? std::to_string(request.int64_value().value()) : ""))}))""")); } TEST_F(HttpOptionUtilsTest, HasHttpAnnotationRoutingHeaderSuccess) {