Skip to content

Commit b4ccc25

Browse files
committed
feat: wizard can be used when piping
also remove dead code
1 parent fb8d1a7 commit b4ccc25

4 files changed

Lines changed: 234 additions & 184 deletions

File tree

src/auth.rs

Lines changed: 78 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -784,10 +784,6 @@ async fn run_login_set(base: &BaseArgs, args: AuthLoginArgs) -> Result<()> {
784784
}
785785

786786
async fn run_login_oauth(base: &BaseArgs, args: AuthLoginArgs) -> Result<()> {
787-
if !ui::is_interactive() {
788-
bail!("oauth login requires an interactive terminal");
789-
}
790-
791787
let api_url = base
792788
.api_url
793789
.clone()
@@ -860,7 +856,7 @@ async fn run_login_oauth(base: &BaseArgs, args: AuthLoginArgs) -> Result<()> {
860856
let selected_org = select_login_org(
861857
login_orgs.clone(),
862858
base.org_name.as_deref(),
863-
true,
859+
ui::can_prompt(),
864860
args.verbose,
865861
true,
866862
)?;
@@ -958,7 +954,7 @@ pub(crate) async fn login_interactive_oauth(base: &mut BaseArgs) -> Result<Strin
958954
let selected_org = select_login_org(
959955
login_orgs.clone(),
960956
base.org_name.as_deref(),
961-
true,
957+
ui::can_prompt(),
962958
false,
963959
false,
964960
)?;
@@ -1633,7 +1629,7 @@ fn select_login_org(
16331629
}
16341630
}));
16351631
let label_refs: Vec<&str> = labels.iter().map(String::as_str).collect();
1636-
println!("\n\nA Braintrust organization is usually a team or a company.");
1632+
eprintln!("\n\nA Braintrust organization is usually a team or a company.");
16371633
let selection = ui::fuzzy_select("Select organization", &label_refs, 0)?;
16381634
if allow_cross_org && selection == 0 {
16391635
return Ok(None);
@@ -1710,6 +1706,29 @@ struct OAuthCallbackParams {
17101706
error: Option<String>,
17111707
}
17121708

1709+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1710+
enum OAuthCallbackMode {
1711+
ListenerOnly,
1712+
ListenerOrStdin,
1713+
PromptThenListener,
1714+
}
1715+
1716+
fn oauth_callback_mode(prefer_manual: bool) -> OAuthCallbackMode {
1717+
if prefer_manual {
1718+
if ui::can_prompt() {
1719+
OAuthCallbackMode::PromptThenListener
1720+
} else {
1721+
OAuthCallbackMode::ListenerOnly
1722+
}
1723+
} else if ui::is_interactive() {
1724+
OAuthCallbackMode::ListenerOrStdin
1725+
} else if ui::can_prompt() {
1726+
OAuthCallbackMode::PromptThenListener
1727+
} else {
1728+
OAuthCallbackMode::ListenerOnly
1729+
}
1730+
}
1731+
17131732
async fn wait_for_oauth_callback(listener: TcpListener) -> Result<OAuthCallbackParams> {
17141733
let (mut stream, _) = tokio::time::timeout(OAUTH_CALLBACK_TIMEOUT, listener.accept())
17151734
.await
@@ -1758,27 +1777,34 @@ async fn collect_oauth_callback(
17581777
listener: TcpListener,
17591778
prefer_manual: bool,
17601779
) -> Result<OAuthCallbackParams> {
1761-
if !prefer_manual {
1762-
return wait_for_oauth_callback_or_stdin(listener).await;
1763-
}
1764-
1765-
println!("Remote/SSH OAuth flow: open the URL in a browser on your local machine.");
1766-
println!(
1767-
"After approving access, your browser may show a localhost connection error on remote hosts."
1768-
);
1769-
println!(
1770-
"Copy the full URL from the browser address bar (or just code=...&state=...) and paste it below."
1771-
);
1772-
let pasted = Input::<String>::new()
1773-
.with_prompt("Callback URL/query/JSON (press Enter to wait for automatic callback)")
1774-
.allow_empty(true)
1775-
.report(false)
1776-
.interact_text()
1777-
.context("failed to read callback URL")?;
1778-
if pasted.trim().is_empty() {
1779-
return wait_for_oauth_callback(listener).await;
1780+
match oauth_callback_mode(prefer_manual) {
1781+
OAuthCallbackMode::ListenerOnly => {
1782+
eprintln!("Waiting for browser authorization...");
1783+
wait_for_oauth_callback(listener).await
1784+
}
1785+
OAuthCallbackMode::ListenerOrStdin => wait_for_oauth_callback_or_stdin(listener).await,
1786+
OAuthCallbackMode::PromptThenListener => {
1787+
let term = ui::prompt_term()
1788+
.ok_or_else(|| anyhow::anyhow!("interactive mode requires TTY"))?;
1789+
println!("Remote/SSH OAuth flow: open the URL in a browser on your local machine.");
1790+
println!(
1791+
"After approving access, your browser may show a localhost connection error on remote hosts."
1792+
);
1793+
println!(
1794+
"Copy the full URL from the browser address bar (or just code=...&state=...) and paste it below."
1795+
);
1796+
let pasted = Input::<String>::new()
1797+
.with_prompt("Callback URL/query/JSON (press Enter to wait for automatic callback)")
1798+
.allow_empty(true)
1799+
.report(false)
1800+
.interact_text_on(&term)
1801+
.context("failed to read callback URL")?;
1802+
if pasted.trim().is_empty() {
1803+
return wait_for_oauth_callback(listener).await;
1804+
}
1805+
parse_oauth_callback_input(&pasted)
1806+
}
17801807
}
1781-
parse_oauth_callback_input(&pasted)
17821808
}
17831809

17841810
async fn wait_for_oauth_callback_or_stdin(listener: TcpListener) -> Result<OAuthCallbackParams> {
@@ -3288,6 +3314,31 @@ mod tests {
32883314
);
32893315
}
32903316

3317+
#[tokio::test]
3318+
async fn oauth_callback_mode_uses_listener_only_when_input_is_disabled() {
3319+
let _guard = env_test_lock().lock().await;
3320+
ui::set_no_input(true);
3321+
assert_eq!(oauth_callback_mode(false), OAuthCallbackMode::ListenerOnly);
3322+
assert_eq!(oauth_callback_mode(true), OAuthCallbackMode::ListenerOnly);
3323+
ui::set_no_input(false);
3324+
}
3325+
3326+
#[test]
3327+
fn oauth_callback_mode_prefers_manual_prompt_when_interactive() {
3328+
ui::set_no_input(false);
3329+
3330+
if ui::is_interactive() {
3331+
assert_eq!(
3332+
oauth_callback_mode(true),
3333+
OAuthCallbackMode::PromptThenListener
3334+
);
3335+
assert_eq!(
3336+
oauth_callback_mode(false),
3337+
OAuthCallbackMode::ListenerOrStdin
3338+
);
3339+
}
3340+
}
3341+
32913342
#[tokio::test]
32923343
async fn login_read_only_no_cached_project_id_uses_validated_login_path() {
32933344
let env = TestEnv::new(None, None).await;

0 commit comments

Comments
 (0)