Skip to content

Commit c31e6ba

Browse files
committed
Use envconfig AdditionalProfileFields
1 parent 2ebdc57 commit c31e6ba

5 files changed

Lines changed: 160 additions & 292 deletions

File tree

cliext/config.oauth.go

Lines changed: 110 additions & 166 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,32 @@
11
package cliext
22

33
import (
4-
"bytes"
54
"context"
5+
"encoding/json"
66
"errors"
77
"fmt"
88
"os"
99
"path/filepath"
1010
"time"
1111

12-
"github.com/BurntSushi/toml"
1312
"go.temporal.io/sdk/contrib/envconfig"
1413
"golang.org/x/oauth2"
1514
)
1615

16+
// oauthConfigJSON is an intermediate struct for JSON serialization of OAuth config.
17+
type oauthConfigJSON struct {
18+
ClientID string `json:"client_id,omitempty"`
19+
ClientSecret string `json:"client_secret,omitempty"`
20+
TokenURL string `json:"token_url,omitempty"`
21+
AuthURL string `json:"auth_url,omitempty"`
22+
RedirectURL string `json:"redirect_url,omitempty"`
23+
AccessToken string `json:"access_token,omitempty"`
24+
RefreshToken string `json:"refresh_token,omitempty"`
25+
TokenType string `json:"token_type,omitempty"`
26+
ExpiresAt string `json:"expires_at,omitempty"`
27+
Scopes []string `json:"scopes,omitempty"`
28+
}
29+
1730
// OAuthConfig combines OAuth client configuration with token information.
1831
type OAuthConfig struct {
1932
// ClientConfig is the OAuth 2.0 client configuration.
@@ -84,52 +97,72 @@ func loadOAuthConfigFromFile(path string) (map[string]*OAuthConfig, error) {
8497
return nil, fmt.Errorf("failed to read config file: %w", err)
8598
}
8699

87-
var raw rawConfigWithOAuth
88-
if _, err := toml.Decode(string(data), &raw); err != nil {
100+
// Use envconfig's FromTOML with AdditionalProfileFields to capture OAuth fields
101+
var conf envconfig.ClientConfig
102+
additional := make(map[string]map[string]any)
103+
if err := conf.FromTOML(data, envconfig.ClientConfigFromTOMLOptions{
104+
AdditionalProfileFields: additional,
105+
}); err != nil {
89106
return nil, fmt.Errorf("failed to parse config file: %w", err)
90107
}
91108

92109
oauthByProfile := make(map[string]*OAuthConfig)
93-
for profileName, profile := range raw.Profile {
94-
if profile == nil || profile.OAuth == nil {
95-
oauthByProfile[profileName] = nil
110+
for profileName, fields := range additional {
111+
oauthRaw, ok := fields["oauth"].(map[string]any)
112+
if !ok {
96113
continue
97114
}
98-
cfg := profile.OAuth
99-
100-
// Parse expiry time if present
101-
var expiry time.Time
102-
if cfg.ExpiresAt != "" {
103-
t, err := time.Parse(time.RFC3339, cfg.ExpiresAt)
104-
if err != nil {
105-
return nil, fmt.Errorf("failed to parse expires_at for profile %q: %w", profileName, err)
106-
}
107-
expiry = t
108-
}
109-
110-
oauth := &OAuthConfig{
111-
ClientConfig: &oauth2.Config{
112-
ClientID: cfg.ClientID,
113-
ClientSecret: cfg.ClientSecret,
114-
RedirectURL: cfg.RedirectURL,
115-
Scopes: cfg.Scopes,
116-
Endpoint: oauth2.Endpoint{
117-
AuthURL: cfg.AuthURL,
118-
TokenURL: cfg.TokenURL,
119-
},
120-
},
121-
Token: &oauth2.Token{
122-
AccessToken: cfg.AccessToken,
123-
RefreshToken: cfg.RefreshToken,
124-
TokenType: cfg.TokenType,
125-
Expiry: expiry,
126-
},
115+
oauth, err := oauthConfigFromMap(oauthRaw)
116+
if err != nil {
117+
return nil, fmt.Errorf("failed to parse oauth for profile %q: %w", profileName, err)
127118
}
128119
oauthByProfile[profileName] = oauth
129120
}
130121
return oauthByProfile, nil
131122
}
132123

124+
// oauthConfigFromMap converts a map[string]any to OAuthConfig using JSON as intermediary.
125+
func oauthConfigFromMap(m map[string]any) (*OAuthConfig, error) {
126+
data, err := json.Marshal(m)
127+
if err != nil {
128+
return nil, fmt.Errorf("failed to marshal oauth config: %w", err)
129+
}
130+
131+
var cfg oauthConfigJSON
132+
if err := json.Unmarshal(data, &cfg); err != nil {
133+
return nil, fmt.Errorf("failed to unmarshal oauth config: %w", err)
134+
}
135+
136+
// Parse expiry time if present
137+
var expiry time.Time
138+
if cfg.ExpiresAt != "" {
139+
t, err := time.Parse(time.RFC3339, cfg.ExpiresAt)
140+
if err != nil {
141+
return nil, fmt.Errorf("failed to parse expires_at: %w", err)
142+
}
143+
expiry = t
144+
}
145+
146+
return &OAuthConfig{
147+
ClientConfig: &oauth2.Config{
148+
ClientID: cfg.ClientID,
149+
ClientSecret: cfg.ClientSecret,
150+
RedirectURL: cfg.RedirectURL,
151+
Scopes: cfg.Scopes,
152+
Endpoint: oauth2.Endpoint{
153+
AuthURL: cfg.AuthURL,
154+
TokenURL: cfg.TokenURL,
155+
},
156+
},
157+
Token: &oauth2.Token{
158+
AccessToken: cfg.AccessToken,
159+
RefreshToken: cfg.RefreshToken,
160+
TokenType: cfg.TokenType,
161+
Expiry: expiry,
162+
},
163+
}, nil
164+
}
165+
133166
// resolveConfigAndProfile resolves the config file path and profile name.
134167
func resolveConfigAndProfile(configFilePath, profileName string, envLookup envconfig.EnvLookup) (string, string, error) {
135168
if envLookup == nil {
@@ -182,43 +215,45 @@ func StoreClientOAuth(opts StoreClientOAuthOptions) error {
182215
return err
183216
}
184217

185-
// Read and parse existing file content.
218+
// Read and parse existing file content using envconfig.
186219
existingContent, err := os.ReadFile(configFilePath)
187220
if err != nil && !errors.Is(err, os.ErrNotExist) {
188221
return fmt.Errorf("failed to read config file: %w", err)
189222
}
190223

191-
var existingRaw map[string]any
224+
var conf envconfig.ClientConfig
225+
additional := make(map[string]map[string]any)
192226
if len(existingContent) > 0 {
193-
if _, err := toml.Decode(string(existingContent), &existingRaw); err != nil {
227+
if err := conf.FromTOML(existingContent, envconfig.ClientConfigFromTOMLOptions{
228+
AdditionalProfileFields: additional,
229+
}); err != nil {
194230
return fmt.Errorf("failed to parse existing config: %w", err)
195231
}
196232
}
197-
if existingRaw == nil {
198-
existingRaw = make(map[string]any)
199-
}
200233

201-
// Load existing OAuth configs from the parsed content.
202-
oauthByProfile, err := parseOAuthFromRaw(existingRaw)
203-
if err != nil {
204-
return fmt.Errorf("failed to parse existing OAuth config: %w", err)
234+
// Ensure the profile exists in the config.
235+
if conf.Profiles == nil {
236+
conf.Profiles = make(map[string]*envconfig.ClientConfigProfile)
205237
}
206-
if oauthByProfile == nil {
207-
oauthByProfile = make(map[string]*OAuthConfig)
238+
if conf.Profiles[profileName] == nil {
239+
conf.Profiles[profileName] = &envconfig.ClientConfigProfile{}
208240
}
209241

210-
// Update the OAuth config for this profile.
211-
oauthByProfile[profileName] = opts.OAuth
212-
213-
// Merge OAuth configs back into the raw structure.
214-
if err := mergeOAuthIntoRaw(existingRaw, oauthByProfile); err != nil {
215-
return err
242+
// Update the OAuth config for this profile in additional fields.
243+
if additional[profileName] == nil {
244+
additional[profileName] = make(map[string]any)
245+
}
246+
if opts.OAuth == nil {
247+
delete(additional[profileName], "oauth")
248+
} else {
249+
additional[profileName]["oauth"] = oauthConfigToMap(opts.OAuth)
216250
}
217251

218-
// Marshal back to TOML.
219-
var buf bytes.Buffer
220-
enc := toml.NewEncoder(&buf)
221-
if err := enc.Encode(existingRaw); err != nil {
252+
// Marshal back to TOML using envconfig.
253+
data, err := conf.ToTOML(envconfig.ClientConfigToTOMLOptions{
254+
AdditionalProfileFields: additional,
255+
})
256+
if err != nil {
222257
return fmt.Errorf("failed to encode config: %w", err)
223258
}
224259

@@ -227,122 +262,20 @@ func StoreClientOAuth(opts StoreClientOAuthOptions) error {
227262
return fmt.Errorf("failed to create config directory: %w", err)
228263
}
229264

230-
if err := os.WriteFile(configFilePath, buf.Bytes(), 0600); err != nil {
265+
if err := os.WriteFile(configFilePath, data, 0600); err != nil {
231266
return fmt.Errorf("failed to write config file: %w", err)
232267
}
233268

234269
return nil
235270
}
236271

237-
func parseOAuthFromRaw(raw map[string]any) (map[string]*OAuthConfig, error) {
238-
var parsed rawConfigWithOAuth
239-
240-
// Re-encode and decode to convert map[string]any to our struct.
241-
// This is simpler than manual type assertions for nested structures.
242-
var buf bytes.Buffer
243-
enc := toml.NewEncoder(&buf)
244-
if err := enc.Encode(raw); err != nil {
245-
return nil, err
246-
}
247-
if _, err := toml.Decode(buf.String(), &parsed); err != nil {
248-
return nil, err
249-
}
250-
251-
oauthByProfile := make(map[string]*OAuthConfig)
252-
for profileName, profile := range parsed.Profile {
253-
if profile == nil || profile.OAuth == nil {
254-
continue
255-
}
256-
cfg := profile.OAuth
257-
258-
// Parse expiry time if present
259-
var expiry time.Time
260-
if cfg.ExpiresAt != "" {
261-
t, err := time.Parse(time.RFC3339, cfg.ExpiresAt)
262-
if err != nil {
263-
return nil, fmt.Errorf("failed to parse expires_at for profile %q: %w", profileName, err)
264-
}
265-
expiry = t
266-
}
267-
268-
oauth := &OAuthConfig{
269-
ClientConfig: &oauth2.Config{
270-
ClientID: cfg.ClientID,
271-
ClientSecret: cfg.ClientSecret,
272-
RedirectURL: cfg.RedirectURL,
273-
Scopes: cfg.Scopes,
274-
Endpoint: oauth2.Endpoint{
275-
AuthURL: cfg.AuthURL,
276-
TokenURL: cfg.TokenURL,
277-
},
278-
},
279-
Token: &oauth2.Token{
280-
AccessToken: cfg.AccessToken,
281-
RefreshToken: cfg.RefreshToken,
282-
TokenType: cfg.TokenType,
283-
Expiry: expiry,
284-
},
285-
}
286-
oauthByProfile[profileName] = oauth
287-
}
288-
return oauthByProfile, nil
289-
}
290-
291-
// mergeOAuthIntoRaw merges OAuth configurations into a raw TOML structure.
292-
func mergeOAuthIntoRaw(raw map[string]any, oauthByProfile map[string]*OAuthConfig) error {
293-
// Get or create the profile section.
294-
profileSection, ok := raw["profile"].(map[string]any)
295-
if !ok {
296-
profileSection = make(map[string]any)
297-
raw["profile"] = profileSection
298-
}
299-
300-
// Update OAuth for each profile.
301-
for profileName, oauth := range oauthByProfile {
302-
profile, ok := profileSection[profileName].(map[string]any)
303-
if !ok {
304-
profile = make(map[string]any)
305-
profileSection[profileName] = profile
306-
}
307-
308-
if oauth == nil {
309-
delete(profile, "oauth")
310-
} else {
311-
profile["oauth"] = oauthConfigToTOML(oauth)
312-
}
313-
}
314-
315-
return nil
316-
}
317-
318-
// oauthConfigTOML is the TOML representation of OAuthConfig.
319-
type oauthConfigTOML struct {
320-
ClientID string `toml:"client_id,omitempty"`
321-
ClientSecret string `toml:"client_secret,omitempty"`
322-
TokenURL string `toml:"token_url,omitempty"`
323-
AuthURL string `toml:"auth_url,omitempty"`
324-
RedirectURL string `toml:"redirect_url,omitempty"`
325-
AccessToken string `toml:"access_token,omitempty"`
326-
RefreshToken string `toml:"refresh_token,omitempty"`
327-
TokenType string `toml:"token_type,omitempty"`
328-
ExpiresAt string `toml:"expires_at,omitempty"`
329-
Scopes []string `toml:"scopes,omitempty"`
330-
}
331-
332-
type rawProfileWithOAuth struct {
333-
OAuth *oauthConfigTOML `toml:"oauth"`
334-
}
335-
336-
type rawConfigWithOAuth struct {
337-
Profile map[string]*rawProfileWithOAuth `toml:"profile"`
338-
}
339-
340-
// oauthConfigToTOML converts OAuthConfig to its TOML representation.
341-
func oauthConfigToTOML(oauth *OAuthConfig) *oauthConfigTOML {
272+
// oauthConfigToMap converts OAuthConfig to map[string]any using JSON as intermediary.
273+
func oauthConfigToMap(oauth *OAuthConfig) map[string]any {
342274
if oauth == nil || oauth.ClientConfig == nil || oauth.Token == nil {
343275
return nil
344276
}
345-
result := &oauthConfigTOML{
277+
278+
cfg := oauthConfigJSON{
346279
ClientID: oauth.ClientConfig.ClientID,
347280
ClientSecret: oauth.ClientConfig.ClientSecret,
348281
TokenURL: oauth.ClientConfig.Endpoint.TokenURL,
@@ -354,7 +287,18 @@ func oauthConfigToTOML(oauth *OAuthConfig) *oauthConfigTOML {
354287
Scopes: oauth.ClientConfig.Scopes,
355288
}
356289
if !oauth.Token.Expiry.IsZero() {
357-
result.ExpiresAt = oauth.Token.Expiry.Format(time.RFC3339)
290+
cfg.ExpiresAt = oauth.Token.Expiry.Format(time.RFC3339)
291+
}
292+
293+
data, err := json.Marshal(cfg)
294+
if err != nil {
295+
return nil
296+
}
297+
298+
var result map[string]any
299+
if err := json.Unmarshal(data, &result); err != nil {
300+
return nil
358301
}
359302
return result
360303
}
304+

0 commit comments

Comments
 (0)