Skip to content

Commit

Permalink
Merge pull request cli#8639 from dmgardiner25/non-interactive-ssh-kee…
Browse files Browse the repository at this point in the history
…p-alive

Send activity signals during non-interactive codespace SSH command
  • Loading branch information
dmgardiner25 committed Jan 30, 2024
2 parents cac9788 + 400db0f commit 023d711
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 30 deletions.
33 changes: 25 additions & 8 deletions internal/codespaces/rpc/invoker.go
Expand Up @@ -31,6 +31,7 @@ const (
codespacesInternalSessionName = "CodespacesInternal"
clientName = "gh"
connectedEventName = "connected"
keepAliveEventName = "keepAlive"
)

type StartSSHServerOptions struct {
Expand All @@ -43,16 +44,18 @@ type Invoker interface {
RebuildContainer(ctx context.Context, full bool) error
StartSSHServer(ctx context.Context) (int, string, error)
StartSSHServerWithOptions(ctx context.Context, options StartSSHServerOptions) (int, string, error)
KeepAlive()
}

type invoker struct {
conn *grpc.ClientConn
fwd portforwarder.PortForwarder
listener net.Listener
jupyterClient jupyter.JupyterServerHostClient
codespaceClient codespace.CodespaceHostClient
sshClient ssh.SshServerHostClient
cancelPF context.CancelFunc
conn *grpc.ClientConn
fwd portforwarder.PortForwarder
listener net.Listener
jupyterClient jupyter.JupyterServerHostClient
codespaceClient codespace.CodespaceHostClient
sshClient ssh.SshServerHostClient
cancelPF context.CancelFunc
keepAliveOverride bool
}

// Connects to the internal RPC server and returns a new invoker for it
Expand Down Expand Up @@ -256,6 +259,12 @@ func listenTCP() (*net.TCPListener, error) {
return listener, nil
}

// KeepAlive sets a flag to continuously send activity signals to
// the codespace even if there is no other activity (e.g. stdio)
func (i *invoker) KeepAlive() {
i.keepAliveOverride = true
}

// Periodically check whether there is a reason to keep the connection alive, and if so, notify the codespace to do so
func (i *invoker) heartbeat(ctx context.Context, interval time.Duration) {
ticker := time.NewTicker(interval)
Expand All @@ -266,7 +275,15 @@ func (i *invoker) heartbeat(ctx context.Context, interval time.Duration) {
case <-ctx.Done():
return
case <-ticker.C:
reason := i.fwd.GetKeepAliveReason()
reason := ""

// If the keep alive override flag is set, we don't need to check for activity on the forwarder
// Otherwise, grab the reason from the forwarder
if i.keepAliveOverride {
reason = keepAliveEventName
} else {
reason = i.fwd.GetKeepAliveReason()
}
_ = i.notifyCodespaceOfClientActivity(ctx, reason)
}
}
Expand Down
33 changes: 17 additions & 16 deletions internal/codespaces/ssh.go
Expand Up @@ -19,9 +19,9 @@ type printer interface {
// port-forwarding session. It runs until the shell is terminated
// (including by cancellation of the context).
func Shell(
ctx context.Context, p printer, sshArgs []string, port int, destination string, printConnDetails bool,
ctx context.Context, p printer, sshArgs []string, command []string, port int, destination string, printConnDetails bool,
) error {
cmd, connArgs, err := newSSHCommand(ctx, port, destination, sshArgs)
cmd, connArgs, err := newSSHCommand(ctx, port, destination, sshArgs, command)
if err != nil {
return fmt.Errorf("failed to create ssh command: %w", err)
}
Expand Down Expand Up @@ -51,30 +51,24 @@ func Copy(ctx context.Context, scpArgs []string, port int, destination string) e
// NewRemoteCommand returns an exec.Cmd that will securely run a shell
// command on the remote machine.
func NewRemoteCommand(ctx context.Context, tunnelPort int, destination string, sshArgs ...string) (*exec.Cmd, error) {
cmd, _, err := newSSHCommand(ctx, tunnelPort, destination, sshArgs)
sshArgs, command, err := ParseSSHArgs(sshArgs)
if err != nil {
return nil, err
}

cmd, _, err := newSSHCommand(ctx, tunnelPort, destination, sshArgs, command)
return cmd, err
}

// newSSHCommand populates an exec.Cmd to run a command (or if blank,
// an interactive shell) over ssh.
func newSSHCommand(ctx context.Context, port int, dst string, cmdArgs []string) (*exec.Cmd, []string, error) {
func newSSHCommand(ctx context.Context, port int, dst string, cmdArgs []string, command []string) (*exec.Cmd, []string, error) {
connArgs := []string{
"-p", strconv.Itoa(port),
"-o", "NoHostAuthenticationForLocalhost=yes",
"-o", "PasswordAuthentication=no",
}

// The ssh command syntax is: ssh [flags] user@host command [args...]
// There is no way to specify the user@host destination as a flag.
// Unfortunately, that means we need to know which user-provided words are
// SSH flags and which are command arguments so that we can place
// them before or after the destination, and that means we need to know all
// the flags and their arities.
cmdArgs, command, err := parseSSHArgs(cmdArgs)
if err != nil {
return nil, nil, err
}

cmdArgs = append(cmdArgs, connArgs...)
cmdArgs = append(cmdArgs, "-C") // Compression
cmdArgs = append(cmdArgs, dst) // user@host
Expand All @@ -96,7 +90,14 @@ func newSSHCommand(ctx context.Context, port int, dst string, cmdArgs []string)
return cmd, connArgs, nil
}

func parseSSHArgs(args []string) (cmdArgs, command []string, err error) {
// ParseSSHArgs parses the given array of arguments into two distinct slices of flags and command.
// The ssh command syntax is: ssh [flags] user@host command [args...]
// There is no way to specify the user@host destination as a flag.
// Unfortunately, that means we need to know which user-provided words are
// SSH flags and which are command arguments so that we can place
// them before or after the destination, and that means we need to know all
// the flags and their arities.
func ParseSSHArgs(args []string) (cmdArgs, command []string, err error) {
return parseArgs(args, "bcDeFIiLlmOopRSWw")
}

Expand Down
2 changes: 1 addition & 1 deletion internal/codespaces/ssh_test.go
Expand Up @@ -74,7 +74,7 @@ func TestParseSSHArgs(t *testing.T) {
}

for _, tcase := range testCases {
args, command, err := parseSSHArgs(tcase.Args)
args, command, err := ParseSSHArgs(tcase.Args)

checkParseResult(t, tcase, args, command, err)
}
Expand Down
22 changes: 17 additions & 5 deletions pkg/cmd/codespace/ssh.go
Expand Up @@ -276,16 +276,28 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err e

shellClosed := make(chan error, 1)
go func() {
var err error
if opts.scpArgs != nil {
// args is the correct variable to use here, we just use scpArgs as the check for which command to run
err = codespaces.Copy(ctx, args, localSSHServerPort, connectDestination)
shellClosed <- codespaces.Copy(ctx, args, localSSHServerPort, connectDestination)
} else {
err = codespaces.Shell(
ctx, a.errLogger, args, localSSHServerPort, connectDestination, opts.printConnDetails,
// Parse the ssh args to determine if the user specified a command
args, command, err := codespaces.ParseSSHArgs(args)
if err != nil {
shellClosed <- err
return
}

// If the user specified a command, we need to keep the shell alive
// since it will be non-interactive and the codespace might shut down
// before the command finishes
if command != nil {
invoker.KeepAlive()
}

shellClosed <- codespaces.Shell(
ctx, a.errLogger, args, command, localSSHServerPort, connectDestination, opts.printConnDetails,
)
}
shellClosed <- err
}()

select {
Expand Down

0 comments on commit 023d711

Please sign in to comment.