Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

service reflection: include transitive closure for a file #3851

Merged
merged 6 commits into from Sep 9, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
18 changes: 8 additions & 10 deletions reflection/serverreflection.go
Expand Up @@ -272,21 +272,19 @@ func (s *serverReflectionServer) allExtensionNumbersForType(st reflect.Type) ([]
// fileDescWithDependencies returns a slice of serialized fileDescriptors in
// wire format ([]byte). The fileDescriptors will include fd and all the
// transitive dependencies of fd with names not in sentFileDescriptors.
func fileDescWithDependencies(fd *dpb.FileDescriptorProto, sentFileDescriptors map[string]struct{}) ([][]byte, error) {
func fileDescWithDependencies(fd *dpb.FileDescriptorProto, sentFileDescriptors map[string]bool) ([][]byte, error) {
r := [][]byte{}
queue := make([]*dpb.FileDescriptorProto, 0)
queue = append(queue, fd)
queue := []*dpb.FileDescriptorProto{fd}
for len(queue) > 0 {
currentfd := queue[0]
queue = queue[1:]
if _, exists := sentFileDescriptors[fd.GetName()]; len(r) == 0 || !exists {
sentFileDescriptors[fd.GetName()] = struct{}{}
if sent := sentFileDescriptors[currentfd.GetName()]; len(r) == 0 || !sent {
sentFileDescriptors[currentfd.GetName()] = true
currentfdEncoded, err := proto.Marshal(currentfd)
if err != nil {
return nil, err
}
r = append(r, currentfdEncoded)

}
for _, dep := range currentfd.Dependency {
fdenc := proto.FileDescriptor(dep)
Expand All @@ -303,7 +301,7 @@ func fileDescWithDependencies(fd *dpb.FileDescriptorProto, sentFileDescriptors m
// fileDescEncodingByFilename finds the file descriptor for given filename,
// finds all of its previously unsent transitive dependencies, does marshalling
// on them, and returns the marshalled result.
func (s *serverReflectionServer) fileDescEncodingByFilename(name string, sentFileDescriptors map[string]struct{}) ([][]byte, error) {
func (s *serverReflectionServer) fileDescEncodingByFilename(name string, sentFileDescriptors map[string]bool) ([][]byte, error) {
enc := proto.FileDescriptor(name)
if enc == nil {
return nil, fmt.Errorf("unknown file: %v", name)
Expand Down Expand Up @@ -337,7 +335,7 @@ func parseMetadata(meta interface{}) ([]byte, bool) {
// given symbol, finds all of its previously unsent transitive dependencies,
// does marshalling on them, and returns the marshalled result. The given symbol
// can be a type, a service or a method.
func (s *serverReflectionServer) fileDescEncodingContainingSymbol(name string, sentFileDescriptors map[string]struct{}) ([][]byte, error) {
func (s *serverReflectionServer) fileDescEncodingContainingSymbol(name string, sentFileDescriptors map[string]bool) ([][]byte, error) {
_, symbols := s.getSymbols()
fd := symbols[name]
if fd == nil {
Expand All @@ -361,7 +359,7 @@ func (s *serverReflectionServer) fileDescEncodingContainingSymbol(name string, s
// fileDescEncodingContainingExtension finds the file descriptor containing
// given extension, finds all of its previously unsent transitive dependencies,
// does marshalling on them, and returns the marshalled result.
func (s *serverReflectionServer) fileDescEncodingContainingExtension(typeName string, extNum int32, sentFileDescriptors map[string]struct{}) ([][]byte, error) {
func (s *serverReflectionServer) fileDescEncodingContainingExtension(typeName string, extNum int32, sentFileDescriptors map[string]bool) ([][]byte, error) {
st, err := typeForName(typeName)
if err != nil {
return nil, err
Expand All @@ -388,7 +386,7 @@ func (s *serverReflectionServer) allExtensionNumbersForTypeName(name string) ([]

// ServerReflectionInfo is the reflection service handler.
func (s *serverReflectionServer) ServerReflectionInfo(stream rpb.ServerReflection_ServerReflectionInfoServer) error {
sentFileDescriptors := make(map[string]struct{})
sentFileDescriptors := make(map[string]bool)
for {
in, err := stream.Recv()
if err == io.EOF {
Expand Down
35 changes: 35 additions & 0 deletions reflection/serverreflection_test.go
Expand Up @@ -214,6 +214,8 @@ func (x) TestReflectionEnd2end(t *testing.T) {
t.Fatalf("cannot get ServerReflectionInfo: %v", err)
}

testFileByFilenameTransitiveClosure(t, stream, true)
testFileByFilenameTransitiveClosure(t, stream, false)
testFileByFilename(t, stream)
testFileByFilenameError(t, stream)
testFileContainingSymbol(t, stream)
Expand All @@ -227,6 +229,39 @@ func (x) TestReflectionEnd2end(t *testing.T) {
s.Stop()
}

func testFileByFilenameTransitiveClosure(t *testing.T, stream rpb.ServerReflection_ServerReflectionInfoClient, expectClosure bool) {
filename := "reflection/grpc_testing/proto2_ext2.proto"
if err := stream.Send(&rpb.ServerReflectionRequest{
MessageRequest: &rpb.ServerReflectionRequest_FileByFilename{
FileByFilename: filename,
},
}); err != nil {
t.Fatalf("failed to send request: %v", err)
}
r, err := stream.Recv()
if err != nil {
// io.EOF is not ok.
t.Fatalf("failed to recv response: %v", err)
}
switch r.MessageResponse.(type) {
case *rpb.ServerReflectionResponse_FileDescriptorResponse:
if !reflect.DeepEqual(r.GetFileDescriptorResponse().FileDescriptorProto[0], fdProto2Ext2Byte) {
t.Errorf("FileByFilename(%v)\nreceived: %q,\nwant: %q", filename, r.GetFileDescriptorResponse().FileDescriptorProto[0], fdProto2Ext2Byte)
}
if expectClosure {
if len(r.GetFileDescriptorResponse().FileDescriptorProto) < 2 {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

!= 2 would be even more precise.

t.Errorf("FileByFilename(%v) returned %v file descriptors, expected 2", filename, len(r.GetFileDescriptorResponse().FileDescriptorProto))
} else if !reflect.DeepEqual(r.GetFileDescriptorResponse().FileDescriptorProto[1], fdProto2Byte) {
t.Errorf("FileByFilename(%v)\nreceived: %q,\nwant: %q", filename, r.GetFileDescriptorResponse().FileDescriptorProto[1], fdProto2Byte)
}
} else if len(r.GetFileDescriptorResponse().FileDescriptorProto) > 1 {
t.Errorf("FileByFilename(%v) returned %v file descriptors, expected 1", filename, len(r.GetFileDescriptorResponse().FileDescriptorProto))
}
default:
t.Errorf("FileByFilename(%v) = %v, want type <ServerReflectionResponse_FileDescriptorResponse>", filename, r.MessageResponse)
}
}

func testFileByFilename(t *testing.T, stream rpb.ServerReflection_ServerReflectionInfoClient) {
for _, test := range []struct {
filename string
Expand Down