diff --git a/loader/env_test.go b/loader/env_test.go new file mode 100644 index 0000000..bc18666 --- /dev/null +++ b/loader/env_test.go @@ -0,0 +1,149 @@ +package loader + +import ( + "os" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestReplaceEnvInJSON(t *testing.T) { + tests := []struct { + name string + input string + envVars map[string]string + expected string + }{ + { + name: "simple replacement", + input: `{"name": "$TEST_VAR"}`, + envVars: map[string]string{"TEST_VAR": "hello"}, + expected: `{"name": "hello"}`, + }, + { + name: "multiple replacements", + input: `{"name": "$NAME", "host": "$HOST"}`, + envVars: map[string]string{"NAME": "test", "HOST": "localhost"}, + expected: `{"name": "test", "host": "localhost"}`, + }, + { + name: "quotes in env value", + input: `{"message": "$QUOTED_VAR"}`, + envVars: map[string]string{"QUOTED_VAR": `say "hello"`}, + expected: `{"message": "say \"hello\""}`, + }, + { + name: "backslashes in env value", + input: `{"path": "$PATH_VAR"}`, + envVars: map[string]string{"PATH_VAR": `C:\Program Files\App`}, + expected: `{"path": "C:\\Program Files\\App"}`, + }, + { + name: "quotes and backslashes together", + input: `{"command": "$CMD_VAR"}`, + envVars: map[string]string{"CMD_VAR": `echo "C:\test"`}, + expected: `{"command": "echo \"C:\\test\""}`, + }, + { + name: "no matching env var", + input: `{"name": "$NONEXISTENT"}`, + envVars: map[string]string{}, + expected: `{"name": ""}`, + }, + { + name: "mixed env and regular strings", + input: `{"env": "$TEST", "regular": "value"}`, + envVars: map[string]string{"TEST": "replaced"}, + expected: `{"env": "replaced", "regular": "value"}`, + }, + { + name: "env var in nested structure", + input: `{"config": {"host": "$HOST", "port": 8080}}`, + envVars: map[string]string{"HOST": "api.example.com"}, + expected: `{"config": {"host": "api.example.com", "port": 8080}}`, + }, + { + name: "env var in array", + input: `{"servers": ["$SERVER1", "$SERVER2"]}`, + envVars: map[string]string{"SERVER1": "host1", "SERVER2": "host2"}, + expected: `{"servers": ["host1", "host2"]}`, + }, + { + name: "complex escaping case", + input: `{"script": "$COMPLEX_SCRIPT"}`, + envVars: map[string]string{"COMPLEX_SCRIPT": `echo "Hello \"World\"" && echo 'C:\Program Files\test'`}, + expected: `{"script": "echo \"Hello \\\"World\\\"\" \u0026\u0026 echo 'C:\\Program Files\\test'"}`, + }, + { + name: "newlines and tabs in env value", + input: `{"multiline": "$MULTILINE_VAR"}`, + envVars: map[string]string{"MULTILINE_VAR": "line1\nline2\tindented"}, + expected: `{"multiline": "line1\nline2\tindented"}`, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + // Set up environment variables + for key, value := range test.envVars { + os.Setenv(key, value) + defer os.Unsetenv(key) + } + + result := replaceEnvInJSON([]byte(test.input)) + assert.Equal(t, test.expected, string(result)) + }) + } +} + +func TestEscapeForJSON(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "no escaping needed", + input: "simple string", + expected: "simple string", + }, + { + name: "escape quotes", + input: `say "hello"`, + expected: `say \"hello\"`, + }, + { + name: "escape backslashes", + input: `C:\Program Files`, + expected: `C:\\Program Files`, + }, + { + name: "escape both quotes and backslashes", + input: `echo "C:\test"`, + expected: `echo \"C:\\test\"`, + }, + { + name: "multiple backslashes", + input: `path\\to\\file`, + expected: `path\\\\to\\\\file`, + }, + { + name: "newlines and tabs", + input: "line1\nline2\tindented", + expected: "line1\\nline2\\tindented", + }, + { + name: "empty string", + input: "", + expected: "", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + result := escapeForJSON(test.input) + assert.Equal(t, test.expected, result) + }) + } +} + diff --git a/loader/loader.go b/loader/loader.go index 5f32b04..da4ca0d 100644 --- a/loader/loader.go +++ b/loader/loader.go @@ -5,7 +5,7 @@ import ( "fmt" "os" "reflect" - "strings" + "regexp" "sync" "github.com/segmentio/encoding/json" @@ -17,19 +17,19 @@ import ( // The config must be a pointer to a struct. // It parses the byteslice as hujson, which allows for C-style comments and // trailing commas on arrays and maps. -// It then unmarshals the JSON into the config struct. -// Finally, it replaces any environment variables in the struct with their -// values referenced by the corresponding environment variables. +// It replaces any environment variables in the JSON string before unmarshalling, +// properly escaping values to maintain JSON structure. +// Finally, it unmarshals the JSON into the config struct. func LoadConfig(bts []byte, cfg any) error { bts, err := hujson.Standardize(bts) if err != nil { return err } + bts = replaceEnvInJSON(bts) err = json.Unmarshal(bts, cfg) if err != nil { return err } - replaceEnv(reflect.ValueOf(cfg)) return nil } @@ -120,35 +120,33 @@ func (l Loader[T]) Configure() (T, error) { return l.Builder.Configure() } -func replaceEnv(v reflect.Value) { - if !v.IsValid() { - return - } +var envVarRegex = regexp.MustCompile(`"\$([A-Za-z_][A-Za-z0-9_]*)"`) - switch v.Kind() { - case reflect.String: - val := v.String() - if v.CanSet() && strings.HasPrefix(val, "$") { - envVar, _ := strings.CutPrefix(val, "$") - v.SetString(os.Getenv(envVar)) - } - case reflect.Ptr: - replaceEnv(v.Elem()) - case reflect.Struct: - for i := 0; i < v.NumField(); i++ { - replaceEnv(v.Field(i)) - } - case reflect.Slice: - for i := 0; i < v.Len(); i++ { - replaceEnv(v.Index(i)) - } - case reflect.Interface: - if v.IsNil() { - return - } - copied := reflect.New(v.Elem().Type()).Elem() - copied.Set(v.Elem()) - replaceEnv(copied) - v.Set(copied) +// replaceEnvInJSON replaces environment variables in JSON string values +// while properly escaping the replacement values to maintain JSON structure +func replaceEnvInJSON(jsonBytes []byte) []byte { + return envVarRegex.ReplaceAllFunc(jsonBytes, func(match []byte) []byte { + // Extract the environment variable name (without the $ prefix and quotes) + envVar := string(match[2 : len(match)-1]) + envValue := os.Getenv(envVar) + + // Escape backslashes and quotes in the environment value + escapedValue := escapeForJSON(envValue) + + // Return the escaped value wrapped in quotes + return []byte(`"` + escapedValue + `"`) + }) +} + +// escapeForJSON escapes a string for JSON using Go's json.Marshal +// This ensures RFC 7159/8259 compliance for all special characters +func escapeForJSON(s string) string { + // Use json.Marshal to properly escape the string per RFC 7159/8259 + escaped, err := json.Marshal(s) + if err != nil { + // This should never happen for a string, but fallback just in case + return s } + // Remove the outer quotes that json.Marshal adds + return string(escaped[1 : len(escaped)-1]) } diff --git a/loader/loader_test.go b/loader/loader_test.go index 2e9a516..24f98f5 100644 --- a/loader/loader_test.go +++ b/loader/loader_test.go @@ -38,6 +38,7 @@ func TestLoadConfigUnreg(t *testing.T) { assert.Error(t, err) } + func TestLoadConfig(t *testing.T) { loader.Register("aTypeOfSource", func() loader.Builder[Source] { return &srcConfigA{} })