Skip to content

Commit bd1adbc

Browse files
committed
feat: add exception handling for SuperTokens in Litestar framework
1 parent a9690e6 commit bd1adbc

File tree

4 files changed

+202
-10
lines changed

4 files changed

+202
-10
lines changed

supertokens_python/framework/litestar/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,17 @@
1212
# License for the specific language governing permissions and limitations
1313
# under the License.
1414

15+
from .litestar_exception_handlers import (
16+
get_exception_handlers,
17+
supertokens_exception_handler,
18+
)
1519
from .litestar_middleware import create_supertokens_middleware
1620
from .litestar_plugin import SupertokensPlugin, get_supertokens_plugin
1721

1822
__all__ = [
1923
"SupertokensPlugin",
2024
"get_supertokens_plugin",
2125
"create_supertokens_middleware",
26+
"get_exception_handlers",
27+
"supertokens_exception_handler",
2228
]
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved.
2+
#
3+
# This software is licensed under the Apache License, Version 2.0 (the
4+
# "License") as published by the Apache Software Foundation.
5+
#
6+
# You may not use this file except in compliance with the License. You may
7+
# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11+
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12+
# License for the specific language governing permissions and limitations
13+
# under the License.
14+
import asyncio
15+
from typing import Any
16+
17+
from litestar import Request, Response
18+
19+
from supertokens_python.exceptions import SuperTokensError
20+
21+
22+
def supertokens_exception_handler(
23+
request: Request[Any, Any, Any], exc: SuperTokensError
24+
) -> Response[Any]:
25+
"""
26+
Exception handler for SuperTokens errors in Litestar applications.
27+
28+
This handler intercepts SuperTokens exceptions and converts them to proper
29+
HTTP responses with session management.
30+
31+
Note: This is a synchronous wrapper around async SuperTokens error handling.
32+
Litestar requires exception handlers to be synchronous, so we run the async
33+
logic in the event loop.
34+
35+
Args:
36+
request: The Litestar request object
37+
exc: The SuperTokens exception
38+
39+
Returns:
40+
A Litestar Response object with proper status code and session cookies
41+
"""
42+
from supertokens_python import Supertokens
43+
from supertokens_python.framework.litestar.litestar_request import LitestarRequest
44+
from supertokens_python.framework.litestar.litestar_response import LitestarResponse
45+
from supertokens_python.utils import default_user_context
46+
47+
async def handle_async() -> Response[Any]:
48+
"""Async logic for handling the SuperTokens error"""
49+
st = Supertokens.get_instance()
50+
custom_request = LitestarRequest(request)
51+
user_context = default_user_context(custom_request)
52+
53+
# Create a response for SuperTokens to populate
54+
response_obj = Response(content=None)
55+
response = LitestarResponse(response_obj)
56+
57+
# Handle the error through SuperTokens
58+
# This will modify the response object with proper status code,
59+
# clear tokens, and set appropriate headers
60+
result = await st.handle_supertokens_error(
61+
custom_request,
62+
exc,
63+
response,
64+
user_context, # type: ignore
65+
)
66+
67+
# Return the modified Litestar response
68+
# The response object has been updated by SuperTokens error handlers
69+
if isinstance(result, LitestarResponse):
70+
litestar_response = result.response
71+
72+
# Clear the session from request.state to prevent the middleware
73+
# from re-applying session cookies after we've cleared them
74+
if hasattr(request.state, "supertokens"):
75+
delattr(request.state, "supertokens")
76+
77+
# Litestar stores cookies separately from headers. When an exception
78+
# handler returns a response, Litestar will automatically convert cookies
79+
# to Set-Cookie headers. So we can just return the response as-is.
80+
return litestar_response
81+
82+
# Fallback to a generic error response
83+
return Response(
84+
content={"message": str(exc)},
85+
status_code=500,
86+
)
87+
88+
# Run the async logic in the event loop
89+
try:
90+
loop = asyncio.get_running_loop()
91+
except RuntimeError:
92+
# No event loop running, create one
93+
loop = asyncio.new_event_loop()
94+
asyncio.set_event_loop(loop)
95+
try:
96+
return loop.run_until_complete(handle_async())
97+
finally:
98+
loop.close()
99+
else:
100+
# Event loop is already running (which is the case in Litestar)
101+
# Create a task and run it
102+
import nest_asyncio # type: ignore
103+
104+
nest_asyncio.apply()
105+
return loop.run_until_complete(handle_async())
106+
107+
108+
def get_exception_handlers() -> dict[int | type[Exception], Any]:
109+
"""
110+
Get exception handlers for SuperTokens errors.
111+
112+
Returns:
113+
A dictionary mapping exception types to handler functions.
114+
115+
Example:
116+
```python
117+
from litestar import Litestar
118+
from supertokens_python.framework.litestar import (
119+
get_exception_handlers,
120+
get_supertokens_plugin,
121+
create_supertokens_middleware,
122+
)
123+
124+
app = Litestar(
125+
route_handlers=[...],
126+
middleware=[create_supertokens_middleware()],
127+
plugins=[get_supertokens_plugin(api_base_path="/auth")],
128+
exception_handlers=get_exception_handlers(),
129+
)
130+
```
131+
"""
132+
return {SuperTokensError: supertokens_exception_handler} # type: ignore

