Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions crates/forge_infra/src/auth/strategy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1319,6 +1319,62 @@ mod tests {
assert!(matches!(actual.unwrap(), AnyAuthStrategy::CodexDevice(_)));
}

#[test]
fn test_create_auth_strategy_xai_oauth_code_uses_standard() {
// xAI is neither CLAUDE_CODE nor GITHUB_COPILOT, so the OAuthCode
// (SuperGrok loopback) flow must fall through to the generic
// StandardHttpProvider with zero per-provider code.
let config = OAuthConfig {
client_id: "b1a00492-073a-47ea-816f-4c329264a828".to_string().into(),
auth_url: Url::parse("https://auth.x.ai/oauth2/authorize").unwrap(),
token_url: Url::parse("https://auth.x.ai/oauth2/token").unwrap(),
scopes: vec!["api:access".to_string()],
redirect_uri: Some("http://127.0.0.1:56121/callback".to_string()),
use_pkce: true,
token_refresh_url: None,
extra_auth_params: None,
custom_headers: None,
};

let factory = ForgeAuthStrategyFactory;
let actual = factory
.create_auth_strategy(
ProviderId::XAI,
forge_domain::AuthMethod::OAuthCode(config),
vec![],
)
.unwrap();
assert!(matches!(actual, AnyAuthStrategy::OAuthCodeStandard(_)));
}

#[test]
fn test_create_auth_strategy_xai_oauth_device_uses_device() {
// The xAI headless device flow omits token_refresh_url, so it must
// route to the plain OAuthDevice strategy (RFC 8628), not the
// GitHub-Copilot OAuthWithApiKey hybrid.
let config = OAuthConfig {
client_id: "b1a00492-073a-47ea-816f-4c329264a828".to_string().into(),
auth_url: Url::parse("https://auth.x.ai/oauth2/device/code").unwrap(),
token_url: Url::parse("https://auth.x.ai/oauth2/token").unwrap(),
scopes: vec!["api:access".to_string()],
redirect_uri: None,
use_pkce: false,
token_refresh_url: None,
extra_auth_params: None,
custom_headers: None,
};

let factory = ForgeAuthStrategyFactory;
let actual = factory
.create_auth_strategy(
ProviderId::XAI,
forge_domain::AuthMethod::OAuthDevice(config),
vec![],
)
.unwrap();
assert!(matches!(actual, AnyAuthStrategy::OAuthDevice(_)));
}

/// Helper to build a JWT token with the given claims payload.
fn build_jwt(claims: &serde_json::Value) -> String {
use base64::Engine;
Expand Down
41 changes: 40 additions & 1 deletion crates/forge_repo/src/provider/provider.json
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,46 @@
"response_type": "OpenAI",
"url": "https://api.x.ai/v1/chat/completions",
"models": "https://api.x.ai/v1/models",
"auth_methods": ["api_key"]
"auth_methods": [
{
"oauth_code": {
"auth_url": "https://auth.x.ai/oauth2/authorize",
"token_url": "https://auth.x.ai/oauth2/token",
"client_id": "b1a00492-073a-47ea-816f-4c329264a828",
"scopes": [
"openid",
"profile",
"email",
"offline_access",
"grok-cli:access",
"api:access"
],
"redirect_uri": "http://127.0.0.1:56121/callback",
"use_pkce": true,
"extra_auth_params": {
"plan": "generic",
"referrer": "forgecode"
}
}
},
{
"oauth_device": {
"auth_url": "https://auth.x.ai/oauth2/device/code",
"token_url": "https://auth.x.ai/oauth2/token",
"client_id": "b1a00492-073a-47ea-816f-4c329264a828",
"scopes": [
"openid",
"profile",
"email",
"offline_access",
"grok-cli:access",
"api:access"
],
"use_pkce": false
}
},
"api_key"
]
},
{
"id": "openai",
Expand Down
78 changes: 78 additions & 0 deletions crates/forge_repo/src/provider/provider_repo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -713,6 +713,84 @@ mod tests {
);
}

#[test]
fn test_xai_oauth_config() {
let configs = get_provider_configs();
let config = configs.iter().find(|c| c.id == ProviderId::XAI).unwrap();

assert_eq!(config.id, ProviderId::XAI);
assert_eq!(config.api_key_vars, Some("XAI_API_KEY".to_string()));
assert_eq!(config.response_type, Some(ProviderResponse::OpenAI));
assert_eq!(config.url.as_str(), "https://api.x.ai/v1/chat/completions");

// Three auth methods: loopback OAuth, headless device OAuth, manual key.
assert_eq!(config.auth_methods.len(), 3);
assert!(config.auth_methods.contains(&AuthMethod::ApiKey));

let expected_scopes = vec![
"openid".to_string(),
"profile".to_string(),
"email".to_string(),
"offline_access".to_string(),
"grok-cli:access".to_string(),
"api:access".to_string(),
];

// Loopback authorization-code + PKCE (SuperGrok subscription).
let code = config
.auth_methods
.iter()
.find_map(|m| match m {
AuthMethod::OAuthCode(cfg) => Some(cfg),
_ => None,
})
.expect("xai should expose an oauth_code auth method");
assert_eq!(
code.client_id.as_str(),
"b1a00492-073a-47ea-816f-4c329264a828"
);
assert_eq!(code.auth_url.as_str(), "https://auth.x.ai/oauth2/authorize");
assert_eq!(code.token_url.as_str(), "https://auth.x.ai/oauth2/token");
assert_eq!(code.scopes, expected_scopes);
assert_eq!(
code.redirect_uri.as_deref(),
Some("http://127.0.0.1:56121/callback")
);
assert!(code.use_pkce);
let extra = code
.extra_auth_params
.as_ref()
.expect("oauth_code should set extra_auth_params");
// plan=generic is mandatory: xAI rejects loopback OAuth from
// non-allowlisted clients without it.
assert_eq!(extra.get("plan").map(String::as_str), Some("generic"));
assert_eq!(extra.get("referrer").map(String::as_str), Some("forgecode"));

// Headless device-code (remote / VPS). auth_url MUST be the
// device-authorization endpoint, and token_refresh_url must be absent
// so the factory routes to the plain device flow.
let device = config
.auth_methods
.iter()
.find_map(|m| match m {
AuthMethod::OAuthDevice(cfg) => Some(cfg),
_ => None,
})
.expect("xai should expose an oauth_device auth method");
assert_eq!(
device.client_id.as_str(),
"b1a00492-073a-47ea-816f-4c329264a828"
);
assert_eq!(
device.auth_url.as_str(),
"https://auth.x.ai/oauth2/device/code"
);
assert_eq!(device.token_url.as_str(), "https://auth.x.ai/oauth2/token");
assert_eq!(device.scopes, expected_scopes);
assert!(device.redirect_uri.is_none());
assert!(device.token_refresh_url.is_none());
}

#[test]
fn test_vertex_ai_config() {
let configs = get_provider_configs();
Expand Down