Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
266 changes: 146 additions & 120 deletions src/yamltrip/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import enum
from pathlib import Path
from typing import TYPE_CHECKING, Any, cast

Expand Down Expand Up @@ -53,59 +54,6 @@ def _make_route(keys: Sequence[KeyPart]) -> Route:
return Route(list(keys))


def _ensure_str_keys(keys: tuple[KeyPart, ...]) -> tuple[str, ...]:
"""Validate all keys are strings for creation operations.

Raises PatchError if any key is an int (cannot create sequences via upsert).
"""
for k in keys:
if isinstance(k, int):
msg = (
f"Cannot create intermediate structure with integer key {k}; "
"only string keys can create new mappings"
)
raise PatchError(msg)
return cast("tuple[str, ...]", keys)


def _flow_seq_replacements(
core_doc: CoreDocument,
old_value: Any,
new_value: Any,
path: tuple[KeyPart, ...],
) -> list[Patch]:
"""Find flow sequences that need modification and emit targeted replace patches."""
patches: list[Patch] = []

if isinstance(old_value, list) and isinstance(new_value, list):
if old_value != new_value:
route = _make_route(path)
try:
feature = core_doc.query_exact(route)
if feature and feature.kind == FeatureKind.FlowSequence:
patches.append(Patch(route=route, operation=Op.replace(new_value)))
return patches
except (KeyError, ValueError):
pass
# Recurse into shared list elements to find nested flow sequences
for i in range(min(len(old_value), len(new_value))):
sub_patches = _flow_seq_replacements(
core_doc, old_value[i], new_value[i], (*path, i)
)
patches.extend(sub_patches)
return patches

if isinstance(old_value, dict) and isinstance(new_value, dict):
for key in new_value:
if key in old_value:
sub_patches = _flow_seq_replacements(
core_doc, old_value[key], new_value[key], (*path, key)
)
patches.extend(sub_patches)

return patches


