diff --git a/src/filterpath/_get.py b/src/filterpath/_get.py index 94c423f..910d504 100644 --- a/src/filterpath/_get.py +++ b/src/filterpath/_get.py @@ -30,8 +30,8 @@ def get( # noqa: C901, PLR0915 :return: :rtype: Any | list[Any] """ - escapable_sequences = frozenset({path_separator, "\\", "["}) - sentinel = object() + escapable_sequences: frozenset[str] = frozenset({path_separator, "\\", "[", ":"}) + sentinel: object = object() def _deep_get(_obj: ObjTypes, _path: PathTypes, container: list) -> Any | list[Any]: # noqa: C901 if _obj is sentinel: @@ -55,11 +55,11 @@ def _deep_get(_obj: ObjTypes, _path: PathTypes, container: list) -> Any | list[A if has_container: logger.trace("encountering container") # Strip brackets for any filtering key or function - filter_key = key[1:-1] + filter_key: str = key[1:-1] logger.trace(f"filtering container on '{key}'") try: - filtered_obj = _deep_get(_obj, filter_key, container) + filtered_obj: Any | list[Any] = _deep_get(_obj, filter_key, container) except KeyError: logger.trace(f"unable to filter '{_obj}' on '{filter_key}', return empty list") return container @@ -78,7 +78,7 @@ def _deep_get(_obj: ObjTypes, _path: PathTypes, container: list) -> Any | list[A for item in filtered_obj: logger.trace(f"getting path '{_path}' of '{item}'") try: - deep_obj = _deep_get(item, _path, container) + deep_obj: Any | list[Any] = _deep_get(item, _path, container) if deep_obj is not container: container.append(deep_obj) except KeyError: @@ -91,15 +91,19 @@ def _deep_get(_obj: ObjTypes, _path: PathTypes, container: list) -> Any | list[A def _parse_path(_path: PathTypes) -> tuple[Any, PathTypes, bool]: if isinstance(_path, str): - is_escaped = False - has_container = _path.startswith("[") - escape_indexes = [] + is_escaped: bool = False + has_container: bool = _path.startswith("[") + slice_operator_count: int = 0 + escape_indexes: list = [] for idx, char in enumerate(_path): if not is_escaped: if char == path_separator: # Non-escaped path separator break + if char == ":": + slice_operator_count += 1 + elif char in escapable_sequences: # Escaped value, store index of escape character (previous index) escape_indexes.append(idx - 1) @@ -109,11 +113,16 @@ def _parse_path(_path: PathTypes) -> tuple[Any, PathTypes, bool]: # No path separators; increment the index in order to encapsulate the entire string idx += 1 - parsed_path = _remove_char_at_index(_path[:idx], escape_indexes) + parsed_path: str = _remove_char_at_index(_path[:idx], escape_indexes) + if slice_operator_count in {1, 2}: + with contextlib.suppress(ValueError): + sliced_path: slice = slice(*(int(part) if part else None for part in parsed_path.split(":"))) + return sliced_path, _path[idx + 1 :], False + return parsed_path, _path[idx + 1 :], has_container and parsed_path.endswith("]") # Get next from _path, operating on a list/tuple - curr_path = _path[0] + curr_path: Any = _path[0] if isinstance(curr_path, str) and path_separator in curr_path: # Parse the returned key for any unescaped subpaths curr_path, remaining_path, has_container = _parse_path(curr_path) @@ -134,7 +143,7 @@ def _remove_char_at_index(string: str, index: int | list[int]) -> str: return string def _get_any(_obj: ObjTypes, key: Any) -> Any: - value = sentinel + value: Any = sentinel # Try as a dict, must use `.get()` to prevent defaultdict from being autofilled if isinstance(_obj, dict) and isinstance(key, Hashable): value = _obj.get(key, sentinel) diff --git a/tests/get_test.py b/tests/get_test.py index 2e72d7a..0808e7a 100644 --- a/tests/get_test.py +++ b/tests/get_test.py @@ -131,6 +131,19 @@ def test_get(obj, args, expected): ("a.[].b", [[3, 4], [5, 6]]), ("a.[].b.0", [3, 5]), ("a.[].b.[]", [3, 4, 5, 6]), + ("x.:0", None), + ("x.\\:0", 99), + ("a.:", [1, 2, {"b": [3, 4]}, {"b": [5, 6]}]), + ("a.::-1", [{"b": [5, 6]}, {"b": [3, 4]}, 2, 1]), + ("a.0:2", [1, 2]), + ("a.::2", [1, {"b": [3, 4]}]), + ("a.1::2", [2, {"b": [5, 6]}]), + ("a.1::2.[].b", [[5, 6]]), + ("a.[1::2].b", [[5, 6]]), + ("a.0:4:3", [1, {"b": [5, 6]}]), + ("a.0:4:3:", None), + ("a.[0:3].b", [[3, 4]]), + ("a.[0:3].b.0", [3]), ], ) def test_get_enhanced(path, expected):