Skip to content

Commit 6d8d1fd

Browse files
committed
feat: refactor Litestar framework integration and update middleware imports
1 parent a21d10d commit 6d8d1fd

File tree

6 files changed

+64
-57
lines changed

6 files changed

+64
-57
lines changed

dev-requirements.txt

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,26 @@
1-
black==24.8.0
2-
Django==4.2.6
3-
django-cors-headers==4.4.0
4-
django-stubs==4.2.7
5-
django-stubs-ext==4.2.7
6-
fastapi==0.115.5
7-
Flask==3.0.3
8-
flask-cors==5.0.0
1+
black==25.9.0
2+
Django==5.2.7
3+
django-cors-headers==4.9.0
4+
django-stubs==5.2.7
5+
django-stubs-ext==5.2.7
6+
fastapi==0.120.1
7+
Flask==3.1.2
8+
flask-cors==6.0.1
99
nest-asyncio==1.6.0
10-
pdoc3==0.11.0
11-
pre-commit==3.5.0
12-
pyfakefs==5.7.4
13-
pylint==3.2.7
14-
pyright==1.1.393
15-
python-dotenv==1.0.1
16-
pytest==8.3.3
17-
pytest-asyncio==0.24.0
18-
pytest-mock==3.14.0
19-
pytest-rerunfailures==14.0
20-
pyyaml==6.0.2
10+
pdoc3==0.11.6
11+
pre-commit==4.3.0
12+
pyfakefs==5.10.1
13+
pylint==4.0.2
14+
pyright==1.1.407
15+
python-dotenv==1.2.1
16+
pytest==8.4.2
17+
pytest-asyncio==1.2.0
18+
pytest-mock==3.15.1
19+
pytest-rerunfailures==16.1
20+
pyyaml==6.0.3
2121
requests-mock==1.12.1
2222
respx>=0.13.0, <1.0.0
23-
uvicorn==0.32.0
24-
wasmtime==25.0.0
25-
litestar==2.16.0
23+
uvicorn==0.38.0
24+
wasmtime==38.0.0
25+
litestar==2.18.0
2626
-e .

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ indent-style = "space" # Default
1313
typeCheckingMode = "strict"
1414
reportImportCycles = false
1515
include = ["supertokens_python/", "tests/", "examples/"]
16+
exclude = [".venv"]
17+
venvPath = "."
18+
venv = ".venv"
1619

1720
[tool.pytest.ini_options]
1821
addopts = " -v -p no:warnings"

supertokens_python/framework/litestar/framework.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,11 @@
1111
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
1212
# License for the specific language governing permissions and limitations
1313
# under the License.
14-
from typing import TYPE_CHECKING, Any
14+
from typing import Any
1515

16+
from litestar import Request
1617
from supertokens_python.framework.types import Framework
1718

18-
if TYPE_CHECKING:
19-
from litestar import Request
20-
2119

2220
class LitestarFramework(Framework):
2321
def wrap_request(self, unwrapped: Request[Any, Any, Any]):

supertokens_python/framework/litestar/middleware.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any
1+
from typing import Any, Optional, cast
22

