Skip to content

Commit

Permalink
Support adding gRPC metadata to outgoing RPCs (#601)
Browse files Browse the repository at this point in the history
Added the RemoteHeaders field to DialParams, allowing the client to provide
a map[string][]string of headers. These get attached to outgoing RPCs with
client interceptors.
  • Loading branch information
jayconrod authored Nov 12, 2024
1 parent f4821a2 commit 3d9543e
Show file tree
Hide file tree
Showing 6 changed files with 296 additions and 19 deletions.
2 changes: 2 additions & 0 deletions go/pkg/client/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ go_library(
"@org_golang_google_grpc//codes",
"@org_golang_google_grpc//credentials",
"@org_golang_google_grpc//credentials/oauth",
"@org_golang_google_grpc//metadata",
"@org_golang_google_grpc//status",
"@org_golang_google_protobuf//encoding/prototext",
"@org_golang_google_protobuf//encoding/protowire",
Expand Down Expand Up @@ -82,6 +83,7 @@ go_test(
"@go_googleapis//google/rpc:status_go_proto",
"@org_golang_google_grpc//:go_default_library",
"@org_golang_google_grpc//codes",
"@org_golang_google_grpc//metadata",
"@org_golang_google_grpc//status",
"@org_golang_google_protobuf//proto",
"@org_golang_google_protobuf//testing/protocmp",
Expand Down
33 changes: 33 additions & 0 deletions go/pkg/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"time"

"errors"

"github.com/bazelbuild/remote-apis-sdks/go/pkg/actas"
"github.com/bazelbuild/remote-apis-sdks/go/pkg/balancer"
"github.com/bazelbuild/remote-apis-sdks/go/pkg/chunker"
Expand All @@ -26,6 +27,7 @@ import (
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/oauth"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"

// Redundant imports are required for the google3 mirror. Aliases should not be changed.
Expand Down Expand Up @@ -189,6 +191,7 @@ type Client struct {
casDownloaders *semaphore.Weighted
casDownloadRequests chan *downloadRequest
rpcTimeouts RPCTimeouts
remoteHeaders map[string]string
creds credentials.PerRPCCredentials
uploadOnce sync.Once
downloadOnce sync.Once
Expand Down Expand Up @@ -551,6 +554,10 @@ type DialParams struct {
//
// If this is specified, TLSClientAuthCert must also be specified.
TLSClientAuthKey string

// RemoteHeaders specifies additional gRPC metadata headers to be passed with
// each RPC. These headers are not meant to be used for authentication.
RemoteHeaders map[string][]string
}

func createTLSConfig(params DialParams) (*tls.Config, error) {
Expand Down Expand Up @@ -651,6 +658,32 @@ func OptsFromParams(ctx context.Context, params DialParams) ([]grpc.DialOption,
opts = append(opts, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)))
}

if len(params.RemoteHeaders) > 0 {
md := metadata.MD(params.RemoteHeaders)
opts = append(
opts,
grpc.WithChainUnaryInterceptor(func(
ctx context.Context,
method string,
req, reply any,
cc *grpc.ClientConn,
invoker grpc.UnaryInvoker,
opts ...grpc.CallOption) error {
ctx = metadata.NewOutgoingContext(ctx, md)
return invoker(ctx, method, req, reply, cc, opts...)
}),
grpc.WithChainStreamInterceptor(func(
ctx context.Context,
desc *grpc.StreamDesc,
cc *grpc.ClientConn,
method string,
streamer grpc.Streamer,
opts ...grpc.CallOption) (grpc.ClientStream, error) {
ctx = metadata.NewOutgoingContext(ctx, md)
return streamer(ctx, desc, cc, method, opts...)
}))
}

return opts, authUsed, nil
}

Expand Down
104 changes: 104 additions & 0 deletions go/pkg/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,21 @@ package client
import (
"context"
"errors"
"io"
"net"
"os"
"path"
"testing"

"github.com/bazelbuild/remote-apis-sdks/go/pkg/digest"
repb "github.com/bazelbuild/remote-apis/build/bazel/remote/execution/v2"
svpb "github.com/bazelbuild/remote-apis/build/bazel/semver"
"github.com/google/go-cmp/cmp"
bspb "google.golang.org/genproto/googleapis/bytestream"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)

const (
Expand Down Expand Up @@ -266,3 +275,98 @@ func TestResourceName(t *testing.T) {
})
}
}

func TestRemoteHeaders(t *testing.T) {
one := []byte{1}
oneDigest := digest.NewFromBlob(one)
want := map[string][]string{"x-test": {"test123"}}
checkHeaders := func(t *testing.T, got metadata.MD) {
t.Helper()
for k, wantV := range want {
if gotV, ok := got[k]; !ok {
t.Errorf("header %s not seen in server metadata", k)
} else if len(gotV) != 1 {
t.Errorf("header %s seen %d times", k, len(wantV))
} else if diff := cmp.Diff(gotV, wantV); diff != "" {
t.Errorf("got header %s value %q; want %q; diff (-got, +want) %s", k, gotV, wantV, diff)
}
}
}

ctx := context.Background()
listener, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatalf("Cannot listen: %v", err)
}
defer listener.Close()
server := grpc.NewServer()
fake := &fakeByteStreamForRemoteHeaders{}
bspb.RegisterByteStreamServer(server, fake)
repb.RegisterCapabilitiesServer(server, &fakeCapabilitiesForRemoteHeaders{})
go server.Serve(listener)
defer server.Stop()

client, err := NewClient(ctx, "instance", DialParams{
Service: listener.Addr().String(),
NoSecurity: true,
RemoteHeaders: want,
})
if err != nil {
t.Fatalf("Cannot create client: %v", err)
}
defer client.Close()

t.Run("unary", func(t *testing.T) {
if _, err := client.WriteBlob(ctx, one); err != nil {
t.Fatalf("Writing blob: %v", err)
}
checkHeaders(t, fake.writeHeaders)
})

t.Run("stream", func(t *testing.T) {
if _, _, err := client.ReadBlob(ctx, oneDigest); err != nil {
t.Fatalf("Reading blob: %v", err)
}
checkHeaders(t, fake.readHeaders)
})
}

