Skip to content

Commit 0e7f657

Browse files
peterguyampcode-comburmudar
authored
refactor: validate and parse the endpoint and proxy at program load (#1267)
* refactor: validate and parse the endpoint and proxy at program load Amp-Thread-ID: https://ampcode.com/threads/T-019cdb3f-f7de-750b-b4c3-13762c7dfc11 Co-authored-by: Amp <amp@ampcode.com> * Update internal/oauth/flow.go Co-authored-by: William Bezuidenhout <william.bezuidenhout@sourcegraph.com> --------- Co-authored-by: Amp <amp@ampcode.com> Co-authored-by: William Bezuidenhout <william.bezuidenhout@sourcegraph.com>
1 parent 22191a7 commit 0e7f657

28 files changed

+441
-325
lines changed

cmd/src/batch_common.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -537,7 +537,7 @@ func executeBatchSpec(ctx context.Context, opts executeBatchSpecOpts) (err error
537537
if err != nil {
538538
return execUI.CreatingBatchSpecError(lr.MaxUnlicensedChangesets, err)
539539
}
540-
previewURL := cfg.Endpoint + url
540+
previewURL := cfg.endpointURL.JoinPath(url).String()
541541
execUI.CreatingBatchSpecSuccess(previewURL)
542542

543543
hasWorkspaceFiles := false
@@ -567,7 +567,7 @@ func executeBatchSpec(ctx context.Context, opts executeBatchSpecOpts) (err error
567567
if err != nil {
568568
return err
569569
}
570-
execUI.ApplyingBatchSpecSuccess(cfg.Endpoint + batch.URL)
570+
execUI.ApplyingBatchSpecSuccess(cfg.endpointURL.JoinPath(batch.URL).String())
571571

572572
return nil
573573
}

cmd/src/batch_remote.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import (
55
"flag"
66
"fmt"
77
cliLog "log"
8-
"strings"
98
"time"
109

1110
"github.com/sourcegraph/sourcegraph/lib/errors"
@@ -155,13 +154,14 @@ Examples:
155154
}
156155
ui.ExecutingBatchSpecSuccess()
157156

158-
executionURL := fmt.Sprintf(
159-
"%s/%s/batch-changes/%s/executions/%s",
160-
strings.TrimSuffix(cfg.Endpoint, "/"),
161-
strings.TrimPrefix(namespace.URL, "/"),
162-
batchChangeName,
163-
batchSpecID,
164-
)
157+
executionURL := cfg.endpointURL.JoinPath(
158+
fmt.Sprintf(
159+
"%s/batch-changes/%s/executions/%s",
160+
namespace.URL,
161+
batchChangeName,
162+
batchSpecID,
163+
),
164+
).String()
165165
ui.RemoteSuccess(executionURL)
166166

167167
return nil

cmd/src/batch_repositories.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ Examples:
131131
Max: max,
132132
RepoCount: len(repos),
133133
Repos: repos,
134-
SourcegraphEndpoint: cfg.Endpoint,
134+
SourcegraphEndpoint: cfg.endpointURL.String(),
135135
}); err != nil {
136136
return err
137137
}

