From 104b6259947c4a6eaf3dda04c6b333caefada15b Mon Sep 17 00:00:00 2001 From: anhthii Date: Sun, 31 Aug 2025 21:50:04 +0700 Subject: [PATCH 1/2] Add more flags for mpcium-cli to handle configuration management --- .gitignore | 1 - Makefile | 12 ++- cmd/mpcium-cli/main.go | 7 ++ cmd/mpcium-cli/register-peers.go | 2 +- cmd/mpcium/main.go | 103 +++++++++++++++++------ config.prod.yaml.template | 6 +- examples/generate/main.go | 2 +- examples/reshare/main.go | 2 +- examples/sign/main.go | 2 +- pkg/config/init.go | 30 +++++-- pkg/identity/identity.go | 38 ++++++--- pkg/infra/consul.go | 15 ++-- scripts/migration/update-keyinfo/main.go | 2 +- 13 files changed, 163 insertions(+), 59 deletions(-) diff --git a/.gitignore b/.gitignore index 14216ed6..769d7d3e 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,6 @@ db/ tmp/ bin/ -identity/ event_initiator.identity.json event_initiator.key event_initiator.key.age diff --git a/Makefile b/Makefile index 6da06a26..bf640ec1 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: all build clean mpcium mpc test test-verbose test-coverage e2e-test e2e-clean cleanup-test-env +.PHONY: all build clean mpcium mpc install test test-verbose test-coverage e2e-test e2e-clean cleanup-test-env BIN_DIR := bin @@ -16,6 +16,16 @@ mpcium: mpc: go install ./cmd/mpcium-cli +# Install binaries to /usr/local/bin (auto-detects architecture) +install: + @echo "Building and installing mpcium binaries for Linux..." + GOOS=linux go build -o /tmp/mpcium ./cmd/mpcium + GOOS=linux go build -o /tmp/mpcium-cli ./cmd/mpcium-cli + sudo install -m 755 /tmp/mpcium /usr/local/bin/ + sudo install -m 755 /tmp/mpcium-cli /usr/local/bin/ + rm -f /tmp/mpcium /tmp/mpcium-cli + @echo "Successfully installed mpcium and mpcium-cli to /usr/local/bin/" + # Run all tests test: go test ./... diff --git a/cmd/mpcium-cli/main.go b/cmd/mpcium-cli/main.go index 8c9d0ad1..319ef314 100644 --- a/cmd/mpcium-cli/main.go +++ b/cmd/mpcium-cli/main.go @@ -18,6 +18,13 @@ func main() { cmd := &cli.Command{ Name: "mpcium", Usage: "Fystack MPC node management tools", + Flags: []cli.Flag{ + &cli.StringFlag{ + Name: "config", + Aliases: []string{"c"}, + Usage: "Path to configuration file", + }, + }, Commands: []*cli.Command{ { Name: "generate-peers", diff --git a/cmd/mpcium-cli/register-peers.go b/cmd/mpcium-cli/register-peers.go index 1b482d77..85d66586 100644 --- a/cmd/mpcium-cli/register-peers.go +++ b/cmd/mpcium-cli/register-peers.go @@ -48,7 +48,7 @@ func registerPeers(ctx context.Context, c *cli.Command) error { } // Initialize config and logger - config.InitViperConfig() + config.InitViperConfig(c.String("config")) logger.Init(environment, true) // Connect to Consul diff --git a/cmd/mpcium/main.go b/cmd/mpcium/main.go index 6d12286d..2ba8ca4b 100644 --- a/cmd/mpcium/main.go +++ b/cmd/mpcium/main.go @@ -6,6 +6,7 @@ import ( "os" "os/signal" "path/filepath" + "strings" "sync" "syscall" "time" @@ -21,6 +22,7 @@ import ( "github.com/fystack/mpcium/pkg/logger" "github.com/fystack/mpcium/pkg/messaging" "github.com/fystack/mpcium/pkg/mpc" + "github.com/fystack/mpcium/pkg/security" "github.com/hashicorp/consul/api" "github.com/nats-io/nats.go" "github.com/spf13/viper" @@ -49,6 +51,11 @@ func main() { Usage: "Node name", Required: true, }, + &cli.StringFlag{ + Name: "config", + Aliases: []string{"c"}, + Usage: "Path to configuration file", + }, &cli.BoolFlag{ Name: "decrypt-private-key", Aliases: []string{"d"}, @@ -60,6 +67,16 @@ func main() { Aliases: []string{"p"}, Usage: "Prompt for sensitive parameters", }, + &cli.StringFlag{ + Name: "password-file", + Aliases: []string{"f"}, + Usage: "Path to file containing BadgerDB password", + }, + &cli.StringFlag{ + Name: "age-password-file", + Aliases: []string{"k"}, + Usage: "Path to file containing password for decrypting .age encrypted node private key", + }, &cli.BoolFlag{ Name: "debug", Usage: "Enable debug logging", @@ -87,15 +104,24 @@ func main() { func runNode(ctx context.Context, c *cli.Command) error { nodeName := c.String("name") + configPath := c.String("config") decryptPrivateKey := c.Bool("decrypt-private-key") usePrompts := c.Bool("prompt-credentials") + passwordFile := c.String("password-file") + agePasswordFile := c.String("age-password-file") debug := c.Bool("debug") viper.SetDefault("backup_enabled", true) - config.InitViperConfig() + config.InitViperConfig(configPath) environment := viper.GetString("environment") logger.Init(environment, debug) + // Handle password file if provided + if passwordFile != "" { + if err := loadPasswordFromFile(passwordFile); err != nil { + return fmt.Errorf("failed to load password from file: %w", err) + } + } // Handle configuration based on prompt flag if usePrompts { promptForSensitiveCredentials() @@ -120,7 +146,7 @@ func runNode(ctx context.Context, c *cli.Command) error { defer stopBackup() } - identityStore, err := identity.NewFileStore("identity", nodeName, decryptPrivateKey) + identityStore, err := identity.NewFileStore("identity", nodeName, decryptPrivateKey, agePasswordFile) if err != nil { logger.Fatal("Failed to create identity store", err) } @@ -272,6 +298,30 @@ func runNode(ctx context.Context, c *cli.Command) error { return nil } +// loadPasswordFromFile reads the BadgerDB password from a file +func loadPasswordFromFile(filePath string) error { + passwordBytes, err := os.ReadFile(filePath) + if err != nil { + return fmt.Errorf("failed to read password file %s: %w", filePath, err) + } + + // Trim whitespace/newlines without altering content + password := strings.TrimSpace(string(passwordBytes)) + + if password == "" { + security.ZeroBytes(passwordBytes) + return fmt.Errorf("password file %s is empty", filePath) + } + + viper.Set("badger_password", password) + logger.Info(fmt.Sprintf("Loaded BadgerDB password from file: %s", filePath), "password", password) + + security.ZeroBytes(passwordBytes) + security.ZeroString(&password) + + return nil +} + // Prompt user for sensitive configuration values func promptForSensitiveCredentials() { fmt.Println("WARNING: Please back up your Badger DB password in a secure location.") @@ -282,6 +332,12 @@ func promptForSensitiveCredentials() { var confirmPass []byte var err error + // Ensure sensitive buffers are zeroed on exit + defer func() { + security.ZeroBytes(badgerPass) + security.ZeroBytes(confirmPass) + }() + for { fmt.Print("Enter Badger DB password: ") badgerPass, err = term.ReadPassword(int(syscall.Stdin)) @@ -311,28 +367,11 @@ func promptForSensitiveCredentials() { } // Show masked password for confirmation - maskedPassword := maskString(string(badgerPass)) + passwordStr := string(badgerPass) + maskedPassword := maskString(passwordStr) fmt.Printf("Password set: %s\n", maskedPassword) - - viper.Set("badger_password", string(badgerPass)) - - // Prompt for initiator public key (using regular input since it's not as sensitive) - var initiatorKey string - fmt.Print("Enter event initiator public key (hex): ") - if _, err := fmt.Scanln(&initiatorKey); err != nil { - logger.Fatal("Failed to read initiator key", err) - } - - if initiatorKey == "" { - logger.Fatal("Initiator public key cannot be empty", nil) - } - - // Show masked key for confirmation - maskedKey := maskString(initiatorKey) - fmt.Printf("Event initiator public key set: %s\n", maskedKey) - - viper.Set("event_initiator_pubkey", initiatorKey) - fmt.Println("\n✓ Configuration complete!") + viper.Set("badger_password", passwordStr) + security.ZeroString(&passwordStr) } // maskString shows the first and last character of a string, replacing the middle with asterisks @@ -479,9 +518,21 @@ func GetNATSConnection(environment string) (*nats.Conn, error) { } if environment == constant.EnvProduction { - clientCert := filepath.Join(".", "certs", "client-cert.pem") - clientKey := filepath.Join(".", "certs", "client-key.pem") - caCert := filepath.Join(".", "certs", "rootCA.pem") + // Load TLS config from configuration + clientCert := viper.GetString("nats.tls.client_cert") + clientKey := viper.GetString("nats.tls.client_key") + caCert := viper.GetString("nats.tls.ca_cert") + + // Fallback to default paths if not configured + if clientCert == "" { + clientCert = filepath.Join(".", "certs", "client-cert.pem") + } + if clientKey == "" { + clientKey = filepath.Join(".", "certs", "client-key.pem") + } + if caCert == "" { + caCert = filepath.Join(".", "certs", "rootCA.pem") + } opts = append(opts, nats.ClientCert(clientCert, clientKey), diff --git a/config.prod.yaml.template b/config.prod.yaml.template index b1082593..7b219684 100644 --- a/config.prod.yaml.template +++ b/config.prod.yaml.template @@ -2,9 +2,13 @@ nats: url: tls://127.0.0.1:4222 # Please use TLS for production username: "" password: "" + tls: + client_cert: "/etc/mpcium/certs/client-cert.pem" + client_key: "/etc/mpcium/certs/client-key.pem" + ca_cert: "/etc/mpcium/certs/rootCA.pem" consul: - address: https://consul.example.com # Use HTTPS for production + address: https://consul.example.com username: username token: "" password: "" diff --git a/examples/generate/main.go b/examples/generate/main.go index 6f54c31e..3f0135f1 100644 --- a/examples/generate/main.go +++ b/examples/generate/main.go @@ -28,7 +28,7 @@ func main() { flag.Parse() - config.InitViperConfig() + config.InitViperConfig("") logger.Init(environment, false) algorithm := viper.GetString("event_initiator_algorithm") diff --git a/examples/reshare/main.go b/examples/reshare/main.go index 68ea7863..47c4d858 100644 --- a/examples/reshare/main.go +++ b/examples/reshare/main.go @@ -19,7 +19,7 @@ import ( func main() { const environment = "dev" - config.InitViperConfig() + config.InitViperConfig("") logger.Init(environment, true) algorithm := viper.GetString("event_initiator_algorithm") diff --git a/examples/sign/main.go b/examples/sign/main.go index 4cc4aa1b..3424610f 100644 --- a/examples/sign/main.go +++ b/examples/sign/main.go @@ -19,7 +19,7 @@ import ( func main() { const environment = "dev" - config.InitViperConfig() + config.InitViperConfig("") logger.Init(environment, true) algorithm := viper.GetString("event_initiator_algorithm") diff --git a/pkg/config/init.go b/pkg/config/init.go index 023d2e5d..a9edbab2 100644 --- a/pkg/config/init.go +++ b/pkg/config/init.go @@ -40,15 +40,31 @@ type ConsulConfig struct { } type NATsConfig struct { - URL string `mapstructure:"url"` - Username string `mapstructure:"username"` - Password string `mapstructure:"password"` + URL string `mapstructure:"url"` + Username string `mapstructure:"username"` + Password string `mapstructure:"password"` + TLS *TLSConfig `mapstructure:"tls"` } -func InitViperConfig() { - viper.SetConfigName("config") // name of config file (without extension) - viper.SetConfigType("yaml") // REQUIRED if the config file does not have the extension in the name - viper.AddConfigPath(".") // optionally look for config in the working directory +type TLSConfig struct { + ClientCert string `mapstructure:"client_cert"` + ClientKey string `mapstructure:"client_key"` + CACert string `mapstructure:"ca_cert"` +} + +func InitViperConfig(configPath string) { + if configPath != "" { + // Use specific config file path + viper.SetConfigFile(configPath) + } else { + // Use default behavior - search for config.yaml in common locations + viper.SetConfigName("config") // name of config file (without extension) + viper.SetConfigType("yaml") // REQUIRED if the config file does not have the extension in the name + viper.AddConfigPath(".") // optionally look for config in the working directory + viper.AddConfigPath("/etc/mpcium/") // look for config in /etc/mpcium/ + viper.AddConfigPath("$HOME/.mpcium/") // look for config in home directory + } + viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) viper.AutomaticEnv() err := viper.ReadInConfig() // Find and read the config file diff --git a/pkg/identity/identity.go b/pkg/identity/identity.go index 3dbccf6a..d312b90c 100644 --- a/pkg/identity/identity.go +++ b/pkg/identity/identity.go @@ -21,6 +21,7 @@ import ( "github.com/fystack/mpcium/pkg/common/pathutil" "github.com/fystack/mpcium/pkg/encryption" "github.com/fystack/mpcium/pkg/logger" + "github.com/fystack/mpcium/pkg/security" "github.com/fystack/mpcium/pkg/types" "github.com/spf13/viper" ) @@ -75,12 +76,12 @@ type fileStore struct { } // NewFileStore creates a new identity store -func NewFileStore(identityDir, nodeName string, decrypt bool) (*fileStore, error) { +func NewFileStore(identityDir, nodeName string, decrypt bool, agePasswordFile string) (*fileStore, error) { if err := os.MkdirAll(identityDir, 0750); err != nil { return nil, fmt.Errorf("failed to create identity directory: %w", err) } - privateKeyHex, err := loadPrivateKey(identityDir, nodeName, decrypt) + privateKeyHex, err := loadPrivateKey(identityDir, nodeName, decrypt, agePasswordFile) if err != nil { return nil, err } @@ -249,7 +250,7 @@ func loadP256InitiatorKey() (*ecdsa.PublicKey, error) { } // loadPrivateKey loads the private key from file, decrypting if necessary -func loadPrivateKey(identityDir, nodeName string, decrypt bool) (string, error) { +func loadPrivateKey(identityDir, nodeName string, decrypt bool, agePasswordFile string) (string, error) { // Check for encrypted or unencrypted private key encryptedKeyFileName := fmt.Sprintf("%s_private.key.age", nodeName) unencryptedKeyFileName := fmt.Sprintf("%s_private.key", nodeName) @@ -279,15 +280,29 @@ func loadPrivateKey(identityDir, nodeName string, decrypt bool) (string, error) } defer encryptedFile.Close() - // Prompt for passphrase using term.ReadPassword - fmt.Print("Enter passphrase to decrypt private key: ") - bytePassword, err := term.ReadPassword(int(syscall.Stdin)) - fmt.Println() // newline after prompt - if err != nil { - return "", fmt.Errorf("failed to read passphrase: %w", err) + var passphrase string + if agePasswordFile != "" { + // Load passphrase from file + data, err := os.ReadFile(agePasswordFile) + if err != nil { + return "", fmt.Errorf("failed to read age key file %s: %w", agePasswordFile, err) + } + passphrase = strings.TrimSpace(string(data)) // trim newline if present + security.ZeroBytes(data) + logger.Infof("Using passphrase from from file: %s to decrypt node private key", agePasswordFile) + } else { + // Prompt for passphrase from terminal + fmt.Print("Enter passphrase to decrypt private key: ") + bytePassword, err := term.ReadPassword(int(syscall.Stdin)) + fmt.Println() // newline after prompt + if err != nil { + return "", fmt.Errorf("failed to read passphrase: %w", err) + } + passphrase = string(bytePassword) + security.ZeroBytes(bytePassword) } - passphrase := string(bytePassword) - // Create an identity with the provided passphrase + + // Create the identity once, regardless of source identity, err := age.NewScryptIdentity(passphrase) if err != nil { return "", fmt.Errorf("failed to create identity for decryption: %w", err) @@ -305,6 +320,7 @@ func loadPrivateKey(identityDir, nodeName string, decrypt bool) (string, error) return "", fmt.Errorf("failed to read decrypted key: %w", err) } + security.ZeroString(&passphrase) return string(decryptedData), nil } else { // Use the unencrypted private key file diff --git a/pkg/infra/consul.go b/pkg/infra/consul.go index 3eca76f4..a438c933 100644 --- a/pkg/infra/consul.go +++ b/pkg/infra/consul.go @@ -18,14 +18,15 @@ type ConsulKV interface { func GetConsulClient(environment string) *api.Client { config := api.DefaultConfig() - if environment != constant.EnvProduction { - config.Scheme = "http" - } else { - config.Scheme = "https" + if environment == constant.EnvProduction { config.Token = viper.GetString("consul.token") - config.HttpAuth = &api.HttpBasicAuth{ - Username: viper.GetString("consul.username"), - Password: viper.GetString("consul.password"), + username := viper.GetString("consul.username") + password := viper.GetString("consul.password") + if username != "" || password != "" { + config.HttpAuth = &api.HttpBasicAuth{ + Username: username, + Password: password, + } } } diff --git a/scripts/migration/update-keyinfo/main.go b/scripts/migration/update-keyinfo/main.go index f96283c6..ebc06c40 100644 --- a/scripts/migration/update-keyinfo/main.go +++ b/scripts/migration/update-keyinfo/main.go @@ -12,7 +12,7 @@ import ( // script to add key type prefix ecdsa for existing keys func main() { - config.InitViperConfig() + config.InitViperConfig("") logger.Init("production", false) appConfig := config.LoadConfig() logger.Info("App config", "config", appConfig) From 31e45bf15bfdf61d34edb0f270370314d22e3462 Mon Sep 17 00:00:00 2001 From: anhthii Date: Sun, 31 Aug 2025 22:37:14 +0700 Subject: [PATCH 2/2] Add security utilities to zero out secrets --- cmd/mpcium/main.go | 2 - examples/generate/kms/main.go | 2 +- pkg/security/memory.go | 86 ++++++++++++ pkg/security/memory_test.go | 249 ++++++++++++++++++++++++++++++++++ 4 files changed, 336 insertions(+), 3 deletions(-) create mode 100644 pkg/security/memory.go create mode 100644 pkg/security/memory_test.go diff --git a/cmd/mpcium/main.go b/cmd/mpcium/main.go index 2ba8ca4b..26d91015 100644 --- a/cmd/mpcium/main.go +++ b/cmd/mpcium/main.go @@ -314,8 +314,6 @@ func loadPasswordFromFile(filePath string) error { } viper.Set("badger_password", password) - logger.Info(fmt.Sprintf("Loaded BadgerDB password from file: %s", filePath), "password", password) - security.ZeroBytes(passwordBytes) security.ZeroString(&password) diff --git a/examples/generate/kms/main.go b/examples/generate/kms/main.go index 2ccb177c..9bf40250 100644 --- a/examples/generate/kms/main.go +++ b/examples/generate/kms/main.go @@ -30,7 +30,7 @@ func main() { flag.Parse() - config.InitViperConfig() + config.InitViperConfig("") logger.Init(environment, false) // KMS signer only supports P256 diff --git a/pkg/security/memory.go b/pkg/security/memory.go new file mode 100644 index 00000000..489c6794 --- /dev/null +++ b/pkg/security/memory.go @@ -0,0 +1,86 @@ +package security + +import ( + "runtime" +) + +// ZeroBytes securely zeros out a byte slice to prevent sensitive data from +// remaining in memory. This uses explicit memory zeroing and garbage collection +// to help ensure the data is actually cleared. +func ZeroBytes(data []byte) { + if len(data) == 0 { + return + } + + // Zero out the slice + for i := range data { + data[i] = 0 + } + + // Force garbage collection to help ensure the zeroed memory is reclaimed + runtime.GC() +} + +// ZeroString securely clears a string reference and encourages garbage collection. +// Note: Go strings are immutable, so we can only clear the reference and rely on GC. +// This provides best-effort security by clearing the reference and forcing GC. +func ZeroString(s *string) { + if s == nil { + return + } + + // Clear the string reference - this is the safe approach + // The actual string data will be garbage collected + *s = "" + + // Force garbage collection to help clear the original string data from memory + // This is best-effort as GC timing is not guaranteed + runtime.GC() + runtime.GC() // Run twice to increase chances of collection +} + +// SecureBytes is a wrapper for sensitive byte data that automatically +// zeros itself when no longer needed +type SecureBytes struct { + data []byte +} + +// NewSecureBytes creates a new SecureBytes instance +func NewSecureBytes(data []byte) *SecureBytes { + // Make a copy to ensure we own the memory + copied := make([]byte, len(data)) + copy(copied, data) + + sb := &SecureBytes{data: copied} + + // Set finalizer to zero memory when GC'd + runtime.SetFinalizer(sb, (*SecureBytes).zero) + + return sb +} + +// Bytes returns the underlying byte slice (use with caution) +func (sb *SecureBytes) Bytes() []byte { + return sb.data +} + +// Copy returns a copy of the data +func (sb *SecureBytes) Copy() []byte { + result := make([]byte, len(sb.data)) + copy(result, sb.data) + return result +} + +// Clear explicitly zeros the data and removes the finalizer +func (sb *SecureBytes) Clear() { + sb.zero() + runtime.SetFinalizer(sb, nil) +} + +// zero securely clears the data +func (sb *SecureBytes) zero() { + if sb.data != nil { + ZeroBytes(sb.data) + sb.data = nil + } +} \ No newline at end of file diff --git a/pkg/security/memory_test.go b/pkg/security/memory_test.go new file mode 100644 index 00000000..f30f189f --- /dev/null +++ b/pkg/security/memory_test.go @@ -0,0 +1,249 @@ +package security + +import ( + "bytes" + "runtime" + "testing" +) + +func TestZeroBytes(t *testing.T) { + tests := []struct { + name string + data []byte + }{ + { + name: "non-empty slice", + data: []byte("sensitive data"), + }, + { + name: "empty slice", + data: []byte{}, + }, + { + name: "nil slice", + data: nil, + }, + { + name: "single byte", + data: []byte{0x42}, + }, + { + name: "binary data", + data: []byte{0x01, 0x02, 0x03, 0xff, 0xfe, 0xfd}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + original := make([]byte, len(tt.data)) + copy(original, tt.data) + + ZeroBytes(tt.data) + + // Verify all bytes are zeroed + for i, b := range tt.data { + if b != 0 { + t.Errorf("byte at index %d not zeroed: got %d, want 0", i, b) + } + } + + // Verify we didn't panic on edge cases + if len(tt.data) == 0 { + // Should handle empty/nil slices gracefully + return + } + + // Verify the slice was actually modified + if bytes.Equal(tt.data, original) && len(original) > 0 { + t.Error("slice was not modified") + } + }) + } +} + +func TestZeroString(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "non-empty string", + input: "sensitive password", + expected: "", + }, + { + name: "empty string", + input: "", + expected: "", + }, + { + name: "single character", + input: "x", + expected: "", + }, + { + name: "unicode string", + input: "🔐password🔑", + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := tt.input + + ZeroString(&s) + + // Verify string is now empty + if s != tt.expected { + t.Errorf("string not cleared: got %q, want %q", s, tt.expected) + } + + // Note: We can't reliably test if the underlying memory was zeroed + // because Go strings are immutable and memory clearing depends on GC timing + }) + } +} + +func TestZeroStringNilPointer(t *testing.T) { + // Test that ZeroString handles nil pointer gracefully + defer func() { + if r := recover(); r != nil { + t.Errorf("ZeroString panicked with nil pointer: %v", r) + } + }() + + ZeroString(nil) +} + +func TestSecureBytes(t *testing.T) { + t.Run("basic functionality", func(t *testing.T) { + original := []byte("secret data") + sb := NewSecureBytes(original) + + // Verify data is accessible + data := sb.Bytes() + if !bytes.Equal(data, original) { + t.Errorf("SecureBytes data mismatch: got %v, want %v", data, original) + } + + // Verify copy works + copied := sb.Copy() + if !bytes.Equal(copied, original) { + t.Errorf("SecureBytes copy mismatch: got %v, want %v", copied, original) + } + + // Verify modifying copy doesn't affect original + copied[0] = 'X' + if bytes.Equal(sb.Bytes(), copied) { + t.Error("SecureBytes copy shares memory with original") + } + }) + + t.Run("manual clear", func(t *testing.T) { + sb := NewSecureBytes([]byte("secret")) + sb.Clear() + + // After Clear(), data should be nil + if sb.data != nil { + t.Error("SecureBytes data not nil after Clear()") + } + + // Calling Clear() again should not panic + sb.Clear() + }) + + t.Run("finalizer behavior", func(t *testing.T) { + // This test verifies that the finalizer doesn't panic + // We can't easily test that it actually zeros memory due to GC timing + func() { + sb := NewSecureBytes([]byte("secret")) + _ = sb // Use the variable to prevent optimization + }() + + // Force garbage collection to potentially trigger finalizer + runtime.GC() + runtime.GC() + + // If we reach here without panic, the finalizer worked correctly + }) + + t.Run("empty data", func(t *testing.T) { + sb := NewSecureBytes([]byte{}) + + if len(sb.Bytes()) != 0 { + t.Error("SecureBytes should handle empty data") + } + + sb.Clear() + // Should not panic + }) + + t.Run("nil data", func(t *testing.T) { + sb := NewSecureBytes(nil) + + if sb.Bytes() == nil { + t.Error("SecureBytes should create empty slice for nil input") + } + + if len(sb.Bytes()) != 0 { + t.Error("SecureBytes should create empty slice for nil input") + } + }) +} + +func TestSecureBytesIsolation(t *testing.T) { + // Verify that SecureBytes creates its own copy of data + original := []byte("secret data") + sb := NewSecureBytes(original) + + // Modify the original + original[0] = 'X' + + // SecureBytes should be unaffected + if sb.Bytes()[0] == 'X' { + t.Error("SecureBytes shares memory with input data") + } +} + +// Benchmark tests to ensure performance is reasonable +func BenchmarkZeroBytes(b *testing.B) { + data := make([]byte, 1024) + for i := range data { + data[i] = byte(i % 256) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + // Reset data for each iteration + for j := range data { + data[j] = byte(j % 256) + } + ZeroBytes(data) + } +} + +func BenchmarkZeroString(b *testing.B) { + original := "this is a test string that represents a password or other sensitive data" + + b.ResetTimer() + for i := 0; i < b.N; i++ { + s := original + ZeroString(&s) + } +} + +func BenchmarkSecureBytes(b *testing.B) { + data := make([]byte, 256) + for i := range data { + data[i] = byte(i) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + sb := NewSecureBytes(data) + _ = sb.Copy() + sb.Clear() + } +} \ No newline at end of file