Skip to content

Commit a9690e6

Browse files
committed
feat: integrate SuperTokens middleware and plugin for Litestar framework
1 parent b338d5e commit a9690e6

File tree

5 files changed

+260
-111
lines changed

5 files changed

+260
-111
lines changed

examples/with-litestar/with-thirdpartyemailpassword/main.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
get_all_cors_headers,
1313
init,
1414
)
15+
from supertokens_python.framework.litestar import (
16+
create_supertokens_middleware,
17+
get_supertokens_plugin,
18+
)
1519
from supertokens_python.recipe import (
1620
dashboard,
1721
emailverification,
@@ -182,6 +186,8 @@ def f_405(_, __: Exception):
182186
exception_handlers={
183187
Exception: f_405,
184188
},
189+
middleware=[create_supertokens_middleware()],
190+
plugins=[get_supertokens_plugin(api_base_path="/auth")],
185191
)
186192

187193
if __name__ == "__main__":

supertokens_python/framework/litestar/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,12 @@
1111
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
1212
# License for the specific language governing permissions and limitations
1313
# under the License.
14+
15+
from .litestar_middleware import create_supertokens_middleware
16+
from .litestar_plugin import SupertokensPlugin, get_supertokens_plugin
17+
18+
__all__ = [
19+
"SupertokensPlugin",
20+
"get_supertokens_plugin",
21+
"create_supertokens_middleware",
22+
]

supertokens_python/framework/litestar/litestar_middleware.py

Lines changed: 79 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -11,114 +11,87 @@
1111
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
1212
# License for the specific language governing permissions and limitations
1313
# under the License.
14-
from typing import Any, Union, cast
15-
16-
17-
def get_middleware():
18-
from litestar import Request as LitestarRequestClass
19-
from litestar import Response as LitestarResponseClass
20-
from starlette.responses import Response as StarletteResponse
21-
from starlette.types import ASGIApp, Message, Receive, Scope, Send
22-
23-
from supertokens_python import Supertokens
24-
from supertokens_python.exceptions import SuperTokensError
25-
from supertokens_python.framework import BaseResponse
26-
from supertokens_python.framework.litestar.litestar_request import (
27-
LitestarRequest,
28-
)
29-
from supertokens_python.framework.litestar.litestar_response import (
30-
LitestarResponse,
31-
)
32-
from supertokens_python.recipe.session import SessionContainer
33-
from supertokens_python.supertokens import manage_session_post_response
34-
from supertokens_python.utils import default_user_context
35-
36-
class ASGIMiddleware:
37-
def __init__(self, app: ASGIApp) -> None:
38-
self.app = app
39-
40-
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
41-
if scope["type"] != "http":
42-
await self.app(scope, receive, send)
43-
return
44-
45-
st = Supertokens.get_instance()
46-
47-
request = LitestarRequestClass[Any, Any, Any](
48-
cast(Any, scope), receive=cast(Any, receive)
49-
)
50-
custom_request = LitestarRequest(request)
51-
user_context = default_user_context(custom_request)
52-
53-
try:
54-
response = LitestarResponse(LitestarResponseClass[Any](content={}))
55-
result: Union[BaseResponse, None] = await st.middleware(
56-
custom_request, response, user_context
57-
)
58-
if result is None:
59-
60-
async def send_wrapper(message: Message):
61-
if message["type"] == "http.response.start":
62-
if hasattr(request.state, "supertokens") and isinstance(
63-
request.state.supertokens, SessionContainer
64-
):
65-
starlette_response = StarletteResponse()
66-
starlette_response.raw_headers = message["headers"]
67-
response = LitestarResponse(
68-
cast(LitestarResponseClass[Any], starlette_response)
69-
)
70-
manage_session_post_response(
71-
request.state.supertokens, response, user_context
72-
)
73-
message["headers"] = starlette_response.raw_headers
74-
75-
await send(message)
76-
77-
await self.app(scope, receive, send_wrapper)
78-
return
79-
14+
from litestar.middleware import DefineMiddleware
15+
from litestar.types import ASGIApp, Receive, Scope, Send
16+
17+
18+
class SupertokensSessionMiddleware:
19+
"""
20+
Middleware to handle session management for non-auth routes.
21+
22+
This middleware applies session-related response mutators (like setting cookies)
23+
for routes that use SuperTokens sessions but aren't auth routes.
24+
"""
25+
26+
def __init__(self, app: ASGIApp) -> None:
27+
self.app = app
28+
29+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
30+
if scope["type"] != "http":
31+
await self.app(scope, receive, send)
32+
return
33+
34+
from litestar import Request
35+
from litestar import Response as LitestarResponseObj
36+
from litestar.types import Message
37+
38+
from supertokens_python.framework.litestar.litestar_request import (
39+
LitestarRequest,
40+
)
41+
from supertokens_python.framework.litestar.litestar_response import (
42+
LitestarResponse,
43+
)
44+
from supertokens_python.recipe.session import SessionContainer
45+
from supertokens_python.supertokens import manage_session_post_response
46+
from supertokens_python.utils import default_user_context
47+
48+
request = Request(scope, receive=receive, send=send) # type: ignore
49+
custom_request = LitestarRequest(request)
50+
user_context = default_user_context(custom_request)
51+
52+
async def send_wrapper(message: Message) -> None:
53+
if message["type"] == "http.response.start":
54+
# Apply session mutators to response headers
8055
if hasattr(request.state, "supertokens") and isinstance(
81-
request.state.supertokens, SessionContainer
56+
getattr(request.state, "supertokens", None), SessionContainer
8257
):
58+
# Create a temporary Litestar Response
59+
temp_response = LitestarResponseObj(content=None)
60+
61+
# Convert raw ASGI headers to dict for Litestar Response
62+
for name, value in message.get("headers", []): # type: ignore
63+
temp_response.headers[
64+
name.decode() if isinstance(name, bytes) else name
65+
] = value.decode() if isinstance(value, bytes) else value
66+
67+
# Wrap it for SuperTokens
68+
wrapped_response = LitestarResponse(temp_response)
69+
70+
# Apply session mutators (this will modify temp_response)
8371
manage_session_post_response(
84-
request.state.supertokens, result, user_context
72+
getattr(request.state, "supertokens"),
73+
wrapped_response,
74+
user_context,
8575
)
8676

