Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

protoc gen oas v2 cleanup #2996

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 4 additions & 6 deletions protoc-gen-openapiv2/internal/genopenapi/generator.go
Expand Up @@ -23,9 +23,7 @@ import (
legacydescriptor "github.com/golang/protobuf/descriptor"
)

var (
errNoTargetService = errors.New("no target service defined in the file")
)
var errNoTargetService = errors.New("no target service defined in the file")

type generator struct {
reg *descriptor.Registry
Expand Down Expand Up @@ -326,7 +324,7 @@ func (g *generator) Generate(targets []*descriptor.File) ([]*descriptor.Response
for _, file := range targets {
glog.V(1).Infof("Processing %s", file.GetName())
swagger, err := applyTemplate(param{File: file, reg: g.reg})
if err == errNoTargetService {
if errors.Is(err, errNoTargetService) {
glog.V(1).Infof("%s: %v", file.GetName(), err)
continue
}
Expand All @@ -343,15 +341,15 @@ func (g *generator) Generate(targets []*descriptor.File) ([]*descriptor.Response
targetOpenAPI := mergeTargetFile(openapis, g.reg.GetMergeFileName())
f, err := encodeOpenAPI(targetOpenAPI, g.format)
if err != nil {
return nil, fmt.Errorf("failed to encode OpenAPI for %s: %s", g.reg.GetMergeFileName(), err)
return nil, fmt.Errorf("failed to encode OpenAPI for %s: %w", g.reg.GetMergeFileName(), err)
}
files = append(files, f)
glog.V(1).Infof("New OpenAPI file will emit")
} else {
for _, file := range openapis {
f, err := encodeOpenAPI(file, g.format)
if err != nil {
return nil, fmt.Errorf("failed to encode OpenAPI for %s: %s", file.fileName, err)
return nil, fmt.Errorf("failed to encode OpenAPI for %s: %w", file.fileName, err)
}
files = append(files, f)
glog.V(1).Infof("New OpenAPI file will emit")
Expand Down
81 changes: 32 additions & 49 deletions protoc-gen-openapiv2/internal/genopenapi/template.go
Expand Up @@ -3,6 +3,7 @@ package genopenapi
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"math"
"net/textproto"
Expand Down Expand Up @@ -180,7 +181,7 @@ func newCycleChecker(recursive int) *cycleChecker {
// toleration
func (c *cycleChecker) Check(name string) bool {
count, ok := c.m[name]
count = count + 1
count += 1
isCycle := count > c.count

if isCycle {
Expand All @@ -201,7 +202,7 @@ func (c *cycleChecker) Check(name string) bool {
func (c *cycleChecker) Branch() *cycleChecker {
copy := &cycleChecker{
count: c.count,
m: map[string]int{},
m: make(map[string]int, len(c.m)),
}

for k, v := range c.m {
Expand Down Expand Up @@ -336,8 +337,7 @@ func nestedQueryParams(message *descriptor.Message, field *descriptor.Field, pre
}

// Check for cyclical message reference:
isOK := cycle.Check(*msg.Name)
if !isOK {
if ok := cycle.Check(*msg.Name); !ok {
return nil, fmt.Errorf("exceeded recursive count (%d) for query parameter %q", cycle.count, fieldType)
}

Expand Down Expand Up @@ -800,7 +800,7 @@ func renderEnumerationsAsDefinition(enums enumMap, d openapiDefinitionsObject, r
for _, enum := range enums {
swgName, ok := fullyQualifiedNameToOpenAPIName(enum.FQEN(), reg)
if !ok {
panic(fmt.Sprintf("can't resolve OpenAPI name from FQEN '%v'", enum.FQEN()))
panic(fmt.Sprintf("can't resolve OpenAPI name from FQEN %q", enum.FQEN()))
}
enumComments := protoComments(reg, enum.File, enum.Outers, "EnumType", int32(enum.Index))

Expand Down Expand Up @@ -958,8 +958,7 @@ func partsToOpenAPIPath(parts []string, overrides map[string]string) string {
}
parts[index] = part
}
last := len(parts) - 1
if strings.HasPrefix(parts[last], ":") {
if last := len(parts) - 1; strings.HasPrefix(parts[last], ":") {
// Last item is a verb (":" LITERAL).
return strings.Join(parts[:last], "/") + parts[last]
}
Expand Down Expand Up @@ -988,7 +987,7 @@ func partsToRegexpMap(parts []string) map[string]string {
regExps := make(map[string]string)
for _, part := range parts {
if strings.Contains(part, "/") {
glog.Warningf("Path parameter '%s' contains '/', which is not supported in OpenAPI", part)
glog.Warningf("Path parameter %q contains '/', which is not supported in OpenAPI", part)
}
if submatch := canRegexp.FindStringSubmatch(part); len(submatch) > 2 {
if strings.HasPrefix(submatch[2], "=") { // this part matches the standard and should be made into a regular expression
Expand Down Expand Up @@ -1086,7 +1085,7 @@ func renderServices(services []*descriptor.Service, paths openapiPathsObject, re
defaultValue = schema.Default
extensions = schema.extensions
} else {
return fmt.Errorf("only primitive and well-known types are allowed in path parameters")
return errors.New("only primitive and well-known types are allowed in path parameters")
}
case descriptorpb.FieldDescriptorProto_TYPE_ENUM:
enum, err := reg.LookupEnum("", parameter.Target.GetTypeName())
Expand Down Expand Up @@ -1194,8 +1193,7 @@ func renderServices(services []*descriptor.Service, paths openapiPathsObject, re
return err
}
if len(b.PathParams) == 0 {
err = schema.setRefFromFQN(meth.RequestType.FQMN(), reg)
if err != nil {
if err := schema.setRefFromFQN(meth.RequestType.FQMN(), reg); err != nil {
return err
}
desc = messageSchema.Description
Expand All @@ -1211,7 +1209,7 @@ func renderServices(services []*descriptor.Service, paths openapiPathsObject, re
// "NOTE: the referred field must be present at the top-level of the request message type."
// Ref: https://github.com/googleapis/googleapis/blob/b3397f5febbf21dfc69b875ddabaf76bee765058/google/api/http.proto#L350-L352
if len(b.Body.FieldPath) > 1 {
return fmt.Errorf("Body of request '%s' is not a top level field: '%v'.", meth.Service.GetName(), b.Body.FieldPath)
return fmt.Errorf("Body of request %q is not a top level field: '%v'.", meth.Service.GetName(), b.Body.FieldPath)
}
bodyField := b.Body.FieldPath[0]
if reg.GetUseJSONNamesForFields() {
Expand Down Expand Up @@ -1315,8 +1313,7 @@ func renderServices(services []*descriptor.Service, paths openapiPathsObject, re
// well, without a definition
wknSchemaCore, isWkn := wktSchemas[meth.ResponseType.FQMN()]
if !isWkn {
err := responseSchema.setRefFromFQN(meth.ResponseType.FQMN(), reg)
if err != nil {
if err := responseSchema.setRefFromFQN(meth.ResponseType.FQMN(), reg); err != nil {
return err
}
} else {
Expand Down Expand Up @@ -1873,7 +1870,7 @@ func applyTemplate(p param) (*openapiSwaggerObject, error) {
}

func processExtensions(inputExts map[string]*structpb.Value) ([]extension, error) {
exts := []extension{}
exts := make([]extension, 0, len(inputExts))
for k, v := range inputExts {
if !strings.HasPrefix(k, "x-") {
return nil, fmt.Errorf("extension keys need to start with \"x-\": %q", k)
Expand All @@ -1888,7 +1885,7 @@ func processExtensions(inputExts map[string]*structpb.Value) ([]extension, error
return exts, nil
}

func validateHeaderTypeAndFormat(headerType string, format string) error {
func validateHeaderTypeAndFormat(headerType, format string) error {
// The type of the object. The value MUST be one of "string", "number", "integer", "boolean", or "array"
// See: https://github.com/OAI/OpenAPI-Specification/blob/3.0.0/versions/2.0.md#headerObject
// Note: currently not implementing array as we are only implementing this in the operation response context
Expand Down Expand Up @@ -1956,50 +1953,40 @@ func validateDefaultValueTypeAndFormat(headerType string, defaultValue string, f
switch format {
case "date-time":
unquoteTime := strings.Trim(defaultValue, `"`)
_, err := time.Parse(time.RFC3339, unquoteTime)
if err != nil {
if _, err := time.Parse(time.RFC3339, unquoteTime); err != nil {
return fmt.Errorf("the provided default value %q is not a valid RFC3339 date-time string", defaultValue)
}
case "date":
const (
layoutRFC3339Date = "2006-01-02"
)
const layoutRFC3339Date = "2006-01-02"
unquoteDate := strings.Trim(defaultValue, `"`)
_, err := time.Parse(layoutRFC3339Date, unquoteDate)
if err != nil {
if _, err := time.Parse(layoutRFC3339Date, unquoteDate); err != nil {
return fmt.Errorf("the provided default value %q is not a valid RFC3339 date-time string", defaultValue)
}
}
case "number":
err := isJSONNumber(defaultValue, headerType)
if err != nil {
if err := isJSONNumber(defaultValue, headerType); err != nil {
return err
}
case "integer":
switch format {
case "int32":
_, err := strconv.ParseInt(defaultValue, 0, 32)
if err != nil {
if _, err := strconv.ParseInt(defaultValue, 0, 32); err != nil {
return fmt.Errorf("the provided default value %q does not match provided format %q", defaultValue, format)
}
case "uint32":
_, err := strconv.ParseUint(defaultValue, 0, 32)
if err != nil {
if _, err := strconv.ParseUint(defaultValue, 0, 32); err != nil {
return fmt.Errorf("the provided default value %q does not match provided format %q", defaultValue, format)
}
case "int64":
_, err := strconv.ParseInt(defaultValue, 0, 64)
if err != nil {
if _, err := strconv.ParseInt(defaultValue, 0, 64); err != nil {
return fmt.Errorf("the provided default value %q does not match provided format %q", defaultValue, format)
}
case "uint64":
_, err := strconv.ParseUint(defaultValue, 0, 64)
if err != nil {
if _, err := strconv.ParseUint(defaultValue, 0, 64); err != nil {
return fmt.Errorf("the provided default value %q does not match provided format %q", defaultValue, format)
}
default:
_, err := strconv.ParseInt(defaultValue, 0, 64)
if err != nil {
if _, err := strconv.ParseInt(defaultValue, 0, 64); err != nil {
return fmt.Errorf("the provided default value %q does not match provided type %q", defaultValue, headerType)
}
}
Expand Down Expand Up @@ -2037,22 +2024,20 @@ func isBool(s string) bool {
}

func processHeaders(inputHdrs map[string]*openapi_options.Header) (openapiHeadersObject, error) {
hdrs := map[string]openapiHeaderObject{}
hdrs := make(map[string]openapiHeaderObject, len(inputHdrs))
for k, v := range inputHdrs {
header := textproto.CanonicalMIMEHeaderKey(k)
ret := openapiHeaderObject{
Description: v.Description,
Format: v.Format,
Pattern: v.Pattern,
}
err := validateHeaderTypeAndFormat(v.Type, v.Format)
if err != nil {
if err := validateHeaderTypeAndFormat(v.Type, v.Format); err != nil {
return nil, err
}
ret.Type = v.Type
if v.Default != "" {
err := validateDefaultValueTypeAndFormat(v.Type, v.Default, v.Format)
if err != nil {
if err := validateDefaultValueTypeAndFormat(v.Type, v.Default, v.Format); err != nil {
return nil, err
}
ret.Default = RawExample(v.Default)
Expand Down Expand Up @@ -2144,7 +2129,7 @@ func updateOpenAPIDataFromComments(reg *descriptor.Registry, swaggerObject inter
return nil
}

return fmt.Errorf("no description nor summary property")
return errors.New("no description nor summary property")
}

func fieldProtoComments(reg *descriptor.Registry, msg *descriptor.Message, field *descriptor.Field) string {
Expand All @@ -2168,8 +2153,7 @@ func enumValueProtoComments(reg *descriptor.Registry, enum *descriptor.Enum) str
if reg.GetEnumsAsInts() {
name = strconv.Itoa(int(value.GetNumber()))
}
str := protoComments(reg, enum.File, enum.Outers, "EnumType", int32(enum.Index), protoPath, int32(idx))
if str != "" {
if str := protoComments(reg, enum.File, enum.Outers, "EnumType", int32(enum.Index), protoPath, int32(idx)); str != "" {
comments = append(comments, name+": "+str)
}
}
Expand Down Expand Up @@ -2213,7 +2197,7 @@ func protoComments(reg *descriptor.Registry, file *descriptor.File, outers []str
// - determine if every (but first and last) line begins with " "
// - trim every line only if that is the case
// - join by \n
comments = strings.Replace(comments, "\n ", "\n", -1)
comments = strings.ReplaceAll(comments, "\n ", "\n")
}
if loc.TrailingComments != nil {
trailing := strings.TrimSpace(*loc.TrailingComments)
Expand Down Expand Up @@ -2250,8 +2234,7 @@ func goTemplateComments(comment string, data interface{}, reg *descriptor.Regist
// to make it easier to debug the template error
return err.Error()
}
err = tpl.Execute(&temp, data)
if err != nil {
if err := tpl.Execute(&temp, data); err != nil {
// If there is an error executing the templating insert the error as string in the comment
// to make it easier to debug the error
return err.Error()
Expand Down Expand Up @@ -2725,7 +2708,7 @@ func openapiExamplesFromProtoExamples(in map[string]string) map[string]interface
if len(in) == 0 {
return nil
}
out := make(map[string]interface{})
out := make(map[string]interface{}, len(in))
for mimeType, exampleStr := range in {
switch mimeType {
case "application/json":
Expand Down Expand Up @@ -2799,7 +2782,7 @@ func addCustomRefs(d openapiDefinitionsObject, reg *descriptor.Registry, refs re
for ref := range refs {
swgName, swgOk := fullyQualifiedNameToOpenAPIName(ref, reg)
if !swgOk {
glog.Errorf("can't resolve OpenAPI name from CustomRef '%v'", ref)
glog.Errorf("can't resolve OpenAPI name from CustomRef %q", ref)
continue
}
if _, ok := d[swgName]; ok {
Expand Down Expand Up @@ -2835,7 +2818,7 @@ func lowerCamelCase(fieldName string, fields []*descriptor.Field, msgs []*descri
return oneField.GetJsonName()
}
}
messageNameToFieldsToJSONName := make(map[string]map[string]string)
messageNameToFieldsToJSONName := make(map[string]map[string]string, len(msgs))
fieldNameToType := make(map[string]string)
for _, msg := range msgs {
fieldNameToJSONName := make(map[string]string)
Expand Down
44 changes: 18 additions & 26 deletions protoc-gen-openapiv2/main.go
Expand Up @@ -78,8 +78,7 @@ func main() {
glog.V(1).Info("Parsed code generator request")
pkgMap := make(map[string]string)
if req.Parameter != nil {
err := parseReqParam(req.GetParameter(), flag.CommandLine, pkgMap)
if err != nil {
if err := parseReqParam(req.GetParameter(), flag.CommandLine, pkgMap); err != nil {
glog.Fatalf("Error parsing flags: %v", err)
}
}
Expand Down Expand Up @@ -166,7 +165,7 @@ func main() {
}
}

var targets []*descriptor.File
targets := make([]*descriptor.File, 0, len(req.FileToGenerate))
for _, target := range req.FileToGenerate {
f, err := reg.LookupFile(target)
if err != nil {
Expand Down Expand Up @@ -218,37 +217,30 @@ func parseReqParam(param string, f *flag.FlagSet, pkgMap map[string]string) erro
for _, p := range strings.Split(param, ",") {
spec := strings.SplitN(p, "=", 2)
if len(spec) == 1 {
if spec[0] == "allow_delete_body" {
err := f.Set(spec[0], "true")
if err != nil {
return fmt.Errorf("cannot set flag %s: %v", p, err)
switch spec[0] {
case "allow_delete_body":
if err := f.Set(spec[0], "true"); err != nil {
return fmt.Errorf("cannot set flag %s: %w", p, err)
}
continue
}
if spec[0] == "allow_merge" {
err := f.Set(spec[0], "true")
if err != nil {
return fmt.Errorf("cannot set flag %s: %v", p, err)
case "allow_merge":
if err := f.Set(spec[0], "true"); err != nil {
return fmt.Errorf("cannot set flag %s: %w", p, err)
}
continue
}
if spec[0] == "allow_repeated_fields_in_body" {
err := f.Set(spec[0], "true")
if err != nil {
return fmt.Errorf("cannot set flag %s: %v", p, err)
case "allow_repeated_fields_in_body":
if err := f.Set(spec[0], "true"); err != nil {
return fmt.Errorf("cannot set flag %s: %w", p, err)
}
continue
}
if spec[0] == "include_package_in_tags" {
err := f.Set(spec[0], "true")
if err != nil {
return fmt.Errorf("cannot set flag %s: %v", p, err)
case "include_package_in_tags":
if err := f.Set(spec[0], "true"); err != nil {
return fmt.Errorf("cannot set flag %s: %w", p, err)
}
continue
}
err := f.Set(spec[0], "")
if err != nil {
return fmt.Errorf("cannot set flag %s: %v", p, err)
if err := f.Set(spec[0], ""); err != nil {
return fmt.Errorf("cannot set flag %s: %w", p, err)
}
continue
}
Expand All @@ -258,7 +250,7 @@ func parseReqParam(param string, f *flag.FlagSet, pkgMap map[string]string) erro
continue
}
if err := f.Set(name, value); err != nil {
return fmt.Errorf("cannot set flag %s: %v", p, err)
return fmt.Errorf("cannot set flag %s: %w", p, err)
}
}
return nil
Expand Down