Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backend/infrahub/core/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ async def query_peers(
branch_agnostic: bool = False,
fetch_peers: bool = False,
include_metadata: MetadataQueryOptions | MetadataOptions = MetadataOptions.NONE,
order: OrderModel | None = None,
) -> list[Relationship]:
branch = await registry.get_branch(branch=branch, db=db)
at = Timestamp(at)
Expand All @@ -308,6 +309,7 @@ async def query_peers(
at=at,
branch_agnostic=branch_agnostic,
include_metadata=relationship_metadata_options,
requested_order=order,
)
await query.execute(db=db)

Expand Down
9 changes: 8 additions & 1 deletion backend/infrahub/core/order.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import Self

from pydantic import BaseModel, model_validator
from pydantic import BaseModel, ConfigDict, model_validator

from infrahub.constants.enums import OrderDirection # noqa: TC001
from infrahub.exceptions import ValidationError
Expand All @@ -15,6 +15,8 @@


class NodeMetaOrder(BaseModel):
model_config = ConfigDict(frozen=True)

created_at: OrderDirection | None = None
updated_at: OrderDirection | None = None

Expand All @@ -23,9 +25,14 @@ def __bool__(self) -> bool:


class OrderModel(BaseModel):
model_config = ConfigDict(frozen=True)

disable: bool | None = None
node_metadata: NodeMetaOrder | None = None

def __bool__(self) -> bool:
return bool(self.disable) or bool(self.node_metadata)

@model_validator(mode="after")
def validate_metadata(self) -> Self:
if self.node_metadata and self.node_metadata.created_at and self.node_metadata.updated_at:
Expand Down
246 changes: 122 additions & 124 deletions backend/infrahub/core/query/node.py

Large diffs are not rendered by default.

38 changes: 37 additions & 1 deletion backend/infrahub/core/query/relationship.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)
from infrahub.core.constants import InfrahubKind, MetadataOptions, RelationshipDirection, RelationshipStatus
from infrahub.core.constants.database import DatabaseEdgeType
from infrahub.core.order import METADATA_CREATED_AT, METADATA_UPDATED_AT, OrderModel
from infrahub.core.query import Query, QueryResult, QueryType
from infrahub.core.query.subquery import build_subquery_filter, build_subquery_order, build_subquery_order_metadata
from infrahub.core.schema.order_by import OrderByTargetKind, parse_order_by_entry
Expand All @@ -26,6 +27,7 @@

from neo4j.graph import Relationship as Neo4jRelationship