class Document:
"""An immutable YAML document.

Expand All @@ -128,22 +76,6 @@ def __init__(self, source: str) -> None:
raise ParseError(str(e)) from None
self._source: str = source

@classmethod
def _from_core(cls, core_doc: CoreDocument) -> Document:
"""Construct a Document from an already-parsed CoreDocument."""
obj = object.__new__(cls)
obj._core_doc = core_doc
obj._source = core_doc.source()
return obj

def _apply_patches(self, patches: list[Patch]) -> Document:
"""Apply patches to this document and return a new Document."""
try:
core_doc = self._core_doc.apply_patches(patches)
except RuntimeError as e:
raise PatchError(str(e)) from None
return Document._from_core(core_doc)

@property
def source(self) -> str:
"""The current YAML source text."""
Expand Down Expand Up @@ -267,9 +199,34 @@ def add(self, *keys: KeyPart, key: str, value: Any) -> Document:
patch = Patch(route=route, operation=op)
return self._apply_patches([patch])

def _is_empty_document(self) -> bool:
"""True if the document has no root data node."""
return not self._core_doc.query_exists(_make_route(()))
def upsert(self, *keys: KeyPart, value: Any) -> Document:
"""Replace if exists, create (with intermediate mappings) if not."""
if not keys:
if self._is_empty_document():
msg = (
"Cannot replace root of an empty document; provide at least one key"
)
raise PatchError(msg)
route = _make_route(())
op = Op.replace(value)
patch = Patch(route=route, operation=op)
return self._apply_patches([patch])

full_route = _make_route(keys)
if self._core_doc.query_exists(full_route):
return self.replace(*keys, value=value)

# Find deepest existing ancestor
for depth in range(len(keys) - 1, 0, -1):
ancestor_keys = keys[:depth]
ancestor_route = _make_route(ancestor_keys)
if self._core_doc.query_exists(ancestor_route):
return self._create_at(
ancestor_keys, _ensure_str_keys(keys[depth:]), value
)

# No path exists — add at root
return self._create_at((), _ensure_str_keys(keys), value)

def _create_at(
self,
Expand Down Expand Up @@ -315,34 +272,13 @@ def _create_at(
patch = Patch(route=route, operation=op)
return self._apply_patches([patch])

def upsert(self, *keys: KeyPart, value: Any) -> Document:
"""Replace if exists, create (with intermediate mappings) if not."""
if not keys:
if self._is_empty_document():
msg = (
"Cannot replace root of an empty document; provide at least one key"
)
raise PatchError(msg)
route = _make_route(())
op = Op.replace(value)
patch = Patch(route=route, operation=op)
return self._apply_patches([patch])

full_route = _make_route(keys)
if self._core_doc.query_exists(full_route):
return self.replace(*keys, value=value)

# Find deepest existing ancestor
for depth in range(len(keys) - 1, 0, -1):
ancestor_keys = keys[:depth]
ancestor_route = _make_route(ancestor_keys)
if self._core_doc.query_exists(ancestor_route):
return self._create_at(
ancestor_keys, _ensure_str_keys(keys[depth:]), value
)
def _is_empty_document(self) -> bool:
"""True if the document has no root data node."""
return not self._core_doc.query_exists(_make_route(()))

# No path exists — add at root
return self._create_at((), _ensure_str_keys(keys), value)
def prune_remove(self, *keys: KeyPart) -> Document:
"""Remove key and prune empty parents."""
return self.remove(*keys, prune=True)

def remove(self, *keys: KeyPart, prune: bool = False) -> Document:
"""Remove the key/index at path."""
Expand All @@ -364,10 +300,6 @@ def remove(self, *keys: KeyPart, prune: bool = False) -> Document:
break
return doc

def prune_remove(self, *keys: KeyPart) -> Document:
"""Remove key and prune empty parents."""
return self.remove(*keys, prune=True)

def append(self, *keys: KeyPart, value: Any) -> Document:
"""Append a single item to the sequence at path."""
route = _make_route(keys)
Expand All @@ -376,15 +308,14 @@ def append(self, *keys: KeyPart, value: Any) -> Document:
try:
return self._apply_patches([patch])
except PatchError as e:
msg = str(e)
# yamlpatch raises "...flow sequence..." for append on FlowSequence nodes
if "flow sequence" in msg:
kind = _classify_patch_error(e)
if kind == _PatchErrorKind.FLOW_SEQUENCE:
current = self[keys]
new_list = [*list(current), value]
replace_op = Op.replace(new_list)
return self._apply_patches([Patch(route=route, operation=replace_op)])
if "only permitted against sequence" in msg:
raise NodeTypeError(msg) from None
if kind == _PatchErrorKind.NOT_A_SEQUENCE:
raise NodeTypeError(str(e)) from None
raise

def insert(self, *keys: KeyPart, index: int, value: Any) -> Document:
Expand All @@ -399,14 +330,13 @@ def insert(self, *keys: KeyPart, index: int, value: Any) -> Document:
try:
return self._apply_patches([patch])
except PatchError as e:
msg = str(e)
# Rust apply_insert_at raises "expected BlockSequence, got ..." for
# both FlowSequence and non-sequence nodes (Scalar, BlockMapping, etc.)
if "expected BlockSequence" not in msg:
# Rust apply_insert_at raises BLOCK_SEQUENCE_EXPECTED for both
# FlowSequence and non-sequence nodes (Scalar, BlockMapping, etc.)
if _classify_patch_error(e) != _PatchErrorKind.BLOCK_SEQUENCE_EXPECTED:
raise
current = self[keys]
if not isinstance(current, list):
raise NodeTypeError(msg) from None
raise NodeTypeError(str(e)) from None
new_list: list[Any] = list(current)
new_list.insert(index, value)
replace_op = Op.replace(new_list)
Expand All @@ -421,15 +351,14 @@ def extend_list(self, *keys: KeyPart, values: Sequence[Any]) -> Document:
try:
return self._apply_patches(patches)
except PatchError as e:
msg = str(e)
# yamlpatch raises "...flow sequence..." for append on FlowSequence nodes
if "flow sequence" in msg:
kind = _classify_patch_error(e)
if kind == _PatchErrorKind.FLOW_SEQUENCE:
current = self[keys]
new_list = [*list(current), *values]
replace_op = Op.replace(new_list)
return self._apply_patches([Patch(route=route, operation=replace_op)])
if "only permitted against sequence" in msg:
raise NodeTypeError(msg) from None
if kind == _PatchErrorKind.NOT_A_SEQUENCE:
raise NodeTypeError(str(e)) from None
raise

def remove_from_list(self, *keys: KeyPart, values: Sequence[Any]) -> Document:
Expand Down Expand Up @@ -543,7 +472,7 @@ def sync(self, *keys: KeyPart, value: Any) -> Document:
try:
return doc._apply_patches(patches)
except PatchError as e:
if "expected BlockSequence" not in str(e):
if _classify_patch_error(e) != _PatchErrorKind.BLOCK_SEQUENCE_EXPECTED:
raise
# Fallback: a flow sequence was missed by pre-detection (e.g. due to
# list reordering). Replace the entire synced value.
Expand Down Expand Up @@ -574,7 +503,7 @@ def merge(self, *keys: KeyPart, value: Any) -> Document:
try:
return self.upsert(*normalized, value=value)
except PatchError as e:
if "unexpected node" in str(e):
if _classify_patch_error(e) == _PatchErrorKind.UNEXPECTED_NODE:
# Find deepest existing ancestor to report
failing = normalized
for i in range(len(normalized), 0, -1):
Expand Down Expand Up @@ -634,3 +563,100 @@ def find_index(self, *keys: KeyPart, where: dict[str, Any]) -> int | None:
if all(k in entry and entry[k] == v for k, v in where.items()):
return i
return None

def _apply_patches(self, patches: list[Patch]) -> Document:
"""Apply patches to this document and return a new Document."""
try:
core_doc = self._core_doc.apply_patches(patches)
except RuntimeError as e:
raise PatchError(str(e)) from None
return Document._from_core(core_doc)

@classmethod
def _from_core(cls, core_doc: CoreDocument) -> Document:
"""Construct a Document from an already-parsed CoreDocument."""
obj = object.__new__(cls)
obj._core_doc = core_doc
obj._source = core_doc.source()
return obj


def _ensure_str_keys(keys: tuple[KeyPart, ...]) -> tuple[str, ...]:
"""Validate all keys are strings for creation operations.

