From 8ba8d5aec4d2f031057316e380a5264b277fb5b8 Mon Sep 17 00:00:00 2001 From: James Bardin Date: Tue, 8 Nov 2022 10:21:08 -0500 Subject: [PATCH] fix typo in upload size check The scp upload size check had a typo preventing files from reporting their size, causing an extra temp file to be created. --- internal/communicator/ssh/communicator.go | 8 ++++++- .../communicator/ssh/communicator_test.go | 24 ++++++++++++++++--- 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/internal/communicator/ssh/communicator.go b/internal/communicator/ssh/communicator.go index c6af68839e0d..609dc1fbaf0e 100644 --- a/internal/communicator/ssh/communicator.go +++ b/internal/communicator/ssh/communicator.go @@ -418,7 +418,7 @@ func (c *Communicator) Upload(path string, input io.Reader) error { switch src := input.(type) { case *os.File: fi, err := src.Stat() - if err != nil { + if err == nil { size = fi.Size() } case *bytes.Buffer: @@ -641,7 +641,13 @@ func checkSCPStatus(r *bufio.Reader) error { return nil } +var testUploadSizeHook func(size int64) + func scpUploadFile(dst string, src io.Reader, w io.Writer, r *bufio.Reader, size int64) error { + if testUploadSizeHook != nil { + testUploadSizeHook(size) + } + if size == 0 { // Create a temporary file where we can copy the contents of the src // so that we can determine the length, since SCP is length-prefixed. diff --git a/internal/communicator/ssh/communicator_test.go b/internal/communicator/ssh/communicator_test.go index 8d7db9996708..b829e5b9afb3 100644 --- a/internal/communicator/ssh/communicator_test.go +++ b/internal/communicator/ssh/communicator_test.go @@ -577,10 +577,28 @@ func TestAccUploadFile(t *testing.T) { } tmpDir := t.TempDir() + source, err := os.CreateTemp(tmpDir, "tempfile.in") + if err != nil { + t.Fatal(err) + } + + content := "this is the file content" + if _, err := source.WriteString(content); err != nil { + t.Fatal(err) + } + source.Seek(0, io.SeekStart) - content := []byte("this is the file content") - source := bytes.NewReader(content) tmpFile := filepath.Join(tmpDir, "tempFile.out") + + testUploadSizeHook = func(size int64) { + if size != int64(len(content)) { + t.Errorf("expected %d bytes, got %d\n", len(content), size) + } + } + defer func() { + testUploadSizeHook = nil + }() + err = c.Upload(tmpFile, source) if err != nil { t.Fatalf("error uploading file: %s", err) @@ -591,7 +609,7 @@ func TestAccUploadFile(t *testing.T) { t.Fatal(err) } - if !bytes.Equal(data, content) { + if string(data) != content { t.Fatalf("bad: %s", data) } }