diff --git a/reflection/serverreflection.go b/reflection/serverreflection.go index 7b6dd414a27..d2696168b10 100644 --- a/reflection/serverreflection.go +++ b/reflection/serverreflection.go @@ -269,9 +269,39 @@ func (s *serverReflectionServer) allExtensionNumbersForType(st reflect.Type) ([] return out, nil } +// fileDescWithDependencies returns a slice of serialized fileDescriptors in +// wire format ([]byte). The fileDescriptors will include fd and all the +// transitive dependencies of fd with names not in sentFileDescriptors. +func fileDescWithDependencies(fd *dpb.FileDescriptorProto, sentFileDescriptors map[string]bool) ([][]byte, error) { + r := [][]byte{} + queue := []*dpb.FileDescriptorProto{fd} + for len(queue) > 0 { + currentfd := queue[0] + queue = queue[1:] + if sent := sentFileDescriptors[currentfd.GetName()]; len(r) == 0 || !sent { + sentFileDescriptors[currentfd.GetName()] = true + currentfdEncoded, err := proto.Marshal(currentfd) + if err != nil { + return nil, err + } + r = append(r, currentfdEncoded) + } + for _, dep := range currentfd.Dependency { + fdenc := proto.FileDescriptor(dep) + fdDep, err := decodeFileDesc(fdenc) + if err != nil { + continue + } + queue = append(queue, fdDep) + } + } + return r, nil +} + // fileDescEncodingByFilename finds the file descriptor for given filename, -// does marshalling on it and returns the marshalled result. -func (s *serverReflectionServer) fileDescEncodingByFilename(name string) ([]byte, error) { +// finds all of its previously unsent transitive dependencies, does marshalling +// on them, and returns the marshalled result. +func (s *serverReflectionServer) fileDescEncodingByFilename(name string, sentFileDescriptors map[string]bool) ([][]byte, error) { enc := proto.FileDescriptor(name) if enc == nil { return nil, fmt.Errorf("unknown file: %v", name) @@ -280,7 +310,7 @@ func (s *serverReflectionServer) fileDescEncodingByFilename(name string) ([]byte if err != nil { return nil, err } - return proto.Marshal(fd) + return fileDescWithDependencies(fd, sentFileDescriptors) } // parseMetadata finds the file descriptor bytes specified meta. @@ -301,10 +331,11 @@ func parseMetadata(meta interface{}) ([]byte, bool) { return nil, false } -// fileDescEncodingContainingSymbol finds the file descriptor containing the given symbol, -// does marshalling on it and returns the marshalled result. -// The given symbol can be a type, a service or a method. -func (s *serverReflectionServer) fileDescEncodingContainingSymbol(name string) ([]byte, error) { +// fileDescEncodingContainingSymbol finds the file descriptor containing the +// given symbol, finds all of its previously unsent transitive dependencies, +// does marshalling on them, and returns the marshalled result. The given symbol +// can be a type, a service or a method. +func (s *serverReflectionServer) fileDescEncodingContainingSymbol(name string, sentFileDescriptors map[string]bool) ([][]byte, error) { _, symbols := s.getSymbols() fd := symbols[name] if fd == nil { @@ -322,12 +353,13 @@ func (s *serverReflectionServer) fileDescEncodingContainingSymbol(name string) ( return nil, fmt.Errorf("unknown symbol: %v", name) } - return proto.Marshal(fd) + return fileDescWithDependencies(fd, sentFileDescriptors) } -// fileDescEncodingContainingExtension finds the file descriptor containing given extension, -// does marshalling on it and returns the marshalled result. -func (s *serverReflectionServer) fileDescEncodingContainingExtension(typeName string, extNum int32) ([]byte, error) { +// fileDescEncodingContainingExtension finds the file descriptor containing +// given extension, finds all of its previously unsent transitive dependencies, +// does marshalling on them, and returns the marshalled result. +func (s *serverReflectionServer) fileDescEncodingContainingExtension(typeName string, extNum int32, sentFileDescriptors map[string]bool) ([][]byte, error) { st, err := typeForName(typeName) if err != nil { return nil, err @@ -336,7 +368,7 @@ func (s *serverReflectionServer) fileDescEncodingContainingExtension(typeName st if err != nil { return nil, err } - return proto.Marshal(fd) + return fileDescWithDependencies(fd, sentFileDescriptors) } // allExtensionNumbersForTypeName returns all extension numbers for the given type. @@ -354,6 +386,7 @@ func (s *serverReflectionServer) allExtensionNumbersForTypeName(name string) ([] // ServerReflectionInfo is the reflection service handler. func (s *serverReflectionServer) ServerReflectionInfo(stream rpb.ServerReflection_ServerReflectionInfoServer) error { + sentFileDescriptors := make(map[string]bool) for { in, err := stream.Recv() if err == io.EOF { @@ -369,7 +402,7 @@ func (s *serverReflectionServer) ServerReflectionInfo(stream rpb.ServerReflectio } switch req := in.MessageRequest.(type) { case *rpb.ServerReflectionRequest_FileByFilename: - b, err := s.fileDescEncodingByFilename(req.FileByFilename) + b, err := s.fileDescEncodingByFilename(req.FileByFilename, sentFileDescriptors) if err != nil { out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{ ErrorResponse: &rpb.ErrorResponse{ @@ -379,11 +412,11 @@ func (s *serverReflectionServer) ServerReflectionInfo(stream rpb.ServerReflectio } } else { out.MessageResponse = &rpb.ServerReflectionResponse_FileDescriptorResponse{ - FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: [][]byte{b}}, + FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: b}, } } case *rpb.ServerReflectionRequest_FileContainingSymbol: - b, err := s.fileDescEncodingContainingSymbol(req.FileContainingSymbol) + b, err := s.fileDescEncodingContainingSymbol(req.FileContainingSymbol, sentFileDescriptors) if err != nil { out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{ ErrorResponse: &rpb.ErrorResponse{ @@ -393,13 +426,13 @@ func (s *serverReflectionServer) ServerReflectionInfo(stream rpb.ServerReflectio } } else { out.MessageResponse = &rpb.ServerReflectionResponse_FileDescriptorResponse{ - FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: [][]byte{b}}, + FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: b}, } } case *rpb.ServerReflectionRequest_FileContainingExtension: typeName := req.FileContainingExtension.ContainingType extNum := req.FileContainingExtension.ExtensionNumber - b, err := s.fileDescEncodingContainingExtension(typeName, extNum) + b, err := s.fileDescEncodingContainingExtension(typeName, extNum, sentFileDescriptors) if err != nil { out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{ ErrorResponse: &rpb.ErrorResponse{ @@ -409,7 +442,7 @@ func (s *serverReflectionServer) ServerReflectionInfo(stream rpb.ServerReflectio } } else { out.MessageResponse = &rpb.ServerReflectionResponse_FileDescriptorResponse{ - FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: [][]byte{b}}, + FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: b}, } } case *rpb.ServerReflectionRequest_AllExtensionNumbersOfType: diff --git a/reflection/serverreflection_test.go b/reflection/serverreflection_test.go index db5fce2d893..9f252d778b9 100644 --- a/reflection/serverreflection_test.go +++ b/reflection/serverreflection_test.go @@ -214,6 +214,8 @@ func (x) TestReflectionEnd2end(t *testing.T) { t.Fatalf("cannot get ServerReflectionInfo: %v", err) } + testFileByFilenameTransitiveClosure(t, stream, true) + testFileByFilenameTransitiveClosure(t, stream, false) testFileByFilename(t, stream) testFileByFilenameError(t, stream) testFileContainingSymbol(t, stream) @@ -227,6 +229,39 @@ func (x) TestReflectionEnd2end(t *testing.T) { s.Stop() } +func testFileByFilenameTransitiveClosure(t *testing.T, stream rpb.ServerReflection_ServerReflectionInfoClient, expectClosure bool) { + filename := "reflection/grpc_testing/proto2_ext2.proto" + if err := stream.Send(&rpb.ServerReflectionRequest{ + MessageRequest: &rpb.ServerReflectionRequest_FileByFilename{ + FileByFilename: filename, + }, + }); err != nil { + t.Fatalf("failed to send request: %v", err) + } + r, err := stream.Recv() + if err != nil { + // io.EOF is not ok. + t.Fatalf("failed to recv response: %v", err) + } + switch r.MessageResponse.(type) { + case *rpb.ServerReflectionResponse_FileDescriptorResponse: + if !reflect.DeepEqual(r.GetFileDescriptorResponse().FileDescriptorProto[0], fdProto2Ext2Byte) { + t.Errorf("FileByFilename(%v)\nreceived: %q,\nwant: %q", filename, r.GetFileDescriptorResponse().FileDescriptorProto[0], fdProto2Ext2Byte) + } + if expectClosure { + if len(r.GetFileDescriptorResponse().FileDescriptorProto) != 2 { + t.Errorf("FileByFilename(%v) returned %v file descriptors, expected 2", filename, len(r.GetFileDescriptorResponse().FileDescriptorProto)) + } else if !reflect.DeepEqual(r.GetFileDescriptorResponse().FileDescriptorProto[1], fdProto2Byte) { + t.Errorf("FileByFilename(%v)\nreceived: %q,\nwant: %q", filename, r.GetFileDescriptorResponse().FileDescriptorProto[1], fdProto2Byte) + } + } else if len(r.GetFileDescriptorResponse().FileDescriptorProto) != 1 { + t.Errorf("FileByFilename(%v) returned %v file descriptors, expected 1", filename, len(r.GetFileDescriptorResponse().FileDescriptorProto)) + } + default: + t.Errorf("FileByFilename(%v) = %v, want type ", filename, r.MessageResponse) + } +} + func testFileByFilename(t *testing.T, stream rpb.ServerReflection_ServerReflectionInfoClient) { for _, test := range []struct { filename string