supertokens_python/framework/litestar/litestar_plugin.py

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,37 @@ class SupertokensPlugin(InitPluginProtocol):
2626
that processes SuperTokens authentication requests.
2727
"""
2828

29-
def __init__(self, api_base_path: str = "/auth"):
29+
def __init__(self, api_base_path: str = "/auth", app_root_path: str = ""):
3030
"""
3131
Initialize the SuperTokens plugin.
3232
3333
Args:
3434
api_base_path: The base path for SuperTokens API routes (default: "/auth")
35+
app_root_path: The root path of the Litestar app (e.g., "api/v1").
36+
If provided, will be stripped from api_base_path for mounting.
3537
"""
36-
self.api_base_path = api_base_path.rstrip("/")
38+
self.full_api_base_path = api_base_path.rstrip("/")
39+
self.app_root_path = app_root_path.strip("/")
40+
41+
# Calculate the mount path by removing app_root_path prefix from api_base_path
42+
if self.app_root_path and self.full_api_base_path.startswith(
43+
f"/{self.app_root_path}"
44+
):
45+
# Remove the root path prefix for mounting
46+
self.api_base_path = self.full_api_base_path[
47+
len(f"/{self.app_root_path}") :
48+
]
49+
elif self.app_root_path and self.full_api_base_path.startswith(
50+
self.app_root_path
51+
):
52+
# Handle case where api_base_path doesn't start with /
53+
self.api_base_path = self.full_api_base_path[len(self.app_root_path) :]
54+
else:
55+
self.api_base_path = self.full_api_base_path
56+
57+
# Ensure it starts with /
58+
if not self.api_base_path.startswith("/"):
59+
self.api_base_path = f"/{self.api_base_path}"
3760

3861
def on_app_init(self, app_config: AppConfig) -> AppConfig:
3962
"""
@@ -123,6 +146,11 @@ async def supertokens_asgi_app(
123146
custom_request, e, error_response, user_context
124147
)
125148

