From 77fb30d585d88a341b7d0e5aec40d3fa12bababf Mon Sep 17 00:00:00 2001 From: "Andrew K. Musselman" Date: Fri, 27 Feb 2026 18:29:00 -0800 Subject: [PATCH] Adding checks for redirect url and tests; fixes #58 --- pyproject.toml | 1 + src/asfquart/generics.py | 4 ++ tests/generics.py | 147 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 152 insertions(+) create mode 100644 tests/generics.py diff --git a/pyproject.toml b/pyproject.toml index 3cf9053..f09610a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,5 +57,6 @@ markers = [ "config: Configuration parsing tests", "session: Client session management tests", "auth: Authentication/Authorization tests", + "generics: Generic endpoint tests", ] asyncio_mode = "auto" diff --git a/src/asfquart/generics.py b/src/asfquart/generics.py index 1b9e5e6..233dac7 100644 --- a/src/asfquart/generics.py +++ b/src/asfquart/generics.py @@ -40,6 +40,8 @@ async def oauth_endpoint(): login_uri = quart.request.args.get("login") logout_uri = quart.request.args.get("logout") if login_uri or quart.request.query_string == b"login": + if login_uri and ((not login_uri.startswith("/")) or login_uri.startswith("//")): + return quart.Response(status=400, response="Invalid redirect URI.\n") state = secrets.token_hex(16) # Save the time we initialized this state and the optional login redirect URI pending_states[state] = [time.time(), login_uri] @@ -55,6 +57,8 @@ async def oauth_endpoint(): elif logout_uri or quart.request.query_string == b"logout": asfquart.session.clear() if logout_uri: # if called with /auth=logout=/foo, redirect to /foo + if (not logout_uri.startswith("/")) or logout_uri.startswith("//"): + return quart.Response(status=400, response="Invalid redirect URI.\n") return quart.redirect(logout_uri) return quart.Response( status=200, diff --git a/tests/generics.py b/tests/generics.py new file mode 100644 index 0000000..9ce41f7 --- /dev/null +++ b/tests/generics.py @@ -0,0 +1,147 @@ +#!/usr/bin/env python3 +"""Tests for generics.py — redirect URI validation (CWE-601, CWE-79)""" + +import itertools + +import pytest +import quart + +import asfquart + + +# Counter for unique app names to avoid duplicate route registration +_counter = itertools.count() + + +def _make_app(): + """Create a minimal Quart app with the OAuth endpoint for testing. + asfquart.construct() calls setup_oauth() internally when oauth=True (the default), + so we do NOT call setup_oauth() again here. + """ + name = f"test_generics_{next(_counter)}" + app = asfquart.construct(name, token_file=None) + return app + + +# ---- Endpoint integration tests ---- + +@pytest.mark.generics +async def test_login_with_valid_redirect(): + """?login=/dashboard should initiate OAuth flow (302 to oauth.apache.org).""" + app = _make_app() + async with app.test_app(): + client = app.test_client() + resp = await client.get("/auth?login=/dashboard") + assert resp.status_code == 302 + location = resp.headers.get("Location", "") + assert "oauth.apache.org" in location + + +@pytest.mark.generics +async def test_login_bare(): + """Bare ?login (no redirect value) should initiate OAuth flow normally.""" + app = _make_app() + async with app.test_app(): + client = app.test_client() + resp = await client.get("/auth?login") + assert resp.status_code == 302 + location = resp.headers.get("Location", "") + assert "oauth.apache.org" in location + + +@pytest.mark.generics +async def test_login_rejects_javascript_uri(): + """?login=javascript:... must return 400.""" + app = _make_app() + async with app.test_app(): + client = app.test_client() + resp = await client.get("/auth?login=javascript:alert(1)") + assert resp.status_code == 400 + body = (await resp.get_data()).decode() + assert "Invalid redirect" in body + + +@pytest.mark.generics +async def test_login_rejects_absolute_url(): + """?login=https://evil.com must return 400.""" + app = _make_app() + async with app.test_app(): + client = app.test_client() + resp = await client.get("/auth?login=https://evil.com") + assert resp.status_code == 400 + + +@pytest.mark.generics +async def test_login_rejects_protocol_relative(): + """?login=//evil.com must return 400.""" + app = _make_app() + async with app.test_app(): + client = app.test_client() + resp = await client.get("/auth?login=//evil.com") + assert resp.status_code == 400 + + +@pytest.mark.generics +async def test_login_rejects_data_uri(): + """?login=data:text/html,... must return 400.""" + app = _make_app() + async with app.test_app(): + client = app.test_client() + resp = await client.get("/auth?login=data:text/html,") + assert resp.status_code == 400 + + +@pytest.mark.generics +async def test_logout_rejects_javascript_uri(): + """?logout=javascript:... must return 400.""" + app = _make_app() + async with app.test_app(): + client = app.test_client() + resp = await client.get("/auth?logout=javascript:alert(1)") + assert resp.status_code == 400 + body = (await resp.get_data()).decode() + assert "Invalid redirect" in body + + +@pytest.mark.generics +async def test_logout_rejects_absolute_url(): + """?logout=https://evil.com must return 400.""" + app = _make_app() + async with app.test_app(): + client = app.test_client() + resp = await client.get("/auth?logout=https://evil.com") + assert resp.status_code == 400 + + +@pytest.mark.generics +async def test_logout_bare(): + """Bare ?logout (no redirect value) should clear session and return 200.""" + app = _make_app() + async with app.test_app(): + client = app.test_client() + resp = await client.get("/auth?logout") + assert resp.status_code == 200 + body = (await resp.get_data()).decode() + assert "goodbye" in body.lower() + + +@pytest.mark.generics +async def test_logout_with_valid_redirect(): + """?logout=/goodbye should clear session and redirect.""" + app = _make_app() + async with app.test_app(): + client = app.test_client() + resp = await client.get("/auth?logout=/goodbye") + assert resp.status_code == 302 + location = resp.headers.get("Location", "") + assert "/goodbye" in location + + +@pytest.mark.generics +async def test_no_session_returns_404(): + """Bare /auth with no session should return 404.""" + app = _make_app() + async with app.test_app(): + client = app.test_client() + resp = await client.get("/auth") + assert resp.status_code == 404