Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions polyfactory/factories/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1071,7 +1071,9 @@ def process_kwargs(cls, **kwargs: Any) -> dict[str, Any]:

for field_meta in cls.get_model_fields():
field_build_parameters = cls.extract_field_build_parameters(field_meta=field_meta, build_args=kwargs)
if cls.should_set_field_value(field_meta, **kwargs) and not cls.should_use_default_value(field_meta):
if cls.should_set_field_value(
field_meta, _build_context=_build_context, **kwargs
) and not cls.should_use_default_value(field_meta):
if hasattr(cls, field_meta.name) and not hasattr(BaseFactory, field_meta.name):
field_value = getattr(cls, field_meta.name)
if isinstance(field_value, Ignore):
Expand Down Expand Up @@ -1122,7 +1124,7 @@ def process_kwargs_coverage(cls, **kwargs: Any) -> abc.Iterable[dict[str, Any]]:
for field_meta in cls.get_model_fields():
field_build_parameters = cls.extract_field_build_parameters(field_meta=field_meta, build_args=kwargs)

if cls.should_set_field_value(field_meta, **kwargs):
if cls.should_set_field_value(field_meta, _build_context=_build_context, **kwargs):
if hasattr(cls, field_meta.name) and not hasattr(BaseFactory, field_meta.name):
field_value = getattr(cls, field_meta.name)
if isinstance(field_value, Ignore):
Expand Down
53 changes: 51 additions & 2 deletions polyfactory/factories/sqlalchemy_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@
Protocol,
TypeVar,
Union,
cast,
)

from sqlalchemy.util.langhelpers import duck_type_collection

from polyfactory.exceptions import ConfigurationException, MissingDependencyException, ParameterException
from polyfactory.factories.base import BaseFactory
from polyfactory.factories.base import BuildContext as BaseBuildContext
from polyfactory.field_meta import Constraints, FieldMeta
from polyfactory.persistence import AsyncPersistenceProtocol, SyncPersistenceProtocol
from polyfactory.utils.types import Frozendict
Expand All @@ -38,12 +40,20 @@
from sqlalchemy.ext.asyncio import AsyncSession, async_scoped_session
from sqlalchemy.orm import Session, scoped_session
from sqlalchemy.sql.type_api import TypeEngine
from typing_extensions import TypeGuard
from typing_extensions import NotRequired, TypeGuard


T = TypeVar("T")


class SQLAlchemyBuildContext(BaseBuildContext):
skip_computed_fields: bool


class SQLAlchemyConstraints(Constraints):
computed: NotRequired[bool]


class SQLAlchemyPersistenceMethod(enum.Enum):
FLUSH = "flush"
COMMIT = "commit"
Expand Down Expand Up @@ -146,6 +156,30 @@ class SQLAlchemyFactory(Generic[T], BaseFactory[T]):
"__persistence_method__",
)

@classmethod
def _get_build_context(
cls, build_context: BaseBuildContext | SQLAlchemyBuildContext | None
) -> SQLAlchemyBuildContext:
build_context = cast("SQLAlchemyBuildContext", super()._get_build_context(build_context))
if build_context.get("skip_computed_fields") is None:
build_context["skip_computed_fields"] = False

return build_context

@classmethod
def create_sync(cls, **kwargs: Any) -> T:
build_context = cls._get_build_context(kwargs.get("_build_context"))
build_context["skip_computed_fields"] = True
kwargs["_build_context"] = build_context
return super().create_sync(**kwargs)

@classmethod
async def create_async(cls, **kwargs: Any) -> T:
build_context = cls._get_build_context(kwargs.get("_build_context"))
build_context["skip_computed_fields"] = True
kwargs["_build_context"] = build_context
return await super().create_async(**kwargs)

