Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from typing import Annotated

from sqlmodel import Field, Relationship, Session, SQLModel, create_engine


class Team(SQLModel, table=True):
id: Annotated[int | None, Field(primary_key=True)] = None
name: Annotated[str, Field(index=True)]
headquarters: str

heroes: Annotated[list["Hero"] | None, Relationship(back_populates="team")] = None


class Hero(SQLModel, table=True):
id: Annotated[int | None, Field(primary_key=True)] = None
name: Annotated[str, Field(index=True)]
secret_name: str
age: Annotated[int | None, Field(index=True)] = None

team_id: Annotated[int | None, Field(foreign_key="team.id")] = None
team: Annotated[Team | None, Relationship(back_populates="heroes")] = None


sqlite_file_name = "database.db"
sqlite_url = f"sqlite:///{sqlite_file_name}"

engine = create_engine(sqlite_url, echo=True)


def create_db_and_tables():
SQLModel.metadata.create_all(engine)


def create_heroes():
with Session(engine) as session:
team_preventers = Team(name="Preventers", headquarters="Sharp Tower")
team_z_force = Team(name="Z-Force", headquarters="Sister Margaret's Bar")

hero_deadpond = Hero(
name="Deadpond", secret_name="Dive Wilson", team=team_z_force
)
hero_rusty_man = Hero(
name="Rusty-Man", secret_name="Tommy Sharp", age=48, team=team_preventers
)
hero_spider_boy = Hero(name="Spider-Boy", secret_name="Pedro Parqueador")
session.add(hero_deadpond)
session.add(hero_rusty_man)
session.add(hero_spider_boy)
session.commit()

session.refresh(hero_deadpond)
session.refresh(hero_rusty_man)
session.refresh(hero_spider_boy)

print("Created hero:", hero_deadpond)
print("Created hero:", hero_rusty_man)
print("Created hero:", hero_spider_boy)

hero_spider_boy.team = team_preventers
session.add(hero_spider_boy)
session.commit()


def main():
create_db_and_tables()
create_heroes()


if __name__ == "__main__":
main()
3 changes: 3 additions & 0 deletions sqlmodel/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,9 @@ def get_relationship_to(
elif origin is list:
use_annotation = get_args(annotation)[0]

elif origin is Annotated:
use_annotation = get_args(annotation)[0]

return get_relationship_to(name=name, rel_info=rel_info, annotation=use_annotation)


Expand Down
21 changes: 20 additions & 1 deletion sqlmodel/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@
from pathlib import Path
from typing import (
TYPE_CHECKING,
Annotated,
Any,
ClassVar,
Literal,
TypeAlias,
TypeVar,
Union,
cast,
get_args,
get_origin,
overload,
)
Expand Down Expand Up @@ -518,6 +520,16 @@ def Relationship(
return relationship_info


def get_annotated_relationshipinfo(t: Any) -> RelationshipInfo | None:
"""Get the first RelationshipInfo from Annotated or None if not Annotated with RelationshipInfo."""
if get_origin(t) is not Annotated:
return None
for a in get_args(t):
if isinstance(a, RelationshipInfo):
return a
return None


@__dataclass_transform__(kw_only_default=True, field_descriptors=(Field, FieldInfo))
class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
__sqlmodel_relationships__: dict[str, RelationshipInfo]
Expand Down Expand Up @@ -551,7 +563,11 @@ def __new__(
pydantic_annotations = {}
relationship_annotations = {}
for k, v in class_dict.items():
if isinstance(v, RelationshipInfo):
a = original_annotations.get(k, None)
r = get_annotated_relationshipinfo(a)
if r is not None:
relationships[k] = r
elif isinstance(v, RelationshipInfo):
relationships[k] = v
else:
dict_for_pydantic[k] = v
Expand Down Expand Up @@ -644,6 +660,9 @@ def __init__(
origin: Any = get_origin(raw_ann)
if origin is Mapped:
ann = raw_ann.__args__[0]
if origin is Annotated:
ann = get_args(raw_ann)[0]
cls.__annotations__[rel_name] = Mapped[ann] # type: ignore[valid-type]
else:
ann = raw_ann
# Plain forward references, for models not yet defined, are not
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from unittest.mock import patch

from sqlmodel import create_engine

from ....conftest import get_testing_print_function, needs_py310

expected_calls = [
[
"Created hero:",
{
"name": "Deadpond",
"age": None,
"team_id": 1,
"id": 1,
"secret_name": "Dive Wilson",
},
],
[
"Created hero:",
{
"name": "Rusty-Man",
"age": 48,
"team_id": 2,
"id": 2,
"secret_name": "Tommy Sharp",
},
],
[
"Created hero:",
{
"name": "Spider-Boy",
"age": None,
"team_id": None,
"id": 3,
"secret_name": "Pedro Parqueador",
},
],
]


@needs_py310
def test_tutorial(clear_sqlmodel):
from docs_src.tutorial.relationship_attributes.define_relationship_attributes import (
tutorial001_an_py310 as mod,
)

mod.sqlite_url = "sqlite://"
mod.engine = create_engine(mod.sqlite_url)
calls = []

new_print = get_testing_print_function(calls)

with patch("builtins.print", new=new_print):
mod.main()
assert calls == expected_calls