Skip to content

Commit

Permalink
metadata: fix validation issues (#6001)
Browse files Browse the repository at this point in the history
  • Loading branch information
ktalg committed Feb 28, 2023
1 parent 75bed1d commit dba41ef
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 55 deletions.
62 changes: 37 additions & 25 deletions internal/metadata/metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,33 +76,11 @@ func Set(addr resolver.Address, md metadata.MD) resolver.Address {
return addr
}

// Validate returns an error if the input md contains invalid keys or values.
//
// If the header is not a pseudo-header, the following items are checked:
// - header names must contain one or more characters from this set [0-9 a-z _ - .].
// - if the header-name ends with a "-bin" suffix, no validation of the header value is performed.
// - otherwise, the header value must contain one or more characters from the set [%x20-%x7E].
// Validate validates every pair in md with ValidatePair.
func Validate(md metadata.MD) error {
for k, vals := range md {
// pseudo-header will be ignored
if k[0] == ':' {
continue
}
// check key, for i that saving a conversion if not using for range
for i := 0; i < len(k); i++ {
r := k[i]
if !(r >= 'a' && r <= 'z') && !(r >= '0' && r <= '9') && r != '.' && r != '-' && r != '_' {
return fmt.Errorf("header key %q contains illegal characters not in [0-9a-z-_.]", k)
}
}
if strings.HasSuffix(k, "-bin") {
continue
}
// check value
for _, val := range vals {
if hasNotPrintable(val) {
return fmt.Errorf("header key %q contains value with non-printable ASCII characters", k)
}
if err := ValidatePair(k, vals...); err != nil {
return err
}
}
return nil
Expand All @@ -118,3 +96,37 @@ func hasNotPrintable(msg string) bool {
}
return false
}

// ValidatePair validate a key-value pair with the following rules (the pseudo-header will be skipped) :
//
// - key must contain one or more characters.
// - the characters in the key must be contained in [0-9 a-z _ - .].
// - if the key ends with a "-bin" suffix, no validation of the corresponding value is performed.
// - the characters in the every value must be printable (in [%x20-%x7E]).
func ValidatePair(key string, vals ...string) error {
// key should not be empty
if key == "" {
return fmt.Errorf("there is an empty key in the header")
}
// pseudo-header will be ignored
if key[0] == ':' {
return nil
}
// check key, for i that saving a conversion if not using for range
for i := 0; i < len(key); i++ {
r := key[i]
if !(r >= 'a' && r <= 'z') && !(r >= '0' && r <= '9') && r != '.' && r != '-' && r != '_' {
return fmt.Errorf("header key %q contains illegal characters not in [0-9a-z-_.]", key)
}
}
if strings.HasSuffix(key, "-bin") {
return nil
}
// check value
for _, val := range vals {
if hasNotPrintable(val) {
return fmt.Errorf("header key %q contains value with non-printable ASCII characters", key)
}
}
return nil
}
4 changes: 4 additions & 0 deletions internal/metadata/metadata_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ func TestValidate(t *testing.T) {
md: map[string][]string{"test": {string(rune(0x19))}},
want: errors.New("header key \"test\" contains value with non-printable ASCII characters"),
},
{
md: map[string][]string{"": {"valid"}},
want: errors.New("there is an empty key in the header"),
},
{
md: map[string][]string{"test-bin": {string(rune(0x19))}},
want: nil,
Expand Down
11 changes: 10 additions & 1 deletion stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,19 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
}

func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (_ ClientStream, err error) {
if md, _, ok := metadata.FromOutgoingContextRaw(ctx); ok {
if md, added, ok := metadata.FromOutgoingContextRaw(ctx); ok {
// validate md
if err := imetadata.Validate(md); err != nil {
return nil, status.Error(codes.Internal, err.Error())
}
// validate added
for _, kvs := range added {
for i := 0; i < len(kvs); i += 2 {
if err := imetadata.ValidatePair(kvs[i], kvs[i+1]); err != nil {
return nil, status.Error(codes.Internal, err.Error())
}
}
}
}
if channelz.IsOn() {
cc.incrCallsStarted()
Expand Down
91 changes: 62 additions & 29 deletions test/metadata_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,29 +36,55 @@ import (
)

func (s) TestInvalidMetadata(t *testing.T) {
grpctest.TLogger.ExpectErrorN("stream: failed to validate md when setting trailer", 2)
grpctest.TLogger.ExpectErrorN("stream: failed to validate md when setting trailer", 5)

tests := []struct {
md metadata.MD
want error
recv error
name string
md metadata.MD
appendMD []string
want error
recv error
}{
{
name: "invalid key",
md: map[string][]string{string(rune(0x19)): {"testVal"}},
want: status.Error(codes.Internal, "header key \"\\x19\" contains illegal characters not in [0-9a-z-_.]"),
recv: status.Error(codes.Internal, "invalid header field"),
},
{
name: "invalid value",
md: map[string][]string{"test": {string(rune(0x19))}},
want: status.Error(codes.Internal, "header key \"test\" contains value with non-printable ASCII characters"),
recv: status.Error(codes.Internal, "invalid header field"),
},
{
name: "invalid appended value",
md: map[string][]string{"test": {"test"}},
appendMD: []string{"/", "value"},
want: status.Error(codes.Internal, "header key \"/\" contains illegal characters not in [0-9a-z-_.]"),
recv: status.Error(codes.Internal, "invalid header field"),
},
{
name: "empty appended key",
md: map[string][]string{"test": {"test"}},
appendMD: []string{"", "value"},
want: status.Error(codes.Internal, "there is an empty key in the header"),
recv: status.Error(codes.Internal, "invalid header field"),
},
{
name: "empty key",
md: map[string][]string{"": {"test"}},
want: status.Error(codes.Internal, "there is an empty key in the header"),
recv: status.Error(codes.Internal, "invalid header field"),
},
{
name: "-bin key with arbitrary value",
md: map[string][]string{"test-bin": {string(rune(0x19))}},
want: nil,
recv: io.EOF,
},
{
name: "valid key and value",
md: map[string][]string{"test": {"value"}},
want: nil,
recv: io.EOF,
Expand All @@ -77,13 +103,16 @@ func (s) TestInvalidMetadata(t *testing.T) {
}
test := tests[testNum]
testNum++
if err := stream.SetHeader(test.md); !reflect.DeepEqual(test.want, err) {
return fmt.Errorf("call stream.SendHeader(md) validate metadata which is %v got err :%v, want err :%v", test.md, err, test.want)
// merge original md and added md.
md := metadata.Join(test.md, metadata.Pairs(test.appendMD...))

if err := stream.SetHeader(md); !reflect.DeepEqual(test.want, err) {
return fmt.Errorf("call stream.SendHeader(md) validate metadata which is %v got err :%v, want err :%v", md, err, test.want)
}
if err := stream.SendHeader(test.md); !reflect.DeepEqual(test.want, err) {
return fmt.Errorf("call stream.SendHeader(md) validate metadata which is %v got err :%v, want err :%v", test.md, err, test.want)
if err := stream.SendHeader(md); !reflect.DeepEqual(test.want, err) {
return fmt.Errorf("call stream.SendHeader(md) validate metadata which is %v got err :%v, want err :%v", md, err, test.want)
}
stream.SetTrailer(test.md)
stream.SetTrailer(md)
return nil
},
}
Expand All @@ -93,29 +122,33 @@ func (s) TestInvalidMetadata(t *testing.T) {
defer ss.Stop()

for _, test := range tests {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()

ctx = metadata.NewOutgoingContext(ctx, test.md)
if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); !reflect.DeepEqual(test.want, err) {
t.Errorf("call ss.Client.EmptyCall() validate metadata which is %v got err :%v, want err :%v", test.md, err, test.want)
}
t.Run("unary "+test.name, func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
ctx = metadata.NewOutgoingContext(ctx, test.md)
ctx = metadata.AppendToOutgoingContext(ctx, test.appendMD...)
if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); !reflect.DeepEqual(test.want, err) {
t.Errorf("call ss.Client.EmptyCall() validate metadata which is %v got err :%v, want err :%v", test.md, err, test.want)
}
})
}

// call the stream server's api to drive the server-side unit testing
for _, test := range tests {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
stream, err := ss.Client.FullDuplexCall(ctx)
defer cancel()
if err != nil {
t.Errorf("call ss.Client.FullDuplexCall(context.Background()) will success but got err :%v", err)
continue
}
if err := stream.Send(&testpb.StreamingOutputCallRequest{}); err != nil {
t.Errorf("call ss.Client stream Send(nil) will success but got err :%v", err)
}
if _, err := stream.Recv(); status.Code(err) != status.Code(test.recv) || !strings.Contains(err.Error(), test.recv.Error()) {
t.Errorf("stream.Recv() = _, get err :%v, want err :%v", err, test.recv)
}
t.Run("streaming "+test.name, func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
stream, err := ss.Client.FullDuplexCall(ctx)
if err != nil {
t.Errorf("call ss.Client.FullDuplexCall(context.Background()) will success but got err :%v", err)
return
}
if err := stream.Send(&testpb.StreamingOutputCallRequest{}); err != nil {
t.Errorf("call ss.Client stream Send(nil) will success but got err :%v", err)
}
if _, err := stream.Recv(); status.Code(err) != status.Code(test.recv) || !strings.Contains(err.Error(), test.recv.Error()) {
t.Errorf("stream.Recv() = _, get err :%v, want err :%v", err, test.recv)
}
})
}
}

0 comments on commit dba41ef

Please sign in to comment.