Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 105 additions & 3 deletions pkg/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,73 @@ type Auth struct {
shouldLogin func() (bool, error)
}

const BrevAPIKeyPrefix = "bak-"

const MissingAPIKeyOrgIDMessage = "api key auth requires an org id; run brev login --api-key <api-key> --org-id <org-id>"

type APIKeyAuthStore interface {
GetAuthTokens() (*entity.AuthTokens, error)
}

type CurrentUserAuthStore interface {
APIKeyAuthStore
GetCurrentUser() (*entity.User, error)
}

type CLIAuth struct {
apiKey bool
user *entity.User
}

func (a CLIAuth) IsAPIKey() bool {
return a.apiKey
}

func (a CLIAuth) User() *entity.User {
return a.user
}

func ResolveCLIAuth(store CurrentUserAuthStore) (CLIAuth, error) {
if IsAPIKeyAuthStore(store) {
return CLIAuth{apiKey: true}, nil
}
user, err := store.GetCurrentUser()
if err != nil {
return CLIAuth{}, breverrors.WrapAndTrace(err)
}
return CLIAuth{user: user}, nil
}

func IsBrevAPIKey(token string) bool {
return strings.HasPrefix(strings.TrimSpace(token), BrevAPIKeyPrefix)
}

func IsAPIKeyAuthStore(authTokensProvider APIKeyAuthStore) bool {
tokens, err := authTokensProvider.GetAuthTokens()
if err != nil {
return false
}
if tokens == nil {
return false
}
return IsBrevAPIKey(tokens.APIKey)
}

func GetAPIKeyOrgID(authTokensProvider APIKeyAuthStore) (string, error) {
tokens, err := authTokensProvider.GetAuthTokens()
if err != nil {
return "", breverrors.WrapAndTrace(err)
}
if tokens == nil {
return "", breverrors.NewValidationError(MissingAPIKeyOrgIDMessage)
}
orgID := strings.TrimSpace(tokens.APIKeyOrgID)
if orgID == "" {
return "", breverrors.NewValidationError(MissingAPIKeyOrgIDMessage)
}
return orgID, nil
}

func NewAuth(authStore AuthStore, oauth OAuth) *Auth {
return &Auth{
authStore: authStore,
Expand Down Expand Up @@ -146,6 +213,11 @@ func (t Auth) GetFreshAccessTokenOrNil() (string, error) {
return "", nil
}

apiKey := strings.TrimSpace(tokens.APIKey)
if apiKey != "" {
return apiKey, nil
}

// should always at least have access token?
if tokens.AccessToken == "" {
breverrors.GetDefaultErrorReporter().ReportMessage("access token is an empty string but shouldn't be")
Expand Down Expand Up @@ -222,6 +294,36 @@ func (t Auth) LoginWithToken(token string) error {
return nil
}

func (t Auth) LoginWithAPIKey(apiKey string, orgID string) error {
apiKey = strings.TrimSpace(apiKey)
if apiKey == "" {
return breverrors.NewValidationError("api key is empty")
}
if !IsBrevAPIKey(apiKey) {
return breverrors.NewValidationError(fmt.Sprintf("api key must start with %s", BrevAPIKeyPrefix))
}
orgID = strings.TrimSpace(orgID)
if orgID == "" {
return breverrors.NewValidationError(MissingAPIKeyOrgIDMessage)
}

tokens, err := t.getSavedTokensOrNil()
if err != nil {
return breverrors.WrapAndTrace(err)
}
if tokens == nil {
tokens = &entity.AuthTokens{}
}
tokens.APIKey = apiKey
tokens.APIKeyOrgID = orgID

err = t.authStore.SaveAuthTokens(*tokens)
if err != nil {
return breverrors.WrapAndTrace(err)
}
return nil
}

// showLoginURL displays the login link and CLI alternative for manual navigation.
func showLoginURL(url string) {
urlType := color.New(color.FgCyan, color.Bold).SprintFunc()
Expand Down Expand Up @@ -313,7 +415,7 @@ func (t Auth) getSavedTokensOrNil() (*entity.AuthTokens, error) {
}
return nil, breverrors.WrapAndTrace(err)
}
if tokens != nil && tokens.AccessToken == "" && tokens.RefreshToken == "" {
if tokens != nil && tokens.AccessToken == "" && tokens.RefreshToken == "" && tokens.APIKey == "" {
return nil, nil
}
return tokens, nil
Expand Down Expand Up @@ -415,7 +517,7 @@ func AuthProviderFlagToCredentialProvider(authProviderFlag string) entity.Creden
func StandardLogin(authProvider string, email string, tokens *entity.AuthTokens) OAuth {
// Set KAS as the default authenticator
shouldPromptEmail := false
if email == "" && tokens != nil && tokens.AccessToken != "" {
if email == "" && tokens != nil && tokens.AccessToken != "" && tokens.APIKey == "" {
email = GetEmailFromToken(tokens.AccessToken)
shouldPromptEmail = true
}
Expand Down Expand Up @@ -445,7 +547,7 @@ func StandardLogin(authProvider string, email string, tokens *entity.AuthTokens)
kasAuthenticator,
})

if tokens != nil && tokens.AccessToken != "" {
if tokens != nil && tokens.AccessToken != "" && tokens.APIKey == "" {
authenticatorFromToken, errr := authRetriever.GetByToken(tokens.AccessToken)
if errr != nil {
fmt.Printf("%v\n", errr)
Expand Down
Loading
Loading