From bdfec2b525b2d750fd89d3ba10402a3d54a5b3d5 Mon Sep 17 00:00:00 2001 From: sdghchj Date: Sun, 20 Nov 2022 18:26:56 +0800 Subject: [PATCH] parse global enums (#1387) * parse global enums --- const.go | 122 +++++++++++++++++ enums.go | 13 ++ enums_test.go | 23 ++++ field_parser.go | 4 +- package.go | 58 ++++++++ packages.go | 128 ++++++++++++++---- parser.go | 21 ++- testdata/enums/api/api.go | 13 ++ testdata/enums/expected.json | 104 ++++++++++++++ testdata/enums/main.go | 18 +++ testdata/enums/types/model.go | 46 +++++++ .../internal/path3/v1/product.go | 17 +++ types.go | 14 +- 13 files changed, 542 insertions(+), 39 deletions(-) create mode 100644 const.go create mode 100644 enums.go create mode 100644 enums_test.go create mode 100644 package.go create mode 100644 testdata/enums/api/api.go create mode 100644 testdata/enums/expected.json create mode 100644 testdata/enums/main.go create mode 100644 testdata/enums/types/model.go create mode 100644 testdata/generics_package_alias/internal/path3/v1/product.go diff --git a/const.go b/const.go new file mode 100644 index 000000000..ffe3a50cf --- /dev/null +++ b/const.go @@ -0,0 +1,122 @@ +package swag + +import ( + "go/ast" + "go/token" + "strconv" +) + +// ConstVariable a model to record an const variable +type ConstVariable struct { + Name *ast.Ident + Type ast.Expr + Value interface{} + Comment *ast.CommentGroup +} + +// EvaluateValue evaluate the value +func (cv *ConstVariable) EvaluateValue(constTable map[string]*ConstVariable) interface{} { + if expr, ok := cv.Value.(ast.Expr); ok { + value, evalType := evaluateConstValue(cv.Name.Name, cv.Name.Obj.Data.(int), expr, constTable, make(map[string]struct{})) + if cv.Type == nil && evalType != nil { + cv.Type = evalType + } + if value != nil { + cv.Value = value + } + return value + } + return cv.Value +} + +func evaluateConstValue(name string, iota int, expr ast.Expr, constTable map[string]*ConstVariable, recursiveStack map[string]struct{}) (interface{}, ast.Expr) { + if len(name) > 0 { + if _, ok := recursiveStack[name]; ok { + return nil, nil + } + recursiveStack[name] = struct{}{} + } + + switch valueExpr := expr.(type) { + case *ast.Ident: + if valueExpr.Name == "iota" { + return iota, nil + } + if constTable != nil { + if cv, ok := constTable[valueExpr.Name]; ok { + if expr, ok = cv.Value.(ast.Expr); ok { + value, evalType := evaluateConstValue(valueExpr.Name, cv.Name.Obj.Data.(int), expr, constTable, recursiveStack) + if cv.Type == nil { + cv.Type = evalType + } + if value != nil { + cv.Value = value + } + return value, evalType + } + return cv.Value, cv.Type + } + } + case *ast.BasicLit: + switch valueExpr.Kind { + case token.INT: + x, err := strconv.ParseInt(valueExpr.Value, 10, 64) + if err != nil { + return nil, nil + } + return int(x), nil + case token.STRING, token.CHAR: + return valueExpr.Value[1 : len(valueExpr.Value)-1], nil + } + case *ast.UnaryExpr: + x, evalType := evaluateConstValue("", iota, valueExpr.X, constTable, recursiveStack) + switch valueExpr.Op { + case token.SUB: + return -x.(int), evalType + case token.XOR: + return ^(x.(int)), evalType + } + case *ast.BinaryExpr: + x, evalTypex := evaluateConstValue("", iota, valueExpr.X, constTable, recursiveStack) + y, evalTypey := evaluateConstValue("", iota, valueExpr.Y, constTable, recursiveStack) + evalType := evalTypex + if evalType == nil { + evalType = evalTypey + } + switch valueExpr.Op { + case token.ADD: + if ix, ok := x.(int); ok { + return ix + y.(int), evalType + } else if sx, ok := x.(string); ok { + return sx + y.(string), evalType + } + case token.SUB: + return x.(int) - y.(int), evalType + case token.MUL: + return x.(int) * y.(int), evalType + case token.QUO: + return x.(int) / y.(int), evalType + case token.REM: + return x.(int) % y.(int), evalType + case token.AND: + return x.(int) & y.(int), evalType + case token.OR: + return x.(int) | y.(int), evalType + case token.XOR: + return x.(int) ^ y.(int), evalType + case token.SHL: + return x.(int) << y.(int), evalType + case token.SHR: + return x.(int) >> y.(int), evalType + } + case *ast.ParenExpr: + return evaluateConstValue("", iota, valueExpr.X, constTable, recursiveStack) + case *ast.CallExpr: + //data conversion + if ident, ok := valueExpr.Fun.(*ast.Ident); ok && len(valueExpr.Args) == 1 && IsGolangPrimitiveType(ident.Name) { + arg, _ := evaluateConstValue("", iota, valueExpr.Args[0], constTable, recursiveStack) + return arg, nil + } + } + return nil, nil +} diff --git a/enums.go b/enums.go new file mode 100644 index 000000000..4dc5547ff --- /dev/null +++ b/enums.go @@ -0,0 +1,13 @@ +package swag + +const ( + enumVarNamesExtension = "x-enum-varnames" + enumCommentsExtension = "x-enum-comments" +) + +// EnumValue a model to record an enum const variable +type EnumValue struct { + key string + Value interface{} + Comment string +} diff --git a/enums_test.go b/enums_test.go new file mode 100644 index 000000000..b77f1d227 --- /dev/null +++ b/enums_test.go @@ -0,0 +1,23 @@ +package swag + +import ( + "encoding/json" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestParseGlobalEnums(t *testing.T) { + searchDir := "testdata/enums" + expected, err := os.ReadFile(filepath.Join(searchDir, "expected.json")) + assert.NoError(t, err) + + p := New() + err = p.ParseAPI(searchDir, mainAPIFile, defaultParseDepth) + assert.NoError(t, err) + b, err := json.MarshalIndent(p.swagger, "", " ") + assert.NoError(t, err) + assert.Equal(t, string(expected), string(b)) +} diff --git a/field_parser.go b/field_parser.go index ce1ef4bf5..66f85427f 100644 --- a/field_parser.go +++ b/field_parser.go @@ -405,13 +405,13 @@ func (ps *tagBaseFieldParser) ComplementSchema(schema *spec.Schema) error { if schema.Items.Schema.Extensions == nil { schema.Items.Schema.Extensions = map[string]interface{}{} } - schema.Items.Schema.Extensions["x-enum-varnames"] = field.enumVarNames + schema.Items.Schema.Extensions[enumVarNamesExtension] = field.enumVarNames } else { // Add to top level schema if schema.Extensions == nil { schema.Extensions = map[string]interface{}{} } - schema.Extensions["x-enum-varnames"] = field.enumVarNames + schema.Extensions[enumVarNamesExtension] = field.enumVarNames } } diff --git a/package.go b/package.go new file mode 100644 index 000000000..c11fe0710 --- /dev/null +++ b/package.go @@ -0,0 +1,58 @@ +package swag + +import "go/ast" + +// PackageDefinitions files and definition in a package. +type PackageDefinitions struct { + // files in this package, map key is file's relative path starting package path + Files map[string]*ast.File + + // definitions in this package, map key is typeName + TypeDefinitions map[string]*TypeSpecDef + + // const variables in this package, map key is the name + ConstTable map[string]*ConstVariable + + // const variables in order in this package + OrderedConst []*ConstVariable + + // package name + Name string +} + +// NewPackageDefinitions new a PackageDefinitions object +func NewPackageDefinitions(name string) *PackageDefinitions { + return &PackageDefinitions{ + Name: name, + Files: make(map[string]*ast.File), + TypeDefinitions: make(map[string]*TypeSpecDef), + ConstTable: make(map[string]*ConstVariable), + } +} + +// AddFile add a file +func (pkg *PackageDefinitions) AddFile(pkgPath string, file *ast.File) *PackageDefinitions { + pkg.Files[pkgPath] = file + return pkg +} + +// AddTypeSpec add a type spec. +func (pkg *PackageDefinitions) AddTypeSpec(name string, typeSpec *TypeSpecDef) *PackageDefinitions { + pkg.TypeDefinitions[name] = typeSpec + return pkg +} + +// AddConst add a const variable. +func (pkg *PackageDefinitions) AddConst(valueSpec *ast.ValueSpec) *PackageDefinitions { + for i := 0; i < len(valueSpec.Names) && i < len(valueSpec.Values); i++ { + variable := &ConstVariable{ + Name: valueSpec.Names[i], + Type: valueSpec.Type, + Value: valueSpec.Values[i], + Comment: valueSpec.Comment, + } + pkg.ConstTable[valueSpec.Names[i].Name] = variable + pkg.OrderedConst = append(pkg.OrderedConst, variable) + } + return pkg +} diff --git a/packages.go b/packages.go index 361dd9344..74547bc7e 100644 --- a/packages.go +++ b/packages.go @@ -59,11 +59,7 @@ func (pkgDefs *PackagesDefinitions) CollectAstFile(packageDir, path string, astF dependency.Files[path] = astFile } else { - pkgDefs.packages[packageDir] = &PackageDefinitions{ - Name: astFile.Name.Name, - Files: map[string]*ast.File{path: astFile}, - TypeDefinitions: make(map[string]*TypeSpecDef), - } + pkgDefs.packages[packageDir] = NewPackageDefinitions(astFile.Name.Name).AddFile(path, astFile) } pkgDefs.files[astFile] = &AstFileInfo{ @@ -110,12 +106,18 @@ func (pkgDefs *PackagesDefinitions) ParseTypes() (map[*TypeSpecDef]*Schema, erro pkgDefs.parseFunctionScopedTypesFromFile(astFile, info.PackagePath, parsedSchemas) } pkgDefs.removeAllNotUniqueTypes() + pkgDefs.evaluateConstVariables() + pkgDefs.collectConstEnums(parsedSchemas) return parsedSchemas, nil } func (pkgDefs *PackagesDefinitions) parseTypesFromFile(astFile *ast.File, packagePath string, parsedSchemas map[*TypeSpecDef]*Schema) { for _, astDeclaration := range astFile.Decls { - if generalDeclaration, ok := astDeclaration.(*ast.GenDecl); ok && generalDeclaration.Tok == token.TYPE { + generalDeclaration, ok := astDeclaration.(*ast.GenDecl) + if !ok { + continue + } + if generalDeclaration.Tok == token.TYPE { for _, astSpec := range generalDeclaration.Specs { if typeSpec, ok := astSpec.(*ast.TypeSpec); ok { typeSpecDef := &TypeSpecDef{ @@ -142,28 +144,30 @@ func (pkgDefs *PackagesDefinitions) parseTypesFromFile(astFile *ast.File, packag if ok { if anotherTypeDef == nil { typeSpecDef.NotUnique = true - pkgDefs.uniqueDefinitions[typeSpecDef.TypeName()] = typeSpecDef + fullName = typeSpecDef.TypeName() + pkgDefs.uniqueDefinitions[fullName] = typeSpecDef } else if typeSpecDef.PkgPath != anotherTypeDef.PkgPath { - anotherTypeDef.NotUnique = true - typeSpecDef.NotUnique = true pkgDefs.uniqueDefinitions[fullName] = nil + anotherTypeDef.NotUnique = true pkgDefs.uniqueDefinitions[anotherTypeDef.TypeName()] = anotherTypeDef - pkgDefs.uniqueDefinitions[typeSpecDef.TypeName()] = typeSpecDef + typeSpecDef.NotUnique = true + fullName = typeSpecDef.TypeName() + pkgDefs.uniqueDefinitions[fullName] = typeSpecDef } } else { pkgDefs.uniqueDefinitions[fullName] = typeSpecDef } if pkgDefs.packages[typeSpecDef.PkgPath] == nil { - pkgDefs.packages[typeSpecDef.PkgPath] = &PackageDefinitions{ - Name: astFile.Name.Name, - TypeDefinitions: map[string]*TypeSpecDef{typeSpecDef.Name(): typeSpecDef}, - } + pkgDefs.packages[typeSpecDef.PkgPath] = NewPackageDefinitions(astFile.Name.Name).AddTypeSpec(typeSpecDef.Name(), typeSpecDef) } else if _, ok = pkgDefs.packages[typeSpecDef.PkgPath].TypeDefinitions[typeSpecDef.Name()]; !ok { - pkgDefs.packages[typeSpecDef.PkgPath].TypeDefinitions[typeSpecDef.Name()] = typeSpecDef + pkgDefs.packages[typeSpecDef.PkgPath].AddTypeSpec(typeSpecDef.Name(), typeSpecDef) } } } + } else if generalDeclaration.Tok == token.CONST { + // collect const + pkgDefs.collectConstVariables(astFile, packagePath, generalDeclaration) } } } @@ -202,25 +206,24 @@ func (pkgDefs *PackagesDefinitions) parseFunctionScopedTypesFromFile(astFile *as if ok { if anotherTypeDef == nil { typeSpecDef.NotUnique = true - pkgDefs.uniqueDefinitions[typeSpecDef.TypeName()] = typeSpecDef + fullName = typeSpecDef.TypeName() + pkgDefs.uniqueDefinitions[fullName] = typeSpecDef } else if typeSpecDef.PkgPath != anotherTypeDef.PkgPath { - anotherTypeDef.NotUnique = true - typeSpecDef.NotUnique = true pkgDefs.uniqueDefinitions[fullName] = nil + anotherTypeDef.NotUnique = true pkgDefs.uniqueDefinitions[anotherTypeDef.TypeName()] = anotherTypeDef - pkgDefs.uniqueDefinitions[typeSpecDef.TypeName()] = typeSpecDef + typeSpecDef.NotUnique = true + fullName = typeSpecDef.TypeName() + pkgDefs.uniqueDefinitions[fullName] = typeSpecDef } } else { pkgDefs.uniqueDefinitions[fullName] = typeSpecDef } if pkgDefs.packages[typeSpecDef.PkgPath] == nil { - pkgDefs.packages[typeSpecDef.PkgPath] = &PackageDefinitions{ - Name: astFile.Name.Name, - TypeDefinitions: map[string]*TypeSpecDef{fullName: typeSpecDef}, - } + pkgDefs.packages[typeSpecDef.PkgPath] = NewPackageDefinitions(astFile.Name.Name).AddTypeSpec(fullName, typeSpecDef) } else if _, ok = pkgDefs.packages[typeSpecDef.PkgPath].TypeDefinitions[fullName]; !ok { - pkgDefs.packages[typeSpecDef.PkgPath].TypeDefinitions[fullName] = typeSpecDef + pkgDefs.packages[typeSpecDef.PkgPath].AddTypeSpec(fullName, typeSpecDef) } } } @@ -232,6 +235,83 @@ func (pkgDefs *PackagesDefinitions) parseFunctionScopedTypesFromFile(astFile *as } } +func (pkgDefs *PackagesDefinitions) collectConstVariables(astFile *ast.File, packagePath string, generalDeclaration *ast.GenDecl) { + pkg, ok := pkgDefs.packages[packagePath] + if !ok { + pkg = NewPackageDefinitions(astFile.Name.Name) + pkgDefs.packages[packagePath] = pkg + } + + var lastValueSpec *ast.ValueSpec + for _, astSpec := range generalDeclaration.Specs { + valueSpec, ok := astSpec.(*ast.ValueSpec) + if !ok { + continue + } + if len(valueSpec.Names) == 1 && len(valueSpec.Values) == 1 { + lastValueSpec = valueSpec + } else if len(valueSpec.Names) == 1 && len(valueSpec.Values) == 0 && valueSpec.Type == nil && lastValueSpec != nil { + valueSpec.Type = lastValueSpec.Type + valueSpec.Values = lastValueSpec.Values + } + pkg.AddConst(valueSpec) + } +} + +func (pkgDefs *PackagesDefinitions) evaluateConstVariables() { + //TODO evaluate enum cross packages + for _, pkg := range pkgDefs.packages { + for _, constVar := range pkg.OrderedConst { + constVar.EvaluateValue(pkg.ConstTable) + } + } +} + +func (pkgDefs *PackagesDefinitions) collectConstEnums(parsedSchemas map[*TypeSpecDef]*Schema) { + for _, pkg := range pkgDefs.packages { + for _, constVar := range pkg.OrderedConst { + if constVar.Type == nil { + continue + } + ident, ok := constVar.Type.(*ast.Ident) + if !ok || IsGolangPrimitiveType(ident.Name) { + continue + } + typeDef, ok := pkg.TypeDefinitions[ident.Name] + if !ok { + continue + } + + //delete it from parsed schemas, and will parse it again + if _, ok := parsedSchemas[typeDef]; ok { + delete(parsedSchemas, typeDef) + } + + if typeDef.Enums == nil { + typeDef.Enums = make([]EnumValue, 0) + } + + name := constVar.Name.Name + if _, ok := constVar.Value.(ast.Expr); ok { + continue + } + + enumValue := EnumValue{ + key: name, + Value: constVar.Value, + } + if constVar.Comment != nil && len(constVar.Comment.List) > 0 { + enumValue.Comment = constVar.Comment.List[0].Text + enumValue.Comment = strings.TrimLeft(enumValue.Comment, "//") + enumValue.Comment = strings.TrimLeft(enumValue.Comment, "/*") + enumValue.Comment = strings.TrimRight(enumValue.Comment, "*/") + enumValue.Comment = strings.TrimSpace(enumValue.Comment) + } + typeDef.Enums = append(typeDef.Enums, enumValue) + } + } +} + func (pkgDefs *PackagesDefinitions) removeAllNotUniqueTypes() { for key, ud := range pkgDefs.uniqueDefinitions { if ud == nil { diff --git a/parser.go b/parser.go index a36ba6d2d..621b8cfb3 100644 --- a/parser.go +++ b/parser.go @@ -912,7 +912,7 @@ func (parser *Parser) getTypeSchema(typeName string, file *ast.File, ref bool) ( } } - if ref && len(schema.Schema.Type) > 0 && schema.Schema.Type[0] == OBJECT { + if ref && (len(schema.Schema.Type) > 0 && schema.Schema.Type[0] == OBJECT || len(schema.Enum) > 0) { return parser.getRefTypeSchema(typeSpecDef, schema), nil } @@ -982,6 +982,25 @@ func (parser *Parser) ParseDefinition(typeSpecDef *TypeSpecDef) (*Schema, error) fillDefinitionDescription(definition, typeSpecDef.File, typeSpecDef) } + if len(typeSpecDef.Enums) > 0 { + var varnames []string + var enumComments = make(map[string]string) + for _, value := range typeSpecDef.Enums { + definition.Enum = append(definition.Enum, value.Value) + varnames = append(varnames, value.key) + if len(value.Comment) > 0 { + enumComments[value.key] = value.Comment + } + } + if definition.Extensions == nil { + definition.Extensions = make(spec.Extensions) + } + definition.Extensions[enumVarNamesExtension] = varnames + if len(enumComments) > 0 { + definition.Extensions[enumCommentsExtension] = enumComments + } + } + sch := Schema{ Name: typeName, PkgPath: typeSpecDef.PkgPath, diff --git a/testdata/enums/api/api.go b/testdata/enums/api/api.go new file mode 100644 index 000000000..ea1766be0 --- /dev/null +++ b/testdata/enums/api/api.go @@ -0,0 +1,13 @@ +package api + +import "github.com/swaggo/swag/testdata/enums/types" + +// enum example +// +// @Summary enums +// @Description enums +// @Failure 400 {object} types.Person "ok" +// @Router /students [post] +func API() { + _ = types.Person{} +} diff --git a/testdata/enums/expected.json b/testdata/enums/expected.json new file mode 100644 index 000000000..1110f6cfa --- /dev/null +++ b/testdata/enums/expected.json @@ -0,0 +1,104 @@ +{ + "swagger": "2.0", + "info": { + "contact": {} + }, + "basePath": "/v2", + "paths": { + "/students": { + "post": { + "description": "enums", + "summary": "enums", + "responses": { + "400": { + "description": "ok", + "schema": { + "$ref": "#/definitions/types.Person" + } + } + } + } + } + }, + "definitions": { + "types.Class": { + "type": "integer", + "enum": [ + -1, + 1, + 2, + 3, + 4, + 5 + ], + "x-enum-comments": { + "A": "AAA", + "B": "BBB" + }, + "x-enum-varnames": [ + "None", + "A", + "B", + "C", + "D", + "F" + ] + }, + "types.Mask": { + "type": "integer", + "enum": [ + 1, + 2, + 4, + 8 + ], + "x-enum-comments": { + "Mask1": "Mask1", + "Mask2": "Mask2", + "Mask3": "Mask3", + "Mask4": "Mask4" + }, + "x-enum-varnames": [ + "Mask1", + "Mask2", + "Mask3", + "Mask4" + ] + }, + "types.Person": { + "type": "object", + "properties": { + "class": { + "$ref": "#/definitions/types.Class" + }, + "mask": { + "$ref": "#/definitions/types.Mask" + }, + "name": { + "type": "string" + }, + "type": { + "$ref": "#/definitions/types.Type" + } + } + }, + "types.Type": { + "type": "string", + "enum": [ + "teacher", + "student", + "Other" + ], + "x-enum-comments": { + "Other": "Other", + "Student": "student", + "Teacher": "teacher" + }, + "x-enum-varnames": [ + "Teacher", + "Student", + "Other" + ] + } + } +} \ No newline at end of file diff --git a/testdata/enums/main.go b/testdata/enums/main.go new file mode 100644 index 000000000..8238df86d --- /dev/null +++ b/testdata/enums/main.go @@ -0,0 +1,18 @@ +package main + +// @title Swagger Example API +// @version 1.0 +// @description This is a sample server. +// @termsOfService http://swagger.io/terms/ + +// @contact.name API Support +// @contact.url http://www.swagger.io/support +// @contact.email support@swagger.io + +// @license.name Apache 2.0 +// @license.url http://www.apache.org/licenses/LICENSE-2.0.html + +// @BasePath /v2 +func main() { + +} diff --git a/testdata/enums/types/model.go b/testdata/enums/types/model.go new file mode 100644 index 000000000..cd5c8e08e --- /dev/null +++ b/testdata/enums/types/model.go @@ -0,0 +1,46 @@ +package types + +type Class int + +const Base = 1 + +const ( + None Class = -1 + A Class = Base + (iota+1-1)*2/2%100 - (1&1 | 1) + (2 ^ 2) // AAA + B /* BBB */ + C + D + F = D + 1 + //G is not enum + G = H + 10 + //H is not enum + H = 10 + //I is not enum + I = int(F + 2) +) + +type Mask int + +const ( + Mask1 Mask = 2 << iota >> 1 // Mask1 + Mask2 /* Mask2 */ + Mask3 // Mask3 + Mask4 // Mask4 +) + +type Type string + +const ( + Teacher Type = "teacher" // teacher + Student Type = "student" /* student */ + Other Type = "Other" // Other + Unknown = "Unknown" + OtherUnknown = string(Other + Unknown) +) + +type Person struct { + Name string + Class Class + Mask Mask + Type Type +} diff --git a/testdata/generics_package_alias/internal/path3/v1/product.go b/testdata/generics_package_alias/internal/path3/v1/product.go new file mode 100644 index 000000000..be5e25127 --- /dev/null +++ b/testdata/generics_package_alias/internal/path3/v1/product.go @@ -0,0 +1,17 @@ +package v1 + +type ProductDto struct { + Name3 string `json:"name3"` +} + +type ListResult[T any] struct { + Items3 []T `json:"items3,omitempty"` +} + +type RenamedProductDto struct { + Name33 string `json:"name33"` +} // @name ProductDtoV3 + +type RenamedListResult[T any] struct { + Items33 []T `json:"items33,omitempty"` +} // @name ListResultV3 diff --git a/types.go b/types.go index e149f8b28..f86d8212f 100644 --- a/types.go +++ b/types.go @@ -22,6 +22,8 @@ type TypeSpecDef struct { // the TypeSpec of this type definition TypeSpec *ast.TypeSpec + Enums []EnumValue + // path of package starting from under ${GOPATH}/src or from module path in go.mod PkgPath string ParentSpec ast.Decl @@ -87,15 +89,3 @@ type AstFileInfo struct { // PackagePath package import path of the ast.File PackagePath string } - -// PackageDefinitions files and definition in a package. -type PackageDefinitions struct { - // files in this package, map key is file's relative path starting package path - Files map[string]*ast.File - - // definitions in this package, map key is typeName - TypeDefinitions map[string]*TypeSpecDef - - // package name - Name string -}