Skip to content

Commit

Permalink
protoc gen oas v2 cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
sashamelentyev committed Nov 8, 2022
1 parent 1b7515f commit 4792df0
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 81 deletions.
10 changes: 4 additions & 6 deletions protoc-gen-openapiv2/internal/genopenapi/generator.go
Expand Up @@ -23,9 +23,7 @@ import (
legacydescriptor "github.com/golang/protobuf/descriptor"
)

var (
errNoTargetService = errors.New("no target service defined in the file")
)
var errNoTargetService = errors.New("no target service defined in the file")

type generator struct {
reg *descriptor.Registry
Expand Down Expand Up @@ -326,7 +324,7 @@ func (g *generator) Generate(targets []*descriptor.File) ([]*descriptor.Response
for _, file := range targets {
glog.V(1).Infof("Processing %s", file.GetName())
swagger, err := applyTemplate(param{File: file, reg: g.reg})
if err == errNoTargetService {
if errors.Is(err, errNoTargetService) {
glog.V(1).Infof("%s: %v", file.GetName(), err)
continue
}
Expand All @@ -343,15 +341,15 @@ func (g *generator) Generate(targets []*descriptor.File) ([]*descriptor.Response
targetOpenAPI := mergeTargetFile(openapis, g.reg.GetMergeFileName())
f, err := encodeOpenAPI(targetOpenAPI, g.format)
if err != nil {
return nil, fmt.Errorf("failed to encode OpenAPI for %s: %s", g.reg.GetMergeFileName(), err)
return nil, fmt.Errorf("failed to encode OpenAPI for %s: %w", g.reg.GetMergeFileName(), err)
}
files = append(files, f)
glog.V(1).Infof("New OpenAPI file will emit")
} else {
for _, file := range openapis {
f, err := encodeOpenAPI(file, g.format)
if err != nil {
return nil, fmt.Errorf("failed to encode OpenAPI for %s: %s", file.fileName, err)
return nil, fmt.Errorf("failed to encode OpenAPI for %s: %w", file.fileName, err)
}
files = append(files, f)
glog.V(1).Infof("New OpenAPI file will emit")
Expand Down
81 changes: 32 additions & 49 deletions protoc-gen-openapiv2/internal/genopenapi/template.go
Expand Up @@ -3,6 +3,7 @@ package genopenapi
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"math"
"net/textproto"
Expand Down Expand Up @@ -180,7 +181,7 @@ func newCycleChecker(recursive int) *cycleChecker {
// toleration
func (c *cycleChecker) Check(name string) bool {
count, ok := c.m[name]
count = count + 1
count += 1
isCycle := count > c.count

if isCycle {
Expand All @@ -201,7 +202,7 @@ func (c *cycleChecker) Check(name string) bool {
func (c *cycleChecker) Branch() *cycleChecker {
copy := &cycleChecker{
count: c.count,
m: map[string]int{},
m: make(map[string]int, len(c.m)),
}

for k, v := range c.m {
Expand Down Expand Up @@ -336,8 +337,7 @@ func nestedQueryParams(message *descriptor.Message, field *descriptor.Field, pre
}

// Check for cyclical message reference:
isOK := cycle.Check(*msg.Name)
if !isOK {
if ok := cycle.Check(*msg.Name); !ok {
return nil, fmt.Errorf("exceeded recursive count (%d) for query parameter %q", cycle.count, fieldType)
}

Expand Down Expand Up @@ -800,7 +800,7 @@ func renderEnumerationsAsDefinition(enums enumMap, d openapiDefinitionsObject, r
for _, enum := range enums {
swgName, ok := fullyQualifiedNameToOpenAPIName(enum.FQEN(), reg)
if !ok {
panic(fmt.Sprintf("can't resolve OpenAPI name from FQEN '%v'", enum.FQEN()))
panic(fmt.Sprintf("can't resolve OpenAPI name from FQEN %q", enum.FQEN()))
}
enumComments := protoComments(reg, enum.File, enum.Outers, "EnumType", int32(enum.Index))

Expand Down Expand Up @@ -958,8 +958,7 @@ func partsToOpenAPIPath(parts []string, overrides map[string]string) string {
}
parts[index] = part
}
last := len(parts) - 1
if strings.HasPrefix(parts[last], ":") {
if last := len(parts) - 1; strings.HasPrefix(parts[last], ":") {
// Last item is a verb (":" LITERAL).
return strings.Join(parts[:last], "/") + parts[last]
}
Expand Down Expand Up @@ -988,7 +987,7 @@ func partsToRegexpMap(parts []string) map[string]string {
regExps := make(map[string]string)
for _, part := range parts {
if strings.Contains(part, "/") {
glog.Warningf("Path parameter '%s' contains '/', which is not supported in OpenAPI", part)
glog.Warningf("Path parameter %q contains '/', which is not supported in OpenAPI", part)
}
if submatch := canRegexp.FindStringSubmatch(part); len(submatch) > 2 {
if strings.HasPrefix(submatch[2], "=") { // this part matches the standard and should be made into a regular expression
Expand Down Expand Up @@ -1086,7 +1085,7 @@ func renderServices(services []*descriptor.Service, paths openapiPathsObject, re
defaultValue = schema.Default
extensions = schema.extensions
} else {
return fmt.Errorf("only primitive and well-known types are allowed in path parameters")
return errors.New("only primitive and well-known types are allowed in path parameters")
}
case descriptorpb.FieldDescriptorProto_TYPE_ENUM:
enum, err := reg.LookupEnum("", parameter.Target.GetTypeName())
Expand Down Expand Up @@ -1194,8 +1193,7 @@ func renderServices(services []*descriptor.Service, paths openapiPathsObject, re
return err
}
if len(b.PathParams) == 0 {
err = schema.setRefFromFQN(meth.RequestType.FQMN(), reg)
if err != nil {
if err := schema.setRefFromFQN(meth.RequestType.FQMN(), reg); err != nil {
return err
}
desc = messageSchema.Description
Expand All @@ -1211,7 +1209,7 @@ func renderServices(services []*descriptor.Service, paths openapiPathsObject, re
// "NOTE: the referred field must be present at the top-level of the request message type."
// Ref: https://github.com/googleapis/googleapis/blob/b3397f5febbf21dfc69b875ddabaf76bee765058/google/api/http.proto#L350-L352
if len(b.Body.FieldPath) > 1 {
return fmt.Errorf("Body of request '%s' is not a top level field: '%v'.", meth.Service.GetName(), b.Body.FieldPath)
return fmt.Errorf("Body of request %q is not a top level field: '%v'.", meth.Service.GetName(), b.Body.FieldPath)
}
bodyField := b.Body.FieldPath[0]
if reg.GetUseJSONNamesForFields() {
Expand Down Expand Up @@ -1315,8 +1313,7 @@ func renderServices(services []*descriptor.Service, paths openapiPathsObject, re
// well, without a definition
wknSchemaCore, isWkn := wktSchemas[meth.ResponseType.FQMN()]
if !isWkn {
err := responseSchema.setRefFromFQN(meth.ResponseType.FQMN(), reg)
if err != nil {
if err := responseSchema.setRefFromFQN(meth.ResponseType.FQMN(), reg); err != nil {
return err
}
} else {
Expand Down Expand Up @@ -1873,7 +1870,7 @@ func applyTemplate(p param) (*openapiSwaggerObject, error) {
}

func processExtensions(inputExts map[string]*structpb.Value) ([]extension, error) {
exts := []extension{}
exts := make([]extension, 0, len(inputExts))
for k, v := range inputExts {
if !strings.HasPrefix(k, "x-") {
return nil, fmt.Errorf("extension keys need to start with \"x-\": %q", k)
Expand All @@ -1888,7 +1885,7 @@ func processExtensions(inputExts map[string]*structpb.Value) ([]extension, error
return exts, nil
}

func validateHeaderTypeAndFormat(headerType string, format string) error {
func validateHeaderTypeAndFormat(headerType, format string) error {
// The type of the object. The value MUST be one of "string", "number", "integer", "boolean", or "array"
// See: https://github.com/OAI/OpenAPI-Specification/blob/3.0.0/versions/2.0.md#headerObject
// Note: currently not implementing array as we are only implementing this in the operation response context
Expand Down Expand Up @@ -1956,50 +1953,40 @@ func validateDefaultValueTypeAndFormat(headerType string, defaultValue string, f
switch format {
case "date-time":
unquoteTime := strings.Trim(defaultValue, `"`)
_, err := time.Parse(time.RFC3339, unquoteTime)
if err != nil {
if _, err := time.Parse(time.RFC3339, unquoteTime); err != nil {
return fmt.Errorf("the provided default value %q is not a valid RFC3339 date-time string", defaultValue)
}
case "date":
const (
layoutRFC3339Date = "2006-01-02"
)
const layoutRFC3339Date = "2006-01-02"
unquoteDate := strings.Trim(defaultValue, `"`)
_, err := time.Parse(layoutRFC3339Date, unquoteDate)
if err != nil {
if _, err := time.Parse(layoutRFC3339Date, unquoteDate); err != nil {
return fmt.Errorf("the provided default value %q is not a valid RFC3339 date-time string", defaultValue)
}
}
case "number":
err := isJSONNumber(defaultValue, headerType)
if err != nil {
if err := isJSONNumber(defaultValue, headerType); err != nil {
return err
}
case "integer":
switch format {
case "int32":
_, err := strconv.ParseInt(defaultValue, 0, 32)
if err != nil {
if _, err := strconv.ParseInt(defaultValue, 0, 32); err != nil {
return fmt.Errorf("the provided default value %q does not match provided format %q", defaultValue, format)
}
case "uint32":
_, err := strconv.ParseUint(defaultValue, 0, 32)
if err != nil {
if _, err := strconv.ParseUint(defaultValue, 0, 32); err != nil {
return fmt.Errorf("the provided default value %q does not match provided format %q", defaultValue, format)
}
case "int64":
_, err := strconv.ParseInt(defaultValue, 0, 64)
if err != nil {
if _, err := strconv.ParseInt(defaultValue, 0, 64); err != nil {
return fmt.Errorf("the provided default value %q does not match provided format %q", defaultValue, format)
}
case "uint64":
_, err := strconv.ParseUint(defaultValue, 0, 64)
if err != nil {
if _, err := strconv.ParseUint(defaultValue, 0, 64); err != nil {
return fmt.Errorf("the provided default value %q does not match provided format %q", defaultValue, format)
}
default:
_, err := strconv.ParseInt(defaultValue, 0, 64)
if err != nil {
if _, err := strconv.ParseInt(defaultValue, 0, 64); err != nil {
return fmt.Errorf("the provided default value %q does not match provided type %q", defaultValue, headerType)
}
}
Expand Down Expand Up @@ -2037,22 +2024,20 @@ func isBool(s string) bool {
}

func processHeaders(inputHdrs map[string]*openapi_options.Header) (openapiHeadersObject, error) {
hdrs := map[string]openapiHeaderObject{}
hdrs := make(map[string]openapiHeaderObject, len(inputHdrs))
for k, v := range inputHdrs {
header := textproto.CanonicalMIMEHeaderKey(k)
ret := openapiHeaderObject{
Description: v.Description,
Format: v.Format,
Pattern: v.Pattern,
}
err := validateHeaderTypeAndFormat(v.Type, v.Format)
if err != nil {
if err := validateHeaderTypeAndFormat(v.Type, v.Format); err != nil {
return nil, err
}
ret.Type = v.Type
if v.Default != "" {
err := validateDefaultValueTypeAndFormat(v.Type, v.Default, v.Format)
if err != nil {
if err := validateDefaultValueTypeAndFormat(v.Type, v.Default, v.Format); err != nil {
return nil, err
}
ret.Default = RawExample(v.Default)
Expand Down Expand Up @@ -2144,7 +2129,7 @@ func updateOpenAPIDataFromComments(reg *descriptor.Registry, swaggerObject inter
return nil
}

return fmt.Errorf("no description nor summary property")
return errors.New("no description nor summary property")
}

func fieldProtoComments(reg *descriptor.Registry, msg *descriptor.Message, field *descriptor.Field) string {
Expand All @@ -2168,8 +2153,7 @@ func enumValueProtoComments(reg *descriptor.Registry, enum *descriptor.Enum) str
if reg.GetEnumsAsInts() {
name = strconv.Itoa(int(value.GetNumber()))
}
str := protoComments(reg, enum.File, enum.Outers, "EnumType", int32(enum.Index), protoPath, int32(idx))
if str != "" {
if str := protoComments(reg, enum.File, enum.Outers, "EnumType", int32(enum.Index), protoPath, int32(idx)); str != "" {
comments = append(comments, name+": "+str)
}
}
Expand Down Expand Up @@ -2213,7 +2197,7 @@ func protoComments(reg *descriptor.Registry, file *descriptor.File, outers []str
// - determine if every (but first and last) line begins with " "
// - trim every line only if that is the case
// - join by \n
comments = strings.Replace(comments, "\n ", "\n", -1)
comments = strings.ReplaceAll(comments, "\n ", "\n")
}
if loc.TrailingComments != nil {
trailing := strings.TrimSpace(*loc.TrailingComments)
Expand Down Expand Up @@ -2250,8 +2234,7 @@ func goTemplateComments(comment string, data interface{}, reg *descriptor.Regist
// to make it easier to debug the template error
return err.Error()
}
err = tpl.Execute(&temp, data)
if err != nil {
if err := tpl.Execute(&temp, data); err != nil {
// If there is an error executing the templating insert the error as string in the comment
// to make it easier to debug the error
return err.Error()
Expand Down Expand Up @@ -2725,7 +2708,7 @@ func openapiExamplesFromProtoExamples(in map[string]string) map[string]interface
if len(in) == 0 {
return nil
}
out := make(map[string]interface{})
out := make(map[string]interface{}, len(in))
for mimeType, exampleStr := range in {
switch mimeType {
case "application/json":
Expand Down Expand Up @@ -2799,7 +2782,7 @@ func addCustomRefs(d openapiDefinitionsObject, reg *descriptor.Registry, refs re
for ref := range refs {
swgName, swgOk := fullyQualifiedNameToOpenAPIName(ref, reg)
if !swgOk {
glog.Errorf("can't resolve OpenAPI name from CustomRef '%v'", ref)
glog.Errorf("can't resolve OpenAPI name from CustomRef %q", ref)
continue
}
if _, ok := d[swgName]; ok {
Expand Down Expand Up @@ -2835,7 +2818,7 @@ func lowerCamelCase(fieldName string, fields []*descriptor.Field, msgs []*descri
return oneField.GetJsonName()
}
}
messageNameToFieldsToJSONName := make(map[string]map[string]string)
messageNameToFieldsToJSONName := make(map[string]map[string]string, len(msgs))
fieldNameToType := make(map[string]string)
for _, msg := range msgs {
fieldNameToJSONName := make(map[string]string)
Expand Down
44 changes: 18 additions & 26 deletions protoc-gen-openapiv2/main.go
Expand Up @@ -78,8 +78,7 @@ func main() {
glog.V(1).Info("Parsed code generator request")
pkgMap := make(map[string]string)
if req.Parameter != nil {
err := parseReqParam(req.GetParameter(), flag.CommandLine, pkgMap)
if err != nil {
if err := parseReqParam(req.GetParameter(), flag.CommandLine, pkgMap); err != nil {
glog.Fatalf("Error parsing flags: %v", err)
}
}
Expand Down Expand Up @@ -166,7 +165,7 @@ func main() {
}
}

var targets []*descriptor.File
targets := make([]*descriptor.File, 0, len(req.FileToGenerate))
for _, target := range req.FileToGenerate {
f, err := reg.LookupFile(target)
if err != nil {
Expand Down Expand Up @@ -218,37 +217,30 @@ func parseReqParam(param string, f *flag.FlagSet, pkgMap map[string]string) erro
for _, p := range strings.Split(param, ",") {
spec := strings.SplitN(p, "=", 2)
if len(spec) == 1 {
if spec[0] == "allow_delete_body" {
err := f.Set(spec[0], "true")
if err != nil {
return fmt.Errorf("cannot set flag %s: %v", p, err)
switch spec[0] {
case "allow_delete_body":
if err := f.Set(spec[0], "true"); err != nil {
return fmt.Errorf("cannot set flag %s: %w", p, err)
}
continue
}
if spec[0] == "allow_merge" {
err := f.Set(spec[0], "true")
if err != nil {
return fmt.Errorf("cannot set flag %s: %v", p, err)
case "allow_merge":
if err := f.Set(spec[0], "true"); err != nil {
return fmt.Errorf("cannot set flag %s: %w", p, err)
}
continue
}
if spec[0] == "allow_repeated_fields_in_body" {
err := f.Set(spec[0], "true")
if err != nil {
return fmt.Errorf("cannot set flag %s: %v", p, err)
case "allow_repeated_fields_in_body":
if err := f.Set(spec[0], "true"); err != nil {
return fmt.Errorf("cannot set flag %s: %w", p, err)
}
continue
}
if spec[0] == "include_package_in_tags" {
err := f.Set(spec[0], "true")
if err != nil {
return fmt.Errorf("cannot set flag %s: %v", p, err)
case "include_package_in_tags":
if err := f.Set(spec[0], "true"); err != nil {
return fmt.Errorf("cannot set flag %s: %w", p, err)
}
continue
}
err := f.Set(spec[0], "")
if err != nil {
return fmt.Errorf("cannot set flag %s: %v", p, err)
if err := f.Set(spec[0], ""); err != nil {
return fmt.Errorf("cannot set flag %s: %w", p, err)
}
continue
}
Expand All @@ -258,7 +250,7 @@ func parseReqParam(param string, f *flag.FlagSet, pkgMap map[string]string) erro
continue
}
if err := f.Set(name, value); err != nil {
return fmt.Errorf("cannot set flag %s: %v", p, err)
return fmt.Errorf("cannot set flag %s: %w", p, err)
}
}
return nil
Expand Down

0 comments on commit 4792df0

Please sign in to comment.