Skip to content
This repository was archived by the owner on Jun 27, 2023. It is now read-only.

Commit 5c85495

Browse files
authoredFeb 2, 2020
Use "." to refer to the current path's package in reflect mode (#387)
* feat: use "." to refer to the current path's package * doc: update reflect mode * fix: generated code lose package name
1 parent 3dcdcb6 commit 5c85495

File tree

3 files changed

+64
-12
lines changed

3 files changed

+64
-12
lines changed
 

Diff for: ‎README.md

+5
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,15 @@ that uses reflection to understand interfaces. It is enabled
5151
by passing two non-flag arguments: an import path, and a
5252
comma-separated list of symbols.
5353

54+
You can use "." to refer to the current path's package.
55+
5456
Example:
5557

5658
```bash
5759
mockgen database/sql/driver Conn,Driver
60+
61+
# Convenient for `go:generate`.
62+
mockgen . Conn,Driver
5863
```
5964

6065
The `mockgen` command is used to generate source code for a mock

Diff for: ‎mockgen/mockgen.go

+14-2
Original file line numberDiff line numberDiff line change
@@ -69,14 +69,26 @@ func main() {
6969

7070
var pkg *model.Package
7171
var err error
72+
var packageName string
7273
if *source != "" {
7374
pkg, err = sourceMode(*source)
7475
} else {
7576
if flag.NArg() != 2 {
7677
usage()
7778
log.Fatal("Expected exactly two arguments")
7879
}
79-
pkg, err = reflectMode(flag.Arg(0), strings.Split(flag.Arg(1), ","))
80+
packageName = flag.Arg(0)
81+
if packageName == "." {
82+
dir, err := os.Getwd()
83+
if err != nil {
84+
log.Fatalf("Get current directory failed: %v", err)
85+
}
86+
packageName, err = packageNameOfDir(dir)
87+
if err != nil {
88+
log.Fatalf("Parse package name failed: %v", err)
89+
}
90+
}
91+
pkg, err = reflectMode(packageName, strings.Split(flag.Arg(1), ","))
8092
}
8193
if err != nil {
8294
log.Fatalf("Loading input failed: %v", err)
@@ -130,7 +142,7 @@ func main() {
130142
if *source != "" {
131143
g.filename = *source
132144
} else {
133-
g.srcPackage = flag.Arg(0)
145+
g.srcPackage = packageName
134146
g.srcInterfaces = flag.Arg(1)
135147
}
136148

Diff for: ‎mockgen/parse.go

+45-10
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424
"go/build"
2525
"go/parser"
2626
"go/token"
27+
"io/ioutil"
2728
"log"
2829
"path"
2930
"path/filepath"
@@ -48,19 +49,10 @@ func sourceMode(source string) (*model.Package, error) {
4849
return nil, fmt.Errorf("failed getting source directory: %v", err)
4950
}
5051

51-
cfg := &packages.Config{Mode: packages.LoadFiles, Tests: true, Dir: srcDir}
52-
pkgs, err := packages.Load(cfg, "file="+source)
52+
packageImport, err := parsePackageImport(source, srcDir)
5353
if err != nil {
5454
return nil, err
5555
}
56-
if packages.PrintErrors(pkgs) > 0 || len(pkgs) == 0 {
57-
return nil, errors.New("loading package failed")
58-
}
59-
60-
packageImport := pkgs[0].PkgPath
61-
62-
// It is illegal to import a _test package.
63-
packageImport = strings.TrimSuffix(packageImport, "_test")
6456

6557
fs := token.NewFileSet()
6658
file, err := parser.ParseFile(fs, source, nil, 0)
@@ -519,3 +511,46 @@ func isVariadic(f *ast.FuncType) bool {
519511
_, ok := f.Params.List[nargs-1].Type.(*ast.Ellipsis)
520512
return ok
521513
}
514+
515+
// packageNameOfDir get package import path via dir
516+
func packageNameOfDir(srcDir string) (string, error) {
517+
files, err := ioutil.ReadDir(srcDir)
518+
if err != nil {
519+
log.Fatal(err)
520+
}
521+
522+
var goFilePath string
523+
for _, file := range files {
524+
if !file.IsDir() && strings.HasSuffix(file.Name(), ".go") {
525+
goFilePath = file.Name()
526+
break
527+
}
528+
}
529+
if goFilePath == "" {
530+
return "", fmt.Errorf("go source file not found %s", srcDir)
531+
}
532+
533+
packageImport, err := parsePackageImport(goFilePath, srcDir)
534+
if err != nil {
535+
return "", err
536+
}
537+
return packageImport, nil
538+
}
539+
540+
// parseImportPackage get package import path via source file
541+
func parsePackageImport(source, srcDir string) (string, error) {
542+
cfg := &packages.Config{Mode: packages.LoadFiles, Tests: true, Dir: srcDir}
543+
pkgs, err := packages.Load(cfg, "file="+source)
544+
if err != nil {
545+
return "", err
546+
}
547+
if packages.PrintErrors(pkgs) > 0 || len(pkgs) == 0 {
548+
return "", errors.New("loading package failed")
549+
}
550+
551+
packageImport := pkgs[0].PkgPath
552+
553+
// It is illegal to import a _test package.
554+
packageImport = strings.TrimSuffix(packageImport, "_test")
555+
return packageImport, nil
556+
}

0 commit comments

Comments
 (0)
This repository has been archived.