Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ Changelog

Features:

- Add ``pre_load`` and ``post_load`` parameters to `marshmallow.fields.Field` for
field-level pre- and post-processing (:issue:`2787`).
- Typing: improvements to `marshmallow.validate` (:pr:`2940`).

4.2.4 (2026-04-02)
Expand Down
2 changes: 1 addition & 1 deletion docs/extending/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ The guides below demonstrate how to extend schemas in various ways.
.. toctree::
:maxdepth: 1

pre_and_post_processing_methods
pre_and_post_processing
schema_validation
using_original_input_data
overriding_attribute_access
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
Pre-processing and post-processing methods
==========================================
Pre-processing and post-processing
===================================

Decorator API
-------------
Expand Down Expand Up @@ -97,6 +97,41 @@ One common use case is to wrap data in a namespace upon serialization and unwrap
user_objs = user_schema.load(users_data, many=True)
# [<User(name='Keith Richards')>, <User(name='Charlie Watts')>]

.. _field_level_processing:

Field-level pre- and post-processing
-------------------------------------

For field-level processing, pass ``pre_load`` and ``post_load``
callables directly to individual fields. This is useful for simple, field-specific
transformations that don't need access to the full schema data.

Each callable receives the field value and returns a transformed value.
You can pass a single callable or a list of callables, which are applied in order.

.. code-block:: python

from marshmallow import Schema, fields


class UserSchema(Schema):
name = fields.Str(pre_load=str.strip)
birthday = fields.Date(post_load=lambda value: value.year)


schema = UserSchema()
result = schema.load({"name": " Steve ", "birthday": "1994-05-12"})
result["name"] # => 'Steve'
result["birthday"] # => 1994


``pre_load`` callables run before the field's deserialization (and before ``allow_none`` is checked),
while ``post_load`` callables run after validation and deserialization.

Like validators, ``pre_load`` and ``post_load`` callables may raise a
`ValidationError <marshmallow.exceptions.ValidationError>`, which will be
stored under the field's key in the errors dictionary.

Raising errors in pre-/post-processor methods
---------------------------------------------

Expand Down Expand Up @@ -157,11 +192,13 @@ In summary, the processing pipeline for deserialization is as follows:

1. ``@pre_load(pass_many=True)`` methods
2. ``@pre_load(pass_many=False)`` methods
3. ``load(in_data, many)`` (validation and deserialization)
4. ``@validates`` methods (field validators)
5. ``@validates_schema`` methods (schema validators)
6. ``@post_load(pass_many=True)`` methods
7. ``@post_load(pass_many=False)`` methods
3. Field-level ``pre_load`` callables
4. Field deserialization (``_deserialize``)
5. Field-level ``validate`` callables and ``@validates`` methods
6. Field-level ``post_load`` callables
7. ``@validates_schema`` methods (schema validators)
8. ``@post_load(pass_many=False)`` methods
9. ``@post_load(pass_many=True)`` methods

The pipeline for serialization is similar, except that the ``pass_many=True`` processors are invoked *after* the ``pass_many=False`` processors and there are no validators.

Expand Down
6 changes: 6 additions & 0 deletions docs/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,12 @@ You may also pass a collection (list, tuple, generator) of callables to ``valida

If you need to validate multiple fields within a single validator, see :ref:`schema_validation`.

.. seealso::

Need to *transform* a field's value?
Use the ``pre_load`` and ``post_load`` field parameters.
See :ref:`field_level_processing`.


Field validators as methods
+++++++++++++++++++++++++++
Expand Down
2 changes: 1 addition & 1 deletion docs/upgrading.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1750,7 +1750,7 @@ The pre- and post-processing API was significantly improved for better consisten
data["field_a"] -= 1
return data

See the :doc:`extending/pre_and_post_processing_methods` page for more information on the ``pre_*`` and ``post_*`` decorators.
See the :doc:`extending/pre_and_post_processing` page for more information on the ``pre_*`` and ``post_*`` decorators.

Schema validators
*****************
Expand Down
61 changes: 50 additions & 11 deletions src/marshmallow/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@
]

