Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

service reflection: include transitive closure for a file #3851

Merged
merged 6 commits into from Sep 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
69 changes: 51 additions & 18 deletions reflection/serverreflection.go
Expand Up @@ -269,9 +269,39 @@ func (s *serverReflectionServer) allExtensionNumbersForType(st reflect.Type) ([]
return out, nil
}

// fileDescWithDependencies returns a slice of serialized fileDescriptors in
// wire format ([]byte). The fileDescriptors will include fd and all the
// transitive dependencies of fd with names not in sentFileDescriptors.
func fileDescWithDependencies(fd *dpb.FileDescriptorProto, sentFileDescriptors map[string]bool) ([][]byte, error) {
r := [][]byte{}
queue := []*dpb.FileDescriptorProto{fd}
for len(queue) > 0 {
currentfd := queue[0]
queue = queue[1:]
if sent := sentFileDescriptors[currentfd.GetName()]; len(r) == 0 || !sent {
sentFileDescriptors[currentfd.GetName()] = true
currentfdEncoded, err := proto.Marshal(currentfd)
if err != nil {
return nil, err
}
r = append(r, currentfdEncoded)
}
for _, dep := range currentfd.Dependency {
fdenc := proto.FileDescriptor(dep)
fdDep, err := decodeFileDesc(fdenc)
if err != nil {
continue
}
queue = append(queue, fdDep)
}
}
return r, nil
}