149+
# Clear the session from request.state to prevent the middleware
150+
# from re-applying session cookies after we've cleared them
151+
if hasattr(litestar_request.state, "supertokens"):
152+
delattr(litestar_request.state, "supertokens")
153+
126154
if isinstance(result, LitestarResponse):
127155
asgi_response = result.response.to_asgi_response(
128156
app=None, request=None
@@ -143,14 +171,31 @@ async def supertokens_asgi_app(
143171
return app_config
144172

145173

146-
def get_supertokens_plugin(api_base_path: str = "/auth") -> SupertokensPlugin:
174+
def get_supertokens_plugin(
175+
api_base_path: str = "/auth", app_root_path: str = ""
176+
) -> SupertokensPlugin:
147177
"""
148178
Get a configured SuperTokens plugin for Litestar.
149179
150180
Args:
151-
api_base_path: The base path for SuperTokens API routes (default: "/auth")
181+
api_base_path: The base path for SuperTokens API routes (default: "/auth").
182+
This should match the api_base_path in your SuperTokens init().
183+
app_root_path: The root path of your Litestar app if using app path (e.g., "api/v1").
184+
This will be automatically stripped from api_base_path for proper mounting.
152185
153186
Returns:
154187
A configured SupertokensPlugin instance
188+
189+
Example:
190+
# Without app root path
191+
app = Litestar(
192+
plugins=[get_supertokens_plugin(api_base_path="/auth")]
193+
)
194+
195+
# With app root path
196+
app = Litestar(
197+
path="api/v1",
198+
plugins=[get_supertokens_plugin(api_base_path="/api/v1/auth", app_root_path="api/v1")]
199+
)
155200
"""
156-
return SupertokensPlugin(api_base_path=api_base_path)
201+
return SupertokensPlugin(api_base_path=api_base_path, app_root_path=app_root_path)

tests/litestar/test_litestar.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,15 @@
1515
from typing import Any, Dict, Optional, Union
1616
from unittest import skip
1717

18-
from litestar import Litestar, MediaType, Request, Response, get, post
18+
from litestar import Litestar, Request, Response, get, post
1919
from litestar.di import Provide
2020
from litestar.testing import TestClient
2121
from pytest import fixture, mark
2222
from supertokens_python import InputAppInfo, SupertokensConfig, init
2323
from supertokens_python.framework import BaseRequest
2424
from supertokens_python.framework.litestar import (
2525
create_supertokens_middleware,
26+
get_exception_handlers,
2627
get_supertokens_plugin,
2728
)
2829
from supertokens_python.querier import Querier
@@ -138,16 +139,16 @@ async def custom_logout(request: Request[Any, Any, Any]) -> Dict[str, Any]:
138139
await session.revoke_session()
139140
return {}
140141

141-
@post("/create", media_type=MediaType.TEXT)
142-
async def _create(request: Request[Any, Any, Any]) -> str:
142+
@post("/create")
143+
async def _create(request: Request[Any, Any, Any]) -> Dict[str, Any]:
143144
await create_new_session(
144145
request=request,
145146
tenant_id="public",
146147
recipe_user_id=RecipeUserId("userId"),
147148
access_token_payload={},
148149
session_data_in_database={},
149150
)
150-
return ""
151+
return {}
151152

152153
@post("/create-throw")
153154
async def _create_throw(request: Request[Any, Any, Any]) -> None:
@@ -174,6 +175,7 @@ async def _create_throw(request: Request[Any, Any, Any]) -> None:
174175
],
175176
middleware=[create_supertokens_middleware()],
176177
plugins=[get_supertokens_plugin(api_base_path="/auth")],
178+
exception_handlers=get_exception_handlers(),
177179
)
178180
return TestClient(app)
179181

@@ -577,7 +579,12 @@ def test_litestar_root_path(litestar_root_path: str):
577579
app = Litestar(
578580
path=litestar_root_path,
579581
middleware=[create_supertokens_middleware()],
580-
plugins=[get_supertokens_plugin(api_base_path=f"{litestar_root_path}/auth")],
582+
plugins=[
583+
get_supertokens_plugin(
584+
api_base_path=f"{litestar_root_path}/auth",
585+
app_root_path=litestar_root_path,
586+
)
587+
],
581588
)
582589

583590
test_client = TestClient(app)
@@ -636,6 +643,8 @@ async def refresh_post(
636643
if token_transfer_method == "header":
637644
headers.update({"authorization": f"Bearer {info['refreshTokenFromAny']}"})
638645
else:
646+
# Clear existing cookies from the client to avoid duplicate cookie issues
647+
driver_config_client.cookies.clear()
639648
cookies.update(
640649
{"sRefreshToken": info["refreshTokenFromAny"], "sIdRefreshToken": "asdf"}
641650
)

0 commit comments

Comments
 (0)