Skip to content

Commit

Permalink
service reflection: include transitive closure for a file (#3851)
Browse files Browse the repository at this point in the history
  • Loading branch information
GarrettGutierrez1 committed Sep 9, 2020
1 parent 15157e2 commit 52029da
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 18 deletions.
69 changes: 51 additions & 18 deletions reflection/serverreflection.go
Expand Up @@ -269,9 +269,39 @@ func (s *serverReflectionServer) allExtensionNumbersForType(st reflect.Type) ([]
return out, nil
}

// fileDescWithDependencies returns a slice of serialized fileDescriptors in
// wire format ([]byte). The fileDescriptors will include fd and all the
// transitive dependencies of fd with names not in sentFileDescriptors.
func fileDescWithDependencies(fd *dpb.FileDescriptorProto, sentFileDescriptors map[string]bool) ([][]byte, error) {
r := [][]byte{}
queue := []*dpb.FileDescriptorProto{fd}
for len(queue) > 0 {
currentfd := queue[0]
queue = queue[1:]
if sent := sentFileDescriptors[currentfd.GetName()]; len(r) == 0 || !sent {
sentFileDescriptors[currentfd.GetName()] = true
currentfdEncoded, err := proto.Marshal(currentfd)
if err != nil {
return nil, err
}
r = append(r, currentfdEncoded)
}
for _, dep := range currentfd.Dependency {
fdenc := proto.FileDescriptor(dep)
fdDep, err := decodeFileDesc(fdenc)
if err != nil {
continue
}
queue = append(queue, fdDep)
}
}
return r, nil
}

// fileDescEncodingByFilename finds the file descriptor for given filename,
// does marshalling on it and returns the marshalled result.
func (s *serverReflectionServer) fileDescEncodingByFilename(name string) ([]byte, error) {
// finds all of its previously unsent transitive dependencies, does marshalling
// on them, and returns the marshalled result.
func (s *serverReflectionServer) fileDescEncodingByFilename(name string, sentFileDescriptors map[string]bool) ([][]byte, error) {
enc := proto.FileDescriptor(name)
if enc == nil {
return nil, fmt.Errorf("unknown file: %v", name)
Expand All @@ -280,7 +310,7 @@ func (s *serverReflectionServer) fileDescEncodingByFilename(name string) ([]byte
if err != nil {
return nil, err
}
return proto.Marshal(fd)
return fileDescWithDependencies(fd, sentFileDescriptors)
}

// parseMetadata finds the file descriptor bytes specified meta.
Expand All @@ -301,10 +331,11 @@ func parseMetadata(meta interface{}) ([]byte, bool) {
return nil, false
}

// fileDescEncodingContainingSymbol finds the file descriptor containing the given symbol,
// does marshalling on it and returns the marshalled result.
// The given symbol can be a type, a service or a method.
func (s *serverReflectionServer) fileDescEncodingContainingSymbol(name string) ([]byte, error) {
// fileDescEncodingContainingSymbol finds the file descriptor containing the
// given symbol, finds all of its previously unsent transitive dependencies,
// does marshalling on them, and returns the marshalled result. The given symbol
// can be a type, a service or a method.
func (s *serverReflectionServer) fileDescEncodingContainingSymbol(name string, sentFileDescriptors map[string]bool) ([][]byte, error) {
_, symbols := s.getSymbols()
fd := symbols[name]
if fd == nil {
Expand All @@ -322,12 +353,13 @@ func (s *serverReflectionServer) fileDescEncodingContainingSymbol(name string) (
return nil, fmt.Errorf("unknown symbol: %v", name)
}

return proto.Marshal(fd)
return fileDescWithDependencies(fd, sentFileDescriptors)
}

// fileDescEncodingContainingExtension finds the file descriptor containing given extension,
// does marshalling on it and returns the marshalled result.
func (s *serverReflectionServer) fileDescEncodingContainingExtension(typeName string, extNum int32) ([]byte, error) {
// fileDescEncodingContainingExtension finds the file descriptor containing
// given extension, finds all of its previously unsent transitive dependencies,
// does marshalling on them, and returns the marshalled result.
func (s *serverReflectionServer) fileDescEncodingContainingExtension(typeName string, extNum int32, sentFileDescriptors map[string]bool) ([][]byte, error) {
st, err := typeForName(typeName)
if err != nil {
return nil, err
Expand All @@ -336,7 +368,7 @@ func (s *serverReflectionServer) fileDescEncodingContainingExtension(typeName st
if err != nil {
return nil, err
}
return proto.Marshal(fd)
return fileDescWithDependencies(fd, sentFileDescriptors)
}

// allExtensionNumbersForTypeName returns all extension numbers for the given type.
Expand All @@ -354,6 +386,7 @@ func (s *serverReflectionServer) allExtensionNumbersForTypeName(name string) ([]

// ServerReflectionInfo is the reflection service handler.
func (s *serverReflectionServer) ServerReflectionInfo(stream rpb.ServerReflection_ServerReflectionInfoServer) error {
sentFileDescriptors := make(map[string]bool)
for {
in, err := stream.Recv()
if err == io.EOF {
Expand All @@ -369,7 +402,7 @@ func (s *serverReflectionServer) ServerReflectionInfo(stream rpb.ServerReflectio
}
switch req := in.MessageRequest.(type) {
case *rpb.ServerReflectionRequest_FileByFilename:
b, err := s.fileDescEncodingByFilename(req.FileByFilename)
b, err := s.fileDescEncodingByFilename(req.FileByFilename, sentFileDescriptors)
if err != nil {
out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{
ErrorResponse: &rpb.ErrorResponse{
Expand All @@ -379,11 +412,11 @@ func (s *serverReflectionServer) ServerReflectionInfo(stream rpb.ServerReflectio
}
} else {
out.MessageResponse = &rpb.ServerReflectionResponse_FileDescriptorResponse{
FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: [][]byte{b}},
FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: b},
}
}
case *rpb.ServerReflectionRequest_FileContainingSymbol:
b, err := s.fileDescEncodingContainingSymbol(req.FileContainingSymbol)
b, err := s.fileDescEncodingContainingSymbol(req.FileContainingSymbol, sentFileDescriptors)
if err != nil {
out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{
ErrorResponse: &rpb.ErrorResponse{
Expand All @@ -393,13 +426,13 @@ func (s *serverReflectionServer) ServerReflectionInfo(stream rpb.ServerReflectio
}
} else {
out.MessageResponse = &rpb.ServerReflectionResponse_FileDescriptorResponse{
FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: [][]byte{b}},
FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: b},
}
}
case *rpb.ServerReflectionRequest_FileContainingExtension:
typeName := req.FileContainingExtension.ContainingType
extNum := req.FileContainingExtension.ExtensionNumber
b, err := s.fileDescEncodingContainingExtension(typeName, extNum)
b, err := s.fileDescEncodingContainingExtension(typeName, extNum, sentFileDescriptors)
if err != nil {
out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{
ErrorResponse: &rpb.ErrorResponse{
Expand All @@ -409,7 +442,7 @@ func (s *serverReflectionServer) ServerReflectionInfo(stream rpb.ServerReflectio
}
} else {
out.MessageResponse = &rpb.ServerReflectionResponse_FileDescriptorResponse{
FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: [][]byte{b}},
FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: b},
}
}
case *rpb.ServerReflectionRequest_AllExtensionNumbersOfType:
Expand Down
35 changes: 35 additions & 0 deletions reflection/serverreflection_test.go
Expand Up @@ -214,6 +214,8 @@ func (x) TestReflectionEnd2end(t *testing.T) {
t.Fatalf("cannot get ServerReflectionInfo: %v", err)
}

testFileByFilenameTransitiveClosure(t, stream, true)
testFileByFilenameTransitiveClosure(t, stream, false)
testFileByFilename(t, stream)
testFileByFilenameError(t, stream)
testFileContainingSymbol(t, stream)
Expand All @@ -227,6 +229,39 @@ func (x) TestReflectionEnd2end(t *testing.T) {
s.Stop()
}

func testFileByFilenameTransitiveClosure(t *testing.T, stream rpb.ServerReflection_ServerReflectionInfoClient, expectClosure bool) {
filename := "reflection/grpc_testing/proto2_ext2.proto"
if err := stream.Send(&rpb.ServerReflectionRequest{
MessageRequest: &rpb.ServerReflectionRequest_FileByFilename{
FileByFilename: filename,
},
}); err != nil {
t.Fatalf("failed to send request: %v", err)
}
r, err := stream.Recv()
if err != nil {
// io.EOF is not ok.
t.Fatalf("failed to recv response: %v", err)
}
switch r.MessageResponse.(type) {
case *rpb.ServerReflectionResponse_FileDescriptorResponse:
if !reflect.DeepEqual(r.GetFileDescriptorResponse().FileDescriptorProto[0], fdProto2Ext2Byte) {
t.Errorf("FileByFilename(%v)\nreceived: %q,\nwant: %q", filename, r.GetFileDescriptorResponse().FileDescriptorProto[0], fdProto2Ext2Byte)
}
if expectClosure {
if len(r.GetFileDescriptorResponse().FileDescriptorProto) != 2 {
t.Errorf("FileByFilename(%v) returned %v file descriptors, expected 2", filename, len(r.GetFileDescriptorResponse().FileDescriptorProto))
} else if !reflect.DeepEqual(r.GetFileDescriptorResponse().FileDescriptorProto[1], fdProto2Byte) {
t.Errorf("FileByFilename(%v)\nreceived: %q,\nwant: %q", filename, r.GetFileDescriptorResponse().FileDescriptorProto[1], fdProto2Byte)
}
} else if len(r.GetFileDescriptorResponse().FileDescriptorProto) != 1 {
t.Errorf("FileByFilename(%v) returned %v file descriptors, expected 1", filename, len(r.GetFileDescriptorResponse().FileDescriptorProto))
}
default:
t.Errorf("FileByFilename(%v) = %v, want type <ServerReflectionResponse_FileDescriptorResponse>", filename, r.MessageResponse)
}
}

func testFileByFilename(t *testing.T, stream rpb.ServerReflection_ServerReflectionInfoClient) {
for _, test := range []struct {
filename string
Expand Down

0 comments on commit 52029da

Please sign in to comment.