// fileDescEncodingByFilename finds the file descriptor for given filename,
// does marshalling on it and returns the marshalled result.
func (s *serverReflectionServer) fileDescEncodingByFilename(name string) ([]byte, error) {
// finds all of its previously unsent transitive dependencies, does marshalling
// on them, and returns the marshalled result.
func (s *serverReflectionServer) fileDescEncodingByFilename(name string, sentFileDescriptors map[string]bool) ([][]byte, error) {
enc := proto.FileDescriptor(name)
if enc == nil {
return nil, fmt.Errorf("unknown file: %v", name)
Expand All @@ -280,7 +310,7 @@ func (s *serverReflectionServer) fileDescEncodingByFilename(name string) ([]byte
if err != nil {
return nil, err
}
return proto.Marshal(fd)
return fileDescWithDependencies(fd, sentFileDescriptors)
}

// parseMetadata finds the file descriptor bytes specified meta.
Expand All @@ -301,10 +331,11 @@ func parseMetadata(meta interface{}) ([]byte, bool) {
return nil, false
}

// fileDescEncodingContainingSymbol finds the file descriptor containing the given symbol,
// does marshalling on it and returns the marshalled result.
// The given symbol can be a type, a service or a method.
func (s *serverReflectionServer) fileDescEncodingContainingSymbol(name string) ([]byte, error) {
// fileDescEncodingContainingSymbol finds the file descriptor containing the
// given symbol, finds all of its previously unsent transitive dependencies,
// does marshalling on them, and returns the marshalled result. The given symbol
// can be a type, a service or a method.
func (s *serverReflectionServer) fileDescEncodingContainingSymbol(name string, sentFileDescriptors map[string]bool) ([][]byte, error) {
_, symbols := s.getSymbols()
fd := symbols[name]
if fd == nil {
Expand All @@ -322,12 +353,13 @@ func (s *serverReflectionServer) fileDescEncodingContainingSymbol(name string) (
return nil, fmt.Errorf("unknown symbol: %v", name)
}

return proto.Marshal(fd)
return fileDescWithDependencies(fd, sentFileDescriptors)
}

// fileDescEncodingContainingExtension finds the file descriptor containing given extension,
// does marshalling on it and returns the marshalled result.
func (s *serverReflectionServer) fileDescEncodingContainingExtension(typeName string, extNum int32) ([]byte, error) {
// fileDescEncodingContainingExtension finds the file descriptor containing
// given extension, finds all of its previously unsent transitive dependencies,
// does marshalling on them, and returns the marshalled result.
func (s *serverReflectionServer) fileDescEncodingContainingExtension(typeName string, extNum int32, sentFileDescriptors map[string]bool) ([][]byte, error) {
st, err := typeForName(typeName)
if err != nil {
return nil, err
Expand All @@ -336,7 +368,7 @@ func (s *serverReflectionServer) fileDescEncodingContainingExtension(typeName st
if err != nil {
return nil, err
}
return proto.Marshal(fd)
return fileDescWithDependencies(fd, sentFileDescriptors)
}

// allExtensionNumbersForTypeName returns all extension numbers for the given type.
Expand All @@ -354,6 +386,7 @@ func (s *serverReflectionServer) allExtensionNumbersForTypeName(name string) ([]

// ServerReflectionInfo is the reflection service handler.
func (s *serverReflectionServer) ServerReflectionInfo(stream rpb.ServerReflection_ServerReflectionInfoServer) error {
sentFileDescriptors := make(map[string]bool)
for {
in, err := stream.Recv()
if err == io.EOF {
Expand All @@ -369,7 +402,7 @@ func (s *serverReflectionServer) ServerReflectionInfo(stream rpb.ServerReflectio
}
switch req := in.MessageRequest.(type) {
case *rpb.ServerReflectionRequest_FileByFilename:
b, err := s.fileDescEncodingByFilename(req.FileByFilename)
b, err := s.fileDescEncodingByFilename(req.FileByFilename, sentFileDescriptors)
if err != nil {
out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{
ErrorResponse: &rpb.ErrorResponse{
Expand All @@ -379,11 +412,11 @@ func (s *serverReflectionServer) ServerReflectionInfo(stream rpb.ServerReflectio
}
} else {
out.MessageResponse = &rpb.ServerReflectionResponse_FileDescriptorResponse{
FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: [][]byte{b}},
FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: b},
}
}
case *rpb.ServerReflectionRequest_FileContainingSymbol:
b, err := s.fileDescEncodingContainingSymbol(req.FileContainingSymbol)
b, err := s.fileDescEncodingContainingSymbol(req.FileContainingSymbol, sentFileDescriptors)
if err != nil {
out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{
ErrorResponse: &rpb.ErrorResponse{
Expand All @@ -393,13 +426,13 @@ func (s *serverReflectionServer) ServerReflectionInfo(stream rpb.ServerReflectio
}
} else {
out.MessageResponse = &rpb.ServerReflectionResponse_FileDescriptorResponse{
FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: [][]byte{b}},
FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: b},
}
}
case *rpb.ServerReflectionRequest_FileContainingExtension:
typeName := req.FileContainingExtension.ContainingType
extNum := req.FileContainingExtension.ExtensionNumber
b, err := s.fileDescEncodingContainingExtension(typeName, extNum)
b, err := s.fileDescEncodingContainingExtension(typeName, extNum, sentFileDescriptors)
if err != nil {
out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{
ErrorResponse: &rpb.ErrorResponse{
Expand All @@ -409,7 +442,7 @@ func (s *serverReflectionServer) ServerReflectionInfo(stream rpb.ServerReflectio
}
} else {
out.MessageResponse = &rpb.ServerReflectionResponse_FileDescriptorResponse{
FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: [][]byte{b}},
FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: b},
}
}
case *rpb.ServerReflectionRequest_AllExtensionNumbersOfType:
Expand Down
35 changes: 35 additions & 0 deletions reflection/serverreflection_test.go
Expand Up @@ -214,6 +214,8 @@ func (x) TestReflectionEnd2end(t *testing.T) {
t.Fatalf("cannot get ServerReflectionInfo: %v", err)
}

testFileByFilenameTransitiveClosure(t, stream, true)
testFileByFilenameTransitiveClosure(t, stream, false)
testFileByFilename(t, stream)
testFileByFilenameError(t, stream)
testFileContainingSymbol(t, stream)
Expand All @@ -227,6 +229,39 @@ func (x) TestReflectionEnd2end(t *testing.T) {
s.Stop()
}

func testFileByFilenameTransitiveClosure(t *testing.T, stream rpb.ServerReflection_ServerReflectionInfoClient, expectClosure bool) {
filename := "reflection/grpc_testing/proto2_ext2.proto"
if err := stream.Send(&rpb.ServerReflectionRequest{
MessageRequest: &rpb.ServerReflectionRequest_FileByFilename{
FileByFilename: filename,
},
}); err != nil {
t.Fatalf("failed to send request: %v", err)
}
r, err := stream.Recv()
if err != nil {
// io.EOF is not ok.
t.Fatalf("failed to recv response: %v", err)
}
switch r.MessageResponse.(type) {
case *rpb.ServerReflectionResponse_FileDescriptorResponse:
if !reflect.DeepEqual(r.GetFileDescriptorResponse().FileDescriptorProto[0], fdProto2Ext2Byte) {
t.Errorf("FileByFilename(%v)\nreceived: %q,\nwant: %q", filename, r.GetFileDescriptorResponse().FileDescriptorProto[0], fdProto2Ext2Byte)
}
if expectClosure {
if len(r.GetFileDescriptorResponse().FileDescriptorProto) != 2 {
t.Errorf("FileByFilename(%v) returned %v file descriptors, expected 2", filename, len(r.GetFileDescriptorResponse().FileDescriptorProto))
} else if !reflect.DeepEqual(r.GetFileDescriptorResponse().FileDescriptorProto[1], fdProto2Byte) {
t.Errorf("FileByFilename(%v)\nreceived: %q,\nwant: %q", filename, r.GetFileDescriptorResponse().FileDescriptorProto[1], fdProto2Byte)
}
} else if len(r.GetFileDescriptorResponse().FileDescriptorProto) != 1 {
t.Errorf("FileByFilename(%v) returned %v file descriptors, expected 1", filename, len(r.GetFileDescriptorResponse().FileDescriptorProto))
}
default:
t.Errorf("FileByFilename(%v) = %v, want type <ServerReflectionResponse_FileDescriptorResponse>", filename, r.MessageResponse)
}
}

func testFileByFilename(t *testing.T, stream rpb.ServerReflection_ServerReflectionInfoClient) {
for _, test := range []struct {
filename string
Expand Down