type fakeByteStreamForRemoteHeaders struct {
bspb.UnimplementedByteStreamServer
readHeaders, writeHeaders metadata.MD
}

func (bs *fakeByteStreamForRemoteHeaders) Read(req *bspb.ReadRequest, stream bspb.ByteStream_ReadServer) error {
md, ok := metadata.FromIncomingContext(stream.Context())
if !ok {
return status.Error(codes.InvalidArgument, "metadata not found")
}
bs.readHeaders = md
stream.Send(&bspb.ReadResponse{Data: []byte{1}})
return nil
}

func (bs *fakeByteStreamForRemoteHeaders) Write(stream bspb.ByteStream_WriteServer) error {
md, ok := metadata.FromIncomingContext(stream.Context())
if !ok {
return status.Error(codes.InvalidArgument, "metadata not found")
}
bs.writeHeaders = md
for {
_, err := stream.Recv()
if err == io.EOF {
break
} else if err != nil {
return err
}
}
return stream.SendAndClose(&bspb.WriteResponse{})
}

type fakeCapabilitiesForRemoteHeaders struct {
repb.UnimplementedCapabilitiesServer
}

func (cap *fakeCapabilitiesForRemoteHeaders) GetCapabilities(ctx context.Context, req *repb.GetCapabilitiesRequest) (*repb.ServerCapabilities, error) {
return &repb.ServerCapabilities{}, nil
}
5 changes: 5 additions & 0 deletions go/pkg/flags/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ var (
// Instance gives the instance of remote execution to test (in
// projects/[PROJECT_ID]/instances/[INSTANCE_NAME] format for Google RBE).
Instance = flag.String("instance", "", "The instance ID to target when calling remote execution via gRPC (e.g., projects/$PROJECT/instances/default_instance for Google RBE).")
// RemoteHeaders stores additional headers to pass with each RPC.
RemoteHeaders map[string][]string
// CASConcurrency specifies the maximum number of concurrent upload & download RPCs that can be in flight.
CASConcurrency = flag.Int("cas_concurrency", client.DefaultCASConcurrency, "Num concurrent upload / download RPCs that the SDK is allowed to do.")
// MaxConcurrentRequests denotes the maximum number of concurrent RPCs on a single gRPC connection.
Expand Down Expand Up @@ -85,6 +87,8 @@ func init() {
// themselves with every RPC, otherwise it is easy to accidentally enforce a timeout on
// WaitExecution, for example.
flag.Var((*moreflag.StringMapValue)(&RPCTimeouts), "rpc_timeouts", "Comma-separated key value pairs in the form rpc_name=timeout. The key for default RPC is named default. 0 indicates no timeout. Example: GetActionResult=500ms,Execute=0,default=10s.")

flag.Var((*moreflag.StringListMapValue)(&RemoteHeaders), "remote_headers", "Comma-separated headers to pass with each RPC in the form key=value.")
}

// NewClientFromFlags connects to a remote execution service and returns a client suitable for higher-level
Expand Down Expand Up @@ -152,5 +156,6 @@ func NewClientFromFlags(ctx context.Context, opts ...client.Opt) (*client.Client
TLSClientAuthCert: *TLSClientAuthCert,
TLSClientAuthKey: *TLSClientAuthKey,
MaxConcurrentRequests: uint32(*MaxConcurrentRequests),
RemoteHeaders: RemoteHeaders,
}, opts...)
}
84 changes: 69 additions & 15 deletions go/pkg/moreflag/moreflag.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,22 +61,16 @@ func (m *StringMapValue) String() string {
// Set updates the map with key and value pair(s) in the format key1=value1,key2=value2.
func (m *StringMapValue) Set(s string) error {
*m = make(map[string]string)
pairs := strings.Split(s, ",")
for _, p := range pairs {
if p == "" {
continue
}
pair := strings.Split(p, "=")
if len(pair) != 2 {
return fmt.Errorf("wrong format for key-value pair: %v", p)
}
if pair[0] == "" {
return fmt.Errorf("key not provided")
}
if _, ok := (*m)[pair[0]]; ok {
return fmt.Errorf("key %v already defined in list of key-value pairs %v", pair[0], s)
pairs, err := parsePairs(s)
if err != nil {
return err
}
for i := 0; i < len(pairs); i += 2 {
k, v := pairs[i], pairs[i+1]
if _, ok := (*m)[k]; ok {
return fmt.Errorf("key %v already defined in list of key-value pairs %v", k, s)
}
(*m)[pair[0]] = pair[1]
(*m)[k] = v
}
return nil
}
Expand Down Expand Up @@ -107,3 +101,63 @@ func (m *StringListValue) Set(s string) error {
func (m *StringListValue) Get() interface{} {
return []string(*m)
}

// StringListMapValue is like StringMapValue, but it allows a key to be used
// with multiple values. The command-line syntax is the same: for example,
// the string key1=a,key1=b,key2=c parses as a map with "key1" having values
// "a" and "b", and "key2" having the value "c".
type StringListMapValue map[string][]string

func (m *StringListMapValue) String() string {
keys := make([]string, 0, len(*m))
for key := range *m {
keys = append(keys, key)
}
sort.Strings(keys)
var b bytes.Buffer
for _, key := range keys {
for _, value := range (*m)[key] {
if b.Len() > 0 {
b.WriteRune(',')
}
b.WriteString(key)
b.WriteRune('=')
b.WriteString(value)
}
}
return b.String()
}

func (m *StringListMapValue) Set(s string) error {
*m = make(map[string][]string)
pairs, err := parsePairs(s)
if err != nil {
return err
}
for i := 0; i < len(pairs); i += 2 {
k, v := pairs[i], pairs[i+1]
(*m)[k] = append((*m)[k], v)
}
return nil
}

// parsePairs parses a string of the form "key1=value1,key2=value2", returning
// a slice with an even number of strings like "key1", "value1", "key2",
// "value2". Pairs are separated by ','; keys and values are separated by '='.
func parsePairs(s string) ([]string, error) {
var pairs []string
for _, p := range strings.Split(s, ",") {
if p == "" {
continue
}
k, v, ok := strings.Cut(p, "=")
if !ok {
return nil, fmt.Errorf("wrong format for key=value pair: %v", p)
}
if k == "" {
return nil, fmt.Errorf("key not provided")
}
pairs = append(pairs, k, v)
}
return pairs, nil
}
Loading

0 comments on commit 3d9543e

Please sign in to comment.