|
1 | | -from typing import Literal, Optional |
| 1 | +from typing import Literal, Optional, Union |
2 | 2 |
|
3 | 3 | import pytest |
4 | 4 | from sqlalchemy.exc import IntegrityError |
| 5 | +from sqlalchemy import text |
5 | 6 | from sqlalchemy.orm import RelationshipProperty |
6 | 7 | from sqlmodel import Field, Relationship, Session, SQLModel, create_engine, select |
7 | 8 |
|
@@ -127,7 +128,7 @@ class Hero(SQLModel, table=True): |
127 | 128 | assert hero_rusty_man.team.name == "Preventers" |
128 | 129 |
|
129 | 130 |
|
130 | | -def test_literal_str(clear_sqlmodel, caplog): |
| 131 | +def test_literal_valid_values(clear_sqlmodel, caplog): |
131 | 132 | """Test https://github.com/fastapi/sqlmodel/issues/57""" |
132 | 133 |
|
133 | 134 | class Model(SQLModel, table=True): |
@@ -172,3 +173,101 @@ class Model(SQLModel, table=True): |
172 | 173 | assert obj.int_bool == 1 |
173 | 174 | assert isinstance(obj.all_bool, bool) |
174 | 175 | assert obj.all_bool is False |
| 176 | + |
| 177 | + |
| 178 | +def test_literal_constraints_invalid_values(clear_sqlmodel): |
| 179 | + """DB should reject values that are not part of the Literal choices.""" |
| 180 | + |
| 181 | + class Model(SQLModel, table=True): |
| 182 | + id: Optional[int] = Field(default=None, primary_key=True) |
| 183 | + all_str: Literal["a", "b", "c"] |
| 184 | + mixed: Literal["yes", "no", 1, 0] |
| 185 | + all_int: Literal[1, 2, 3] |
| 186 | + int_bool: Literal[0, 1, True, False] |
| 187 | + all_bool: Literal[True, False] |
| 188 | + |
| 189 | + engine = create_engine("sqlite://") |
| 190 | + SQLModel.metadata.create_all(engine) |
| 191 | + |
| 192 | + # Helper to attempt a raw insert that bypasses Pydantic validation so we |
| 193 | + # can verify that the database-level CHECK constraints are enforced. |
| 194 | + def insert_raw(values: dict[str, object]) -> None: |
| 195 | + stmt = text( |
| 196 | + "INSERT INTO model (all_str, mixed, all_int, int_bool, all_bool) " |
| 197 | + "VALUES (:all_str, :mixed, :all_int, :int_bool, :all_bool)" |
| 198 | + ).bindparams(**values) |
| 199 | + with pytest.raises(IntegrityError): |
| 200 | + with Session(engine) as session: |
| 201 | + session.exec(stmt) |
| 202 | + session.commit() |
| 203 | + |
| 204 | + # Invalid string literal for all_str |
| 205 | + insert_raw( |
| 206 | + { |
| 207 | + "all_str": "z", # invalid, not in {"a","b","c"} |
| 208 | + "mixed": "yes", |
| 209 | + "all_int": 1, |
| 210 | + "int_bool": 1, |
| 211 | + "all_bool": 0, |
| 212 | + } |
| 213 | + ) |
| 214 | + |
| 215 | + # Invalid int literal for all_int |
| 216 | + insert_raw( |
| 217 | + { |
| 218 | + "all_str": "a", |
| 219 | + "mixed": "yes", |
| 220 | + "all_int": 5, # invalid, not in {1,2,3} |
| 221 | + "int_bool": 1, |
| 222 | + "all_bool": 0, |
| 223 | + } |
| 224 | + ) |
| 225 | + |
| 226 | + # Invalid bool literal for all_bool |
| 227 | + insert_raw( |
| 228 | + { |
| 229 | + "all_str": "a", |
| 230 | + "mixed": "yes", |
| 231 | + "all_int": 1, |
| 232 | + "int_bool": 1, |
| 233 | + "all_bool": 2, # invalid boolean value |
| 234 | + } |
| 235 | + ) |
| 236 | + |
| 237 | + |
| 238 | +def test_literal_optional_and_union_constraints(clear_sqlmodel): |
| 239 | + """Literals inside Optional/Union should also be enforced at the DB level.""" |
| 240 | + |
| 241 | + class Model(SQLModel, table=True): |
| 242 | + id: Optional[int] = Field(default=None, primary_key=True) |
| 243 | + opt_str: Optional[Literal["x", "y"]] = None |
| 244 | + union_int: Union[Literal[10, 20], None] = None |
| 245 | + |
| 246 | + engine = create_engine("sqlite://") |
| 247 | + SQLModel.metadata.create_all(engine) |
| 248 | + |
| 249 | + # Valid values should be accepted |
| 250 | + obj = Model(opt_str="x", union_int=10) |
| 251 | + with Session(engine) as session: |
| 252 | + session.add(obj) |
| 253 | + session.commit() |
| 254 | + session.refresh(obj) |
| 255 | + assert obj.opt_str == "x" |
| 256 | + assert obj.union_int == 10 |
| 257 | + |
| 258 | + # Invalid values should be rejected by the database |
| 259 | + def insert_raw(values: dict[str, object]) -> None: |
| 260 | + stmt = text( |
| 261 | + "INSERT INTO model (opt_str, union_int) " |
| 262 | + "VALUES (:opt_str, :union_int)" |
| 263 | + ).bindparams(**values) |
| 264 | + with pytest.raises(IntegrityError): |
| 265 | + with Session(engine) as session: |
| 266 | + session.exec(stmt) |
| 267 | + session.commit() |
| 268 | + |
| 269 | + # opt_str not in {"x", "y"} |
| 270 | + insert_raw({"opt_str": "z", "union_int": 10}) |
| 271 | + |
| 272 | + # union_int not in {10, 20} |
| 273 | + insert_raw({"opt_str": "x", "union_int": 30}) |
0 commit comments