Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gopls/internal/golang: add extract interface code action #478

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
65 changes: 42 additions & 23 deletions gopls/internal/golang/codeaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,13 @@ func CodeActions(ctx context.Context, snapshot *cache.Snapshot, fh file.Handle,
}
}

pkg, pgf, err := NarrowestPackageForFile(ctx, snapshot, fh.URI())
if err != nil {
return nil, err
}

if want[protocol.RefactorExtract] {
extractions, err := getExtractCodeActions(pgf, rng, snapshot.Options())
extractions, err := getExtractCodeActions(pkg, pgf, rng, snapshot.Options())
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -198,20 +203,18 @@ func fixedByImportFix(fix *imports.ImportFix, diagnostics []protocol.Diagnostic)
}

// getExtractCodeActions returns any refactor.extract code actions for the selection.
func getExtractCodeActions(pgf *parsego.File, rng protocol.Range, options *settings.Options) ([]protocol.CodeAction, error) {
if rng.Start == rng.End {
return nil, nil
}

func getExtractCodeActions(pkg *cache.Package, pgf *parsego.File, rng protocol.Range, options *settings.Options) ([]protocol.CodeAction, error) {
start, end, err := pgf.RangePos(rng)
if err != nil {
return nil, err
}

puri := pgf.URI
var commands []protocol.Command
if _, ok, methodOk, _ := CanExtractFunction(pgf.Tok, start, end, pgf.Src, pgf.File); ok {
cmd, err := command.NewApplyFixCommand("Extract function", command.ApplyFixArgs{
Fix: fixExtractFunction,

if _, _, ok, _ := CanExtractInterface(pkg, start, end, pgf.File); ok {
cmd, err := command.NewApplyFixCommand("Extract interface", command.ApplyFixArgs{
Fix: fixExtractInterface,
URI: puri,
Range: rng,
ResolveEdits: supportsResolveEdits(options),
Expand All @@ -220,9 +223,12 @@ func getExtractCodeActions(pgf *parsego.File, rng protocol.Range, options *setti
return nil, err
}
commands = append(commands, cmd)
if methodOk {
cmd, err := command.NewApplyFixCommand("Extract method", command.ApplyFixArgs{
Fix: fixExtractMethod,
}

if rng.Start != rng.End {
if _, ok, methodOk, _ := CanExtractFunction(pgf.Tok, start, end, pgf.Src, pgf.File); ok {
cmd, err := command.NewApplyFixCommand("Extract function", command.ApplyFixArgs{
Fix: fixExtractFunction,
URI: puri,
Range: rng,
ResolveEdits: supportsResolveEdits(options),
Expand All @@ -231,20 +237,33 @@ func getExtractCodeActions(pgf *parsego.File, rng protocol.Range, options *setti
return nil, err
}
commands = append(commands, cmd)
if methodOk {
cmd, err := command.NewApplyFixCommand("Extract method", command.ApplyFixArgs{
Fix: fixExtractMethod,
URI: puri,
Range: rng,
ResolveEdits: supportsResolveEdits(options),
})
if err != nil {
return nil, err
}
commands = append(commands, cmd)
}
}
}
if _, _, ok, _ := CanExtractVariable(start, end, pgf.File); ok {
cmd, err := command.NewApplyFixCommand("Extract variable", command.ApplyFixArgs{
Fix: fixExtractVariable,
URI: puri,
Range: rng,
ResolveEdits: supportsResolveEdits(options),
})
if err != nil {
return nil, err
if _, _, ok, _ := CanExtractVariable(start, end, pgf.File); ok {
cmd, err := command.NewApplyFixCommand("Extract variable", command.ApplyFixArgs{
Fix: fixExtractVariable,
URI: puri,
Range: rng,
ResolveEdits: supportsResolveEdits(options),
})
if err != nil {
return nil, err
}
commands = append(commands, cmd)
}
commands = append(commands, cmd)
}

var actions []protocol.CodeAction
for i := range commands {
actions = append(actions, newCodeAction(commands[i].Title, protocol.RefactorExtract, &commands[i], nil, options))
Expand Down
34 changes: 34 additions & 0 deletions gopls/internal/golang/extract.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (

"golang.org/x/tools/go/analysis"
"golang.org/x/tools/go/ast/astutil"
"golang.org/x/tools/gopls/internal/cache"
"golang.org/x/tools/gopls/internal/util/bug"
"golang.org/x/tools/gopls/internal/util/safetoken"
"golang.org/x/tools/internal/analysisinternal"
Expand Down Expand Up @@ -127,6 +128,39 @@ func CanExtractVariable(start, end token.Pos, file *ast.File) (ast.Expr, []ast.N
return nil, nil, false, fmt.Errorf("cannot extract an %T to a variable", expr)
}

// CanExtractInterface reports whether the code in the given position is for a
// type which can be represented as an interface.
func CanExtractInterface(pkg *cache.Package, start, end token.Pos, file *ast.File) (ast.Expr, []ast.Node, bool, error) {
path, _ := astutil.PathEnclosingInterval(file, start, end)
if len(path) == 0 {
return nil, nil, false, fmt.Errorf("no path enclosing interval")
}

node := path[0]
expr, ok := node.(ast.Expr)
if !ok {
return nil, nil, false, fmt.Errorf("node is not an expression")
}

switch e := expr.(type) {
case *ast.Ident:
o, ok := pkg.TypesInfo().ObjectOf(e).(*types.TypeName)
if !ok {
return nil, nil, false, fmt.Errorf("cannot extract a %T to an interface", expr)
}

if _, ok := o.Type().(*types.Basic); ok {
return nil, nil, false, fmt.Errorf("cannot extract a basic type to an interface")
}

return expr, path, true, nil
case *ast.StarExpr, *ast.SelectorExpr:
return expr, path, true, nil
default:
return nil, nil, false, fmt.Errorf("cannot extract a %T to an interface", expr)
}
}

// Calculate indentation for insertion.
// When inserting lines of code, we must ensure that the lines have consistent
// formatting (i.e. the proper indentation). To do so, we observe the indentation on the
Expand Down
141 changes: 141 additions & 0 deletions gopls/internal/golang/fix.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,17 @@
package golang

import (
"bytes"
"context"
"errors"
"fmt"
"go/ast"
"go/token"
"go/types"
"slices"

"golang.org/x/tools/go/analysis"
"golang.org/x/tools/go/ast/astutil"
"golang.org/x/tools/gopls/internal/analysis/embeddirective"
"golang.org/x/tools/gopls/internal/analysis/fillstruct"
"golang.org/x/tools/gopls/internal/analysis/stubmethods"
Expand All @@ -22,6 +26,7 @@ import (
"golang.org/x/tools/gopls/internal/file"
"golang.org/x/tools/gopls/internal/protocol"
"golang.org/x/tools/gopls/internal/util/bug"
"golang.org/x/tools/gopls/internal/util/safetoken"
"golang.org/x/tools/internal/imports"
)

Expand Down Expand Up @@ -61,6 +66,7 @@ func singleFile(fixer1 singleFileFixer) fixer {
const (
fixExtractVariable = "extract_variable"
fixExtractFunction = "extract_function"
fixExtractInterface = "extract_interface"
fixExtractMethod = "extract_method"
fixInlineCall = "inline_call"
fixInvertIfCondition = "invert_if_condition"
Expand Down Expand Up @@ -112,6 +118,7 @@ func ApplyFix(ctx context.Context, fix string, snapshot *cache.Snapshot, fh file

// Ad-hoc fixers: these are used when the command is
// constructed directly by logic in server/code_action.
fixExtractInterface: extractInterface,
fixExtractFunction: singleFile(extractFunction),
fixExtractMethod: singleFile(extractMethod),
fixExtractVariable: singleFile(extractVariable),
Expand Down Expand Up @@ -142,6 +149,140 @@ func ApplyFix(ctx context.Context, fix string, snapshot *cache.Snapshot, fh file
return suggestedFixToEdits(ctx, snapshot, fixFset, suggestion)
}

func extractInterface(ctx context.Context, snapshot *cache.Snapshot, pkg *cache.Package, pgf *parsego.File, start, end token.Pos) (*token.FileSet, *analysis.SuggestedFix, error) {
path, _ := astutil.PathEnclosingInterval(pgf.File, start, end)

var field *ast.Field
var decl ast.Decl
for _, node := range path {
if f, ok := node.(*ast.Field); ok {
field = f
continue
}

// Record the node that starts the declaration of the type that contains
// the field we are creating the interface for.
if d, ok := node.(ast.Decl); ok {
decl = d
break // we have both the field and the declaration
}
}

if field == nil || decl == nil {
return nil, nil, nil
}

p := safetoken.StartPosition(pkg.FileSet(), field.Pos())
pos := protocol.Position{
Line: uint32(p.Line - 1), // Line is zero-based
Character: uint32(p.Column - 1), // Character is zero-based
}

fh, err := snapshot.ReadFile(ctx, pgf.URI)
if err != nil {
return nil, nil, err
}

refs, err := references(ctx, snapshot, fh, pos, false)
if err != nil {
return nil, nil, err
}

type method struct {
signature *types.Signature
name string
}

var methods []method
for _, ref := range refs {
locPkg, locPgf, err := NarrowestPackageForFile(ctx, snapshot, ref.location.URI)
if err != nil {
return nil, nil, err
}

_, end, err := locPgf.RangePos(ref.location.Range)
if err != nil {
return nil, nil, err
}

// We are interested in the method call, so we need the node after the dot
rangeEnd := end + token.Pos(len("."))
path, _ := astutil.PathEnclosingInterval(locPgf.File, rangeEnd, rangeEnd)
id, ok := path[0].(*ast.Ident)
if !ok {
continue
}

obj := locPkg.TypesInfo().ObjectOf(id)
if obj == nil {
continue
}

sig, ok := obj.Type().(*types.Signature)
if !ok {
return nil, nil, errors.New("cannot extract interface with non-method accesses")
}

fc := method{signature: sig, name: obj.Name()}
if !slices.Contains(methods, fc) {
methods = append(methods, fc)
}
}

interfaceName := "I" + pkg.TypesInfo().ObjectOf(field.Names[0]).Name()
var buf bytes.Buffer
buf.WriteString("\ntype ")
buf.WriteString(interfaceName)
buf.WriteString(" interface {\n")
for _, fc := range methods {
buf.WriteString("\t")
buf.WriteString(fc.name)
types.WriteSignature(&buf, fc.signature, relativeTo(pkg.Types()))
buf.WriteByte('\n')
}
buf.WriteByte('}')
buf.WriteByte('\n')

interfacePos := decl.Pos() - 1
// Move the interface above the documentation comment if the type declaration
// includes one.
switch d := decl.(type) {
case *ast.GenDecl:
if d.Doc != nil {
interfacePos = d.Doc.Pos() - 1
}
case *ast.FuncDecl:
if d.Doc != nil {
interfacePos = d.Doc.Pos() - 1
}
}

return pkg.FileSet(), &analysis.SuggestedFix{
Message: "Extract interface",
TextEdits: []analysis.TextEdit{{
Pos: interfacePos,
End: interfacePos,
NewText: buf.Bytes(),
}, {
Pos: field.Type.Pos(),
End: field.Type.End(),
NewText: []byte(interfaceName),
}},
}, nil
}

func relativeTo(pkg *types.Package) types.Qualifier {
if pkg == nil {
return nil
}
return func(other *types.Package) string {
if pkg == other {
return "" // same package; unqualified
}
return other.Name()
}
}

// suggestedFixToEdits converts the suggestion's edits from analysis form into protocol form.
func suggestedFixToEdits(ctx context.Context, snapshot *cache.Snapshot, fset *token.FileSet, suggestion *analysis.SuggestedFix) ([]protocol.TextDocumentEdit, error) {
editsPerFile := map[protocol.DocumentURI]*protocol.TextDocumentEdit{}
Expand Down