Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 46 additions & 3 deletions src/marshmallow/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import typing
from abc import ABC, abstractmethod
from operator import attrgetter
from urllib.parse import urlsplit, urlunsplit

from marshmallow.exceptions import ValidationError

Expand Down Expand Up @@ -195,6 +196,45 @@ def _repr_args(self) -> str:
def _format_error(self, value) -> str:
return self.error.format(input=value)

@staticmethod
def _encode_idn(value: str) -> str:
"""Encode internationalized domain names (IDN) to punycode.

Converts Unicode hostnames to their ASCII-compatible encoding
so that the existing ASCII-only regex can validate them.
"""
try:
parsed = urlsplit(value)
except ValueError:
return value
if not parsed.hostname:
return value
# Only encode if hostname contains non-ASCII characters
try:
parsed.hostname.encode("ascii")
except UnicodeEncodeError:
pass
else:
return value
try:
encoded_hostname = parsed.hostname.encode("idna").decode("ascii")
except (UnicodeError, UnicodeDecodeError):
return value
# Reconstruct netloc with encoded hostname (preserve port)
if parsed.port:
netloc = f"{encoded_hostname}:{parsed.port}"
else:
netloc = encoded_hostname
# Preserve userinfo if present
if parsed.username:
userinfo = parsed.username
if parsed.password:
userinfo += f":{parsed.password}"
netloc = f"{userinfo}@{netloc}"
return urlunsplit(
(parsed.scheme, netloc, parsed.path, parsed.query, parsed.fragment)
)

def __call__(self, value: str) -> str:
message = self._format_error(value)
if not value:
Expand All @@ -211,12 +251,15 @@ def __call__(self, value: str) -> str:
relative=self.relative, absolute=self.absolute, require_tld=self.require_tld
)

# Encode IDN hostnames to punycode for validation
validation_value = self._encode_idn(value)

# Hostname is optional for file URLS. If absent it means `localhost`.
# Fill it in for the validation if needed
if scheme == "file" and value.lower().startswith("file:///"):
matched = regex.search("file://localhost/" + value[8:])
if scheme == "file" and validation_value.lower().startswith("file:///"):
matched = regex.search("file://localhost/" + validation_value[8:])
else:
matched = regex.search(value)
matched = regex.search(validation_value)

if not matched:
raise ValidationError(message)
Expand Down