Skip to content

Commit e73709b

Browse files
committed
create database constraints from the Literal values
1 parent 41f61aa commit e73709b

File tree

3 files changed

+161
-8
lines changed

3 files changed

+161
-8
lines changed

sqlmodel/_compat.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,34 @@ def _is_union_type(t: Any) -> bool:
6565
return t is UnionType or t is Union
6666

6767

68+
def get_literal_annotation_info(annotation: Any) -> tuple[type[Any], tuple[Any, ...]] | None:
69+
if annotation is None or get_origin(annotation) is None:
70+
return None
71+
origin = get_origin(annotation)
72+
if origin is Annotated:
73+
return get_literal_annotation_info(get_args(annotation)[0])
74+
if _is_union_type(origin):
75+
bases = get_args(annotation)
76+
if len(bases) > 2:
77+
raise ValueError("Cannot have a Union with more than 2 members")
78+
if bases[0] is not NoneType and bases[1] is not NoneType:
79+
raise ValueError("Cannot have a Union without None")
80+
use_type = bases[0] if bases[0] is not NoneType else bases[1]
81+
return get_literal_annotation_info(use_type)
82+
if origin is Literal:
83+
literal_args = get_args(annotation)
84+
if not literal_args:
85+
return None
86+
if all(isinstance(arg, bool) for arg in literal_args): # all bools
87+
base_type: type[Any] = bool
88+
elif all(isinstance(arg, int) for arg in literal_args): # all ints
89+
base_type = int
90+
else:
91+
base_type = str
92+
return base_type, tuple(literal_args)
93+
return None
94+
95+
6896
finish_init: ContextVar[bool] = ContextVar("finish_init", default=True)
6997

7098

@@ -191,12 +219,11 @@ def get_sa_type_from_type_annotation(annotation: Any) -> Any:
191219
use_type = bases[0] if bases[0] is not NoneType else bases[1]
192220
return get_sa_type_from_type_annotation(use_type)
193221
if origin is Literal:
194-
literal_args = get_args(annotation)
195-
if all(isinstance(arg, bool) for arg in literal_args): # all bools
196-
return bool
197-
if all(isinstance(arg, int) for arg in literal_args): # all ints
198-
return int
199-
return str
222+
literal_info = get_literal_annotation_info(annotation)
223+
if literal_info is None:
224+
raise ValueError("Literal without values is not supported")
225+
base_type, _ = literal_info
226+
return base_type
200227
return origin
201228

202229