87-
if isinstance(result, LitestarResponse):
88-
resp_any = cast(Any, result.response)
89-
asgi_response = resp_any.to_asgi_response(app=None, request=request)
90-
await asgi_response(scope, receive, send)
91-
return
92-
93-
return
94-
95-
except SuperTokensError as e:
96-
response = LitestarResponse(LitestarResponseClass[Any](content={}))
97-
result: Union[BaseResponse, None] = await st.handle_supertokens_error(
98-
LitestarRequest(request), e, response, user_context
99-
)
100-
if isinstance(result, LitestarResponse):
101-
resp_any = cast(Any, result.response)
102-
asgi_response = resp_any.to_asgi_response(app=None, request=request)
103-
await asgi_response(scope, receive, send)
104-
return
105-
106-
async def send_wrapper(message: Message):
107-
if message["type"] == "http.response.start":
108-
if hasattr(request.state, "supertokens") and isinstance(
109-
request.state.supertokens, SessionContainer
110-
):
111-
starlette_response = StarletteResponse()
112-
starlette_response.raw_headers = message["headers"]
113-
response = LitestarResponse(
114-
cast(LitestarResponseClass[Any], starlette_response)
115-
)
116-
manage_session_post_response(
117-
request.state.supertokens, response, user_context
118-
)
119-
message["headers"] = starlette_response.raw_headers
120-
await send(message)
121-
122-
await self.app(scope, receive, send_wrapper)
123-
124-
return ASGIMiddleware
77+
# Convert the Litestar Response to ASGI format to get cookies as Set-Cookie headers
78+
asgi_response = temp_response.to_asgi_response(
79+
app=None, request=None
80+
) # type: ignore
81+
82+
# Use the encoded headers which include Set-Cookie headers from cookies
83+
message["headers"] = asgi_response.encoded_headers
84+
85+
await send(message)
86+
87+
await self.app(scope, receive, send_wrapper)
88+
89+
90+
def create_supertokens_middleware() -> DefineMiddleware:
91+
"""
92+
Create a DefineMiddleware instance for SuperTokens session management.
93+
94+
Returns:
95+
A DefineMiddleware configured with SupertokensSessionMiddleware
96+
"""
97+
return DefineMiddleware(SupertokensSessionMiddleware)
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved.
2+
#
3+
# This software is licensed under the Apache License, Version 2.0 (the
4+
# "License") as published by the Apache Software Foundation.
5+
#
6+
# You may not use this file except in compliance with the License. You may
7+
# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11+
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12+
# License for the specific language governing permissions and limitations
13+
# under the License.
14+
from typing import Union
15+
16+
from litestar import asgi
17+
from litestar.config.app import AppConfig
18+
from litestar.plugins import InitPluginProtocol
19+
20+
21+
class SupertokensPlugin(InitPluginProtocol):
22+
"""
23+
Litestar plugin for SuperTokens integration.
24+
25+
This plugin handles authentication routes by mounting a custom ASGI app
26+
that processes SuperTokens authentication requests.
27+
"""
28+
29+
def __init__(self, api_base_path: str = "/auth"):
30+
"""
31+
Initialize the SuperTokens plugin.
32+
33+
Args:
34+
api_base_path: The base path for SuperTokens API routes (default: "/auth")
35+
"""
36+
self.api_base_path = api_base_path.rstrip("/")
37+
38+
def on_app_init(self, app_config: AppConfig) -> AppConfig:
39+
"""
40+
Called during app initialization to register the SuperTokens ASGI app.
41+
42+
Args:
43+
app_config: The Litestar application configuration
44+
45+
Returns:
46+
The modified application configuration
47+
"""
48+
from litestar import Request, Response
49+
from litestar.types import Receive, Scope, Send
50+
51+
from supertokens_python import Supertokens
52+
from supertokens_python.exceptions import SuperTokensError
53+
from supertokens_python.framework.litestar.litestar_request import (
54+
LitestarRequest,
55+
)
56+
from supertokens_python.framework.litestar.litestar_response import (
57+
LitestarResponse,
58+
)
59+
from supertokens_python.recipe.session import SessionContainer
60+
from supertokens_python.supertokens import manage_session_post_response
61+
from supertokens_python.utils import default_user_context
62+
63+
async def supertokens_asgi_app(
64+
scope: Scope, receive: Receive, send: Send
65+
) -> None:
66+
"""
67+
ASGI app that handles SuperTokens authentication requests.
68+
"""
69+
if scope["type"] != "http":
70+
# Pass through non-HTTP requests
71+
not_found = Response(content=None, status_code=404)
72+
await not_found.to_asgi_response(app=None, request=None)(
73+
scope, receive, send
74+
) # type: ignore
75+
return
76+
77+
st = Supertokens.get_instance()
78+
79+
# Create Litestar request and wrap it for SuperTokens
80+
litestar_request = Request(scope, receive=receive, send=send)
81+
custom_request = LitestarRequest(litestar_request)
82+
user_context = default_user_context(custom_request)
83+
84+
try:
85+
# Create a response object for SuperTokens to use
86+
litestar_response = Response(content=None)
87+
response = LitestarResponse(litestar_response)
88+
89+
# Let SuperTokens middleware handle the request
90+
result: Union[LitestarResponse, None] = await st.middleware(
91+
custom_request, response, user_context
92+
)
93+
94+
if result is None:
95+
# Request was not handled by SuperTokens
96+
not_found_response = Response(content=None, status_code=404)
97+
await not_found_response.to_asgi_response(app=None, request=None)(
98+
scope, receive, send
99+
) # type: ignore
100+
return
101+
102+
# Handle session management
103+
if hasattr(litestar_request.state, "supertokens") and isinstance(
104+
litestar_request.state.supertokens, SessionContainer
105+
):
106+
manage_session_post_response(
107+
litestar_request.state.supertokens, result, user_context
108+
)
109+
110+
# Send the response
111+
if isinstance(result, LitestarResponse):
112+
asgi_response = result.response.to_asgi_response(
113+
app=None, request=None
114+
) # type: ignore
115+
await asgi_response(scope, receive, send)
116+
return
117+
118+
except SuperTokensError as e:
119+
# Handle SuperTokens-specific errors
120+
error_response_obj = Response(content=None)
121+
error_response = LitestarResponse(error_response_obj)
122+
result = await st.handle_supertokens_error(
123+
custom_request, e, error_response, user_context
124+
)
125+
126+
if isinstance(result, LitestarResponse):
127+
asgi_response = result.response.to_asgi_response(
128+
app=None, request=None
129+
) # type: ignore
130+
await asgi_response(scope, receive, send)
131+
return
132+
133+
# Fallback - this should not normally be reached
134+
fallback_response = Response(content=None, status_code=500)
135+
await fallback_response.to_asgi_response(app=None, request=None)(
136+
scope, receive, send
137+
) # type: ignore
138+
139+
# Mount the SuperTokens ASGI app to handle auth routes
140+
app_mount = asgi(self.api_base_path, is_mount=True)(supertokens_asgi_app)
141+
app_config.route_handlers.append(app_mount)
142+
143+
return app_config
144+
145+
146+
def get_supertokens_plugin(api_base_path: str = "/auth") -> SupertokensPlugin:
147+
"""
148+
Get a configured SuperTokens plugin for Litestar.
149+
150+
Args:
151+
api_base_path: The base path for SuperTokens API routes (default: "/auth")
152+
153+
Returns:
154+
A configured SupertokensPlugin instance
155+
"""
156+
return SupertokensPlugin(api_base_path=api_base_path)

0 commit comments

Comments
 (0)