Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 37 additions & 14 deletions src/vercel/_internal/sandbox/network_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,20 @@ def to_dict(self) -> dict[str, Any]:
payload["headerNames"] = self.header_names
return payload

def to_redacted_headers(self) -> dict[str, str]:
if self.header_names:
redacted: dict[str, str] = {}
lower_to_name: dict[str, str] = {}
for name in self.header_names:
lower_name = name.lower()
previous_name = lower_to_name.get(lower_name)
if previous_name is not None and previous_name != name:
redacted.pop(previous_name, None)
lower_to_name[lower_name] = name
redacted[name] = _REDACTED_HEADER_VALUE
return redacted
return dict.fromkeys(self.headers or {}, _REDACTED_HEADER_VALUE)


@dataclass(frozen=True, slots=True)
class ApiNetworkPolicy:
Expand Down Expand Up @@ -184,10 +198,7 @@ def to_network_policy(self) -> NetworkPolicy:
allow: dict[str, list[NetworkPolicyRule]] = {domain: [] for domain in allowed_domains}
for rule in injection_rules:
allow.setdefault(rule.domain, [])
header_names = list(rule.header_names or [])
if not header_names and rule.headers is not None:
header_names = list(rule.headers.keys())
headers = _redacted_headers_from_names(header_names)
headers = rule.to_redacted_headers()
if not headers:
continue
allow[rule.domain].append(
Expand All @@ -200,26 +211,38 @@ def to_network_policy(self) -> NetworkPolicy:
def _merge_headers_case_insensitively(
headers: Sequence[Mapping[str, str] | None],
) -> dict[str, str]:
merged: dict[str, tuple[str, str]] = {}
merged: dict[str, str] = {}
lower_to_names: dict[str, set[str]] = {}
for header_map in headers:
for name, value in (header_map or {}).items():
merged[name.lower()] = (name, value)
return dict(merged.values())
if not header_map:
continue

current_lower_to_names: dict[str, set[str]] = {}
for name, value in header_map.items():
merged[name] = value
current_lower_to_names.setdefault(name.lower(), set()).add(name)

def _redacted_headers_from_names(header_names: Sequence[str]) -> dict[str, str]:
return dict.fromkeys(
_merge_headers_case_insensitively([dict.fromkeys(header_names, "")]),
_REDACTED_HEADER_VALUE,
)
for lower_name, current_names in current_lower_to_names.items():
for previous_name in lower_to_names.get(lower_name, set()) - current_names:
merged.pop(previous_name, None)
lower_to_names[lower_name] = current_names

return merged


def _merge_rule_headers(rules: Sequence[NetworkPolicyRule]) -> dict[str, str]:
return _merge_headers_case_insensitively(
[transform.headers for rule in rules for transform in rule.transform or []]
[_merge_rule_transform_headers(rule) for rule in rules]
)


def _merge_rule_transform_headers(rule: NetworkPolicyRule) -> dict[str, str]:
merged: dict[str, str] = {}
for transform in rule.transform or []:
merged.update(transform.headers or {})
return merged


def _subnets_from_api(network_policy: ApiNetworkPolicy) -> NetworkPolicySubnets | None:
if network_policy.allowed_cidrs is None and network_policy.denied_cidrs is None:
return None
Expand Down
233 changes: 134 additions & 99 deletions tests/unit/test_sandbox_network_policy_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,6 @@ def _rule_header_names(rules: Iterable[NetworkPolicyRule]) -> set[str]:
return header_names


def _case_insensitive_headers(items: Iterable[tuple[str, str]]) -> dict[str, str]:
merged: dict[str, tuple[str, str]] = {}
for name, value in items:
merged[name.lower()] = (name, value)
return dict(merged.values())


def _domain_strategy() -> st.SearchStrategy[str]:
label = st.from_regex(r"[a-z][a-z0-9-]{0,5}", fullmatch=True)
wildcard = st.just("*")
Expand All @@ -64,18 +57,6 @@ def _header_name_strategy() -> st.SearchStrategy[str]:
return st.from_regex(r"X-[A-Z][A-Za-z0-9-]{0,10}", fullmatch=True)


@st.composite
def _header_name_case_variant_strategy(draw: st.DrawFn) -> str:
canonical = draw(st.from_regex(r"x-[a-z][a-z0-9-]{0,10}", fullmatch=True))
chars: list[str] = []
for char in canonical:
if char.isalpha():
chars.append(char.upper() if draw(st.booleans()) else char.lower())
else:
chars.append(char)
return "".join(chars)


def _header_value_strategy() -> st.SearchStrategy[str]:
return st.text(
alphabet=st.characters(blacklist_categories=["Cs"]),
Expand Down Expand Up @@ -170,35 +151,41 @@ def _policy_semantics(policy: NetworkPolicy) -> Any:
return ("record", domain_semantics, subnet_semantics)


@st.composite
def _duplicate_header_assignments_strategy(
draw: st.DrawFn,
) -> list[tuple[str, str]]:
canonical_names = draw(
st.lists(
st.from_regex(r"x-[a-z][a-z0-9-]{0,10}", fullmatch=True),
min_size=1,
max_size=4,
unique=True,
def _subnet_semantics(policy: NetworkPolicy) -> tuple[tuple[str, ...], tuple[str, ...]] | None:
if isinstance(policy, str) or policy.subnets is None:
return None

return (
tuple(sorted(policy.subnets.allow or [])),
tuple(sorted(policy.subnets.deny or [])),
)


def _normalized_custom_policy_semantics(
policy: NetworkPolicy,
) -> tuple[tuple[tuple[str, frozenset[str]], ...], tuple[tuple[str, ...], tuple[str, ...]] | None]:
if isinstance(policy, str):
raise TypeError("expected custom policy")

domain_semantics = tuple(
sorted(
(domain, frozenset(name.lower() for name in names))
for domain, names in _record_policy_domains(policy).items()
)
)
assignments: list[tuple[str, str]] = []
for canonical_name in canonical_names:
variant_count = draw(st.integers(min_value=1, max_value=3))
for _ in range(variant_count):
variant_chars: list[str] = []
for char in canonical_name:
if char.isalpha():
variant_chars.append(char.upper() if draw(st.booleans()) else char.lower())
else:
variant_chars.append(char)
assignments.append(
(
"".join(variant_chars),
draw(_header_value_strategy()),
)
)
return draw(st.permutations(assignments).map(list))
return (domain_semantics, _subnet_semantics(policy))


def _header_values(policy: NetworkPolicyCustom) -> set[str]:
if isinstance(policy.allow, list):
return set()

values: set[str] = set()
for rules in policy.allow.values():
for rule in rules:
for transform in rule.transform or []:
values.update((transform.headers or {}).values())
return values


class TestNetworkPolicyModes:
Expand Down Expand Up @@ -463,6 +450,54 @@ def test_multiple_transforms_merge_case_insensitive_header_names(self) -> None:
],
)

def test_single_rule_preserves_distinct_case_variants_in_api_headers(self) -> None:
policy = NetworkPolicyCustom(
allow={
"example.com": [
NetworkPolicyRule(
transform=[
NetworkTransformer(headers={"X-Trace": "first", "x-trace": "second"})
]
)
]
}
)

assert ApiNetworkPolicy.from_network_policy(policy) == ApiNetworkPolicy(
mode="custom",
allowed_domains=["example.com"],
injection_rules=[
ApiNetworkInjectionRule(
domain="example.com",
headers={"X-Trace": "first", "x-trace": "second"},
)
],
)

def test_later_rules_replace_earlier_case_variants_for_same_header(self) -> None:
policy = NetworkPolicyCustom(
allow={
"example.com": [
NetworkPolicyRule(transform=[NetworkTransformer(headers={"X-Trace": "first"})]),
NetworkPolicyRule(
transform=[NetworkTransformer(headers={"x-trace": "second"})]
),
NetworkPolicyRule(transform=[NetworkTransformer(headers={"X-Other": "other"})]),
]
}
)

assert ApiNetworkPolicy.from_network_policy(policy) == ApiNetworkPolicy(
mode="custom",
allowed_domains=["example.com"],
injection_rules=[
ApiNetworkInjectionRule(
domain="example.com",
headers={"x-trace": "second", "X-Other": "other"},
)
],
)

def test_api_response_header_names_merge_case_insensitively(self) -> None:
assert ApiNetworkPolicy(
mode="custom",
Expand All @@ -487,6 +522,47 @@ def test_api_response_header_names_merge_case_insensitively(self) -> None:
}
)

def test_api_response_header_names_keep_last_case_variant(self) -> None:
assert ApiNetworkPolicy(
mode="custom",
allowed_domains=["example.com"],
injection_rules=[
ApiNetworkInjectionRule(
domain="example.com",
header_names=["X-Trace", "x-trace", "X-TRACE"],
)
],
).to_network_policy() == NetworkPolicyCustom(
allow={
"example.com": [
NetworkPolicyRule(
transform=[NetworkTransformer(headers={"X-TRACE": "<redacted>"})]
)
]
}
)

def test_empty_api_header_names_fall_back_to_headers(self) -> None:
assert ApiNetworkPolicy(
mode="custom",
allowed_domains=["example.com"],
injection_rules=[
ApiNetworkInjectionRule(
domain="example.com",
headers={"X-Trace": "trace-value"},
header_names=[],
)
],
).to_network_policy() == NetworkPolicyCustom(
allow={
"example.com": [
NetworkPolicyRule(
transform=[NetworkTransformer(headers={"X-Trace": "<redacted>"})]
)
]
}
)

def test_api_network_policy_to_dict_uses_wire_keys(self) -> None:
policy = ApiNetworkPolicy(
mode="custom",
Expand Down Expand Up @@ -557,63 +633,22 @@ def test_generated_list_form_policies_round_trip_exactly(

@given(_record_policy_strategy())
@settings(max_examples=25, deadline=None, suppress_health_check=[HealthCheck.too_slow])
def test_generated_record_form_policies_preserve_domains_and_header_names(
def test_generated_record_form_policies_preserve_normalized_semantics(
self, policy: NetworkPolicyCustom
) -> None:
assert _record_policy_domains(
ApiNetworkPolicy.from_network_policy(policy).to_network_policy()
) == (_record_policy_domains(policy))
round_tripped = ApiNetworkPolicy.from_network_policy(policy).to_network_policy()

@given(_duplicate_header_assignments_strategy())
@settings(max_examples=25, deadline=None, suppress_health_check=[HealthCheck.too_slow])
def test_generated_record_form_merges_duplicate_headers_case_insensitively(
self, assignments: list[tuple[str, str]]
) -> None:
policy = NetworkPolicyCustom(
allow={
"example.com": [
NetworkPolicyRule(transform=[NetworkTransformer(headers={name: value})])
for name, value in assignments
]
}
)

api_policy = ApiNetworkPolicy.from_network_policy(policy)

assert api_policy == ApiNetworkPolicy(
mode="custom",
allowed_domains=["example.com"],
injection_rules=[
ApiNetworkInjectionRule(
domain="example.com",
headers=_case_insensitive_headers(assignments),
)
],
)
assert isinstance(round_tripped, NetworkPolicyCustom)
assert _normalized_custom_policy_semantics(
round_tripped
) == _normalized_custom_policy_semantics(policy)

@given(st.lists(_header_name_case_variant_strategy(), min_size=1, max_size=8))
@given(_record_policy_strategy())
@settings(max_examples=25, deadline=None, suppress_health_check=[HealthCheck.too_slow])
def test_generated_api_header_names_collapse_case_insensitive_duplicates(
self, header_names: list[str]
def test_generated_record_form_policies_decode_to_redacted_headers(
self, policy: NetworkPolicyCustom
) -> None:
expected_headers = dict.fromkeys(
_case_insensitive_headers((name, "<redacted>") for name in header_names),
"<redacted>",
)
round_tripped = ApiNetworkPolicy.from_network_policy(policy).to_network_policy()

assert ApiNetworkPolicy(
mode="custom",
allowed_domains=["example.com"],
injection_rules=[
ApiNetworkInjectionRule(
domain="example.com",
header_names=header_names,
)
],
).to_network_policy() == NetworkPolicyCustom(
allow={
"example.com": [
NetworkPolicyRule(transform=[NetworkTransformer(headers=expected_headers)])
]
}
)
assert isinstance(round_tripped, NetworkPolicyCustom)
assert _header_values(round_tripped) <= {"<redacted>"}
Loading