Skip to content

Commit

Permalink
Merge branch 'grpc:master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
mustafasen81 committed Jan 9, 2024
2 parents 2e642bb + 3a8270f commit d437bd7
Show file tree
Hide file tree
Showing 19 changed files with 617 additions and 58 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/codeql-analysis.yml
Expand Up @@ -29,9 +29,9 @@ jobs:

# Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL
uses: github/codeql-action/init@407ffafae6a767df3e0230c3df91b6443ae8df75 # v2.22.8
uses: github/codeql-action/init@1500a131381b66de0c52ac28abb13cd79f4b7ecc # v2.22.12
with:
languages: go

- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@407ffafae6a767df3e0230c3df91b6443ae8df75 # v2.22.8
uses: github/codeql-action/analyze@1500a131381b66de0c52ac28abb13cd79f4b7ecc # v2.22.12
2 changes: 1 addition & 1 deletion credentials/alts/alts_test.go
Expand Up @@ -397,7 +397,7 @@ func establishAltsConnection(t *testing.T, handshakerAddress, serverAddress stri
if err == nil {
break
}
if code := status.Code(err); code == codes.Unavailable || code == codes.DeadlineExceeded {
if code := status.Code(err); code == codes.Unavailable {
// The server is not ready yet. Try again.
continue
}
Expand Down
10 changes: 6 additions & 4 deletions credentials/alts/internal/handshaker/handshaker.go
Expand Up @@ -61,6 +61,8 @@ var (
// control number of concurrent created (but not closed) handshakes.
clientHandshakes = semaphore.NewWeighted(int64(envconfig.ALTSMaxConcurrentHandshakes))
serverHandshakes = semaphore.NewWeighted(int64(envconfig.ALTSMaxConcurrentHandshakes))
// errDropped occurs when maxPendingHandshakes is reached.
errDropped = errors.New("maximum number of concurrent ALTS handshakes is reached")
// errOutOfBound occurs when the handshake service returns a consumed
// bytes value larger than the buffer that was passed to it originally.
errOutOfBound = errors.New("handshaker service consumed bytes value is out-of-bound")
Expand Down Expand Up @@ -154,8 +156,8 @@ func NewServerHandshaker(ctx context.Context, conn *grpc.ClientConn, c net.Conn,
// ClientHandshake starts and completes a client ALTS handshake for GCP. Once
// done, ClientHandshake returns a secure connection.
func (h *altsHandshaker) ClientHandshake(ctx context.Context) (net.Conn, credentials.AuthInfo, error) {
if err := clientHandshakes.Acquire(ctx, 1); err != nil {
return nil, nil, err
if !clientHandshakes.TryAcquire(1) {
return nil, nil, errDropped
}
defer clientHandshakes.Release(1)

Expand Down Expand Up @@ -207,8 +209,8 @@ func (h *altsHandshaker) ClientHandshake(ctx context.Context) (net.Conn, credent
// ServerHandshake starts and completes a server ALTS handshake for GCP. Once
// done, ServerHandshake returns a secure connection.
func (h *altsHandshaker) ServerHandshake(ctx context.Context) (net.Conn, credentials.AuthInfo, error) {
if err := serverHandshakes.Acquire(ctx, 1); err != nil {
return nil, nil, err
if !serverHandshakes.TryAcquire(1) {
return nil, nil, errDropped
}
defer serverHandshakes.Release(1)

Expand Down
12 changes: 6 additions & 6 deletions credentials/alts/internal/handshaker/handshaker_test.go
Expand Up @@ -193,10 +193,10 @@ func (s) TestClientHandshake(t *testing.T) {
}()
}

// Ensure that there are no errors.
// Ensure all errors are expected.
for i := 0; i < testCase.numberOfHandshakes; i++ {
if err := <-errc; err != nil {
t.Errorf("ClientHandshake() = _, %v, want _, <nil>", err)
if err := <-errc; err != nil && err != errDropped {
t.Errorf("ClientHandshake() = _, %v, want _, <nil> or %v", err, errDropped)
}
}

Expand Down Expand Up @@ -250,10 +250,10 @@ func (s) TestServerHandshake(t *testing.T) {
}()
}

// Ensure that there are no errors.
// Ensure all errors are expected.
for i := 0; i < testCase.numberOfHandshakes; i++ {
if err := <-errc; err != nil {
t.Errorf("ServerHandshake() = _, %v, want _, <nil>", err)
if err := <-errc; err != nil && err != errDropped {
t.Errorf("ServerHandshake() = _, %v, want _, <nil> or %v", err, errDropped)
}
}

Expand Down
6 changes: 3 additions & 3 deletions credentials/tls/certprovider/pemfile/builder.go
Expand Up @@ -29,7 +29,7 @@ import (
)

const (
pluginName = "file_watcher"
PluginName = "file_watcher"
defaultRefreshInterval = 10 * time.Minute
)

Expand All @@ -48,13 +48,13 @@ func (p *pluginBuilder) ParseConfig(c any) (*certprovider.BuildableConfig, error
if err != nil {
return nil, err
}
return certprovider.NewBuildableConfig(pluginName, opts.canonical(), func(certprovider.BuildOptions) certprovider.Provider {
return certprovider.NewBuildableConfig(PluginName, opts.canonical(), func(certprovider.BuildOptions) certprovider.Provider {
return newProvider(opts)
}), nil
}

func (p *pluginBuilder) Name() string {
return pluginName
return PluginName
}

func pluginConfigFromJSON(jd json.RawMessage) (Options, error) {
Expand Down
4 changes: 2 additions & 2 deletions internal/testutils/xds/e2e/setup_certs.go
Expand Up @@ -98,7 +98,7 @@ func CreateClientTLSCredentials(t *testing.T) credentials.TransportCredentials {

// CreateServerTLSCredentials creates server-side TLS transport credentials
// using certificate and key files from testdata/x509 directory.
func CreateServerTLSCredentials(t *testing.T) credentials.TransportCredentials {
func CreateServerTLSCredentials(t *testing.T, clientAuth tls.ClientAuthType) credentials.TransportCredentials {
t.Helper()

cert, err := tls.LoadX509KeyPair(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem"))
Expand All @@ -114,7 +114,7 @@ func CreateServerTLSCredentials(t *testing.T) credentials.TransportCredentials {
t.Fatal("Failed to append certificates")
}
return credentials.NewTLS(&tls.Config{
ClientAuth: tls.RequireAndVerifyClientCert,
ClientAuth: clientAuth,
Certificates: []tls.Certificate{cert},
ClientCAs: ca,
})
Expand Down
8 changes: 6 additions & 2 deletions rpc_util.go
Expand Up @@ -640,14 +640,18 @@ func encode(c baseCodec, msg any) ([]byte, error) {
return b, nil
}

// compress returns the input bytes compressed by compressor or cp. If both
// compressors are nil, returns nil.
// compress returns the input bytes compressed by compressor or cp.
// If both compressors are nil, or if the message has zero length, returns nil,
// indicating no compression was done.
//
// TODO(dfawley): eliminate cp parameter by wrapping Compressor in an encoding.Compressor.
func compress(in []byte, cp Compressor, compressor encoding.Compressor) ([]byte, error) {
if compressor == nil && cp == nil {
return nil, nil
}
if len(in) == 0 {
return nil, nil
}
wrapErr := func(err error) error {
return status.Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error())
}
Expand Down
30 changes: 22 additions & 8 deletions test/compressor_test.go
Expand Up @@ -290,19 +290,29 @@ func (s) TestSetSendCompressorSuccess(t *testing.T) {
for _, tt := range []struct {
name string
desc string
payload *testpb.Payload
dialOpts []grpc.DialOption
resCompressor string
wantCompressInvokes int32
}{
{
name: "identity_request_and_gzip_response",
desc: "request is uncompressed and response is gzip compressed",
payload: &testpb.Payload{Body: []byte("payload")},
resCompressor: "gzip",
wantCompressInvokes: 1,
},
{
name: "identity_request_and_empty_response",
desc: "request is uncompressed and response is gzip compressed",
payload: nil,
resCompressor: "gzip",
wantCompressInvokes: 0,
},
{
name: "gzip_request_and_identity_response",
desc: "request is gzip compressed and response is uncompressed with identity",
payload: &testpb.Payload{Body: []byte("payload")},
resCompressor: "identity",
dialOpts: []grpc.DialOption{
// Use WithCompressor instead of UseCompressor to avoid counting
Expand All @@ -314,24 +324,26 @@ func (s) TestSetSendCompressorSuccess(t *testing.T) {
} {
t.Run(tt.name, func(t *testing.T) {
t.Run("unary", func(t *testing.T) {
testUnarySetSendCompressorSuccess(t, tt.resCompressor, tt.wantCompressInvokes, tt.dialOpts)
testUnarySetSendCompressorSuccess(t, tt.payload, tt.resCompressor, tt.wantCompressInvokes, tt.dialOpts)
})

t.Run("stream", func(t *testing.T) {
testStreamSetSendCompressorSuccess(t, tt.resCompressor, tt.wantCompressInvokes, tt.dialOpts)
testStreamSetSendCompressorSuccess(t, tt.payload, tt.resCompressor, tt.wantCompressInvokes, tt.dialOpts)
})
})
}
}

func testUnarySetSendCompressorSuccess(t *testing.T, resCompressor string, wantCompressInvokes int32, dialOpts []grpc.DialOption) {
func testUnarySetSendCompressorSuccess(t *testing.T, payload *testpb.Payload, resCompressor string, wantCompressInvokes int32, dialOpts []grpc.DialOption) {
wc := setupGzipWrapCompressor(t)
ss := &stubserver.StubServer{
EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
UnaryCallF: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
if err := grpc.SetSendCompressor(ctx, resCompressor); err != nil {
return nil, err
}
return &testpb.Empty{}, nil
return &testpb.SimpleResponse{
Payload: payload,
}, nil
},
}
if err := ss.Start(nil, dialOpts...); err != nil {
Expand All @@ -342,7 +354,7 @@ func testUnarySetSendCompressorSuccess(t *testing.T, resCompressor string, wantC
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()

if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil {
if _, err := ss.Client.UnaryCall(ctx, &testpb.SimpleRequest{}); err != nil {
t.Fatalf("Unexpected unary call error, got: %v, want: nil", err)
}

Expand All @@ -352,7 +364,7 @@ func testUnarySetSendCompressorSuccess(t *testing.T, resCompressor string, wantC
}
}

func testStreamSetSendCompressorSuccess(t *testing.T, resCompressor string, wantCompressInvokes int32, dialOpts []grpc.DialOption) {
func testStreamSetSendCompressorSuccess(t *testing.T, payload *testpb.Payload, resCompressor string, wantCompressInvokes int32, dialOpts []grpc.DialOption) {
wc := setupGzipWrapCompressor(t)
ss := &stubserver.StubServer{
FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error {
Expand All @@ -364,7 +376,9 @@ func testStreamSetSendCompressorSuccess(t *testing.T, resCompressor string, want
return err
}

return stream.Send(&testpb.StreamingOutputCallResponse{})
return stream.Send(&testpb.StreamingOutputCallResponse{
Payload: payload,
})
},
}
if err := ss.Start(nil, dialOpts...); err != nil {
Expand Down
7 changes: 4 additions & 3 deletions test/xds/xds_client_certificate_providers_test.go
Expand Up @@ -20,6 +20,7 @@ package xds_test

import (
"context"
"crypto/tls"
"fmt"
"strings"
"testing"
Expand Down Expand Up @@ -198,7 +199,7 @@ func (s) TestClientSideXDS_WithNoCertificateProvidersInBootstrap_Failure(t *test
testutils.AwaitState(ctx, t, cc, connectivity.TransientFailure)

// Make an RPC and ensure that expected error is returned.
wantErr := fmt.Sprintf("identitiy certificate provider instance name %q missing in bootstrap configuration", e2e.ClientSideCertProviderInstance)
wantErr := fmt.Sprintf("identity certificate provider instance name %q missing in bootstrap configuration", e2e.ClientSideCertProviderInstance)
client := testgrpc.NewTestServiceClient(cc)
if _, err := client.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) != codes.Unavailable || !strings.Contains(err.Error(), wantErr) {
t.Fatalf("EmptyCall() failed: %v, wantCode: %s, wantErr: %s", err, codes.Unavailable, wantErr)
Expand Down Expand Up @@ -226,7 +227,7 @@ func (s) TestClientSideXDS_WithValidAndInvalidSecurityConfiguration(t *testing.T
// backend1 configured with TLS creds, represents cluster1
// backend2 configured with insecure creds, represents cluster2
// backend3 configured with insecure creds, represents cluster3
creds := e2e.CreateServerTLSCredentials(t)
creds := e2e.CreateServerTLSCredentials(t, tls.RequireAndVerifyClientCert)
server1 := stubserver.StartTestService(t, nil, grpc.Creds(creds))
defer server1.Stop()
server2 := stubserver.StartTestService(t, nil)
Expand Down Expand Up @@ -355,7 +356,7 @@ func (s) TestClientSideXDS_WithValidAndInvalidSecurityConfiguration(t *testing.T
}

// Make an RPC to be routed to cluster3 and verify that it fails.
const wantErr = `identitiy certificate provider instance name "non-existent-certificate-provider-instance-name" missing in bootstrap configuration`
const wantErr = `identity certificate provider instance name "non-existent-certificate-provider-instance-name" missing in bootstrap configuration`
if _, err := client.FullDuplexCall(ctx); status.Code(err) != codes.Unavailable || !strings.Contains(err.Error(), wantErr) {
t.Fatalf("FullDuplexCall failed: %v, wantCode: %s, wantErr: %s", err, codes.Unavailable, wantErr)
}
Expand Down
1 change: 0 additions & 1 deletion vet.sh
Expand Up @@ -186,7 +186,6 @@ GetSuffixMatch
GetTlsCertificateCertificateProviderInstance
GetValidationContextCertificateProviderInstance
XXXXX TODO: Remove the below deprecation usages:
CloseNotifier
Roots.Subjects
XXXXX PleaseIgnoreUnused'

Expand Down
6 changes: 4 additions & 2 deletions xds/bootstrap/bootstrap.go
Expand Up @@ -37,8 +37,10 @@ var registry = make(map[string]Credentials)
// Credentials interface encapsulates a credentials.Bundle builder
// that can be used for communicating with the xDS Management server.
type Credentials interface {
// Build returns a credential bundle associated with this credential.
Build(config json.RawMessage) (credentials.Bundle, error)
// Build returns a credential bundle associated with this credential, and
// a function to cleans up additional resources associated with this bundle
// when it is no longer needed.
Build(config json.RawMessage) (credentials.Bundle, func(), error)
// Name returns the credential name associated with this credential.
Name() string
}
Expand Down
6 changes: 3 additions & 3 deletions xds/bootstrap/bootstrap_test.go
Expand Up @@ -36,9 +36,9 @@ type testCredsBuilder struct {
config json.RawMessage
}

func (t *testCredsBuilder) Build(config json.RawMessage) (credentials.Bundle, error) {
func (t *testCredsBuilder) Build(config json.RawMessage) (credentials.Bundle, func(), error) {
t.config = config
return nil, nil
return nil, nil, nil
}

func (t *testCredsBuilder) Name() string {
Expand All @@ -53,7 +53,7 @@ func TestRegisterNew(t *testing.T) {

const sampleConfig = "sample_config"
rawMessage := json.RawMessage(sampleConfig)
if _, err := c.Build(rawMessage); err != nil {
if _, _, err := c.Build(rawMessage); err != nil {
t.Errorf("Build(%v) error = %v, want nil", rawMessage, err)
}

Expand Down
29 changes: 24 additions & 5 deletions xds/internal/xdsclient/bootstrap/bootstrap.go
Expand Up @@ -39,6 +39,7 @@ import (
"google.golang.org/grpc/internal/envconfig"
"google.golang.org/grpc/internal/pretty"
"google.golang.org/grpc/xds/bootstrap"
"google.golang.org/grpc/xds/internal/xdsclient/tlscreds"
)

const (
Expand All @@ -60,6 +61,7 @@ const (
func init() {
bootstrap.RegisterCredentials(&insecureCredsBuilder{})
bootstrap.RegisterCredentials(&googleDefaultCredsBuilder{})
bootstrap.RegisterCredentials(&tlsCredsBuilder{})
}

// For overriding in unit tests.
Expand All @@ -69,20 +71,32 @@ var bootstrapFileReadFunc = os.ReadFile
// package `xds/bootstrap` and encapsulates an insecure credential.
type insecureCredsBuilder struct{}

func (i *insecureCredsBuilder) Build(json.RawMessage) (credentials.Bundle, error) {
return insecure.NewBundle(), nil
func (i *insecureCredsBuilder) Build(json.RawMessage) (credentials.Bundle, func(), error) {
return insecure.NewBundle(), func() {}, nil
}

func (i *insecureCredsBuilder) Name() string {
return "insecure"
}

// tlsCredsBuilder implements the `Credentials` interface defined in
// package `xds/bootstrap` and encapsulates a TLS credential.
type tlsCredsBuilder struct{}

func (t *tlsCredsBuilder) Build(config json.RawMessage) (credentials.Bundle, func(), error) {
return tlscreds.NewBundle(config)
}

func (t *tlsCredsBuilder) Name() string {
return "tls"
}

// googleDefaultCredsBuilder implements the `Credentials` interface defined in
// package `xds/boostrap` and encapsulates a Google Default credential.
type googleDefaultCredsBuilder struct{}

func (d *googleDefaultCredsBuilder) Build(json.RawMessage) (credentials.Bundle, error) {
return google.NewDefaultCredentials(), nil
func (d *googleDefaultCredsBuilder) Build(json.RawMessage) (credentials.Bundle, func(), error) {
return google.NewDefaultCredentials(), func() {}, nil
}

func (d *googleDefaultCredsBuilder) Name() string {
Expand Down Expand Up @@ -151,6 +165,10 @@ type ServerConfig struct {
// when a resource is deleted, nor will it remove the existing resource value
// from its cache.
IgnoreResourceDeletion bool

// Cleanups are called when the xDS client for this server is closed. Allows
// cleaning up resources created specifically for this ServerConfig.
Cleanups []func()
}

// CredsDialOption returns the configured credentials as a grpc dial option.
Expand Down Expand Up @@ -206,12 +224,13 @@ func (sc *ServerConfig) UnmarshalJSON(data []byte) error {
if c == nil {
continue
}
bundle, err := c.Build(cc.Config)
bundle, cancel, err := c.Build(cc.Config)
if err != nil {
return fmt.Errorf("failed to build credentials bundle from bootstrap for %q: %v", cc.Type, err)
}
sc.Creds = ChannelCreds(cc)
sc.credsDialOption = grpc.WithCredentialsBundle(bundle)
sc.Cleanups = append(sc.Cleanups, cancel)
break
}
return nil
Expand Down

0 comments on commit d437bd7

Please sign in to comment.