From fa44a54ef9fe34eebbb6ccd5d04ec1b34863be5c Mon Sep 17 00:00:00 2001 From: HirenGajjar Date: Tue, 12 May 2026 22:24:04 -0700 Subject: [PATCH] tests: check that client and server schemas match Add 15 parametrized tests across EmissionBase, RunBase, ExperimentBase, ProjectBase and OrganizationBase that verify the client (codecarbon/core/schemas.py) and server (carbonserver/carbonserver/api/schemas.py) schemas are compatible. Tests check: - All client fields exist in the server schema - All required server fields exist in the client schema - Shared fields have wire-compatible types Uses ast to parse schema files without importing either module, so the test runs in the client test environment with no extra deps. This would have caught the on_cloud: str vs bool mismatch fixed in #1189. Closes #1190 --- tests/test_schema_compatibility.py | 323 +++++++++++++++++++++++++++++ 1 file changed, 323 insertions(+) create mode 100644 tests/test_schema_compatibility.py diff --git a/tests/test_schema_compatibility.py b/tests/test_schema_compatibility.py new file mode 100644 index 000000000..41797adb8 --- /dev/null +++ b/tests/test_schema_compatibility.py @@ -0,0 +1,323 @@ +""" +Test that client (codecarbon/core/schemas.py) and server +(carbonserver/carbonserver/api/schemas.py) schemas are compatible. + +A mismatch between these schemas can cause silent data corruption or API errors. +This test was added to prevent regressions like the one fixed in PR #1189, +where `on_cloud` was typed as `str` on one side and `bool` on the other. + +Related issue: https://github.com/mlco2/codecarbon/issues/1190 +""" +import ast +from pathlib import Path + +import pytest + +# --------------------------------------------------------------------------- +# Paths +# --------------------------------------------------------------------------- +REPO_ROOT = Path(__file__).parent.parent +CLIENT_SCHEMA_PATH = REPO_ROOT / "codecarbon" / "core" / "schemas.py" +SERVER_SCHEMA_PATH = REPO_ROOT / "carbonserver" / "carbonserver" / "api" / "schemas.py" + + +# --------------------------------------------------------------------------- +# AST helpers - parse schema files without importing them +# --------------------------------------------------------------------------- + + +def _is_pydantic_required_field(value_node: ast.expr) -> bool: + """ + Return True when the node marks a Pydantic field as required. + + Covers two cases: + - Field(...) -- the classic explicit-required sentinel. + - Field() -- bare call with no positional arg and no default + or default_factory keyword; Pydantic v2 treats + this as required. + """ + if not isinstance(value_node, ast.Call): + return False + func = value_node.func + func_name = ( + func.id + if isinstance(func, ast.Name) + else (func.attr if isinstance(func, ast.Attribute) else "") + ) + if func_name != "Field": + return False + + # Field(...) - Ellipsis as first positional argument + if ( + value_node.args + and isinstance(value_node.args[0], ast.Constant) + and value_node.args[0].value is ... + ): + return True + + # Bare Field() - no positional args, no default/default_factory keyword + if not value_node.args: + kw_names = {kw.arg for kw in value_node.keywords} + if "default" not in kw_names and "default_factory" not in kw_names: + return True + + return False + + +def _parse_class_fields(filepath: Path, class_name: str) -> dict[str, dict]: + """ + Parse filepath with the ast module and return a dict of annotated fields + declared directly on class_name: + + { + "field_name": { + "annotation": "Optional[bool]", + "required": True | False, + }, + ... + } + + Works for both plain Python dataclasses and Pydantic BaseModel subclasses. + Un-annotated class-level assignments (e.g. model_config = ...) are + intentionally ignored. + """ + source = filepath.read_text(encoding="utf-8") + tree = ast.parse(source, filename=str(filepath)) + + for node in ast.walk(tree): + if isinstance(node, ast.ClassDef) and node.name == class_name: + fields: dict[str, dict] = {} + for item in node.body: + if not ( + isinstance(item, ast.AnnAssign) + and isinstance(item.target, ast.Name) + ): + continue + + field_name = item.target.id + annotation = ast.unparse(item.annotation) + + # A field is required when: + # 1. It has no right-hand side at all (pure annotation) + # 2. Its right-hand side is a Pydantic required-field marker + if item.value is None or _is_pydantic_required_field(item.value): + required = True + else: + required = False + + fields[field_name] = {"annotation": annotation, "required": required} + return fields + + return {} # class not found - tests that call this will get empty dicts + + +# --------------------------------------------------------------------------- +# Type-compatibility helpers +# --------------------------------------------------------------------------- + + +def _unwrap_optional(annotation: str) -> str: + """ + Strip an Optional / Union-with-None wrapper to return the core type. + + Handles all three common spellings: + - Optional[X] -> X + - X | None -> X (PEP 604 union syntax) + - Union[X, None] -> X + + Any annotation that is not a nullable wrapper is returned unchanged. + """ + # Optional[X] + if annotation.startswith("Optional[") and annotation.endswith("]"): + return annotation[len("Optional["):-1] + + # X | None or None | X + if "|" in annotation: + parts = [p.strip() for p in annotation.split("|")] + non_none = [p for p in parts if p != "None"] + if len(non_none) == 1: + return non_none[0] + + # Union[X, None] or Union[None, X] + if annotation.startswith("Union[") and annotation.endswith("]"): + inner = annotation[len("Union["):-1] + parts = [p.strip() for p in inner.split(",")] + non_none = [p for p in parts if p != "None"] + if len(non_none) == 1: + return non_none[0] + + return annotation + + +# Pydantic coerces these client-side types to the server-side types at +# validation time, so they are considered wire-compatible. +# Key = client core type (after unwrapping Optional) +# Value = set of acceptable server core types (after unwrapping Optional) +_COMPATIBLE_CORE_TYPES: dict[str, set[str]] = { + # The client uses plain str for UUIDs and datetime strings; + # Pydantic on the server will parse those correctly. + "str": {"str", "UUID", "datetime"}, + "UUID": {"UUID", "str"}, + "datetime": {"datetime", "str"}, + # Scalar types must match exactly. + "bool": {"bool"}, + "int": {"int"}, + "float": {"float"}, +} + + +def _types_compatible(client_annotation: str, server_annotation: str) -> bool: + """ + Return True when client_annotation is safe to send to an endpoint + that expects server_annotation. + + Optional wrappers are stripped before comparison so that, for example, + bool (client) and Optional[bool] (server) are treated as compatible -- + the server simply allows None in addition to a bool value. + + A bool vs str mismatch (the bug fixed in #1189) returns False. + """ + if client_annotation == server_annotation: + return True + + client_core = _unwrap_optional(client_annotation) + server_core = _unwrap_optional(server_annotation) + + if client_core == server_core: + return True + + return server_core in _COMPATIBLE_CORE_TYPES.get(client_core, set()) + + +# --------------------------------------------------------------------------- +# Schema pairs under test +# +# Each entry is (label, client_class_name, server_class_name). +# The client and server classes share the same names today; the three-tuple +# structure is intentional so that if the server classes are ever renamed +# (e.g. EmissionBase -> EmissionRead) only this list needs updating, not the +# test functions themselves. +# --------------------------------------------------------------------------- +SCHEMA_PAIRS = [ + ("EmissionBase", "EmissionBase", "EmissionBase"), + ("RunBase", "RunBase", "RunBase"), + ("ExperimentBase", "ExperimentBase", "ExperimentBase"), + ("ProjectBase", "ProjectBase", "ProjectBase"), + ("OrganizationBase", "OrganizationBase", "OrganizationBase"), +] + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("label,client_cls,server_cls", SCHEMA_PAIRS) +def test_client_fields_exist_in_server(label: str, client_cls: str, server_cls: str): + """ + Every field declared in the client schema must also exist in the server + schema. If the server drops a field the client sends, the payload will + be silently ignored or rejected. + """ + client_fields = _parse_class_fields(CLIENT_SCHEMA_PATH, client_cls) + server_fields = _parse_class_fields(SERVER_SCHEMA_PATH, server_cls) + + assert client_fields, f"[{label}] Could not parse client class '{client_cls}'" + assert server_fields, f"[{label}] Could not parse server class '{server_cls}'" + + missing = set(client_fields) - set(server_fields) + assert not missing, ( + f"[{label}] Fields present in client but missing from server schema: {missing}\n" + f" client : {CLIENT_SCHEMA_PATH.relative_to(REPO_ROOT)}\n" + f" server : {SERVER_SCHEMA_PATH.relative_to(REPO_ROOT)}" + ) + + +@pytest.mark.parametrize("label,client_cls,server_cls", SCHEMA_PAIRS) +def test_required_server_fields_exist_in_client( + label: str, client_cls: str, server_cls: str +): + """ + Every required server field (one without a default value) must also + appear in the client schema. If the client never sends a required field, + every API call for that resource will fail validation. + """ + client_fields = _parse_class_fields(CLIENT_SCHEMA_PATH, client_cls) + server_fields = _parse_class_fields(SERVER_SCHEMA_PATH, server_cls) + + assert client_fields, f"[{label}] Could not parse client class '{client_cls}'" + assert server_fields, f"[{label}] Could not parse server class '{server_cls}'" + + required_server_fields = { + name for name, meta in server_fields.items() if meta["required"] + } + missing = required_server_fields - set(client_fields) + assert not missing, ( + f"[{label}] Required server fields missing from client schema: {missing}\n" + f" client : {CLIENT_SCHEMA_PATH.relative_to(REPO_ROOT)}\n" + f" server : {SERVER_SCHEMA_PATH.relative_to(REPO_ROOT)}" + ) + + +@pytest.mark.parametrize("label,client_cls,server_cls", SCHEMA_PAIRS) +def test_shared_field_types_are_compatible( + label: str, client_cls: str, server_cls: str +): + """ + For every field that appears in both schemas, the client-side type must + be wire-compatible with the server-side type. + + This test would have caught the on_cloud: str (server) vs + on_cloud: bool (client) mismatch that was fixed in PR #1189. + """ + client_fields = _parse_class_fields(CLIENT_SCHEMA_PATH, client_cls) + server_fields = _parse_class_fields(SERVER_SCHEMA_PATH, server_cls) + + assert client_fields, f"[{label}] Could not parse client class '{client_cls}'" + assert server_fields, f"[{label}] Could not parse server class '{server_cls}'" + + shared = set(client_fields) & set(server_fields) + mismatches: list[str] = [] + + for field in sorted(shared): + c_type = client_fields[field]["annotation"] + s_type = server_fields[field]["annotation"] + if not _types_compatible(c_type, s_type): + mismatches.append(f" {field}: client={c_type!r} server={s_type!r}") + + assert not mismatches, ( + f"[{label}] Incompatible types between client and server schemas:\n" + + "\n".join(mismatches) + + f"\n client : {CLIENT_SCHEMA_PATH.relative_to(REPO_ROOT)}" + + f"\n server : {SERVER_SCHEMA_PATH.relative_to(REPO_ROOT)}" + ) + + +@pytest.mark.parametrize("label,client_cls,server_cls", SCHEMA_PAIRS) +def test_required_alignment_on_shared_fields( + label: str, client_cls: str, server_cls: str +): + """ + If a server field is required, the client must not treat it as optional. + + When the client marks a field Optional (or gives it a default of None) + but the server requires a value, the client can legally send None and + the server will reject it -- the same class of silent wire mismatch as + the type bug fixed in #1189, just on the optionality axis. + """ + client_fields = _parse_class_fields(CLIENT_SCHEMA_PATH, client_cls) + server_fields = _parse_class_fields(SERVER_SCHEMA_PATH, server_cls) + + assert client_fields, f"[{label}] Could not parse client class '{client_cls}'" + assert server_fields, f"[{label}] Could not parse server class '{server_cls}'" + + weakened = [ + f + for f in set(client_fields) & set(server_fields) + if server_fields[f]["required"] and not client_fields[f]["required"] + ] + assert not weakened, ( + f"[{label}] server requires fields client treats as optional: {weakened}\n" + f" client : {CLIENT_SCHEMA_PATH.relative_to(REPO_ROOT)}\n" + f" server : {SERVER_SCHEMA_PATH.relative_to(REPO_ROOT)}" + )