Skip to content

Commit

Permalink
fix: exit on ctrl+c in interactive printers by default (#593)
Browse files Browse the repository at this point in the history
  • Loading branch information
panbanda committed Dec 16, 2023
1 parent 1446167 commit 857a7ac
Show file tree
Hide file tree
Showing 6 changed files with 117 additions and 10 deletions.
3 changes: 1 addition & 2 deletions interactive_continue_printer.go
Expand Up @@ -2,7 +2,6 @@ package pterm

import (
"fmt"
"os"
"strings"

"atomicgo.dev/cursor"
Expand Down Expand Up @@ -155,7 +154,7 @@ func (p InteractiveContinuePrinter) Show(text ...string) (string, error) {
result = p.Options[p.DefaultValueIndex]
return true, nil
case keys.CtrlC:
os.Exit(1)
internal.Exit(1)
return true, nil
}
return false, nil
Expand Down
25 changes: 19 additions & 6 deletions interactive_textinput_printer_test.go
@@ -1,6 +1,7 @@
package pterm_test

import (
"os"
"reflect"
"testing"
"time"
Expand All @@ -10,6 +11,7 @@ import (
"github.com/MarvinJWendt/testza"

"github.com/pterm/pterm"
"github.com/pterm/pterm/internal"
)

func TestInteractiveTextInputPrinter_WithDefaultText(t *testing.T) {
Expand Down Expand Up @@ -56,12 +58,23 @@ func TestInteractiveTextInputPrinter_WithMask(t *testing.T) {
}

func TestInteractiveTextInputPrinter_WithCancel(t *testing.T) {
go func() {
time.Sleep(1 * time.Millisecond)
keyboard.SimulateKeyPress(keys.CtrlC)
}()
result, _ := pterm.DefaultInteractiveTextInput.WithMask("*").Show()
testza.AssertEqual(t, "", result)
exitCalled := false
internal.DefaultExitFunc = func(code int) {
exitCalled = true
}
defer func() { internal.DefaultExitFunc = os.Exit }()

go func() {
time.Sleep(1 * time.Millisecond)
keyboard.SimulateKeyPress(keys.CtrlC)
}()

result, _ := pterm.DefaultInteractiveTextInput.WithMask("*").Show()
testza.AssertEqual(t, "", result)

if !exitCalled {
t.Errorf("Expected exit to be called on Ctrl+C")
}
}

func TestInteractiveTextInputPrinter_OnEnter(t *testing.T) {
Expand Down
8 changes: 6 additions & 2 deletions internal/cancelation_signal.go
Expand Up @@ -9,8 +9,12 @@ func NewCancelationSignal(interruptFunc func()) (func(), func()) {
}

exit := func() {
if canceled && interruptFunc != nil {
interruptFunc()
if canceled {
if interruptFunc != nil {
interruptFunc()
} else {
Exit(1)
}
}
}

Expand Down
54 changes: 54 additions & 0 deletions internal/cancelation_signal_test.go
@@ -0,0 +1,54 @@
package internal

import (
"os"
"testing"
)

func TestNewCancelationSignal(t *testing.T) {
// Mock DefaultExitFunc
exitCalled := false
exitCode := 0
DefaultExitFunc = func(code int) {
exitCalled = true
exitCode = code
}
defer func() { DefaultExitFunc = os.Exit }() // Reset after tests

// Scenario 1: Testing cancel function
cancel, exit := NewCancelationSignal(nil)
if exitCalled {
t.Error("Exit function should not be called immediately after NewCancelationSignal")
}

// Scenario 2: Testing exit function without cancel
exit()
if exitCalled {
t.Error("Exit function should not be called when cancel is not set")
}

// Scenario 3: Testing cancel then exit with no interruptFunc
cancel()
exit()
if !exitCalled || exitCode != 1 {
t.Errorf("Expected Exit(1) to be called, exitCalled: %v, exitCode: %d", exitCalled, exitCode)
}

// Reset for next scenario
exitCalled = false
exitCode = 0

// Scenario 4: Testing cancel then exit with interruptFunc
interruptCalled := false
cancel, exit = NewCancelationSignal(func() {
interruptCalled = true
})
cancel()
exit()
if interruptCalled == false {
t.Error("Expected interruptFunc to be called")
}
if exitCalled {
t.Error("Exit should not be called when interruptFunc is provided")
}
}
14 changes: 14 additions & 0 deletions internal/exit.go
@@ -0,0 +1,14 @@
package internal

import "os"

// ExitFuncType is the type of function used to exit the program.
type ExitFuncType func(int)

// DefaultExitFunc is the default function used to exit the program.
var DefaultExitFunc ExitFuncType = os.Exit

// Exit calls the current exit function.
func Exit(code int) {
DefaultExitFunc(code)
}
23 changes: 23 additions & 0 deletions internal/exit_test.go
@@ -0,0 +1,23 @@
package internal_test

import (
"os"
"testing"

"github.com/pterm/pterm/internal"
)

func TestExit(t *testing.T) {
var lastExitCode int
internal.DefaultExitFunc = func(code int) {
lastExitCode = code
}

defer func() { internal.DefaultExitFunc = os.Exit }()

internal.Exit(1)

if lastExitCode != 1 {
t.Errorf("Expected exit code 1, got %d", lastExitCode)
}
}

0 comments on commit 857a7ac

Please sign in to comment.