From d0b3f79bff1b50f81fc9e7ae708c65fe534771b3 Mon Sep 17 00:00:00 2001 From: Sam Lock Date: Tue, 9 May 2023 16:20:10 +0100 Subject: [PATCH] feat: Storage overlay (#1560) Introduces the ability to configure a second fallback storage driver using a configurable circuit breaker pattern. Signed-off-by: Sam Lock Co-authored-by: Charith Ellawala --- docs/modules/configuration/pages/storage.adoc | 25 ++ .../partials/fullconfiguration.adoc | 6 + go.mod | 3 +- go.sum | 2 + internal/server/server.go | 20 +- internal/storage/blob/store_test.go | 4 +- internal/storage/blob/tests.go | 4 +- internal/storage/overlay/conf.go | 74 +++++ internal/storage/overlay/overlay.go | 19 ++ internal/storage/overlay/store.go | 205 ++++++++++++ internal/storage/overlay/store_test.go | 298 ++++++++++++++++++ internal/storage/store.go | 13 + 12 files changed, 663 insertions(+), 10 deletions(-) create mode 100644 internal/storage/overlay/conf.go create mode 100644 internal/storage/overlay/overlay.go create mode 100644 internal/storage/overlay/store.go create mode 100644 internal/storage/overlay/store_test.go diff --git a/docs/modules/configuration/pages/storage.adoc b/docs/modules/configuration/pages/storage.adoc index d68e9201b..5755c9a28 100644 --- a/docs/modules/configuration/pages/storage.adoc +++ b/docs/modules/configuration/pages/storage.adoc @@ -383,3 +383,28 @@ You can customise the script below to suit your environment. Make sure to specif include::example$sqlserver_schema.sql[] ---- + +[#redundancy] +== Redundancy + +You can provide redundancy by configuring an `overlay` driver, which wraps a `base` and a `fallback` driver. Under normal operation, the base driver will be targeted as usual. However, if the driver consistently errors, the PDP will start targeting the fallback driver instead. The fallback is determined by a configurable https://learn.microsoft.com/en-us/previous-versions/msp-n-p/dn589784(v=pandp.10)[circuit breaker pattern]. + +You can configure the fallback error threshold and the fallback error window to determine how many errors can occur within a rolling window before the circuit breaker is tripped. + +[source,yaml,linenums] +---- +storage: + driver: "overlay" + overlay: + baseDriver: postgres + fallbackDriver: disk + fallbackErrorThreshold: 5 # number of errors that occur within the fallbackErrorWindow to trigger failover + fallbackErrorWindow: 5s # the rolling window in which errors are aggregated + disk: + directory: policies + watchForChanges: true + postgres: + url: "postgres://${PG_USER}:${PG_PASSWORD}@localhost:5432/postgres?sslmode=disable&search_path=cerbos" +---- + +NOTE: The overlay driver assumes the same interface as the base driver. Any operations that are available on the base driver but not the fallback driver will error if the circuit breaker is open and the fallback driver is being targeted. Likewise, even if the fallback driver supports additional operations compared to the base driver, these will still not be available should failover occur. diff --git a/docs/modules/configuration/partials/fullconfiguration.adoc b/docs/modules/configuration/partials/fullconfiguration.adoc index fcb7b2a93..3a3ce799f 100644 --- a/docs/modules/configuration/partials/fullconfiguration.adoc +++ b/docs/modules/configuration/partials/fullconfiguration.adoc @@ -157,6 +157,12 @@ storage: cert: /path/to/certificate key: /path/to/private_key caCert: /path/to/CA_certificate + overlay: + # This section is required only if storage.driver is overlay. + baseDriver: blob # Required. BaseDriver is the default storage driver + fallbackDriver: disk # Required. FallbackDriver is the secondary or fallback storage driver + fallbackErrorThreshold: 5 # FallbackErrorThreshold is the max number of errors we allow within the fallbackErrorWindow period + fallbackErrorWindow: 5m # FallbackErrorWindow is the cyclic period within which we aggregate failures postgres: # This section is required only if storage.driver is postgres. connPool: diff --git a/go.mod b/go.mod index 6e62ec7bf..515558c92 100644 --- a/go.mod +++ b/go.mod @@ -63,6 +63,8 @@ require ( github.com/rs/cors v1.9.0 github.com/rudderlabs/analytics-go v3.3.3+incompatible github.com/santhosh-tekuri/jsonschema/v5 v5.3.0 + github.com/sony/gobreaker v0.5.0 + github.com/sourcegraph/conc v0.3.0 github.com/spf13/afero v1.9.5 github.com/stretchr/testify v1.8.2 github.com/tidwall/gjson v1.14.4 @@ -234,7 +236,6 @@ require ( github.com/shopspring/decimal v1.3.1 // indirect github.com/sirupsen/logrus v1.9.0 // indirect github.com/skeema/knownhosts v1.1.0 // indirect - github.com/sourcegraph/conc v0.3.0 // indirect github.com/spf13/cast v1.5.0 // indirect github.com/stoewer/go-strcase v1.2.0 // indirect github.com/stretchr/objx v0.5.0 // indirect diff --git a/go.sum b/go.sum index 1b19c17d6..2ba6d11ed 100644 --- a/go.sum +++ b/go.sum @@ -2155,6 +2155,8 @@ github.com/snowflakedb/gosnowflake v1.6.3/go.mod h1:6hLajn6yxuJ4xUHZegMekpq9rnQb github.com/soheilhy/cmux v0.1.4/go.mod h1:IM3LyeVVIOuxMH7sFAkER9+bJ4dT7Ms6E4xg4kGIyLM= github.com/soheilhy/cmux v0.1.5/go.mod h1:T7TcVDs9LWfQgPlPsdngu6I6QIoyIFZDDC6sNE1GqG0= github.com/sony/gobreaker v0.4.1/go.mod h1:ZKptC7FHNvhBz7dN2LGjPVBz2sZJmc0/PkyDJOjmxWY= +github.com/sony/gobreaker v0.5.0 h1:dRCvqm0P490vZPmy7ppEk2qCnCieBooFJ+YoXGYB+yg= +github.com/sony/gobreaker v0.5.0/go.mod h1:ZKptC7FHNvhBz7dN2LGjPVBz2sZJmc0/PkyDJOjmxWY= github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo= github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0= github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72 h1:qLC7fQah7D6K1B0ujays3HV9gkFtllcxhzImRR7ArPQ= diff --git a/internal/server/server.go b/internal/server/server.go index 6320e5b10..903b82f92 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -67,6 +67,8 @@ import ( // Import blob to register the storage driver. _ "github.com/cerbos/cerbos/internal/storage/blob" + "github.com/cerbos/cerbos/internal/storage/overlay" + // Import bundle to register the storage driver. _ "github.com/cerbos/cerbos/internal/storage/bundle" // Import mysql to register the storage driver. @@ -140,16 +142,24 @@ func Start(ctx context.Context, zpagesEnabled bool) error { } var policyLoader engine.PolicyLoader - if bs, ok := store.(storage.BinaryStore); ok { - policyLoader = bs - } else if ss, ok := store.(storage.SourceStore); ok { + switch st := store.(type) { + case storage.BinaryStore: + policyLoader = st + case storage.SourceStore: // create compile manager - compileMgr, err := compile.NewManager(ctx, ss, schemaMgr) + compileMgr, err := compile.NewManager(ctx, st, schemaMgr) if err != nil { return fmt.Errorf("failed to create compile manager: %w", err) } policyLoader = compileMgr - } else { + case overlay.Overlay: + // create wrapped policy loader + pl, err := st.GetOverlayPolicyLoader(ctx, schemaMgr) + if err != nil { + return fmt.Errorf("failed to create overlay policy loader: %w", err) + } + policyLoader = pl + default: return ErrInvalidStore } diff --git a/internal/storage/blob/store_test.go b/internal/storage/blob/store_test.go index 5ca85d99b..d1276b7fe 100644 --- a/internal/storage/blob/store_test.go +++ b/internal/storage/blob/store_test.go @@ -61,7 +61,7 @@ func TestNewStore(t *testing.T) { must := require.New(t) bucketName := "test" - endpoint := startMinio(ctx, t, bucketName) + endpoint := StartMinio(ctx, t, bucketName) t.Setenv("AWS_ACCESS_KEY_ID", minioUsername) t.Setenv("AWS_SECRET_ACCESS_KEY", minioPassword) conf.Bucket = MinioBucketURL(bucketName, endpoint) @@ -132,7 +132,7 @@ func mkAddFn(t *testing.T, bucket *blob.Bucket) internal.MutateStoreFn { func mkStore(t *testing.T, dir string) (*Store, *blob.Bucket) { t.Helper() - endpoint := startMinio(context.Background(), t, bucketName) + endpoint := StartMinio(context.Background(), t, bucketName) conf := mkConf(t, dir, bucketName, endpoint) bucket, err := newBucket(context.Background(), conf) require.NoError(t, err) diff --git a/internal/storage/blob/tests.go b/internal/storage/blob/tests.go index e32130b49..32eae3f72 100644 --- a/internal/storage/blob/tests.go +++ b/internal/storage/blob/tests.go @@ -85,7 +85,7 @@ func newMinioBucket(ctx context.Context, t *testing.T, prefix string) *blob.Buck ctx, cancelFunc := context.WithDeadline(ctx, deadline) defer cancelFunc() - endpoint := startMinio(ctx, t, bucketName) + endpoint := StartMinio(ctx, t, bucketName) param := UploadParam{ BucketURL: MinioBucketURL(bucketName, endpoint), @@ -140,7 +140,7 @@ func uploadDirToBucket(tb testing.TB, ctx context.Context, dir string, bucket *b return files, err } -func startMinio(ctx context.Context, t *testing.T, bucketName string) string { +func StartMinio(ctx context.Context, t *testing.T, bucketName string) string { t.Helper() is := require.New(t) pool, err := dockertest.NewPool("") diff --git a/internal/storage/overlay/conf.go b/internal/storage/overlay/conf.go new file mode 100644 index 000000000..38693148e --- /dev/null +++ b/internal/storage/overlay/conf.go @@ -0,0 +1,74 @@ +// Copyright 2021-2023 Zenauth Ltd. +// SPDX-License-Identifier: Apache-2.0 + +package overlay + +import ( + "errors" + "time" + + "github.com/cerbos/cerbos/internal/config" + "github.com/cerbos/cerbos/internal/storage" + "go.uber.org/multierr" +) + +const ( + confKey = storage.ConfKey + ".overlay" + defaultFallbackErrorThreshold = 5 + defaultFallbackErrorWindow = 5 * time.Minute +) + +// Conf is required (if driver is set to 'overlay') configuration for overlay storage driver. +// +desc=This section is required only if storage.driver is overlay. +type Conf struct { + // BaseDriver is the default storage driver + BaseDriver string `yaml:"baseDriver" conf:"required,example=blob"` + // FallbackDriver is the secondary or fallback storage driver + FallbackDriver string `yaml:"fallbackDriver" conf:"required,example=disk"` + // FallbackErrorThreshold is the max number of errors we allow within the fallbackErrorWindow period + FallbackErrorThreshold int `yaml:"fallbackErrorThreshold,omitempty" conf:",example=5"` + // FallbackErrorWindow is the cyclic period within which we aggregate failures + FallbackErrorWindow time.Duration `yaml:"fallbackErrorWindow" conf:",example=5m"` +} + +func (conf *Conf) Key() string { + return confKey +} + +func (conf *Conf) Validate() error { + var errs []error + + if conf.BaseDriver == "" { + errs = append(errs, errors.New("baseDriver is required")) + } + + if conf.FallbackDriver == "" { + errs = append(errs, errors.New("fallbackDriver is required")) + } + + if conf.BaseDriver != "" && conf.BaseDriver == conf.FallbackDriver { + errs = append(errs, errors.New("baseDriver and fallbackDriver cannot be the same")) + } + + if len(errs) > 0 { + return multierr.Combine(errs...) + } + + return nil +} + +func (conf *Conf) SetDefaults() { + if conf.FallbackErrorThreshold == 0 { + conf.FallbackErrorThreshold = defaultFallbackErrorThreshold + } + if conf.FallbackErrorWindow == 0 { + conf.FallbackErrorWindow = defaultFallbackErrorWindow + } +} + +func GetConf() (*Conf, error) { + conf := &Conf{} + err := config.GetSection(conf) + + return conf, err +} diff --git a/internal/storage/overlay/overlay.go b/internal/storage/overlay/overlay.go new file mode 100644 index 000000000..2f847eb15 --- /dev/null +++ b/internal/storage/overlay/overlay.go @@ -0,0 +1,19 @@ +// Copyright 2021-2023 Zenauth Ltd. +// SPDX-License-Identifier: Apache-2.0 + +package overlay + +import ( + "context" + + "github.com/cerbos/cerbos/internal/engine" + "github.com/cerbos/cerbos/internal/schema" +) + +// The interface is defined here because placing in storage causes a circular dependency, +// probably because it blurs the lines by implementing `SourceStore` whilst having a dependency on +// `schema` in order to build the compile managers in the GetOverlayPolicyLoader method. +type Overlay interface { + // GetOverlayPolicyLoader returns a PolicyLoader implementation that wraps two SourceStores + GetOverlayPolicyLoader(ctx context.Context, schemaMgr schema.Manager) (engine.PolicyLoader, error) +} diff --git a/internal/storage/overlay/store.go b/internal/storage/overlay/store.go new file mode 100644 index 000000000..04369cb4f --- /dev/null +++ b/internal/storage/overlay/store.go @@ -0,0 +1,205 @@ +// Copyright 2021-2023 Zenauth Ltd. +// SPDX-License-Identifier: Apache-2.0 + +package overlay + +import ( + "context" + "errors" + "fmt" + "io" + + runtimev1 "github.com/cerbos/cerbos/api/genpb/cerbos/runtime/v1" + "go.uber.org/zap" + + "github.com/cerbos/cerbos/internal/compile" + "github.com/cerbos/cerbos/internal/config" + "github.com/cerbos/cerbos/internal/engine" + "github.com/cerbos/cerbos/internal/namer" + "github.com/cerbos/cerbos/internal/schema" + "github.com/cerbos/cerbos/internal/storage" + "github.com/sony/gobreaker" + "github.com/sourcegraph/conc/pool" +) + +const DriverName = "overlay" + +var ( + _ Overlay = (*Store)(nil) + _ storage.BinaryStore = (*Store)(nil) + _ storage.Reloadable = (*Store)(nil) +) + +func init() { + storage.RegisterDriver(DriverName, func(ctx context.Context, confW *config.Wrapper) (storage.Store, error) { + conf := new(Conf) + if err := confW.GetSection(conf); err != nil { + return nil, fmt.Errorf("failed to read overlay configuration: %w", err) + } + + return NewStore(ctx, conf, confW) + }) +} + +func NewStore(ctx context.Context, conf *Conf, confW *config.Wrapper) (*Store, error) { + getStore := func(key string) (storage.Store, error) { + cons, err := storage.GetDriverConstructor(key) + if err != nil { + return nil, fmt.Errorf("unknown storage driver [%s]", key) + } + + store, err := cons(ctx, confW) + if err != nil { + return nil, fmt.Errorf("failed to create overlay store: %w", err) + } + + return store, nil + } + + logger := zap.S().Named(confKey+".store").With("baseDriver", conf.BaseDriver, "fallbackDriver", conf.FallbackDriver) + + baseStore, err := getStore(conf.BaseDriver) + if err != nil { + return nil, fmt.Errorf("failed to create base policy loader: %w", err) + } + + fallbackStore, err := getStore(conf.FallbackDriver) + if err != nil { + return nil, fmt.Errorf("failed to create fallback policy loader: %w", err) + } + + return &Store{ + log: logger, + conf: conf, + baseStore: baseStore, + fallbackStore: fallbackStore, + circuitBreaker: newCircuitBreaker(conf), + }, nil +} + +type Store struct { + log *zap.SugaredLogger + conf *Conf + baseStore storage.Store + fallbackStore storage.Store + basePolicyLoader engine.PolicyLoader + fallbackPolicyLoader engine.PolicyLoader + circuitBreaker *gobreaker.CircuitBreaker +} + +func newCircuitBreaker(conf *Conf) *gobreaker.CircuitBreaker { + breakerSettings := gobreaker.Settings{ + Name: "Store", + ReadyToTrip: func(counts gobreaker.Counts) bool { + return counts.ConsecutiveFailures >= uint32(conf.FallbackErrorThreshold) + }, + Interval: conf.FallbackErrorWindow, + Timeout: 0, + } + return gobreaker.NewCircuitBreaker(breakerSettings) +} + +// GetOverlayPolicyLoader instantiates both the base and fallback policy loaders and then returns itself. +func (s *Store) GetOverlayPolicyLoader(ctx context.Context, schemaMgr schema.Manager) (engine.PolicyLoader, error) { + getPolicyLoader := func(storeInterface storage.Store, key string) (engine.PolicyLoader, error) { + switch st := storeInterface.(type) { + case storage.SourceStore: + pl, err := compile.NewManager(ctx, st, schemaMgr) + if err != nil { + return nil, fmt.Errorf("failed to create %s compile manager: %w", key, err) + } + return pl, nil + case storage.BinaryStore: + return st, nil + default: + return nil, errors.New("overlaid store does not implement either SourceStore or BinaryStore interfaces") + } + } + + var err error + if s.basePolicyLoader, err = getPolicyLoader(s.baseStore, "base"); err != nil { + return nil, err + } + if s.fallbackPolicyLoader, err = getPolicyLoader(s.fallbackStore, "fallback"); err != nil { + return nil, err + } + + return s, nil +} + +func withCircuitBreaker[T any](s *Store, baseFn, fallbackFn func() (T, error)) (T, error) { + if s.circuitBreaker.State() == gobreaker.StateOpen { + s.log.Debug("Calling overlay fallback method") + return fallbackFn() + } + + s.log.Debug("Calling overlay base method") + result, err := s.circuitBreaker.Execute(func() (interface{}, error) { + // TODO(saml) only increment on network specific errors? + return baseFn() + }) + + //nolint:forcetypeassert + return result.(T), err +} + +// +// PolicyLoader interface +// + +func (s *Store) GetPolicySet(ctx context.Context, id namer.ModuleID) (*runtimev1.RunnablePolicySet, error) { + // Both `SourceStore` (via `compile.Manager`) and `BinaryStore` implement GetPolicySet + return withCircuitBreaker( + s, + func() (*runtimev1.RunnablePolicySet, error) { return s.basePolicyLoader.GetPolicySet(ctx, id) }, + func() (*runtimev1.RunnablePolicySet, error) { return s.fallbackPolicyLoader.GetPolicySet(ctx, id) }, + ) +} + +// +// Store interface methods +// + +func (s *Store) Driver() string { + return DriverName +} + +func (s *Store) ListPolicyIDs(ctx context.Context, includeDisabled bool) ([]string, error) { + return withCircuitBreaker( + s, + func() ([]string, error) { return s.baseStore.ListPolicyIDs(ctx, includeDisabled) }, + func() ([]string, error) { return s.fallbackStore.ListPolicyIDs(ctx, includeDisabled) }, + ) +} + +func (s *Store) ListSchemaIDs(ctx context.Context) ([]string, error) { + return withCircuitBreaker( + s, + func() ([]string, error) { return s.baseStore.ListSchemaIDs(ctx) }, + func() ([]string, error) { return s.fallbackStore.ListSchemaIDs(ctx) }, + ) +} + +func (s *Store) LoadSchema(ctx context.Context, url string) (io.ReadCloser, error) { + return withCircuitBreaker( + s, + func() (io.ReadCloser, error) { return s.baseStore.LoadSchema(ctx, url) }, + func() (io.ReadCloser, error) { return s.fallbackStore.LoadSchema(ctx, url) }, + ) +} + +func (s *Store) Reload(ctx context.Context) error { + // We attempt to reload all stores in parallel, regardless of base/fallback configuration. + // Attempts on non-Reloadable stores will result in a noop. + p := pool.New().WithContext(ctx).WithCancelOnError().WithFirstError() + + if bs, ok := s.baseStore.(storage.Reloadable); ok { + p.Go(func(ctx context.Context) error { return bs.Reload(ctx) }) + } + + if fs, ok := s.fallbackStore.(storage.Reloadable); ok { + p.Go(func(ctx context.Context) error { return fs.Reload(ctx) }) + } + + return p.Wait() +} diff --git a/internal/storage/overlay/store_test.go b/internal/storage/overlay/store_test.go new file mode 100644 index 000000000..3670c8e1d --- /dev/null +++ b/internal/storage/overlay/store_test.go @@ -0,0 +1,298 @@ +// Copyright 2021-2023 Zenauth Ltd. +// SPDX-License-Identifier: Apache-2.0 + +package overlay + +import ( + "context" + "errors" + "io" + "testing" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "go.uber.org/zap" + + runtimev1 "github.com/cerbos/cerbos/api/genpb/cerbos/runtime/v1" + "github.com/cerbos/cerbos/internal/config" + "github.com/cerbos/cerbos/internal/namer" + "github.com/cerbos/cerbos/internal/schema" + "github.com/cerbos/cerbos/internal/storage" + "github.com/cerbos/cerbos/internal/storage/blob" + + "github.com/cerbos/cerbos/internal/storage/disk" +) + +var ( + _ storage.Store = (*MockStore)(nil) + _ storage.BinaryStore = (*MockBinaryStore)(nil) + _ storage.Reloadable = (*MockReloadable)(nil) +) + +func TestDriverInstantiation(t *testing.T) { + ctx := context.Background() + + bucketName := "test" + t.Setenv("AWS_ACCESS_KEY_ID", "minioadmin") + t.Setenv("AWS_SECRET_ACCESS_KEY", "minioadmin") + + conf := map[string]any{ + "storage": map[string]any{ + "driver": "overlay", + "overlay": map[string]any{ + "baseDriver": "blob", + "fallbackDriver": "disk", + "fallbackErrorThreshold": 3, + }, + "blob": map[string]any{ + "bucket": blob.MinioBucketURL(bucketName, blob.StartMinio(ctx, t, bucketName)), + "workDir": t.TempDir(), + "updatePollInterval": "10s", + }, + "disk": map[string]any{ + "directory": t.TempDir(), + }, + }, + } + require.NoError(t, config.LoadMap(conf)) + + // policy loader successfully created + t.Run("policy loader creation successful", func(t *testing.T) { + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + store, err := storage.New(ctx) + require.NoError(t, err, "error creating store") + require.Equal(t, DriverName, store.Driver()) + + schemaMgr, err := schema.New(ctx, store) + require.NoError(t, err, "error creating schema manager") + + overlayStore, ok := store.(Overlay) + require.True(t, ok, "store does not implement Overlay interface") + + _, err = overlayStore.GetOverlayPolicyLoader(ctx, schemaMgr) + require.NoError(t, err, "error creating overlay policy loader") + + wrappedStore, ok := store.(*Store) + require.True(t, ok) + + _, ok = wrappedStore.baseStore.(*blob.Store) + require.True(t, ok, "baseStore should be of type *blob.Store") + + _, ok = wrappedStore.fallbackStore.(*disk.Store) + require.True(t, ok, "baseStore should be of type *disk.Store") + }) +} + +func TestFailover(t *testing.T) { + fallbackErrorThreshold := 3 + confMap := map[string]any{ + "storage": map[string]any{ + "driver": "overlay", + "overlay": map[string]any{ + "baseDriver": "foo", + "fallbackDriver": "bar", + "fallbackErrorThreshold": fallbackErrorThreshold, + }, + }, + } + require.NoError(t, config.LoadMap(confMap)) + + conf := new(Conf) + err := config.Get(confKey, conf) + require.NoError(t, err) + + t.Run("failover not triggered when consecutive failures within threshold", func(t *testing.T) { + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + nFailures := fallbackErrorThreshold - 1 + nRequests := nFailures + 1 + basePolicyLoader := new(MockPolicyLoader) + basePolicyLoader.On("GetPolicySet", ctx, mock.AnythingOfType("namer.ModuleID")).Return((*runtimev1.RunnablePolicySet)(nil), errors.New("base store error")).Times(nFailures) + basePolicyLoader.On("GetPolicySet", ctx, mock.AnythingOfType("namer.ModuleID")).Return(&runtimev1.RunnablePolicySet{}, nil).Once() + + fallbackPolicyLoader := new(MockPolicyLoader) + + wrappedSourceStore := &Store{ + log: zap.S(), + basePolicyLoader: basePolicyLoader, + fallbackPolicyLoader: fallbackPolicyLoader, + circuitBreaker: newCircuitBreaker(conf), + } + + for i := 0; i < nRequests; i++ { + _, err := wrappedSourceStore.GetPolicySet(ctx, namer.GenModuleIDFromFQN("example")) + if i < nFailures { + require.Error(t, err, "expected base store to return an error") + } else { + require.NoError(t, err, "expected base store to succeed") + } + } + + basePolicyLoader.AssertExpectations(t) + }) + + t.Run("failover triggered when consecutive failures exceed threshold", func(t *testing.T) { + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + nFailures := fallbackErrorThreshold + nRequests := nFailures + 1 + basePolicyLoader := new(MockPolicyLoader) + basePolicyLoader.On("GetPolicySet", ctx, mock.AnythingOfType("namer.ModuleID")).Return((*runtimev1.RunnablePolicySet)(nil), errors.New("base store error")).Times(nFailures) + + fallbackPolicyLoader := new(MockPolicyLoader) + fallbackPolicyLoader.On("GetPolicySet", ctx, mock.AnythingOfType("namer.ModuleID")).Return(&runtimev1.RunnablePolicySet{}, nil).Once() + + wrappedSourceStore := &Store{ + log: zap.S(), + basePolicyLoader: basePolicyLoader, + fallbackPolicyLoader: fallbackPolicyLoader, + circuitBreaker: newCircuitBreaker(conf), + } + + for i := 0; i < nRequests; i++ { + _, err := wrappedSourceStore.GetPolicySet(ctx, namer.GenModuleIDFromFQN("example")) + if i < nFailures { + require.Error(t, err, "expected base store to return an error") + } else { + require.NoError(t, err, "expected fallback store to succeed") + } + } + + basePolicyLoader.AssertExpectations(t) + fallbackPolicyLoader.AssertExpectations(t) + }) + + t.Run("reload only called on baseStore if not implemented on fallbackStore", func(t *testing.T) { + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + baseStore := new(MockReloadable) + baseStore.On("Reload", mock.AnythingOfType("*context.cancelCtx")).Return(nil).Once() + + // Fallback store does not implement required method + fallbackStore := new(MockBinaryStore) + + wrappedSourceStore := &Store{ + log: zap.S(), + baseStore: baseStore, + fallbackStore: fallbackStore, + circuitBreaker: newCircuitBreaker(conf), + } + + err := wrappedSourceStore.Reload(ctx) + require.NoError(t, err, "error calling overlay reload method") + + baseStore.AssertExpectations(t) + fallbackStore.AssertNotCalled(t, "Reload") + }) + + t.Run("reload only called on fallbackStore if not implemented on baseStore", func(t *testing.T) { + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + baseStore := new(MockBinaryStore) + + fallbackStore := new(MockReloadable) + fallbackStore.On("Reload", mock.AnythingOfType("*context.cancelCtx")).Return(nil).Once() + + wrappedSourceStore := &Store{ + log: zap.S(), + baseStore: baseStore, + fallbackStore: fallbackStore, + circuitBreaker: newCircuitBreaker(conf), + } + + err := wrappedSourceStore.Reload(ctx) + require.NoError(t, err, "error calling overlay reload method") + + baseStore.AssertNotCalled(t, "Reload") + fallbackStore.AssertExpectations(t) + }) + + t.Run("reload not called if not implemented on either store", func(t *testing.T) { + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + baseStore := new(MockBinaryStore) + fallbackStore := new(MockBinaryStore) + + wrappedSourceStore := &Store{ + log: zap.S(), + baseStore: baseStore, + fallbackStore: fallbackStore, + circuitBreaker: newCircuitBreaker(conf), + } + + err := wrappedSourceStore.Reload(ctx) + require.NoError(t, err, "error calling overlay reload method") + + baseStore.AssertNotCalled(t, "Reload") + fallbackStore.AssertNotCalled(t, "Reload") + }) +} + +type MockPolicyLoader struct { + mock.Mock +} + +func (m *MockPolicyLoader) GetPolicySet(ctx context.Context, id namer.ModuleID) (*runtimev1.RunnablePolicySet, error) { + args := m.Called(ctx, id) + return args.Get(0).(*runtimev1.RunnablePolicySet), args.Error(1) +} + +type MockStore struct { + mock.Mock +} + +func (ms *MockStore) Driver() string { + args := ms.Called() + return args.String(0) +} + +func (ms *MockStore) ListPolicyIDs(ctx context.Context, _ bool) ([]string, error) { + args := ms.Called(ctx) + if res := args.Get(0); res == nil { + return nil, args.Error(0) + } + return args.Get(0).([]string), args.Error(0) +} + +func (ms *MockStore) ListSchemaIDs(ctx context.Context) ([]string, error) { + args := ms.Called(ctx) + if res := args.Get(0); res == nil { + return nil, args.Error(0) + } + return args.Get(0).([]string), args.Error(0) +} + +func (ms *MockStore) LoadSchema(ctx context.Context, _ string) (io.ReadCloser, error) { + args := ms.Called(ctx) + if res := args.Get(0); res == nil { + return nil, args.Error(0) + } + return nil, nil +} + +type MockBinaryStore struct { + mock.Mock + MockStore +} + +func (m *MockBinaryStore) GetPolicySet(ctx context.Context, id namer.ModuleID) (*runtimev1.RunnablePolicySet, error) { + args := m.Called(ctx, id) + return args.Get(0).(*runtimev1.RunnablePolicySet), args.Error(1) +} + +type MockReloadable struct { + mock.Mock + MockStore +} + +func (m *MockReloadable) Reload(ctx context.Context) error { + args := m.Called(ctx) + return args.Error(0) +} diff --git a/internal/storage/store.go b/internal/storage/store.go index 2fca6d5b8..ff10ba799 100644 --- a/internal/storage/store.go +++ b/internal/storage/store.go @@ -68,6 +68,19 @@ func RegisterDriver(name string, cons Constructor) { drivers[name] = cons } +// GetDriverConstructor registers a storage driver. +func GetDriverConstructor(name string) (Constructor, error) { + driversMu.RLock() + defer driversMu.RUnlock() + + cons, ok := drivers[name] + if !ok { + return nil, fmt.Errorf("unknown storage driver [%s]", name) + } + + return cons, nil +} + // New returns a storage driver implementation based on the configured driver. func New(ctx context.Context) (Store, error) { return NewFromConf(ctx, config.Global())