Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@ dependencies = [

[dependency-groups]
dev = [
"pyright>=1.1.407",
"pytest>=8.3.4",
"pytest-benchmark>=5.1.0",
"ruff>=0.14.2",
]

[build-system]
Expand Down
416 changes: 334 additions & 82 deletions src/deigma/proxy.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions src/deigma/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __dataclass_fields__(self) -> dict[str, Field]: ...
if TYPE_CHECKING:

@final
@runtime_checkable
class Dataclass(Protocol):
__dataclass_fields__: ClassVar[dict[str, Any]]

Expand Down
49 changes: 26 additions & 23 deletions src/deigma/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def template(
path: None = None,
serialize: Serialize = DEFAULT_SERIALIZE,
use_proxy: bool = USE_PROXY,
) -> Callable[[type[T]], type[Template]]: ...
) -> Callable[[type[T]], type[T]]: ...


@overload
Expand All @@ -46,7 +46,7 @@ def template(
path: str | PathLike,
serialize: Serialize = DEFAULT_SERIALIZE,
use_proxy: bool = USE_PROXY,
) -> Callable[[type[T]], type[Template]]: ...
) -> Callable[[type[T]], type[T]]: ...


@dataclass_transform()
Expand All @@ -56,7 +56,7 @@ def template(
path: str | PathLike | None = None,
serialize: Serialize = DEFAULT_SERIALIZE,
use_proxy: bool = USE_PROXY,
) -> Callable[[type[T]], type[Template]]:
) -> Callable[[type[T]], type[T]]:
if source is None and path is None:
raise ValueError("Either source or path must be provided")

Expand Down Expand Up @@ -87,26 +87,26 @@ def inline_template(
*,
serialize: Serialize = DEFAULT_SERIALIZE,
use_proxy: bool = USE_PROXY,
) -> Callable[[type[T]], type[Template]]:
) -> Callable[[type[T]], type[T]]:
def decorator(cls: type[T]) -> type[T]:
config = ConfigDict(arbitrary_types_allowed=True)
cls = pydantic_dataclass(init=False, config=config)(
type(cls.__name__, (cls,), dict(cls.__dict__))
)
cls.__is_deigma_template__ = True
cls._source = cleandoc(source)
# Apply pydantic_dataclass directly to preserve type information
cls = pydantic_dataclass(config=config)(cls) # pyright: ignore[reportAssignmentType]
# Add template-specific class attributes (not visible to type checker)
cls.__is_deigma_template__ = True # pyright: ignore[reportAttributeAccessIssue]
cls._source = cleandoc(source) # pyright: ignore[reportAttributeAccessIssue]
engine = Jinja2Engine(serialize=serialize)
cls._engine = engine
cls._variables = engine.introspect_variables(source)
cls._engine = engine # pyright: ignore[reportAttributeAccessIssue]
cls._variables = engine.introspect_variables(source) # pyright: ignore[reportAttributeAccessIssue]

static_fields = set(cls.__annotations__)
properties = {
prop for prop in vars(cls) if isinstance(getattr(cls, prop), property)
}
fields = static_fields | properties

