Skip to content

Commit 54925f7

Browse files
committed
feat: update Litestar framework integration with new request and response modules
1 parent e931dc0 commit 54925f7

File tree

9 files changed

+140
-130
lines changed

9 files changed

+140
-130
lines changed

examples/with-litestar/with-thirdpartyemailpassword/main.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
get_all_cors_headers,
1313
init,
1414
)
15-
from supertokens_python.framework.litestar.middleware import LitestarMiddleware
1615
from supertokens_python.recipe import (
1716
dashboard,
1817
emailverification,
@@ -179,9 +178,6 @@ def f_405(_, __: Exception):
179178
route_handlers=[
180179
get_session_info,
181180
],
182-
middleware=[
183-
LitestarMiddleware(),
184-
],
185181
cors_config=cors,
186182
exception_handlers={
187183
Exception: f_405,

supertokens_python/framework/litestar/__init__.py

Whitespace-only changes.

supertokens_python/framework/litestar/framework.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,13 @@
1414
from typing import Any
1515

1616
from litestar import Request
17+
1718
from supertokens_python.framework.types import Framework
1819

1920

2021
class LitestarFramework(Framework):
2122
def wrap_request(self, unwrapped: Request[Any, Any, Any]):
22-
from supertokens_python.framework.litestar.request import (
23+
from supertokens_python.framework.litestar.litestar_request import (
2324
LitestarRequest,
2425
)
2526

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved.
2+
#
3+
# This software is licensed under the Apache License, Version 2.0 (the
4+
# "License") as published by the Apache Software Foundation.
5+
#
6+
# You may not use this file except in compliance with the License. You may
7+
# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11+
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12+
# License for the specific language governing permissions and limitations
13+
# under the License.
14+
from typing import Any, Union, cast
15+
16+
17+
def get_middleware():
18+
from litestar import Request as LitestarRequestClass
19+
from litestar import Response as LitestarResponseClass
20+
from starlette.responses import Response as StarletteResponse
21+
from starlette.types import ASGIApp, Message, Receive, Scope, Send
22+
23+
from supertokens_python import Supertokens
24+
from supertokens_python.exceptions import SuperTokensError
25+
from supertokens_python.framework import BaseResponse
26+
from supertokens_python.framework.litestar.litestar_request import (
27+
LitestarRequest,
28+
)
29+
from supertokens_python.framework.litestar.litestar_response import (
30+
LitestarResponse,
31+
)
32+
from supertokens_python.recipe.session import SessionContainer
33+
from supertokens_python.supertokens import manage_session_post_response
34+
from supertokens_python.utils import default_user_context
35+
36+
class ASGIMiddleware:
37+
def __init__(self, app: ASGIApp) -> None:
38+
self.app = app
39+
40+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
41+
if scope["type"] != "http":
42+
await self.app(scope, receive, send)
43+
return
44+
45+
st = Supertokens.get_instance()
46+
47+
request = LitestarRequestClass[Any, Any, Any](
48+
cast(Any, scope), receive=cast(Any, receive)
49+
)
50+
custom_request = LitestarRequest(request)
51+
user_context = default_user_context(custom_request)
52+
53+
try:
54+
response = LitestarResponse(LitestarResponseClass[Any](content={}))
55+
result: Union[BaseResponse, None] = await st.middleware(
56+
custom_request, response, user_context
57+
)
58+
if result is None:
59+
60+
async def send_wrapper(message: Message):
61+
if message["type"] == "http.response.start":
62+
if hasattr(request.state, "supertokens") and isinstance(
63+
request.state.supertokens, SessionContainer
64+
):
65+
starlette_response = StarletteResponse()
66+
starlette_response.raw_headers = message["headers"]
67+
response = LitestarResponse(
68+
cast(LitestarResponseClass[Any], starlette_response)
69+
)
70+
manage_session_post_response(
71+
request.state.supertokens, response, user_context
72+
)
73+
message["headers"] = starlette_response.raw_headers
74+
75+
await send(message)
76+
77+
await self.app(scope, receive, send_wrapper)
78+
return
79+
80+
if hasattr(request.state, "supertokens") and isinstance(
81+
request.state.supertokens, SessionContainer
82+
):
83+
manage_session_post_response(
84+
request.state.supertokens, result, user_context
85+
)
86+
87+
if isinstance(result, LitestarResponse):
88+
resp_any = cast(Any, result.response)
89+
asgi_response = resp_any.to_asgi_response(app=None, request=request)
90+
await asgi_response(scope, receive, send)
91+
return
92+
93+
return
94+
95+
except SuperTokensError as e:
96+
response = LitestarResponse(LitestarResponseClass[Any](content={}))
97+
result: Union[BaseResponse, None] = await st.handle_supertokens_error(
98+
LitestarRequest(request), e, response, user_context
99+
)
100+
if isinstance(result, LitestarResponse):
101+
resp_any = cast(Any, result.response)
102+
asgi_response = resp_any.to_asgi_response(app=None, request=request)
103+
await asgi_response(scope, receive, send)
104+
return
105+
106+
async def send_wrapper(message: Message):
107+
if message["type"] == "http.response.start":
108+
if hasattr(request.state, "supertokens") and isinstance(
109+
request.state.supertokens, SessionContainer
110+
):
111+
starlette_response = StarletteResponse()
112+
starlette_response.raw_headers = message["headers"]
113+
response = LitestarResponse(
114+
cast(LitestarResponseClass[Any], starlette_response)
115+
)
116+
manage_session_post_response(
117+
request.state.supertokens, response, user_context
118+
)
119+
message["headers"] = starlette_response.raw_headers
120+
await send(message)
121+
122+
await self.app(scope, receive, send_wrapper)
123+
124+
return ASGIMiddleware

supertokens_python/framework/litestar/request.py renamed to supertokens_python/framework/litestar/litestar_request.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@
1212
# License for the specific language governing permissions and limitations
1313
# under the License.
1414
import json
15-
from typing import Any, Union
15+
from typing import Any, Dict, Union
1616
from urllib.parse import parse_qsl
1717

1818
from litestar import Request
19+
1920
from supertokens_python.framework.request import BaseRequest
2021
from supertokens_python.recipe.session.interfaces import SessionContainer
2122

@@ -33,10 +34,10 @@ def get_query_param(
3334
) -> Union[str, None]:
3435
return self.request.query_params.get(key, default)
3536

36-
def get_query_params(self) -> dict[str, Any]:
37+
def get_query_params(self) -> Dict[str, Any]:
3738
return dict(self.request.query_params.items()) # type: ignore
3839

39-
async def json(self) -> dict[str, Any]:
40+
async def json(self) -> Dict[str, Any]:
4041
"""
4142
Read the entire ASGI stream and JSON-decode it,
4243
sidestepping Litestar’s internal max-body-size logic.

supertokens_python/framework/litestar/response.py renamed to supertokens_python/framework/litestar/litestar_response.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from litestar import Response
1818
from litestar.serialization import encode_json
19+
1920
from supertokens_python.framework.response import BaseResponse
2021
from supertokens_python.utils import get_timestamp_ms
2122

@@ -60,7 +61,7 @@ def set_cookie(
6061
)
6162

6263
def set_header(self, key: str, value: str):
63-
self.response.set_header(key, value)
64+
self.response.headers[key] = value
6465

6566
def get_header(self, key: str) -> Optional[str]:
6667
return self.response.headers.get(key, None)

supertokens_python/framework/litestar/middleware.py

Lines changed: 0 additions & 109 deletions
This file was deleted.

supertokens_python/recipe/session/framework/litestar/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818

1919
from supertokens_python import Supertokens
2020
from supertokens_python.exceptions import SuperTokensError
21-
from supertokens_python.framework.litestar.request import LitestarRequest
22-
from supertokens_python.framework.litestar.response import LitestarResponse
21+
from supertokens_python.framework.litestar.litestar_request import LitestarRequest
22+
from supertokens_python.framework.litestar.litestar_response import LitestarResponse
2323
from supertokens_python.recipe.session import SessionContainer, SessionRecipe
2424
from supertokens_python.recipe.session.interfaces import SessionClaimValidator
2525
from supertokens_python.types import MaybeAwaitable

tests/litestar/test_litestar.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from pytest import fixture, mark
2222
from supertokens_python import InputAppInfo, SupertokensConfig, init
2323
from supertokens_python.framework import BaseRequest
24-
from supertokens_python.framework.litestar.middleware import LitestarMiddleware
24+
from supertokens_python.framework.litestar.litestar_middleware import get_middleware
2525
from supertokens_python.querier import Querier
2626
from supertokens_python.recipe import emailpassword, session, thirdparty
2727
from supertokens_python.recipe.dashboard import DashboardRecipe, InputOverrideConfig
@@ -169,12 +169,9 @@ async def _create_throw(request: Request[Any, Any, Any]) -> None:
169169
_create,
170170
_create_throw,
171171
],
172-
middleware=[
173-
LitestarMiddleware(),
174-
],
175172
)
176-
177-
return TestClient(app)
173+
return TestClient(get_middleware()(app))
174+
# return TestClient(app)
178175

179176

180177
def apis_override_session(param: APIInterface):
@@ -575,11 +572,10 @@ def test_litestar_root_path(litestar_root_path: str):
575572
litestar_root_path = litestar_root_path[1:]
576573
app = Litestar(
577574
path=litestar_root_path,
578-
middleware=[
579-
LitestarMiddleware(),
580-
],
581575
)
582-
test_client = TestClient(app)
576+
577+
test_client = TestClient(get_middleware()(app))
578+
# test_client = TestClient(app)
583579

584580
response = test_client.get(
585581
f"{litestar_root_path}/auth/signup/email/exists?email=test@example.com"

0 commit comments

Comments
 (0)