Skip to content

Commit

Permalink
[proto] Allow multiple outputs from a proto compiler (#3650)
Browse files Browse the repository at this point in the history
  • Loading branch information
tingilee committed Aug 13, 2023
1 parent 07ec991 commit f5ae196
Show file tree
Hide file tree
Showing 7 changed files with 383 additions and 7 deletions.
20 changes: 13 additions & 7 deletions proto/compiler.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,18 @@ def go_proto_compile(go, compiler, protos, imports, importpath):
continue
proto_paths[path] = src

out = go.declare_file(
go,
path = importpath + "/" + src.basename[:-len(".proto")],
ext = compiler.internal.suffix,
)
go_srcs.append(out)
suffixes = compiler.internal.suffixes
if not suffixes:
suffixes = [compiler.internal.suffix]
for suffix in suffixes:
out = go.declare_file(
go,
path = importpath + "/" + src.basename[:-len(".proto")],
ext = suffix,
)
go_srcs.append(out)
if outpath == None:
outpath = out.dirname[:-len(importpath)]
outpath = go_srcs[0].dirname[:-len(importpath)]

transitive_descriptor_sets = depset(direct = [], transitive = desc_sets)

Expand Down Expand Up @@ -174,6 +178,7 @@ def _go_proto_compiler_impl(ctx):
internal = struct(
options = ctx.attr.options,
suffix = ctx.attr.suffix,
suffixes = ctx.attr.suffixes,
protoc = ctx.executable._protoc,
go_protoc = ctx.executable._go_protoc,
plugin = ctx.executable.plugin,
Expand All @@ -190,6 +195,7 @@ _go_proto_compiler = rule(
"deps": attr.label_list(providers = [GoLibrary]),
"options": attr.string_list(),
"suffix": attr.string(default = ".pb.go"),
"suffixes": attr.string_list(),
"valid_archive": attr.bool(default = True),
"import_path_option": attr.bool(default = False),
"plugin": attr.label(
Expand Down
29 changes: 29 additions & 0 deletions tests/core/go_proto_library/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ proto_library(
srcs = ["grpc.proto"],
)

proto_library(
name = "enum_proto",
srcs = ["enum.proto"],
deps = [
"@com_google_protobuf//:descriptor_proto",
],
)

# embed_test
go_proto_library(
name = "embed_go_proto",
Expand Down Expand Up @@ -199,6 +207,27 @@ go_proto_library(
protos = [":grpc_proto"],
)

# compilers with multiple suffixes
go_test(
name = "compilers_multi_suffix_test",
srcs = ["compiler_multi_suffix_test.go"],
deps = [
":compilers_multi_suffix",
],
)

go_proto_library(
name = "compilers_multi_suffix",
compilers = ["//tests/core/go_proto_library/compilers:dbenum_compiler"],
importpath = "github.com/bazelbuild/rules_go/tests/core/go_proto_library/enum",
protos = [":enum_proto"],
deps = [
"@com_github_gogo_protobuf//proto",
"@com_github_gogo_protobuf//protoc-gen-gogo/descriptor",
"@com_github_gogo_protobuf//types",
],
)

# adjusted_import_test
# TODO(#1851): uncomment when Bazel 0.22.0 is the minimum version.
# go_test(
Expand Down
41 changes: 41 additions & 0 deletions tests/core/go_proto_library/compiler_multi_suffix_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/* Copyright 2019 The Bazel Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package multi_suffix_compiler

import (
"testing"

"github.com/bazelbuild/rules_go/tests/core/go_proto_library/enum"
)

func use(interface{}) {}

func TestMultiSuffixCompiler(t *testing.T) {
// This test expects the compiler to generate two outputs:
// <proto>.pb.go and <proto>_dbenums.pb.go.
// Assert <proto>_dbenums.pb.go contains String() that returns dbenum value.
v := enum.Enum_BYTES
expected := "bytes_type"
if v.String() != expected {
panic(v.String())
}
// Assert <proto>.pb.go contains String() that returns proto Enum key.
v = enum.Enum_INT32
expected = "INT32"
if v.String() != expected {
panic(v.String())
}
}
44 changes: 44 additions & 0 deletions tests/core/go_proto_library/compilers/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_library")
load(
"//proto:compiler.bzl",
"go_proto_compiler",
)
load(
"//proto/wkt:well_known_types.bzl",
"GOGO_WELL_KNOWN_TYPE_REMAPS",
)

go_library(
name = "protoc_gen_dbenum_lib",
srcs = [
"dbenums.go",
"main.go",
],
importpath = "github.com/bazelbuild/rules_go/tests/core/go_proto_library/compilers",
visibility = ["//visibility:private"],
deps = [
"@com_github_gogo_protobuf//proto",
"@com_github_gogo_protobuf//protoc-gen-gogo/descriptor",
"@com_github_gogo_protobuf//protoc-gen-gogo/generator",
"@com_github_gogo_protobuf//protoc-gen-gogo/plugin",
"@com_github_gogo_protobuf//vanity",
"@com_github_gogo_protobuf//vanity/command",
],
)

go_binary(
name = "protoc-gen-dbenum-compiler",
embed = [":protoc_gen_dbenum_lib"],
visibility = ["//visibility:private"],
)

go_proto_compiler(
name = "dbenum_compiler",
options = GOGO_WELL_KNOWN_TYPE_REMAPS,
plugin = "//tests/core/go_proto_library/compilers:protoc-gen-dbenum-compiler",
suffixes = [
"_dbenum.pb.go",
".pb.go",
],
visibility = ["//visibility:public"],
)
189 changes: 189 additions & 0 deletions tests/core/go_proto_library/compilers/dbenums.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
package main

import (
"bytes"
"strings"
"text/template"

"github.com/gogo/protobuf/proto"
"github.com/gogo/protobuf/protoc-gen-gogo/descriptor"
pb "github.com/gogo/protobuf/protoc-gen-gogo/descriptor"
"github.com/gogo/protobuf/protoc-gen-gogo/generator"
)

func init() {
generator.RegisterPlugin(NewGenerator())
}

type Generator struct {
*generator.Generator
generator.PluginImports
write bool
}

func NewGenerator() *Generator {
return &Generator{}
}

func (g *Generator) Name() string {
return "dbenum"
}

func (g *Generator) Init(gen *generator.Generator) {
g.Generator = gen
}

func (g *Generator) GenerateImports(file *generator.FileDescriptor) {
}

func (g *Generator) Generate(file *generator.FileDescriptor) {
for _, enum := range file.Enums() {
g.enumHelper(enum)
}
g.writeTrailer(file.Enums())
}

func (g *Generator) Write() bool {
return g.write
}

const initTmpl = `
`

func (g *Generator) writeTrailer(enums []*generator.EnumDescriptor) {
type desc struct {
PackageName string
TypeName string
LowerCaseTypeName string
}
if !g.write {
return
}
tmpl := template.Must(template.New("db_enum_trailer").Parse(initTmpl))
g.P("func init() {")
for _, e := range enums {
if !HasDBEnum(e.Value) {
continue
}
pkg := e.File().GetPackage()
if pkg != "" {
pkg += "."
}
tp := generator.CamelCaseSlice(e.TypeName())
var buf bytes.Buffer
tmpl.Execute(&buf, desc{
PackageName: pkg + tp,
TypeName: tp,
LowerCaseTypeName: strings.ToLower(tp),
})
g.P(buf.String())
}
g.P("}")
}

func (g *Generator) enumHelper(enum *generator.EnumDescriptor) {
type anEnum struct {
PBName string
DBName string
}
type typeDesc struct {
TypeName string
TypeNamespace string
LowerCaseTypeName string
Found map[int32]bool
Names []anEnum
AllNames []anEnum
}
tp := generator.CamelCaseSlice(enum.TypeName())
namespace := tp
enumTypeName := enum.TypeName()
if len(enumTypeName) > 1 { // This is a nested enum.
names := enumTypeName[:len(enumTypeName)-1]
// See https://protobuf.dev/reference/go/go-generated/#enum
namespace = generator.CamelCaseSlice(names)
}
t := typeDesc{
TypeName: tp,
TypeNamespace: namespace,
LowerCaseTypeName: strings.ToLower(tp),
Found: make(map[int32]bool),
}
for _, v := range enum.Value {
enumValue := v.GetNumber()
if validDbEnum, dbName := getDbEnum(v); validDbEnum {
names := anEnum{PBName: v.GetName(), DBName: dbName}
t.AllNames = append(t.AllNames, names)
// Skip enums that are aliased where one value has already been processed.
if t.Found[enumValue] {
continue
}
t.Found[enumValue] = true
t.Names = append(t.Names, names)
} else {
t.Found[enumValue] = true
}
}
if len(t.AllNames) == 0 {
return
}
g.write = true
tmpl := template.Must(template.New("db_enum").Parse(tmpl))
var buf bytes.Buffer
tmpl.Execute(&buf, t)
g.P(buf.String())
}

var E_DbEnum = &proto.ExtensionDesc{
ExtendedType: (*descriptor.EnumValueOptions)(nil),
ExtensionType: (*string)(nil),
Field: 5002,
Name: "tests.core.go_proto_library.enum",
Tag: "bytes,5002,opt,name=db_enum",
}

func getDbEnum(value *pb.EnumValueDescriptorProto) (bool, string) {
if value == nil || value.Options == nil {
return false, ""
}
EDbEnum := E_DbEnum
v, err := proto.GetExtension(value.Options, EDbEnum)
if err != nil {
return false, ""
}
strPtr := v.(*string)
if strPtr == nil {
return false, ""
}
return true, *strPtr
}

// HasDBEnum returns if there is DBEnums extensions defined in given enums.
func HasDBEnum(enums []*pb.EnumValueDescriptorProto) bool {
for _, enum := range enums {
if validDbEnum, _ := getDbEnum(enum); validDbEnum {
return true
}
}
return false
}

const tmpl = `
var {{ .LowerCaseTypeName }}ToStringValue = ` +
`map[{{ .TypeName }}]string { {{ range $names := .Names }}
{{ $.TypeNamespace }}_{{ $names.PBName }}: ` +
`"{{ $names.DBName }}",{{ end }}
}
// String implements the stringer interface and should produce the same output
// that is inserted into the db.
func (v {{ .TypeName }}) String() string {
if val, ok := {{ .LowerCaseTypeName }}ToStringValue[v]; ok {
return val
} else if int(v) == 0 {
return "null"
} else {
return proto.EnumName({{ .TypeName }}_name, int32(v))
}
}`

0 comments on commit f5ae196

Please sign in to comment.