diff --git a/object.go b/object.go index 11376fa0a..45841e0fa 100644 --- a/object.go +++ b/object.go @@ -21,6 +21,7 @@ import ( "encoding/json" "errors" "fmt" + "hash" "io" "net" "os" @@ -127,6 +128,7 @@ var ( ErrObjectNotFound = errors.New("nats: object not found") ErrInvalidStoreName = errors.New("nats: invalid object-store name") ErrDigestMismatch = errors.New("nats: received a corrupt object, digests do not match") + ErrInvalidDigestFormat = errors.New("nats: object digest hash has invalid format") ErrNoObjectsFound = errors.New("nats: no objects found") ErrObjectAlreadyExists = errors.New("nats: an object already exists with that name") ErrNameRequired = errors.New("nats: name is required") @@ -476,10 +478,11 @@ func (obs *obs) Put(meta *ObjectMeta, r io.Reader, opts ...ObjectOpt) (*ObjectIn // ObjectResult impl. type objResult struct { sync.Mutex - info *ObjectInfo - r io.ReadCloser - err error - ctx context.Context + info *ObjectInfo + r io.ReadCloser + err error + ctx context.Context + digest hash.Hash } func (info *ObjectInfo) isLink() bool { @@ -542,7 +545,7 @@ func (obs *obs) Get(name string, opts ...ObjectOpt) (ObjectResult, error) { } // For calculating sum256 - h := sha256.New() + result.digest = sha256.New() processChunk := func(m *Msg) { if ctx != nil { @@ -577,24 +580,12 @@ func (obs *obs) Get(name string, opts ...ObjectOpt) (ObjectResult, error) { b = b[n:] } // Update sha256 - h.Write(m.Data) + result.digest.Write(m.Data) // Check if we are done. if tokens[ackNumPendingTokenPos] == objNoPending { pw.Close() m.Sub.Unsubscribe() - - // Make sure the digest matches. - sha := h.Sum(nil) - rsha, err := base64.URLEncoding.DecodeString(info.Digest) - if err != nil { - gotErr(m, err) - return - } - if !bytes.Equal(sha[:], rsha) { - gotErr(m, ErrDigestMismatch) - return - } } } @@ -1070,6 +1061,24 @@ func (o *objResult) Read(p []byte) (n int, err error) { } } } + if err == io.EOF { + // Make sure the digest matches. + sha := o.digest.Sum(nil) + digest := strings.SplitN(o.info.Digest, "=", 2) + if len(digest) != 2 { + o.err = ErrInvalidDigestFormat + return 0, o.err + } + rsha, decodeErr := base64.URLEncoding.DecodeString(o.info.Digest) + if decodeErr != nil { + o.err = decodeErr + return 0, o.err + } + if !bytes.Equal(sha[:], rsha) { + o.err = ErrDigestMismatch + return 0, o.err + } + } return n, err } diff --git a/test/object_test.go b/test/object_test.go index 6d8f09cdb..afde046d8 100644 --- a/test/object_test.go +++ b/test/object_test.go @@ -114,9 +114,6 @@ func TestObjectBasics(t *testing.T) { _, err = obs.Get("") expectErr(t, err) - _, err = obs.Get("") - expectErr(t, err) - _, err = obs.PutBytes("", blob) expectErr(t, err) } @@ -592,25 +589,16 @@ func TestObjectLinks(t *testing.T) { } func expectLinkIsCorrect(t *testing.T, originalObject *nats.ObjectInfo, linkObject *nats.ObjectInfo) { - if linkObject.Opts.Link != nil { - if expectLinkPartsAreCorrect(t, linkObject, originalObject.Bucket, originalObject.Name) { - return - } + if linkObject.Opts.Link == nil || !expectLinkPartsAreCorrect(t, linkObject, originalObject.Bucket, originalObject.Name) { + t.Fatalf("Link info not what was expected:\nActual: %+v\nTarget: %+v", linkObject, originalObject) } - t.Fatalf("Link info not what was expected:\nActual: %+v\nTarget: %+v", linkObject, originalObject) } func expectLinkPartsAreCorrect(t *testing.T, linkObject *nats.ObjectInfo, bucket, name string) bool { - if linkObject.Opts.Link.Bucket == bucket { - if linkObject.Opts.Link.Name == name { - if !linkObject.ModTime.IsZero() { - if linkObject.NUID != "" { - return true - } - } - } - } - return false + return linkObject.Opts.Link.Bucket == bucket && + linkObject.Opts.Link.Name == name && + !linkObject.ModTime.IsZero() && + linkObject.NUID != "" } // Right now no history, just make sure we are cleaning up after ourselves.