Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion backend/infrahub/graphql/initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from infrahub.graphql.resolvers.account_metadata import AccountMetadataResolver
from infrahub.graphql.resolvers.many_relationship import ManyRelationshipResolver
from infrahub.graphql.resolvers.single_relationship import SingleRelationshipResolver
from infrahub.graphql.resolvers.template_pool_source import TemplatePoolSourceResolver
from infrahub.permissions import PermissionManager

if TYPE_CHECKING:
Expand Down Expand Up @@ -121,7 +122,9 @@ async def prepare_graphql_params(
context=GraphqlContext(
db=db,
branch=branch,
single_relationship_resolver=SingleRelationshipResolver(),
single_relationship_resolver=SingleRelationshipResolver(
pool_source_resolver=TemplatePoolSourceResolver(),
),
many_relationship_resolver=ManyRelationshipResolver(),
account_metadata_resolver=AccountMetadataResolver(),
at=Timestamp(at),
Expand Down
16 changes: 15 additions & 1 deletion backend/infrahub/graphql/resolvers/single_relationship.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from infrahub.database import InfrahubDatabase
from infrahub.graphql.field_extractor import extract_graphql_fields
from infrahub.graphql.metadata import build_metadata_query_options, get_metadata_options_from_fields
from infrahub.graphql.resolvers.template_pool_source import TemplatePoolSourceResolver

from ..loaders.node import GetManyParams, NodeDataLoader
from ..types import RELATIONS_PROPERTY_MAP, RELATIONS_PROPERTY_MAP_REVERSED
Expand All @@ -25,8 +26,9 @@


class SingleRelationshipResolver:
def __init__(self) -> None:
def __init__(self, pool_source_resolver: TemplatePoolSourceResolver) -> None:
self._data_loader_instances: dict[GetManyParams, NodeDataLoader] = {}
self._pool_source_resolver = pool_source_resolver

def _build_relationship_meta_response(
self, relationship: Relationship, metadata_fields: dict[str, Any]
Expand Down Expand Up @@ -126,6 +128,18 @@ async def resolve(self, parent: dict, info: GraphQLResolveInfo, **kwargs: Any) -
)

if not relationship and not peer_node:
pool_source_response = await self._pool_source_resolver.get_pool_source(
db=graphql_context.db,
branch=graphql_context.branch,
at=graphql_context.at,
parent_id=parent["id"],
source_kind=node_schema.kind,
node_schema=node_schema,
rel_name=info.field_name,
property_fields=property_fields,
)
if pool_source_response:
return pool_source_response
return response

async with graphql_context.db.start_session(read_only=True) as db:
Expand Down
68 changes: 68 additions & 0 deletions backend/infrahub/graphql/resolvers/template_pool_source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any

from infrahub.core.constants import BranchSupportType
from infrahub.core.constants.schema import RESOURCE_POOL_REL_SUFFIX
from infrahub.core.manager import NodeManager

if TYPE_CHECKING:
from infrahub.core.branch.models import Branch
from infrahub.core.schema.node_schema import NodeSchema
from infrahub.core.timestamp import Timestamp
from infrahub.database import InfrahubDatabase


class TemplatePoolSourceResolver:
"""Resolves pool source information for template relationship properties.

When a template has a _from_resource_pool relationship set for a given relationship
(e.g. primary_ip_from_resource_pool) but no direct peer on the main relationship,
this resolver returns the pool node as the "source" property of that relationship.
"""

async def get_pool_source(
self,
db: InfrahubDatabase,
branch: Branch,
at: Timestamp | None,
parent_id: str,
source_kind: str,
node_schema: NodeSchema,
rel_name: str,
property_fields: dict[str, Any],
) -> dict[str, Any] | None:
if not node_schema.is_template_schema or not property_fields:
return None

pool_rel_name = f"{rel_name}{RESOURCE_POOL_REL_SUFFIX}"
try:
pool_rel_schema = node_schema.get_relationship(pool_rel_name)
except ValueError:
return None

async with db.start_session(read_only=True) as dbs:
pool_rels = await NodeManager.query_peers(
db=dbs,
ids=[parent_id],
source_kind=source_kind,
schema=pool_rel_schema,
filters={},
at=at,
branch=branch,
branch_agnostic=pool_rel_schema.branch is BranchSupportType.AGNOSTIC,
fetch_peers=True,
)
if not pool_rels:
return None

pool_rel = pool_rels[0]
response: dict[str, Any] = {"node": None, "properties": {}}

if "source" in property_fields:
source_fields = property_fields.get("source", {})
async with db.start_session(read_only=True) as dbs:
pool_peer = await pool_rel.get_peer(db=dbs)
response["properties"]["source"] = await pool_peer.to_graphql(db=dbs, fields=source_fields or None)

return response
5 changes: 4 additions & 1 deletion backend/infrahub/graphql/subscription/graphql_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from infrahub.graphql.resolvers.account_metadata import AccountMetadataResolver
from infrahub.graphql.resolvers.many_relationship import ManyRelationshipResolver
from infrahub.graphql.resolvers.single_relationship import SingleRelationshipResolver
from infrahub.graphql.resolvers.template_pool_source import TemplatePoolSourceResolver
from infrahub.log import get_logger

if TYPE_CHECKING:
Expand Down Expand Up @@ -49,7 +50,9 @@ async def resolver_graphql_query(
at=Timestamp(),
related_node_ids=set(),
types=graphql_context.types,
single_relationship_resolver=SingleRelationshipResolver(),
single_relationship_resolver=SingleRelationshipResolver(
pool_source_resolver=TemplatePoolSourceResolver(),
),
many_relationship_resolver=ManyRelationshipResolver(),
account_metadata_resolver=AccountMetadataResolver(),
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,24 @@
}
"""

QUERY_TEMPLATE_WITH_SOURCE = """
query QueryTemplate($id: ID!) {
TemplateTestingDevice(ids: [$id]) {
edges {
node {
id
primary_ip {
node { id }
properties {
source { id }
}
}
}
}
}
}
"""


class TestTemplatePoolRelationships:
"""Component tests for template pool relationship routing via GraphQL.
Expand Down Expand Up @@ -426,3 +444,97 @@ async def test_update_template_with_pool_by_name(

direct_peer = await template.get_relationship("primary_ip").get_peer(db=db)
assert direct_peer is None

async def test_query_template_with_pool_shows_source_no_direct_peer(
self,
db: InfrahubDatabase,
default_branch_scope_class: Branch,
device_schema_with_pool_rel: None,
ip_address_pool: CoreIPAddressPool,
) -> None:
template_schema = registry.schema.get_template_schema(
name="TemplateTestingDevice", branch=default_branch_scope_class
)
template = await Node.init(db=db, schema=template_schema, branch=default_branch_scope_class)
await template.new(
db=db, template_name="device-tpl-query-pool-source", primary_ip_from_resource_pool=ip_address_pool
)
await template.save(db=db)

gql_params = await prepare_graphql_params(db=db, branch=default_branch_scope_class)
result = await graphql(
schema=gql_params.schema,
source=QUERY_TEMPLATE_WITH_SOURCE,
context_value=gql_params.context,
root_value=None,
variable_values={"id": template.id},
)

assert not result.errors
assert result.data
edges = result.data["TemplateTestingDevice"]["edges"]
assert len(edges) == 1
primary_ip = edges[0]["node"]["primary_ip"]
assert primary_ip["node"] is None
assert primary_ip["properties"]["source"]["id"] == ip_address_pool.id

async def test_query_template_with_direct_peer_no_pool_source(
self,
db: InfrahubDatabase,
default_branch_scope_class: Branch,
device_schema_with_pool_rel: None,
ip_address: Node,
) -> None:
template_schema = registry.schema.get_template_schema(
name="TemplateTestingDevice", branch=default_branch_scope_class
)
template = await Node.init(db=db, schema=template_schema, branch=default_branch_scope_class)
await template.new(db=db, template_name="device-tpl-query-direct-peer", primary_ip=ip_address)
await template.save(db=db)

gql_params = await prepare_graphql_params(db=db, branch=default_branch_scope_class)
result = await graphql(
schema=gql_params.schema,
source=QUERY_TEMPLATE_WITH_SOURCE,
context_value=gql_params.context,
root_value=None,
variable_values={"id": template.id},
)

assert not result.errors
assert result.data
edges = result.data["TemplateTestingDevice"]["edges"]
assert len(edges) == 1
primary_ip = edges[0]["node"]["primary_ip"]
assert primary_ip["node"]["id"] == ip_address.id
assert primary_ip["properties"]["source"] is None

async def test_query_template_without_pool_or_peer_returns_null(
self,
db: InfrahubDatabase,
default_branch_scope_class: Branch,
device_schema_with_pool_rel: None,
) -> None:
template_schema = registry.schema.get_template_schema(
name="TemplateTestingDevice", branch=default_branch_scope_class
)
template = await Node.init(db=db, schema=template_schema, branch=default_branch_scope_class)
await template.new(db=db, template_name="device-tpl-query-no-pool")
await template.save(db=db)

gql_params = await prepare_graphql_params(db=db, branch=default_branch_scope_class)
result = await graphql(
schema=gql_params.schema,
source=QUERY_TEMPLATE_WITH_SOURCE,
context_value=gql_params.context,
root_value=None,
variable_values={"id": template.id},
)

assert not result.errors
assert result.data
edges = result.data["TemplateTestingDevice"]["edges"]
assert len(edges) == 1
primary_ip = edges[0]["node"]["primary_ip"]
assert primary_ip["node"] is None
assert primary_ip["properties"]["source"] is None
Loading