Raises PatchError if any key is an int (cannot create sequences via upsert).
"""
for k in keys:
if isinstance(k, int):
msg = (
f"Cannot create intermediate structure with integer key {k}; "
"only string keys can create new mappings"
)
raise PatchError(msg)
return cast("tuple[str, ...]", keys)


def _flow_seq_replacements(
core_doc: CoreDocument,
old_value: Any,
new_value: Any,
path: tuple[KeyPart, ...],
) -> list[Patch]:
"""Find flow sequences that need modification and emit targeted replace patches."""
patches: list[Patch] = []

if isinstance(old_value, list) and isinstance(new_value, list):
if old_value != new_value:
route = _make_route(path)
try:
feature = core_doc.query_exact(route)
if feature and feature.kind == FeatureKind.FlowSequence:
patches.append(Patch(route=route, operation=Op.replace(new_value)))
return patches
except (KeyError, ValueError):
pass
# Recurse into shared list elements to find nested flow sequences
for i in range(min(len(old_value), len(new_value))):
sub_patches = _flow_seq_replacements(
core_doc, old_value[i], new_value[i], (*path, i)
)
patches.extend(sub_patches)
return patches

if isinstance(old_value, dict) and isinstance(new_value, dict):
for key in new_value:
if key in old_value:
sub_patches = _flow_seq_replacements(
core_doc, old_value[key], new_value[key], (*path, key)
)
patches.extend(sub_patches)

return patches


class _PatchErrorKind(enum.Enum):
"""Classifies a PatchError by its originating yamlpatch error message."""

FLOW_SEQUENCE = "flow sequence"
NOT_A_SEQUENCE = "only permitted against sequence"
BLOCK_SEQUENCE_EXPECTED = "expected BlockSequence"
UNEXPECTED_NODE = "unexpected node"
UNKNOWN = ""


def _classify_patch_error(err: PatchError) -> _PatchErrorKind:
"""Return the kind of a PatchError based on its message string.

All yamlpatch error-message substrings are confined here so that
callers can branch on the enum rather than matching raw strings.
Comment thread
nathanjmcdougall marked this conversation as resolved.
"""
msg = str(err)
if _PatchErrorKind.FLOW_SEQUENCE.value in msg:
return _PatchErrorKind.FLOW_SEQUENCE
if _PatchErrorKind.NOT_A_SEQUENCE.value in msg:
return _PatchErrorKind.NOT_A_SEQUENCE
if _PatchErrorKind.BLOCK_SEQUENCE_EXPECTED.value in msg:
return _PatchErrorKind.BLOCK_SEQUENCE_EXPECTED
if _PatchErrorKind.UNEXPECTED_NODE.value in msg:
return _PatchErrorKind.UNEXPECTED_NODE
return _PatchErrorKind.UNKNOWN
16 changes: 8 additions & 8 deletions src/yamltrip/editor.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,6 @@ def original(self) -> Document:
raise RuntimeError(msg)
return self._original

@property
def document(self) -> Document:
"""The current in-progress document."""
if self._document is None:
msg = "Editor must be used as a context manager"
raise RuntimeError(msg)
return self._document

@property
def root(self) -> Any:
"""The entire document parsed as a Python object."""
Expand Down Expand Up @@ -170,3 +162,11 @@ def query(self, *keys: KeyPart) -> Feature:
def extract(self, feature: Feature) -> str:
"""Extract the raw YAML text for a feature."""
return self.document.extract(feature)

@property
def document(self) -> Document:
"""The current in-progress document."""
if self._document is None:
msg = "Editor must be used as a context manager"
raise RuntimeError(msg)
return self._document
Loading
Loading