diff --git a/apistar/types.py b/apistar/types.py index e96531da..b68b7082 100644 --- a/apistar/types.py +++ b/apistar/types.py @@ -52,6 +52,9 @@ def __new__(cls, name, bases, attrs): class Type(Mapping, metaclass=TypeMetaclass): + + formatter = None + def __init__(self, *args, **kwargs): definitions = None allow_coerce = False @@ -116,8 +119,8 @@ def __getitem__(self, key): if value is None: return None validator = self.validator.properties[key] - if hasattr(validator, 'format') and validator.format in validators.FORMATS: - formatter = validators.FORMATS[validator.format] + if validator.formatter is not None: + formatter = validator.formatter return formatter.to_string(value) return value diff --git a/apistar/validators.py b/apistar/validators.py index 9df9e4f8..684cf56c 100644 --- a/apistar/validators.py +++ b/apistar/validators.py @@ -26,7 +26,8 @@ class Validator: errors = {} _creation_counter = 0 - def __init__(self, title='', description='', default=NO_DEFAULT, allow_null=False, definitions=None, def_name=None): + def __init__(self, title='', description='', default=NO_DEFAULT, allow_null=False, + definitions=None, def_name=None, formatter=None): definitions = {} if (definitions is None) else dict_type(definitions) assert isinstance(title, str) @@ -46,6 +47,7 @@ def __init__(self, title='', description='', default=NO_DEFAULT, allow_null=Fals self.allow_null = allow_null self.definitions = definitions self.def_name = def_name + self.formatter = formatter # We need this global counter to determine what order fields have # been declared in when used with `Type`. @@ -127,13 +129,15 @@ def __init__(self, max_length=None, min_length=None, pattern=None, self.pattern = pattern self.enum = enum self.format = format + if isinstance(self.format, str) and self.formatter is None and self.format in FORMATS: + self.formatter = FORMATS[self.format] def validate(self, value, definitions=None, allow_coerce=False): if value is None and self.allow_null: return None elif value is None: self.error('null') - elif self.format in FORMATS and FORMATS[self.format].is_native_type(value): + elif self.formatter is not None and self.formatter.is_native_type(value): return value elif not isinstance(value, str): self.error('type') @@ -159,8 +163,8 @@ def validate(self, value, definitions=None, allow_coerce=False): if not re.search(self.pattern, value): self.error('pattern') - if self.format in FORMATS: - return FORMATS[self.format].validate(value) + if self.formatter is not None: + return self.formatter.validate(value) return value diff --git a/docs/api-guide/type-system.md b/docs/api-guide/type-system.md index 3ec307b5..778fa4d9 100644 --- a/docs/api-guide/type-system.md +++ b/docs/api-guide/type-system.md @@ -270,3 +270,39 @@ You can also access the serialized string representation if needed. * `title` - A title to use in API schemas and documentation. * `description` - A description to use in API schemas and documentation. * `allow_null` - Indicates if `None` should be considered a valid value. Defaults to `False`. + +## Custom Formats + +Custom formatters can be provided for validators to enable them to return any native type + +```python +from apistar.formats import BaseFormat + +class Foo: + def __init__(self, bar): + self.bar = bar + +class FooFormatter(BaseFormat): + def is_native_type(self, value): + return isinstance(value, Foo) + + def to_string(self, value): + return value.bar + + def validate(self, value): + if not isinstance(value, str) or not value.startswith('bar_'): + raise exceptions.ValidationError('Must start with bar_.') + return Foo(value) + +class Example(types.Type): + foo = validators.String(formatter=FooFormatter()) + +>>> data = {'foo': 'bar_foo'} +>>> obj = Example(data) + +>>> obj.foo +<__main__.Foo object at 0x7f143ec8ec88> + +>>> obj['foo'] +"bar_foo" +``` \ No newline at end of file diff --git a/tests/test_formats.py b/tests/test_formats.py index b7fb0efc..5db504c0 100644 --- a/tests/test_formats.py +++ b/tests/test_formats.py @@ -3,6 +3,7 @@ import pytest from apistar import exceptions, types, validators +from apistar.formats import BaseFormat UTC = datetime.timezone.utc @@ -120,3 +121,35 @@ class Example(types.Type): }) assert example.when is None assert example['when'] is None + + +def test_custom_formatter(): + class Foo: + def __init__(self, bar): + self.bar = bar + + class FooFormatter(BaseFormat): + def is_native_type(self, value): + return isinstance(value, Foo) + + def to_string(self, value): + return value.bar + + def validate(self, value): + if not isinstance(value, str) or not value.startswith('bar_'): + raise exceptions.ValidationError('Must start with bar_.') + return Foo(value) + + class Example(types.Type): + foo = validators.String(formatter=FooFormatter()) + + with pytest.raises(exceptions.ValidationError) as exc: + example = Example({ + 'foo': 'foo' + }) + assert exc.value.detail == {'foo': 'Must start with bar_.'} + + example = Example({'foo': 'bar_foo'}) + assert isinstance(example.foo, Foo) + assert example.foo.bar == 'bar_foo' + assert example['foo'] == 'bar_foo'