@classmethod
def get_sqlalchemy_types(cls) -> dict[Any, Callable[[], Any]]:
"""Get mapping of types where column type should be used directly.
Expand Down Expand Up @@ -200,6 +234,17 @@ def is_supported_type(cls, value: Any) -> TypeGuard[type[T]]:
return False
return isinstance(inspected, (Mapper, InstanceState))

@classmethod
def should_set_field_value(cls, field_meta: FieldMeta, **kwargs: Any) -> bool:
build_context = kwargs.get("_build_context", {})

if field_meta.constraints:
constraints = cast("SQLAlchemyConstraints", field_meta.constraints)
if constraints.get("computed") and build_context.get("skip_computed_fields"):
return False

return super().should_set_field_value(field_meta, **kwargs)

@classmethod
def should_column_be_set(cls, column: Any) -> bool:
if not isinstance(column, Column):
Expand Down Expand Up @@ -238,7 +283,7 @@ def _get_type_from_type_engine(cls, type_engine: TypeEngine) -> type:
raise ParameterException(msg) from None
annotation = type_engine.impl.python_type # pyright: ignore[reportAttributeAccessIssue]

constraints: Constraints = {}
constraints: SQLAlchemyConstraints = {}
for type_, constraint_fields in cls.get_sqlalchemy_constraints().items():
if not isinstance(type_engine, type_):
continue
Expand All @@ -262,6 +307,10 @@ def get_type_from_column(cls, column: Column) -> type:
if column.nullable:
annotation = Union[annotation, None] # type: ignore[assignment]

if column.computed:
constraints: SQLAlchemyConstraints = {"computed": True}
annotation = Annotated[annotation, Frozendict(constraints)] # type: ignore[assignment]

return annotation

@classmethod
Expand Down
10 changes: 9 additions & 1 deletion tests/sqlalchemy_factory/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dataclasses import dataclass
from typing import Any, Optional

from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, String, func, orm, text
from sqlalchemy import Boolean, Column, Computed, DateTime, ForeignKey, Integer, String, func, orm, text
from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy.orm import relationship
from sqlalchemy.orm.decl_api import DeclarativeMeta, registry
Expand Down Expand Up @@ -146,3 +146,11 @@ class Employee(Base):
name = Column(String)
company_id = Column(Integer, ForeignKey("companies.id"))
company = relationship(Company, back_populates="employees")


class Shape(Base):
__tablename__ = "shape"

id = Column(Integer, primary_key=True)
side: Any = Column(Integer(), nullable=False, default=10)
area: Any = Column(Integer, Computed("side * side"), nullable=False)
29 changes: 29 additions & 0 deletions tests/sqlalchemy_factory/test_sqlalchemy_factory_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
CollectionChildMixin,
CollectionParentMixin,
NonSQLAchemyClass,
Shape,
)
from tests.sqlalchemy_factory.types import ListLike, SetLike

Expand Down Expand Up @@ -132,6 +133,34 @@ class ModelFactory(SQLAlchemyFactory[Model]): ...
assert instance.age * 3 == instance.triple_age


def test_computed_column_sync_persistence(engine: Engine) -> None:
Base.metadata.create_all(engine)

class ShapeFactory(SQLAlchemyFactory[Shape]):
__model__ = Shape
__session__ = Session(engine)

instance = ShapeFactory.create_sync()
assert instance.area == pow(instance.side, 2)


async def test_computed_column_async_persistence(engine: Engine, async_engine: AsyncEngine) -> None:
class ShapeFactory(SQLAlchemyFactory[Shape]):
__model__ = Shape
__async_session__ = AsyncSession(async_engine)

instance = await ShapeFactory.create_async()
assert instance.area == pow(instance.side, 2)


def test_computed_column_no_persistence() -> None:
class ShapeFactory(SQLAlchemyFactory[Shape]):
__model__ = Shape

fields = ShapeFactory.get_model_fields()
assert "area" in [field.name for field in fields]


@pytest.mark.parametrize(
"type_",
tuple(SQLAlchemyFactory.get_sqlalchemy_types().keys()),
Expand Down
Loading