cmd/src/code_intel_upload.go

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ import (
77
"flag"
88
"fmt"
99
"io"
10-
"net/url"
1110
"os"
1211
"strings"
1312
"time"
@@ -87,10 +86,7 @@ func handleCodeIntelUpload(args []string) error {
8786
return handleUploadError(uploadOptions.SourcegraphInstanceOptions.AccessToken, err)
8887
}
8988

90-
uploadURL, err := makeCodeIntelUploadURL(uploadID)
91-
if err != nil {
92-
return err
93-
}
89+
uploadURL := makeCodeIntelUploadURL(uploadID)
9490

9591
if codeintelUploadFlags.json {
9692
serialized, err := json.Marshal(map[string]any{
@@ -132,7 +128,7 @@ func codeintelUploadOptions(out *output.Output) upload.UploadOptions {
132128
associatedIndexID = &codeintelUploadFlags.associatedIndexID
133129
}
134130

135-
cfg.AdditionalHeaders["Content-Type"] = "application/x-protobuf+scip"
131+
cfg.additionalHeaders["Content-Type"] = "application/x-protobuf+scip"
136132

137133
logger := upload.NewRequestLogger(
138134
os.Stdout,
@@ -153,9 +149,9 @@ func codeintelUploadOptions(out *output.Output) upload.UploadOptions {
153149
AssociatedIndexID: associatedIndexID,
154150
},
155151
SourcegraphInstanceOptions: upload.SourcegraphInstanceOptions{
156-
SourcegraphURL: cfg.Endpoint,
157-
AccessToken: cfg.AccessToken,
158-
AdditionalHeaders: cfg.AdditionalHeaders,
152+
SourcegraphURL: cfg.endpointURL.String(),
153+
AccessToken: cfg.accessToken,
154+
AdditionalHeaders: cfg.additionalHeaders,
159155
MaxRetries: 5,
160156
RetryInterval: time.Second,
161157
Path: codeintelUploadFlags.uploadRoute,
@@ -191,16 +187,12 @@ func printInferredArguments(out *output.Output) {
191187

192188
// makeCodeIntelUploadURL constructs a URL to the upload with the given internal identifier.
193189
// The base of the URL is constructed from the configured Sourcegraph instance.
194-
func makeCodeIntelUploadURL(uploadID int) (string, error) {
195-
url, err := url.Parse(cfg.Endpoint)
196-
if err != nil {
197-
return "", err
198-
}
199-
190+
func makeCodeIntelUploadURL(uploadID int) string {
191+
// Careful: copy by dereference makes a shallow copy, so User is not duplicated.
192+
url := *cfg.endpointURL
200193
graphqlID := base64.URLEncoding.EncodeToString(fmt.Appendf(nil, `SCIPUpload:%d`, uploadID))
201194
url.Path = codeintelUploadFlags.repo + "/-/code-intelligence/uploads/" + graphqlID
202-
url.User = nil
203-
return url.String(), nil
195+
return url.String()
204196
}
205197

206198
type errorWithHint struct {

cmd/src/debug_compose.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ Examples:
7575
return errors.Wrap(err, "failed to get containers for subcommand with err")
7676
}
7777
// Safety check user knows what they are targeting with this debug command
78-
log.Printf("This command will archive docker-cli data for %d containers\n SRC_ENDPOINT: %v\n Output filename: %v", len(containers), cfg.Endpoint, base)
78+
log.Printf("This command will archive docker-cli data for %d containers\n SRC_ENDPOINT: %v\n Output filename: %v", len(containers), cfg.endpointURL, base)
7979
if verified, _ := verify("Do you want to start writing to an archive?"); !verified {
8080
return nil
8181
}

cmd/src/debug_kube.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ Examples:
8484
return errors.Wrapf(err, "failed to get current-context")
8585
}
8686
// Safety check user knows what they've targeted with this command
87-
log.Printf("Archiving kubectl data for %d pods\n SRC_ENDPOINT: %v\n Context: %s Namespace: %v\n Output filename: %v", len(pods.Items), cfg.Endpoint, kubectx, namespace, base)
87+
log.Printf("Archiving kubectl data for %d pods\n SRC_ENDPOINT: %v\n Context: %s Namespace: %v\n Output filename: %v", len(pods.Items), cfg.endpointURL, kubectx, namespace, base)
8888
if verified, _ := verify("Do you want to start writing to an archive?"); !verified {
8989
return nil
9090
}

cmd/src/debug_server.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ Examples:
7272
defer zw.Close()
7373

7474
// Safety check user knows what they are targeting with this debug command
75-
log.Printf("This command will archive docker-cli data for container: %s\n SRC_ENDPOINT: %s\n Output filename: %s", container, cfg.Endpoint, base)
75+
log.Printf("This command will archive docker-cli data for container: %s\n SRC_ENDPOINT: %s\n Output filename: %s", container, cfg.endpointURL, base)
7676
if verified, _ := verify("Do you want to start writing to an archive?"); !verified {
7777
return nil
7878
}

cmd/src/login.go

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"flag"
66
"fmt"
77
"io"
8+
"net/url"
89
"os"
910

1011
"github.com/sourcegraph/src-cli/internal/api"
@@ -48,23 +49,26 @@ Examples:
4849
if err := flagSet.Parse(args); err != nil {
4950
return err
5051
}
51-
endpoint := cfg.Endpoint
52+
53+
var loginEndpointURL *url.URL
5254
if flagSet.NArg() >= 1 {
53-
endpoint = flagSet.Arg(0)
54-
}
55-
if endpoint == "" {
56-
return cmderrors.Usage("expected exactly one argument: the Sourcegraph URL, or SRC_ENDPOINT to be set")
55+
arg := flagSet.Arg(0)
56+
u, err := parseEndpoint(arg)
57+
if err != nil {
58+
return cmderrors.Usage(fmt.Sprintf("invalid endpoint URL: %s", arg))
59+
}
60+
loginEndpointURL = u
5761
}
5862

5963
client := cfg.apiClient(apiFlags, io.Discard)
6064

6165
return loginCmd(context.Background(), loginParams{
62-
cfg: cfg,
63-
client: client,
64-
endpoint: endpoint,
65-
out: os.Stdout,
66-
apiFlags: apiFlags,
67-
oauthClient: oauth.NewClient(oauth.DefaultClientID),
66+
cfg: cfg,
67+
client: client,
68+
out: os.Stdout,
69+
apiFlags: apiFlags,
70+
oauthClient: oauth.NewClient(oauth.DefaultClientID),
71+
loginEndpointURL: loginEndpointURL,
6872
})
6973
}
7074

@@ -76,12 +80,12 @@ Examples:
7680
}
7781

7882
type loginParams struct {
79-
cfg *config
80-
client api.Client
81-
endpoint string
82-
out io.Writer
83-
apiFlags *api.Flags
84-
oauthClient oauth.Client
83+
cfg *config
84+
client api.Client
85+
out io.Writer
86+
apiFlags *api.Flags
87+
oauthClient oauth.Client
88+
loginEndpointURL *url.URL
8589
}
8690

8791
type loginFlow func(context.Context, loginParams) error
@@ -96,9 +100,9 @@ const (
96100
)
97101

98102
func loginCmd(ctx context.Context, p loginParams) error {
99-
if p.cfg.ConfigFilePath != "" {
103+
if p.cfg.configFilePath != "" {
100104
fmt.Fprintln(p.out)
101-
fmt.Fprintf(p.out, "⚠️ Warning: Configuring src with a JSON file is deprecated. Please migrate to using the env vars SRC_ENDPOINT, SRC_ACCESS_TOKEN, and SRC_PROXY instead, and then remove %s. See https://github.com/sourcegraph/src-cli#readme for more information.\n", p.cfg.ConfigFilePath)
105+
fmt.Fprintf(p.out, "⚠️ Warning: Configuring src with a JSON file is deprecated. Please migrate to using the env vars SRC_ENDPOINT, SRC_ACCESS_TOKEN, and SRC_PROXY instead, and then remove %s. See https://github.com/sourcegraph/src-cli#readme for more information.\n", p.cfg.configFilePath)
102106
}
103107

104108
_, flow := selectLoginFlow(p)
@@ -107,15 +111,13 @@ func loginCmd(ctx context.Context, p loginParams) error {
107111

108112
// selectLoginFlow decides what login flow to run based on configured AuthMode.
109113
func selectLoginFlow(p loginParams) (loginFlowKind, loginFlow) {
110-
endpointArg := cleanEndpoint(p.endpoint)
111-
114+
if p.loginEndpointURL != nil && p.loginEndpointURL.String() != p.cfg.endpointURL.String() {
115+
return loginFlowEndpointConflict, runEndpointConflictLogin
116+
}
112117
switch p.cfg.AuthMode() {
113118
case AuthModeOAuth:
114119
return loginFlowOAuth, runOAuthLogin
115120
case AuthModeAccessToken:
116-
if endpointArg != p.cfg.Endpoint {
117-
return loginFlowEndpointConflict, runEndpointConflictLogin
118-
}
119121
return loginFlowValidate, runValidatedLogin
120122
default:
121123
return loginFlowMissingAuth, runMissingAuthLogin
@@ -126,7 +128,7 @@ func printLoginProblem(out io.Writer, problem string) {
126128
fmt.Fprintf(out, "❌ Problem: %s\n", problem)
127129
}
128130

129-
func loginAccessTokenMessage(endpoint string) string {
131+
func loginAccessTokenMessage(endpointURL *url.URL) string {
130132
return fmt.Sprintf("\n"+`🛠 To fix: Create an access token by going to %s/user/settings/tokens, then set the following environment variables in your terminal:
131133
132134
export SRC_ENDPOINT=%s
@@ -135,5 +137,5 @@ func loginAccessTokenMessage(endpoint string) string {
135137
To verify that it's working, run the login command again.
136138
137139
Alternatively, you can try logging in interactively by running: src login %s
138-
`, endpoint, endpoint, endpoint)
140+
`, endpointURL, endpointURL, endpointURL)
139141
}

cmd/src/login_oauth.go

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,14 @@ import (
1818
var loadStoredOAuthToken = oauth.LoadToken
1919

2020
func runOAuthLogin(ctx context.Context, p loginParams) error {
21-
endpointArg := cleanEndpoint(p.endpoint)
22-
client, err := oauthLoginClient(ctx, p, endpointArg)
21+
client, err := oauthLoginClient(ctx, p)
2322
if err != nil {
2423
printLoginProblem(p.out, fmt.Sprintf("OAuth Device flow authentication failed: %s", err))
25-
fmt.Fprintln(p.out, loginAccessTokenMessage(endpointArg))
24+
fmt.Fprintln(p.out, loginAccessTokenMessage(p.cfg.endpointURL))
2625
return cmderrors.ExitCode1
2726
}
2827

29-
if err := validateCurrentUser(ctx, client, p.out, endpointArg); err != nil {
28+
if err := validateCurrentUser(ctx, client, p.out, p.cfg.endpointURL); err != nil {
3029
return err
3130
}
3231

@@ -39,13 +38,13 @@ func runOAuthLogin(ctx context.Context, p loginParams) error {
3938
// oauthLoginClient returns a api.Client with the OAuth token set. It will check secret storage for a token
4039
// and use it if one is present.
4140
// If no token is found, it will start a OAuth Device flow to get a token and storage in secret storage.
42-
func oauthLoginClient(ctx context.Context, p loginParams, endpoint string) (api.Client, error) {
41+
func oauthLoginClient(ctx context.Context, p loginParams) (api.Client, error) {
4342
// if we have a stored token, used it. Otherwise run the device flow
44-
if token, err := loadStoredOAuthToken(ctx, endpoint); err == nil {
45-
return newOAuthAPIClient(p, endpoint, token), nil
43+
if token, err := loadStoredOAuthToken(ctx, p.cfg.endpointURL); err == nil {
44+
return newOAuthAPIClient(p, token), nil
4645
}
4746

48-
token, err := runOAuthDeviceFlow(ctx, endpoint, p.out, p.oauthClient)
47+
token, err := runOAuthDeviceFlow(ctx, p.cfg.endpointURL, p.out, p.oauthClient)
4948
if err != nil {
5049
return nil, err
5150
}
@@ -55,23 +54,23 @@ func oauthLoginClient(ctx context.Context, p loginParams, endpoint string) (api.
5554
fmt.Fprintf(p.out, "⚠️ Warning: Failed to store token in keyring store: %q. Continuing with this session only.\n", err)
5655
}
5756

58-
return newOAuthAPIClient(p, endpoint, token), nil
57+
return newOAuthAPIClient(p, token), nil
5958
}
6059

61-
func newOAuthAPIClient(p loginParams, endpoint string, token *oauth.Token) api.Client {
60+
func newOAuthAPIClient(p loginParams, token *oauth.Token) api.Client {
6261
return api.NewClient(api.ClientOpts{
63-
Endpoint: endpoint,
64-
AdditionalHeaders: p.cfg.AdditionalHeaders,
62+
EndpointURL: p.cfg.endpointURL,
63+
AdditionalHeaders: p.cfg.additionalHeaders,
6564
Flags: p.apiFlags,
6665
Out: p.out,
67-
ProxyURL: p.cfg.ProxyURL,
68-
ProxyPath: p.cfg.ProxyPath,
66+
ProxyURL: p.cfg.proxyURL,
67+
ProxyPath: p.cfg.proxyPath,
6968
OAuthToken: token,
7069
})
7170
}
7271

73-
func runOAuthDeviceFlow(ctx context.Context, endpoint string, out io.Writer, client oauth.Client) (*oauth.Token, error) {
74-
authResp, err := client.Start(ctx, endpoint, nil)
72+
func runOAuthDeviceFlow(ctx context.Context, endpointURL *url.URL, out io.Writer, client oauth.Client) (*oauth.Token, error) {
73+
authResp, err := client.Start(ctx, endpointURL, nil)
7574
if err != nil {
7675
return nil, err
7776
}
@@ -95,12 +94,12 @@ func runOAuthDeviceFlow(ctx context.Context, endpoint string, out io.Writer, cli
9594
interval = 5 * time.Second
9695
}
9796

98-
resp, err := client.Poll(ctx, endpoint, authResp.DeviceCode, interval, authResp.ExpiresIn)
97+
resp, err := client.Poll(ctx, endpointURL, authResp.DeviceCode, interval, authResp.ExpiresIn)
9998
if err != nil {
10099
return nil, err
101100
}
102101

103-
token := resp.Token(endpoint)
102+
token := resp.Token(endpointURL)
104103
token.ClientID = client.ClientID()
105104
return token, nil
106105
}

0 commit comments

Comments
 (0)