Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 27 additions & 18 deletions src/quart_schema/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import sys
from dataclasses import fields, is_dataclass
from inspect import isclass
from typing import Any, Literal, TypeGuard, TypeVar
from typing import Any, get_origin, Literal, TypeGuard, TypeVar

import humps
from quart import current_app
Expand Down Expand Up @@ -208,7 +208,11 @@ def model_schema(
)
elif _use_msgspec(model_class, preference):
_, schema = schema_components([model_class], ref_template=MSGSPEC_REF_TEMPLATE)
return list(schema.values())[0]
schema_name = list(schema.keys())[0]
main_schema = schema.pop(schema_name)
if schema: # Remaining schemas (like Attribute) become $defs
main_schema["$defs"] = schema
return main_schema
elif not PYDANTIC_INSTALLED and not MSGSPEC_INSTALLED:
raise TypeError(
f"Cannot create schema for {model_class} - try installing msgspec or pydantic"
Expand Down Expand Up @@ -255,31 +259,36 @@ def _is_list_or_dict(type_: type) -> bool:
return origin in (dict, dict, list, list)


def _valid_model_class(model_class: type) -> bool:
"""Validate if a type can be used as a schema class.

Returns True for types that don't require conversion:
- TypedDict, dataclasses, and attrs classes
- Built-in dict/list and their generic aliases (e.g., dict[str, int])
"""
if (
_is_list_or_dict(model_class)
or is_dataclass(model_class)
or is_typeddict(model_class)
# Generic aliases: https://github.com/python/cpython/issues/149574
or is_dataclass(get_origin(model_class))
or is_typeddict(get_origin(model_class))
):
return True
return False


def _use_pydantic(model_class: type, preference: str | None) -> bool:
return PYDANTIC_INSTALLED and (
is_pydantic_dataclass(model_class)
or (isclass(model_class) and issubclass(model_class, BaseModel))
or (
(
_is_list_or_dict(model_class)
or is_dataclass(model_class)
or is_typeddict(model_class)
)
and preference != "msgspec"
)
or (_valid_model_class(model_class) and preference != "msgspec")
)


def _use_msgspec(model_class: type, preference: str | None) -> bool:
return MSGSPEC_INSTALLED and (
(isclass(model_class) and issubclass(model_class, Struct))
or is_attrs(model_class)
or (
(
_is_list_or_dict(model_class)
or is_dataclass(model_class)
or is_typeddict(model_class)
)
and preference != "pydantic"
)
or (_valid_model_class(model_class) and preference != "pydantic")
)
18 changes: 17 additions & 1 deletion tests/helpers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import sys
from dataclasses import dataclass
from typing import Annotated
from typing import Annotated, Generic, TypeVar

from attrs import define
from msgspec import Struct
Expand All @@ -12,6 +12,11 @@
else:
from typing_extensions import TypedDict

try:
from typing import NotRequired
except ImportError:
from typing_extensions import NotRequired


@define
class ADetails:
Expand Down Expand Up @@ -44,3 +49,14 @@ class PyDCDetails:
class TDetails(TypedDict):
name: str
age: Annotated[int | None, Field(default=None)]


N = TypeVar("N")


class _TGDetails(TypedDict, Generic[N]):
name: N
age: NotRequired[int | None]


TGDetails = _TGDetails[str]
207 changes: 141 additions & 66 deletions tests/test_conversion.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import sys
from dataclasses import dataclass
from typing import TypedDict
from typing import Any, Generic, TypeVar

import pytest
from attrs import define
Expand All @@ -8,69 +9,75 @@
from pydantic.dataclasses import dataclass as pydantic_dataclass

from quart_schema.conversion import convert_headers, model_dump, model_load, model_schema
from .helpers import ADetails, DCDetails, MDetails, PyDCDetails, PyDetails, TDetails
from .helpers import ADetails, DCDetails, MDetails, PyDCDetails, PyDetails, TDetails, TGDetails

if sys.version_info >= (3, 12):
from typing import TypedDict
else:
from typing_extensions import TypedDict


class ValidationError(Exception):
pass


@pytest.mark.parametrize("type_", [ADetails, DCDetails, MDetails, PyDetails, PyDCDetails, TDetails])
@pytest.mark.parametrize(
"type_",
[ADetails, DCDetails, MDetails, PyDetails, PyDCDetails, TDetails, TGDetails],
)
def test_model_dump(
type_: type[ADetails | DCDetails | MDetails | PyDetails | PyDCDetails | TDetails],
type_: type[ADetails | DCDetails | MDetails | PyDetails | PyDCDetails | TGDetails],
) -> None:
assert model_dump(type_(name="bob", age=2)) == { # type: ignore
"name": "bob",
"age": 2,
}


@pytest.mark.parametrize(
"type_, preference",
[
(ADetails, "msgspec"),
(DCDetails, "msgspec"),
(DCDetails, "pydantic"),
(MDetails, "msgspec"),
(PyDetails, "pydantic"),
(PyDCDetails, "pydantic"),
(TDetails, "pydantic"),
],
)
def test_model_dump_list(
type_: type[ADetails | DCDetails | MDetails | PyDetails | PyDCDetails | TDetails],
preference: str,
) -> None:
test_types_and_preference = [
(ADetails, "msgspec"),
(DCDetails, "msgspec"),
(DCDetails, "pydantic"),
(MDetails, "msgspec"),
(TGDetails, "msgspec"),
(TGDetails, "pydantic"),
(PyDetails, "pydantic"),
(PyDCDetails, "pydantic"),
(TDetails, "pydantic"),
]
test_types = [
ADetails,
DCDetails,
MDetails,
PyDetails,
PyDCDetails,
TDetails,
TGDetails,
]

TestType = type[ADetails | DCDetails | MDetails | PyDetails | PyDCDetails | TDetails]


@pytest.mark.parametrize("type_, preference", test_types_and_preference)
def test_model_dump_list(type_: TestType, preference: str) -> None:
assert model_dump(
[type_(name="bob", age=2), type_(name="jim", age=3)], preference=preference
[type_(name="bob", age=2), type_(name="jim", age=3)],
preference=preference,
) == [{"name": "bob", "age": 2}, {"name": "jim", "age": 3}]


@pytest.mark.parametrize("type_", [ADetails, DCDetails, MDetails, PyDetails, PyDCDetails, TDetails])
def test_model_load(
type_: type[ADetails | DCDetails | MDetails | PyDetails | PyDCDetails | TDetails],
) -> None:
assert model_load({"name": "bob", "age": 2}, type_, exception_class=ValidationError) == type_(
name="bob", age=2
)
@pytest.mark.parametrize("type_, preference", test_types_and_preference)
def test_model_load(type_: TestType, preference: str) -> None:
assert model_load(
{"name": "bob", "age": 2},
type_,
exception_class=ValidationError,
preference=preference,
) == type_(name="bob", age=2)


@pytest.mark.parametrize(
"type_, preference",
[
(ADetails, "msgspec"),
(DCDetails, "msgspec"),
(DCDetails, "pydantic"),
(MDetails, "msgspec"),
(PyDetails, "pydantic"),
(PyDCDetails, "pydantic"),
(TDetails, "pydantic"),
],
)
def test_model_load_list(
type_: type[ADetails | DCDetails | MDetails | PyDetails | PyDCDetails | TDetails],
preference: str,
) -> None:
@pytest.mark.parametrize("type_, preference", test_types_and_preference)
def test_model_load_list(type_: TestType, preference: str) -> None:
assert model_load(
[{"name": "bob", "age": 2}],
list[type_], # type: ignore
Expand All @@ -79,43 +86,111 @@ def test_model_load_list(
) == [type_(name="bob", age=2)]


@pytest.mark.parametrize("type_", [ADetails, DCDetails, MDetails, PyDetails, PyDCDetails, TDetails])
def test_model_load_error(
type_: type[ADetails | DCDetails | MDetails | PyDetails | PyDCDetails | TDetails],
) -> None:
@pytest.mark.parametrize("type_, preference", test_types_and_preference)
def test_model_load_error(type_: TestType, preference: str) -> None:
with pytest.raises(ValidationError):
model_load({"name": "bob", "age": "two"}, type_, exception_class=ValidationError)
model_load(
{"name": "bob", "age": "two"},
type_,
exception_class=ValidationError,
preference=preference,
)


@pytest.mark.parametrize("type_, preference", test_types_and_preference)
def test_model_schema_msgspec(type_: TestType, preference: str) -> None:
schema = model_schema(
type_,
preference=preference,
)

@pytest.mark.parametrize("type_", [ADetails, DCDetails, MDetails])
def test_model_schema_msgspec(type_: type[ADetails | DCDetails | MDetails]) -> None:
assert model_schema(type_, preference="msgspec") == {
# Base expected schema (common to both)
expected: dict[str, Any] = {
"title": type_.__name__,
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"anyOf": [{"type": "integer"}, {"type": "null"}], "default": None},
"age": {
"anyOf": [
{"type": "integer"},
{"type": "null"},
],
"default": None,
},
},
"required": ["name"],
}

# Pydantic adds "title" fields to properties
if preference == "pydantic":
expected["properties"]["name"]["title"] = "Name"
expected["properties"]["age"]["title"] = "Age"

@pytest.mark.parametrize("type_", [DCDetails, PyDetails, PyDCDetails, TDetails])
def test_model_schema_pydantic(
type_: type[DCDetails | PyDetails | PyDCDetails | TDetails],
) -> None:
assert model_schema(type_, preference="pydantic") == {
# For some reason the name for aliased type dicts
# includes the generic in msgspec
if preference == "msgspec" and type_ is TGDetails:
expected["title"] = "_TGDetails[str]"

# TGDetails does not include the default for age
if type_ is TGDetails:
del expected["properties"]["age"]["default"]

assert schema == expected


A = TypeVar("A")
M = TypeVar("M")


class Modifier(TypedDict):
mod: int


class Attribute(TypedDict):
title: str


class Resource(TypedDict, Generic[A, M]):
foo: str
attribute: A
modifier: M


def test_nested_generic_ref_included() -> None:
schema = model_schema(
Resource[Attribute, Modifier],
preference="msgspec",
)

assert schema == {
"title": "Resource[Attribute, Modifier]",
"type": "object",
"properties": {
"name": {"title": "Name", "type": "string"},
"age": {
"anyOf": [{"type": "integer"}, {"type": "null"}],
"default": None,
"title": "Age",
"attribute": {
"$ref": "#/components/schemas/Attribute",
},
"modifier": {
"$ref": "#/components/schemas/Modifier",
},
"foo": {"type": "string"},
},
"required": ["attribute", "foo", "modifier"],
"$defs": {
"Attribute": {
"properties": {
"title": {"type": "string"},
},
"required": ["title"],
"title": "Attribute",
"type": "object",
},
"Modifier": {
"properties": {"mod": {"type": "integer"}},
"required": ["mod"],
"title": "Modifier",
"type": "object",
},
},
"required": ["name"],
"title": type_.__name__,
"type": "object",
}


Expand Down
Loading