From 447fc0f65434d9bb69ba00feee480f9f927e3aac Mon Sep 17 00:00:00 2001 From: Piotr Piotrowski Date: Wed, 24 Aug 2022 12:04:37 +0200 Subject: [PATCH] Fix hash decoding in object store, propagate errors from Get() correctly --- object.go | 45 +++++++++++++++++++++++---------------- test/object_test.go | 51 +++++++++++++++++++++++++++++---------------- 2 files changed, 60 insertions(+), 36 deletions(-) diff --git a/object.go b/object.go index 11376fa0a..dcc468f32 100644 --- a/object.go +++ b/object.go @@ -21,6 +21,7 @@ import ( "encoding/json" "errors" "fmt" + "hash" "io" "net" "os" @@ -127,6 +128,7 @@ var ( ErrObjectNotFound = errors.New("nats: object not found") ErrInvalidStoreName = errors.New("nats: invalid object-store name") ErrDigestMismatch = errors.New("nats: received a corrupt object, digests do not match") + ErrInvalidDigestFormat = errors.New("nats: object digest hash has invalid format") ErrNoObjectsFound = errors.New("nats: no objects found") ErrObjectAlreadyExists = errors.New("nats: an object already exists with that name") ErrNameRequired = errors.New("nats: name is required") @@ -476,10 +478,11 @@ func (obs *obs) Put(meta *ObjectMeta, r io.Reader, opts ...ObjectOpt) (*ObjectIn // ObjectResult impl. type objResult struct { sync.Mutex - info *ObjectInfo - r io.ReadCloser - err error - ctx context.Context + info *ObjectInfo + r io.ReadCloser + err error + ctx context.Context + digest hash.Hash } func (info *ObjectInfo) isLink() bool { @@ -542,7 +545,7 @@ func (obs *obs) Get(name string, opts ...ObjectOpt) (ObjectResult, error) { } // For calculating sum256 - h := sha256.New() + result.digest = sha256.New() processChunk := func(m *Msg) { if ctx != nil { @@ -577,24 +580,12 @@ func (obs *obs) Get(name string, opts ...ObjectOpt) (ObjectResult, error) { b = b[n:] } // Update sha256 - h.Write(m.Data) + result.digest.Write(m.Data) // Check if we are done. if tokens[ackNumPendingTokenPos] == objNoPending { pw.Close() m.Sub.Unsubscribe() - - // Make sure the digest matches. - sha := h.Sum(nil) - rsha, err := base64.URLEncoding.DecodeString(info.Digest) - if err != nil { - gotErr(m, err) - return - } - if !bytes.Equal(sha[:], rsha) { - gotErr(m, ErrDigestMismatch) - return - } } } @@ -1070,6 +1061,24 @@ func (o *objResult) Read(p []byte) (n int, err error) { } } } + if err == io.EOF { + // Make sure the digest matches. + sha := o.digest.Sum(nil) + digest := strings.SplitN(o.info.Digest, "=", 2) + if len(digest) != 2 { + o.err = ErrInvalidDigestFormat + return 0, o.err + } + rsha, decodeErr := base64.URLEncoding.DecodeString(digest[1]) + if decodeErr != nil { + o.err = decodeErr + return 0, o.err + } + if !bytes.Equal(sha[:], rsha) { + o.err = ErrDigestMismatch + return 0, o.err + } + } return n, err } diff --git a/test/object_test.go b/test/object_test.go index 6d8f09cdb..c88b60cd9 100644 --- a/test/object_test.go +++ b/test/object_test.go @@ -16,6 +16,7 @@ package test import ( "bytes" "crypto/rand" + "fmt" "io/ioutil" "os" "path" @@ -114,13 +115,36 @@ func TestObjectBasics(t *testing.T) { _, err = obs.Get("") expectErr(t, err) - _, err = obs.Get("") - expectErr(t, err) - _, err = obs.PutBytes("", blob) expectErr(t, err) } +func TestGetObjectDigestMismatch(t *testing.T) { + s := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, s) + + nc, js := jsClient(t, s) + defer nc.Close() + + obs, err := js.CreateObjectStore(&nats.ObjectStoreConfig{Bucket: "FOO"}) + expectOk(t, err) + + _, err = obs.PutString("A", "abc") + expectOk(t, err) + info, err := obs.GetInfo("A") + expectOk(t, err) + + // add new chunk after using Put(), this will change the digest hash on Get() + _, err = js.Publish(fmt.Sprintf("$O.FOO.C.%s", info.NUID), []byte("123")) + expectOk(t, err) + + res, err := obs.Get("A") + expectOk(t, err) + _, err = ioutil.ReadAll(res) + expectErr(t, err, nats.ErrDigestMismatch) + expectErr(t, res.Error(), nats.ErrDigestMismatch) +} + func TestDefaultObjectStatus(t *testing.T) { s := RunBasicJetStreamServer() defer shutdownJSServerAndRemoveStorage(t, s) @@ -592,25 +616,16 @@ func TestObjectLinks(t *testing.T) { } func expectLinkIsCorrect(t *testing.T, originalObject *nats.ObjectInfo, linkObject *nats.ObjectInfo) { - if linkObject.Opts.Link != nil { - if expectLinkPartsAreCorrect(t, linkObject, originalObject.Bucket, originalObject.Name) { - return - } + if linkObject.Opts.Link == nil || !expectLinkPartsAreCorrect(t, linkObject, originalObject.Bucket, originalObject.Name) { + t.Fatalf("Link info not what was expected:\nActual: %+v\nTarget: %+v", linkObject, originalObject) } - t.Fatalf("Link info not what was expected:\nActual: %+v\nTarget: %+v", linkObject, originalObject) } func expectLinkPartsAreCorrect(t *testing.T, linkObject *nats.ObjectInfo, bucket, name string) bool { - if linkObject.Opts.Link.Bucket == bucket { - if linkObject.Opts.Link.Name == name { - if !linkObject.ModTime.IsZero() { - if linkObject.NUID != "" { - return true - } - } - } - } - return false + return linkObject.Opts.Link.Bucket == bucket && + linkObject.Opts.Link.Name == name && + !linkObject.ModTime.IsZero() && + linkObject.NUID != "" } // Right now no history, just make sure we are cleaning up after ourselves.