_InternalT = typing.TypeVar("_InternalT")
_ProcessorT = typing.TypeVar(
"_ProcessorT",
bound=types.PostLoadCallable | types.PreLoadCallable | types.Validator,
)


class _BaseFieldKwargs(typing.TypedDict, total=False):
Expand All @@ -90,6 +94,8 @@ class _BaseFieldKwargs(typing.TypedDict, total=False):
data_key: str | None
attribute: str | None
validate: types.Validator | typing.Iterable[types.Validator] | None
pre_load: types.PreLoadCallable | typing.Iterable[types.PreLoadCallable] | None
post_load: types.PostLoadCallable | typing.Iterable[types.PostLoadCallable] | None
required: bool
allow_none: bool | None
load_only: bool
Expand Down Expand Up @@ -135,6 +141,12 @@ class Field(typing.Generic[_InternalT]):
during deserialization. Validator takes a field's input value as
its only parameter and returns a boolean.
If it returns `False`, an :exc:`ValidationError` is raised.
:param pre_load: Callable or collection of callables that are applied to the
raw input value before deserialization. Each callable receives the value
and returns a transformed value.
:param post_load: Callable or collection of callables that are applied to the
deserialized value after validation. Each callable receives the value
and returns a transformed value.
:param required: Raise a :exc:`ValidationError` if the field value
is not supplied during deserialization.
:param allow_none: Set this to `True` if `None` should be considered a valid value during
Expand All @@ -159,6 +171,8 @@ class Field(typing.Generic[_InternalT]):
Use `Raw <marshmallow.fields.Raw>` or another `Field <marshmallow.fields.Field>` subclass instead.
.. versionchanged:: 4.0.0
Remove ``context`` property.
.. versionchanged:: 4.3.0
Add ``pre_load`` and ``post_load``.
"""

# Some fields, such as Method fields and Function fields, are not expected
Expand All @@ -183,6 +197,12 @@ def __init__(
data_key: str | None = None,
attribute: str | None = None,
validate: types.Validator | typing.Iterable[types.Validator] | None = None,
pre_load: (
types.PreLoadCallable | typing.Iterable[types.PreLoadCallable] | None
) = None,
post_load: (
types.PostLoadCallable | typing.Iterable[types.PostLoadCallable] | None
) = None,
required: bool = False,
allow_none: bool | None = None,
load_only: bool = False,
Expand All @@ -196,17 +216,9 @@ def __init__(
self.attribute = attribute
self.data_key = data_key
self.validate = validate
if validate is None:
self.validators = []
elif callable(validate):
self.validators = [validate]
elif utils.is_iterable_but_not_string(validate):
self.validators = list(validate)
else:
raise ValueError(
"The 'validate' parameter must be a callable "
"or a collection of callables."
)
self.validators = self._normalize_processors(validate, param="validate")
self.pre_load = self._normalize_processors(pre_load, param="pre_load")
self.post_load = self._normalize_processors(post_load, param="post_load")

# If allow_none is None and load_default is None
# None should be considered valid by default
Expand Down Expand Up @@ -369,10 +381,21 @@ def deserialize(
if value is missing_:
_miss = self.load_default
return _miss() if callable(_miss) else _miss

# Apply pre_load functions
for func in self.pre_load:
value = func(value)

if self.allow_none and value is None:
return None

output = self._deserialize(value, attr, data, **kwargs)
# Apply validators
self._validate(output)

# Apply post_load functions
for func in self.post_load:
output = func(output)
return output

# Methods for concrete classes to override.
Expand Down Expand Up @@ -433,6 +456,22 @@ def _deserialize(
"""
return value

@staticmethod
def _normalize_processors(
processors: _ProcessorT | typing.Iterable[_ProcessorT] | None,
*,
param: str,
) -> list[_ProcessorT]:
if processors is None:
return []
if callable(processors):
return [processors]
if utils.is_iterable_but_not_string(processors):
return list(processors)
raise ValueError(
f"The '{param}' parameter must be a callable or an iterable of callables."
)


class Raw(Field[typing.Any]):
"""Field that applies no formatting."""
Expand Down
7 changes: 7 additions & 0 deletions src/marshmallow/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

