Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,5 +57,6 @@ markers = [
"config: Configuration parsing tests",
"session: Client session management tests",
"auth: Authentication/Authorization tests",
"generics: Generic endpoint tests",
]
asyncio_mode = "auto"
4 changes: 4 additions & 0 deletions src/asfquart/generics.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ async def oauth_endpoint():
login_uri = quart.request.args.get("login")
logout_uri = quart.request.args.get("logout")
if login_uri or quart.request.query_string == b"login":
if login_uri and ((not login_uri.startswith("/")) or login_uri.startswith("//")):
return quart.Response(status=400, response="Invalid redirect URI.\n")
state = secrets.token_hex(16)
# Save the time we initialized this state and the optional login redirect URI
pending_states[state] = [time.time(), login_uri]
Expand All @@ -55,6 +57,8 @@ async def oauth_endpoint():
elif logout_uri or quart.request.query_string == b"logout":
asfquart.session.clear()
if logout_uri: # if called with /auth=logout=/foo, redirect to /foo
if (not logout_uri.startswith("/")) or logout_uri.startswith("//"):
return quart.Response(status=400, response="Invalid redirect URI.\n")
return quart.redirect(logout_uri)
return quart.Response(
status=200,
Expand Down
147 changes: 147 additions & 0 deletions tests/generics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
#!/usr/bin/env python3
"""Tests for generics.py — redirect URI validation (CWE-601, CWE-79)"""

import itertools

import pytest
import quart

import asfquart


# Counter for unique app names to avoid duplicate route registration
_counter = itertools.count()


def _make_app():
"""Create a minimal Quart app with the OAuth endpoint for testing.
asfquart.construct() calls setup_oauth() internally when oauth=True (the default),
so we do NOT call setup_oauth() again here.
"""
name = f"test_generics_{next(_counter)}"
app = asfquart.construct(name, token_file=None)
return app


# ---- Endpoint integration tests ----

@pytest.mark.generics
async def test_login_with_valid_redirect():
"""?login=/dashboard should initiate OAuth flow (302 to oauth.apache.org)."""
app = _make_app()
async with app.test_app():
client = app.test_client()
resp = await client.get("/auth?login=/dashboard")
assert resp.status_code == 302
location = resp.headers.get("Location", "")
assert "oauth.apache.org" in location


@pytest.mark.generics
async def test_login_bare():
"""Bare ?login (no redirect value) should initiate OAuth flow normally."""
app = _make_app()
async with app.test_app():
client = app.test_client()
resp = await client.get("/auth?login")
assert resp.status_code == 302
location = resp.headers.get("Location", "")
assert "oauth.apache.org" in location


@pytest.mark.generics
async def test_login_rejects_javascript_uri():
"""?login=javascript:... must return 400."""
app = _make_app()
async with app.test_app():
client = app.test_client()
resp = await client.get("/auth?login=javascript:alert(1)")
assert resp.status_code == 400
body = (await resp.get_data()).decode()
assert "Invalid redirect" in body


@pytest.mark.generics
async def test_login_rejects_absolute_url():
"""?login=https://evil.com must return 400."""
app = _make_app()
async with app.test_app():
client = app.test_client()
resp = await client.get("/auth?login=https://evil.com")
assert resp.status_code == 400


@pytest.mark.generics
async def test_login_rejects_protocol_relative():
"""?login=//evil.com must return 400."""
app = _make_app()
async with app.test_app():
client = app.test_client()
resp = await client.get("/auth?login=//evil.com")
assert resp.status_code == 400


@pytest.mark.generics
async def test_login_rejects_data_uri():
"""?login=data:text/html,... must return 400."""
app = _make_app()
async with app.test_app():
client = app.test_client()
resp = await client.get("/auth?login=data:text/html,<script>alert(1)</script>")
assert resp.status_code == 400


@pytest.mark.generics
async def test_logout_rejects_javascript_uri():
"""?logout=javascript:... must return 400."""
app = _make_app()
async with app.test_app():
client = app.test_client()
resp = await client.get("/auth?logout=javascript:alert(1)")
assert resp.status_code == 400
body = (await resp.get_data()).decode()
assert "Invalid redirect" in body


@pytest.mark.generics
async def test_logout_rejects_absolute_url():
"""?logout=https://evil.com must return 400."""
app = _make_app()
async with app.test_app():
client = app.test_client()
resp = await client.get("/auth?logout=https://evil.com")
assert resp.status_code == 400


@pytest.mark.generics
async def test_logout_bare():
"""Bare ?logout (no redirect value) should clear session and return 200."""
app = _make_app()
async with app.test_app():
client = app.test_client()
resp = await client.get("/auth?logout")
assert resp.status_code == 200
body = (await resp.get_data()).decode()
assert "goodbye" in body.lower()


@pytest.mark.generics
async def test_logout_with_valid_redirect():
"""?logout=/goodbye should clear session and redirect."""
app = _make_app()
async with app.test_app():
client = app.test_client()
resp = await client.get("/auth?logout=/goodbye")
assert resp.status_code == 302
location = resp.headers.get("Location", "")
assert "/goodbye" in location


@pytest.mark.generics
async def test_no_session_returns_404():
"""Bare /auth with no session should return 404."""
app = _make_app()
async with app.test_app():
client = app.test_client()
resp = await client.get("/auth")
assert resp.status_code == 404
Loading