33
from litestar import Request, Response
44
from litestar.datastructures import MutableScopeHeaders
@@ -32,7 +32,10 @@ async def handle(
3232

3333
try:
3434
response = LitestarResponse(Response[Any](content=None))
35-
result = await st.middleware(custom_request, response, user_context)
35+
result = cast(
36+
Optional[LitestarResponse],
37+
await st.middleware(custom_request, response, user_context),
38+
)
3639

3740
if result is not None:
3841
# SuperTokens handled the request
@@ -50,8 +53,8 @@ async def handle(
5053
async def modified_send(message: Message):
5154
if message["type"] == "http.response.start":
5255
mutable_headers = MutableScopeHeaders(message)
53-
for cookie in result.response.cookies:
54-
cookie_value = cookie.to_header().split(": ", 1)[1]
56+
for cookie in result.response.cookies: # type: ignore
57+
cookie_value = cookie.to_header().split(": ", 1)[1] # type: ignore
5558
mutable_headers.add("set-cookie", cookie_value)
5659
await send(message)
5760

@@ -64,16 +67,16 @@ async def send_wrapper(message: Message):
6467
request.state.supertokens, SessionContainer
6568
):
6669
temp_response = Response[Any](content=None)
67-
temp_response.headers = MutableScopeHeaders(message)
70+
temp_response.headers = MutableScopeHeaders(message) # type: ignore
6871
litestar_response = LitestarResponse(temp_response)
6972
manage_session_post_response(
7073
request.state.supertokens,
7174
litestar_response,
7275
user_context,
7376
)
7477
mutable_headers = MutableScopeHeaders(message)
75-
for cookie in litestar_response.response.cookies:
76-
cookie_value = cookie.to_header().split(": ", 1)[1]
78+
for cookie in litestar_response.response.cookies: # type: ignore
79+
cookie_value = cookie.to_header().split(": ", 1)[1] # type: ignore
7780
mutable_headers.add("set-cookie", cookie_value)
7881
await send(message)
7982

@@ -82,8 +85,11 @@ async def send_wrapper(message: Message):
8285
except SuperTokensError as e:
8386
# Handle SuperTokens errors
8487
response = LitestarResponse(Response[Any](content=None))
85-
result = await st.handle_supertokens_error(
86-
custom_request, e, response, user_context
88+
result = cast(
89+
Optional[LitestarResponse],
90+
await st.handle_supertokens_error(
91+
custom_request, e, response, user_context
92+
),
8793
)
8894
if isinstance(result, LitestarResponse):
8995
asgi_response = await result.response.to_asgi_response(
@@ -93,8 +99,8 @@ async def send_wrapper(message: Message):
9399
async def modified_send(message: Message):
94100
if message["type"] == "http.response.start":
95101
mutable_headers = MutableScopeHeaders(message)
96-
for cookie in result.response.cookies:
97-
cookie_value = cookie.to_header().split(": ", 1)[1]
102+
for cookie in result.response.cookies: # type: ignore
103+
cookie_value = cookie.to_header().split(": ", 1)[1] # type: ignore
98104
mutable_headers.add("set-cookie", cookie_value)
99105
await send(message)
100106

supertokens_python/framework/litestar/request.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# License for the specific language governing permissions and limitations
1313
# under the License.
1414
import json
15-
from typing import Any, Dict, Union
15+
from typing import Any, Union
1616
from urllib.parse import parse_qsl
1717

1818
from litestar import Request
@@ -33,10 +33,10 @@ def get_query_param(
3333
) -> Union[str, None]:
3434
return self.request.query_params.get(key, default)
3535

36-
def get_query_params(self) -> Dict[str, Any]:
36+
def get_query_params(self) -> dict[str, Any]:
3737
return dict(self.request.query_params.items()) # type: ignore
3838

39-
async def json(self) -> dict:
39+
async def json(self) -> dict[str, Any]:
4040
"""
4141
Read the entire ASGI stream and JSON-decode it,
4242
sidestepping Litestar’s internal max-body-size logic.

tests/litestar/test_litestar.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# License for the specific language governing permissions and limitations
1313
# under the License.
1414
import json
15-
from typing import Any, Dict, Union
15+
from typing import Any, Dict, Optional, Union
1616
from unittest import skip
1717

1818
from litestar import Litestar, MediaType, Request, Response, get, post
@@ -61,7 +61,7 @@
6161

6262
def override_dashboard_functions(original_implementation: RecipeInterface):
6363
async def should_allow_access(
64-
request: BaseRequest, __: DashboardConfig, ___: Dict[str, Any]
64+
request: BaseRequest, config: DashboardConfig, user_context: Dict[str, Any]
6565
):
6666
auth_header = request.get_header("authorization")
6767
return auth_header == "Bearer testapikey"
@@ -121,7 +121,7 @@ async def handle_get(
121121
},
122122
)
123123
async def handle_get_optional(
124-
session: SessionContainer,
124+
session: Optional[SessionContainer],
125125
) -> Dict[str, Any]:
126126
if session is None:
127127
return {"s": "empty session"}
@@ -739,9 +739,9 @@ async def test_search_with_email_t(driver_config_client: TestClient[Litestar]):
739739
querier = Querier.get_instance(DashboardRecipe.recipe_id)
740740
cdi_version = await querier.get_api_version()
741741
if not cdi_version:
742-
skip()
742+
skip("CDI version not available")
743743
if not is_version_gte(cdi_version, "2.20"):
744-
skip()
744+
skip("CDI version too old")
745745
await create_users(emailpassword=True)
746746
query = {"limit": "10", "email": "t"}
747747
res = driver_config_client.get(
@@ -786,9 +786,9 @@ async def test_search_with_email_multiple_email_entry(
786786
querier = Querier.get_instance(DashboardRecipe.recipe_id)
787787
cdi_version = await querier.get_api_version()
788788
if not cdi_version:
789-
skip()
789+
skip("CDI version not available")
790790
if not is_version_gte(cdi_version, "2.20"):
791-
skip()
791+
skip("CDI version too old")
792792
await create_users(emailpassword=True)
793793
query = {"limit": "10", "email": "iresh;john"}
794794
res = driver_config_client.get(
@@ -831,9 +831,9 @@ async def test_search_with_email_iresh(driver_config_client: TestClient[Litestar
831831
querier = Querier.get_instance(DashboardRecipe.recipe_id)
832832
cdi_version = await querier.get_api_version()
833833
if not cdi_version:
834-
skip()
834+
skip("CDI version not available")
835835
if not is_version_gte(cdi_version, "2.20"):
836-
skip()
836+
skip("CDI version too old")
837837
await create_users(emailpassword=True)
838838
query = {"limit": "10", "email": "iresh"}
839839
res = driver_config_client.get(
@@ -879,9 +879,9 @@ async def test_search_with_phone_plus_one(driver_config_client: TestClient[Lites
879879
querier = Querier.get_instance(DashboardRecipe.recipe_id)
880880
cdi_version = await querier.get_api_version()
881881
if not cdi_version:
882-
skip()
882+
skip("CDI version not available")
883883
if not is_version_gte(cdi_version, "2.20"):
884-
skip()
884+
skip("CDI version too old")
885885
await create_users(passwordless=True)
886886
query = {"limit": "10", "phone": "+1"}
887887
res = driver_config_client.get(
@@ -929,9 +929,9 @@ async def test_search_with_phone_one_bracket(
929929
querier = Querier.get_instance(DashboardRecipe.recipe_id)
930930
cdi_version = await querier.get_api_version()
931931
if not cdi_version:
932-
skip()
932+
skip("CDI version not available")
933933
if not is_version_gte(cdi_version, "2.20"):
934-
skip()
934+
skip("CDI version too old")
935935
await create_users(passwordless=True)
936936
query = {"limit": "10", "phone": "1("}
937937
res = driver_config_client.get(
@@ -1016,9 +1016,9 @@ async def test_search_with_provider_google(driver_config_client: TestClient[Lite
10161016
querier = Querier.get_instance(DashboardRecipe.recipe_id)
10171017
cdi_version = await querier.get_api_version()
10181018
if not cdi_version:
1019-
skip()
1019+
skip("CDI version not available")
10201020
if not is_version_gte(cdi_version, "2.20"):
1021-
skip()
1021+
skip("CDI version too old")
10221022
await create_users(thirdparty=True)
10231023
query = {"limit": "10", "provider": "google"}
10241024
res = driver_config_client.get(
@@ -1109,9 +1109,9 @@ async def test_search_with_provider_google_and_phone_1(
11091109
querier = Querier.get_instance(DashboardRecipe.recipe_id)
11101110
cdi_version = await querier.get_api_version()
11111111
if not cdi_version:
1112-
skip()
1112+
skip("CDI version not available")
11131113
if not is_version_gte(cdi_version, "2.20"):
1114-
skip()
1114+
skip("CDI version too old")
11151115
await create_users(thirdparty=True, passwordless=True)
11161116
query = {"limit": "10", "provider": "google", "phone": "1"}
11171117
res = driver_config_client.get(

0 commit comments

Comments
 (0)