import typing

T = typing.TypeVar("T")

#: A type that can be either a sequence of strings or a set of strings
StrSequenceOrSet: typing.TypeAlias = typing.Sequence[str] | typing.AbstractSet[str]

Expand All @@ -24,6 +26,11 @@
#: A valid option for the ``unknown`` schema option and argument
UnknownOption: typing.TypeAlias = typing.Literal["exclude", "include", "raise"]

#: Type for field-level pre-load functions
PreLoadCallable = typing.Callable[[typing.Any], typing.Any]
#: Type for field-level post-load functions
PostLoadCallable = typing.Callable[[T], T]


class SchemaValidator(typing.Protocol):
def __call__(
Expand Down
122 changes: 122 additions & 0 deletions tests/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,3 +631,125 @@ class Family(Schema):
"daughter": {"value": {"age": ["Missing data for required field."]}}
}
}


class TestFieldPreAndPostLoad:
def test_field_pre_load(self):
class UserSchema(Schema):
name = fields.Str(pre_load=str)

schema = UserSchema()
result = schema.load({"name": 808})
assert result["name"] == "808"

def test_field_pre_load_multiple(self):
def decrement(value):
return value - 1

def add_prefix(value):
return "test_" + value

class UserSchema(Schema):
name = fields.Str(pre_load=[decrement, str, add_prefix])

schema = UserSchema()
result = schema.load({"name": 809})
assert result["name"] == "test_808"

def test_field_post_load(self):
class UserSchema(Schema):
age = fields.Int(post_load=str)

schema = UserSchema()
result = schema.load({"age": 42})
assert result["age"] == "42"

def test_field_post_load_multiple(self):
def multiply_by_2(value):
return value * 2

def decrement(value):
return value - 1

class UserSchema(Schema):
age = fields.Float(post_load=[multiply_by_2, decrement])

schema = UserSchema()
result = schema.load({"age": 21.5})
assert result["age"] == 42.0

def test_field_pre_and_post_load(self):
def multiply_by_2(value):
return value * 2

class UserSchema(Schema):
age = fields.Int(pre_load=[str.strip, int], post_load=[multiply_by_2])

schema = UserSchema()
result = schema.load({"age": " 21 "})
assert result["age"] == 42

def test_field_pre_load_validation_error(self):
def always_fail(value):
raise ValidationError("oops")

class UserSchema(Schema):
age = fields.Int(pre_load=always_fail)

schema = UserSchema()
with pytest.raises(ValidationError) as exc:
schema.load({"age": 42})
assert exc.value.messages == {"age": ["oops"]}

def test_field_post_load_validation_error(self):
def always_fail(value):
raise ValidationError("oops")

class UserSchema(Schema):
age = fields.Int(post_load=always_fail)

schema = UserSchema()
with pytest.raises(ValidationError) as exc:
schema.load({"age": 42})
assert exc.value.messages == {"age": ["oops"]}

def test_field_pre_load_none(self):
def handle_none(value):
if value is None:
return 0
return value

class UserSchema(Schema):
age = fields.Int(pre_load=handle_none, allow_none=True)

schema = UserSchema()
result = schema.load({"age": None})
assert result["age"] == 0

def test_field_post_load_not_called_with_none_input_when_not_allowed(self):
def handle_none(value):
if value is None:
return 0
return value

class UserSchema(Schema):
age = fields.Int(post_load=handle_none, allow_none=False)

schema = UserSchema()
with pytest.raises(ValidationError) as exc:
schema.load({"age": None})
assert exc.value.messages == {"age": ["Field may not be null."]}

def test_invalid_type_passed_to_pre_load(self):
with pytest.raises(
ValueError,
match="The 'pre_load' parameter must be a callable or an iterable of callables.",
):
fields.Int(pre_load="not_callable") # type: ignore[arg-type]

def test_invalid_type_passed_to_post_load(self):
with pytest.raises(
ValueError,
match="The 'post_load' parameter must be a callable or an iterable of callables.",
):
fields.Int(post_load="not_callable") # type: ignore[arg-type]