if not set(cls._variables).issubset(fields):
variables = cls._variables
variables = cls._variables # pyright: ignore[reportAttributeAccessIssue]
if not set(variables).issubset(fields):
msg = (
"Template variables mismatch. Template fields must match variables in source:\n\n"
f"fields on type: {fields}, variables in source: {variables}"
Expand All @@ -123,38 +123,41 @@ def decorator(cls: type[T]) -> type[T]:

raise ValueError(msg)

cls._compiled_template = engine.compile_template(cls._source)
cls._type_adapter = TypeAdapter(cls)
cls._compiled_template = engine.compile_template(cls._source) # pyright: ignore[reportAttributeAccessIssue]
cls._type_adapter = TypeAdapter(cls) # pyright: ignore[reportAttributeAccessIssue]

if use_proxy:

def __str__(instance):
proxied = {
field: getattr(instance._proxy, field) for field in cls._variables
field: getattr(instance._proxy, field) # pyright: ignore[reportAttributeAccessIssue]
for field in cls._variables # pyright: ignore[reportAttributeAccessIssue]
}
return instance._compiled_template.render(proxied)
return instance._compiled_template.render(proxied) # pyright: ignore[reportAttributeAccessIssue]

original_init = cls.__init__

def __init__(instance, *args, **kwargs):
original_init(instance, *args, **kwargs)
instance._proxy = SerializationProxy.build(instance, cls._type_adapter)
instance._proxy = SerializationProxy.build( # pyright: ignore[reportAttributeAccessIssue]
instance, cls._type_adapter # pyright: ignore[reportAttributeAccessIssue]
)

cls.__init__ = __init__
cls.__init__ = __init__ # pyright: ignore[reportAttributeAccessIssue]

else:

def __str__(instance):
serialized = cls._type_adapter.dump_python(instance)
serialized = cls._type_adapter.dump_python(instance) # pyright: ignore[reportAttributeAccessIssue]
rendered_fields = {
field: _render_field_maybe(
getattr(instance, field), serialized[field]
)
for field in cls._variables
for field in cls._variables # pyright: ignore[reportAttributeAccessIssue]
}
return instance._compiled_template.render(rendered_fields)
return instance._compiled_template.render(rendered_fields) # pyright: ignore[reportAttributeAccessIssue]

cls.__str__ = __str__
cls.__str__ = __str__ # pyright: ignore[reportAttributeAccessIssue]

return cls

Expand Down
30 changes: 17 additions & 13 deletions src/deigma/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ class TemplateKwargs(TypedDict, total=False):
def copy(cls: type[_T]) -> type[_T]:
copy_ = type(cls.__name__, (cls,), {})
copy_.__annotations__ = cls.__annotations__
return template(cls._source, cls._engine._env.serialize)(copy_)
# Access dynamically added template attributes
return template(cls._source, cls._engine._env.serialize)(copy_) # pyright: ignore[reportAttributeAccessIssue, reportCallIssue]


def replace(instance: _T, **kwargs: Any) -> _T:
Expand All @@ -29,34 +30,37 @@ def replace(instance: _T, **kwargs: Any) -> _T:

def with_serialize(cls: type[_T], serialize: Serialize) -> type[_T]:
if is_template_type(cls):
copy_: type[_T] = copy(cls)
copy_._engine._env.serialize = serialize
copy_: type[_T] = copy(cls) # pyright: ignore[reportArgumentType]
# Access and modify dynamically added template attributes
copy_._engine._env.serialize = serialize # pyright: ignore[reportAttributeAccessIssue]
return copy_
return cls


def with_source(cls: type[_T], source: str) -> type[_T]:
if is_template_type(cls):
copy_: type[_T] = copy(cls)
copy_._source = source
copy_._compiled_template = copy_._engine._env.from_string(cleandoc(source))
parsed_source = copy_._engine._env.parse(source)
copy_._variables = meta.find_undeclared_variables(parsed_source)
copy_: type[_T] = copy(cls) # pyright: ignore[reportArgumentType]
# Access and modify dynamically added template attributes
copy_._source = source # pyright: ignore[reportAttributeAccessIssue]
copy_._compiled_template = copy_._engine._env.from_string(cleandoc(source)) # pyright: ignore[reportAttributeAccessIssue]
parsed_source = copy_._engine._env.parse(source) # pyright: ignore[reportAttributeAccessIssue]
copy_._variables = meta.find_undeclared_variables(parsed_source) # pyright: ignore[reportAttributeAccessIssue]
return copy_
return cls


def with_(cls_or_instance: type[_T] | _T, **kwargs: Unpack[TemplateKwargs]) -> type[_T]:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The return type hint for this function appears to be incorrect. It's declared as type[_T], but when an instance is passed as cls_or_instance, the function returns a new instance of _T (on line 65), not a type. This causes a type mismatch that is currently suppressed with pyright: ignore[reportReturnType].

The return type should be type[_T] | _T to accurately reflect both execution paths. For even better type inference, you could consider using @overload to define separate signatures for when a type is passed versus when an instance is passed.

Suggested change
def with_(cls_or_instance: type[_T] | _T, **kwargs: Unpack[TemplateKwargs]) -> type[_T]:
def with_(cls_or_instance: type[_T] | _T, **kwargs: Unpack[TemplateKwargs]) -> type[_T] | _T:

if is_template_type(cls_or_instance):
copy_: type[_T] = copy(cls_or_instance)
if is_template_type(cls_or_instance): # pyright: ignore[reportArgumentType]
copy_: type[_T] = copy(cls_or_instance) # pyright: ignore[reportArgumentType]
match kwargs:
case {"source": source}:
copy_ = with_source(copy_, source)
case {"serialize": serialize}:
copy_ = with_serialize(copy_, serialize)
return copy_
if is_template(instance := cls_or_instance):
fields = {k: v for k, v in vars(instance).items() if k in instance._variables}
copy_: type[_T] = copy(type(instance))
return with_(copy_, **kwargs)(**fields)
# Access dynamically added template attributes
fields = {k: v for k, v in vars(instance).items() if k in instance._variables} # pyright: ignore[reportAttributeAccessIssue]
copy_: type[_T] = copy(type(instance)) # pyright: ignore[reportArgumentType, reportAssignmentType]
return with_(copy_, **kwargs)(**fields) # pyright: ignore[reportReturnType]
raise ValueError(f"Expected a Template type or instance, got {cls_or_instance!r}")
1 change: 1 addition & 0 deletions src/deigma/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def __is_deigma_template__(self) -> Literal[True]: ...
if TYPE_CHECKING:
from pydantic._internal._dataclasses import PydanticDataclass

@runtime_checkable
class Template(Template, PydanticDataclass):
@property
def __is_deigma_template__(self) -> Literal[True]: ...
6 changes: 3 additions & 3 deletions tests/benches/test_serialization_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import Any

import pytest
from pydantic import BaseModel, Field, TypeAdapter, field_serializer
from pydantic import BaseModel, TypeAdapter, field_serializer

from deigma.proxy import SerializationProxy

Expand Down Expand Up @@ -406,7 +406,7 @@ def complete_workflow():
proxy = SerializationProxy.build(nested_model)
_ = proxy.id
_ = proxy.data.name
items_len = len(proxy.items)
_ = len(proxy.items)
count = 0
for item in proxy.items:
count += 1
Expand All @@ -423,7 +423,7 @@ def test_benchmark_direct_complete_workflow(benchmark, nested_model: NestedModel
def complete_workflow():
_ = nested_model.id
_ = nested_model.data.name
items_len = len(nested_model.items)
_ = len(nested_model.items)
count = 0
for item in nested_model.items:
count += 1
Expand Down
65 changes: 40 additions & 25 deletions tests/integration/test_field_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,55 @@
"""

from dataclasses import dataclass
from typing import Annotated, TypedDict
from typing import Annotated, TypeAlias, TypedDict

import pytest
from pydantic import PlainSerializer, field_serializer
from pydantic import PlainSerializer, WrapSerializer, field_serializer

from deigma import template


# Fixtures for reusable test models
@pytest.fixture
def user_inline_type():
"""TypedDict User for inline serialization tests."""
# Module-level type definitions for use in tests
class User(TypedDict):
"""TypedDict for user data."""
first_name: str
last_name: str


UserInlineType: TypeAlias = Annotated[
User, PlainSerializer(lambda user: f"{user['first_name']} {user['last_name']}")
]

UserWrapType: TypeAlias = Annotated[
User,
WrapSerializer(
lambda user, nxt: f"User: {nxt(user)['first_name']} {nxt(user)['last_name']}",
return_type=str,
),
]

Keywords: TypeAlias = Annotated[list[str], PlainSerializer(lambda xs: ", ".join(xs))]

SQLKeywordName: TypeAlias = Annotated[str, PlainSerializer(lambda keyword: keyword.upper())]

PhoneNumber: TypeAlias = Annotated[
str,
PlainSerializer(
lambda phone: f"({phone[:3]}) {phone[3:6]}-{phone[6:]}", return_type=str
),
]

class User(TypedDict):
first_name: str
last_name: str

# Fixtures for backward compatibility
@pytest.fixture
def user_inline_type() -> type[User]:
"""Fixture returning User type."""
return User


# Field Serializer with Decorator
def test_field_serializer_decorator(user_inline_type):
def test_field_serializer_decorator():
"""Test field serialization using @field_serializer."""
User = user_inline_type

@template("{{ user }}")
class UserTemplate:
user: User
Expand All @@ -44,16 +68,11 @@ def inline_user(self, user: User) -> str:


# Plain Serializer Annotation
def test_plain_serializer_annotation(user_inline_type):
def test_plain_serializer_annotation():
"""Test field serialization using PlainSerializer annotation."""
User = user_inline_type
UserInline = Annotated[
User, PlainSerializer(lambda user: f"{user['first_name']} {user['last_name']}")
]

@template("{{ user }}")
class UserTemplate:
user: UserInline
user: UserInlineType

result = str(UserTemplate(user={"first_name": "Li", "last_name": "Si"}))
assert result.strip() == "Li Si"
Expand All @@ -62,8 +81,6 @@ class UserTemplate:
# SQL Keyword Example from README
def test_sql_keyword_example():
"""Test the SQL keyword example from README - tests field serializers in loops."""
SQLKeywordName = Annotated[str, PlainSerializer(lambda keyword: keyword.upper())]

@dataclass
class SQLKeyword:
name: SQLKeywordName
Expand Down Expand Up @@ -96,8 +113,6 @@ class SQLKeywordListingTemplate:

def test_sql_keyword_literal_rendering():
"""Test field serializers applied when rendering compound object natively."""
SQLKeywordName = Annotated[str, PlainSerializer(lambda keyword: keyword.upper())]

@dataclass
class SQLKeyword:
name: SQLKeywordName
Expand Down Expand Up @@ -269,11 +284,11 @@ def counting_serializer(value: str) -> str:
counter.count += 1
return value.upper()

Name = Annotated[str, PlainSerializer(counting_serializer)]
Name = Annotated[str, PlainSerializer(counting_serializer)] # pyright: ignore[reportInvalidTypeForm]

@dataclass
class Person:
name: Name
name: Name # pyright: ignore[reportInvalidTypeForm]
title: str

@template(
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_template_rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from pydantic import Field

from deigma import replace, template
from deigma.serialize import serialize_json, serialize_str
from deigma.serialize import serialize_json


# Fixtures
Expand Down
Loading