diff --git a/client/client.go b/client/client.go index effadf029..e9bb4ec43 100644 --- a/client/client.go +++ b/client/client.go @@ -69,6 +69,8 @@ type config struct { tlsClientKey string userAgent string playgroundInstance string + streamInterceptors []grpc.StreamClientInterceptor + unaryInterceptors []grpc.UnaryClientInterceptor connectTimeout time.Duration retryTimeout time.Duration maxRetries uint @@ -151,6 +153,20 @@ func WithPlaygroundInstance(instance string) Opt { } } +// WithStreamInterceptors sets the interceptors to be used for streaming gRPC operations. +func WithStreamInterceptors(interceptors ...grpc.StreamClientInterceptor) Opt { + return func(c *config) { + c.streamInterceptors = interceptors + } +} + +// WithUnaryInterceptors sets the interceptors to be used for unary gRPC operations. +func WithUnaryInterceptors(interceptors ...grpc.UnaryClientInterceptor) Opt { + return func(c *config) { + c.unaryInterceptors = interceptors + } +} + // New creates a new Cerbos client. func New(address string, opts ...Opt) (Client, error) { grpcConn, _, err := mkConn(address, opts...) @@ -194,23 +210,39 @@ func mkDialOpts(conf *config) ([]grpc.DialOption, error) { dialOpts = append(dialOpts, grpc.WithConnectParams(grpc.ConnectParams{MinConnectTimeout: conf.connectTimeout})) } + streamInterceptors := conf.streamInterceptors + unaryInterceptors := conf.unaryInterceptors + if conf.maxRetries > 0 && conf.retryTimeout > 0 { - dialOpts = append(dialOpts, - grpc.WithChainStreamInterceptor( + streamInterceptors = append( + []grpc.StreamClientInterceptor{ grpc_retry.StreamClientInterceptor( grpc_retry.WithMax(conf.maxRetries), grpc_retry.WithPerRetryTimeout(conf.retryTimeout), ), - ), - grpc.WithChainUnaryInterceptor( + }, + streamInterceptors..., + ) + + unaryInterceptors = append( + []grpc.UnaryClientInterceptor{ grpc_retry.UnaryClientInterceptor( grpc_retry.WithMax(conf.maxRetries), grpc_retry.WithPerRetryTimeout(conf.retryTimeout), ), - ), + }, + unaryInterceptors..., ) } + if len(streamInterceptors) > 0 { + dialOpts = append(dialOpts, grpc.WithChainStreamInterceptor(streamInterceptors...)) + } + + if len(unaryInterceptors) > 0 { + dialOpts = append(dialOpts, grpc.WithChainUnaryInterceptor(unaryInterceptors...)) + } + if conf.plaintext { dialOpts = append(dialOpts, grpc.WithTransportCredentials(insecure.NewCredentials())) } else { diff --git a/client/client_test.go b/client/client_test.go index ce7abbd64..2e6d06bd5 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -13,7 +13,11 @@ import ( "time" "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + svcv1 "github.com/cerbos/cerbos/api/genpb/cerbos/svc/v1" "github.com/cerbos/cerbos/client" "github.com/cerbos/cerbos/client/testutil" "github.com/cerbos/cerbos/internal/test" @@ -108,6 +112,41 @@ func TestClient(t *testing.T) { }) }) } + + t.Run("interceptors", func(t *testing.T) { + errCanceled := status.Error(codes.Canceled, "canceled") + + t.Run("stream", func(t *testing.T) { + var called string + + c, err := client.NewAdminClientWithCredentials("unix:/dev/null", "username", "password", client.WithStreamInterceptors(func(_ context.Context, _ *grpc.StreamDesc, _ *grpc.ClientConn, method string, _ grpc.Streamer, _ ...grpc.CallOption) (grpc.ClientStream, error) { + called = method + return nil, errCanceled + })) + require.NoError(t, err, "Failed to create client") + + _, err = c.AuditLogs(context.Background(), client.AuditLogOptions{ + Type: client.DecisionLogs, + Tail: 1, + }) + require.ErrorIs(t, err, errCanceled) + require.Equal(t, svcv1.CerbosAdminService_ListAuditLogEntries_FullMethodName, called) + }) + + t.Run("unary", func(t *testing.T) { + var called string + + c, err := client.New("unix:/dev/null", client.WithUnaryInterceptors(func(_ context.Context, method string, _, _ any, _ *grpc.ClientConn, _ grpc.UnaryInvoker, _ ...grpc.CallOption) error { + called = method + return errCanceled + })) + require.NoError(t, err, "Failed to create client") + + _, err = c.IsAllowed(context.Background(), client.NewPrincipal("id", "role"), client.NewResource("kind", "id"), "action") + require.ErrorIs(t, err, errCanceled) + require.Equal(t, svcv1.CerbosService_CheckResources_FullMethodName, called) + }) + }) } func mkServerOpts(t *testing.T, withTLS bool) []testutil.ServerOpt { diff --git a/go.mod b/go.mod index 295bc49a8..3d079a197 100644 --- a/go.mod +++ b/go.mod @@ -76,6 +76,7 @@ require ( github.com/twmb/franz-go/plugin/kzap v1.1.2 go.elastic.co/ecszap v1.0.1 go.opencensus.io v0.24.0 + go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.42.0 go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.42.0 go.opentelemetry.io/contrib/propagators/autoprop v0.42.0 go.opentelemetry.io/contrib/propagators/b3 v1.17.0 diff --git a/go.sum b/go.sum index a96212e62..8c74fa988 100644 --- a/go.sum +++ b/go.sum @@ -864,6 +864,8 @@ go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk= go.opencensus.io v0.23.0/go.mod h1:XItmlyltB5F7CS4xOC1DcqMoFqwtC6OG2xF7mCv7P7E= go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0= go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo= +go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.42.0 h1:ZOLJc06r4CB42laIXg/7udr0pbZyuAihN10A/XuiQRY= +go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.42.0/go.mod h1:5z+/ZWJQKXa9YT34fQNx5K8Hd1EoIhvtUygUQPqEOgQ= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.42.0 h1:pginetY7+onl4qN1vl0xW/V/v6OBZ0vVdH+esuJgvmM= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.42.0/go.mod h1:XiYsayHc36K3EByOO6nbAXnAWbrUxdjUROCEeeROOH8= go.opentelemetry.io/contrib/propagators/autoprop v0.42.0 h1:s2RzYOAqHVgG23q8fPWYChobUoZM6rJZ98EnylJr66w= diff --git a/internal/server/server.go b/internal/server/server.go index 63cb90e34..7d47711a0 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -30,6 +30,7 @@ import ( "go.opencensus.io/plugin/ochttp" "go.opencensus.io/stats/view" "go.opencensus.io/zpages" + "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" "go.uber.org/zap" "golang.org/x/net/http2" "golang.org/x/net/http2/h2c" @@ -443,6 +444,7 @@ func (s *Server) mkGRPCServer(log *zap.Logger, auditLog audit.Log) (*grpc.Server grpc.ChainStreamInterceptor( grpc_recovery.StreamServerInterceptor(), telemetryInt.StreamServerInterceptor(), + otelgrpc.StreamServerInterceptor(), grpc_validator.StreamServerInterceptor(), grpc_ctxtags.StreamServerInterceptor(grpc_ctxtags.WithFieldExtractorForInitialReq(svc.ExtractRequestFields)), grpc_zap.StreamServerInterceptor(log, @@ -454,6 +456,7 @@ func (s *Server) mkGRPCServer(log *zap.Logger, auditLog audit.Log) (*grpc.Server grpc.ChainUnaryInterceptor( grpc_recovery.UnaryServerInterceptor(), telemetryInt.UnaryServerInterceptor(), + otelgrpc.UnaryServerInterceptor(), grpc_validator.UnaryServerInterceptor(), grpc_ctxtags.UnaryServerInterceptor(grpc_ctxtags.WithFieldExtractorForInitialReq(svc.ExtractRequestFields)), XForwardedHostUnaryServerInterceptor,