diff --git a/ioutils.go b/ioutils.go index 6ac3fb7..cef43b2 100644 --- a/ioutils.go +++ b/ioutils.go @@ -4,42 +4,54 @@ import ( "fmt" "io" "os" + "path" ) type TmpFile struct { - file *os.File + path string } -func NewTmpFile() (*TmpFile, error) { - file, err := os.CreateTemp(os.TempDir(), "k8s-secret-editor-") +func NewTmpFile(suffix string) (*TmpFile, error) { + name := fmt.Sprintf("k8s-secret-editor-%s", suffix) + p := path.Join(os.TempDir(), name) + f, err := os.Create(p) if err != nil { - return nil, err + return nil, fmt.Errorf("error creating temp file: %w", err) + } + if err := f.Close(); err != nil { + _ = os.Remove(f.Name()) + return nil, fmt.Errorf("error closing temp file: %w", err) } - // Set restrictive permissions to protect sensitive secret data - if err := os.Chmod(file.Name(), 0600); err != nil { - _ = file.Close() - _ = os.Remove(file.Name()) + if err := os.Chmod(p, 0o600); err != nil { + _ = os.Remove(p) return nil, fmt.Errorf("error setting file permissions: %w", err) } - return &TmpFile{file: file}, nil + return &TmpFile{path: p}, nil } func (t *TmpFile) Write(data []byte) error { - if _, err := t.file.Write(data); err != nil { + f, err := os.OpenFile(t.path, os.O_WRONLY|os.O_TRUNC, 0o600) + if err != nil { + return fmt.Errorf("error opening temp file for writing: %w", err) + } + if _, err := f.Write(data); err != nil { return fmt.Errorf("error writing to temp file: %w", err) } + if err := f.Sync(); err != nil { + return fmt.Errorf("error syncing temp file: %w", err) + } return nil } func (t *TmpFile) Read() ([]byte, error) { - // Move the file pointer back to the beginning before reading - if _, err := t.file.Seek(0, io.SeekStart); err != nil { - return nil, fmt.Errorf("error seeking to beginning of temp file: %w", err) + f, err := os.Open(t.path) + if err != nil { + return nil, fmt.Errorf("error opening temp file for reading: %w", err) } - data, err := io.ReadAll(t.file) + data, err := io.ReadAll(f) if err != nil { return nil, fmt.Errorf("error reading from temp file: %w", err) } @@ -47,16 +59,14 @@ func (t *TmpFile) Read() ([]byte, error) { } func (t *TmpFile) OpenEditor(editor interface{ Open(filePath string) error }) error { - return editor.Open(t.file.Name()) + if err := editor.Open(t.path); err != nil { + return fmt.Errorf("error opening editor: %w", err) + } + return nil } func (t *TmpFile) Close() error { - name := t.file.Name() - if err := t.file.Close(); err != nil { - return err - } - // Remove temp file after closing to avoid leaking secrets - if err := os.Remove(name); err != nil { + if err := os.Remove(t.path); err != nil { return fmt.Errorf("error removing temp file: %w", err) } return nil diff --git a/ioutils_test.go b/ioutils_test.go index fba951c..36ceec2 100644 --- a/ioutils_test.go +++ b/ioutils_test.go @@ -6,15 +6,10 @@ import ( ) func TestNewTmpFile(t *testing.T) { - tmp, err := NewTmpFile() - if err != nil { - t.Fatalf("failed to create temp file: %v", err) - } - defer tmp.file.Close() - defer os.Remove(tmp.file.Name()) + tmp := newTestTmpFile(t) // Verify file exists - fi, err := os.Stat(tmp.file.Name()) + fi, err := os.Stat(tmp.path) if err != nil { t.Fatalf("failed to stat temp file: %v", err) } @@ -26,14 +21,9 @@ func TestNewTmpFile(t *testing.T) { } func TestTmpFilePermissions(t *testing.T) { - tmp, err := NewTmpFile() - if err != nil { - t.Fatalf("failed to create temp file: %v", err) - } - defer tmp.file.Close() - defer os.Remove(tmp.file.Name()) + tmp := newTestTmpFile(t) - fi, err := os.Stat(tmp.file.Name()) + fi, err := os.Stat(tmp.path) if err != nil { t.Fatalf("failed to stat temp file: %v", err) } @@ -47,28 +37,20 @@ func TestTmpFilePermissions(t *testing.T) { } func TestTmpFileWrite(t *testing.T) { - tmp, err := NewTmpFile() - if err != nil { - t.Fatalf("failed to create temp file: %v", err) - } - defer tmp.Close() + tmp := newTestTmpFile(t) testData := []byte("test data for secret") - err = tmp.Write(testData) + err := tmp.Write(testData) if err != nil { t.Fatalf("failed to write to temp file: %v", err) } } func TestTmpFileRead(t *testing.T) { - tmp, err := NewTmpFile() - if err != nil { - t.Fatalf("failed to create temp file: %v", err) - } - defer tmp.Close() + tmp := newTestTmpFile(t) testData := []byte("test secret data") - err = tmp.Write(testData) + err := tmp.Write(testData) if err != nil { t.Fatalf("failed to write to temp file: %v", err) } @@ -85,16 +67,12 @@ func TestTmpFileRead(t *testing.T) { } func TestTmpFileReadAfterMultipleWrites(t *testing.T) { - tmp, err := NewTmpFile() - if err != nil { - t.Fatalf("failed to create temp file: %v", err) - } - defer tmp.Close() + tmp := newTestTmpFile(t) // Write multiple times data1 := []byte("first") data2 := []byte("second") - err = tmp.Write(data1) + err := tmp.Write(data1) if err != nil { t.Fatalf("failed to write first data: %v", err) } @@ -104,28 +82,22 @@ func TestTmpFileReadAfterMultipleWrites(t *testing.T) { t.Fatalf("failed to write second data: %v", err) } - // Read should return both readData, err := tmp.Read() if err != nil { t.Fatalf("failed to read from temp file: %v", err) } - expected := "firstsecond" + expected := "second" if string(readData) != expected { t.Errorf("expected data %s, got %s", expected, string(readData)) } } func TestTmpFileClose(t *testing.T) { - tmp, err := NewTmpFile() - if err != nil { - t.Fatalf("failed to create temp file: %v", err) - } - - filePath := tmp.file.Name() + tmp := newTestTmpFile(t) // Write some data - err = tmp.Write([]byte("test")) + err := tmp.Write([]byte("test")) if err != nil { t.Fatalf("failed to write to temp file: %v", err) } @@ -137,20 +109,17 @@ func TestTmpFileClose(t *testing.T) { } // Verify file is deleted - _, err = os.Stat(filePath) + _, err = os.Stat(tmp.path) if !os.IsNotExist(err) { - t.Errorf("temp file was not deleted: %s", filePath) + t.Errorf("temp file was not deleted: %s", tmp.path) } } func TestTmpFileCloseMultipleTimes(t *testing.T) { - tmp, err := NewTmpFile() - if err != nil { - t.Fatalf("failed to create temp file: %v", err) - } + tmp := newTestTmpFile(t) // First close - err = tmp.Close() + err := tmp.Close() if err != nil { t.Fatalf("first close failed: %v", err) } @@ -163,11 +132,7 @@ func TestTmpFileCloseMultipleTimes(t *testing.T) { } func TestTmpFileOpenEditor(t *testing.T) { - tmp, err := NewTmpFile() - if err != nil { - t.Fatalf("failed to create temp file: %v", err) - } - defer tmp.Close() + tmp := newTestTmpFile(t) editorPath := "/bin/sh" if _, err := os.Stat(editorPath); os.IsNotExist(err) { @@ -185,3 +150,18 @@ func TestTmpFileOpenEditor(t *testing.T) { t.Logf("OpenEditor returned error: %v (expected for sh with no input)", err) } } + +func newTestTmpFile(t *testing.T) *TmpFile { + t.Helper() + + tmp, err := NewTmpFile("test") + if err != nil { + t.Fatalf("failed to create temp file: %v", err) + } + t.Cleanup(func() { + if err := tmp.Close(); err != nil { + t.Logf("failed to clean up temp file: %v", err) + } + }) + return tmp +} diff --git a/main.go b/main.go index d57fcf3..d2ca869 100644 --- a/main.go +++ b/main.go @@ -72,7 +72,7 @@ func main() { //nolint:gocyclo fatalf("Key '%s' not found in secret '%s' in namespace '%s'", selectedKey, selectedSecret, selectedNamespace) } - tmpFile, err := NewTmpFile() + tmpFile, err := NewTmpFile(selectedKey) if err != nil { fatalf("Error creating temp file: %v", err) }