diff --git a/cmd/cerbos/compile/compile.go b/cmd/cerbos/compile/compile.go index 5de7bdfad..04f4c3e9f 100644 --- a/cmd/cerbos/compile/compile.go +++ b/cmd/cerbos/compile/compile.go @@ -96,6 +96,7 @@ func (c *Cmd) Run(k *kong.Kong) error { } store := disk.NewFromIndexWithConf(idx, &disk.Conf{}) + defer store.Close() enforcement := internalschema.EnforcementReject if c.IgnoreSchemas { diff --git a/go.mod b/go.mod index be6998b86..3048b408d 100644 --- a/go.mod +++ b/go.mod @@ -50,7 +50,7 @@ require ( github.com/lestrrat-go/jwx/v2 v2.0.12 github.com/mattn/go-isatty v0.0.19 github.com/minio/minio-go/v7 v7.0.61 - github.com/nlepage/go-tarfs v1.1.0 + github.com/nlepage/go-tarfs v1.2.0 github.com/oklog/ulid/v2 v2.1.0 github.com/olekukonko/tablewriter v0.0.5 github.com/ory/dockertest/v3 v3.10.0 diff --git a/go.sum b/go.sum index 19cafeef4..56b2c748f 100644 --- a/go.sum +++ b/go.sum @@ -650,8 +650,8 @@ github.com/mrunalp/fileutils v0.5.0/go.mod h1:M1WthSahJixYnrXQl/DFQuteStB1weuxD2 github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= -github.com/nlepage/go-tarfs v1.1.0 h1:bsACOiZMB/zFjYG/sE01070i9Fl26MnRpw0L6WuyfVs= -github.com/nlepage/go-tarfs v1.1.0/go.mod h1:IhxRcLhLkawBetnwu/JNuoPkq/6cclAllhgEa6SmzS8= +github.com/nlepage/go-tarfs v1.2.0 h1:UDFlDHRCjTjvUxMpZ6K2JzDwj6O3gPZto/eQYDcsSbQ= +github.com/nlepage/go-tarfs v1.2.0/go.mod h1:rno18mpMy9aEH1IiJVftFsqPyIpwqSUiAOpJYjlV2NA= github.com/oklog/ulid/v2 v2.1.0 h1:+9lhoxAP56we25tyYETBBY1YLA2SaoLvUFgrP2miPJU= github.com/oklog/ulid/v2 v2.1.0/go.mod h1:rcEKHmBBKfef9DhnvX7y1HZBYxjXb0cP5ExxNsTT1QQ= github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec= diff --git a/internal/storage/bundle/bundle.go b/internal/storage/bundle/bundle.go index d4f43079d..5620f2d83 100644 --- a/internal/storage/bundle/bundle.go +++ b/internal/storage/bundle/bundle.go @@ -244,5 +244,9 @@ func (b *Bundle) LoadSchema(_ context.Context, path string) (io.ReadCloser, erro } func (b *Bundle) Release() error { + return b.Close() +} + +func (b *Bundle) Close() error { return b.cleanup() } diff --git a/internal/storage/bundle/remote_source.go b/internal/storage/bundle/remote_source.go index 964e9e99e..f94a7ed49 100644 --- a/internal/storage/bundle/remote_source.go +++ b/internal/storage/bundle/remote_source.go @@ -460,3 +460,16 @@ func (s *RemoteSource) Reload(ctx context.Context) error { func (s *RemoteSource) SourceKind() string { return "remote" } + +func (s *RemoteSource) Close() error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.bundle != nil { + err := s.bundle.Close() + s.bundle = nil + return err + } + + return nil +} diff --git a/internal/storage/bundle/remote_source_test.go b/internal/storage/bundle/remote_source_test.go index 5ed1cf78a..17d410173 100644 --- a/internal/storage/bundle/remote_source_test.go +++ b/internal/storage/bundle/remote_source_test.go @@ -32,6 +32,7 @@ func TestRemoteSource(t *testing.T) { rs, err := bundle.NewRemoteSource(mkConf(t, true)) require.NoError(t, err, "Failed to create remote source") + t.Cleanup(func() { _ = rs.Close() }) require.NoError(t, rs.InitWithClient(context.Background(), mockClient), "Failed to init") ids, err := rs.ListPolicyIDs(context.Background(), storage.ListPolicyIDsParams{IncludeDisabled: true}) @@ -46,6 +47,7 @@ func TestRemoteSource(t *testing.T) { rs, err := bundle.NewRemoteSource(mkConf(t, true)) require.NoError(t, err, "Failed to create remote source") + t.Cleanup(func() { _ = rs.Close() }) require.NoError(t, rs.InitWithClient(context.Background(), mockClient), "Failed to init") ids, err := rs.ListPolicyIDs(context.Background(), storage.ListPolicyIDsParams{IncludeDisabled: true}) @@ -60,6 +62,7 @@ func TestRemoteSource(t *testing.T) { rs, err := bundle.NewRemoteSource(mkConf(t, true)) require.NoError(t, err, "Failed to create remote source") + t.Cleanup(func() { _ = rs.Close() }) require.Error(t, rs.InitWithClient(context.Background(), mockClient), "Expected error") require.False(t, rs.IsHealthy(), "Source should be unhealthy") @@ -75,6 +78,7 @@ func TestRemoteSource(t *testing.T) { rs, err := bundle.NewRemoteSource(mkConf(t, true)) require.NoError(t, err, "Failed to create remote source") + t.Cleanup(func() { _ = rs.Close() }) require.NoError(t, rs.InitWithClient(context.Background(), mockClient), "Failed to init") require.NoError(t, rs.Reload(context.Background()), "Failed to reload") @@ -93,6 +97,7 @@ func TestRemoteSource(t *testing.T) { rs, err := bundle.NewRemoteSource(mkConf(t, false)) require.NoError(t, err, "Failed to create remote source") + t.Cleanup(func() { _ = rs.Close() }) require.NoError(t, rs.InitWithClient(context.Background(), mockClient), "Failed to init") ids, err := rs.ListPolicyIDs(context.Background(), storage.ListPolicyIDsParams{IncludeDisabled: true}) @@ -114,6 +119,7 @@ func TestRemoteSource(t *testing.T) { Once() rs, err := bundle.NewRemoteSource(mkConf(t, false)) + t.Cleanup(func() { _ = rs.Close() }) require.NoError(t, err, "Failed to create remote source") ctx, cancelFn := context.WithCancel(context.Background()) @@ -152,6 +158,7 @@ func TestRemoteSource(t *testing.T) { rs, err := bundle.NewRemoteSource(mkConf(t, false)) require.NoError(t, err, "Failed to create remote source") + t.Cleanup(func() { _ = rs.Close() }) ctx, cancelFn := context.WithCancel(context.Background()) t.Cleanup(cancelFn) @@ -192,6 +199,7 @@ func TestRemoteSource(t *testing.T) { rs, err := bundle.NewRemoteSource(mkConf(t, false)) require.NoError(t, err, "Failed to create remote source") + t.Cleanup(func() { _ = rs.Close() }) ctx, cancelFn := context.WithCancel(context.Background()) t.Cleanup(cancelFn) @@ -237,6 +245,7 @@ func TestRemoteSource(t *testing.T) { rs, err := bundle.NewRemoteSource(mkConf(t, false)) require.NoError(t, err, "Failed to create remote source") + t.Cleanup(func() { _ = rs.Close() }) ctx, cancelFn := context.WithCancel(context.Background()) t.Cleanup(cancelFn) @@ -278,6 +287,7 @@ func TestRemoteSource(t *testing.T) { rs, err := bundle.NewRemoteSource(mkConf(t, false)) require.NoError(t, err, "Failed to create remote source") + t.Cleanup(func() { _ = rs.Close() }) ctx, cancelFn := context.WithCancel(context.Background()) t.Cleanup(cancelFn) diff --git a/internal/storage/bundle/store.go b/internal/storage/bundle/store.go index 5a36a6516..ca0a8473d 100644 --- a/internal/storage/bundle/store.go +++ b/internal/storage/bundle/store.go @@ -9,11 +9,13 @@ import ( "fmt" "io" + "go.uber.org/multierr" + "go.uber.org/zap" + runtimev1 "github.com/cerbos/cerbos/api/genpb/cerbos/runtime/v1" "github.com/cerbos/cerbos/internal/config" "github.com/cerbos/cerbos/internal/namer" "github.com/cerbos/cerbos/internal/storage" - "go.uber.org/zap" ) const DriverName = "bundle" @@ -123,3 +125,15 @@ func (hs *HybridStore) GetFirstMatch(ctx context.Context, candidates []namer.Mod func (hs *HybridStore) SourceKind() string { return "hybrid" } + +func (hs *HybridStore) Close() (outErr error) { + if c, ok := hs.remote.(io.Closer); ok { + outErr = multierr.Append(outErr, c.Close()) + } + + if c, ok := hs.local.(io.Closer); ok { + outErr = multierr.Append(outErr, c.Close()) + } + + return outErr +} diff --git a/internal/storage/disk/disk.go b/internal/storage/disk/disk.go index d733d248e..9abeb1d52 100644 --- a/internal/storage/disk/disk.go +++ b/internal/storage/disk/disk.go @@ -132,3 +132,7 @@ func (s *Store) Reload(ctx context.Context) error { return nil } + +func (s *Store) Close() error { + return s.idx.Close() +} diff --git a/internal/storage/disk/disk_test.go b/internal/storage/disk/disk_test.go index 0b6925ac5..89014c4b2 100644 --- a/internal/storage/disk/disk_test.go +++ b/internal/storage/disk/disk_test.go @@ -30,6 +30,7 @@ func mkStore(t *testing.T, dir string) *Store { store, err := NewStore(context.Background(), &Conf{Directory: dir}) require.NoError(t, err) + t.Cleanup(func() { _ = store.Close() }) return store } diff --git a/internal/storage/index/builder.go b/internal/storage/index/builder.go index aae586bc6..4e848735e 100644 --- a/internal/storage/index/builder.go +++ b/internal/storage/index/builder.go @@ -6,6 +6,7 @@ package index import ( "context" "fmt" + "io" "io/fs" "path" @@ -137,6 +138,9 @@ func build(ctx context.Context, fsys fs.FS, opts buildOptions) (Index, error) { return nil }) if err != nil { + if c, ok := fsys.(io.Closer); ok { + _ = c.Close() + } return nil, err } diff --git a/internal/storage/index/index.go b/internal/storage/index/index.go index 59824f1ed..1392247cf 100644 --- a/internal/storage/index/index.go +++ b/internal/storage/index/index.go @@ -40,6 +40,7 @@ type Entry struct { } type Index interface { + io.Closer storage.Instrumented GetFirstMatch([]namer.ModuleID) (*policy.CompilationUnit, error) GetCompilationUnits(...namer.ModuleID) (map[namer.ModuleID]*policy.CompilationUnit, error) @@ -517,3 +518,10 @@ func (idx *index) Reload(ctx context.Context) ([]storage.Event, error) { return []storage.Event{storage.NewReloadEvent()}, nil } + +func (idx *index) Close() error { + if c, ok := idx.fsys.(io.Closer); ok { + return c.Close() + } + return nil +} diff --git a/internal/storage/index/index_test.go b/internal/storage/index/index_test.go index 5f01e72e5..5060c8c77 100644 --- a/internal/storage/index/index_test.go +++ b/internal/storage/index/index_test.go @@ -50,6 +50,7 @@ func TestIndexLoadPolicy(t *testing.T) { require.NoError(t, err) idx, err := index.Build(context.Background(), fsys) require.NoError(t, err) + t.Cleanup(func() { _ = idx.Close() }) t.Run("should load the policies", func(t *testing.T) { policies, err := idx.LoadPolicy(context.Background(), policyFiles...) diff --git a/internal/storage/overlay/store.go b/internal/storage/overlay/store.go index 121d072c1..142220e34 100644 --- a/internal/storage/overlay/store.go +++ b/internal/storage/overlay/store.go @@ -9,17 +9,20 @@ import ( "fmt" "io" - runtimev1 "github.com/cerbos/cerbos/api/genpb/cerbos/runtime/v1" + "go.uber.org/multierr" "go.uber.org/zap" + runtimev1 "github.com/cerbos/cerbos/api/genpb/cerbos/runtime/v1" + + "github.com/sony/gobreaker" + "github.com/sourcegraph/conc/pool" + "github.com/cerbos/cerbos/internal/compile" "github.com/cerbos/cerbos/internal/config" "github.com/cerbos/cerbos/internal/engine" "github.com/cerbos/cerbos/internal/namer" "github.com/cerbos/cerbos/internal/schema" "github.com/cerbos/cerbos/internal/storage" - "github.com/sony/gobreaker" - "github.com/sourcegraph/conc/pool" ) const DriverName = "overlay" @@ -219,3 +222,15 @@ func (s *Store) Reload(ctx context.Context) error { return p.Wait() } + +func (s *Store) Close() (outErr error) { + if c, ok := s.baseStore.(io.Closer); ok { + outErr = multierr.Append(outErr, c.Close()) + } + + if c, ok := s.fallbackStore.(io.Closer); ok { + outErr = multierr.Append(outErr, c.Close()) + } + + return outErr +} diff --git a/internal/test/mocks/Index.go b/internal/test/mocks/Index.go index b499fc23d..11c31e4ac 100644 --- a/internal/test/mocks/Index.go +++ b/internal/test/mocks/Index.go @@ -126,6 +126,47 @@ func (_c *Index_Clear_Call) RunAndReturn(run func() error) *Index_Clear_Call { return _c } +// Close provides a mock function with given fields: +func (_m *Index) Close() error { + ret := _m.Called() + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Index_Close_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Close' +type Index_Close_Call struct { + *mock.Call +} + +// Close is a helper method to define mock.On call +func (_e *Index_Expecter) Close() *Index_Close_Call { + return &Index_Close_Call{Call: _e.mock.On("Close")} +} + +func (_c *Index_Close_Call) Run(run func()) *Index_Close_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *Index_Close_Call) Return(_a0 error) *Index_Close_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *Index_Close_Call) RunAndReturn(run func() error) *Index_Close_Call { + _c.Call.Return(run) + return _c +} + // Delete provides a mock function with given fields: _a0 func (_m *Index) Delete(_a0 index.Entry) (storage.Event, error) { ret := _m.Called(_a0) diff --git a/internal/util/filesystem.go b/internal/util/filesystem.go index 7d2286ca9..bd3ae90ee 100644 --- a/internal/util/filesystem.go +++ b/internal/util/filesystem.go @@ -15,6 +15,7 @@ import ( "strings" "github.com/nlepage/go-tarfs" + "go.uber.org/multierr" "google.golang.org/protobuf/proto" ) @@ -82,12 +83,30 @@ func IsArchiveFile(fileName string) bool { return IsZip(fileName) || IsTar(fileName) || IsGzip(fileName) } -func getFsFromTar(r io.Reader) (fs.FS, error) { +func getFsFromTar(r io.Reader, closers ...io.Closer) (fs.FS, error) { tfs, err := tarfs.New(r) if err != nil { + for _, c := range closers { + _ = c.Close() + } return nil, fmt.Errorf("failed to open tar file: %w", err) } - return tfs, nil + + return ClosableFS{FS: tfs, closers: closers}, nil +} + +type ClosableFS struct { + fs.FS + io.Closer + closers []io.Closer +} + +func (cfs ClosableFS) Close() (outErr error) { + for _, c := range cfs.closers { + outErr = multierr.Append(outErr, c.Close()) + } + + return outErr } // OpenDirectoryFS attempts to open a directory FS at the given location. It'll initially check if the target file is an archive, @@ -101,29 +120,27 @@ func OpenDirectoryFS(path string) (fs.FS, error) { if err != nil { return nil, fmt.Errorf("failed to open zip file: %w", err) } - return zr, nil + return ClosableFS{FS: zr, closers: []io.Closer{zr}}, nil case IsTar(path): f, err := os.Open(path) if err != nil { return nil, fmt.Errorf("failed to open tar file: %w", err) } - defer f.Close() - return getFsFromTar(f) + return getFsFromTar(f, f) case IsGzip(path): f, err := os.Open(path) if err != nil { return nil, fmt.Errorf("failed to open gzip file: %w", err) } - defer f.Close() gzr, err := gzip.NewReader(f) if err != nil { + _ = f.Close() return nil, fmt.Errorf("failed to open gzip file: %w", err) } - defer gzr.Close() - return getFsFromTar(gzr) + return getFsFromTar(gzr, gzr, f) } return os.DirFS(path), nil