From e99175f116a693749ea412bd719b069832648bdf Mon Sep 17 00:00:00 2001 From: "gargut@google.com" Date: Wed, 26 Aug 2020 15:25:57 -0700 Subject: [PATCH 1/5] Sending transitive closure of file dependencies --- reflection/serverreflection.go | 71 +++++++++++++++++++++++++--------- 1 file changed, 53 insertions(+), 18 deletions(-) diff --git a/reflection/serverreflection.go b/reflection/serverreflection.go index 7b6dd414a27..f0e58a98d22 100644 --- a/reflection/serverreflection.go +++ b/reflection/serverreflection.go @@ -269,9 +269,41 @@ func (s *serverReflectionServer) allExtensionNumbersForType(st reflect.Type) ([] return out, nil } +// fileDescWithDependencies returns a slice of serialized fileDescriptors in +// wire format ([]byte). The fileDescriptors will include fd and all the +// transitive dependencies of fd with names not in sentFileDescriptors. +func fileDescWithDependencies(fd *dpb.FileDescriptorProto, sentFileDescriptors map[string]struct{}) ([][]byte, error) { + r := [][]byte{} + queue := make([]*dpb.FileDescriptorProto, 0) + queue = append(queue, fd) + for len(queue) > 0 { + currentfd := queue[0] + queue = queue[1:] + if _, exists := sentFileDescriptors[fd.GetName()]; len(r) == 0 || !exists { + sentFileDescriptors[fd.GetName()] = struct{}{} + currentfdEncoded, err := proto.Marshal(currentfd) + if err != nil { + return nil, err + } + r = append(r, currentfdEncoded) + + } + for _, dep := range currentfd.Dependency { + fdenc := proto.FileDescriptor(dep) + fdDep, err := decodeFileDesc(fdenc) + if err != nil { + continue + } + queue = append(queue, fdDep) + } + } + return r, nil +} + // fileDescEncodingByFilename finds the file descriptor for given filename, -// does marshalling on it and returns the marshalled result. -func (s *serverReflectionServer) fileDescEncodingByFilename(name string) ([]byte, error) { +// finds all of its previously unsent transitive dependencies, does marshalling +// on them, and returns the marshalled result. +func (s *serverReflectionServer) fileDescEncodingByFilename(name string, sentFileDescriptors map[string]struct{}) ([][]byte, error) { enc := proto.FileDescriptor(name) if enc == nil { return nil, fmt.Errorf("unknown file: %v", name) @@ -280,7 +312,7 @@ func (s *serverReflectionServer) fileDescEncodingByFilename(name string) ([]byte if err != nil { return nil, err } - return proto.Marshal(fd) + return fileDescWithDependencies(fd, sentFileDescriptors) } // parseMetadata finds the file descriptor bytes specified meta. @@ -301,10 +333,11 @@ func parseMetadata(meta interface{}) ([]byte, bool) { return nil, false } -// fileDescEncodingContainingSymbol finds the file descriptor containing the given symbol, -// does marshalling on it and returns the marshalled result. -// The given symbol can be a type, a service or a method. -func (s *serverReflectionServer) fileDescEncodingContainingSymbol(name string) ([]byte, error) { +// fileDescEncodingContainingSymbol finds the file descriptor containing the +// given symbol, finds all of its previously unsent transitive dependencies, +// does marshalling on them, and returns the marshalled result. The given symbol +// can be a type, a service or a method. +func (s *serverReflectionServer) fileDescEncodingContainingSymbol(name string, sentFileDescriptors map[string]struct{}) ([][]byte, error) { _, symbols := s.getSymbols() fd := symbols[name] if fd == nil { @@ -322,12 +355,13 @@ func (s *serverReflectionServer) fileDescEncodingContainingSymbol(name string) ( return nil, fmt.Errorf("unknown symbol: %v", name) } - return proto.Marshal(fd) + return fileDescWithDependencies(fd, sentFileDescriptors) } -// fileDescEncodingContainingExtension finds the file descriptor containing given extension, -// does marshalling on it and returns the marshalled result. -func (s *serverReflectionServer) fileDescEncodingContainingExtension(typeName string, extNum int32) ([]byte, error) { +// fileDescEncodingContainingExtension finds the file descriptor containing +// given extension, finds all of its previously unsent transitive dependencies, +// does marshalling on them, and returns the marshalled result. +func (s *serverReflectionServer) fileDescEncodingContainingExtension(typeName string, extNum int32, sentFileDescriptors map[string]struct{}) ([][]byte, error) { st, err := typeForName(typeName) if err != nil { return nil, err @@ -336,7 +370,7 @@ func (s *serverReflectionServer) fileDescEncodingContainingExtension(typeName st if err != nil { return nil, err } - return proto.Marshal(fd) + return fileDescWithDependencies(fd, sentFileDescriptors) } // allExtensionNumbersForTypeName returns all extension numbers for the given type. @@ -354,6 +388,7 @@ func (s *serverReflectionServer) allExtensionNumbersForTypeName(name string) ([] // ServerReflectionInfo is the reflection service handler. func (s *serverReflectionServer) ServerReflectionInfo(stream rpb.ServerReflection_ServerReflectionInfoServer) error { + sentFileDescriptors := make(map[string]struct{}) for { in, err := stream.Recv() if err == io.EOF { @@ -369,7 +404,7 @@ func (s *serverReflectionServer) ServerReflectionInfo(stream rpb.ServerReflectio } switch req := in.MessageRequest.(type) { case *rpb.ServerReflectionRequest_FileByFilename: - b, err := s.fileDescEncodingByFilename(req.FileByFilename) + b, err := s.fileDescEncodingByFilename(req.FileByFilename, sentFileDescriptors) if err != nil { out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{ ErrorResponse: &rpb.ErrorResponse{ @@ -379,11 +414,11 @@ func (s *serverReflectionServer) ServerReflectionInfo(stream rpb.ServerReflectio } } else { out.MessageResponse = &rpb.ServerReflectionResponse_FileDescriptorResponse{ - FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: [][]byte{b}}, + FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: b}, } } case *rpb.ServerReflectionRequest_FileContainingSymbol: - b, err := s.fileDescEncodingContainingSymbol(req.FileContainingSymbol) + b, err := s.fileDescEncodingContainingSymbol(req.FileContainingSymbol, sentFileDescriptors) if err != nil { out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{ ErrorResponse: &rpb.ErrorResponse{ @@ -393,13 +428,13 @@ func (s *serverReflectionServer) ServerReflectionInfo(stream rpb.ServerReflectio } } else { out.MessageResponse = &rpb.ServerReflectionResponse_FileDescriptorResponse{ - FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: [][]byte{b}}, + FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: b}, } } case *rpb.ServerReflectionRequest_FileContainingExtension: typeName := req.FileContainingExtension.ContainingType extNum := req.FileContainingExtension.ExtensionNumber - b, err := s.fileDescEncodingContainingExtension(typeName, extNum) + b, err := s.fileDescEncodingContainingExtension(typeName, extNum, sentFileDescriptors) if err != nil { out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{ ErrorResponse: &rpb.ErrorResponse{ @@ -409,7 +444,7 @@ func (s *serverReflectionServer) ServerReflectionInfo(stream rpb.ServerReflectio } } else { out.MessageResponse = &rpb.ServerReflectionResponse_FileDescriptorResponse{ - FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: [][]byte{b}}, + FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: b}, } } case *rpb.ServerReflectionRequest_AllExtensionNumbersOfType: From 0461eb64280a1a803fdd3ab9b5925e42febe0323 Mon Sep 17 00:00:00 2001 From: "gargut@google.com" Date: Wed, 2 Sep 2020 14:05:44 -0700 Subject: [PATCH 2/5] Using bool map as set --- reflection/serverreflection.go | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/reflection/serverreflection.go b/reflection/serverreflection.go index f0e58a98d22..9346e3c7f11 100644 --- a/reflection/serverreflection.go +++ b/reflection/serverreflection.go @@ -272,21 +272,19 @@ func (s *serverReflectionServer) allExtensionNumbersForType(st reflect.Type) ([] // fileDescWithDependencies returns a slice of serialized fileDescriptors in // wire format ([]byte). The fileDescriptors will include fd and all the // transitive dependencies of fd with names not in sentFileDescriptors. -func fileDescWithDependencies(fd *dpb.FileDescriptorProto, sentFileDescriptors map[string]struct{}) ([][]byte, error) { +func fileDescWithDependencies(fd *dpb.FileDescriptorProto, sentFileDescriptors map[string]bool) ([][]byte, error) { r := [][]byte{} - queue := make([]*dpb.FileDescriptorProto, 0) - queue = append(queue, fd) + queue := []*dpb.FileDescriptorProto{fd} for len(queue) > 0 { currentfd := queue[0] queue = queue[1:] - if _, exists := sentFileDescriptors[fd.GetName()]; len(r) == 0 || !exists { - sentFileDescriptors[fd.GetName()] = struct{}{} + if sent := sentFileDescriptors[fd.GetName()]; len(r) == 0 || !sent { + sentFileDescriptors[fd.GetName()] = true currentfdEncoded, err := proto.Marshal(currentfd) if err != nil { return nil, err } r = append(r, currentfdEncoded) - } for _, dep := range currentfd.Dependency { fdenc := proto.FileDescriptor(dep) @@ -303,7 +301,7 @@ func fileDescWithDependencies(fd *dpb.FileDescriptorProto, sentFileDescriptors m // fileDescEncodingByFilename finds the file descriptor for given filename, // finds all of its previously unsent transitive dependencies, does marshalling // on them, and returns the marshalled result. -func (s *serverReflectionServer) fileDescEncodingByFilename(name string, sentFileDescriptors map[string]struct{}) ([][]byte, error) { +func (s *serverReflectionServer) fileDescEncodingByFilename(name string, sentFileDescriptors map[string]bool) ([][]byte, error) { enc := proto.FileDescriptor(name) if enc == nil { return nil, fmt.Errorf("unknown file: %v", name) @@ -337,7 +335,7 @@ func parseMetadata(meta interface{}) ([]byte, bool) { // given symbol, finds all of its previously unsent transitive dependencies, // does marshalling on them, and returns the marshalled result. The given symbol // can be a type, a service or a method. -func (s *serverReflectionServer) fileDescEncodingContainingSymbol(name string, sentFileDescriptors map[string]struct{}) ([][]byte, error) { +func (s *serverReflectionServer) fileDescEncodingContainingSymbol(name string, sentFileDescriptors map[string]bool) ([][]byte, error) { _, symbols := s.getSymbols() fd := symbols[name] if fd == nil { @@ -361,7 +359,7 @@ func (s *serverReflectionServer) fileDescEncodingContainingSymbol(name string, s // fileDescEncodingContainingExtension finds the file descriptor containing // given extension, finds all of its previously unsent transitive dependencies, // does marshalling on them, and returns the marshalled result. -func (s *serverReflectionServer) fileDescEncodingContainingExtension(typeName string, extNum int32, sentFileDescriptors map[string]struct{}) ([][]byte, error) { +func (s *serverReflectionServer) fileDescEncodingContainingExtension(typeName string, extNum int32, sentFileDescriptors map[string]bool) ([][]byte, error) { st, err := typeForName(typeName) if err != nil { return nil, err @@ -388,7 +386,7 @@ func (s *serverReflectionServer) allExtensionNumbersForTypeName(name string) ([] // ServerReflectionInfo is the reflection service handler. func (s *serverReflectionServer) ServerReflectionInfo(stream rpb.ServerReflection_ServerReflectionInfoServer) error { - sentFileDescriptors := make(map[string]struct{}) + sentFileDescriptors := make(map[string]bool) for { in, err := stream.Recv() if err == io.EOF { From f5a3d32d1a97d1336768ae274244854ff935df2e Mon Sep 17 00:00:00 2001 From: "gargut@google.com" Date: Wed, 2 Sep 2020 15:19:45 -0700 Subject: [PATCH 3/5] Adding correct file descriptor --- reflection/serverreflection.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/reflection/serverreflection.go b/reflection/serverreflection.go index 9346e3c7f11..d2696168b10 100644 --- a/reflection/serverreflection.go +++ b/reflection/serverreflection.go @@ -278,8 +278,8 @@ func fileDescWithDependencies(fd *dpb.FileDescriptorProto, sentFileDescriptors m for len(queue) > 0 { currentfd := queue[0] queue = queue[1:] - if sent := sentFileDescriptors[fd.GetName()]; len(r) == 0 || !sent { - sentFileDescriptors[fd.GetName()] = true + if sent := sentFileDescriptors[currentfd.GetName()]; len(r) == 0 || !sent { + sentFileDescriptors[currentfd.GetName()] = true currentfdEncoded, err := proto.Marshal(currentfd) if err != nil { return nil, err From ef91cc3ddabbcbf5500ff8d4d654cc52c9cbaef4 Mon Sep 17 00:00:00 2001 From: "gargut@google.com" Date: Wed, 9 Sep 2020 14:59:50 -0700 Subject: [PATCH 4/5] Added test for transitive closure --- reflection/serverreflection_test.go | 35 +++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/reflection/serverreflection_test.go b/reflection/serverreflection_test.go index db5fce2d893..ff7c954d0d8 100644 --- a/reflection/serverreflection_test.go +++ b/reflection/serverreflection_test.go @@ -214,6 +214,8 @@ func (x) TestReflectionEnd2end(t *testing.T) { t.Fatalf("cannot get ServerReflectionInfo: %v", err) } + testFileByFilenameTransitiveClosure(t, stream, true) + testFileByFilenameTransitiveClosure(t, stream, false) testFileByFilename(t, stream) testFileByFilenameError(t, stream) testFileContainingSymbol(t, stream) @@ -227,6 +229,39 @@ func (x) TestReflectionEnd2end(t *testing.T) { s.Stop() } +func testFileByFilenameTransitiveClosure(t *testing.T, stream rpb.ServerReflection_ServerReflectionInfoClient, expectClosure bool) { + filename := "reflection/grpc_testing/proto2_ext2.proto" + if err := stream.Send(&rpb.ServerReflectionRequest{ + MessageRequest: &rpb.ServerReflectionRequest_FileByFilename{ + FileByFilename: filename, + }, + }); err != nil { + t.Fatalf("failed to send request: %v", err) + } + r, err := stream.Recv() + if err != nil { + // io.EOF is not ok. + t.Fatalf("failed to recv response: %v", err) + } + switch r.MessageResponse.(type) { + case *rpb.ServerReflectionResponse_FileDescriptorResponse: + if !reflect.DeepEqual(r.GetFileDescriptorResponse().FileDescriptorProto[0], fdProto2Ext2Byte) { + t.Errorf("FileByFilename(%v)\nreceived: %q,\nwant: %q", filename, r.GetFileDescriptorResponse().FileDescriptorProto[0], fdProto2Ext2Byte) + } + if expectClosure { + if len(r.GetFileDescriptorResponse().FileDescriptorProto) < 2 { + t.Errorf("FileByFilename(%v) returned %v file descriptors, expected 2", filename, len(r.GetFileDescriptorResponse().FileDescriptorProto)) + } else if !reflect.DeepEqual(r.GetFileDescriptorResponse().FileDescriptorProto[1], fdProto2Byte) { + t.Errorf("FileByFilename(%v)\nreceived: %q,\nwant: %q", filename, r.GetFileDescriptorResponse().FileDescriptorProto[1], fdProto2Byte) + } + } else if len(r.GetFileDescriptorResponse().FileDescriptorProto) > 1 { + t.Errorf("FileByFilename(%v) returned %v file descriptors, expected 1", filename, len(r.GetFileDescriptorResponse().FileDescriptorProto)) + } + default: + t.Errorf("FileByFilename(%v) = %v, want type ", filename, r.MessageResponse) + } +} + func testFileByFilename(t *testing.T, stream rpb.ServerReflection_ServerReflectionInfoClient) { for _, test := range []struct { filename string From b6536fef492cdc297694a039615db25c73c18c8f Mon Sep 17 00:00:00 2001 From: "gargut@google.com" Date: Wed, 9 Sep 2020 15:18:53 -0700 Subject: [PATCH 5/5] != over >< --- reflection/serverreflection_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/reflection/serverreflection_test.go b/reflection/serverreflection_test.go index ff7c954d0d8..9f252d778b9 100644 --- a/reflection/serverreflection_test.go +++ b/reflection/serverreflection_test.go @@ -249,12 +249,12 @@ func testFileByFilenameTransitiveClosure(t *testing.T, stream rpb.ServerReflecti t.Errorf("FileByFilename(%v)\nreceived: %q,\nwant: %q", filename, r.GetFileDescriptorResponse().FileDescriptorProto[0], fdProto2Ext2Byte) } if expectClosure { - if len(r.GetFileDescriptorResponse().FileDescriptorProto) < 2 { + if len(r.GetFileDescriptorResponse().FileDescriptorProto) != 2 { t.Errorf("FileByFilename(%v) returned %v file descriptors, expected 2", filename, len(r.GetFileDescriptorResponse().FileDescriptorProto)) } else if !reflect.DeepEqual(r.GetFileDescriptorResponse().FileDescriptorProto[1], fdProto2Byte) { t.Errorf("FileByFilename(%v)\nreceived: %q,\nwant: %q", filename, r.GetFileDescriptorResponse().FileDescriptorProto[1], fdProto2Byte) } - } else if len(r.GetFileDescriptorResponse().FileDescriptorProto) > 1 { + } else if len(r.GetFileDescriptorResponse().FileDescriptorProto) != 1 { t.Errorf("FileByFilename(%v) returned %v file descriptors, expected 1", filename, len(r.GetFileDescriptorResponse().FileDescriptorProto)) } default: