diff --git a/docs_src/tutorial/relationship_attributes/define_relationship_attributes/tutorial001_an_py310.py b/docs_src/tutorial/relationship_attributes/define_relationship_attributes/tutorial001_an_py310.py new file mode 100644 index 0000000000..a4584836f7 --- /dev/null +++ b/docs_src/tutorial/relationship_attributes/define_relationship_attributes/tutorial001_an_py310.py @@ -0,0 +1,70 @@ +from typing import Annotated + +from sqlmodel import Field, Relationship, Session, SQLModel, create_engine + + +class Team(SQLModel, table=True): + id: Annotated[int | None, Field(primary_key=True)] = None + name: Annotated[str, Field(index=True)] + headquarters: str + + heroes: Annotated[list["Hero"] | None, Relationship(back_populates="team")] = None + + +class Hero(SQLModel, table=True): + id: Annotated[int | None, Field(primary_key=True)] = None + name: Annotated[str, Field(index=True)] + secret_name: str + age: Annotated[int | None, Field(index=True)] = None + + team_id: Annotated[int | None, Field(foreign_key="team.id")] = None + team: Annotated[Team | None, Relationship(back_populates="heroes")] = None + + +sqlite_file_name = "database.db" +sqlite_url = f"sqlite:///{sqlite_file_name}" + +engine = create_engine(sqlite_url, echo=True) + + +def create_db_and_tables(): + SQLModel.metadata.create_all(engine) + + +def create_heroes(): + with Session(engine) as session: + team_preventers = Team(name="Preventers", headquarters="Sharp Tower") + team_z_force = Team(name="Z-Force", headquarters="Sister Margaret's Bar") + + hero_deadpond = Hero( + name="Deadpond", secret_name="Dive Wilson", team=team_z_force + ) + hero_rusty_man = Hero( + name="Rusty-Man", secret_name="Tommy Sharp", age=48, team=team_preventers + ) + hero_spider_boy = Hero(name="Spider-Boy", secret_name="Pedro Parqueador") + session.add(hero_deadpond) + session.add(hero_rusty_man) + session.add(hero_spider_boy) + session.commit() + + session.refresh(hero_deadpond) + session.refresh(hero_rusty_man) + session.refresh(hero_spider_boy) + + print("Created hero:", hero_deadpond) + print("Created hero:", hero_rusty_man) + print("Created hero:", hero_spider_boy) + + hero_spider_boy.team = team_preventers + session.add(hero_spider_boy) + session.commit() + + +def main(): + create_db_and_tables() + create_heroes() + + +if __name__ == "__main__": + main() diff --git a/sqlmodel/_compat.py b/sqlmodel/_compat.py index a220b193f1..841fe5744c 100644 --- a/sqlmodel/_compat.py +++ b/sqlmodel/_compat.py @@ -151,6 +151,9 @@ def get_relationship_to( elif origin is list: use_annotation = get_args(annotation)[0] + elif origin is Annotated: + use_annotation = get_args(annotation)[0] + return get_relationship_to(name=name, rel_info=rel_info, annotation=use_annotation) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 300031de8b..0d2217267e 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -12,6 +12,7 @@ from pathlib import Path from typing import ( TYPE_CHECKING, + Annotated, Any, ClassVar, Literal, @@ -19,6 +20,7 @@ TypeVar, Union, cast, + get_args, get_origin, overload, ) @@ -518,6 +520,16 @@ def Relationship( return relationship_info +def get_annotated_relationshipinfo(t: Any) -> RelationshipInfo | None: + """Get the first RelationshipInfo from Annotated or None if not Annotated with RelationshipInfo.""" + if get_origin(t) is not Annotated: + return None + for a in get_args(t): + if isinstance(a, RelationshipInfo): + return a + return None + + @__dataclass_transform__(kw_only_default=True, field_descriptors=(Field, FieldInfo)) class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta): __sqlmodel_relationships__: dict[str, RelationshipInfo] @@ -551,7 +563,11 @@ def __new__( pydantic_annotations = {} relationship_annotations = {} for k, v in class_dict.items(): - if isinstance(v, RelationshipInfo): + a = original_annotations.get(k, None) + r = get_annotated_relationshipinfo(a) + if r is not None: + relationships[k] = r + elif isinstance(v, RelationshipInfo): relationships[k] = v else: dict_for_pydantic[k] = v @@ -644,6 +660,9 @@ def __init__( origin: Any = get_origin(raw_ann) if origin is Mapped: ann = raw_ann.__args__[0] + if origin is Annotated: + ann = get_args(raw_ann)[0] + cls.__annotations__[rel_name] = Mapped[ann] # type: ignore[valid-type] else: ann = raw_ann # Plain forward references, for models not yet defined, are not diff --git a/tests/test_tutorial/test_relationship_attributes/test_define_relationship_attributes/test_tutorial001.py b/tests/test_tutorial/test_relationship_attributes/test_define_relationship_attributes/test_tutorial001.py index c1f5c269be..74e0a754eb 100644 --- a/tests/test_tutorial/test_relationship_attributes/test_define_relationship_attributes/test_tutorial001.py +++ b/tests/test_tutorial/test_relationship_attributes/test_define_relationship_attributes/test_tutorial001.py @@ -4,13 +4,14 @@ import pytest from sqlmodel import create_engine -from ....conftest import PrintMock, needs_py310 +from ....conftest import PrintMock @pytest.fixture( name="mod", params=[ - pytest.param("tutorial001_py310", marks=needs_py310), + pytest.param("tutorial001_py310"), + pytest.param("tutorial001_an_py310"), ], ) def get_module(request: pytest.FixtureRequest) -> ModuleType: