diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 300031de8b..1b88869a42 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -1,8 +1,10 @@ from __future__ import annotations import builtins +import inspect as inspect_module import ipaddress import uuid +import warnings import weakref from collections.abc import Callable, Mapping, Sequence, Set from dataclasses import dataclass @@ -12,6 +14,7 @@ from pathlib import Path from typing import ( TYPE_CHECKING, + Annotated, Any, ClassVar, Literal, @@ -24,6 +27,7 @@ ) from pydantic import BaseModel, EmailStr +from pydantic import Field as PydanticField from pydantic.fields import FieldInfo as PydanticFieldInfo from sqlalchemy import ( Boolean, @@ -90,6 +94,15 @@ ) OnDeleteType = Literal["CASCADE", "SET NULL", "RESTRICT"] +SCHEMA_EXTRA_DEPRECATION_MSG = ( + "This parameter is deprecated.\n" + "Use `json_schema_extra` to add extra information to JSON schema." +) + +FIELD_ACCEPTED_KWARGS = set(inspect_module.signature(PydanticField).parameters.keys()) +if "schema_extra" in FIELD_ACCEPTED_KWARGS: + FIELD_ACCEPTED_KWARGS.remove("schema_extra") + def __dataclass_transform__( *, @@ -271,7 +284,11 @@ def Field( sa_type: type[Any] | UndefinedType = Undefined, sa_column_args: Sequence[Any] | UndefinedType = Undefined, sa_column_kwargs: Mapping[str, Any] | UndefinedType = Undefined, - schema_extra: dict[str, Any] | None = None, + schema_extra: Annotated[ + dict[str, Any] | None, + deprecated(SCHEMA_EXTRA_DEPRECATION_MSG), + ] = None, + json_schema_extra: dict[str, Any] | None = None, ) -> Any: ... @@ -315,7 +332,11 @@ def Field( sa_type: type[Any] | UndefinedType = Undefined, sa_column_args: Sequence[Any] | UndefinedType = Undefined, sa_column_kwargs: Mapping[str, Any] | UndefinedType = Undefined, - schema_extra: dict[str, Any] | None = None, + schema_extra: Annotated[ + dict[str, Any] | None, + deprecated(SCHEMA_EXTRA_DEPRECATION_MSG), + ] = None, + json_schema_extra: dict[str, Any] | None = None, ) -> Any: ... @@ -359,7 +380,11 @@ def Field( discriminator: str | None = None, repr: bool = True, sa_column: Column[Any] | UndefinedType = Undefined, - schema_extra: dict[str, Any] | None = None, + schema_extra: Annotated[ + dict[str, Any] | None, + deprecated(SCHEMA_EXTRA_DEPRECATION_MSG), + ] = None, + json_schema_extra: dict[str, Any] | None = None, ) -> Any: ... @@ -401,9 +426,23 @@ def Field( sa_column: Column | UndefinedType = Undefined, # type: ignore sa_column_args: Sequence[Any] | UndefinedType = Undefined, sa_column_kwargs: Mapping[str, Any] | UndefinedType = Undefined, - schema_extra: dict[str, Any] | None = None, + schema_extra: Annotated[ + dict[str, Any] | None, + deprecated(SCHEMA_EXTRA_DEPRECATION_MSG), + ] = None, + json_schema_extra: dict[str, Any] | None = None, ) -> Any: + if schema_extra: + warnings.warn( + "schema_extra parameter is deprecated. " + "Use json_schema_extra to add extra information to JSON schema.", + DeprecationWarning, + stacklevel=1, + ) + + current_json_schema_extra = json_schema_extra or {} current_schema_extra = schema_extra or {} + # Extract possible alias settings from schema_extra so we can control precedence schema_validation_alias = current_schema_extra.pop("validation_alias", None) schema_serialization_alias = current_schema_extra.pop("serialization_alias", None) @@ -451,6 +490,21 @@ def Field( serialization_alias or schema_serialization_alias or alias ) + # Handle a workaround when json_schema_extra was passed via schema_extra + if "json_schema_extra" in current_schema_extra: + json_schema_extra_from_schema_extra = current_schema_extra.pop( + "json_schema_extra" + ) + if not current_json_schema_extra: + current_json_schema_extra = json_schema_extra_from_schema_extra + # Split parameters from schema_extra to field_info_kwargs and json_schema_extra + for key, value in current_schema_extra.items(): + if key in FIELD_ACCEPTED_KWARGS: + field_info_kwargs[key] = value + else: + current_json_schema_extra[key] = value + field_info_kwargs["json_schema_extra"] = current_json_schema_extra + field_info = FieldInfo( default, default_factory=default_factory, diff --git a/tests/test_field_json_schema_extra.py b/tests/test_field_json_schema_extra.py new file mode 100644 index 0000000000..0cbe3dafaa --- /dev/null +++ b/tests/test_field_json_schema_extra.py @@ -0,0 +1,85 @@ +import pytest +from sqlmodel import Field, SQLModel + + +def test_json_schema_extra_applied(): + """test json_schema_extra is applied to the field""" + + class Item(SQLModel): + name: str = Field( + json_schema_extra={ + "example": "Sword of Power", + "x-custom-key": "Important Data", + } + ) + + schema = Item.model_json_schema() + name_schema = schema["properties"]["name"] + + assert name_schema["example"] == "Sword of Power" + assert name_schema["x-custom-key"] == "Important Data" + + +def test_schema_extra_and_json_schema_extra_conflict(): + """ + Test that passing schema_extra and json_schema_extra at the same time produces + a warning. + """ + + with pytest.warns(DeprecationWarning, match="schema_extra parameter is deprecated"): + Field(schema_extra={"legacy": 1}, json_schema_extra={"new": 2}) + + +def test_schema_extra_backward_compatibility(): + """ + test that schema_extra is backward compatible with json_schema_extra + """ + + with pytest.warns(DeprecationWarning, match="schema_extra parameter is deprecated"): + + class LegacyItem(SQLModel): + name: str = Field( + schema_extra={ + "example": "Sword of Old", + "x-custom-key": "Important Data", + "serialization_alias": "id_test", + } + ) + + schema = LegacyItem.model_json_schema() + name_schema = schema["properties"]["name"] + + assert name_schema["example"] == "Sword of Old" + assert name_schema["x-custom-key"] == "Important Data" + + # With Pydantic V1 serialization_alias from schema_extra is applied + field_info = LegacyItem.model_fields["name"] + assert field_info.serialization_alias == "id_test" + + +def test_json_schema_extra_mix_in_schema_extra(): + """ + Test workaround when json_schema_extra was passed via schema_extra. + """ + + with pytest.warns(DeprecationWarning, match="schema_extra parameter is deprecated"): + + class Item(SQLModel): + name: str = Field( + schema_extra={ + "json_schema_extra": { + "example": "Sword of Power", + "x-custom-key": "Important Data", + }, + "serialization_alias": "id_test", + } + ) + + schema = Item.model_json_schema() + + name_schema = schema["properties"]["name"] + assert name_schema["example"] == "Sword of Power" + assert name_schema["x-custom-key"] == "Important Data" + + field_info = Item.model_fields["name"] + assert field_info.serialization_alias == "id_test"