Skip to content

Commit 8bab505

Browse files
authored
Merge pull request #1 from PaleNeutron/sqlalchemy_polymorphic_support
Add support for SQLAlchemy polymorphic models
2 parents 3e971b2 + 64f774f commit 8bab505

3 files changed

Lines changed: 380 additions & 10 deletions

File tree

sqlmodel/_compat.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
from pydantic import VERSION as P_VERSION
2323
from pydantic import BaseModel
2424
from pydantic.fields import FieldInfo
25+
from sqlalchemy import inspect
26+
from sqlalchemy.orm import Mapper
2527
from typing_extensions import Annotated, get_args, get_origin
2628

2729
# Reassign variable to make it reexported for mypy
@@ -66,6 +68,35 @@ def _is_union_type(t: Any) -> bool:
6668
finish_init: ContextVar[bool] = ContextVar("finish_init", default=True)
6769

6870

71+
def set_polymorphic_default_value(
72+
self_instance: _TSQLModel,
73+
values: Dict[str, Any],
74+
) -> bool:
75+
"""By default, when init a model, pydantic will set the polymorphic_on
76+
value to field default value. But when inherit a model, the polymorphic_on
77+
should be set to polymorphic_identity value by default."""
78+
cls = type(self_instance)
79+
mapper = inspect(cls)
80+
ret = False
81+
if isinstance(mapper, Mapper):
82+
polymorphic_on = mapper.polymorphic_on
83+
if polymorphic_on is not None:
84+
polymorphic_property = mapper.get_property_by_column(polymorphic_on)
85+
field_info = get_model_fields(cls).get(polymorphic_property.key)
86+
if field_info:
87+
v = values.get(polymorphic_property.key)
88+
# if model is inherited or polymorphic_on is not explicitly set
89+
# set the polymorphic_on by default
90+
if mapper.inherits or v is None:
91+
setattr(
92+
self_instance,
93+
polymorphic_property.key,
94+
mapper.polymorphic_identity,
95+
)
96+
ret = True
97+
return ret
98+
99+
69100
@contextmanager
70101
def partial_init() -> Generator[None, None, None]:
71102
token = finish_init.set(False)
@@ -312,6 +343,8 @@ def sqlmodel_table_construct(
312343
if value is not Undefined:
313344
setattr(self_instance, key, value)
314345
# End SQLModel override
346+
# Override polymorphic_on default value
347+
set_polymorphic_default_value(self_instance, values)
315348
return self_instance
316349

317350
def sqlmodel_validate(

sqlmodel/main.py

Lines changed: 63 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import ipaddress
44
import uuid
5+
import warnings
56
import weakref
67
from datetime import date, datetime, time, timedelta
78
from decimal import Decimal
@@ -43,9 +44,10 @@
4344
)
4445
from sqlalchemy import Enum as sa_Enum
4546
from sqlalchemy.orm import (
47+
InstrumentedAttribute,
4648
Mapped,
49+
MappedColumn,
4750
RelationshipProperty,
48-
declared_attr,
4951
registry,
5052
relationship,
5153
)
@@ -539,7 +541,42 @@ def __new__(
539541
config_kwargs = {
540542
key: kwargs[key] for key in kwargs.keys() & allowed_config_kwargs
541543
}
542-
new_cls = super().__new__(cls, name, bases, dict_used, **config_kwargs)
544+
is_polymorphic = False
545+
if IS_PYDANTIC_V2:
546+
base_fields = {}
547+
base_annotations = {}
548+
for base in bases[::-1]:
549+
if issubclass(base, BaseModel):
550+
base_fields.update(get_model_fields(base))
551+
base_annotations.update(base.__annotations__)
552+
if hasattr(base, "__sqlmodel_relationships__"):
553+
for k in base.__sqlmodel_relationships__:
554+
# create a dummy attribute to avoid inherit
555+
# pydantic will treat it as class variables, and will not become fields on model instances
556+
anno = base_annotations.get(k, Any)
557+
if get_origin(anno) is not ClassVar:
558+
dummy_anno = ClassVar[anno]
559+
dict_used["__annotations__"][k] = dummy_anno
560+
561+
if hasattr(base, "__tablename__"):
562+
is_polymorphic = True
563+
# use base_fields overwriting the ones from the class for inherit
564+
# if base is a sqlalchemy model, it's attributes will be an InstrumentedAttribute
565+
# thus pydantic will use the value of the attribute as the default value
566+
base_annotations.update(dict_used["__annotations__"])
567+
dict_used["__annotations__"] = base_annotations
568+
base_fields.update(dict_used)
569+
dict_used = base_fields
570+
# if is_polymorphic, disable pydantic `shadows an attribute` warning
571+
if is_polymorphic:
572+
with warnings.catch_warnings():
573+
warnings.filterwarnings(
574+
"ignore",
575+
message="Field name .+ shadows an attribute in parent.+",
576+
)
577+
new_cls = super().__new__(cls, name, bases, dict_used, **config_kwargs)
578+
else:
579+
new_cls = super().__new__(cls, name, bases, dict_used, **config_kwargs)
543580
new_cls.__annotations__ = {
544581
**relationship_annotations,
545582
**pydantic_annotations,
@@ -559,9 +596,22 @@ def get_config(name: str) -> Any:
559596

560597
config_table = get_config("table")
561598
if config_table is True:
599+
# sqlalchemy mark a class as table by check if it has __tablename__ attribute
600+
# or if __tablename__ is in __annotations__. Only set __tablename__ if it's
601+
# a table model
602+
if new_cls.__name__ != "SQLModel" and not hasattr(new_cls, "__tablename__"):
603+
setattr(new_cls, "__tablename__", new_cls.__name__.lower()) # noqa: B010
562604
# If it was passed by kwargs, ensure it's also set in config
563605
set_config_value(model=new_cls, parameter="table", value=config_table)
564606
for k, v in get_model_fields(new_cls).items():
607+
original_v = getattr(new_cls, k, None)
608+
if (
609+
isinstance(original_v, InstrumentedAttribute)
610+
and k not in class_dict
611+
):
612+
# The attribute was already set by SQLAlchemy, don't override it
613+
# Needed for polymorphic models, see #36
614+
continue
565615
col = get_column_from_field(v)
566616
setattr(new_cls, k, col)
567617
# Set a config flag to tell FastAPI that this should be read with a field
@@ -595,7 +645,15 @@ def __init__(
595645
# trying to create a new SQLAlchemy, for a new table, with the same name, that
596646
# triggers an error
597647
base_is_table = any(is_table_model_class(base) for base in bases)
598-
if is_table_model_class(cls) and not base_is_table:
648+
_mapper_args = dict_.get("__mapper_args__", {})
649+
polymorphic_identity = _mapper_args.get("polymorphic_identity")
650+
polymorphic_abstract = _mapper_args.get("polymorphic_abstract")
651+
has_polymorphic = (
652+
polymorphic_identity is not None or polymorphic_abstract is not None
653+
)
654+
655+
# allow polymorphic models inherit from table models
656+
if is_table_model_class(cls) and (not base_is_table or has_polymorphic):
599657
for rel_name, rel_info in cls.__sqlmodel_relationships__.items():
600658
if rel_info.sa_relationship:
601659
# There's a SQLAlchemy relationship declared, that takes precedence
@@ -703,13 +761,13 @@ def get_sqlalchemy_type(field: Any) -> Any:
703761
raise ValueError(f"{type_} has no matching SQLAlchemy type")
704762

705763

706-
def get_column_from_field(field: Any) -> Column: # type: ignore
764+
def get_column_from_field(field: Any) -> Union[Column, MappedColumn]: # type: ignore
707765
if IS_PYDANTIC_V2:
708766
field_info = field
709767
else:
710768
field_info = field.field_info
711769
sa_column = getattr(field_info, "sa_column", Undefined)
712-
if isinstance(sa_column, Column):
770+
if isinstance(sa_column, Column) or isinstance(sa_column, MappedColumn):
713771
return sa_column
714772
sa_type = get_sqlalchemy_type(field)
715773
primary_key = getattr(field_info, "primary_key", Undefined)
@@ -773,7 +831,6 @@ def get_column_from_field(field: Any) -> Column: # type: ignore
773831
class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry):
774832
# SQLAlchemy needs to set weakref(s), Pydantic will set the other slots values
775833
__slots__ = ("__weakref__",)
776-
__tablename__: ClassVar[Union[str, Callable[..., str]]]
777834
__sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty[Any]]]
778835
__name__: ClassVar[str]
779836
metadata: ClassVar[MetaData]
@@ -837,10 +894,6 @@ def __repr_args__(self) -> Sequence[Tuple[Optional[str], Any]]:
837894
if not (isinstance(k, str) and k.startswith("_sa_"))
838895
]
839896

840-
@declared_attr # type: ignore
841-
def __tablename__(cls) -> str:
842-
return cls.__name__.lower()
843-
844897
@classmethod
845898
def model_validate( # type: ignore[override]
846899
cls: Type[_TSQLModel],

0 commit comments

Comments
 (0)