Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 109 additions & 28 deletions tests/test_qblike_3.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typemap.typing import (
Attrs,
Bool,
Cls,
Length,
GetArg,
GetMemberType,
Expand All @@ -29,6 +30,7 @@
Matches,
Member,
NewProtocol,
PropName,
Slice,
)

Expand Down Expand Up @@ -137,13 +139,15 @@ class Table[name: str]:
pass


class Field[Table, Name, PyType]:
class FieldType[Table, Name, PyType]:
def __lt__(self, other: Any) -> Filter[Table]: ...


type FieldTable[T] = GetArg[T, Field, Literal[0]]
type FieldName[T] = GetArg[T, Field, Literal[1]]
type FieldPyType[T] = GetArg[T, Field, Literal[2]]
type Field[T] = FieldType[Cls, PropName, T]

type FieldTable[T: FieldType] = GetArg[T, FieldType, Literal[0]]
type FieldName[T: FieldType] = GetArg[T, FieldType, Literal[1]]
type FieldPyType[T: FieldType] = GetArg[T, FieldType, Literal[2]]


class ColumnArgs(TypedDict, total=False):
Expand Down Expand Up @@ -230,7 +234,9 @@ class DbLinkSource[Args: DbLinkSourceArgs](InitField[Args]):

type MakeQueryEntryAllFields[T: Table] = QueryEntry[
T,
tuple[*[GetName[m] for m in Iter[Attrs[T]] if IsSub[GetType[m], Field]],],
tuple[
*[GetName[m] for m in Iter[Attrs[T]] if IsSub[GetType[m], FieldType]],
],
]
type MakeQueryEntryNamedFields[
T: Table,
Expand All @@ -241,7 +247,7 @@ class DbLinkSource[Args: DbLinkSourceArgs](InitField[Args]):
*[
GetName[m]
for m in Iter[Attrs[T]]
if IsSub[GetType[m], Field]
if IsSub[GetType[m], FieldType]
and any(IsSub[FieldName[GetType[m]], f] for f in Iter[FieldNames])
],
],
Expand All @@ -258,7 +264,7 @@ class DbLinkSource[Args: DbLinkSourceArgs](InitField[Args]):
else [MakeQueryEntryAllFields[New]]
),
]
type AddField[Entries, New: Field] = tuple[
type AddField[Entries, New: FieldType] = tuple[
*[ # Existing entries
(
e # Non-matching entry
Expand All @@ -276,7 +282,7 @@ class DbLinkSource[Args: DbLinkSourceArgs](InitField[Args]):
if not Bool[EntriesHasTable[Entries, FieldTable[New]]]
),
]
type AddEntries[Entries, News: tuple[Table | Field, ...]] = (
type AddEntries[Entries, News: tuple[Table | FieldType, ...]] = (
Entries
if IsSub[Length[News], Literal[0]]
else AddEntries[
Expand Down Expand Up @@ -341,50 +347,44 @@ def execute[Es: tuple[type[Table], ...]](


class User(Table[Literal["users"]]):
id: Field[User, Literal["id"], int] = column(
id: Field[int] = column(
db_type=DbInteger(), primary_key=True, autoincrement=True
)
name: Field[User, Literal["name"], str] = column(
db_type=DbString(length=150), nullable=False
)
email: Field[User, Literal["email"], str] = column(
name: Field[str] = column(db_type=DbString(length=150), nullable=False)
email: Field[str] = column(
db_type=DbString(length=100), unique=True, nullable=False
)
age: Field[User, Literal["age"], int | None] = column(db_type=DbInteger())
active: Field[User, Literal["active"], bool] = column(
age: Field[int | None] = column(db_type=DbInteger())
active: Field[bool] = column(
db_type=DbBoolean(), default=True, nullable=False
)
posts: Field[User, Literal["posts"], list[Post]] = column(
posts: Field[list[Post]] = column(
db_type=DbLinkSource(source="Post", cardinality=Cardinality.MANY)
)


class Post(Table[Literal["posts"]]):
id: Field[Post, Literal["id"], int] = column(
id: Field[int] = column(
db_type=DbInteger(), primary_key=True, autoincrement=True
)
content: Field[Post, Literal["content"], str] = column(
db_type=DbString(length=1000), nullable=False
)
author: Field[Post, Literal["author"], User] = column(
content: Field[str] = column(db_type=DbString(length=1000), nullable=False)
author: Field[User] = column(
db_type=DbLinkTarget(target=User), nullable=False
)
comments: Field[Post, Literal["comments"], list[Comment]] = column(
comments: Field[list[Comment]] = column(
db_type=DbLinkSource(source="Comment", cardinality=Cardinality.MANY)
)


class Comment(Table[Literal["comments"]]):
id: Field[Comment, Literal["id"], int] = column(
id: Field[int] = column(
db_type=DbInteger(), primary_key=True, autoincrement=True
)
content: Field[Comment, Literal["content"], str] = column(
db_type=DbString(length=1000), nullable=False
)
author: Field[Comment, Literal["author"], User] = column(
content: Field[str] = column(db_type=DbString(length=1000), nullable=False)
author: Field[User] = column(
db_type=DbLinkTarget(target=User), nullable=False
)
post: Field[Comment, Literal["post"], Post] = column(
post: Field[Post] = column(
db_type=DbLinkTarget(target=Post), nullable=False
)

Expand Down Expand Up @@ -573,3 +573,84 @@ class Select[tests.test_qblike_3.User, tuple[typing.Literal['name']]]:
class Select[tests.test_qblike_3.Post, tuple[typing.Literal['content']]]:
content: str
""")


def test_qblike_3_select_08():
class UserAlias(User):
pass

query = eval_call_with_types(select, UserAlias)
fmt = format_helper.format_class(query)
assert fmt == textwrap.dedent("""\
class Query[tuple[tuple[tests.test_qblike_3.test_qblike_3_select_08.<locals>.UserAlias, tuple[typing.Literal['id'], typing.Literal['name'], typing.Literal['email'], typing.Literal['age'], typing.Literal['active'], typing.Literal['posts']]]]]:
""")

results = eval_call_with_types(Session.execute, Session, query)
result = eval_typing(GetArg[results, list, Literal[0]])

fmt = format_helper.format_class(result)
assert fmt == textwrap.dedent("""\
class Select[tests.test_qblike_3.test_qblike_3_select_08.<locals>.UserAlias, tuple[typing.Literal['id'], typing.Literal['name'], typing.Literal['email'], typing.Literal['age'], typing.Literal['active'], typing.Literal['posts']]]:
id: int
name: str
email: str
age: int | None
active: bool
posts: list[tests.test_qblike_3.Post]
""")


def test_qblike_3_select_09():
class UserAlias(User):
pass

user_alias_name = eval_typing(GetMemberType[UserAlias, Literal["name"]])
user_alias_email = eval_typing(GetMemberType[UserAlias, Literal["email"]])
query = eval_call_with_types(select, user_alias_name, user_alias_email)
fmt = format_helper.format_class(query)
assert fmt == textwrap.dedent("""\
class Query[tuple[tuple[tests.test_qblike_3.test_qblike_3_select_09.<locals>.UserAlias, tuple[typing.Literal['name'], typing.Literal['email']]]]]:
""")

results = eval_call_with_types(Session.execute, Session, query)
result = eval_typing(GetArg[results, list, Literal[0]])

fmt = format_helper.format_class(result)
assert fmt == textwrap.dedent("""\
class Select[tests.test_qblike_3.test_qblike_3_select_09.<locals>.UserAlias, tuple[typing.Literal['name'], typing.Literal['email']]]:
name: str
email: str
""")


def test_qblike_3_select_10():
class UserAlias(User):
pass

user_name = eval_typing(GetMemberType[User, Literal["name"]])
user_alias_name = eval_typing(GetMemberType[UserAlias, Literal["name"]])
query = eval_call_with_types(select, user_name, user_alias_name)
fmt = format_helper.format_class(query)
assert fmt == textwrap.dedent("""\
class Query[tuple[tuple[tests.test_qblike_3.User, tuple[typing.Literal['name']]], tuple[tests.test_qblike_3.test_qblike_3_select_10.<locals>.UserAlias, tuple[typing.Literal['name']]]]]:
""")

results = eval_call_with_types(Session.execute, Session, query)
result = eval_typing(GetArg[results, list, Literal[0]])

result_names = eval_typing(AttrNames[result])
assert result_names == tuple[Literal["User"], Literal["UserAlias"]]

result_user = eval_typing(GetMemberType[result, Literal["User"]])
fmt = format_helper.format_class(result_user)
assert fmt == textwrap.dedent("""\
class Select[tests.test_qblike_3.User, tuple[typing.Literal['name']]]:
name: str
""")

result_user_alias = eval_typing(GetMemberType[result, Literal["UserAlias"]])
fmt = format_helper.format_class(result_user_alias)
assert fmt == textwrap.dedent("""\
class Select[tests.test_qblike_3.test_qblike_3_select_10.<locals>.UserAlias, tuple[typing.Literal['name']]]:
name: str
""")
Loading