from infrahub.constants.enums import OrderDirection
from infrahub.core.branch import Branch
from infrahub.core.node import Node
from infrahub.core.relationship import Relationship
Expand Down Expand Up @@ -689,6 +691,7 @@ def __init__(
branch: Branch | None = None,
at: Timestamp | str | None = None,
include_metadata: MetadataOptions = MetadataOptions.NONE,
requested_order: OrderModel | None = None,
**kwargs,
) -> None:
if not source and not source_ids:
Expand All @@ -712,6 +715,7 @@ def __init__(
self.rel_type = rel_type or self.rel.rel_type
self.schema = schema or self.rel.schema
self.include_metadata = include_metadata
self.requested_order = requested_order

if not branch and inspect.isclass(rel) and not hasattr(rel, "branch"):
raise ValueError("Either an instance of Relationship or a valid branch must be provided.")
Expand Down Expand Up @@ -843,7 +847,39 @@ def _add_updated_metadata_to_query(self, branch_filter_str: str) -> None:
""" % {"branch_filter": branch_filter_str, "time_details": time_details}
self.add_to_query(last_updated_query)

def _get_requested_metadata_order_fields(self) -> list[tuple[str, OrderDirection]]:
if not (self.requested_order and self.requested_order.node_metadata):
return []
fields: list[tuple[str, OrderDirection]] = []
nm = self.requested_order.node_metadata
if nm.created_at:
fields.append((METADATA_CREATED_AT, nm.created_at))
if nm.updated_at:
fields.append((METADATA_UPDATED_AT, nm.updated_at))
return fields

async def _add_peer_order_by(self, db: InfrahubDatabase, peer_schema: MainSchemaTypes, branch_filter: str) -> None:
if self.requested_order and self.requested_order.disable:
return

query_time_order_overrides_schema = bool(self.requested_order)
if query_time_order_overrides_schema:
for order_cnt, (metadata_field, direction) in enumerate(
self._get_requested_metadata_order_fields(), start=1
):
subquery, subquery_params, subquery_result_name = build_subquery_order_metadata(
metadata_field=metadata_field,
branch=self.branch,
branch_filter=branch_filter,
branch_agnostic=self.branch_agnostic,
node_alias="peer",
subquery_idx=order_cnt,
)
self.order_by.append(f"{subquery_result_name} {direction.value}")
self.params.update(subquery_params)
self.add_subquery(subquery=subquery, node_alias="peer")
return

if not (hasattr(peer_schema, "order_by") and peer_schema.order_by):
return

Expand Down Expand Up @@ -876,11 +912,11 @@ async def _add_peer_order_by(self, db: InfrahubDatabase, peer_schema: MainSchema
subquery, subquery_params, subquery_result_name = await build_subquery_order(
db=db,
field=field,
node_alias="peer",
name=order_by_field_name,
order_by=order_by_next_name,
branch_filter=branch_filter,
branch=self.branch,
node_alias="peer",
subquery_idx=order_cnt,
)
self.order_by.append(f"{subquery_result_name} {parsed.direction.value}")
Expand Down
4 changes: 4 additions & 0 deletions backend/infrahub/graphql/loaders/peers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from infrahub.core.branch.models import Branch
from infrahub.core.manager import NodeManager
from infrahub.core.metadata.model import MetadataQueryOptions
from infrahub.core.order import OrderModel
from infrahub.core.relationship.model import Relationship
from infrahub.core.schema.relationship_schema import RelationshipSchema
from infrahub.core.timestamp import Timestamp
Expand All @@ -24,6 +25,7 @@ class QueryPeerParams:
at: Timestamp | str | None = None
branch_agnostic: bool = False
include_metadata: MetadataQueryOptions = field(default_factory=MetadataQueryOptions)
order: OrderModel | None = None

def __hash__(self) -> int:
frozen_fields: frozenset | None = None
Expand All @@ -42,6 +44,7 @@ def __hash__(self) -> int:
str(self.source_kind),
str(self.branch_agnostic),
str(hash(self.include_metadata)),
str(hash(self.order)) if self.order is not None else "",
]
)
return hash(hash_str)
Expand All @@ -67,6 +70,7 @@ async def batch_load_fn(self, keys: list[Any]) -> list[list[Relationship]]:
branch_agnostic=self.query_params.branch_agnostic,
include_metadata=self.query_params.include_metadata,
fetch_peers=True,
order=self.query_params.order,
)
peer_rels_by_node_id: dict[str, list[Relationship]] = {}
for rel in peer_rels:
Expand Down
11 changes: 11 additions & 0 deletions backend/infrahub/graphql/resolvers/many_relationship.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
)
from infrahub.core.manager import NodeManager
from infrahub.core.metadata.model import MetadataQueryOptions
from infrahub.core.order import OrderModel
from infrahub.core.query.node import NodeGetHierarchyQuery
from infrahub.core.relationship import Relationship
from infrahub.core.schema.node_schema import NodeSchema
Expand All @@ -19,6 +20,7 @@
from infrahub.graphql.metadata import build_metadata_query_options, get_metadata_options_from_fields

from ..loaders.peers import PeerRelationshipsDataLoader, QueryPeerParams
from ..order import deserialize_order_input
from ..types import RELATIONS_PROPERTY_MAP, RELATIONS_PROPERTY_MAP_REVERSED

if TYPE_CHECKING:
Expand Down Expand Up @@ -99,6 +101,7 @@ async def resolve(
include_descendants: bool = False,
offset: int | None = None,
limit: int | None = None,
order: dict[str, Any] | None = None,
**kwargs: Any,
) -> dict[str, Any]:
"""Resolver for relationships of cardinality=one for Edged responses.
Expand Down Expand Up @@ -174,6 +177,8 @@ async def resolve(
# Add relationship properties metadata to relationship_level
include_metadata |= MetadataQueryOptions(relationship_level=get_metadata_options_from_fields(property_fields))

order_model = deserialize_order_input(input_data=order)

if offset or limit:
relationships = await self._get_entities_simple(
db=graphql_context.db,
Expand All @@ -187,6 +192,7 @@ async def resolve(
include_metadata=include_metadata,
offset=offset,
limit=limit,
order=order_model,
)
else:
relationships = await self._get_entities_with_data_loader(
Expand All @@ -199,6 +205,7 @@ async def resolve(
filters=filters,
node_fields=node_fields,
include_metadata=include_metadata,
order=order_model,
)

if not relationships:
Expand Down Expand Up @@ -248,6 +255,7 @@ async def _get_entities_simple(
include_metadata: MetadataQueryOptions,
offset: int | None = None,
limit: int | None = None,
order: OrderModel | None = None,
) -> list[Relationship] | None:
async with db.start_session(read_only=True) as dbs:
objs = await NodeManager.query_peers(
Expand All @@ -264,6 +272,7 @@ async def _get_entities_simple(
branch_agnostic=rel_schema.branch is BranchSupportType.AGNOSTIC,
fetch_peers=True,
include_metadata=include_metadata,
order=order,
)
if not objs:
return None
Expand All @@ -280,6 +289,7 @@ async def _get_entities_with_data_loader(
filters: dict[str, Any],
node_fields: dict[str, Any],
include_metadata: MetadataQueryOptions,
order: OrderModel | None = None,
) -> list[Relationship] | None:
if node_fields and "hfid" in node_fields:
node_fields["human_friendly_id"] = None
Expand All @@ -293,6 +303,7 @@ async def _get_entities_with_data_loader(
at=at,
branch_agnostic=rel_schema.branch is BranchSupportType.AGNOSTIC,
include_metadata=include_metadata,
order=order,
)
if query_params in self._data_loader_instances:
loader = self._data_loader_instances[query_params]
Expand Down
129 changes: 129 additions & 0 deletions backend/tests/component/core/test_node_get_hierarchy_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,132 @@ async def test_NodeGetHierarchyQuery_order_by_uuid_tiebreaker(

descendant_ids = list(query.get_peer_ids())
assert descendant_ids == sorted([paris_r1.id, paris_r2.id])


async def test_NodeGetHierarchyQuery_order_by_multi_field_mixed_direction(
db: InfrahubDatabase,
hierarchical_location_data_simple: dict[str, Node],
default_branch: Branch,
) -> None:
location_generic = registry.schema.get(name="LocationGeneric", branch=default_branch, duplicate=False)
location_generic.order_by = ["status__value__desc", "name__value"]

site_schema = registry.schema.get_node_schema(name="LocationSite", branch=default_branch, duplicate=False)
rack_schema = registry.schema.get_node_schema(name="LocationRack", branch=default_branch, duplicate=False)

multi_region = await Node.init(db=db, branch=default_branch, schema="LocationRegion")
await multi_region.new(db=db, name="multi-region")
await multi_region.save(db=db)

multi_site = await Node.init(db=db, branch=default_branch, schema="LocationSite")
await multi_site.new(db=db, name="multi-site", parent=multi_region.id)
await multi_site.save(db=db)

specs = [
("multi-rack-alpha", "offline"),
("multi-rack-bravo", "online"),
("multi-rack-charlie", "offline"),
("multi-rack-delta", "online"),
]
racks_by_name: dict[str, Node] = {}
for rack_name, status_value in specs:
rack = await Node.init(db=db, branch=default_branch, schema=rack_schema)
await rack.new(db=db, name=rack_name, parent=multi_site.id, status=status_value)
await rack.save(db=db)
racks_by_name[rack_name] = rack

query = await NodeGetHierarchyQuery.init(
db=db,
direction=RelationshipHierarchyDirection.DESCENDANTS,
node_id=multi_site.id,
node_schema=site_schema,
branch=default_branch,
)
await query.execute(db=db)

descendant_ids = list(query.get_peer_ids())
assert descendant_ids == [
racks_by_name["multi-rack-bravo"].id,
racks_by_name["multi-rack-delta"].id,
racks_by_name["multi-rack-alpha"].id,
racks_by_name["multi-rack-charlie"].id,
]


@dataclass
class MixedDirectionWithMetadataCase:
name: str
order_by: list[str]
expected_indices: list[int]


# Creation order: 0=alpha(status=offline), 1=bravo(status=online), 2=charlie(status=offline), 3=delta(status=online).
# created_at strictly increases with index.
MIXED_DIRECTION_WITH_METADATA_CASES_HIERARCHY = [
MixedDirectionWithMetadataCase(
name="status_desc_then_metadata_created_desc",
order_by=["status__value__desc", "node_metadata__created_at__desc"],
expected_indices=[3, 1, 2, 0],
),
MixedDirectionWithMetadataCase(
name="status_desc_then_metadata_created_asc",
order_by=["status__value__desc", "node_metadata__created_at"],
expected_indices=[1, 3, 0, 2],
),
MixedDirectionWithMetadataCase(
name="metadata_created_desc_then_status_asc",
order_by=["node_metadata__created_at__desc", "status__value"],
expected_indices=[3, 2, 1, 0],
),
MixedDirectionWithMetadataCase(
name="metadata_created_asc_then_name_desc",
order_by=["node_metadata__created_at", "name__value__desc"],
expected_indices=[0, 1, 2, 3],
),
]


async def test_NodeGetHierarchyQuery_order_by_multi_field_mixed_direction_with_metadata(
db: InfrahubDatabase,
hierarchical_location_data_simple: dict[str, Node],
default_branch: Branch,
) -> None:
location_generic = registry.schema.get(name="LocationGeneric", branch=default_branch, duplicate=False)
site_schema = registry.schema.get_node_schema(name="LocationSite", branch=default_branch, duplicate=False)
rack_schema = registry.schema.get_node_schema(name="LocationRack", branch=default_branch, duplicate=False)

mfm_region = await Node.init(db=db, branch=default_branch, schema="LocationRegion")
await mfm_region.new(db=db, name="mfm-region")
await mfm_region.save(db=db)

mfm_site = await Node.init(db=db, branch=default_branch, schema="LocationSite")
await mfm_site.new(db=db, name="mfm-site", parent=mfm_region.id)
await mfm_site.save(db=db)

specs = [
("mfm-rack-alpha", "offline"),
("mfm-rack-bravo", "online"),
("mfm-rack-charlie", "offline"),
("mfm-rack-delta", "online"),
]
racks: list[Node] = []
for rack_name, status_value in specs:
rack = await Node.init(db=db, branch=default_branch, schema=rack_schema)
await rack.new(db=db, name=rack_name, parent=mfm_site.id, status=status_value)
await rack.save(db=db)
racks.append(rack)

for case in MIXED_DIRECTION_WITH_METADATA_CASES_HIERARCHY:
location_generic.order_by = case.order_by
query = await NodeGetHierarchyQuery.init(
db=db,
direction=RelationshipHierarchyDirection.DESCENDANTS,
node_id=mfm_site.id,
node_schema=site_schema,
branch=default_branch,
)
await query.execute(db=db)
descendant_ids = list(query.get_peer_ids())
assert descendant_ids == [racks[i].id for i in case.expected_indices], (
f"order_by={case.order_by!r} produced wrong order"
)
Loading
Loading