Skip to content

Commit

Permalink
Merge pull request #1066 from codablock/fix-unknown-extensions
Browse files Browse the repository at this point in the history
Properly support skipping of non-mandatory extensions
  • Loading branch information
pjbgf committed Apr 9, 2024
2 parents 7a9304e + 23fa589 commit cd6633c
Show file tree
Hide file tree
Showing 4 changed files with 201 additions and 41 deletions.
104 changes: 65 additions & 39 deletions plumbing/format/index/decoder.go
Expand Up @@ -24,8 +24,8 @@ var (
// ErrInvalidChecksum is returned by Decode if the SHA1 hash mismatch with
// the read content
ErrInvalidChecksum = errors.New("invalid checksum")

errUnknownExtension = errors.New("unknown extension")
// ErrUnknownExtension is returned when an index extension is encountered that is considered mandatory
ErrUnknownExtension = errors.New("unknown extension")
)

const (
Expand All @@ -39,6 +39,7 @@ const (

// A Decoder reads and decodes index files from an input stream.
type Decoder struct {
buf *bufio.Reader
r io.Reader
hash hash.Hash
lastEntry *Entry
Expand All @@ -49,8 +50,10 @@ type Decoder struct {
// NewDecoder returns a new decoder that reads from r.
func NewDecoder(r io.Reader) *Decoder {
h := hash.New(hash.CryptoType)
buf := bufio.NewReader(r)
return &Decoder{
r: io.TeeReader(r, h),
buf: buf,
r: io.TeeReader(buf, h),
hash: h,
extReader: bufio.NewReader(nil),
}
Expand Down Expand Up @@ -210,71 +213,76 @@ func (d *Decoder) readExtensions(idx *Index) error {
// count that they are not supported by jgit or libgit

var expected []byte
var peeked []byte
var err error

var header [4]byte
// we should always be able to peek for 4 bytes (header) + 4 bytes (extlen) + final hash
// if this fails, we know that we're at the end of the index
peekLen := 4 + 4 + d.hash.Size()

for {
expected = d.hash.Sum(nil)

var n int
if n, err = io.ReadFull(d.r, header[:]); err != nil {
if n == 0 {
err = io.EOF
}

peeked, err = d.buf.Peek(peekLen)
if len(peeked) < peekLen {
// there can't be an extension at this point, so let's bail out
err = nil
break
}
if err != nil {
return err
}

err = d.readExtension(idx, header[:])
err = d.readExtension(idx)
if err != nil {
break
return err
}
}

if err != errUnknownExtension {
return d.readChecksum(expected)
}

func (d *Decoder) readExtension(idx *Index) error {
var header [4]byte

if _, err := io.ReadFull(d.r, header[:]); err != nil {
return err
}

return d.readChecksum(expected, header)
}
r, err := d.getExtensionReader()
if err != nil {
return err
}

func (d *Decoder) readExtension(idx *Index, header []byte) error {
switch {
case bytes.Equal(header, treeExtSignature):
r, err := d.getExtensionReader()
if err != nil {
return err
}

case bytes.Equal(header[:], treeExtSignature):
idx.Cache = &Tree{}
d := &treeExtensionDecoder{r}
if err := d.Decode(idx.Cache); err != nil {
return err
}
case bytes.Equal(header, resolveUndoExtSignature):
r, err := d.getExtensionReader()
if err != nil {
return err
}

case bytes.Equal(header[:], resolveUndoExtSignature):
idx.ResolveUndo = &ResolveUndo{}
d := &resolveUndoDecoder{r}
if err := d.Decode(idx.ResolveUndo); err != nil {
return err
}
case bytes.Equal(header, endOfIndexEntryExtSignature):
r, err := d.getExtensionReader()
if err != nil {
return err
}

case bytes.Equal(header[:], endOfIndexEntryExtSignature):
idx.EndOfIndexEntry = &EndOfIndexEntry{}
d := &endOfIndexEntryDecoder{r}
if err := d.Decode(idx.EndOfIndexEntry); err != nil {
return err
}
default:
return errUnknownExtension
// See https://git-scm.com/docs/index-format, which says:
// If the first byte is 'A'..'Z' the extension is optional and can be ignored.
if header[0] < 'A' || header[0] > 'Z' {
return ErrUnknownExtension
}

d := &unknownExtensionDecoder{r}
if err := d.Decode(); err != nil {
return err
}
}

return nil
Expand All @@ -290,11 +298,10 @@ func (d *Decoder) getExtensionReader() (*bufio.Reader, error) {
return d.extReader, nil
}

func (d *Decoder) readChecksum(expected []byte, alreadyRead [4]byte) error {
func (d *Decoder) readChecksum(expected []byte) error {
var h plumbing.Hash
copy(h[:4], alreadyRead[:])

if _, err := io.ReadFull(d.r, h[4:]); err != nil {
if _, err := io.ReadFull(d.r, h[:]); err != nil {
return err
}

Expand Down Expand Up @@ -476,3 +483,22 @@ func (d *endOfIndexEntryDecoder) Decode(e *EndOfIndexEntry) error {
_, err = io.ReadFull(d.r, e.Hash[:])
return err
}

type unknownExtensionDecoder struct {
r *bufio.Reader
}

func (d *unknownExtensionDecoder) Decode() error {
var buf [1024]byte

for {
_, err := d.r.Read(buf[:])
if err == io.EOF {
break
}
if err != nil {
return err
}
}
return nil
}
102 changes: 102 additions & 0 deletions plumbing/format/index/decoder_test.go
@@ -1,6 +1,11 @@
package index

import (
"bytes"
"crypto"
"github.com/go-git/go-git/v5/plumbing/hash"
"github.com/go-git/go-git/v5/utils/binary"
"io"
"testing"

"github.com/go-git/go-git/v5/plumbing"
Expand Down Expand Up @@ -218,3 +223,100 @@ func (s *IndexSuite) TestDecodeEndOfIndexEntry(c *C) {
c.Assert(idx.EndOfIndexEntry.Offset, Equals, uint32(716))
c.Assert(idx.EndOfIndexEntry.Hash.String(), Equals, "922e89d9ffd7cefce93a211615b2053c0f42bd78")
}

func (s *IndexSuite) readSimpleIndex(c *C) *Index {
f, err := fixtures.Basic().One().DotGit().Open("index")
c.Assert(err, IsNil)
defer func() { c.Assert(f.Close(), IsNil) }()

idx := &Index{}
d := NewDecoder(f)
err = d.Decode(idx)
c.Assert(err, IsNil)

return idx
}

func (s *IndexSuite) buildIndexWithExtension(c *C, signature string, data string) []byte {
idx := s.readSimpleIndex(c)

buf := bytes.NewBuffer(nil)
e := NewEncoder(buf)

err := e.encode(idx, false)
c.Assert(err, IsNil)
err = e.encodeRawExtension(signature, []byte(data))
c.Assert(err, IsNil)

err = e.encodeFooter()
c.Assert(err, IsNil)

return buf.Bytes()
}

func (s *IndexSuite) TestDecodeUnknownOptionalExt(c *C) {
f := bytes.NewReader(s.buildIndexWithExtension(c, "TEST", "testdata"))

idx := &Index{}
d := NewDecoder(f)
err := d.Decode(idx)
c.Assert(err, IsNil)
}

func (s *IndexSuite) TestDecodeUnknownMandatoryExt(c *C) {
f := bytes.NewReader(s.buildIndexWithExtension(c, "test", "testdata"))

idx := &Index{}
d := NewDecoder(f)
err := d.Decode(idx)
c.Assert(err, ErrorMatches, ErrUnknownExtension.Error())
}

func (s *IndexSuite) TestDecodeTruncatedExt(c *C) {
idx := s.readSimpleIndex(c)

buf := bytes.NewBuffer(nil)
e := NewEncoder(buf)

err := e.encode(idx, false)
c.Assert(err, IsNil)

_, err = e.w.Write([]byte("TEST"))
c.Assert(err, IsNil)

err = binary.WriteUint32(e.w, uint32(100))
c.Assert(err, IsNil)

_, err = e.w.Write([]byte("truncated"))
c.Assert(err, IsNil)

err = e.encodeFooter()
c.Assert(err, IsNil)

idx = &Index{}
d := NewDecoder(buf)
err = d.Decode(idx)
c.Assert(err, ErrorMatches, io.EOF.Error())
}

func (s *IndexSuite) TestDecodeInvalidHash(c *C) {
idx := s.readSimpleIndex(c)

buf := bytes.NewBuffer(nil)
e := NewEncoder(buf)

err := e.encode(idx, false)
c.Assert(err, IsNil)

err = e.encodeRawExtension("TEST", []byte("testdata"))
c.Assert(err, IsNil)

h := hash.New(crypto.SHA1)
err = binary.Write(e.w, h.Sum(nil))
c.Assert(err, IsNil)

idx = &Index{}
d := NewDecoder(buf)
err = d.Decode(idx)
c.Assert(err, ErrorMatches, ErrInvalidChecksum.Error())
}
34 changes: 33 additions & 1 deletion plumbing/format/index/encoder.go
Expand Up @@ -3,6 +3,7 @@ package index
import (
"bytes"
"errors"
"fmt"
"io"
"sort"
"time"
Expand Down Expand Up @@ -35,6 +36,11 @@ func NewEncoder(w io.Writer) *Encoder {

// Encode writes the Index to the stream of the encoder.
func (e *Encoder) Encode(idx *Index) error {
return e.encode(idx, true)
}

func (e *Encoder) encode(idx *Index, footer bool) error {

// TODO: support v4
// TODO: support extensions
if idx.Version > EncodeVersionSupported {
Expand All @@ -49,7 +55,10 @@ func (e *Encoder) Encode(idx *Index) error {
return err
}

return e.encodeFooter()
if footer {
return e.encodeFooter()
}
return nil
}

func (e *Encoder) encodeHeader(idx *Index) error {
Expand Down Expand Up @@ -135,6 +144,29 @@ func (e *Encoder) encodeEntry(entry *Entry) error {
return binary.Write(e.w, []byte(entry.Name))
}

func (e *Encoder) encodeRawExtension(signature string, data []byte) error {
if len(signature) != 4 {
return fmt.Errorf("invalid signature length")
}

_, err := e.w.Write([]byte(signature))
if err != nil {
return err
}

err = binary.WriteUint32(e.w, uint32(len(data)))
if err != nil {
return err
}

_, err = e.w.Write(data)
if err != nil {
return err
}

return nil
}

func (e *Encoder) timeToUint32(t *time.Time) (uint32, uint32, error) {
if t.IsZero() {
return 0, 0, nil
Expand Down
2 changes: 1 addition & 1 deletion storage/filesystem/index.go
Expand Up @@ -48,7 +48,7 @@ func (s *IndexStorage) Index() (i *index.Index, err error) {

defer ioutil.CheckClose(f, &err)

d := index.NewDecoder(bufio.NewReader(f))
d := index.NewDecoder(f)
err = d.Decode(idx)
return idx, err
}

0 comments on commit cd6633c

Please sign in to comment.