sqlmodel/main.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from pydantic.fields import FieldInfo as PydanticFieldInfo
2727
from sqlalchemy import (
2828
Boolean,
29+
CheckConstraint,
2930
Column,
3031
Date,
3132
DateTime,
@@ -62,6 +63,7 @@
6263
finish_init,
6364
get_annotations,
6465
get_field_metadata,
66+
get_literal_annotation_info,
6567
get_model_fields,
6668
get_relationship_to,
6769
get_sa_type_from_field,
@@ -631,6 +633,31 @@ def __init__(
631633
# Ref: https://github.com/sqlalchemy/sqlalchemy/commit/428ea01f00a9cc7f85e435018565eb6da7af1b77
632634
# Tag: 1.4.36
633635
DeclarativeMeta.__init__(cls, classname, bases, dict_, **kw)
636+
table = getattr(cls, "__table__", None)
637+
if table is not None:
638+
# Attach Literal-based value constraints at the database level
639+
for field_name, field in get_model_fields(cls).items():
640+
annotation = getattr(field, "annotation", None)
641+
literal_info = get_literal_annotation_info(annotation)
642+
if literal_info is None:
643+
continue
644+
base_type, values = literal_info
645+
assert base_type in (str, int, bool)
646+
column = table.c.get(field_name)
647+
if column is None:
648+
continue
649+
if base_type is int:
650+
coerced_values = tuple(int(v) for v in values)
651+
elif base_type is bool:
652+
coerced_values = tuple(bool(v) for v in values)
653+
else:
654+
coerced_values = tuple(str(v) for v in values)
655+
constraint_name = f"ck_{table.name}_{field_name}_literal"
656+
constraint = CheckConstraint(
657+
column.in_(coerced_values),
658+
name=constraint_name,
659+
)
660+
table.append_constraint(constraint)
634661
else:
635662
ModelMetaclass.__init__(cls, classname, bases, dict_, **kw)
636663

tests/test_main.py

Lines changed: 101 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
from typing import Literal, Optional
1+
from typing import Literal, Optional, Union
22

33
import pytest
44
from sqlalchemy.exc import IntegrityError
5+
from sqlalchemy import text
56
from sqlalchemy.orm import RelationshipProperty
67
from sqlmodel import Field, Relationship, Session, SQLModel, create_engine, select
78

@@ -127,7 +128,7 @@ class Hero(SQLModel, table=True):
127128
assert hero_rusty_man.team.name == "Preventers"
128129

129130

130-
def test_literal_str(clear_sqlmodel, caplog):
131+
def test_literal_valid_values(clear_sqlmodel, caplog):
131132
"""Test https://github.com/fastapi/sqlmodel/issues/57"""
132133

133134
class Model(SQLModel, table=True):
@@ -172,3 +173,101 @@ class Model(SQLModel, table=True):
172173
assert obj.int_bool == 1
173174
assert isinstance(obj.all_bool, bool)
174175
assert obj.all_bool is False
176+
177+
178+
def test_literal_constraints_invalid_values(clear_sqlmodel):
179+
"""DB should reject values that are not part of the Literal choices."""
180+
181+
class Model(SQLModel, table=True):
182+
id: Optional[int] = Field(default=None, primary_key=True)
183+
all_str: Literal["a", "b", "c"]
184+
mixed: Literal["yes", "no", 1, 0]
185+
all_int: Literal[1, 2, 3]
186+
int_bool: Literal[0, 1, True, False]
187+
all_bool: Literal[True, False]
188+
189+
engine = create_engine("sqlite://")
190+
SQLModel.metadata.create_all(engine)
191+
192+
# Helper to attempt a raw insert that bypasses Pydantic validation so we
193+
# can verify that the database-level CHECK constraints are enforced.
194+
def insert_raw(values: dict[str, object]) -> None:
195+
stmt = text(
196+
"INSERT INTO model (all_str, mixed, all_int, int_bool, all_bool) "
197+
"VALUES (:all_str, :mixed, :all_int, :int_bool, :all_bool)"
198+
).bindparams(**values)
199+
with pytest.raises(IntegrityError):
200+
with Session(engine) as session:
201+
session.exec(stmt)
202+
session.commit()
203+
204+
# Invalid string literal for all_str
205+
insert_raw(
206+
{
207+
"all_str": "z", # invalid, not in {"a","b","c"}
208+
"mixed": "yes",
209+
"all_int": 1,
210+
"int_bool": 1,
211+
"all_bool": 0,
212+
}
213+
)
214+
215+
# Invalid int literal for all_int
216+
insert_raw(
217+
{
218+
"all_str": "a",
219+
"mixed": "yes",
220+
"all_int": 5, # invalid, not in {1,2,3}
221+
"int_bool": 1,
222+
"all_bool": 0,
223+
}
224+
)
225+
226+
# Invalid bool literal for all_bool
227+
insert_raw(
228+
{
229+
"all_str": "a",
230+
"mixed": "yes",
231+
"all_int": 1,
232+
"int_bool": 1,
233+
"all_bool": 2, # invalid boolean value
234+
}
235+
)
236+
237+
238+
def test_literal_optional_and_union_constraints(clear_sqlmodel):
239+
"""Literals inside Optional/Union should also be enforced at the DB level."""
240+
241+
class Model(SQLModel, table=True):
242+
id: Optional[int] = Field(default=None, primary_key=True)
243+
opt_str: Optional[Literal["x", "y"]] = None
244+
union_int: Union[Literal[10, 20], None] = None
245+
246+
engine = create_engine("sqlite://")
247+
SQLModel.metadata.create_all(engine)
248+
249+
# Valid values should be accepted
250+
obj = Model(opt_str="x", union_int=10)
251+
with Session(engine) as session:
252+
session.add(obj)
253+
session.commit()
254+
session.refresh(obj)
255+
assert obj.opt_str == "x"
256+
assert obj.union_int == 10
257+
258+
# Invalid values should be rejected by the database
259+
def insert_raw(values: dict[str, object]) -> None:
260+
stmt = text(
261+
"INSERT INTO model (opt_str, union_int) "
262+
"VALUES (:opt_str, :union_int)"
263+
).bindparams(**values)
264+
with pytest.raises(IntegrityError):
265+
with Session(engine) as session:
266+
session.exec(stmt)
267+
session.commit()
268+
269+
# opt_str not in {"x", "y"}
270+
insert_raw({"opt_str": "z", "union_int": 10})
271+
272+
# union_int not in {10, 20}
273+
insert_raw({"opt_str": "x", "union_int": 30})

0 commit comments

Comments
 (0)