diff --git a/crates/sts/src/lib.rs b/crates/sts/src/lib.rs index 9b3e54d..d2911a4 100644 --- a/crates/sts/src/lib.rs +++ b/crates/sts/src/lib.rs @@ -12,7 +12,7 @@ //! use multistore_sts::route_handler::StsRouterExt; //! //! let router = Router::new() -//! .with_sts(config, jwks_cache, token_key); +//! .with_sts("/.sts", config, jwks_cache, token_key); //! ``` //! //! # Flow diff --git a/crates/sts/src/route_handler.rs b/crates/sts/src/route_handler.rs index 7aae62e..298e57f 100644 --- a/crates/sts/src/route_handler.rs +++ b/crates/sts/src/route_handler.rs @@ -27,13 +27,14 @@ impl RouteHandler for StsHandler { /// Extension trait for registering STS routes on a [`Router`]. pub trait StsRouterExt { - /// Register the STS handler on the root path (`/`). + /// Register the STS handler on the given `path`. /// /// STS requests are identified by query parameters - /// (`Action=AssumeRoleWithWebIdentity`), not by path, and clients - /// always send them to `/`. + /// (`Action=AssumeRoleWithWebIdentity`), not by path, so any path + /// can be used (e.g. `"/"` or `"/.sts"`). fn with_sts( self, + path: &str, config: C, cache: JwksCache, key: Option, @@ -43,11 +44,12 @@ pub trait StsRouterExt { impl StsRouterExt for Router { fn with_sts( self, + path: &str, config: C, cache: JwksCache, key: Option, ) -> Self { - self.route("/", StsHandler { config, cache, key }) + self.route(path, StsHandler { config, cache, key }) } } @@ -75,7 +77,7 @@ mod tests { fn test_router() -> Router { let cache = JwksCache::new(reqwest::Client::new(), std::time::Duration::from_secs(60)); - Router::new().with_sts(EmptyRegistry, cache, None) + Router::new().with_sts("/", EmptyRegistry, cache, None) } #[tokio::test] diff --git a/examples/cf-workers/src/lib.rs b/examples/cf-workers/src/lib.rs index f53bf3c..162d0d8 100644 --- a/examples/cf-workers/src/lib.rs +++ b/examples/cf-workers/src/lib.rs @@ -77,7 +77,7 @@ async fn fetch(req: web_sys::Request, env: Env, _ctx: Context) -> Result Result<(), Error> { if let (Some(signer), Some(issuer)) = (oidc_signer, oidc_issuer) { router = router.with_oidc_discovery(issuer, vec![signer]); } - router = router.with_sts(sts_creds, jwks_cache, token_key.clone()); + router = router.with_sts("/.sts", sts_creds, jwks_cache, token_key.clone()); // Build the gateway with the router. let mut handler = ProxyGateway::new(backend, config.clone(), config, domain) diff --git a/examples/server/src/server.rs b/examples/server/src/server.rs index 5ca6c97..5d671e0 100644 --- a/examples/server/src/server.rs +++ b/examples/server/src/server.rs @@ -115,7 +115,7 @@ where if let (Some(signer), Some(issuer)) = (oidc_signer, oidc_issuer) { proxy_router = proxy_router.with_oidc_discovery(issuer, vec![signer]); } - proxy_router = proxy_router.with_sts(sts_creds, jwks_cache, token_key.clone()); + proxy_router = proxy_router.with_sts("/.sts", sts_creds, jwks_cache, token_key.clone()); // Build the gateway with the router. let mut handler = ProxyGateway::new( diff --git a/tests/integration/test_integration.py b/tests/integration/test_integration.py index c412362..01071b3 100644 --- a/tests/integration/test_integration.py +++ b/tests/integration/test_integration.py @@ -29,7 +29,7 @@ def assume_role(role_arn: str, oidc_token: str) -> dict: """Assume a role via the STS proxy and return parsed credentials.""" resp = requests.get( - PROXY_URL, + f"{PROXY_URL}/.sts", params={ "Action": "AssumeRoleWithWebIdentity", "RoleArn": role_arn, diff --git a/tests/smoke/test_smoke.py b/tests/smoke/test_smoke.py index a683942..dc9735b 100644 --- a/tests/smoke/test_smoke.py +++ b/tests/smoke/test_smoke.py @@ -21,7 +21,7 @@ def assume_role(role_arn: str, oidc_token: str) -> dict: """Assume a role via the STS proxy and return parsed credentials.""" resp = requests.get( - DEPLOY_URL, + f"{DEPLOY_URL}/.sts", params={ "Action": "AssumeRoleWithWebIdentity", "RoleArn": role_arn,