Skip to content

Commit c1c9d0d

Browse files
update log in flow and fix jwt claim issue
1 parent c554a57 commit c1c9d0d

10 files changed

Lines changed: 390 additions & 107 deletions

File tree

lib/src/datum_cloud/auth.rs

Lines changed: 125 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use std::{
2+
future::Future,
23
sync::{
34
Arc,
45
atomic::{AtomicU64, Ordering},
@@ -17,6 +18,7 @@ use openidconnect::{
1718
};
1819
use serde::{Deserialize, Serialize};
1920
use tokio::sync::watch;
21+
use tokio_util::sync::CancellationToken;
2022
use tracing::{debug, error, info, warn};
2123

2224
use crate::Repo;
@@ -143,13 +145,19 @@ impl StatelessClient {
143145
.build()
144146
.expect("Client should build");
145147

146-
// Use OpenID Connect Discovery to fetch the provider metadata.
148+
// Use OpenID Connect Discovery to fetch the provider metadata (including JWKs).
149+
// We fetch fresh metadata each time to avoid "No matching key found" when
150+
// Datum Cloud rotates signing keys (see datum-cloud/app#121).
147151
let provider_metadata = CoreProviderMetadata::discover_async(
148152
IssuerUrl::new(provider.issuer_url).std_context("Invalid OIDC provider issuer URL")?,
149153
&http,
150154
)
151155
.await
152156
.std_context("Failed to discover OIDC provider metadata")?;
157+
debug!(
158+
jwks_uri=?provider_metadata.jwks_uri(),
159+
"fetched fresh OIDC provider metadata"
160+
);
153161

154162
// Create an OpenID Connect client
155163
let oidc = CoreClient::from_provider_metadata(
@@ -162,7 +170,12 @@ impl StatelessClient {
162170
Ok(Self { oidc, http, env })
163171
}
164172

165-
pub async fn login(&self) -> Result<AuthState> {
173+
pub async fn login<F, Fut, C>(&self, open_url: F) -> Result<AuthState>
174+
where
175+
F: FnOnce(String, CancellationToken) -> Fut,
176+
Fut: Future<Output = Option<C>>,
177+
C: FnOnce() + Send + 'static,
178+
{
166179
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
167180

168181
let (auth_url, csrf_token, nonce) = self
@@ -183,13 +196,22 @@ impl StatelessClient {
183196
// Bind a localhost HTTP server to receive the redirect.
184197
let mut redirect_server = RedirectServer::bind(csrf_token.clone()).await?;
185198

186-
// Open the auth URL in the platform's default browser.
187-
if let Err(err) = open::that(auth_url.to_string()) {
188-
warn!("Failed to auto-open url: {err}");
189-
println!("Open this URL in a browser to complete the login:\n{auth_url}")
199+
let cancel_token = CancellationToken::new();
200+
let cancel_token_for_opener = cancel_token.clone();
201+
202+
// Open the auth URL; opener may return a close handle to close the window when done.
203+
let mut close_handle = open_url(auth_url.to_string(), cancel_token_for_opener).await;
204+
205+
let recv_result = redirect_server
206+
.recv_with_timeout(LOGIN_TIMEOUT, Some(&cancel_token))
207+
.await;
208+
209+
// Close the auth window if one was opened (e.g. in-app webview), on success or error.
210+
if let Some(close) = close_handle.take() {
211+
close();
190212
}
191213

192-
let authorization_code = redirect_server.recv_with_timeout(LOGIN_TIMEOUT).await?;
214+
let authorization_code = recv_result?;
193215
debug!("received redirect with authorization code");
194216

195217
// Exchange auth code for ID and access tokens.
@@ -257,8 +279,14 @@ impl StatelessClient {
257279

258280
let claims = id_token
259281
.claims(&id_token_verifier, nonce_verifier)
260-
.std_context("Failed to verify claims")
261-
.inspect_err(|e| error!("{e:#}"))?;
282+
.map_err(|e| {
283+
error!(
284+
error=%e,
285+
signing_alg=?id_token.signing_alg(),
286+
"Failed to verify ID token claims, try logging in again"
287+
);
288+
anyerr!("Failed to verify login. Please try again — if the problem persists, your session may need to be refreshed.")
289+
})?;
262290

263291
// Verify the access token hash to ensure that the access token hasn't been substituted for
264292
// another user's.
@@ -507,17 +535,21 @@ fn set_sentry_user(auth: Option<&AuthState>) {
507535
#[derive(derive_more::Debug, Clone)]
508536
pub struct AuthClient {
509537
state: AuthStateWrapper,
510-
client: StatelessClient,
538+
env: ApiEnv,
539+
/// OIDC client with JWKs. Swapped before each login/refresh so we always have fresh keys
540+
/// (avoids "No matching key found" when Datum Cloud rotates signing keys; datum-cloud/app#121).
541+
client: Arc<ArcSwap<StatelessClient>>,
511542
_refresh_task: Option<Arc<n0_future::task::AbortOnDropHandle<()>>>,
512543
}
513544

514545
impl AuthClient {
515546
pub async fn with_repo(env: ApiEnv, repo: Repo) -> Result<Self> {
516547
let auth = AuthStateWrapper::from_repo(repo, env.oauth_storage_key()).await?;
517-
let auth_client = StatelessClient::new(env).await?;
548+
let auth_client = Arc::new(StatelessClient::new(env).await?);
518549
let mut client = Self {
519550
state: auth,
520-
client: auth_client,
551+
env,
552+
client: Arc::new(ArcSwap::new(auth_client)),
521553
_refresh_task: None,
522554
};
523555
client.start_refresh_loop();
@@ -526,16 +558,27 @@ impl AuthClient {
526558

527559
pub async fn new(env: ApiEnv) -> Result<Self> {
528560
let auth = AuthStateWrapper::empty();
529-
let auth_client = StatelessClient::new(env).await?;
561+
let auth_client = Arc::new(StatelessClient::new(env).await?);
530562
let mut client = Self {
531563
state: auth,
532-
client: auth_client,
564+
env,
565+
client: Arc::new(ArcSwap::new(auth_client)),
533566
_refresh_task: None,
534567
};
535568
client.start_refresh_loop();
536569
Ok(client)
537570
}
538571

572+
/// Fetch fresh OIDC provider metadata (including JWKs) and swap in a new client.
573+
/// Call before login/refresh to avoid "No matching key found" when keys rotate.
574+
async fn ensure_fresh_client(&self) -> Result<Arc<StatelessClient>> {
575+
let fresh = Arc::new(
576+
StatelessClient::with_provider(self.env, self.env.auth_provider()).await?,
577+
);
578+
self.client.store(fresh.clone());
579+
Ok(fresh)
580+
}
581+
539582
pub fn login_state(&self) -> LoginState {
540583
match self.state.load().get().ok() {
541584
None => LoginState::Missing,
@@ -624,15 +667,44 @@ impl AuthClient {
624667
}
625668

626669
pub async fn login(&self) -> Result<()> {
670+
self.login_with_opener(|url, _cancel_token| async move {
671+
if let Err(err) = open::that(&url) {
672+
warn!("Failed to auto-open url: {err}");
673+
eprintln!("Open this URL in a browser to complete the login:\n{url}");
674+
}
675+
None::<Box<dyn FnOnce() + Send>>
676+
})
677+
.await
678+
}
679+
680+
pub async fn login_with_opener<F, Fut, C>(&self, open_url: F) -> Result<()>
681+
where
682+
F: FnOnce(String, CancellationToken) -> Fut,
683+
Fut: Future<Output = Option<C>>,
684+
C: FnOnce() + Send + 'static,
685+
{
627686
let auth = self.state.load();
628687
let auth = match auth.get() {
629-
Err(_) => self.client.login().await?,
688+
Err(_) => {
689+
let client = self.ensure_fresh_client().await?;
690+
client.login(open_url).await?
691+
}
630692
Ok(auth) if auth.tokens.expires_in_less_than(REFRESH_AUTH_WHEN) => {
631-
match self.client.refresh(&auth.tokens).await {
693+
let client = self.ensure_fresh_client().await?;
694+
match client.refresh(&auth.tokens).await {
632695
Ok(auth) => auth,
633696
Err(err) => {
634697
warn!("Failed to refresh auth token: {err:#}");
635-
self.client.login().await?
698+
let client = self.ensure_fresh_client().await?;
699+
client
700+
.login(|url, _cancel_token| async move {
701+
if let Err(e) = open::that(&url) {
702+
warn!("Failed to auto-open url: {e}");
703+
eprintln!("Open this URL in a browser to complete the login:\n{url}");
704+
}
705+
None::<Box<dyn FnOnce() + Send>>
706+
})
707+
.await?
636708
}
637709
}
638710
}
@@ -645,7 +717,8 @@ impl AuthClient {
645717
pub async fn refresh(&self) -> Result<()> {
646718
let auth = self.state.load();
647719
let auth = auth.get()?;
648-
let new_auth = match self.client.refresh(&auth.tokens).await {
720+
let client = self.ensure_fresh_client().await?;
721+
let new_auth = match client.refresh(&auth.tokens).await {
649722
Ok(auth) => auth,
650723
Err(err) => {
651724
warn!("Failed to refresh auth tokens, logging out: {err:#}");
@@ -664,6 +737,7 @@ impl AuthClient {
664737
let user_id = auth.profile.user_id.clone();
665738
let new_profile = self
666739
.client
740+
.load()
667741
.fetch_user_profile(&auth.tokens, &user_id)
668742
.await?;
669743
let new_auth = AuthState {
@@ -740,14 +814,15 @@ mod redirect_server {
740814
extract::{Query, State},
741815
routing::get,
742816
};
743-
use n0_error::StdResultExt;
817+
use n0_error::{StdResultExt, anyerr};
744818
use openidconnect::{CsrfToken, RedirectUrl};
745819
use serde::Deserialize;
746820
use std::{
747821
net::{Ipv4Addr, SocketAddr},
748822
time::Duration,
749823
};
750-
use tokio::{net::TcpListener, sync::mpsc};
824+
use tokio::net::TcpSocket;
825+
use tokio::sync::mpsc;
751826
use tokio_util::sync::CancellationToken;
752827
use tracing::{Instrument, debug, instrument, warn};
753828

@@ -776,7 +851,10 @@ mod redirect_server {
776851
let app = Router::new()
777852
.route("/oauth/redirect", get(oauth_redirect))
778853
.with_state(state);
779-
let listener = TcpListener::bind(bind_addr).await?;
854+
let socket = TcpSocket::new_v4()?;
855+
socket.set_reuseaddr(true)?;
856+
socket.bind(bind_addr)?;
857+
let listener = socket.listen(128)?;
780858
debug!(addr=%bind_addr, "OIDC redirect HTTP server listening");
781859

782860
tokio::spawn({
@@ -810,10 +888,21 @@ mod redirect_server {
810888
.expect("valid url")
811889
}
812890

813-
pub async fn recv_with_timeout(&mut self, timeout: Duration) -> n0_error::Result<String> {
814-
let res = tokio::time::timeout(timeout, self.recv()).await;
891+
pub async fn recv_with_timeout(
892+
&mut self,
893+
timeout: Duration,
894+
cancel: Option<&CancellationToken>,
895+
) -> n0_error::Result<String> {
896+
let res = if let Some(cancel_token) = cancel {
897+
tokio::select! {
898+
_ = cancel_token.cancelled() => Err(anyerr!("Login cancelled")),
899+
r = tokio::time::timeout(timeout, self.recv()) => r.anyerr()?,
900+
}
901+
} else {
902+
tokio::time::timeout(timeout, self.recv()).await.anyerr()?
903+
};
815904
self.cancel_token.cancel();
816-
res.anyerr()?
905+
res
817906
}
818907

819908
pub async fn recv(&mut self) -> n0_error::Result<String> {
@@ -844,9 +933,19 @@ mod redirect_server {
844933
sender: mpsc::Sender<n0_error::Result<OauthRedirectData>>,
845934
}
846935

847-
async fn oauth_redirect(state: State<AppState>, query: Query<OauthRedirectData>) -> String {
936+
async fn oauth_redirect(state: State<AppState>, query: Query<OauthRedirectData>) -> axum::response::Html<String> {
848937
let data = query.0;
849938
state.sender.send(Ok(data)).await.ok();
850-
"You are now logged in and can close this window.".to_string()
939+
axum::response::Html(
940+
r#"<!DOCTYPE html>
941+
<html>
942+
<head><meta charset="utf-8"><title>Login complete</title></head>
943+
<body>
944+
<p>You are now logged in. This window will close automatically.</p>
945+
<script>window.close();</script>
946+
</body>
947+
</html>"#
948+
.to_string(),
949+
)
851950
}
852951
}

ui/assets/icons/move-right.svg

Lines changed: 1 addition & 0 deletions
Loading

0 commit comments

Comments
 (0)