Skip to content

Commit dabb1dd

Browse files
committed
feat: add Litestar framework support and related middleware
1 parent 1d431dc commit dabb1dd

File tree

12 files changed

+1506
-5
lines changed

12 files changed

+1506
-5
lines changed

dev-requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,5 @@ pyyaml==6.0.2
2121
requests-mock==1.12.1
2222
respx>=0.13.0, <1.0.0
2323
uvicorn==0.32.0
24+
litestar==2.16.0
2425
-e .

supertokens_python/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030

3131
def init(
3232
app_info: InputAppInfo,
33-
framework: Literal["fastapi", "flask", "django"],
33+
framework: Literal["fastapi", "flask", "django", "litestar"],
3434
supertokens_config: SupertokensConfig,
3535
recipe_list: List[Callable[[supertokens.AppInfo], RecipeModule]],
3636
mode: Optional[Literal["asgi", "wsgi"]] = None,
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved.
2+
#
3+
# This software is licensed under the Apache License, Version 2.0 (the
4+
# "License") as published by the Apache Software Foundation.
5+
#
6+
# You may not use this file except in compliance with the License. You may
7+
# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11+
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12+
# License for the specific language governing permissions and limitations
13+
# under the License.
14+
from __future__ import annotations
15+
16+
from typing import TYPE_CHECKING, Any
17+
18+
from supertokens_python.framework.types import Framework
19+
20+
if TYPE_CHECKING:
21+
from litestar import Request
22+
23+
24+
class LitestarFramework(Framework):
25+
def wrap_request(self, unwrapped: Request[Any, Any, Any]):
26+
from supertokens_python.framework.litestar.litestar_request import (
27+
LitestarRequest,
28+
)
29+
30+
return LitestarRequest(unwrapped)
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
from typing import Any
2+
3+
from litestar import Request, Response
4+
from litestar.datastructures import MutableScopeHeaders
5+
from litestar.middleware.base import AbstractMiddleware
6+
from litestar.types import Message, Receive, Scope, Send
7+
from supertokens_python import Supertokens
8+
from supertokens_python.exceptions import SuperTokensError
9+
from supertokens_python.framework.litestar.litestar_request import LitestarRequest
10+
from supertokens_python.framework.litestar.litestar_response import LitestarResponse
11+
from supertokens_python.recipe.session import SessionContainer
12+
from supertokens_python.supertokens import manage_session_post_response
13+
from supertokens_python.utils import default_user_context
14+
15+
16+
class LitestarMiddleware(AbstractMiddleware):
17+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
18+
if scope["type"] != "http":
19+
await self.app(scope, receive, send)
20+
return
21+
22+
st = Supertokens.get_instance()
23+
request = Request[Any, Any, Any](scope, receive=receive, send=send)
24+
custom_request = LitestarRequest(request)
25+
user_context = default_user_context(custom_request)
26+
27+
try:
28+
response = LitestarResponse(Response[Any](content=None))
29+
result = await st.middleware(custom_request, response, user_context)
30+
if result is not None:
31+
# SuperTokens handled the request
32+
if hasattr(request.state, "supertokens") and isinstance(
33+
request.state.supertokens, SessionContainer
34+
):
35+
manage_session_post_response(
36+
request.state.supertokens, result, user_context
37+
)
38+
# Add cookies using MutableScopeHeaders
39+
asgi_response = await result.response.to_asgi_response(
40+
app=None, request=request
41+
)
42+
43+
async def modified_send(message: Message):
44+
if message["type"] == "http.response.start":
45+
mutable_headers = MutableScopeHeaders(message)
46+
for cookie in result.response.cookies:
47+
cookie_value = cookie.to_header().split(": ", 1)[1]
48+
mutable_headers.add("set-cookie", cookie_value)
49+
await send(message)
50+
51+
await asgi_response(scope, receive, modified_send)
52+
return
53+
else:
54+
# SuperTokens didn’t handle the request; wrap the send function
55+
async def send_wrapper(message: Message):
56+
if message["type"] == "http.response.start":
57+
if hasattr(request.state, "supertokens") and isinstance(
58+
request.state.supertokens, SessionContainer
59+
):
60+
temp_response = Response[Any](content=None)
61+
temp_response.headers = MutableScopeHeaders(message)
62+
litestar_response = LitestarResponse(temp_response)
63+
manage_session_post_response(
64+
request.state.supertokens,
65+
litestar_response,
66+
user_context,
67+
)
68+
mutable_headers = MutableScopeHeaders(message)
69+
for cookie in litestar_response.response.cookies:
70+
cookie_value = cookie.to_header().split(": ", 1)[1]
71+
mutable_headers.add("set-cookie", cookie_value)
72+
await send(message)
73+
74+
await self.app(scope, receive, send_wrapper)
75+
return
76+
77+
except SuperTokensError as e:
78+
response = LitestarResponse(Response[Any](content=None))
79+
result = await st.handle_supertokens_error(
80+
custom_request, e, response, user_context
81+
)
82+
if isinstance(result, LitestarResponse):
83+
# Add cookies using MutableScopeHeaders
84+
asgi_response = await result.response.to_asgi_response(
85+
app=None, request=request
86+
)
87+
88+
async def modified_send(message: Message):
89+
if message["type"] == "http.response.start":
90+
mutable_headers = MutableScopeHeaders(message)
91+
for cookie in result.response.cookies:
92+
cookie_value = cookie.to_header().split(": ", 1)[1]
93+
mutable_headers.add("set-cookie", cookie_value)
94+
await send(message)
95+
96+
await asgi_response(scope, receive, modified_send)
97+
return
98+
raise Exception("Should never come here")
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved.
2+
#
3+
# This software is licensed under the Apache License, Version 2.0 (the
4+
# "License") as published by the Apache Software Foundation.
5+
#
6+
# You may not use this file except in compliance with the License. You may
7+
# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11+
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12+
# License for the specific language governing permissions and limitations
13+
# under the License.
14+
from typing import Any, Dict, Union
15+
from urllib.parse import parse_qsl
16+
17+
from litestar import Request
18+
from litestar.exceptions import SerializationException
19+
from supertokens_python.framework.request import BaseRequest
20+
from supertokens_python.recipe.session.interfaces import SessionContainer
21+
22+
23+
class LitestarRequest(BaseRequest):
24+
def __init__(self, request: Request[Any, Any, Any]):
25+
super().__init__()
26+
self.request = request
27+
28+
def get_original_url(self) -> str:
29+
print(f"The request url is {self.request.url}")
30+
print(f"The request url is {self.request.url.from_components()}")
31+
return str(self.request.url)
32+
33+
def get_query_param(
34+
self, key: str, default: Union[str, None] = None
35+
) -> Union[str, None]:
36+
return self.request.query_params.get(key, default)
37+
38+
def get_query_params(self) -> Dict[str, Any]:
39+
return dict(self.request.query_params.items()) # type: ignore
40+
41+
async def json(self) -> Union[Any, None]:
42+
try:
43+
return await self.request.json()
44+
except SerializationException:
45+
return {}
46+
47+
def method(self) -> str:
48+
return self.request.method
49+
50+
def get_cookie(self, key: str) -> Union[str, None]:
51+
return self.request.cookies.get(key)
52+
53+
def get_header(self, key: str) -> Union[str, None]:
54+
return self.request.headers.get(key, None)
55+
56+
def get_session(self) -> Union[SessionContainer, None]:
57+
return self.request.state.supertokens
58+
59+
def set_session(self, session: SessionContainer):
60+
self.request.state.supertokens = session
61+
62+
def set_session_as_none(self):
63+
self.request.state.supertokens = None
64+
65+
def get_path(self) -> str:
66+
print(f"The request url is {self.request.url}")
67+
return self.request.url.path
68+
69+
async def form_data(self):
70+
return dict(parse_qsl((await self.request.body()).decode("utf-8")))
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved.
2+
#
3+
# This software is licensed under the Apache License, Version 2.0 (the
4+
# "License") as published by the Apache Software Foundation.
5+
#
6+
# You may not use this file except in compliance with the License. You may
7+
# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11+
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12+
# License for the specific language governing permissions and limitations
13+
# under the License.
14+
from math import ceil
15+
from typing import Any, Dict, Literal, Optional
16+
17+
from supertokens_python.framework.response import BaseResponse
18+
from supertokens_python.utils import get_timestamp_ms
19+
20+
21+
class LitestarResponse(BaseResponse):
22+
from litestar import Response
23+
24+
def __init__(self, response: Response[Any]):
25+
super().__init__({})
26+
self.response = response
27+
self.original = response
28+
self.parser_checked = False
29+
self.response_sent = False
30+
self.status_set = False
31+
32+
def set_html_content(self, content: str):
33+
if not self.response_sent:
34+
body = bytes(content, "utf-8")
35+
self.set_header("Content-Length", str(len(body)))
36+
self.set_header("Content-Type", "text/html")
37+
self.response.content = body
38+
self.response_sent = True
39+
40+
def set_cookie(
41+
self,
42+
key: str,
43+
value: str,
44+
expires: int,
45+
path: str = "/",
46+
domain: Optional[str] = None,
47+
secure: bool = False,
48+
httponly: bool = False,
49+
samesite: Literal["lax", "strict", "none"] = "lax",
50+
):
51+
self.response.set_cookie(
52+
key=key,
53+
value=value,
54+
expires=ceil((expires - get_timestamp_ms()) / 1000),
55+
path=path,
56+
domain=domain,
57+
secure=secure,
58+
httponly=httponly,
59+
samesite=samesite,
60+
)
61+
62+
def set_header(self, key: str, value: str):
63+
self.response.set_header(key, value)
64+
65+
def get_header(self, key: str) -> Optional[str]:
66+
return self.response.headers.get(key, None)
67+
68+
def remove_header(self, key: str):
69+
del self.response.headers[key]
70+
71+
def set_status_code(self, status_code: int):
72+
if not self.status_set:
73+
self.response.status_code = status_code
74+
self.status_code = status_code
75+
self.status_set = True
76+
77+
def set_json_content(self, content: Dict[str, Any]):
78+
if not self.response_sent:
79+
from litestar.serialization import encode_json
80+
81+
body = encode_json(
82+
content,
83+
)
84+
self.set_header("Content-Type", "application/json; charset=utf-8")
85+
self.set_header("Content-Length", str(len(body)))
86+
self.response.content = body
87+
self.response_sent = True
88+
89+
def redirect(self, url: str) -> BaseResponse:
90+
if not self.response_sent:
91+
self.set_header("Location", url)
92+
self.set_status_code(302)
93+
self.response_sent = True
94+
return self

supertokens_python/framework/types.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,14 @@
1818

1919
from supertokens_python.framework.request import BaseRequest
2020

21-
frameworks = ["fastapi", "flask", "django"]
21+
frameworks = ["fastapi", "flask", "django", "litestar"]
2222

2323

2424
class FrameworkEnum(Enum):
2525
FASTAPI = 1
2626
FLASK = 2
2727
DJANGO = 3
28+
LITESTAR = 4
2829

2930

3031
class Framework(ABC):

0 commit comments

Comments
 (0)