Skip to content

Commit

Permalink
enhancement for enums (#1400)
Browse files Browse the repository at this point in the history
* enhancement for enums:
1. explicit enum conversion;
2. implicit integer conversion;
3. keep the type of the right operand of a shift operation;
3. parse escape characters.
  • Loading branch information
sdghchj committed Nov 30, 2022
1 parent 8139731 commit e50db3e
Show file tree
Hide file tree
Showing 8 changed files with 731 additions and 53 deletions.
551 changes: 551 additions & 0 deletions const.go

Large diffs are not rendered by default.

9 changes: 9 additions & 0 deletions enums_test.go
Expand Up @@ -20,4 +20,13 @@ func TestParseGlobalEnums(t *testing.T) {
b, err := json.MarshalIndent(p.swagger, "", " ")
assert.NoError(t, err)
assert.Equal(t, string(expected), string(b))
constsPath := "github.com/swaggo/swag/testdata/enums/consts"
assert.Equal(t, 64, p.packages.packages[constsPath].ConstTable["uintSize"].Value)
assert.Equal(t, int32(62), p.packages.packages[constsPath].ConstTable["maxBase"].Value)
assert.Equal(t, 8, p.packages.packages[constsPath].ConstTable["shlByLen"].Value)
assert.Equal(t, 255, p.packages.packages[constsPath].ConstTable["hexnum"].Value)
assert.Equal(t, 15, p.packages.packages[constsPath].ConstTable["octnum"].Value)
assert.Equal(t, `aa\nbb\u8888cc`, p.packages.packages[constsPath].ConstTable["nonescapestr"].Value)
assert.Equal(t, "aa\nbb\u8888cc", p.packages.packages[constsPath].ConstTable["escapestr"].Value)
assert.Equal(t, '\u8888', p.packages.packages[constsPath].ConstTable["escapechar"].Value)
}
115 changes: 69 additions & 46 deletions package.go
Expand Up @@ -3,6 +3,7 @@ package swag
import (
"go/ast"
"go/token"
"reflect"
"strconv"
)

Expand Down Expand Up @@ -31,6 +32,7 @@ type PackageDefinitions struct {
type ConstVariableGlobalEvaluator interface {
EvaluateConstValue(pkg *PackageDefinitions, cv *ConstVariable, recursiveStack map[string]struct{}) (interface{}, ast.Expr)
EvaluateConstValueByName(file *ast.File, pkgPath, constVariableName string, recursiveStack map[string]struct{}) (interface{}, ast.Expr)
FindTypeSpec(typeName string, file *ast.File) *TypeSpecDef
}

// NewPackageDefinitions new a PackageDefinitions object
Expand Down Expand Up @@ -92,68 +94,89 @@ func (pkg *PackageDefinitions) evaluateConstValue(file *ast.File, iota int, expr
case *ast.BasicLit:
switch valueExpr.Kind {
case token.INT:
x, err := strconv.ParseInt(valueExpr.Value, 10, 64)
if err != nil {
return nil, nil
// hexadecimal
if len(valueExpr.Value) > 2 && valueExpr.Value[0] == '0' && valueExpr.Value[1] == 'x' {
if x, err := strconv.ParseInt(valueExpr.Value[2:], 16, 64); err == nil {
return int(x), nil
} else if x, err := strconv.ParseUint(valueExpr.Value[2:], 16, 64); err == nil {
return x, nil
} else {
panic(err)
}
}

//octet
if len(valueExpr.Value) > 1 && valueExpr.Value[0] == '0' {
if x, err := strconv.ParseInt(valueExpr.Value[1:], 8, 64); err == nil {
return int(x), nil
} else if x, err := strconv.ParseUint(valueExpr.Value[1:], 8, 64); err == nil {
return x, nil
} else {
panic(err)
}
}

//a basic literal integer is int type in default, or must have an explicit converting type in front
if x, err := strconv.ParseInt(valueExpr.Value, 10, 64); err == nil {
return int(x), nil
} else if x, err := strconv.ParseUint(valueExpr.Value, 10, 64); err == nil {
return x, nil
} else {
panic(err)
}
return int(x), nil
case token.STRING, token.CHAR:
return valueExpr.Value[1 : len(valueExpr.Value)-1], nil
case token.STRING:
if valueExpr.Value[0] == '`' {
return valueExpr.Value[1 : len(valueExpr.Value)-1], nil
}
return EvaluateEscapedString(valueExpr.Value[1 : len(valueExpr.Value)-1]), nil
case token.CHAR:
return EvaluateEscapedChar(valueExpr.Value[1 : len(valueExpr.Value)-1]), nil
}
case *ast.UnaryExpr:
x, evalType := pkg.evaluateConstValue(file, iota, valueExpr.X, globalEvaluator, recursiveStack)
if x == nil {
return nil, nil
}
switch valueExpr.Op {
case token.SUB:
return -x.(int), evalType
case token.XOR:
return ^(x.(int)), evalType
return x, evalType
}
return EvaluateUnary(x, valueExpr.Op, evalType)
case *ast.BinaryExpr:
x, evalTypex := pkg.evaluateConstValue(file, iota, valueExpr.X, globalEvaluator, recursiveStack)
y, evalTypey := pkg.evaluateConstValue(file, iota, valueExpr.Y, globalEvaluator, recursiveStack)
if x == nil || y == nil {
return nil, nil
}
evalType := evalTypex
if evalType == nil {
evalType = evalTypey
}
switch valueExpr.Op {
case token.ADD:
if ix, ok := x.(int); ok {
return ix + y.(int), evalType
} else if sx, ok := x.(string); ok {
return sx + y.(string), evalType
}
case token.SUB:
return x.(int) - y.(int), evalType
case token.MUL:
return x.(int) * y.(int), evalType
case token.QUO:
return x.(int) / y.(int), evalType
case token.REM:
return x.(int) % y.(int), evalType
case token.AND:
return x.(int) & y.(int), evalType
case token.OR:
return x.(int) | y.(int), evalType
case token.XOR:
return x.(int) ^ y.(int), evalType
case token.SHL:
return x.(int) << y.(int), evalType
case token.SHR:
return x.(int) >> y.(int), evalType
}
return EvaluateBinary(x, y, valueExpr.Op, evalTypex, evalTypey)
case *ast.ParenExpr:
return pkg.evaluateConstValue(file, iota, valueExpr.X, globalEvaluator, recursiveStack)
case *ast.CallExpr:
//data conversion
if ident, ok := valueExpr.Fun.(*ast.Ident); ok && len(valueExpr.Args) == 1 && IsGolangPrimitiveType(ident.Name) {
arg, _ := pkg.evaluateConstValue(file, iota, valueExpr.Args[0], globalEvaluator, recursiveStack)
return arg, nil
if len(valueExpr.Args) != 1 {
return nil, nil
}
arg := valueExpr.Args[0]
if ident, ok := valueExpr.Fun.(*ast.Ident); ok {
name := ident.Name
if name == "uintptr" {
name = "uint"
}
if IsGolangPrimitiveType(name) {
value, _ := pkg.evaluateConstValue(file, iota, arg, globalEvaluator, recursiveStack)
value = EvaluateDataConversion(value, name)
return value, nil
} else if name == "len" {
value, _ := pkg.evaluateConstValue(file, iota, arg, globalEvaluator, recursiveStack)
return reflect.ValueOf(value).Len(), nil
}
typeDef := globalEvaluator.FindTypeSpec(name, file)
if typeDef == nil {
return nil, nil
}
return arg, valueExpr.Fun
} else if selector, ok := valueExpr.Fun.(*ast.SelectorExpr); ok {
typeDef := globalEvaluator.FindTypeSpec(fullTypeName(selector.X.(*ast.Ident).Name, selector.Sel.Name), file)
if typeDef == nil {
return nil, nil
}
return arg, typeDef.TypeSpec.Type
}
}
return nil, nil
Expand Down
9 changes: 9 additions & 0 deletions testdata/enums/consts/const.go
@@ -1,3 +1,12 @@
package consts

const Base = 1

const uintSize = 32 << (^uint(uintptr(0)) >> 63)
const maxBase = 10 + ('z' - 'a' + 1) + ('Z' - 'A' + 1)
const shlByLen = 1 << len("aaa")
const hexnum = 0xFF
const octnum = 017
const nonescapestr = `aa\nbb\u8888cc`
const escapestr = "aa\nbb\u8888cc"
const escapechar = '\u8888'
1 change: 0 additions & 1 deletion testdata/enums/main.go
Expand Up @@ -14,5 +14,4 @@ package main

// @BasePath /v2
func main() {

}
21 changes: 15 additions & 6 deletions testdata/enums/types/model.go
Expand Up @@ -11,8 +11,8 @@ const (
A Class = consts.Base + (iota+1-1)*2/2%100 - (1&1 | 1) + (2 ^ 2) // AAA
B /* BBB */
C
D
F = D + 1
D = C + 1
F = Class(5)
//G is not enum
G = H + 10
//H is not enum
Expand All @@ -21,13 +21,15 @@ const (
I = int(F + 2)
)

const J = 1 << uint16(I)

type Mask int

const (
Mask1 Mask = 2 << iota >> 1 // Mask1
Mask2 /* Mask2 */
Mask3 // Mask3
Mask4 // Mask4
Mask1 Mask = 0x02 << iota >> 1 // Mask1
Mask2 /* Mask2 */
Mask3 // Mask3
Mask4 // Mask4
)

type Type string
Expand All @@ -40,6 +42,13 @@ const (
OtherUnknown = string(Other + Unknown)
)

type Sex rune

const (
Male Sex = 'M'
Female = 'F'
)

type Person struct {
Name string
Class Class
Expand Down
31 changes: 31 additions & 0 deletions utils_go18.go
@@ -0,0 +1,31 @@
//go:build go1.18
// +build go1.18

package swag

import (
"reflect"
"unicode/utf8"
)

// AppendUtf8Rune appends the UTF-8 encoding of r to the end of p and
// returns the extended buffer. If the rune is out of range,
// it appends the encoding of RuneError.
func AppendUtf8Rune(p []byte, r rune) []byte {
return utf8.AppendRune(p, r)
}

// CanIntegerValue a wrapper of reflect.Value
type CanIntegerValue struct {
reflect.Value
}

// CanInt reports whether Uint can be used without panicking.
func (v CanIntegerValue) CanInt() bool {
return v.Value.CanInt()
}

// CanUint reports whether Uint can be used without panicking.
func (v CanIntegerValue) CanUint() bool {
return v.Value.CanUint()
}
47 changes: 47 additions & 0 deletions utils_other.go
@@ -0,0 +1,47 @@
//go:build !go1.18
// +build !go1.18

package swag

import (
"reflect"
"unicode/utf8"
)

// AppendUtf8Rune appends the UTF-8 encoding of r to the end of p and
// returns the extended buffer. If the rune is out of range,
// it appends the encoding of RuneError.
func AppendUtf8Rune(p []byte, r rune) []byte {
length := utf8.RuneLen(rune(r))
if length > 0 {
utf8Slice := make([]byte, length)
utf8.EncodeRune(utf8Slice, rune(r))
p = append(p, utf8Slice...)
}
return p
}

// CanIntegerValue a wrapper of reflect.Value
type CanIntegerValue struct {
reflect.Value
}

// CanInt reports whether Uint can be used without panicking.
func (v CanIntegerValue) CanInt() bool {
switch v.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return true
default:
return false
}
}

// CanUint reports whether Uint can be used without panicking.
func (v CanIntegerValue) CanUint() bool {
switch v.Kind() {
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return true
default:
return false
}
}

0 comments on commit e50db3e

Please sign in to comment.