diff --git a/.vale/styles/Infrahub/sentence-case.yml b/.vale/styles/Infrahub/sentence-case.yml
index 126e18f6..c27cf7a1 100644
--- a/.vale/styles/Infrahub/sentence-case.yml
+++ b/.vale/styles/Infrahub/sentence-case.yml
@@ -52,6 +52,7 @@ exceptions:
- Jinja
- Jinja2
- JWT
+ - MDX
- Namespace
- NATS
- Node
diff --git a/.vale/styles/spelling-exceptions.txt b/.vale/styles/spelling-exceptions.txt
index 0a4c0144..ecba179f 100644
--- a/.vale/styles/spelling-exceptions.txt
+++ b/.vale/styles/spelling-exceptions.txt
@@ -79,6 +79,7 @@ kbps
Keycloak
Loopbacks
markdownlint
+MDX
max_count
memgraph
menu_placement
diff --git a/CHANGELOG.md b/CHANGELOG.md
index b6c418e0..a3200c39 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -10,6 +10,24 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
This project uses [*towncrier*](https://towncrier.readthedocs.io/) and the changes for the upcoming release can be found in .
+
+## [1.19.0](https://github.com/opsmill/infrahub-sdk-python/tree/v1.19.0) - 2026-03-16
+
+### Added
+
+- Added support for FileObject nodes with file upload and download capabilities. New methods `upload_from_path(path)` and `upload_from_bytes(content, name)` allow setting file content before saving, while `download_file(dest)` enables downloading files to memory or streaming to disk for large files. ([#ihs193](https://github.com/opsmill/infrahub-sdk-python/issues/ihs193))
+- Python SDK API documentation is now generated directly from the docstrings of the classes, functions, and methods contained in the code. ([#201](https://github.com/opsmill/infrahub-sdk-python/issues/201))
+- Added a 'py.typed' file to the project. This is to enable type checking when the Infrahub SDK is imported from other projects. The addition of this file could cause new typing issues in external projects until all typing issues have been resolved. Adding it to the project now to better highlight remaining issues.
+
+### Changed
+
+- Updated branch report command to use node metadata for proposed change creator information instead of the deprecated relationship-based approach. Requires Infrahub 1.7 or above.
+
+### Fixed
+
+- Allow SDK tracking feature to continue after encountering delete errors due to impacted nodes having already been deleted by cascade delete. ([#265](https://github.com/opsmill/infrahub-sdk-python/issues/265))
+- Fixed Python SDK query generation regarding from_pool generated attribute value ([#497](https://github.com/opsmill/infrahub-sdk-python/issues/497))
+
## [1.18.1](https://github.com/opsmill/infrahub-sdk-python/tree/v1.18.1) - 2026-01-08
### Fixed
diff --git a/changelog/151.added.md b/changelog/151.added.md
new file mode 100644
index 00000000..c770976f
--- /dev/null
+++ b/changelog/151.added.md
@@ -0,0 +1 @@
+Add `infrahubctl schema export` command to export schemas from Infrahub.
diff --git a/changelog/201.added.md b/changelog/201.added.md
deleted file mode 100644
index bb3fbf00..00000000
--- a/changelog/201.added.md
+++ /dev/null
@@ -1 +0,0 @@
-Python SDK API documentation is now generated directly from the docstrings of the classes, functions, and methods contained in the code.
\ No newline at end of file
diff --git a/changelog/265.fixed.md b/changelog/265.fixed.md
deleted file mode 100644
index 4e3c43a9..00000000
--- a/changelog/265.fixed.md
+++ /dev/null
@@ -1 +0,0 @@
-Allow SDK tracking feature to continue after encountering delete errors due to impacted nodes having already been deleted by cascade delete.
diff --git a/changelog/497.fixed.md b/changelog/497.fixed.md
deleted file mode 100644
index b32323d1..00000000
--- a/changelog/497.fixed.md
+++ /dev/null
@@ -1 +0,0 @@
-Fixed Python SDK query generation regarding from_pool generated attribute value
diff --git a/docs/AGENTS.md b/docs/AGENTS.md
index 23ae5f68..43e104c6 100644
--- a/docs/AGENTS.md
+++ b/docs/AGENTS.md
@@ -1,4 +1,4 @@
-# docs/AGENTS.md
+# Documentation agents
Docusaurus documentation following Diataxis framework.
@@ -34,12 +34,12 @@ Sidebar navigation is dynamic: `sidebars-*.ts` files read the filesystem at buil
No manual sidebar update is needed when adding a new `.mdx` file. However, to control the display order of a new page, add its doc ID to the ordered list in the corresponding `sidebars-*.ts` file.
-## Adding Documentation
+## Adding documentation
1. Create MDX file in appropriate directory
2. Add frontmatter with `title`
-## MDX Pattern
+## MDX pattern
Use Tabs for async/sync examples, callouts for notes:
diff --git a/docs/docs/infrahubctl/infrahubctl-schema.mdx b/docs/docs/infrahubctl/infrahubctl-schema.mdx
index 7d569cc4..1467eae8 100644
--- a/docs/docs/infrahubctl/infrahubctl-schema.mdx
+++ b/docs/docs/infrahubctl/infrahubctl-schema.mdx
@@ -17,6 +17,7 @@ $ infrahubctl schema [OPTIONS] COMMAND [ARGS]...
**Commands**:
* `check`: Check if schema files are valid and what...
+* `export`: Export the schema from Infrahub as YAML...
* `load`: Load one or multiple schema files into...
## `infrahubctl schema check`
@@ -40,6 +41,25 @@ $ infrahubctl schema check [OPTIONS] SCHEMAS...
* `--config-file TEXT`: [env var: INFRAHUBCTL_CONFIG; default: infrahubctl.toml]
* `--help`: Show this message and exit.
+## `infrahubctl schema export`
+
+Export the schema from Infrahub as YAML files, one per namespace.
+
+**Usage**:
+
+```console
+$ infrahubctl schema export [OPTIONS]
+```
+
+**Options**:
+
+* `--directory PATH`: Directory path to store schema files [default: (dynamic)]
+* `--branch TEXT`: Branch from which to export the schema
+* `--namespaces TEXT`: Namespace(s) to export (default: all user-defined)
+* `--debug / --no-debug`: [default: no-debug]
+* `--config-file TEXT`: [env var: INFRAHUBCTL_CONFIG; default: infrahubctl.toml]
+* `--help`: Show this message and exit.
+
## `infrahubctl schema load`
Load one or multiple schema files into Infrahub.
diff --git a/docs/docs/python-sdk/guides/client.mdx b/docs/docs/python-sdk/guides/client.mdx
index 460036a1..90872a0f 100644
--- a/docs/docs/python-sdk/guides/client.mdx
+++ b/docs/docs/python-sdk/guides/client.mdx
@@ -251,7 +251,7 @@ Your client is now configured to use the specified default branch instead of `ma
## Hello world example
-Let's create a simple "Hello World" example to verify your client configuration works correctly. This example will connect to your Infrahub instance and query the available accounts.
+Let's create a "Hello World" example to verify your client configuration works correctly. This example will connect to your Infrahub instance and query the available accounts.
1. Create a new file called `hello_world.py`:
diff --git a/docs/docs/python-sdk/guides/python-typing.mdx b/docs/docs/python-sdk/guides/python-typing.mdx
index 9bc2c323..77780177 100644
--- a/docs/docs/python-sdk/guides/python-typing.mdx
+++ b/docs/docs/python-sdk/guides/python-typing.mdx
@@ -131,7 +131,7 @@ infrahubctl graphql generate-return-types queries/get_tags.gql
### Example workflow
-1. **Create your GraphQL queries** in `.gql` files preferably in a directory (e.g., `queries/`):
+1. **Create your GraphQL queries** in `.gql` files preferably in a directory (for example, `queries/`):
```graphql
# queries/get_tags.gql
diff --git a/docs/docs/python-sdk/sdk_ref/infrahub_sdk/node/node.mdx b/docs/docs/python-sdk/sdk_ref/infrahub_sdk/node/node.mdx
index ff1b817f..e23120dd 100644
--- a/docs/docs/python-sdk/sdk_ref/infrahub_sdk/node/node.mdx
+++ b/docs/docs/python-sdk/sdk_ref/infrahub_sdk/node/node.mdx
@@ -37,6 +37,44 @@ artifact_generate(self, name: str) -> None
artifact_fetch(self, name: str) -> str | dict[str, Any]
```
+#### `download_file`
+
+```python
+download_file(self, dest: Path | None = None) -> bytes | int
+```
+
+Download the file content from this FileObject node.
+
+This method is only available for nodes that inherit from CoreFileObject.
+The node must have been saved (have an id) before calling this method.
+
+**Args:**
+
+- `dest`: Optional destination path. If provided, the file will be streamed
+ directly to this path (memory-efficient for large files) and the
+ number of bytes written will be returned. If not provided, the
+ file content will be returned as bytes.
+
+**Returns:**
+
+- If ``dest`` is None: The file content as bytes.
+- If ``dest`` is provided: The number of bytes written to the file.
+
+**Raises:**
+
+- `FeatureNotSupportedError`: If this node doesn't inherit from CoreFileObject.
+- `ValueError`: If the node hasn't been saved yet or file not found.
+- `AuthenticationError`: If authentication fails.
+
+**Examples:**
+
+```python
+>>> # Download to memory
+>>> content = await contract.download_file()
+>>> # Stream to file (memory-efficient for large files)
+>>> bytes_written = await contract.download_file(dest=Path("/tmp/contract.pdf"))
+```
+
#### `delete`
```python
@@ -180,6 +218,44 @@ artifact_generate(self, name: str) -> None
artifact_fetch(self, name: str) -> str | dict[str, Any]
```
+#### `download_file`
+
+```python
+download_file(self, dest: Path | None = None) -> bytes | int
+```
+
+Download the file content from this FileObject node.
+
+This method is only available for nodes that inherit from CoreFileObject.
+The node must have been saved (have an id) before calling this method.
+
+**Args:**
+
+- `dest`: Optional destination path. If provided, the file will be streamed
+ directly to this path (memory-efficient for large files) and the
+ number of bytes written will be returned. If not provided, the
+ file content will be returned as bytes.
+
+**Returns:**
+
+- If ``dest`` is None: The file content as bytes.
+- If ``dest`` is provided: The number of bytes written to the file.
+
+**Raises:**
+
+- `FeatureNotSupportedError`: If this node doesn't inherit from CoreFileObject.
+- `ValueError`: If the node hasn't been saved yet or file not found.
+- `AuthenticationError`: If authentication fails.
+
+**Examples:**
+
+```python
+>>> # Download to memory
+>>> content = contract.download_file()
+>>> # Stream to file (memory-efficient for large files)
+>>> bytes_written = contract.download_file(dest=Path("/tmp/contract.pdf"))
+```
+
#### `delete`
```python
@@ -373,6 +449,70 @@ is_ip_address(self) -> bool
is_resource_pool(self) -> bool
```
+#### `is_file_object`
+
+```python
+is_file_object(self) -> bool
+```
+
+Check if this node inherits from CoreFileObject and supports file uploads.
+
+#### `upload_from_path`
+
+```python
+upload_from_path(self, path: Path) -> None
+```
+
+Set a file from disk to be uploaded when saving this FileObject node.
+
+The file will be streamed during upload, avoiding loading the entire file into memory.
+
+**Args:**
+
+- `path`: Path to the file on disk.
+
+**Raises:**
+
+- `FeatureNotSupportedError`: If this node doesn't inherit from CoreFileObject.
+
+#### `upload_from_bytes`
+
+```python
+upload_from_bytes(self, content: bytes | BinaryIO, name: str) -> None
+```
+
+Set content to be uploaded when saving this FileObject node.
+
+The content can be provided as bytes or a file-like object.
+Using BinaryIO is recommended for large content to stream during upload.
+
+**Args:**
+
+- `content`: The file content as bytes or a file-like object.
+- `name`: The filename to use for the uploaded file.
+
+**Raises:**
+
+- `FeatureNotSupportedError`: If this node doesn't inherit from CoreFileObject.
+
+**Examples:**
+
+```python
+>>> # Using bytes (for small files)
+>>> node.upload_from_bytes(content=b"file content", name="example.txt")
+>>> # Using file-like object (for large files)
+>>> with open("/path/to/file.bin", "rb") as f:
+... node.upload_from_bytes(content=f, name="file.bin")
+```
+
+#### `clear_file`
+
+```python
+clear_file(self) -> None
+```
+
+Clear any pending file content.
+
#### `get_raw_graphql_data`
```python
diff --git a/docs/docs/python-sdk/topics/object_file.mdx b/docs/docs/python-sdk/topics/object_file.mdx
index 751599f0..75744d89 100644
--- a/docs/docs/python-sdk/topics/object_file.mdx
+++ b/docs/docs/python-sdk/topics/object_file.mdx
@@ -68,13 +68,13 @@ spec:
> Multiple documents in a single YAML file are also supported, each document will be loaded separately. Documents are separated by `---`
-### Data Processing Parameters
+### Data processing parameters
The `parameters` field controls how the data in the object file is processed before loading into Infrahub:
-| Parameter | Description | Default |
-| -------------- | ------------------------------------------------------------------------------------------------------- | ------- |
-| `expand_range` | When set to `true`, range patterns (e.g., `[1-5]`) in string fields are expanded into multiple objects. | `false` |
+| Parameter | Description | Default |
+| -------------- | -------------------------------------------------------------------------------------------------------------- | ------- |
+| `expand_range` | When set to `true`, range patterns (for example, `[1-5]`) in string fields are expanded into multiple objects. | `false` |
When `expand_range` is not specified, it defaults to `false`.
@@ -208,9 +208,9 @@ Metadata support is planned for future releases. Currently, the Object file does
3. Validate object files before loading them into production environments.
4. Use comments in your YAML files to document complex relationships or dependencies.
-## Range Expansion in Object Files
+## Range expansion in object files
-The Infrahub Python SDK supports **range expansion** for string fields in object files when the `parameters > expand_range` is set to `true`. This feature allows you to specify a range pattern (e.g., `[1-5]`) in any string value, and the SDK will automatically expand it into multiple objects during validation and processing.
+The Infrahub Python SDK supports **range expansion** for string fields in object files when the `parameters > expand_range` is set to `true`. This feature allows you to specify a range pattern (for example, `[1-5]`) in any string value, and the SDK will automatically expand it into multiple objects during validation and processing.
```yaml
---
@@ -225,7 +225,7 @@ spec:
type: Country
```
-### How Range Expansion Works
+### How range expansion works
- Any string field containing a pattern like `[1-5]`, `[10-15]`, or `[1,3,5]` will be expanded into multiple objects.
- If multiple fields in the same object use range expansion, **all expanded lists must have the same length**. If not, validation will fail.
@@ -233,7 +233,7 @@ spec:
### Examples
-#### Single Field Expansion
+#### Single field expansion
```yaml
spec:
@@ -256,7 +256,7 @@ This will expand to:
type: Country
```
-#### Multiple Field Expansion (Matching Lengths)
+#### Multiple field expansion (matching lengths)
```yaml
spec:
@@ -283,7 +283,7 @@ This will expand to:
type: Country
```
-#### Error: Mismatched Range Lengths
+#### Error: mismatched range lengths
If you use ranges of different lengths in multiple fields:
diff --git a/docs/docs_generation/content_gen_methods/mdx/mdx_collapsed_overload_code_doc.py b/docs/docs_generation/content_gen_methods/mdx/mdx_collapsed_overload_code_doc.py
index 95ea8f00..1f680d15 100644
--- a/docs/docs_generation/content_gen_methods/mdx/mdx_collapsed_overload_code_doc.py
+++ b/docs/docs_generation/content_gen_methods/mdx/mdx_collapsed_overload_code_doc.py
@@ -44,7 +44,7 @@ def _collapse_overloads(self, mdx_file: MdxFile) -> MdxFile:
h3_parsed = _parse_sections(h2.content, heading_level=3)
new_lines = h3_parsed.reassembled(processed_h3)
processed_h2.append(
- MdxSection(name=h2.name, heading_level=h2.heading_level, _lines=[h2.heading] + new_lines)
+ MdxSection(name=h2.name, heading_level=h2.heading_level, _lines=[h2.heading, *new_lines])
)
new_content = "\n".join(parsed_h2.reassembled(processed_h2))
@@ -67,7 +67,7 @@ def _process_class_sections(self, h2_content: list[str]) -> list[ASection] | Non
h4_parsed = _parse_sections(h3.content, heading_level=4)
new_lines = h4_parsed.reassembled(collapsed_methods)
processed.append(
- MdxSection(name=h3.name, heading_level=h3.heading_level, _lines=[h3.heading] + new_lines)
+ MdxSection(name=h3.name, heading_level=h3.heading_level, _lines=[h3.heading, *new_lines])
)
return processed if any_collapsed else None
diff --git a/docs/docs_generation/content_gen_methods/mdx/mdx_section.py b/docs/docs_generation/content_gen_methods/mdx/mdx_section.py
index 6187f67a..ed4e25fa 100644
--- a/docs/docs_generation/content_gen_methods/mdx/mdx_section.py
+++ b/docs/docs_generation/content_gen_methods/mdx/mdx_section.py
@@ -19,7 +19,7 @@ def content(self) -> list[str]: ...
@property
def lines(self) -> list[str]:
- return [self.heading] + self.content
+ return [self.heading, *self.content]
@dataclass
diff --git a/infrahub_sdk/branch.py b/infrahub_sdk/branch.py
index 2c32a481..4c89bd41 100644
--- a/infrahub_sdk/branch.py
+++ b/infrahub_sdk/branch.py
@@ -19,6 +19,7 @@ class BranchStatus(str, Enum):
NEED_REBASE = "NEED_REBASE"
NEED_UPGRADE_REBASE = "NEED_UPGRADE_REBASE"
DELETING = "DELETING"
+ MERGED = "MERGED"
class BranchData(BaseModel):
diff --git a/infrahub_sdk/client.py b/infrahub_sdk/client.py
index f2ebfab7..988d30bd 100644
--- a/infrahub_sdk/client.py
+++ b/infrahub_sdk/client.py
@@ -5,18 +5,12 @@
import logging
import time
import warnings
-from collections.abc import Callable, Coroutine, Mapping, MutableMapping
+from collections.abc import AsyncIterator, Callable, Coroutine, Iterator, Mapping, MutableMapping
+from contextlib import asynccontextmanager, contextmanager
from datetime import datetime
from functools import wraps
from time import sleep
-from typing import (
- TYPE_CHECKING,
- Any,
- Literal,
- TypedDict,
- TypeVar,
- overload,
-)
+from typing import TYPE_CHECKING, Any, BinaryIO, Literal, TypedDict, TypeVar, overload
from urllib.parse import urlencode
import httpx
@@ -24,12 +18,7 @@
from typing_extensions import Self
from .batch import InfrahubBatch, InfrahubBatchSync
-from .branch import (
- MUTATION_QUERY_TASK,
- BranchData,
- InfrahubBranchManager,
- InfrahubBranchManagerSync,
-)
+from .branch import MUTATION_QUERY_TASK, BranchData, InfrahubBranchManager, InfrahubBranchManagerSync
from .config import Config
from .constants import InfrahubClientMode
from .convert_object_type import CONVERT_OBJECT_MUTATION, ConversionFieldInput
@@ -44,11 +33,8 @@
ServerNotResponsiveError,
URLNotFoundError,
)
-from .graphql import Mutation, Query
-from .node import (
- InfrahubNode,
- InfrahubNodeSync,
-)
+from .graphql import MultipartBuilder, Mutation, Query
+from .node import InfrahubNode, InfrahubNodeSync
from .object_store import ObjectStore, ObjectStoreSync
from .protocols_base import CoreNode, CoreNodeSync
from .queries import QUERY_USER, get_commit_update_mutation
@@ -994,7 +980,7 @@ async def execute_graphql(
messages = [error.get("message") for error in errors]
raise AuthenticationError(" | ".join(messages)) from exc
if exc.response.status_code == 404:
- raise URLNotFoundError(url=url)
+ raise URLNotFoundError(url=url) from exc
if not resp:
raise Error("Unexpected situation, resp hasn't been initialized.")
@@ -1008,6 +994,128 @@ async def execute_graphql(
# TODO add a special method to execute mutation that will check if the method returned OK
+ async def _execute_graphql_with_file(
+ self,
+ query: str,
+ variables: dict | None = None,
+ file_content: BinaryIO | None = None,
+ file_name: str | None = None,
+ branch_name: str | None = None,
+ timeout: int | None = None,
+ tracker: str | None = None,
+ ) -> dict:
+ """Execute a GraphQL mutation with a file upload using multipart/form-data.
+
+ This method follows the GraphQL Multipart Request Spec for file uploads.
+ The file is attached to the 'file' variable in the mutation.
+
+ Args:
+ query: GraphQL mutation query that includes a $file variable of type Upload!
+ variables: Variables to pass along with the GraphQL query.
+ file_content: The file content as a file-like object (BinaryIO).
+ file_name: The name of the file being uploaded.
+ branch_name: Name of the branch on which the mutation will be executed.
+ timeout: Timeout in seconds for the query.
+ tracker: Optional tracker for request tracing.
+
+ Raises:
+ GraphQLError: When the GraphQL response contains errors.
+
+ Returns:
+ dict: The GraphQL data payload (response["data"]).
+ """
+ branch_name = branch_name or self.default_branch
+ url = self._graphql_url(branch_name=branch_name)
+
+ # Prepare variables with file placeholder
+ variables = variables or {}
+ variables["file"] = None
+
+ headers = copy.copy(self.headers or {})
+ # Remove content-type header - httpx will set it for multipart
+ headers.pop("content-type", None)
+ if self.insert_tracker and tracker:
+ headers["X-Infrahub-Tracker"] = tracker
+
+ self._echo(url=url, query=query, variables=variables)
+
+ resp = await self._post_multipart(
+ url=url,
+ query=query,
+ variables=variables,
+ file_content=file_content,
+ file_name=file_name or "upload",
+ headers=headers,
+ timeout=timeout,
+ )
+
+ resp.raise_for_status()
+ response = decode_json(response=resp)
+
+ if "errors" in response:
+ raise GraphQLError(errors=response["errors"], query=query, variables=variables)
+
+ return response["data"]
+
+ @handle_relogin
+ async def _post_multipart(
+ self,
+ url: str,
+ query: str,
+ variables: dict,
+ file_content: BinaryIO | None,
+ file_name: str,
+ headers: dict | None = None,
+ timeout: int | None = None,
+ ) -> httpx.Response:
+ """Execute a HTTP POST with multipart/form-data for GraphQL file uploads.
+
+ The file_content is streamed directly from the file-like object, avoiding loading the entire file into memory for large files.
+ """
+ await self.login()
+
+ headers = headers or {}
+ base_headers = copy.copy(self.headers or {})
+ # Remove content-type from base headers - httpx will set it for multipart
+ base_headers.pop("content-type", None)
+ headers.update(base_headers)
+
+ # Build the multipart form data according to GraphQL Multipart Request Spec
+ files = MultipartBuilder.build_payload(
+ query=query, variables=variables, file_content=file_content, file_name=file_name
+ )
+
+ return await self._request_multipart(
+ url=url, headers=headers, timeout=timeout or self.default_timeout, files=files
+ )
+
+ def _build_proxy_config(self) -> ProxyConfig:
+ """Build proxy configuration for httpx AsyncClient."""
+ proxy_config: ProxyConfig = {"proxy": None, "mounts": None}
+ if self.config.proxy:
+ proxy_config["proxy"] = self.config.proxy
+ elif self.config.proxy_mounts.is_set:
+ proxy_config["mounts"] = {
+ key: httpx.AsyncHTTPTransport(proxy=value)
+ for key, value in self.config.proxy_mounts.model_dump(by_alias=True).items()
+ }
+ return proxy_config
+
+ async def _request_multipart(
+ self, url: str, headers: dict[str, Any], timeout: int, files: dict[str, Any]
+ ) -> httpx.Response:
+ """Execute a multipart HTTP POST request."""
+ async with httpx.AsyncClient(**self._build_proxy_config(), verify=self.config.tls_context) as client:
+ try:
+ response = await client.post(url=url, headers=headers, timeout=timeout, files=files)
+ except httpx.NetworkError as exc:
+ raise ServerNotReachableError(address=self.address) from exc
+ except httpx.ReadTimeout as exc:
+ raise ServerNotResponsiveError(url=url, timeout=timeout) from exc
+
+ self._record(response)
+ return response
+
@handle_relogin
async def _post(
self,
@@ -1057,6 +1165,36 @@ async def _get(self, url: str, headers: dict | None = None, timeout: int | None
timeout=timeout or self.default_timeout,
)
+ @asynccontextmanager
+ async def _get_streaming(
+ self, url: str, headers: dict | None = None, timeout: int | None = None
+ ) -> AsyncIterator[httpx.Response]:
+ """Execute a streaming HTTP GET with HTTPX.
+
+ Returns an async context manager that yields the streaming response.
+ Use this for downloading large files without loading into memory.
+
+ Raises:
+ ServerNotReachableError if we are not able to connect to the server
+ ServerNotResponsiveError if the server didn't respond before the timeout expired
+ """
+ await self.login()
+
+ headers = headers or {}
+ base_headers = copy.copy(self.headers or {})
+ headers.update(base_headers)
+
+ async with httpx.AsyncClient(**self._build_proxy_config(), verify=self.config.tls_context) as client:
+ try:
+ async with client.stream(
+ method="GET", url=url, headers=headers, timeout=timeout or self.default_timeout
+ ) as response:
+ yield response
+ except httpx.NetworkError as exc:
+ raise ServerNotReachableError(address=self.address) from exc
+ except httpx.ReadTimeout as exc:
+ raise ServerNotResponsiveError(url=url, timeout=timeout or self.default_timeout) from exc
+
async def _request(
self,
url: str,
@@ -1081,19 +1219,7 @@ async def _default_request_method(
if payload:
params["json"] = payload
- proxy_config: ProxyConfig = {"proxy": None, "mounts": None}
- if self.config.proxy:
- proxy_config["proxy"] = self.config.proxy
- elif self.config.proxy_mounts.is_set:
- proxy_config["mounts"] = {
- key: httpx.AsyncHTTPTransport(proxy=value)
- for key, value in self.config.proxy_mounts.model_dump(by_alias=True).items()
- }
-
- async with httpx.AsyncClient(
- **proxy_config,
- verify=self.config.tls_context,
- ) as client:
+ async with httpx.AsyncClient(**self._build_proxy_config(), verify=self.config.tls_context) as client:
try:
response = await client.request(
method=method.value,
@@ -1742,10 +1868,11 @@ async def convert_object_type(
for more information.
"""
- if fields_mapping is None:
- mapping_dict = {}
- else:
- mapping_dict = {field_name: model.model_dump(mode="json") for field_name, model in fields_mapping.items()}
+ mapping_dict = (
+ {}
+ if fields_mapping is None
+ else {field_name: model.model_dump(mode="json") for field_name, model in fields_mapping.items()}
+ )
branch_name = branch or self.default_branch
response = await self.execute_graphql(
@@ -1910,7 +2037,7 @@ def execute_graphql(
messages = [error.get("message") for error in errors]
raise AuthenticationError(" | ".join(messages)) from exc
if exc.response.status_code == 404:
- raise URLNotFoundError(url=url)
+ raise URLNotFoundError(url=url) from exc
if not resp:
raise Error("Unexpected situation, resp hasn't been initialized.")
@@ -1924,6 +2051,126 @@ def execute_graphql(
# TODO add a special method to execute mutation that will check if the method returned OK
+ def _execute_graphql_with_file(
+ self,
+ query: str,
+ variables: dict | None = None,
+ file_content: BinaryIO | None = None,
+ file_name: str | None = None,
+ branch_name: str | None = None,
+ timeout: int | None = None,
+ tracker: str | None = None,
+ ) -> dict:
+ """Execute a GraphQL mutation with a file upload using multipart/form-data.
+
+ This method follows the GraphQL Multipart Request Spec for file uploads.
+ The file is attached to the 'file' variable in the mutation.
+
+ Args:
+ query: GraphQL mutation query that includes a $file variable of type Upload!
+ variables: Variables to pass along with the GraphQL query.
+ file_content: The file content as a file-like object (BinaryIO).
+ file_name: The name of the file being uploaded.
+ branch_name: Name of the branch on which the mutation will be executed.
+ timeout: Timeout in seconds for the query.
+ tracker: Optional tracker for request tracing.
+
+ Raises:
+ GraphQLError: When the GraphQL response contains errors.
+
+ Returns:
+ dict: The GraphQL data payload (response["data"]).
+ """
+ branch_name = branch_name or self.default_branch
+ url = self._graphql_url(branch_name=branch_name)
+
+ # Prepare variables with file placeholder
+ variables = variables or {}
+ variables["file"] = None
+
+ headers = copy.copy(self.headers or {})
+ # Remove content-type header - httpx will set it for multipart
+ headers.pop("content-type", None)
+ if self.insert_tracker and tracker:
+ headers["X-Infrahub-Tracker"] = tracker
+
+ self._echo(url=url, query=query, variables=variables)
+
+ resp = self._post_multipart(
+ url=url,
+ query=query,
+ variables=variables,
+ file_content=file_content,
+ file_name=file_name or "upload",
+ headers=headers,
+ timeout=timeout,
+ )
+
+ resp.raise_for_status()
+ response = decode_json(response=resp)
+
+ if "errors" in response:
+ raise GraphQLError(errors=response["errors"], query=query, variables=variables)
+
+ return response["data"]
+
+ @handle_relogin_sync
+ def _post_multipart(
+ self,
+ url: str,
+ query: str,
+ variables: dict,
+ file_content: BinaryIO | None,
+ file_name: str,
+ headers: dict | None = None,
+ timeout: int | None = None,
+ ) -> httpx.Response:
+ """Execute a HTTP POST with multipart/form-data for GraphQL file uploads.
+
+ The file_content is streamed directly from the file-like object, avoiding loading the entire file into memory for large files.
+ """
+ self.login()
+
+ headers = headers or {}
+ base_headers = copy.copy(self.headers or {})
+ # Remove content-type from base headers - httpx will set it for multipart
+ base_headers.pop("content-type", None)
+ headers.update(base_headers)
+
+ # Build the multipart form data according to GraphQL Multipart Request Spec
+ files = MultipartBuilder.build_payload(
+ query=query, variables=variables, file_content=file_content, file_name=file_name
+ )
+
+ return self._request_multipart(url=url, headers=headers, timeout=timeout or self.default_timeout, files=files)
+
+ def _build_proxy_config(self) -> ProxyConfigSync:
+ """Build proxy configuration for httpx Client."""
+ proxy_config: ProxyConfigSync = {"proxy": None, "mounts": None}
+ if self.config.proxy:
+ proxy_config["proxy"] = self.config.proxy
+ elif self.config.proxy_mounts.is_set:
+ proxy_config["mounts"] = {
+ key: httpx.HTTPTransport(proxy=value)
+ for key, value in self.config.proxy_mounts.model_dump(by_alias=True).items()
+ }
+ return proxy_config
+
+ def _request_multipart(
+ self, url: str, headers: dict[str, Any], timeout: int, files: dict[str, Any]
+ ) -> httpx.Response:
+ """Execute a multipart HTTP POST request."""
+ with httpx.Client(**self._build_proxy_config(), verify=self.config.tls_context) as client:
+ try:
+ response = client.post(url=url, headers=headers, timeout=timeout, files=files)
+ except httpx.NetworkError as exc:
+ raise ServerNotReachableError(address=self.address) from exc
+ except httpx.ReadTimeout as exc:
+ raise ServerNotResponsiveError(url=url, timeout=timeout) from exc
+
+ self._record(response)
+ return response
+
def count(
self,
kind: str | type[SchemaType],
@@ -2995,6 +3242,36 @@ def _get(self, url: str, headers: dict | None = None, timeout: int | None = None
timeout=timeout or self.default_timeout,
)
+ @contextmanager
+ def _get_streaming(
+ self, url: str, headers: dict | None = None, timeout: int | None = None
+ ) -> Iterator[httpx.Response]:
+ """Execute a streaming HTTP GET with HTTPX.
+
+ Returns a context manager that yields the streaming response.
+ Use this for downloading large files without loading into memory.
+
+ Raises:
+ ServerNotReachableError if we are not able to connect to the server
+ ServerNotResponsiveError if the server didn't respond before the timeout expired
+ """
+ self.login()
+
+ headers = headers or {}
+ base_headers = copy.copy(self.headers or {})
+ headers.update(base_headers)
+
+ with httpx.Client(**self._build_proxy_config(), verify=self.config.tls_context) as client:
+ try:
+ with client.stream(
+ method="GET", url=url, headers=headers, timeout=timeout or self.default_timeout
+ ) as response:
+ yield response
+ except httpx.NetworkError as exc:
+ raise ServerNotReachableError(address=self.address) from exc
+ except httpx.ReadTimeout as exc:
+ raise ServerNotResponsiveError(url=url, timeout=timeout or self.default_timeout) from exc
+
@handle_relogin_sync
def _post(
self,
@@ -3047,20 +3324,7 @@ def _default_request_method(
if payload:
params["json"] = payload
- proxy_config: ProxyConfigSync = {"proxy": None, "mounts": None}
-
- if self.config.proxy:
- proxy_config["proxy"] = self.config.proxy
- elif self.config.proxy_mounts.is_set:
- proxy_config["mounts"] = {
- key: httpx.HTTPTransport(proxy=value)
- for key, value in self.config.proxy_mounts.model_dump(by_alias=True).items()
- }
-
- with httpx.Client(
- **proxy_config,
- verify=self.config.tls_context,
- ) as client:
+ with httpx.Client(**self._build_proxy_config(), verify=self.config.tls_context) as client:
try:
response = client.request(
method=method.value,
@@ -3162,10 +3426,11 @@ def convert_object_type(
for more information.
"""
- if fields_mapping is None:
- mapping_dict = {}
- else:
- mapping_dict = {field_name: model.model_dump(mode="json") for field_name, model in fields_mapping.items()}
+ mapping_dict = (
+ {}
+ if fields_mapping is None
+ else {field_name: model.model_dump(mode="json") for field_name, model in fields_mapping.items()}
+ )
branch_name = branch or self.default_branch
response = self.execute_graphql(
diff --git a/infrahub_sdk/ctl/branch.py b/infrahub_sdk/ctl/branch.py
index 60d67e86..d309c1d5 100644
--- a/infrahub_sdk/ctl/branch.py
+++ b/infrahub_sdk/ctl/branch.py
@@ -110,6 +110,10 @@ def generate_proposed_change_tables(proposed_changes: list[CoreProposedChange])
proposed_change_tables: list[Table] = []
for pc in proposed_changes:
+ metadata = pc.get_node_metadata()
+ created_by = metadata.created_by.display_label if metadata and metadata.created_by else "-"
+ created_at = format_timestamp(metadata.created_at) if metadata and metadata.created_at else "-"
+
# Create proposal table
proposed_change_table = Table(show_header=False, box=None)
proposed_change_table.add_column(justify="left")
@@ -119,8 +123,8 @@ def generate_proposed_change_tables(proposed_changes: list[CoreProposedChange])
proposed_change_table.add_row("Name", pc.name.value)
proposed_change_table.add_row("State", str(pc.state.value))
proposed_change_table.add_row("Is draft", "Yes" if pc.is_draft.value else "No")
- proposed_change_table.add_row("Created by", pc.created_by.peer.name.value) # type: ignore[union-attr]
- proposed_change_table.add_row("Created at", format_timestamp(str(pc.created_by.updated_at)))
+ proposed_change_table.add_row("Created by", created_by)
+ proposed_change_table.add_row("Created at", created_at)
proposed_change_table.add_row("Approvals", str(len(pc.approved_by.peers)))
proposed_change_table.add_row("Rejections", str(len(pc.rejected_by.peers)))
@@ -295,9 +299,9 @@ async def report(
proposed_changes = await client.filters(
kind=CoreProposedChange, # type: ignore[type-abstract]
source_branch__value=branch_name,
- include=["created_by"],
prefetch_relationships=True,
property=True,
+ include_metadata=True,
)
branch_table = generate_branch_report_table(branch=branch, diff_tree=diff_tree, git_files_changed=git_files_changed)
diff --git a/infrahub_sdk/ctl/cli_commands.py b/infrahub_sdk/ctl/cli_commands.py
index 2b571723..d7a636ed 100644
--- a/infrahub_sdk/ctl/cli_commands.py
+++ b/infrahub_sdk/ctl/cli_commands.py
@@ -239,7 +239,7 @@ async def _run_transform(
elif isinstance(error, str) and "Branch:" in error:
console.print(f"[yellow] - {error}")
console.print("[yellow] you can specify a different branch with --branch")
- raise typer.Abort
+ raise typer.Abort from None
if inspect.iscoroutinefunction(transform_func):
output = await transform_func(response)
@@ -350,10 +350,7 @@ def transform(
# Run Transform
result = asyncio.run(transform.run(data=data))
- if isinstance(result, str):
- json_string = result
- else:
- json_string = ujson.dumps(result, indent=2, sort_keys=True)
+ json_string = result if isinstance(result, str) else ujson.dumps(result, indent=2, sort_keys=True)
if out:
write_to_file(Path(out), json_string)
diff --git a/infrahub_sdk/ctl/config.py b/infrahub_sdk/ctl/config.py
index b3d2a404..a5b522b2 100644
--- a/infrahub_sdk/ctl/config.py
+++ b/infrahub_sdk/ctl/config.py
@@ -90,7 +90,7 @@ def load_and_exit(self, config_file: str | Path = "infrahubctl.toml", config_dat
for error in exc.errors():
loc_str = [str(item) for item in error["loc"]]
print(f" {'/'.join(loc_str)} | {error['msg']} ({error['type']})")
- raise typer.Abort
+ raise typer.Abort from None
SETTINGS = ConfiguredSettings()
diff --git a/infrahub_sdk/ctl/exporter.py b/infrahub_sdk/ctl/exporter.py
index ae5e5d18..402b47ec 100644
--- a/infrahub_sdk/ctl/exporter.py
+++ b/infrahub_sdk/ctl/exporter.py
@@ -46,4 +46,4 @@ def dump(
aiorun(exporter.export(export_directory=directory, namespaces=namespace, branch=branch, exclude=exclude))
except TransferError as exc:
console.print(f"[red]{exc}")
- raise typer.Exit(1)
+ raise typer.Exit(1) from None
diff --git a/infrahub_sdk/ctl/importer.py b/infrahub_sdk/ctl/importer.py
index 420c6d75..d3318eb5 100644
--- a/infrahub_sdk/ctl/importer.py
+++ b/infrahub_sdk/ctl/importer.py
@@ -50,4 +50,4 @@ def load(
aiorun(importer.import_data(import_directory=directory, branch=branch))
except TransferError as exc:
console.print(f"[red]{exc}")
- raise typer.Exit(1)
+ raise typer.Exit(1) from None
diff --git a/infrahub_sdk/ctl/schema.py b/infrahub_sdk/ctl/schema.py
index 5a977b59..9532959e 100644
--- a/infrahub_sdk/ctl/schema.py
+++ b/infrahub_sdk/ctl/schema.py
@@ -2,6 +2,7 @@
import asyncio
import time
+from datetime import datetime, timezone
from pathlib import Path
from typing import TYPE_CHECKING, Any
@@ -211,3 +212,49 @@ def _display_schema_warnings(console: Console, warnings: list[SchemaWarning]) ->
console.print(
f"[yellow] {warning.type.value}: {warning.message} [{', '.join([kind.display for kind in warning.kinds])}]"
)
+
+
+def _default_export_directory() -> Path:
+ timestamp = datetime.now(timezone.utc).astimezone().strftime("%Y%m%d-%H%M%S")
+ return Path(f"infrahub-schema-export-{timestamp}")
+
+
+@app.command()
+@catch_exception(console=console)
+async def export(
+ directory: Path = typer.Option(_default_export_directory, help="Directory path to store schema files"),
+ branch: str = typer.Option(None, help="Branch from which to export the schema"),
+ namespaces: list[str] = typer.Option([], help="Namespace(s) to export (default: all user-defined)"),
+ debug: bool = False,
+ _: str = CONFIG_PARAM,
+) -> None:
+ """Export the schema from Infrahub as YAML files, one per namespace."""
+ init_logging(debug=debug)
+
+ client = initialize_client()
+ user_schemas = await client.schema.export(
+ branch=branch,
+ namespaces=namespaces or None,
+ )
+
+ if not user_schemas.namespaces:
+ console.print("[yellow]No user-defined schema found to export.")
+ return
+
+ directory.mkdir(parents=True, exist_ok=True)
+
+ for ns, data in sorted(user_schemas.namespaces.items()):
+ payload: dict[str, Any] = {"version": "1.0"}
+ if data.generics:
+ payload["generics"] = data.generics
+ if data.nodes:
+ payload["nodes"] = data.nodes
+
+ output_file = directory / f"{ns.lower()}.yml"
+ output_file.write_text(
+ yaml.dump(payload, default_flow_style=False, sort_keys=False, allow_unicode=True),
+ encoding="utf-8",
+ )
+ console.print(f"[green] Exported namespace '{ns}' to {output_file}")
+
+ console.print(f"[green] Schema exported to {directory}")
diff --git a/infrahub_sdk/ctl/validate.py b/infrahub_sdk/ctl/validate.py
index 3ffbd85a..07256faf 100644
--- a/infrahub_sdk/ctl/validate.py
+++ b/infrahub_sdk/ctl/validate.py
@@ -48,7 +48,7 @@ async def validate_schema(schema: Path, _: str = CONFIG_PARAM) -> None:
for error in exc.errors():
loc_str = [str(item) for item in error["loc"]]
console.print(f" '{'/'.join(loc_str)}' | {error['msg']} ({error['type']})")
- raise typer.Exit(1)
+ raise typer.Exit(1) from None
console.print("[green]Schema is valid !!")
diff --git a/infrahub_sdk/exceptions.py b/infrahub_sdk/exceptions.py
index e1dba1d5..727239bf 100644
--- a/infrahub_sdk/exceptions.py
+++ b/infrahub_sdk/exceptions.py
@@ -136,7 +136,7 @@ def __init__(self, position: list[int | str], message: str) -> None:
super().__init__(self.message)
def __str__(self) -> str:
- return f"{'.'.join(map(str, self.position))}: {self.message}"
+ return f"{'.'.join(str(p) for p in self.position)}: {self.message}"
class AuthenticationError(Error):
diff --git a/infrahub_sdk/file_handler.py b/infrahub_sdk/file_handler.py
new file mode 100644
index 00000000..5d32441a
--- /dev/null
+++ b/infrahub_sdk/file_handler.py
@@ -0,0 +1,348 @@
+from __future__ import annotations
+
+from dataclasses import dataclass
+from io import BytesIO
+from pathlib import Path
+from typing import TYPE_CHECKING, BinaryIO, cast, overload
+
+import anyio
+import httpx
+
+from .exceptions import AuthenticationError, NodeNotFoundError, ServerNotReachableError
+
+if TYPE_CHECKING:
+ from .client import InfrahubClient, InfrahubClientSync
+
+
+@dataclass
+class PreparedFile:
+ file_object: BinaryIO | None
+ filename: str | None
+ should_close: bool
+
+
+class FileHandlerBase:
+ """Base class for file handling operations.
+
+ Provides common functionality for both async and sync file handlers, including upload preparation and error handling.
+ """
+
+ @staticmethod
+ async def prepare_upload(content: bytes | Path | BinaryIO | None, name: str | None = None) -> PreparedFile:
+ """Prepare file content for upload (async version).
+
+ Converts various content types to a consistent BinaryIO interface for streaming uploads.
+ For Path inputs, opens the file handle in a thread pool to avoid blocking the event loop.
+ The actual file reading is streamed by httpx during the HTTP request.
+
+ Args:
+ content: The file content as bytes, a Path to a file, or a file-like object.
+ Can be None if no file is set.
+ name: Optional filename. If not provided and content is a Path,
+ the filename will be derived from the path.
+
+ Returns:
+ A PreparedFile containing the file object, filename, and whether it should be closed.
+ """
+ if content is None:
+ return PreparedFile(file_object=None, filename=None, should_close=False)
+
+ if name is None and isinstance(content, Path):
+ name = content.name
+
+ filename = name or "uploaded_file"
+
+ if isinstance(content, bytes):
+ return PreparedFile(file_object=BytesIO(content), filename=filename, should_close=False)
+ if isinstance(content, Path):
+ # Open file in thread pool to avoid blocking the event loop
+ # Returns a sync file handle that httpx can stream from in chunks
+ file_obj = await anyio.to_thread.run_sync(content.open, "rb")
+ return PreparedFile(file_object=cast("BinaryIO", file_obj), filename=filename, should_close=True)
+
+ # At this point, content must be a BinaryIO (file-like object)
+ return PreparedFile(file_object=cast("BinaryIO", content), filename=filename, should_close=False)
+
+ @staticmethod
+ def prepare_upload_sync(content: bytes | Path | BinaryIO | None, name: str | None = None) -> PreparedFile:
+ """Prepare file content for upload (sync version).
+
+ Converts various content types to a consistent BinaryIO interface for streaming uploads.
+
+ Args:
+ content: The file content as bytes, a Path to a file, or a file-like object.
+ Can be None if no file is set.
+ name: Optional filename. If not provided and content is a Path,
+ the filename will be derived from the path.
+
+ Returns:
+ A PreparedFile containing the file object, filename, and whether it should be closed.
+ """
+ if content is None:
+ return PreparedFile(file_object=None, filename=None, should_close=False)
+
+ if name is None and isinstance(content, Path):
+ name = content.name
+
+ filename = name or "uploaded_file"
+
+ if isinstance(content, bytes):
+ return PreparedFile(file_object=BytesIO(content), filename=filename, should_close=False)
+ if isinstance(content, Path):
+ return PreparedFile(file_object=content.open("rb"), filename=filename, should_close=True)
+
+ # At this point, content must be a BinaryIO (file-like object)
+ return PreparedFile(file_object=cast("BinaryIO", content), filename=filename, should_close=False)
+
+ @staticmethod
+ def handle_error_response(exc: httpx.HTTPStatusError) -> None:
+ """Handle HTTP error responses for file operations.
+
+ Args:
+ exc: The HTTP status error from httpx.
+
+ Raises:
+ AuthenticationError: If authentication fails (401/403).
+ NodeNotFoundError: If the file/node is not found (404).
+ httpx.HTTPStatusError: For other HTTP errors.
+ """
+ if exc.response.status_code in {401, 403}:
+ response = exc.response.json()
+ errors = response.get("errors", [])
+ messages = [error.get("message") for error in errors]
+ raise AuthenticationError(" | ".join(messages)) from exc
+ if exc.response.status_code == 404:
+ response = exc.response.json()
+ detail = response.get("detail", "File not found")
+ raise NodeNotFoundError(node_type="FileObject", identifier=detail) from exc
+ raise exc
+
+ @staticmethod
+ def handle_response(resp: httpx.Response) -> bytes:
+ """Handle the HTTP response and return file content as bytes.
+
+ Args:
+ resp: The HTTP response from httpx.
+
+ Returns:
+ The file content as bytes.
+
+ Raises:
+ AuthenticationError: If authentication fails.
+ NodeNotFoundError: If the file is not found.
+ """
+ try:
+ resp.raise_for_status()
+ except httpx.HTTPStatusError as exc:
+ FileHandlerBase.handle_error_response(exc=exc)
+ return resp.content
+
+
+class FileHandler(FileHandlerBase):
+ """Async file handler for download operations.
+
+ Handles file downloads with support for streaming to disk
+ for memory-efficient handling of large files.
+ """
+
+ def __init__(self, client: InfrahubClient) -> None:
+ """Initialize the async file handler.
+
+ Args:
+ client: The async Infrahub client instance.
+ """
+ self._client = client
+
+ def _build_url(self, node_id: str, branch: str | None) -> str:
+ """Build the download URL for a file.
+
+ Args:
+ node_id: The ID of the FileObject node.
+ branch: Optional branch name.
+
+ Returns:
+ The complete URL for downloading the file.
+ """
+ url = f"{self._client.address}/api/storage/files/{node_id}"
+ if branch:
+ url = f"{url}?branch={branch}"
+ return url
+
+ @overload
+ async def download(self, node_id: str, branch: str | None) -> bytes: ...
+
+ @overload
+ async def download(self, node_id: str, branch: str | None, dest: Path) -> int: ...
+
+ @overload
+ async def download(self, node_id: str, branch: str | None, dest: None) -> bytes: ...
+
+ async def download(self, node_id: str, branch: str | None, dest: Path | None = None) -> bytes | int:
+ """Download file content from a FileObject node.
+
+ Args:
+ node_id: The ID of the FileObject node.
+ branch: Optional branch name. Uses client default if not provided.
+ dest: Optional destination path. If provided, streams to disk.
+
+ Returns:
+ If dest is None: The file content as bytes.
+ If dest is provided: The number of bytes written.
+
+ Raises:
+ ServerNotReachableError: If the server is not reachable.
+ AuthenticationError: If authentication fails.
+ NodeNotFoundError: If the node/file is not found.
+ """
+ effective_branch = branch or self._client.default_branch
+ url = self._build_url(node_id=node_id, branch=effective_branch)
+
+ if dest is not None:
+ return await self._stream_to_file(url=url, dest=dest)
+
+ try:
+ resp = await self._client._get(url=url)
+ except ServerNotReachableError:
+ self._client.log.error(f"Unable to connect to {self._client.address}")
+ raise
+
+ return self.handle_response(resp=resp)
+
+ async def _stream_to_file(self, url: str, dest: Path) -> int:
+ """Stream download directly to a file without loading into memory.
+
+ Args:
+ url: The URL to download from.
+ dest: The destination path to write to.
+
+ Returns:
+ The number of bytes written to the file.
+
+ Raises:
+ ServerNotReachableError: If the server is not reachable.
+ AuthenticationError: If authentication fails.
+ NodeNotFoundError: If the file is not found.
+ """
+ try:
+ async with self._client._get_streaming(url=url) as resp:
+ try:
+ resp.raise_for_status()
+ except httpx.HTTPStatusError as exc:
+ # Need to read the response body for error details
+ await resp.aread()
+ self.handle_error_response(exc=exc)
+
+ bytes_written = 0
+ async with await anyio.Path(dest).open("wb") as f:
+ async for chunk in resp.aiter_bytes(chunk_size=65536):
+ await f.write(chunk)
+ bytes_written += len(chunk)
+ return bytes_written
+ except ServerNotReachableError:
+ self._client.log.error(f"Unable to connect to {self._client.address}")
+ raise
+
+
+class FileHandlerSync(FileHandlerBase):
+ """Sync file handler for download operations.
+
+ Handles file downloads with support for streaming to disk
+ for memory-efficient handling of large files.
+ """
+
+ def __init__(self, client: InfrahubClientSync) -> None:
+ """Initialize the sync file handler.
+
+ Args:
+ client: The sync Infrahub client instance.
+ """
+ self._client = client
+
+ def _build_url(self, node_id: str, branch: str | None) -> str:
+ """Build the download URL for a file.
+
+ Args:
+ node_id: The ID of the FileObject node.
+ branch: Optional branch name.
+
+ Returns:
+ The complete URL for downloading the file.
+ """
+ url = f"{self._client.address}/api/storage/files/{node_id}"
+ if branch:
+ url = f"{url}?branch={branch}"
+ return url
+
+ @overload
+ def download(self, node_id: str, branch: str | None) -> bytes: ...
+
+ @overload
+ def download(self, node_id: str, branch: str | None, dest: Path) -> int: ...
+
+ @overload
+ def download(self, node_id: str, branch: str | None, dest: None) -> bytes: ...
+
+ def download(self, node_id: str, branch: str | None, dest: Path | None = None) -> bytes | int:
+ """Download file content from a FileObject node.
+
+ Args:
+ node_id: The ID of the FileObject node.
+ branch: Optional branch name. Uses client default if not provided.
+ dest: Optional destination path. If provided, streams to disk.
+
+ Returns:
+ If dest is None: The file content as bytes.
+ If dest is provided: The number of bytes written.
+
+ Raises:
+ ServerNotReachableError: If the server is not reachable.
+ AuthenticationError: If authentication fails.
+ NodeNotFoundError: If the node/file is not found.
+ """
+ effective_branch = branch or self._client.default_branch
+ url = self._build_url(node_id=node_id, branch=effective_branch)
+
+ if dest is not None:
+ return self._stream_to_file(url=url, dest=dest)
+
+ try:
+ resp = self._client._get(url=url)
+ except ServerNotReachableError:
+ self._client.log.error(f"Unable to connect to {self._client.address}")
+ raise
+
+ return self.handle_response(resp=resp)
+
+ def _stream_to_file(self, url: str, dest: Path) -> int:
+ """Stream download directly to a file without loading into memory.
+
+ Args:
+ url: The URL to download from.
+ dest: The destination path to write to.
+
+ Returns:
+ The number of bytes written to the file.
+
+ Raises:
+ ServerNotReachableError: If the server is not reachable.
+ AuthenticationError: If authentication fails.
+ NodeNotFoundError: If the file is not found.
+ """
+ try:
+ with self._client._get_streaming(url=url) as resp:
+ try:
+ resp.raise_for_status()
+ except httpx.HTTPStatusError as exc:
+ # Need to read the response body for error details
+ resp.read()
+ self.handle_error_response(exc=exc)
+
+ bytes_written = 0
+ with dest.open("wb") as f:
+ for chunk in resp.iter_bytes(chunk_size=65536):
+ f.write(chunk)
+ bytes_written += len(chunk)
+ return bytes_written
+ except ServerNotReachableError:
+ self._client.log.error(f"Unable to connect to {self._client.address}")
+ raise
diff --git a/infrahub_sdk/graphql/__init__.py b/infrahub_sdk/graphql/__init__.py
index 33438e35..743919b6 100644
--- a/infrahub_sdk/graphql/__init__.py
+++ b/infrahub_sdk/graphql/__init__.py
@@ -1,9 +1,11 @@
from .constants import VARIABLE_TYPE_MAPPING
+from .multipart import MultipartBuilder
from .query import Mutation, Query
from .renderers import render_input_block, render_query_block, render_variables_to_string
__all__ = [
"VARIABLE_TYPE_MAPPING",
+ "MultipartBuilder",
"Mutation",
"Query",
"render_input_block",
diff --git a/infrahub_sdk/graphql/constants.py b/infrahub_sdk/graphql/constants.py
index e2033155..0fed5c57 100644
--- a/infrahub_sdk/graphql/constants.py
+++ b/infrahub_sdk/graphql/constants.py
@@ -1,4 +1,6 @@
from datetime import datetime
+from pathlib import Path
+from typing import BinaryIO
VARIABLE_TYPE_MAPPING = (
(str, "String!"),
@@ -11,4 +13,7 @@
(bool | None, "Boolean"),
(datetime, "DateTime!"),
(datetime | None, "DateTime"),
+ (bytes, "Upload!"),
+ (Path, "Upload!"),
+ (BinaryIO, "Upload!"),
)
diff --git a/infrahub_sdk/graphql/multipart.py b/infrahub_sdk/graphql/multipart.py
new file mode 100644
index 00000000..bdb1f84e
--- /dev/null
+++ b/infrahub_sdk/graphql/multipart.py
@@ -0,0 +1,100 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any
+
+import ujson
+
+if TYPE_CHECKING:
+ from typing import BinaryIO
+
+
+class MultipartBuilder:
+ """Builds multipart form data payloads for GraphQL file uploads.
+
+ This class implements the GraphQL Multipart Request Spec for uploading files via GraphQL mutations. The spec defines a standard way to send files
+ alongside GraphQL operations using multipart/form-data.
+
+ The payload structure follows the spec:
+ - operations: JSON containing the GraphQL query and variables
+ - map: JSON mapping file keys to variable paths
+ - 0, 1, ...: The actual file contents
+
+ Example payload:
+ {
+ "operations": '{"query": "mutation($file: Upload!) {...}", "variables": {"file": null}}',
+ "map": '{"0": ["variables.file"]}',
+ "0": (filename, file_content)
+ }
+ """
+
+ @staticmethod
+ def build_operations(query: str, variables: dict[str, Any]) -> str:
+ """Build the operations JSON string.
+
+ Args:
+ query: The GraphQL query string.
+ variables: The variables dict (file variable should be null).
+
+ Returns:
+ JSON string containing the query and variables.
+ """
+ return ujson.dumps({"query": query, "variables": variables})
+
+ @staticmethod
+ def build_file_map(file_key: str = "0", variable_path: str = "variables.file") -> str:
+ """Build the file map JSON string.
+
+ Args:
+ file_key: The key used for the file in the multipart payload.
+ variable_path: The path to the file variable in the GraphQL variables.
+
+ Returns:
+ JSON string mapping the file key to the variable path.
+ """
+ return ujson.dumps({file_key: [variable_path]})
+
+ @staticmethod
+ def build_payload(
+ query: str,
+ variables: dict[str, Any],
+ file_content: BinaryIO | None = None,
+ file_name: str = "upload",
+ ) -> dict[str, Any]:
+ """Build the complete multipart form data payload.
+
+ Constructs the payload according to the GraphQL Multipart Request Spec. The returned dict can be passed directly to httpx as the `files`
+ parameter.
+
+ Args:
+ query: The GraphQL query string containing $file: Upload! variable.
+ variables: The variables dict. The 'file' key will be set to null.
+ file_content: The file content as a file-like object (BinaryIO).
+ If None, only the operations and map will be included.
+ file_name: The filename to use for the upload.
+
+ Returns:
+ A dict suitable for httpx's `files` parameter in a POST request.
+
+ Example:
+ >>> builder = MultipartBuilder()
+ >>> payload = builder.build_payload(
+ ... query="mutation($file: Upload!) { upload(file: $file) { id } }",
+ ... variables={"other": "value"},
+ ... file_content=open("file.pdf", "rb"),
+ ... file_name="document.pdf",
+ ... )
+ >>> # payload can be passed to httpx.post(..., files=payload)
+ """
+ # Ensure file variable is null (spec requirement)
+ variables = {**variables, "file": None}
+
+ operations = MultipartBuilder.build_operations(query=query, variables=variables)
+ file_map = MultipartBuilder.build_file_map()
+
+ files: dict[str, Any] = {"operations": (None, operations), "map": (None, file_map)}
+
+ if file_content is not None:
+ # httpx streams from file-like objects automatically
+ files["0"] = (file_name, file_content)
+
+ return files
diff --git a/infrahub_sdk/graphql/renderers.py b/infrahub_sdk/graphql/renderers.py
index b0d2ab28..5b6c2c0f 100644
--- a/infrahub_sdk/graphql/renderers.py
+++ b/infrahub_sdk/graphql/renderers.py
@@ -3,7 +3,8 @@
import json
from datetime import datetime
from enum import Enum
-from typing import Any
+from pathlib import Path
+from typing import Any, BinaryIO
from pydantic import BaseModel
@@ -88,7 +89,7 @@ def convert_to_graphql_as_string(value: Any, convert_enum: bool = False) -> str:
return str(value)
-GRAPHQL_VARIABLE_TYPES = type[str | int | float | bool | datetime | None]
+GRAPHQL_VARIABLE_TYPES = type[str | int | float | bool | datetime | bytes | Path | BinaryIO | None]
def render_variables_to_string(data: dict[str, GRAPHQL_VARIABLE_TYPES]) -> str:
@@ -148,10 +149,7 @@ def render_query_block(data: dict, offset: int = 4, indentation: int = 4, conver
elif isinstance(value, dict) and len(value) == 1 and alias_key in value and value[alias_key]:
lines.append(f"{offset_str}{value[alias_key]}: {key}")
elif isinstance(value, dict):
- if value.get(alias_key):
- key_str = f"{value[alias_key]}: {key}"
- else:
- key_str = key
+ key_str = f"{value[alias_key]}: {key}" if value.get(alias_key) else key
if value.get(filters_key):
filters_str = ", ".join(
diff --git a/infrahub_sdk/node/constants.py b/infrahub_sdk/node/constants.py
index 8d301115..7a0bc6fd 100644
--- a/infrahub_sdk/node/constants.py
+++ b/infrahub_sdk/node/constants.py
@@ -27,6 +27,9 @@
ARTIFACT_DEFINITION_GENERATE_FEATURE_NOT_SUPPORTED_MESSAGE = (
"calling generate is only supported for CoreArtifactDefinition nodes"
)
+FILE_DOWNLOAD_FEATURE_NOT_SUPPORTED_MESSAGE = (
+ "calling download_file is only supported for nodes that inherit from CoreFileObject"
+)
HIERARCHY_FETCH_FEATURE_NOT_SUPPORTED_MESSAGE = "Hierarchical fields are not supported for this node."
diff --git a/infrahub_sdk/node/node.py b/infrahub_sdk/node/node.py
index 9d024cbb..a47209dc 100644
--- a/infrahub_sdk/node/node.py
+++ b/infrahub_sdk/node/node.py
@@ -2,10 +2,12 @@
from collections.abc import Iterable
from copy import copy, deepcopy
-from typing import TYPE_CHECKING, Any
+from pathlib import Path
+from typing import TYPE_CHECKING, Any, BinaryIO
from ..constants import InfrahubClientMode
from ..exceptions import FeatureNotSupportedError, NodeNotFoundError, ResourceNotDefinedError, SchemaNotFoundError
+from ..file_handler import FileHandler, FileHandlerBase, FileHandlerSync, PreparedFile
from ..graphql import Mutation, Query
from ..schema import (
GenericSchemaAPI,
@@ -21,6 +23,7 @@
ARTIFACT_DEFINITION_GENERATE_FEATURE_NOT_SUPPORTED_MESSAGE,
ARTIFACT_FETCH_FEATURE_NOT_SUPPORTED_MESSAGE,
ARTIFACT_GENERATE_FEATURE_NOT_SUPPORTED_MESSAGE,
+ FILE_DOWNLOAD_FEATURE_NOT_SUPPORTED_MESSAGE,
PROPERTIES_OBJECT,
)
from .metadata import NodeMetadata
@@ -65,14 +68,15 @@ def __init__(self, schema: MainSchemaTypesAPI, branch: str, data: dict | None =
self._attributes = [item.name for item in self._schema.attributes]
self._relationships = [item.name for item in self._schema.relationships]
- # GenericSchemaAPI doesn't have inherit_from, so we need to check the type first
- if isinstance(schema, GenericSchemaAPI):
- self._artifact_support = False
- else:
- inherit_from = getattr(schema, "inherit_from", None) or []
- self._artifact_support = "CoreArtifactTarget" in inherit_from
+ # GenericSchemaAPI doesn't have inherit_from
+ inherit_from: list[str] = getattr(schema, "inherit_from", None) or []
+ self._artifact_support = "CoreArtifactTarget" in inherit_from
+ self._file_object_support = "CoreFileObject" in inherit_from
self._artifact_definition_support = schema.kind == "CoreArtifactDefinition"
+ self._file_content: bytes | Path | BinaryIO | None = None
+ self._file_name: str | None = None
+
# Check if this node is hierarchical (supports parent/children and ancestors/descendants)
if not isinstance(schema, (ProfileSchemaAPI, GenericSchemaAPI, TemplateSchemaAPI)):
self._hierarchy_support = getattr(schema, "hierarchy", None) is not None
@@ -143,7 +147,7 @@ def get_human_friendly_id_as_string(self, include_kind: bool = False) -> str | N
if not hfid:
return None
if include_kind:
- hfid = [self.get_kind()] + hfid
+ hfid = [self.get_kind(), *hfid]
return "__".join(hfid)
@property
@@ -199,7 +203,7 @@ def get_kind(self) -> str:
def get_all_kinds(self) -> list[str]:
if inherit_from := getattr(self._schema, "inherit_from", None):
- return [self._schema.kind] + inherit_from
+ return [self._schema.kind, *inherit_from]
return [self._schema.kind]
def is_ip_prefix(self) -> bool:
@@ -213,6 +217,72 @@ def is_ip_address(self) -> bool:
def is_resource_pool(self) -> bool:
return hasattr(self._schema, "inherit_from") and "CoreResourcePool" in self._schema.inherit_from # type: ignore[union-attr]
+ def is_file_object(self) -> bool:
+ """Check if this node inherits from CoreFileObject and supports file uploads."""
+ return self._file_object_support
+
+ def upload_from_path(self, path: Path) -> None:
+ """Set a file from disk to be uploaded when saving this FileObject node.
+
+ The file will be streamed during upload, avoiding loading the entire file into memory.
+
+ Args:
+ path: Path to the file on disk.
+
+ Raises:
+ FeatureNotSupportedError: If this node doesn't inherit from CoreFileObject.
+
+ Example:
+ node.upload_from_path(path=Path("/path/to/large_file.pdf"))
+ """
+ if not self._file_object_support:
+ raise FeatureNotSupportedError(
+ f"File upload is not supported for {self._schema.kind}. Only nodes inheriting from CoreFileObject support file uploads."
+ )
+ self._file_content = path
+ self._file_name = path.name
+
+ def upload_from_bytes(self, content: bytes | BinaryIO, name: str) -> None:
+ """Set content to be uploaded when saving this FileObject node.
+
+ The content can be provided as bytes or a file-like object.
+ Using BinaryIO is recommended for large content to stream during upload.
+
+ Args:
+ content: The file content as bytes or a file-like object.
+ name: The filename to use for the uploaded file.
+
+ Raises:
+ FeatureNotSupportedError: If this node doesn't inherit from CoreFileObject.
+
+ Examples:
+ >>> # Using bytes (for small files)
+ >>> node.upload_from_bytes(content=b"file content", name="example.txt")
+
+ >>> # Using file-like object (for large files)
+ >>> with open("/path/to/file.bin", "rb") as f:
+ ... node.upload_from_bytes(content=f, name="file.bin")
+ """
+ if not self._file_object_support:
+ raise FeatureNotSupportedError(
+ f"File upload is not supported for {self._schema.kind}. Only nodes inheriting from CoreFileObject support file uploads."
+ )
+ self._file_content = content
+ self._file_name = name
+
+ def clear_file(self) -> None:
+ """Clear any pending file content."""
+ self._file_content = None
+ self._file_name = None
+
+ async def _get_file_for_upload(self) -> PreparedFile:
+ """Get the file content as a file-like object for upload (async version)."""
+ return await FileHandlerBase.prepare_upload(content=self._file_content, name=self._file_name)
+
+ def _get_file_for_upload_sync(self) -> PreparedFile:
+ """Get the file content as a file-like object for upload (sync version)."""
+ return FileHandlerBase.prepare_upload_sync(content=self._file_content, name=self._file_name)
+
def get_raw_graphql_data(self) -> dict | None:
return self._data
@@ -288,10 +358,16 @@ def _generate_input_data( # noqa: C901
elif self.hfid is not None and not exclude_hfid:
data["hfid"] = self.hfid
- mutation_payload = {"data": data}
+ mutation_payload: dict[str, Any] = {"data": data}
if context_data := self._get_request_context(request_context=request_context):
mutation_payload["context"] = context_data
+ # Add file variable for FileObject nodes with pending file content
+ # file is a mutation argument at the same level as data, not inside data
+ if self._file_object_support and self._file_content is not None:
+ mutation_payload["file"] = "$file"
+ mutation_variables["file"] = bytes
+
return {
"data": mutation_payload,
"variables": variables,
@@ -417,6 +493,10 @@ def _validate_artifact_definition_support(self, message: str) -> None:
if not self._artifact_definition_support:
raise FeatureNotSupportedError(message)
+ def _validate_file_object_support(self, message: str) -> None:
+ if not self._file_object_support:
+ raise FeatureNotSupportedError(message)
+
def generate_query_data_init(
self,
filters: dict[str, Any] | None = None,
@@ -506,6 +586,7 @@ def __init__(
data: Optional data to initialize the node.
"""
self._client = client
+ self._file_handler = FileHandler(client=client)
# Extract node_metadata before extracting node data (node_metadata is sibling to node in edges)
node_metadata_data: dict | None = None
@@ -558,10 +639,7 @@ def _init_relationships(self, data: dict | RelatedNode | None = None) -> None:
)
if value is not None
}
- if peer_id_data:
- rel_data = peer_id_data
- else:
- rel_data = None
+ rel_data = peer_id_data or None
self._relationship_cardinality_one_data[rel_schema.name] = RelatedNode(
name=rel_schema.name, branch=self._branch, client=self._client, schema=rel_schema, data=rel_data
)
@@ -694,6 +772,41 @@ async def artifact_fetch(self, name: str) -> str | dict[str, Any]:
artifact = await self._client.get(kind="CoreArtifact", name__value=name, object__ids=[self.id])
return await self._client.object_store.get(identifier=artifact._get_attribute(name="storage_id").value)
+ async def download_file(self, dest: Path | None = None) -> bytes | int:
+ """Download the file content from this FileObject node.
+
+ This method is only available for nodes that inherit from CoreFileObject.
+ The node must have been saved (have an id) before calling this method.
+
+ Args:
+ dest: Optional destination path. If provided, the file will be streamed
+ directly to this path (memory-efficient for large files) and the
+ number of bytes written will be returned. If not provided, the
+ file content will be returned as bytes.
+
+ Returns:
+ If ``dest`` is None: The file content as bytes.
+ If ``dest`` is provided: The number of bytes written to the file.
+
+ Raises:
+ FeatureNotSupportedError: If this node doesn't inherit from CoreFileObject.
+ ValueError: If the node hasn't been saved yet or file not found.
+ AuthenticationError: If authentication fails.
+
+ Examples:
+ >>> # Download to memory
+ >>> content = await contract.download_file()
+
+ >>> # Stream to file (memory-efficient for large files)
+ >>> bytes_written = await contract.download_file(dest=Path("/tmp/contract.pdf"))
+ """
+ self._validate_file_object_support(message=FILE_DOWNLOAD_FEATURE_NOT_SUPPORTED_MESSAGE)
+
+ if not self.id:
+ raise ValueError("Cannot download file for a node that hasn't been saved yet.")
+
+ return await self._file_handler.download(node_id=self.id, branch=self._branch, dest=dest)
+
async def delete(self, timeout: int | None = None, request_context: RequestContext | None = None) -> None:
input_data = {"data": {"id": self.id}}
if context_data := self._get_request_context(request_context=request_context):
@@ -1024,6 +1137,12 @@ async def _process_mutation_result(
async def create(
self, allow_upsert: bool = False, timeout: int | None = None, request_context: RequestContext | None = None
) -> None:
+ if self._file_object_support and self._file_content is None:
+ raise ValueError(
+ f"Cannot create {self._schema.kind} without file content. Use upload_from_path() or upload_from_bytes() to provide "
+ "file content before saving."
+ )
+
mutation_query = self._generate_mutation_query()
# Upserting means we may want to create, meaning payload contains all mandatory fields required for a creation,
@@ -1036,19 +1155,39 @@ async def create(
input_data = self._generate_input_data(exclude_hfid=True, request_context=request_context)
mutation_name = f"{self._schema.kind}Create"
tracker = f"mutation-{str(self._schema.kind).lower()}-create"
+
query = Mutation(
mutation=mutation_name,
input_data=input_data["data"],
query=mutation_query,
variables=input_data["mutation_variables"],
)
- response = await self._client.execute_graphql(
- query=query.render(),
- branch_name=self._branch,
- tracker=tracker,
- variables=input_data["variables"],
- timeout=timeout,
- )
+
+ if "file" in input_data["mutation_variables"]:
+ prepared = await self._get_file_for_upload()
+ try:
+ response = await self._client._execute_graphql_with_file(
+ query=query.render(),
+ variables=input_data["variables"],
+ file_content=prepared.file_object,
+ file_name=prepared.filename,
+ branch_name=self._branch,
+ tracker=tracker,
+ timeout=timeout,
+ )
+ finally:
+ if prepared.should_close and prepared.file_object:
+ prepared.file_object.close()
+ # Clear the file content after successful upload
+ self.clear_file()
+ else:
+ response = await self._client.execute_graphql(
+ query=query.render(),
+ branch_name=self._branch,
+ tracker=tracker,
+ variables=input_data["variables"],
+ timeout=timeout,
+ )
await self._process_mutation_result(mutation_name=mutation_name, response=response, timeout=timeout)
async def update(
@@ -1057,6 +1196,7 @@ async def update(
input_data = self._generate_input_data(exclude_unmodified=not do_full_update, request_context=request_context)
mutation_query = self._generate_mutation_query()
mutation_name = f"{self._schema.kind}Update"
+ tracker = f"mutation-{str(self._schema.kind).lower()}-update"
query = Mutation(
mutation=mutation_name,
@@ -1064,13 +1204,32 @@ async def update(
query=mutation_query,
variables=input_data["mutation_variables"],
)
- response = await self._client.execute_graphql(
- query=query.render(),
- branch_name=self._branch,
- timeout=timeout,
- tracker=f"mutation-{str(self._schema.kind).lower()}-update",
- variables=input_data["variables"],
- )
+
+ if "file" in input_data["mutation_variables"]:
+ prepared = await self._get_file_for_upload()
+ try:
+ response = await self._client._execute_graphql_with_file(
+ query=query.render(),
+ variables=input_data["variables"],
+ file_content=prepared.file_object,
+ file_name=prepared.filename,
+ branch_name=self._branch,
+ tracker=tracker,
+ timeout=timeout,
+ )
+ finally:
+ if prepared.should_close and prepared.file_object:
+ prepared.file_object.close()
+ # Clear the file content after successful upload
+ self.clear_file()
+ else:
+ response = await self._client.execute_graphql(
+ query=query.render(),
+ branch_name=self._branch,
+ timeout=timeout,
+ tracker=tracker,
+ variables=input_data["variables"],
+ )
await self._process_mutation_result(mutation_name=mutation_name, response=response, timeout=timeout)
async def _process_relationships(
@@ -1310,6 +1469,7 @@ def __init__(
data (Optional[dict]): Optional data to initialize the node.
"""
self._client = client
+ self._file_handler = FileHandlerSync(client=client)
# Extract node_metadata before extracting node data (node_metadata is sibling to node in edges)
node_metadata_data: dict | None = None
@@ -1362,10 +1522,7 @@ def _init_relationships(self, data: dict | None = None) -> None:
)
if value is not None
}
- if peer_id_data:
- rel_data = peer_id_data
- else:
- rel_data = None
+ rel_data = peer_id_data or None
self._relationship_cardinality_one_data[rel_schema.name] = RelatedNodeSync(
name=rel_schema.name, branch=self._branch, client=self._client, schema=rel_schema, data=rel_data
)
@@ -1499,6 +1656,41 @@ def artifact_fetch(self, name: str) -> str | dict[str, Any]:
artifact = self._client.get(kind="CoreArtifact", name__value=name, object__ids=[self.id])
return self._client.object_store.get(identifier=artifact._get_attribute(name="storage_id").value)
+ def download_file(self, dest: Path | None = None) -> bytes | int:
+ """Download the file content from this FileObject node.
+
+ This method is only available for nodes that inherit from CoreFileObject.
+ The node must have been saved (have an id) before calling this method.
+
+ Args:
+ dest: Optional destination path. If provided, the file will be streamed
+ directly to this path (memory-efficient for large files) and the
+ number of bytes written will be returned. If not provided, the
+ file content will be returned as bytes.
+
+ Returns:
+ If ``dest`` is None: The file content as bytes.
+ If ``dest`` is provided: The number of bytes written to the file.
+
+ Raises:
+ FeatureNotSupportedError: If this node doesn't inherit from CoreFileObject.
+ ValueError: If the node hasn't been saved yet or file not found.
+ AuthenticationError: If authentication fails.
+
+ Examples:
+ >>> # Download to memory
+ >>> content = contract.download_file()
+
+ >>> # Stream to file (memory-efficient for large files)
+ >>> bytes_written = contract.download_file(dest=Path("/tmp/contract.pdf"))
+ """
+ self._validate_file_object_support(message=FILE_DOWNLOAD_FEATURE_NOT_SUPPORTED_MESSAGE)
+
+ if not self.id:
+ raise ValueError("Cannot download file for a node that hasn't been saved yet.")
+
+ return self._file_handler.download(node_id=self.id, branch=self._branch, dest=dest)
+
def delete(self, timeout: int | None = None, request_context: RequestContext | None = None) -> None:
input_data = {"data": {"id": self.id}}
if context_data := self._get_request_context(request_context=request_context):
@@ -1828,6 +2020,12 @@ def _process_mutation_result(
def create(
self, allow_upsert: bool = False, timeout: int | None = None, request_context: RequestContext | None = None
) -> None:
+ if self._file_object_support and self._file_content is None:
+ raise ValueError(
+ f"Cannot create {self._schema.kind} without file content. Use upload_from_path() or upload_from_bytes() to provide "
+ "file content before saving."
+ )
+
mutation_query = self._generate_mutation_query()
if allow_upsert:
@@ -1838,6 +2036,7 @@ def create(
input_data = self._generate_input_data(exclude_hfid=True, request_context=request_context)
mutation_name = f"{self._schema.kind}Create"
tracker = f"mutation-{str(self._schema.kind).lower()}-create"
+
query = Mutation(
mutation=mutation_name,
input_data=input_data["data"],
@@ -1845,13 +2044,31 @@ def create(
variables=input_data["mutation_variables"],
)
- response = self._client.execute_graphql(
- query=query.render(),
- branch_name=self._branch,
- tracker=tracker,
- variables=input_data["variables"],
- timeout=timeout,
- )
+ if "file" in input_data["mutation_variables"]:
+ prepared = self._get_file_for_upload_sync()
+ try:
+ response = self._client._execute_graphql_with_file(
+ query=query.render(),
+ variables=input_data["variables"],
+ file_content=prepared.file_object,
+ file_name=prepared.filename,
+ branch_name=self._branch,
+ tracker=tracker,
+ timeout=timeout,
+ )
+ finally:
+ if prepared.should_close and prepared.file_object:
+ prepared.file_object.close()
+ # Clear the file content after successful upload
+ self.clear_file()
+ else:
+ response = self._client.execute_graphql(
+ query=query.render(),
+ branch_name=self._branch,
+ tracker=tracker,
+ variables=input_data["variables"],
+ timeout=timeout,
+ )
self._process_mutation_result(mutation_name=mutation_name, response=response, timeout=timeout)
def update(
@@ -1860,6 +2077,7 @@ def update(
input_data = self._generate_input_data(exclude_unmodified=not do_full_update, request_context=request_context)
mutation_query = self._generate_mutation_query()
mutation_name = f"{self._schema.kind}Update"
+ tracker = f"mutation-{str(self._schema.kind).lower()}-update"
query = Mutation(
mutation=mutation_name,
@@ -1868,13 +2086,31 @@ def update(
variables=input_data["mutation_variables"],
)
- response = self._client.execute_graphql(
- query=query.render(),
- branch_name=self._branch,
- tracker=f"mutation-{str(self._schema.kind).lower()}-update",
- variables=input_data["variables"],
- timeout=timeout,
- )
+ if "file" in input_data["mutation_variables"]:
+ prepared = self._get_file_for_upload_sync()
+ try:
+ response = self._client._execute_graphql_with_file(
+ query=query.render(),
+ variables=input_data["variables"],
+ file_content=prepared.file_object,
+ file_name=prepared.filename,
+ branch_name=self._branch,
+ tracker=tracker,
+ timeout=timeout,
+ )
+ finally:
+ if prepared.should_close and prepared.file_object:
+ prepared.file_object.close()
+ # Clear the file content after successful upload
+ self.clear_file()
+ else:
+ response = self._client.execute_graphql(
+ query=query.render(),
+ branch_name=self._branch,
+ tracker=tracker,
+ variables=input_data["variables"],
+ timeout=timeout,
+ )
self._process_mutation_result(mutation_name=mutation_name, response=response, timeout=timeout)
def _process_relationships(
diff --git a/infrahub_sdk/protocols.py b/infrahub_sdk/protocols.py
index b3752bed..c359ad5c 100644
--- a/infrahub_sdk/protocols.py
+++ b/infrahub_sdk/protocols.py
@@ -29,7 +29,6 @@
StringOptional,
)
-# pylint: disable=too-many-ancestors
# ---------------------------------------------
# ASYNC
@@ -108,6 +107,14 @@ class CoreCredential(CoreNode):
description: StringOptional
+class CoreFileObject(CoreNode):
+ file_name: String
+ checksum: String
+ file_size: Integer
+ file_type: String
+ storage_id: String
+
+
class CoreGenericAccount(CoreNode):
name: String
password: HashedPassword
@@ -227,6 +234,7 @@ class CoreValidator(CoreNode):
class CoreWebhook(CoreNode):
name: String
event_type: Enum
+ active: Boolean
branch_scope: Dropdown
node_kind: StringOptional
description: StringOptional
@@ -499,7 +507,6 @@ class CoreProposedChange(CoreTaskTarget):
approved_by: RelationshipManager
rejected_by: RelationshipManager
reviewers: RelationshipManager
- created_by: RelatedNode
comments: RelationshipManager
threads: RelationshipManager
validations: RelationshipManager
@@ -665,6 +672,14 @@ class CoreCredentialSync(CoreNodeSync):
description: StringOptional
+class CoreFileObjectSync(CoreNodeSync):
+ file_name: String
+ checksum: String
+ file_size: Integer
+ file_type: String
+ storage_id: String
+
+
class CoreGenericAccountSync(CoreNodeSync):
name: String
password: HashedPassword
@@ -784,6 +799,7 @@ class CoreValidatorSync(CoreNodeSync):
class CoreWebhookSync(CoreNodeSync):
name: String
event_type: Enum
+ active: Boolean
branch_scope: Dropdown
node_kind: StringOptional
description: StringOptional
@@ -1056,7 +1072,6 @@ class CoreProposedChangeSync(CoreTaskTargetSync):
approved_by: RelationshipManagerSync
rejected_by: RelationshipManagerSync
reviewers: RelationshipManagerSync
- created_by: RelatedNodeSync
comments: RelationshipManagerSync
threads: RelationshipManagerSync
validations: RelationshipManagerSync
diff --git a/infrahub_sdk/protocols_base.py b/infrahub_sdk/protocols_base.py
index 8a841b5b..7f6569ae 100644
--- a/infrahub_sdk/protocols_base.py
+++ b/infrahub_sdk/protocols_base.py
@@ -6,6 +6,7 @@
import ipaddress
from .context import RequestContext
+ from .node.metadata import NodeMetadata
from .schema import MainSchemaTypes
@@ -203,6 +204,8 @@ def is_resource_pool(self) -> bool: ...
def get_raw_graphql_data(self) -> dict | None: ...
+ def get_node_metadata(self) -> NodeMetadata | None: ...
+
@runtime_checkable
class CoreNode(CoreNodeBase, Protocol):
diff --git a/infrahub_sdk/protocols_generator/generator.py b/infrahub_sdk/protocols_generator/generator.py
index e70e221c..38bc968f 100644
--- a/infrahub_sdk/protocols_generator/generator.py
+++ b/infrahub_sdk/protocols_generator/generator.py
@@ -59,13 +59,13 @@ def __init__(self, schema: dict[str, MainSchemaTypesAll]) -> None:
not in {"TYPE_CHECKING", "CoreNode", "Optional", "Protocol", "Union", "annotations", "runtime_checkable"}
]
- self.sorted_generics = self._sort_and_filter_models(self.generics, filters=["CoreNode"] + self.base_protocols)
- self.sorted_nodes = self._sort_and_filter_models(self.nodes, filters=["CoreNode"] + self.base_protocols)
+ self.sorted_generics = self._sort_and_filter_models(self.generics, filters=["CoreNode", *self.base_protocols])
+ self.sorted_nodes = self._sort_and_filter_models(self.nodes, filters=["CoreNode", *self.base_protocols])
self.sorted_profiles = self._sort_and_filter_models(
- self.profiles, filters=["CoreProfile"] + self.base_protocols
+ self.profiles, filters=["CoreProfile", *self.base_protocols]
)
self.sorted_templates = self._sort_and_filter_models(
- self.templates, filters=["CoreObjectTemplate"] + self.base_protocols
+ self.templates, filters=["CoreObjectTemplate", *self.base_protocols]
)
def render(self, sync: bool = True) -> str:
diff --git a/infrahub_sdk/py.typed b/infrahub_sdk/py.typed
new file mode 100644
index 00000000..e69de29b
diff --git a/infrahub_sdk/pytest_plugin/items/check.py b/infrahub_sdk/pytest_plugin/items/check.py
index f42f4808..8f68b271 100644
--- a/infrahub_sdk/pytest_plugin/items/check.py
+++ b/infrahub_sdk/pytest_plugin/items/check.py
@@ -12,7 +12,7 @@
from .base import InfrahubItem
if TYPE_CHECKING:
- from pytest import ExceptionInfo
+ import pytest
from ...checks import InfrahubCheck
from ...schema.repository import InfrahubRepositoryConfigElement
@@ -46,7 +46,7 @@ def run_check(self, variables: dict[str, Any]) -> Any:
self.instantiate_check()
return asyncio.run(self.check_instance.run(data=variables))
- def repr_failure(self, excinfo: ExceptionInfo, style: str | None = None) -> str:
+ def repr_failure(self, excinfo: pytest.ExceptionInfo, style: str | None = None) -> str:
if isinstance(excinfo.value, HTTPStatusError):
try:
response_content = ujson.dumps(excinfo.value.response.json(), indent=4)
diff --git a/infrahub_sdk/pytest_plugin/items/graphql_query.py b/infrahub_sdk/pytest_plugin/items/graphql_query.py
index defb9fb9..bced5542 100644
--- a/infrahub_sdk/pytest_plugin/items/graphql_query.py
+++ b/infrahub_sdk/pytest_plugin/items/graphql_query.py
@@ -11,7 +11,7 @@
from .base import InfrahubItem
if TYPE_CHECKING:
- from pytest import ExceptionInfo
+ import pytest
class InfrahubGraphQLQueryItem(InfrahubItem):
@@ -25,7 +25,7 @@ def execute_query(self) -> Any:
variables=self.test.spec.get_variables_data(), # type: ignore[union-attr]
)
- def repr_failure(self, excinfo: ExceptionInfo, style: str | None = None) -> str:
+ def repr_failure(self, excinfo: pytest.ExceptionInfo, style: str | None = None) -> str:
if isinstance(excinfo.value, HTTPStatusError):
try:
response_content = ujson.dumps(excinfo.value.response.json(), indent=4)
diff --git a/infrahub_sdk/pytest_plugin/items/jinja2_transform.py b/infrahub_sdk/pytest_plugin/items/jinja2_transform.py
index 433309a4..fe54fd71 100644
--- a/infrahub_sdk/pytest_plugin/items/jinja2_transform.py
+++ b/infrahub_sdk/pytest_plugin/items/jinja2_transform.py
@@ -16,7 +16,7 @@
from .base import InfrahubItem
if TYPE_CHECKING:
- from pytest import ExceptionInfo
+ import pytest
class InfrahubJinja2Item(InfrahubItem):
@@ -57,7 +57,7 @@ def get_result_differences(self, computed: Any) -> str | None:
)
return "\n".join(differences)
- def repr_failure(self, excinfo: ExceptionInfo, style: str | None = None) -> str:
+ def repr_failure(self, excinfo: pytest.ExceptionInfo, style: str | None = None) -> str:
if isinstance(excinfo.value, HTTPStatusError):
try:
response_content = ujson.dumps(excinfo.value.response.json(), indent=4, sort_keys=True)
@@ -94,7 +94,7 @@ def runtest(self) -> None:
if computed is not None and differences and self.test.expect == InfrahubTestExpectedResult.PASS:
raise OutputMatchError(name=self.name, differences=differences)
- def repr_failure(self, excinfo: ExceptionInfo, style: str | None = None) -> str:
+ def repr_failure(self, excinfo: pytest.ExceptionInfo, style: str | None = None) -> str:
if isinstance(excinfo.value, (JinjaTemplateError)):
return str(excinfo.value.message)
diff --git a/infrahub_sdk/pytest_plugin/items/python_transform.py b/infrahub_sdk/pytest_plugin/items/python_transform.py
index 0ec42052..f895e971 100644
--- a/infrahub_sdk/pytest_plugin/items/python_transform.py
+++ b/infrahub_sdk/pytest_plugin/items/python_transform.py
@@ -13,7 +13,7 @@
from .base import InfrahubItem
if TYPE_CHECKING:
- from pytest import ExceptionInfo
+ import pytest
from ...schema.repository import InfrahubRepositoryConfigElement
from ...transforms import InfrahubTransform
@@ -48,7 +48,7 @@ def run_transform(self, variables: dict[str, Any]) -> Any:
self.instantiate_transform()
return asyncio.run(self.transform_instance.run(data=variables))
- def repr_failure(self, excinfo: ExceptionInfo, style: str | None = None) -> str:
+ def repr_failure(self, excinfo: pytest.ExceptionInfo, style: str | None = None) -> str:
if isinstance(excinfo.value, HTTPStatusError):
try:
response_content = ujson.dumps(excinfo.value.response.json(), indent=4)
diff --git a/infrahub_sdk/pytest_plugin/loader.py b/infrahub_sdk/pytest_plugin/loader.py
index 5912a0c3..c26b09d2 100644
--- a/infrahub_sdk/pytest_plugin/loader.py
+++ b/infrahub_sdk/pytest_plugin/loader.py
@@ -6,7 +6,6 @@
import pytest
import yaml
-from pytest import Item
from .exceptions import InvalidResourceConfigError
from .items import (
@@ -66,7 +65,7 @@ def get_resource_config(self, group: InfrahubTestGroup) -> Any | None:
return resource_config
- def collect_group(self, group: InfrahubTestGroup) -> Iterable[Item]:
+ def collect_group(self, group: InfrahubTestGroup) -> Iterable[pytest.Item]:
"""Collect all items for a group."""
marker = MARKER_MAPPING[group.resource]
resource_config = self.get_resource_config(group)
@@ -98,7 +97,7 @@ def collect_group(self, group: InfrahubTestGroup) -> Iterable[Item]:
yield item
- def collect(self) -> Iterable[Item]:
+ def collect(self) -> Iterable[pytest.Item]:
raw = yaml.safe_load(self.path.open(encoding="utf-8"))
if not raw or "infrahub_tests" not in raw:
diff --git a/infrahub_sdk/pytest_plugin/plugin.py b/infrahub_sdk/pytest_plugin/plugin.py
index 74148a05..258e7f9c 100644
--- a/infrahub_sdk/pytest_plugin/plugin.py
+++ b/infrahub_sdk/pytest_plugin/plugin.py
@@ -3,8 +3,7 @@
import os
from pathlib import Path
-from pytest import Collector, Config, Item, Parser, Session
-from pytest import exit as exit_test
+import pytest
from .. import InfrahubClientSync
from ..utils import is_valid_url
@@ -12,7 +11,7 @@
from .utils import find_repository_config_file, load_repository_config
-def pytest_addoption(parser: Parser) -> None:
+def pytest_addoption(parser: pytest.Parser) -> None:
group = parser.getgroup("pytest-infrahub")
group.addoption(
"--infrahub-repo-config",
@@ -62,7 +61,7 @@ def pytest_addoption(parser: Parser) -> None:
)
-def pytest_sessionstart(session: Session) -> None:
+def pytest_sessionstart(session: pytest.Session) -> None:
if session.config.option.infrahub_repo_config:
session.infrahub_config_path = Path(session.config.option.infrahub_repo_config) # type: ignore[attr-defined]
else:
@@ -72,7 +71,7 @@ def pytest_sessionstart(session: Session) -> None:
session.infrahub_repo_config = load_repository_config(repo_config_file=session.infrahub_config_path) # type: ignore[attr-defined]
if not is_valid_url(session.config.option.infrahub_address):
- exit_test("Infrahub test instance address is not a valid URL", returncode=1)
+ pytest.exit("Infrahub test instance address is not a valid URL", returncode=1)
client_config = {
"address": session.config.option.infrahub_address,
@@ -89,13 +88,13 @@ def pytest_sessionstart(session: Session) -> None:
session.infrahub_client = infrahub_client # type: ignore[attr-defined]
-def pytest_collect_file(parent: Collector | Item, file_path: Path) -> InfrahubYamlFile | None:
+def pytest_collect_file(parent: pytest.Collector | pytest.Item, file_path: Path) -> InfrahubYamlFile | None:
if file_path.suffix in {".yml", ".yaml"} and file_path.name.startswith("test_"):
return InfrahubYamlFile.from_parent(parent, path=file_path)
return None
-def pytest_configure(config: Config) -> None:
+def pytest_configure(config: pytest.Config) -> None:
config.addinivalue_line("markers", "infrahub: Infrahub test")
config.addinivalue_line("markers", "infrahub_smoke: Smoke test for an Infrahub resource")
config.addinivalue_line("markers", "infrahub_unit: Unit test for an Infrahub resource, works without dependencies")
diff --git a/infrahub_sdk/schema/__init__.py b/infrahub_sdk/schema/__init__.py
index 3e61ad2a..557e76f3 100644
--- a/infrahub_sdk/schema/__init__.py
+++ b/infrahub_sdk/schema/__init__.py
@@ -22,6 +22,7 @@
)
from ..graphql import Mutation
from ..queries import SCHEMA_HASH_SYNC_STATUS
+from .export import RESTRICTED_NAMESPACES, NamespaceExport, SchemaExport, schema_to_export_dict
from .main import (
AttributeSchema,
AttributeSchemaAPI,
@@ -54,6 +55,7 @@
"BranchSupportType",
"GenericSchema",
"GenericSchemaAPI",
+ "NamespaceExport",
"NodeSchema",
"NodeSchemaAPI",
"ProfileSchemaAPI",
@@ -61,9 +63,11 @@
"RelationshipKind",
"RelationshipSchema",
"RelationshipSchemaAPI",
+ "SchemaExport",
"SchemaRoot",
"SchemaRootAPI",
"TemplateSchemaAPI",
+ "schema_to_export_dict",
]
@@ -118,6 +122,47 @@ def __init__(self, client: InfrahubClient | InfrahubClientSync) -> None:
self.client = client
self.cache = {}
+ @staticmethod
+ def _build_export_schemas(
+ schema_nodes: MutableMapping[str, MainSchemaTypesAPI],
+ namespaces: list[str] | None = None,
+ ) -> SchemaExport:
+ """Organize fetched schemas into a per-namespace export structure.
+
+ Filters out system types (Profile/Template) and restricted namespaces
+ (see :data:`RESTRICTED_NAMESPACES`), and optionally limits to specific
+ namespaces. If the caller requests restricted namespaces they are
+ silently excluded and a :func:`warnings.warn` is emitted.
+
+ Returns:
+ A :class:`SchemaExport` containing user-defined schemas by namespace.
+ """
+ if namespaces:
+ restricted = set(namespaces) & set(RESTRICTED_NAMESPACES)
+ if restricted:
+ warnings.warn(
+ f"Restricted namespace(s) {sorted(restricted)} requested but will be excluded from export",
+ stacklevel=3,
+ )
+
+ ns_map: dict[str, NamespaceExport] = {}
+ for schema in schema_nodes.values():
+ if isinstance(schema, (ProfileSchemaAPI, TemplateSchemaAPI)):
+ continue
+ if schema.namespace in RESTRICTED_NAMESPACES:
+ continue
+ if namespaces and schema.namespace not in namespaces:
+ continue
+ ns = schema.namespace
+ if ns not in ns_map:
+ ns_map[ns] = NamespaceExport()
+ schema_dict = schema_to_export_dict(schema)
+ if isinstance(schema, GenericSchemaAPI):
+ ns_map[ns].generics.append(schema_dict)
+ else:
+ ns_map[ns].nodes.append(schema_dict)
+ return SchemaExport(namespaces=ns_map)
+
def validate(self, data: dict[str, Any]) -> None:
SchemaRoot(**data)
@@ -497,6 +542,32 @@ async def fetch(
return branch_schema.nodes
+ async def export(
+ self,
+ branch: str | None = None,
+ namespaces: list[str] | None = None,
+ ) -> SchemaExport:
+ """Export user-defined schemas organized by namespace.
+
+ Fetches schemas from the server, filters out system types and
+ restricted namespaces (see :data:`RESTRICTED_NAMESPACES`), and returns
+ a :class:`SchemaExport` object with per-namespace data. Restricted
+ namespaces such as ``Core`` and ``Builtin`` are always excluded even if
+ explicitly listed in *namespaces*; a warning is emitted when this
+ happens.
+
+ Args:
+ branch: Branch to export from. Defaults to default_branch.
+ namespaces: Optional list of namespaces to include. If empty/None,
+ all user-defined namespaces are exported.
+
+ Returns:
+ A :class:`SchemaExport` containing user-defined schemas by namespace.
+ """
+ branch = branch or self.client.default_branch
+ schema_nodes = await self.fetch(branch=branch, namespaces=namespaces, populate_cache=False)
+ return self._build_export_schemas(schema_nodes=schema_nodes, namespaces=namespaces)
+
async def get_graphql_schema(self, branch: str | None = None) -> str:
"""Get the GraphQL schema as a string.
@@ -739,6 +810,32 @@ def fetch(
return branch_schema.nodes
+ def export(
+ self,
+ branch: str | None = None,
+ namespaces: list[str] | None = None,
+ ) -> SchemaExport:
+ """Export user-defined schemas organized by namespace.
+
+ Fetches schemas from the server, filters out system types and
+ restricted namespaces (see :data:`RESTRICTED_NAMESPACES`), and returns
+ a :class:`SchemaExport` object with per-namespace data. Restricted
+ namespaces such as ``Core`` and ``Builtin`` are always excluded even if
+ explicitly listed in *namespaces*; a warning is emitted when this
+ happens.
+
+ Args:
+ branch: Branch to export from. Defaults to default_branch.
+ namespaces: Optional list of namespaces to include. If empty/None,
+ all user-defined namespaces are exported.
+
+ Returns:
+ A :class:`SchemaExport` containing user-defined schemas by namespace.
+ """
+ branch = branch or self.client.default_branch
+ schema_nodes = self.fetch(branch=branch, namespaces=namespaces, populate_cache=False)
+ return self._build_export_schemas(schema_nodes=schema_nodes, namespaces=namespaces)
+
def get_graphql_schema(self, branch: str | None = None) -> str:
"""Get the GraphQL schema as a string.
diff --git a/infrahub_sdk/schema/export.py b/infrahub_sdk/schema/export.py
new file mode 100644
index 00000000..d5a09c77
--- /dev/null
+++ b/infrahub_sdk/schema/export.py
@@ -0,0 +1,120 @@
+from __future__ import annotations
+
+from typing import Any
+
+from pydantic import BaseModel, Field
+
+from .main import GenericSchemaAPI, NodeSchemaAPI
+
+
+class NamespaceExport(BaseModel):
+ """Export data for a single namespace."""
+
+ nodes: list[dict[str, Any]] = Field(default_factory=list)
+ generics: list[dict[str, Any]] = Field(default_factory=list)
+
+
+class SchemaExport(BaseModel):
+ """Result of a schema export, organized by namespace."""
+
+ namespaces: dict[str, NamespaceExport] = Field(default_factory=dict)
+
+ def to_dict(self) -> dict[str, dict[str, list[dict[str, Any]]]]:
+ """Convert to plain dict for YAML serialization."""
+ return {ns: data.model_dump(exclude_defaults=True) for ns, data in self.namespaces.items()}
+
+
+# Namespaces reserved by the Infrahub server — mirrored from
+# backend/infrahub/core/constants/__init__.py in the opsmill/infrahub repo.
+RESTRICTED_NAMESPACES: list[str] = [
+ "Account",
+ "Branch",
+ "Builtin",
+ "Core",
+ "Deprecated",
+ "Diff",
+ "Infrahub",
+ "Internal",
+ "Lineage",
+ "Schema",
+ "Profile",
+ "Template",
+]
+
+_SCHEMA_EXPORT_EXCLUDE: set[str] = {"hash", "hierarchy", "used_by", "id", "state"}
+# branch is inherited from the node and need not be repeated on each field
+_FIELD_EXPORT_EXCLUDE: set[str] = {"inherited", "allow_override", "hierarchical", "id", "state", "branch"}
+
+# Attribute field values that match schema loading defaults — omitted for cleaner output
+_ATTR_EXPORT_DEFAULTS: dict[str, Any] = {
+ "read_only": False,
+ "optional": False,
+}
+
+# Relationship field values that match schema loading defaults — omitted for cleaner output
+_REL_EXPORT_DEFAULTS: dict[str, Any] = {
+ "direction": "bidirectional",
+ "on_delete": "no-action",
+ "cardinality": "many",
+ "optional": True,
+ "min_count": 0,
+ "max_count": 0,
+ "read_only": False,
+}
+
+# Relationship kinds that Infrahub generates automatically — never user-defined
+_AUTO_GENERATED_REL_KINDS: frozenset[str] = frozenset({"Group", "Profile", "Hierarchy"})
+
+
+def schema_to_export_dict(schema: NodeSchemaAPI | GenericSchemaAPI) -> dict[str, Any]:
+ """Convert an API schema object to an export-ready dict (omits API-internal fields)."""
+ data = schema.model_dump(exclude=_SCHEMA_EXPORT_EXCLUDE, exclude_none=True)
+
+ # Pop attrs/rels so they can be re-inserted last for better readability
+ data.pop("attributes", None)
+ data.pop("relationships", None)
+
+ # Generics with Hierarchy relationships were defined with `hierarchical: true`.
+ # Restore that flag and drop the auto-generated rels so the schema round-trips cleanly.
+ if isinstance(schema, GenericSchemaAPI) and any(
+ rel.kind == "Hierarchy" for rel in schema.relationships if not rel.inherited
+ ):
+ data["hierarchical"] = True
+
+ # Strip uniqueness_constraints that are auto-generated from `unique: true` attributes
+ # (single-field entries of the form ["__value"]). User-defined multi-field
+ # constraints are preserved.
+ unique_attr_suffixes = {f"{attr.name}__value" for attr in schema.attributes if attr.unique}
+ user_constraints = [
+ c
+ for c in (data.pop("uniqueness_constraints", None) or [])
+ if not (len(c) == 1 and c[0] in unique_attr_suffixes)
+ ]
+ if user_constraints:
+ data["uniqueness_constraints"] = user_constraints
+
+ attributes = [
+ {
+ k: v
+ for k, v in attr.model_dump(exclude=_FIELD_EXPORT_EXCLUDE, exclude_none=True).items()
+ if k not in _ATTR_EXPORT_DEFAULTS or v != _ATTR_EXPORT_DEFAULTS[k]
+ }
+ for attr in schema.attributes
+ if not attr.inherited
+ ]
+ if attributes:
+ data["attributes"] = attributes
+
+ relationships = [
+ {
+ k: v
+ for k, v in rel.model_dump(exclude=_FIELD_EXPORT_EXCLUDE, exclude_none=True).items()
+ if k not in _REL_EXPORT_DEFAULTS or v != _REL_EXPORT_DEFAULTS[k]
+ }
+ for rel in schema.relationships
+ if not rel.inherited and rel.kind not in _AUTO_GENERATED_REL_KINDS
+ ]
+ if relationships:
+ data["relationships"] = relationships
+
+ return data
diff --git a/infrahub_sdk/spec/object.py b/infrahub_sdk/spec/object.py
index cf7a6fc3..30d9f93d 100644
--- a/infrahub_sdk/spec/object.py
+++ b/infrahub_sdk/spec/object.py
@@ -265,7 +265,7 @@ async def validate_object(
# First validate if all mandatory fields are present
errors.extend(
- ObjectValidationError(position=position + [element], message=f"{element} is mandatory")
+ ObjectValidationError(position=[*position, element], message=f"{element} is mandatory")
for element in schema.mandatory_input_names
if not any([element in data, element in context])
)
@@ -275,7 +275,7 @@ async def validate_object(
if key not in schema.attribute_names and key not in schema.relationship_names:
errors.append(
ObjectValidationError(
- position=position + [key],
+ position=[*position, key],
message=f"{key} is not a valid attribute or relationship for {schema.kind}",
)
)
@@ -283,7 +283,7 @@ async def validate_object(
if key in schema.attribute_names and not isinstance(value, (str, int, float, bool, list, dict)):
errors.append(
ObjectValidationError(
- position=position + [key],
+ position=[*position, key],
message=f"{key} must be a string, int, float, bool, list, or dict",
)
)
@@ -295,7 +295,7 @@ async def validate_object(
if not rel_info.is_valid:
errors.append(
ObjectValidationError(
- position=position + [key],
+ position=[*position, key],
message=rel_info.reason_relationship_not_valid or "Invalid relationship",
)
)
@@ -303,7 +303,7 @@ async def validate_object(
errors.extend(
await cls.validate_related_nodes(
client=client,
- position=position + [key],
+ position=[*position, key],
rel_info=rel_info,
data=value,
context=context,
@@ -378,7 +378,7 @@ async def validate_related_nodes(
errors.extend(
await cls.validate_object(
client=client,
- position=position + [idx + 1],
+ position=[*position, idx + 1],
schema=peer_schema,
data=peer_data,
context=context,
@@ -403,7 +403,7 @@ async def validate_related_nodes(
errors.extend(
await cls.validate_object(
client=client,
- position=position + [idx + 1],
+ position=[*position, idx + 1],
schema=peer_schema,
data=item["data"],
context=context,
@@ -613,7 +613,7 @@ async def create_related_nodes(
node = await cls.create_node(
client=client,
schema=peer_schema,
- position=position + [rel_info.name, idx + 1],
+ position=[*position, rel_info.name, idx + 1],
data=peer_data,
context=context,
branch=branch,
@@ -639,7 +639,7 @@ async def create_related_nodes(
node = await cls.create_node(
client=client,
schema=peer_schema,
- position=position + [rel_info.name, idx + 1],
+ position=[*position, rel_info.name, idx + 1],
data=item["data"],
context=context,
branch=branch,
@@ -681,7 +681,7 @@ def spec(self) -> InfrahubObjectFileData:
try:
self._spec = InfrahubObjectFileData(**self.data.spec)
except Exception as exc:
- raise ValidationError(identifier=str(self.location), message=str(exc))
+ raise ValidationError(identifier=str(self.location), message=str(exc)) from exc
return self._spec
def validate_content(self) -> None:
@@ -691,7 +691,7 @@ def validate_content(self) -> None:
try:
self._spec = InfrahubObjectFileData(**self.data.spec)
except Exception as exc:
- raise ValidationError(identifier=str(self.location), message=str(exc))
+ raise ValidationError(identifier=str(self.location), message=str(exc)) from exc
async def validate_format(self, client: InfrahubClient, branch: str | None = None) -> None:
self.validate_content()
diff --git a/infrahub_sdk/template/__init__.py b/infrahub_sdk/template/__init__.py
index ff866ecd..6a7f7fe2 100644
--- a/infrahub_sdk/template/__init__.py
+++ b/infrahub_sdk/template/__init__.py
@@ -64,14 +64,11 @@ def get_template(self) -> jinja2.Template:
return self._template_definition
try:
- if self.is_string_based:
- template = self._get_string_based_template()
- else:
- template = self._get_file_based_template()
+ template = self._get_string_based_template() if self.is_string_based else self._get_file_based_template()
except jinja2.TemplateSyntaxError as exc:
self._raise_template_syntax_error(error=exc)
except jinja2.TemplateNotFound as exc:
- raise JinjaTemplateNotFoundError(message=exc.message, filename=str(exc.name))
+ raise JinjaTemplateNotFoundError(message=exc.message, filename=str(exc.name)) from exc
return template
@@ -119,19 +116,18 @@ async def render(self, variables: dict[str, Any]) -> str:
try:
output = await template.render_async(variables)
except jinja2.exceptions.TemplateNotFound as exc:
- raise JinjaTemplateNotFoundError(message=exc.message, filename=str(exc.name), base_template=template.name)
+ raise JinjaTemplateNotFoundError(
+ message=exc.message, filename=str(exc.name), base_template=template.name
+ ) from exc
except jinja2.TemplateSyntaxError as exc:
self._raise_template_syntax_error(error=exc)
except jinja2.UndefinedError as exc:
traceback = Traceback(show_locals=False)
errors = _identify_faulty_jinja_code(traceback=traceback)
- raise JinjaTemplateUndefinedError(message=exc.message, errors=errors)
+ raise JinjaTemplateUndefinedError(message=exc.message, errors=errors) from exc
except Exception as exc:
- if error_message := getattr(exc, "message", None):
- message = error_message
- else:
- message = str(exc)
- raise JinjaTemplateError(message=message or "Unknown template error")
+ message = error_message if (error_message := getattr(exc, "message", None)) else str(exc)
+ raise JinjaTemplateError(message=message or "Unknown template error") from exc
return output
@@ -195,10 +191,7 @@ def _identify_faulty_jinja_code(traceback: Traceback, nbr_context_lines: int = 3
# Extract only the Jinja related exception
for frame in [frame for frame in traceback.trace.stacks[0].frames if not frame.filename.endswith(".py")]:
code = "".join(linecache.getlines(frame.filename))
- if frame.filename == "":
- lexer_name = "text"
- else:
- lexer_name = Traceback._guess_lexer(frame.filename, code)
+ lexer_name = "text" if frame.filename == "" else Traceback._guess_lexer(frame.filename, code)
syntax = Syntax(
code,
lexer_name,
diff --git a/infrahub_sdk/testing/schemas/file_object.py b/infrahub_sdk/testing/schemas/file_object.py
new file mode 100644
index 00000000..dc79b214
--- /dev/null
+++ b/infrahub_sdk/testing/schemas/file_object.py
@@ -0,0 +1,45 @@
+import pytest
+
+from infrahub_sdk import InfrahubClient, InfrahubClientSync
+from infrahub_sdk.schema.main import AttributeKind, NodeSchema, SchemaRoot
+from infrahub_sdk.schema.main import AttributeSchema as Attr
+
+NAMESPACE = "Testing"
+TESTING_FILE_CONTRACT = f"{NAMESPACE}FileContract"
+
+PDF_MAGIC_BYTES = b"%PDF-1.4 fake pdf content for testing"
+PNG_MAGIC_BYTES = b"\x89PNG\r\n\x1a\n fake png content for testing"
+TEXT_CONTENT = b"This is a simple text file content for testing purposes."
+
+
+class SchemaFileObject:
+ @pytest.fixture(scope="class")
+ def schema_file_contract(self) -> NodeSchema:
+ return NodeSchema(
+ name="FileContract",
+ namespace=NAMESPACE,
+ include_in_menu=True,
+ inherit_from=["CoreFileObject"],
+ display_label="file_name__value",
+ human_friendly_id=["contract_ref__value"],
+ order_by=["contract_ref__value"],
+ attributes=[
+ Attr(name="contract_ref", kind=AttributeKind.TEXT, unique=True),
+ Attr(name="description", kind=AttributeKind.TEXT, optional=True),
+ Attr(name="active", kind=AttributeKind.BOOLEAN, default_value=True, optional=True),
+ ],
+ )
+
+ @pytest.fixture(scope="class")
+ def schema_file_object_base(self, schema_file_contract: NodeSchema) -> SchemaRoot:
+ return SchemaRoot(version="1.0", nodes=[schema_file_contract])
+
+ @pytest.fixture(scope="class")
+ async def load_file_object_schema(self, client: InfrahubClient, schema_file_object_base: SchemaRoot) -> None:
+ await client.schema.load(schemas=[schema_file_object_base.to_schema_dict()], wait_until_converged=True)
+
+ @pytest.fixture(scope="class")
+ def load_file_object_schema_sync(
+ self, client_sync: InfrahubClientSync, schema_file_object_base: SchemaRoot
+ ) -> None:
+ client_sync.schema.load(schemas=[schema_file_object_base.to_schema_dict()], wait_until_converged=True)
diff --git a/infrahub_sdk/timestamp.py b/infrahub_sdk/timestamp.py
index 07de7b40..fd69122e 100644
--- a/infrahub_sdk/timestamp.py
+++ b/infrahub_sdk/timestamp.py
@@ -92,7 +92,11 @@ def _parse_string(cls, value: str) -> ZonedDateTime:
params["hours"] = float(match.group(1))
if params:
- return ZonedDateTime.now("UTC").subtract(**params)
+ return ZonedDateTime.now("UTC").subtract(
+ seconds=params.get("seconds", 0.0),
+ minutes=params.get("minutes", 0.0),
+ hours=params.get("hours", 0.0),
+ )
raise TimestampFormatError(f"Invalid time format for {value}")
diff --git a/infrahub_sdk/topological_sort.py b/infrahub_sdk/topological_sort.py
index a323d440..58047bd4 100644
--- a/infrahub_sdk/topological_sort.py
+++ b/infrahub_sdk/topological_sort.py
@@ -61,9 +61,9 @@ def _get_cycles(dependency_dict: dict[str, Iterable[str]], path: list[str]) -> l
cycles = []
for next_node in next_nodes:
if next_node in path:
- cycles.append(path[path.index(next_node) :] + [next_node])
+ cycles.append([*path[path.index(next_node) :], next_node])
else:
- next_cycles = _get_cycles(dependency_dict, path + [next_node])
+ next_cycles = _get_cycles(dependency_dict, [*path, next_node])
if next_cycles:
cycles += next_cycles
return cycles
diff --git a/infrahub_sdk/utils.py b/infrahub_sdk/utils.py
index 6168664b..de9bd625 100644
--- a/infrahub_sdk/utils.py
+++ b/infrahub_sdk/utils.py
@@ -145,7 +145,7 @@ def deep_merge_dict(dicta: dict, dictb: dict, path: list | None = None) -> dict:
if key in dicta:
a_val = dicta[key]
if isinstance(a_val, dict) and isinstance(b_val, dict):
- deep_merge_dict(a_val, b_val, path + [str(key)])
+ deep_merge_dict(a_val, b_val, [*path, str(key)])
elif isinstance(a_val, list) and isinstance(b_val, list):
# Merge lists
# Cannot use compare_list because list of dicts won't work (dict not hashable)
@@ -155,7 +155,7 @@ def deep_merge_dict(dicta: dict, dictb: dict, path: list | None = None) -> dict:
elif a_val == b_val or (a_val is not None and b_val is None):
continue
else:
- raise ValueError("Conflict at %s" % ".".join(path + [str(key)]))
+ raise ValueError("Conflict at %s" % ".".join([*path, str(key)]))
else:
dicta[key] = b_val
return dicta
diff --git a/pyproject.toml b/pyproject.toml
index 6750ef12..4e0716a6 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "infrahub-sdk"
-version = "1.18.1"
+version = "1.19.0"
description = "Python Client to interact with Infrahub"
authors = [
{name = "OpsMill", email = "info@opsmill.com"}
@@ -127,18 +127,38 @@ python-version = "3.10"
include = ["infrahub_sdk/**"]
[tool.ty.overrides.rules]
-##################################################################################################
-# The ignored rules below should be removed once the code has been updated, they are included #
-# like this so that we can reactivate them one by one. #
-##################################################################################################
-invalid-argument-type = "ignore"
-invalid-assignment = "ignore"
-invalid-await = "ignore"
-invalid-type-form = "ignore"
-no-matching-overload = "ignore"
-unresolved-attribute = "ignore"
unused-ignore-comment = "ignore" # Clashes with mypy's type ignore comments
+# File-specific overrides for remaining type violations
+# Fix these incrementally by addressing violations and removing the override
+
+
+[[tool.ty.overrides]]
+include = ["infrahub_sdk/checks.py"]
+
+[tool.ty.overrides.rules]
+invalid-await = "ignore" # 1 violation
+
+[[tool.ty.overrides]]
+include = ["infrahub_sdk/file_handler.py", "infrahub_sdk/utils.py"]
+
+[tool.ty.overrides.rules]
+unresolved-attribute = "ignore" # 5 violations total (1 in file_handler.py, 4 in utils.py)
+
+[[tool.ty.overrides]]
+include = ["infrahub_sdk/transfer/**"]
+
+[tool.ty.overrides.rules]
+invalid-argument-type = "ignore" # 2 violations in importer/json.py
+invalid-assignment = "ignore" # 1 violation in importer/json.py
+
+[[tool.ty.overrides]]
+include = ["infrahub_sdk/node/node.py"]
+
+[tool.ty.overrides.rules]
+invalid-argument-type = "ignore" # 9 violations - lines 776, 855, 859, 862
+
+
[[tool.ty.overrides]]
include = ["infrahub_sdk/ctl/config.py"]
@@ -150,19 +170,82 @@ unresolved-import = "ignore" # import tomli as tomllib when running on later ver
include = ["tests/**"]
[tool.ty.overrides.rules]
-##################################################################################################
-# The ignored rules below should be removed once the code has been updated, they are included #
-# like this so that we can reactivate them one by one. #
-##################################################################################################
-invalid-argument-type = "ignore"
+unused-ignore-comment = "ignore" # Clashes with mypy's type ignore comments
+
+[[tool.ty.overrides]]
+include = ["tests/fixtures/**"]
+
+[tool.ty.overrides.rules]
+invalid-argument-type = "ignore" # Test fixtures - dynamic mock data
+possibly-missing-attribute = "ignore" # Test fixtures use dynamic attributes
+
+# Test-specific overrides - tests have more lenient type checking
+# Fix these incrementally, starting with files that have fewer violations
+
+[[tool.ty.overrides]]
+include = ["tests/unit/sdk/conftest.py"]
+
+[tool.ty.overrides.rules]
+invalid-argument-type = "ignore" # 434 violations - test fixtures with dynamic types
+
+[[tool.ty.overrides]]
+include = [
+ "tests/unit/sdk/test_node.py",
+ "tests/unit/sdk/test_hierarchical_nodes.py",
+ "tests/unit/sdk/test_schema.py",
+]
+
+[tool.ty.overrides.rules]
+invalid-argument-type = "ignore" # 97 violations total across these files
invalid-assignment = "ignore"
-invalid-method-override = "ignore"
-no-matching-overload = "ignore"
not-subscriptable = "ignore"
not-iterable = "ignore"
-possibly-missing-attribute = "ignore"
unresolved-attribute = "ignore"
-unused-ignore-comment = "ignore" # Clashes with mypy's type ignore comments
+possibly-missing-attribute = "ignore"
+
+[[tool.ty.overrides]]
+include = [
+ "tests/integration/**",
+ "tests/unit/sdk/test_store_branch.py",
+ "tests/unit/sdk/test_repository.py",
+]
+
+[tool.ty.overrides.rules]
+invalid-argument-type = "ignore" # ~120 violations across integration tests
+invalid-assignment = "ignore"
+no-matching-overload = "ignore"
+possibly-missing-attribute = "ignore" # Tests use dynamic node attributes
+
+[[tool.ty.overrides]]
+include = [
+ "tests/unit/sdk/spec/test_object.py",
+ "tests/unit/sdk/test_client.py",
+ "tests/unit/ctl/test_graphql_utils.py",
+]
+
+[tool.ty.overrides.rules]
+invalid-argument-type = "ignore" # 29 violations
+invalid-assignment = "ignore"
+no-matching-overload = "ignore"
+possibly-missing-attribute = "ignore"
+
+[[tool.ty.overrides]]
+include = [
+ "tests/unit/sdk/test_artifact.py",
+ "tests/unit/sdk/test_group_context.py",
+ "tests/unit/sdk/test_utils.py",
+ "tests/unit/sdk/checks/test_checks.py",
+ "tests/unit/sdk/graphql/test_plugin.py",
+ "tests/unit/sdk/test_protocols_generator.py",
+ "tests/unit/sdk/test_schema_sorter.py",
+ "tests/unit/sdk/test_topological_sort.py",
+ "tests/unit/sdk/test_schema_export.py",
+]
+
+[tool.ty.overrides.rules]
+invalid-argument-type = "ignore" # Remaining files with 1-5 violations each
+invalid-method-override = "ignore"
+no-matching-overload = "ignore"
[[tool.ty.overrides]]
include = ["docs/**"]
@@ -238,18 +321,15 @@ ignore = [
# investigation if they are deemed to not make sense. #
##################################################################################################
"B008", # Do not perform function call `typer.Option` in argument defaults; instead, perform the call within the function, or read the default from a module-level singleton variable
- "B904", # Within an `except` clause, raise exceptions with `raise ... from err` or `raise ... from None` to distinguish them from errors in exception handling
"N802", # Function name should be lowercase
"PERF203", # `try`-`except` within a loop incurs performance overhead
"PLR0913", # Too many arguments in function definition
"PLR0917", # Too many positional arguments
"PLR2004", # Magic value used in comparison
"PLR6301", # Method could be a function, class method, or static method
- "RUF005", # Consider `[*path, str(key)]` instead of concatenation
"RUF029", # Function is declared `async`, but doesn't `await` or use `async` features.
"RUF067", # `__init__` module should only contain docstrings and re-exports
"S311", # Standard pseudo-random generators are not suitable for cryptographic purposes
- "SIM108", # Use ternary operator `key_str = f"{value[ALIAS_KEY]}: {key}" if ALIAS_KEY in value and value[ALIAS_KEY] else key` instead of `if`-`else`-block
"TC003", # Move standard library import `collections.abc.Iterable` into a type-checking block
"UP031", # Use format specifiers instead of percent format
]
@@ -292,7 +372,6 @@ max-complexity = 17
"ANN202", # Missing return type annotation for private function
"ANN401", # Dynamically typed expressions (typing.Any) are disallowed
"ASYNC240", # Async functions should not use pathlib.Path methods, use trio.Path or anyio.path
- "PT013", # Incorrect import of `pytest`; use `import pytest` instead
]
"infrahub_sdk/client.py" = [
@@ -307,6 +386,7 @@ max-complexity = 17
# Review and change the below later #
##################################################################################################
"PLR0912", # Too many branches
+ "PLR0904", # Too many public methods
]
"infrahub_sdk/node/related_node.py" = [
@@ -348,12 +428,11 @@ max-complexity = 17
"S106", # Possible hardcoded password assigned to argument
"ARG001", # Unused function argument
"ARG002", # Unused method argument
- "PT006", # Wrong type passed to first argument of `pytest.mark.parametrize`; expected `tuple`
"PT011", # `pytest.raises(ValueError)` is too broad, set the `match` parameter or use a more specific exception
"PT012", # `pytest.raises()` block should contain a single simple statement
- "PT013", # Incorrect import of `pytest`; use `import pytest` instead
]
+
# tests/integration/
"tests/integration/test_infrahub_client.py" = ["PLR0904"]
"tests/integration/test_infrahub_client_sync.py" = ["PLR0904"]
diff --git a/tests/integration/test_export_import.py b/tests/integration/test_export_import.py
index a138728f..d317d5e7 100644
--- a/tests/integration/test_export_import.py
+++ b/tests/integration/test_export_import.py
@@ -15,8 +15,6 @@
from infrahub_sdk.transfer.schema_sorter import InfrahubSchemaTopologicalSorter
if TYPE_CHECKING:
- from pytest import TempPathFactory
-
from infrahub_sdk import InfrahubClient
from infrahub_sdk.node import InfrahubNode
from infrahub_sdk.schema import SchemaRoot
@@ -24,7 +22,7 @@
class TestSchemaExportImportBase(TestInfrahubDockerClient, SchemaCarPerson):
@pytest.fixture(scope="class")
- def temporary_directory(self, tmp_path_factory: TempPathFactory) -> Path:
+ def temporary_directory(self, tmp_path_factory: pytest.TempPathFactory) -> Path:
return tmp_path_factory.mktemp("infrahub-integration-tests")
@pytest.fixture(scope="class")
@@ -189,7 +187,7 @@ async def test_step99_import_wrong_directory(self, client: InfrahubClient) -> No
class TestSchemaExportImportManyRelationships(TestInfrahubDockerClient, SchemaCarPerson):
@pytest.fixture(scope="class")
- def temporary_directory(self, tmp_path_factory: TempPathFactory) -> Path:
+ def temporary_directory(self, tmp_path_factory: pytest.TempPathFactory) -> Path:
return tmp_path_factory.mktemp("infrahub-integration-tests-many")
@pytest.fixture(scope="class")
diff --git a/tests/integration/test_file_object.py b/tests/integration/test_file_object.py
new file mode 100644
index 00000000..49dd7421
--- /dev/null
+++ b/tests/integration/test_file_object.py
@@ -0,0 +1,260 @@
+from __future__ import annotations
+
+import hashlib
+import tempfile
+from pathlib import Path
+from typing import TYPE_CHECKING
+
+import pytest
+
+from infrahub_sdk.testing.docker import TestInfrahubDockerClient
+from infrahub_sdk.testing.schemas.file_object import (
+ PDF_MAGIC_BYTES,
+ PNG_MAGIC_BYTES,
+ TESTING_FILE_CONTRACT,
+ TEXT_CONTENT,
+ SchemaFileObject,
+)
+
+if TYPE_CHECKING:
+ from infrahub_sdk import InfrahubClient, InfrahubClientSync
+
+
+@pytest.mark.xfail(reason="Requires Infrahub 1.8+")
+class TestFileObjectAsync(TestInfrahubDockerClient, SchemaFileObject):
+ """Async integration tests for FileObject functionality."""
+
+ async def test_create_file_object_with_upload(self, client: InfrahubClient, load_file_object_schema: None) -> None:
+ """Test creating FileObject nodes with both upload_from_bytes and upload_from_path."""
+ contract_bytes = await client.create(
+ kind=TESTING_FILE_CONTRACT,
+ contract_ref="CONTRACT-CREATE-BYTES-001",
+ description="Test contract with bytes upload",
+ )
+ contract_bytes.upload_from_bytes(content=PDF_MAGIC_BYTES, name="contract.pdf")
+ await contract_bytes.save()
+
+ fetched = await client.get(kind=TESTING_FILE_CONTRACT, id=contract_bytes.id)
+ assert fetched.contract_ref.value == "CONTRACT-CREATE-BYTES-001"
+ assert fetched.file_name.value == "contract.pdf"
+ assert fetched.file_size.value == len(PDF_MAGIC_BYTES)
+ assert fetched.checksum.value == hashlib.sha1(PDF_MAGIC_BYTES, usedforsecurity=False).hexdigest()
+ assert fetched.storage_id.value
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ tmp_path = Path(tmpdir) / "upload_test.txt"
+ tmp_path.write_bytes(TEXT_CONTENT)
+
+ contract_path = await client.create(
+ kind=TESTING_FILE_CONTRACT,
+ contract_ref="CONTRACT-CREATE-PATH-001",
+ description="Test contract from path",
+ )
+ contract_path.upload_from_path(path=tmp_path)
+ await contract_path.save()
+
+ fetched = await client.get(kind=TESTING_FILE_CONTRACT, id=contract_path.id)
+ assert fetched.file_name.value == tmp_path.name
+ assert fetched.file_size.value == len(TEXT_CONTENT)
+ assert fetched.checksum.value == hashlib.sha1(TEXT_CONTENT, usedforsecurity=False).hexdigest()
+ assert fetched.storage_id.value
+
+ async def test_update_file_object_with_new_file(
+ self, client: InfrahubClient, load_file_object_schema: None
+ ) -> None:
+ """Test updating a FileObject node with a new file."""
+ contract = await client.create(
+ kind=TESTING_FILE_CONTRACT, contract_ref="CONTRACT-UPDATE-001", description="Initial contract"
+ )
+ contract.upload_from_bytes(content=PDF_MAGIC_BYTES, name="initial.pdf")
+ await contract.save()
+
+ contract_to_update = await client.get(kind=TESTING_FILE_CONTRACT, id=contract.id, populate_store=False)
+ contract_to_update.description.value = "Updated contract"
+ contract_to_update.upload_from_bytes(content=PNG_MAGIC_BYTES, name="updated.png")
+ await contract_to_update.save()
+
+ updated = await client.get(kind=TESTING_FILE_CONTRACT, id=contract.id, populate_store=False)
+ assert updated.description.value == "Updated contract"
+ assert updated.file_name.value == "updated.png"
+ assert updated.storage_id.value != contract.storage_id.value
+ assert updated.checksum.value != contract.checksum.value
+
+ async def test_upsert_file_object_update(self, client: InfrahubClient, load_file_object_schema: None) -> None:
+ """Test upserting an existing FileObject node updates it rather than creating a duplicate."""
+ contract = await client.create(
+ kind=TESTING_FILE_CONTRACT, contract_ref="CONTRACT-UPSERT-001", description="Original"
+ )
+ contract.upload_from_bytes(content=PDF_MAGIC_BYTES, name="original.pdf")
+ await contract.save()
+
+ contract_upsert = await client.create(
+ kind=TESTING_FILE_CONTRACT,
+ contract_ref="CONTRACT-UPSERT-001",
+ description="Upserted update",
+ )
+ contract_upsert.upload_from_bytes(content=PNG_MAGIC_BYTES, name="upserted.png")
+ await contract_upsert.save(allow_upsert=True)
+ assert contract_upsert.id == contract.id
+
+ updated = await client.get(kind=TESTING_FILE_CONTRACT, id=contract.id, populate_store=False)
+ assert updated.description.value == "Upserted update"
+ assert updated.file_name.value == "upserted.png"
+ assert updated.storage_id.value != contract.storage_id.value
+
+ async def test_download_file(self, client: InfrahubClient, load_file_object_schema: None) -> None:
+ """Test downloading files to memory and to disk."""
+ contract = await client.create(
+ kind=TESTING_FILE_CONTRACT, contract_ref="CONTRACT-DOWNLOAD-001", description="Download test"
+ )
+ contract.upload_from_bytes(content=TEXT_CONTENT, name="download_test.txt")
+ await contract.save()
+
+ fetched = await client.get(kind=TESTING_FILE_CONTRACT, id=contract.id, populate_store=False)
+ downloaded_content = await fetched.download_file()
+ assert downloaded_content == TEXT_CONTENT
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ dest_path = Path(tmpdir) / "downloaded.txt"
+ bytes_written = await fetched.download_file(dest=dest_path)
+ assert bytes_written == len(TEXT_CONTENT)
+ assert dest_path.read_bytes() == TEXT_CONTENT
+
+ async def test_update_without_file_change(self, client: InfrahubClient, load_file_object_schema: None) -> None:
+ """Test updating FileObject attributes without replacing the file."""
+ contract = await client.create(
+ kind=TESTING_FILE_CONTRACT, contract_ref="CONTRACT-META-001", description="Original description"
+ )
+ contract.upload_from_bytes(content=TEXT_CONTENT, name="unchanged.txt")
+ await contract.save()
+
+ contract_to_update = await client.get(kind=TESTING_FILE_CONTRACT, id=contract.id, populate_store=False)
+ contract_to_update.description.value = "Updated description"
+ await contract_to_update.save()
+
+ updated = await client.get(kind=TESTING_FILE_CONTRACT, id=contract.id)
+ assert updated.description.value == "Updated description"
+ assert updated.storage_id.value == contract_to_update.storage_id.value
+ assert updated.checksum.value == contract_to_update.checksum.value
+
+
+@pytest.mark.xfail(reason="Requires Infrahub 1.8+")
+class TestFileObjectSync(TestInfrahubDockerClient, SchemaFileObject):
+ """Sync integration tests for FileObject functionality."""
+
+ def test_create_file_object_with_upload_sync(
+ self, client_sync: InfrahubClientSync, load_file_object_schema_sync: None
+ ) -> None:
+ """Test creating FileObject nodes with both upload_from_bytes and upload_from_path (sync)."""
+ contract_bytes = client_sync.create(
+ kind=TESTING_FILE_CONTRACT,
+ contract_ref="CONTRACT-CREATE-BYTES-SYNC-001",
+ description="Test contract with bytes upload (sync)",
+ )
+ contract_bytes.upload_from_bytes(content=PDF_MAGIC_BYTES, name="contract_sync.pdf")
+ contract_bytes.save()
+
+ fetched = client_sync.get(kind=TESTING_FILE_CONTRACT, id=contract_bytes.id)
+ assert fetched.contract_ref.value == "CONTRACT-CREATE-BYTES-SYNC-001"
+ assert fetched.file_name.value == "contract_sync.pdf"
+ assert fetched.file_size.value == len(PDF_MAGIC_BYTES)
+ assert fetched.checksum.value == hashlib.sha1(PDF_MAGIC_BYTES, usedforsecurity=False).hexdigest()
+ assert fetched.storage_id.value
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ tmp_path = Path(tmpdir) / "upload_test_sync.txt"
+ tmp_path.write_bytes(TEXT_CONTENT)
+
+ contract_path = client_sync.create(
+ kind=TESTING_FILE_CONTRACT,
+ contract_ref="CONTRACT-CREATE-PATH-SYNC-001",
+ description="Test contract from path (sync)",
+ )
+ contract_path.upload_from_path(path=tmp_path)
+ contract_path.save()
+
+ fetched = client_sync.get(kind=TESTING_FILE_CONTRACT, id=contract_path.id)
+ assert fetched.file_name.value == tmp_path.name
+ assert fetched.file_size.value == len(TEXT_CONTENT)
+ assert fetched.checksum.value == hashlib.sha1(TEXT_CONTENT, usedforsecurity=False).hexdigest()
+ assert fetched.storage_id.value
+
+ def test_update_file_object_with_new_file_sync(
+ self, client_sync: InfrahubClientSync, load_file_object_schema_sync: None
+ ) -> None:
+ """Test updating a FileObject node with a new file (sync)."""
+ contract = client_sync.create(
+ kind=TESTING_FILE_CONTRACT, contract_ref="CONTRACT-UPDATE-SYNC-001", description="Initial contract sync"
+ )
+ contract.upload_from_bytes(content=PDF_MAGIC_BYTES, name="initial_sync.pdf")
+ contract.save()
+
+ contract_to_update = client_sync.get(kind=TESTING_FILE_CONTRACT, id=contract.id, populate_store=False)
+ contract_to_update.description.value = "Updated contract sync"
+ contract_to_update.upload_from_bytes(content=PNG_MAGIC_BYTES, name="updated_sync.png")
+ contract_to_update.save()
+
+ updated = client_sync.get(kind=TESTING_FILE_CONTRACT, id=contract.id, populate_store=False)
+ assert updated.description.value == "Updated contract sync"
+ assert updated.file_name.value == "updated_sync.png"
+ assert updated.storage_id.value != contract.storage_id.value
+ assert updated.checksum.value != contract.checksum.value
+
+ def test_upsert_file_object_update_sync(
+ self, client_sync: InfrahubClientSync, load_file_object_schema_sync: None
+ ) -> None:
+ """Test upserting an existing FileObject node updates it rather than creating a duplicate (sync)."""
+ contract = client_sync.create(
+ kind=TESTING_FILE_CONTRACT, contract_ref="CONTRACT-UPSERT-SYNC-001", description="Original sync"
+ )
+ contract.upload_from_bytes(content=PDF_MAGIC_BYTES, name="original_sync.pdf")
+ contract.save()
+
+ contract_upsert = client_sync.create(
+ kind=TESTING_FILE_CONTRACT, contract_ref="CONTRACT-UPSERT-SYNC-001", description="Upserted update sync"
+ )
+ contract_upsert.upload_from_bytes(content=PNG_MAGIC_BYTES, name="upserted_sync.png")
+ contract_upsert.save(allow_upsert=True)
+ assert contract_upsert.id == contract.id
+
+ updated = client_sync.get(kind=TESTING_FILE_CONTRACT, id=contract.id, populate_store=False)
+ assert updated.description.value == "Upserted update sync"
+ assert updated.file_name.value == "upserted_sync.png"
+ assert updated.storage_id.value != contract.storage_id.value
+
+ def test_download_file_sync(self, client_sync: InfrahubClientSync, load_file_object_schema_sync: None) -> None:
+ """Test downloading files to memory and to disk (sync)."""
+ contract = client_sync.create(
+ kind=TESTING_FILE_CONTRACT, contract_ref="CONTRACT-DOWNLOAD-SYNC-001", description="Download test sync"
+ )
+ contract.upload_from_bytes(content=TEXT_CONTENT, name="download_sync.txt")
+ contract.save()
+
+ fetched = client_sync.get(kind=TESTING_FILE_CONTRACT, id=contract.id, populate_store=False)
+ downloaded_content = fetched.download_file()
+ assert downloaded_content == TEXT_CONTENT
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ dest_path = Path(tmpdir) / "downloaded_sync.txt"
+ bytes_written = fetched.download_file(dest=dest_path)
+ assert bytes_written == len(TEXT_CONTENT)
+ assert dest_path.read_bytes() == TEXT_CONTENT
+
+ def test_update_without_file_change_sync(
+ self, client_sync: InfrahubClientSync, load_file_object_schema_sync: None
+ ) -> None:
+ """Test updating FileObject attributes without replacing the file (sync)."""
+ contract = client_sync.create(
+ kind=TESTING_FILE_CONTRACT, contract_ref="CONTRACT-META-SYNC-001", description="Original description sync"
+ )
+ contract.upload_from_bytes(content=TEXT_CONTENT, name="unchanged_sync.txt")
+ contract.save()
+
+ contract_to_update = client_sync.get(kind=TESTING_FILE_CONTRACT, id=contract.id, populate_store=False)
+ contract_to_update.description.value = "Updated description sync"
+ contract_to_update.save()
+
+ updated = client_sync.get(kind=TESTING_FILE_CONTRACT, id=contract.id)
+ assert updated.description.value == "Updated description sync"
+ assert updated.storage_id.value == contract_to_update.storage_id.value
+ assert updated.checksum.value == contract_to_update.checksum.value
diff --git a/tests/unit/ctl/test_branch_report.py b/tests/unit/ctl/test_branch_report.py
index c9af76c5..201f1c92 100644
--- a/tests/unit/ctl/test_branch_report.py
+++ b/tests/unit/ctl/test_branch_report.py
@@ -341,19 +341,21 @@ def mock_branch_report_with_proposed_changes(httpx_mock: HTTPXMock) -> HTTPXMock
],
},
"rejected_by": {"count": 0, "edges": []},
+ },
+ "node_metadata": {
+ "created_at": "2025-11-10T14:30:00Z",
"created_by": {
- "node": {
- "id": "187895d8-723e-8f5d-3614-c517ac8e761c",
- "hfid": ["johndoe"],
- "display_label": "John Doe",
- "__typename": "CoreAccount",
- "name": {"value": "John Doe"},
- },
- "properties": {
- "updated_at": "2025-11-10T14:30:00Z",
- },
+ "id": "187895d8-723e-8f5d-3614-c517ac8e761c",
+ "__typename": "CoreAccount",
+ "display_label": "John Doe",
},
- }
+ "updated_at": "2025-11-10T14:30:00Z",
+ "updated_by": {
+ "id": "187895d8-723e-8f5d-3614-c517ac8e761c",
+ "__typename": "CoreAccount",
+ "display_label": "John Doe",
+ },
+ },
},
{
"node": {
@@ -392,19 +394,21 @@ def mock_branch_report_with_proposed_changes(httpx_mock: HTTPXMock) -> HTTPXMock
},
],
},
+ },
+ "node_metadata": {
+ "created_at": "2025-11-12T09:15:00Z",
"created_by": {
- "node": {
- "id": "287895d8-723e-8f5d-3614-c517ac8e762c",
- "hfid": ["janesmith"],
- "display_label": "Jane Smith",
- "__typename": "CoreAccount",
- "name": {"value": "Jane Smith"},
- },
- "properties": {
- "updated_at": "2025-11-10T14:30:00Z",
- },
+ "id": "287895d8-723e-8f5d-3614-c517ac8e762c",
+ "__typename": "CoreAccount",
+ "display_label": "Jane Smith",
+ },
+ "updated_at": "2025-11-12T09:15:00Z",
+ "updated_by": {
+ "id": "287895d8-723e-8f5d-3614-c517ac8e762c",
+ "__typename": "CoreAccount",
+ "display_label": "Jane Smith",
},
- }
+ },
},
],
}
diff --git a/tests/unit/ctl/test_render_app.py b/tests/unit/ctl/test_render_app.py
index 88fd32c9..f159dfa0 100644
--- a/tests/unit/ctl/test_render_app.py
+++ b/tests/unit/ctl/test_render_app.py
@@ -71,7 +71,7 @@ def test_validate_template_not_found(test_case: RenderAppFailure, httpx_mock: HT
@pytest.mark.parametrize(
- "cli_branch,env_branch,from_git,expected_branch",
+ ("cli_branch", "env_branch", "from_git", "expected_branch"),
[
("cli-branch", None, False, "cli-branch"),
(None, "env-branch", False, "env-branch"),
diff --git a/tests/unit/doc_generation/content_gen_methods/mdx/test_mdx_code_doc.py b/tests/unit/doc_generation/content_gen_methods/mdx/test_mdx_code_doc.py
index 5c505ed1..d08d7e1a 100644
--- a/tests/unit/doc_generation/content_gen_methods/mdx/test_mdx_code_doc.py
+++ b/tests/unit/doc_generation/content_gen_methods/mdx/test_mdx_code_doc.py
@@ -1,9 +1,9 @@
from __future__ import annotations
from pathlib import Path
+from unittest.mock import create_autospec
-from invoke import Result
-from invoke.context import MockContext
+from invoke import Context, Result
from docs.docs_generation.content_gen_methods import (
MdxCodeDocumentation,
@@ -13,8 +13,8 @@
def _make_mock_context(
module_files: dict[str, dict[str, str]],
calls: list[str] | None = None,
-) -> MockContext:
- """Build a ``MockContext`` whose ``run()`` writes files based on requested modules.
+) -> Context:
+ """Build a mock ``Context`` whose ``run()`` writes files based on requested modules.
Args:
module_files: Mapping of module name to its output files
@@ -24,7 +24,7 @@ def _make_mock_context(
calls: If provided, each executed command string is appended to this
list so the caller can verify how many times ``run()`` was invoked.
"""
- ctx = MockContext(run=Result())
+ ctx = create_autospec(Context, instance=True)
def fake_run(cmd: str, **kwargs: object) -> Result:
if calls is not None:
diff --git a/tests/unit/doc_generation/content_gen_methods/test_command_output_method.py b/tests/unit/doc_generation/content_gen_methods/test_command_output_method.py
index 8a3d56c8..9b1befcf 100644
--- a/tests/unit/doc_generation/content_gen_methods/test_command_output_method.py
+++ b/tests/unit/doc_generation/content_gen_methods/test_command_output_method.py
@@ -1,9 +1,9 @@
from __future__ import annotations
from pathlib import Path
+from unittest.mock import create_autospec
-from invoke import Result
-from invoke.context import MockContext
+from invoke import Context, Result
from docs.docs_generation import ACommand, CommandOutputDocContentGenMethod
@@ -29,7 +29,7 @@ def fake_run(cmd: str, **kwargs: object) -> Result:
output_path.write_text(output_content, encoding="utf-8")
return Result()
- mock_context = MockContext(run=Result())
+ mock_context = create_autospec(Context, instance=True)
mock_context.run.side_effect = fake_run
method = CommandOutputDocContentGenMethod(
@@ -55,7 +55,7 @@ def fake_run(cmd: str, **kwargs: object) -> Result:
Path(parts[1].strip()).write_text("", encoding="utf-8")
return Result()
- mock_context = MockContext(run=Result())
+ mock_context = create_autospec(Context, instance=True)
mock_context.run.side_effect = fake_run
method = CommandOutputDocContentGenMethod(
diff --git a/tests/unit/doc_generation/mdx/test_mdx_ordered_section.py b/tests/unit/doc_generation/mdx/test_mdx_ordered_section.py
index 1bc25948..f279ca8e 100644
--- a/tests/unit/doc_generation/mdx/test_mdx_ordered_section.py
+++ b/tests/unit/doc_generation/mdx/test_mdx_ordered_section.py
@@ -87,7 +87,7 @@ def _make_ordered(
child_heading_level: int = 3,
) -> OrderedMdxSection:
heading = "#" * heading_level + f" `{name}`"
- section = MdxSection(name=name, heading_level=heading_level, _lines=[heading] + children_lines)
+ section = MdxSection(name=name, heading_level=heading_level, _lines=[heading, *children_lines])
return OrderedMdxSection(
section=section,
priority=priority,
diff --git a/tests/unit/doc_generation/test_docs_validate.py b/tests/unit/doc_generation/test_docs_validate.py
index 52e259d3..b65f5bbf 100644
--- a/tests/unit/doc_generation/test_docs_validate.py
+++ b/tests/unit/doc_generation/test_docs_validate.py
@@ -3,16 +3,12 @@
import os
import subprocess # noqa: S404
from pathlib import Path
-from typing import TYPE_CHECKING
import pytest
from invoke import Context, Exit
import tasks
-if TYPE_CHECKING:
- from pytest import MonkeyPatch
-
_GIT_ENV = {
"GIT_AUTHOR_NAME": "test",
"GIT_AUTHOR_EMAIL": "test@test.com",
@@ -45,7 +41,7 @@ class TestDocsValidate:
"""Ensure docs_validate() detects drift between committed and regenerated documentation."""
def test_passes_when_generation_produces_no_changes(
- self, git_repo_with_docs: Path, monkeypatch: MonkeyPatch
+ self, git_repo_with_docs: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
# Arrange
monkeypatch.setattr(tasks, "docs_generate", lambda context: None) # noqa: ARG005
@@ -55,7 +51,7 @@ def test_passes_when_generation_produces_no_changes(
tasks.docs_validate(Context())
def test_fails_when_generation_modifies_existing_file(
- self, git_repo_with_docs: Path, monkeypatch: MonkeyPatch
+ self, git_repo_with_docs: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
# Arrange
def fake_generate(context: Context) -> None:
@@ -69,7 +65,7 @@ def fake_generate(context: Context) -> None:
tasks.docs_validate(Context())
def test_fails_when_generation_deletes_tracked_file(
- self, git_repo_with_docs: Path, monkeypatch: MonkeyPatch
+ self, git_repo_with_docs: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
# Arrange
def fake_generate(context: Context) -> None:
@@ -82,7 +78,9 @@ def fake_generate(context: Context) -> None:
with pytest.raises(Exit, match="Modified or deleted files"):
tasks.docs_validate(Context())
- def test_fails_when_generation_creates_new_file(self, git_repo_with_docs: Path, monkeypatch: MonkeyPatch) -> None:
+ def test_fails_when_generation_creates_new_file(
+ self, git_repo_with_docs: Path, monkeypatch: pytest.MonkeyPatch
+ ) -> None:
# Arrange
def fake_generate(context: Context) -> None:
(git_repo_with_docs / "docs" / "new_file.mdx").write_text("# New\n")
diff --git a/tests/unit/sdk/conftest.py b/tests/unit/sdk/conftest.py
index d28c23b9..ad725532 100644
--- a/tests/unit/sdk/conftest.py
+++ b/tests/unit/sdk/conftest.py
@@ -2538,6 +2538,52 @@ async def nested_device_with_interfaces_schema() -> NodeSchemaAPI:
return NodeSchema(**data).convert_api()
+@pytest.fixture
+async def file_object_schema() -> NodeSchemaAPI:
+ """Schema for a node that inherits from CoreFileObject."""
+ data = {
+ "name": "CircuitContract",
+ "namespace": "Network",
+ "label": "Circuit Contract",
+ "default_filter": "file_name__value",
+ "inherit_from": ["CoreFileObject"],
+ "order_by": ["file_name__value"],
+ "display_labels": ["file_name__value"],
+ "attributes": [
+ # Simulate inherited attributes from CoreFileObject
+ {"name": "file_name", "kind": "Text", "read_only": True, "optional": False},
+ {"name": "checksum", "kind": "Text", "read_only": True, "optional": False},
+ {"name": "file_size", "kind": "Number", "read_only": True, "optional": False},
+ {"name": "file_type", "kind": "Text", "read_only": True, "optional": False},
+ {"name": "storage_id", "kind": "Text", "read_only": True, "optional": False},
+ {"name": "contract_start", "kind": "DateTime", "optional": False},
+ {"name": "contract_end", "kind": "DateTime", "optional": False},
+ ],
+ "relationships": [],
+ }
+ return NodeSchema(**data).convert_api()
+
+
+@pytest.fixture
+async def non_file_object_schema() -> NodeSchemaAPI:
+ """Schema for a regular node that does not inherit from CoreFileObject."""
+ data = {
+ "name": "Device",
+ "namespace": "Infra",
+ "label": "Device",
+ "default_filter": "name__value",
+ "inherit_from": [],
+ "order_by": ["name__value"],
+ "display_labels": ["name__value"],
+ "attributes": [
+ {"name": "name", "kind": "Text", "unique": True},
+ {"name": "description", "kind": "Text", "optional": True},
+ ],
+ "relationships": [],
+ }
+ return NodeSchema(**data).convert_api()
+
+
@pytest.fixture
async def vlan_schema() -> NodeSchemaAPI:
data = {
diff --git a/tests/unit/sdk/graphql/test_multipart.py b/tests/unit/sdk/graphql/test_multipart.py
new file mode 100644
index 00000000..35a0e58a
--- /dev/null
+++ b/tests/unit/sdk/graphql/test_multipart.py
@@ -0,0 +1,177 @@
+"""Unit tests for MultipartBuilder class."""
+
+from __future__ import annotations
+
+from io import BytesIO
+
+import ujson
+
+from infrahub_sdk.graphql import MultipartBuilder
+
+
+def test_build_operations_simple() -> None:
+ """Test building operations with simple query and variables."""
+ query = "mutation($file: Upload!) { upload(file: $file) { id } }"
+ variables = {"other": "value"}
+
+ result = MultipartBuilder.build_operations(query=query, variables=variables)
+
+ parsed = ujson.loads(result)
+ assert parsed["query"] == query
+ assert parsed["variables"] == variables
+
+
+def test_build_operations_empty_variables() -> None:
+ """Test building operations with empty variables."""
+ query = "mutation { doSomething { id } }"
+ variables: dict[str, str] = {}
+
+ result = MultipartBuilder.build_operations(query=query, variables=variables)
+
+ parsed = ujson.loads(result)
+ assert parsed["query"] == query
+ assert parsed["variables"] == {}
+
+
+def test_build_operations_complex_variables() -> None:
+ """Test building operations with nested variables."""
+ query = "mutation($input: CreateInput!) { create(input: $input) { id } }"
+ variables = {"input": {"name": "test", "nested": {"value": 123}, "list": [1, 2, 3]}}
+
+ result = MultipartBuilder.build_operations(query=query, variables=variables)
+
+ parsed = ujson.loads(result)
+ assert parsed["variables"]["input"]["name"] == "test"
+ assert parsed["variables"]["input"]["nested"]["value"] == 123
+ assert parsed["variables"]["input"]["list"] == [1, 2, 3]
+
+
+def test_build_file_map_defaults() -> None:
+ """Test building file map with default values."""
+ result = MultipartBuilder.build_file_map()
+
+ parsed = ujson.loads(result)
+ assert parsed == {"0": ["variables.file"]}
+
+
+def test_build_file_map_custom_key() -> None:
+ """Test building file map with custom file key."""
+ result = MultipartBuilder.build_file_map(file_key="1")
+
+ parsed = ujson.loads(result)
+ assert parsed == {"1": ["variables.file"]}
+
+
+def test_build_file_map_custom_path() -> None:
+ """Test building file map with custom variable path."""
+ result = MultipartBuilder.build_file_map(variable_path="variables.input.document")
+
+ parsed = ujson.loads(result)
+ assert parsed == {"0": ["variables.input.document"]}
+
+
+def test_build_file_map_both_custom() -> None:
+ """Test building file map with both custom values."""
+ result = MultipartBuilder.build_file_map(file_key="attachment", variable_path="variables.attachment")
+
+ parsed = ujson.loads(result)
+ assert parsed == {"attachment": ["variables.attachment"]}
+
+
+def test_build_payload_with_file() -> None:
+ """Test building complete payload with file content."""
+ query = "mutation($file: Upload!) { upload(file: $file) { id } }"
+ variables = {"other": "value"}
+ file_content = BytesIO(b"test file content")
+ file_name = "document.pdf"
+
+ result = MultipartBuilder.build_payload(
+ query=query, variables=variables, file_content=file_content, file_name=file_name
+ )
+
+ # Check operations
+ assert "operations" in result
+ assert result["operations"][0] is None # No filename for operations
+ operations_json = ujson.loads(result["operations"][1])
+ assert operations_json["query"] == query
+ assert operations_json["variables"]["other"] == "value"
+ assert operations_json["variables"]["file"] is None # File var should be null
+
+ # Check map
+ assert "map" in result
+ assert result["map"][0] is None
+ map_json = ujson.loads(result["map"][1])
+ assert map_json == {"0": ["variables.file"]}
+
+ # Check file
+ assert "0" in result
+ assert result["0"][0] == file_name
+ assert result["0"][1] is file_content
+
+
+def test_build_payload_without_file() -> None:
+ """Test building payload without file content."""
+ query = "mutation($file: Upload!) { upload(file: $file) { id } }"
+ variables = {"other": "value"}
+
+ result = MultipartBuilder.build_payload(query=query, variables=variables, file_content=None, file_name="unused.txt")
+
+ # Should have operations and map
+ assert "operations" in result
+ assert "map" in result
+
+ # Should NOT have file key
+ assert "0" not in result
+
+
+def test_build_payload_sets_file_var_to_null() -> None:
+ """Test that build_payload sets file variable to null per spec."""
+ query = "mutation($file: Upload!) { upload(file: $file) { id } }"
+ variables = {"file": "should_be_overwritten", "other": "value"}
+ file_content = BytesIO(b"content")
+
+ result = MultipartBuilder.build_payload(
+ query=query, variables=variables, file_content=file_content, file_name="test.txt"
+ )
+
+ operations_json = ujson.loads(result["operations"][1])
+ assert operations_json["variables"]["file"] is None
+ assert operations_json["variables"]["other"] == "value"
+
+
+def test_build_payload_default_filename() -> None:
+ """Test that default filename is used when not specified."""
+ query = "mutation($file: Upload!) { upload(file: $file) { id } }"
+ file_content = BytesIO(b"content")
+
+ result = MultipartBuilder.build_payload(
+ query=query,
+ variables={},
+ file_content=file_content,
+ )
+
+ assert result["0"][0] == "upload"
+
+
+def test_build_payload_preserves_existing_variables() -> None:
+ """Test that existing variables are preserved in the payload."""
+ query = "mutation($file: Upload!, $nodeId: ID!) { upload(file: $file, node: $nodeId) { id } }"
+ variables = {
+ "nodeId": "node-123",
+ "description": "A test file",
+ "nested": {"key": "value"},
+ }
+ file_content = BytesIO(b"content")
+
+ result = MultipartBuilder.build_payload(
+ query=query,
+ variables=variables,
+ file_content=file_content,
+ file_name="test.txt",
+ )
+
+ operations_json = ujson.loads(result["operations"][1])
+ assert operations_json["variables"]["nodeId"] == "node-123"
+ assert operations_json["variables"]["description"] == "A test file"
+ assert operations_json["variables"]["nested"] == {"key": "value"}
+ assert operations_json["variables"]["file"] is None # file is null per spec
diff --git a/tests/unit/sdk/pool/conftest.py b/tests/unit/sdk/pool/conftest.py
index e8276be6..9d0fd245 100644
--- a/tests/unit/sdk/pool/conftest.py
+++ b/tests/unit/sdk/pool/conftest.py
@@ -1,5 +1,7 @@
from __future__ import annotations
+from typing import Any
+
import pytest
from infrahub_sdk.schema import BranchSupportType, NodeSchema, NodeSchemaAPI
@@ -7,7 +9,7 @@
@pytest.fixture
async def ipaddress_pool_schema() -> NodeSchemaAPI:
- data = {
+ data: dict[str, Any] = {
"name": "IPAddressPool",
"namespace": "Core",
"description": "A pool of IP address resources",
@@ -57,7 +59,7 @@ async def ipaddress_pool_schema() -> NodeSchemaAPI:
@pytest.fixture
async def ipprefix_pool_schema() -> NodeSchemaAPI:
- data = {
+ data: dict[str, Any] = {
"name": "IPPrefixPool",
"namespace": "Core",
"description": "A pool of IP prefix resources",
diff --git a/tests/unit/sdk/pool/test_allocate.py b/tests/unit/sdk/pool/test_allocate.py
index 1ed53500..eacc1a7b 100644
--- a/tests/unit/sdk/pool/test_allocate.py
+++ b/tests/unit/sdk/pool/test_allocate.py
@@ -1,6 +1,6 @@
from __future__ import annotations
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, cast
import pytest
@@ -11,6 +11,7 @@
from pytest_httpx import HTTPXMock
+ from infrahub_sdk.protocols_base import CoreNode, CoreNodeSync
from infrahub_sdk.schema import NodeSchemaAPI
from tests.unit.sdk.conftest import BothClients
@@ -83,7 +84,7 @@ async def test_allocate_next_ip_address(
},
)
ip_address = await clients.standard.allocate_next_ip_address(
- resource_pool=ip_pool,
+ resource_pool=cast("CoreNode", ip_pool),
identifier="test",
prefix_length=32,
address_type="IpamIPAddress",
@@ -105,7 +106,7 @@ async def test_allocate_next_ip_address(
},
)
ip_address = clients.sync.allocate_next_ip_address(
- resource_pool=ip_pool,
+ resource_pool=cast("CoreNodeSync", ip_pool),
identifier="test",
prefix_length=32,
address_type="IpamIPAddress",
@@ -114,8 +115,8 @@ async def test_allocate_next_ip_address(
)
assert ip_address
- assert str(ip_address.address.value) == "192.0.2.0/32"
- assert ip_address.description.value == "test"
+ assert str(cast("InfrahubNodeSync", ip_address).address.value) == "192.0.2.0/32"
+ assert cast("InfrahubNodeSync", ip_address).description.value == "test"
@pytest.mark.parametrize("client_type", client_types)
@@ -184,7 +185,7 @@ async def test_allocate_next_ip_prefix(
},
)
ip_prefix = await clients.standard.allocate_next_ip_prefix(
- resource_pool=ip_pool,
+ resource_pool=cast("CoreNode", ip_pool),
identifier="test",
prefix_length=31,
prefix_type="IpamIPPrefix",
@@ -206,7 +207,7 @@ async def test_allocate_next_ip_prefix(
},
)
ip_prefix = clients.sync.allocate_next_ip_prefix(
- resource_pool=ip_pool,
+ resource_pool=cast("CoreNodeSync", ip_pool),
identifier="test",
prefix_length=31,
prefix_type="IpamIPPrefix",
@@ -215,5 +216,5 @@ async def test_allocate_next_ip_prefix(
)
assert ip_prefix
- assert str(ip_prefix.prefix.value) == "192.0.2.0/31"
- assert ip_prefix.description.value == "test"
+ assert str(cast("InfrahubNodeSync", ip_prefix).prefix.value) == "192.0.2.0/31" # type: ignore[unresolved-attribute]
+ assert cast("InfrahubNodeSync", ip_prefix).description.value == "test" # type: ignore[unresolved-attribute]
diff --git a/tests/unit/sdk/pool/test_relationship_from_pool.py b/tests/unit/sdk/pool/test_relationship_from_pool.py
index 9ce543dc..f8c44c6b 100644
--- a/tests/unit/sdk/pool/test_relationship_from_pool.py
+++ b/tests/unit/sdk/pool/test_relationship_from_pool.py
@@ -9,15 +9,15 @@
if TYPE_CHECKING:
from typing import Any
- from infrahub_sdk import InfrahubClient
from infrahub_sdk.schema import NodeSchemaAPI
+ from tests.unit.sdk.conftest import BothClients
client_types = ["standard", "sync"]
@pytest.mark.parametrize("client_type", client_types)
async def test_create_input_data_with_resource_pool_relationship(
- client: InfrahubClient,
+ clients: BothClients,
ipaddress_pool_schema: NodeSchemaAPI,
ipam_ipprefix_schema: NodeSchemaAPI,
simple_device_schema: NodeSchemaAPI,
@@ -25,9 +25,9 @@ async def test_create_input_data_with_resource_pool_relationship(
client_type: str,
) -> None:
if client_type == "standard":
- ip_prefix = InfrahubNode(client=client, schema=ipam_ipprefix_schema, data=ipam_ipprefix_data)
+ ip_prefix = InfrahubNode(client=clients.standard, schema=ipam_ipprefix_schema, data=ipam_ipprefix_data)
ip_pool = InfrahubNode(
- client=client,
+ client=clients.standard,
schema=ipaddress_pool_schema,
data={
"id": "pppppppp-pppp-pppp-pppp-pppppppppppp",
@@ -39,14 +39,14 @@ async def test_create_input_data_with_resource_pool_relationship(
},
)
device = InfrahubNode(
- client=client,
+ client=clients.standard,
schema=simple_device_schema,
data={"name": "device-01", "primary_address": ip_pool, "ip_address_pool": ip_pool},
)
else:
- ip_prefix = InfrahubNodeSync(client=client, schema=ipam_ipprefix_schema, data=ipam_ipprefix_data)
+ ip_prefix = InfrahubNodeSync(client=clients.sync, schema=ipam_ipprefix_schema, data=ipam_ipprefix_data)
ip_pool = InfrahubNodeSync(
- client=client,
+ client=clients.sync,
schema=ipaddress_pool_schema,
data={
"id": "pppppppp-pppp-pppp-pppp-pppppppppppp",
@@ -58,7 +58,7 @@ async def test_create_input_data_with_resource_pool_relationship(
},
)
device = InfrahubNodeSync(
- client=client,
+ client=clients.sync,
schema=simple_device_schema,
data={"name": "device-01", "primary_address": ip_pool, "ip_address_pool": ip_pool},
)
@@ -74,7 +74,7 @@ async def test_create_input_data_with_resource_pool_relationship(
@pytest.mark.parametrize("client_type", client_types)
async def test_create_mutation_query_with_resource_pool_relationship(
- client: InfrahubClient,
+ clients: BothClients,
ipaddress_pool_schema: NodeSchemaAPI,
ipam_ipprefix_schema: NodeSchemaAPI,
simple_device_schema: NodeSchemaAPI,
@@ -82,9 +82,9 @@ async def test_create_mutation_query_with_resource_pool_relationship(
client_type: str,
) -> None:
if client_type == "standard":
- ip_prefix = InfrahubNode(client=client, schema=ipam_ipprefix_schema, data=ipam_ipprefix_data)
+ ip_prefix = InfrahubNode(client=clients.standard, schema=ipam_ipprefix_schema, data=ipam_ipprefix_data)
ip_pool = InfrahubNode(
- client=client,
+ client=clients.standard,
schema=ipaddress_pool_schema,
data={
"id": "pppppppp-pppp-pppp-pppp-pppppppppppp",
@@ -96,14 +96,14 @@ async def test_create_mutation_query_with_resource_pool_relationship(
},
)
device = InfrahubNode(
- client=client,
+ client=clients.standard,
schema=simple_device_schema,
data={"name": "device-01", "primary_address": ip_pool, "ip_address_pool": ip_pool},
)
else:
- ip_prefix = InfrahubNodeSync(client=client, schema=ipam_ipprefix_schema, data=ipam_ipprefix_data)
+ ip_prefix = InfrahubNodeSync(client=clients.sync, schema=ipam_ipprefix_schema, data=ipam_ipprefix_data)
ip_pool = InfrahubNodeSync(
- client=client,
+ client=clients.sync,
schema=ipaddress_pool_schema,
data={
"id": "pppppppp-pppp-pppp-pppp-pppppppppppp",
@@ -114,8 +114,8 @@ async def test_create_mutation_query_with_resource_pool_relationship(
"resources": [ip_prefix],
},
)
- device = InfrahubNode(
- client=client,
+ device = InfrahubNodeSync(
+ client=clients.sync,
schema=simple_device_schema,
data={"name": "device-01", "primary_address": ip_pool, "ip_address_pool": ip_pool},
)
diff --git a/tests/unit/sdk/spec/test_object.py b/tests/unit/sdk/spec/test_object.py
index 90b248b1..edb06891 100644
--- a/tests/unit/sdk/spec/test_object.py
+++ b/tests/unit/sdk/spec/test_object.py
@@ -236,7 +236,7 @@ async def test_validate_object_expansion_multiple_ranges_bad_syntax(
]
-@pytest.mark.parametrize("data,is_valid,format", get_relationship_info_testdata)
+@pytest.mark.parametrize(("data", "is_valid", "format"), get_relationship_info_testdata)
async def test_get_relationship_info_tags(
client_with_schema_01: InfrahubClient,
data: dict | list,
diff --git a/tests/unit/sdk/test_attribute_generate_input_data.py b/tests/unit/sdk/test_attribute_generate_input_data.py
index a50a2fe9..394623fc 100644
--- a/tests/unit/sdk/test_attribute_generate_input_data.py
+++ b/tests/unit/sdk/test_attribute_generate_input_data.py
@@ -9,6 +9,7 @@
from infrahub_sdk.node.attribute import Attribute
from infrahub_sdk.protocols_base import CoreNodeBase
from infrahub_sdk.schema import AttributeSchemaAPI
+from infrahub_sdk.schema.main import AttributeKind
# ──────────────────────────────────────────────
# Value resolution: from_pool (dict-based)
@@ -18,7 +19,7 @@
class TestFromPoolDict:
def test_from_pool_with_id(self) -> None:
pool_data = {"id": "pool-uuid-1"}
- attr = Attribute(name="vlan_id", schema=_make_schema("Number"), data={"from_pool": pool_data})
+ attr = Attribute(name="vlan_id", schema=_make_schema(AttributeKind.NUMBER), data={"from_pool": pool_data})
result = attr._generate_input_data()
@@ -27,7 +28,7 @@ def test_from_pool_with_id(self) -> None:
def test_from_pool_with_id_and_identifier(self) -> None:
pool_data = {"id": "pool-uuid-1", "identifier": "test"}
- attr = Attribute(name="vlan_id", schema=_make_schema("Number"), data={"from_pool": pool_data})
+ attr = Attribute(name="vlan_id", schema=_make_schema(AttributeKind.NUMBER), data={"from_pool": pool_data})
result = attr._generate_input_data()
@@ -37,7 +38,7 @@ def test_from_pool_with_id_and_identifier(self) -> None:
def test_from_pool_with_pool_name(self) -> None:
"""from_pool can be a plain string (pool name), e.g. from_pool: 'VLAN ID Pool'."""
attr = Attribute(
- name="vlan_id", schema=_make_schema("Number", optional=True), data={"from_pool": "VLAN ID Pool"}
+ name="vlan_id", schema=_make_schema(AttributeKind.NUMBER, optional=True), data={"from_pool": "VLAN ID Pool"}
)
result = attr._generate_input_data()
@@ -48,7 +49,9 @@ def test_from_pool_with_pool_name(self) -> None:
def test_from_pool_value_is_none(self) -> None:
"""from_pool pops 'from_pool' and sets Attribute.value to None; value should NOT appear in payload."""
- attr = Attribute(name="vlan_id", schema=_make_schema("Number"), data={"from_pool": {"id": "pool-uuid-1"}})
+ attr = Attribute(
+ name="vlan_id", schema=_make_schema(AttributeKind.NUMBER), data={"from_pool": {"id": "pool-uuid-1"}}
+ )
assert attr.value is None
result = attr._generate_input_data()
@@ -64,7 +67,7 @@ class TestFromPoolNode:
def test_pool_node_generates_from_pool(self) -> None:
pool_node = _FakeNode(node_id="node-pool-uuid", is_pool=True)
- attr = Attribute(name="vlan_id", schema=_make_schema("Number"), data=pool_node)
+ attr = Attribute(name="vlan_id", schema=_make_schema(AttributeKind.NUMBER), data=pool_node)
result = attr._generate_input_data()
@@ -74,7 +77,7 @@ def test_pool_node_generates_from_pool(self) -> None:
def test_non_pool_node_treated_as_regular_value(self) -> None:
"""A CoreNodeBase that is NOT a resource pool should go through the normal value path."""
node = _FakeNode(node_id="regular-node-uuid", is_pool=False)
- attr = Attribute(name="vlan_id", schema=_make_schema("Number"), data=node)
+ attr = Attribute(name="vlan_id", schema=_make_schema(AttributeKind.NUMBER), data=node)
result = attr._generate_input_data()
@@ -89,7 +92,7 @@ def test_non_pool_node_treated_as_regular_value(self) -> None:
class TestNullValue:
def test_null_value_not_mutated(self) -> None:
"""None value that was never mutated → empty payload, no properties."""
- attr = Attribute(name="test_attr", schema=_make_schema("Text"), data={"value": None})
+ attr = Attribute(name="test_attr", schema=_make_schema(AttributeKind.TEXT), data={"value": None})
result = attr._generate_input_data()
@@ -99,7 +102,9 @@ def test_null_value_not_mutated(self) -> None:
def test_null_value_mutated_optional(self) -> None:
"""None value on an optional attr that was mutated → explicit null."""
- attr = Attribute(name="test_attr", schema=_make_schema("Text", optional=True), data={"value": "initial"})
+ attr = Attribute(
+ name="test_attr", schema=_make_schema(AttributeKind.TEXT, optional=True), data={"value": "initial"}
+ )
attr.value = None # triggers value_has_been_mutated
result = attr._generate_input_data()
@@ -109,7 +114,9 @@ def test_null_value_mutated_optional(self) -> None:
def test_null_value_mutated_non_optional(self) -> None:
"""None value on a non-optional attr that was mutated → empty payload (same as not mutated)."""
- attr = Attribute(name="test_attr", schema=_make_schema("Text", optional=False), data={"value": "initial"})
+ attr = Attribute(
+ name="test_attr", schema=_make_schema(AttributeKind.TEXT, optional=False), data={"value": "initial"}
+ )
attr.value = None
result = attr._generate_input_data()
@@ -135,7 +142,7 @@ class TestStringValues:
],
)
def test_safe_string(self, value: str) -> None:
- attr = Attribute(name="test_attr", schema=_make_schema("Text"), data=value)
+ attr = Attribute(name="test_attr", schema=_make_schema(AttributeKind.TEXT), data=value)
result = attr._generate_input_data()
@@ -151,7 +158,7 @@ def test_safe_string(self, value: str) -> None:
],
)
def test_unsafe_string_uses_variable_binding(self, value: str) -> None:
- attr = Attribute(name="test_attr", schema=_make_schema("Text"), data=value)
+ attr = Attribute(name="test_attr", schema=_make_schema(AttributeKind.TEXT), data=value)
result = attr._generate_input_data()
@@ -171,7 +178,7 @@ def test_unsafe_string_uses_variable_binding(self, value: str) -> None:
class TestIPValues:
def test_ipv4_interface(self) -> None:
- attr = Attribute(name="address", schema=_make_schema("IPHost"), data={"value": "10.0.0.1/24"})
+ attr = Attribute(name="address", schema=_make_schema(AttributeKind.IPHOST), data={"value": "10.0.0.1/24"})
result = attr._generate_input_data()
@@ -179,21 +186,21 @@ def test_ipv4_interface(self) -> None:
assert result.variables == {}
def test_ipv6_interface(self) -> None:
- attr = Attribute(name="address", schema=_make_schema("IPHost"), data={"value": "2001:db8::1/64"})
+ attr = Attribute(name="address", schema=_make_schema(AttributeKind.IPHOST), data={"value": "2001:db8::1/64"})
result = attr._generate_input_data()
assert result.payload["value"] == "2001:db8::1/64"
def test_ipv4_network(self) -> None:
- attr = Attribute(name="network", schema=_make_schema("IPNetwork"), data={"value": "10.0.0.0/24"})
+ attr = Attribute(name="network", schema=_make_schema(AttributeKind.IPNETWORK), data={"value": "10.0.0.0/24"})
result = attr._generate_input_data()
assert result.payload["value"] == "10.0.0.0/24"
def test_ipv6_network(self) -> None:
- attr = Attribute(name="network", schema=_make_schema("IPNetwork"), data={"value": "2001:db8::/32"})
+ attr = Attribute(name="network", schema=_make_schema(AttributeKind.IPNETWORK), data={"value": "2001:db8::/32"})
result = attr._generate_input_data()
@@ -207,7 +214,7 @@ def test_ipv6_network(self) -> None:
class TestScalarValues:
def test_number_value(self) -> None:
- attr = Attribute(name="vlan_id", schema=_make_schema("Number"), data=42)
+ attr = Attribute(name="vlan_id", schema=_make_schema(AttributeKind.NUMBER), data=42)
result = attr._generate_input_data()
@@ -215,7 +222,7 @@ def test_number_value(self) -> None:
assert result.variables == {}
def test_boolean_value(self) -> None:
- attr = Attribute(name="enabled", schema=_make_schema("Boolean"), data=True)
+ attr = Attribute(name="enabled", schema=_make_schema(AttributeKind.BOOLEAN), data=True)
result = attr._generate_input_data()
@@ -230,14 +237,16 @@ def test_boolean_value(self) -> None:
class TestProperties:
def test_no_properties_set(self) -> None:
"""When no properties are set, payload only has the value."""
- attr = Attribute(name="test_attr", schema=_make_schema("Text"), data="hello")
+ attr = Attribute(name="test_attr", schema=_make_schema(AttributeKind.TEXT), data="hello")
result = attr._generate_input_data()
assert result.payload == {"value": "hello"}
def test_flag_property_is_protected(self) -> None:
- attr = Attribute(name="test_attr", schema=_make_schema("Text"), data={"value": "hello", "is_protected": True})
+ attr = Attribute(
+ name="test_attr", schema=_make_schema(AttributeKind.TEXT), data={"value": "hello", "is_protected": True}
+ )
result = attr._generate_input_data()
@@ -247,7 +256,7 @@ def test_flag_property_is_protected(self) -> None:
def test_object_property_source(self) -> None:
attr = Attribute(
name="test_attr",
- schema=_make_schema("Text"),
+ schema=_make_schema(AttributeKind.TEXT),
data={"value": "hello", "source": {"id": "source-uuid", "display_label": "Git", "__typename": "CoreGit"}},
)
@@ -259,7 +268,7 @@ def test_object_property_source(self) -> None:
def test_object_property_owner(self) -> None:
attr = Attribute(
name="test_attr",
- schema=_make_schema("Text"),
+ schema=_make_schema(AttributeKind.TEXT),
data={
"value": "hello",
"owner": {"id": "owner-uuid", "display_label": "Admin", "__typename": "CoreAccount"},
@@ -273,7 +282,7 @@ def test_object_property_owner(self) -> None:
def test_both_flag_and_object_properties(self) -> None:
attr = Attribute(
name="test_attr",
- schema=_make_schema("Text"),
+ schema=_make_schema(AttributeKind.TEXT),
data={
"value": "hello",
"is_protected": True,
@@ -291,7 +300,7 @@ def test_properties_not_appended_for_null_value(self) -> None:
"""When need_additional_properties is False (null non-mutated), properties are ignored."""
attr = Attribute(
name="test_attr",
- schema=_make_schema("Text"),
+ schema=_make_schema(AttributeKind.TEXT),
data={
"value": None,
"is_protected": True,
@@ -308,7 +317,7 @@ def test_properties_appended_for_from_pool(self) -> None:
"""from_pool payloads have need_additional_properties=True, so properties are included."""
attr = Attribute(
name="vlan_id",
- schema=_make_schema("Number"),
+ schema=_make_schema(AttributeKind.NUMBER),
data={"from_pool": {"id": "pool-uuid"}, "is_protected": True},
)
@@ -325,14 +334,14 @@ def test_properties_appended_for_from_pool(self) -> None:
class TestToDictIntegration:
def test_to_dict_simple_value(self) -> None:
- attr = Attribute(name="test_attr", schema=_make_schema("Text"), data="hello")
+ attr = Attribute(name="test_attr", schema=_make_schema(AttributeKind.TEXT), data="hello")
result = attr._generate_input_data().to_dict()
assert result == {"data": {"value": "hello"}, "variables": {}}
def test_to_dict_with_variables(self) -> None:
- attr = Attribute(name="test_attr", schema=_make_schema("Text"), data='has "quotes"')
+ attr = Attribute(name="test_attr", schema=_make_schema(AttributeKind.TEXT), data='has "quotes"')
result = attr._generate_input_data().to_dict()
@@ -344,7 +353,7 @@ def test_to_dict_with_variables(self) -> None:
assert result["data"]["value"] == f"${var_name}"
-def _make_schema(kind: str = "Text", optional: bool = False) -> AttributeSchemaAPI:
+def _make_schema(kind: AttributeKind = AttributeKind.TEXT, optional: bool = False) -> AttributeSchemaAPI:
return AttributeSchemaAPI(name="test_attr", kind=kind, optional=optional)
diff --git a/tests/unit/sdk/test_client.py b/tests/unit/sdk/test_client.py
index 1e883f95..13346c02 100644
--- a/tests/unit/sdk/test_client.py
+++ b/tests/unit/sdk/test_client.py
@@ -278,7 +278,7 @@ async def test_method_all_multiple_pages(
assert len(repos) == 5
-@pytest.mark.parametrize("client_type, use_parallel", batch_client_types)
+@pytest.mark.parametrize(("client_type", "use_parallel"), batch_client_types)
async def test_method_all_batching(
clients: BothClients,
mock_query_location_batch_count: HTTPXMock,
diff --git a/tests/unit/sdk/test_file_handler.py b/tests/unit/sdk/test_file_handler.py
new file mode 100644
index 00000000..ae59c842
--- /dev/null
+++ b/tests/unit/sdk/test_file_handler.py
@@ -0,0 +1,305 @@
+from __future__ import annotations
+
+import tempfile
+from io import BytesIO
+from pathlib import Path
+from typing import TYPE_CHECKING
+
+import anyio
+import httpx
+import pytest
+
+from infrahub_sdk.exceptions import AuthenticationError, NodeNotFoundError
+from infrahub_sdk.file_handler import FileHandler, FileHandlerBase, FileHandlerSync, PreparedFile
+
+if TYPE_CHECKING:
+ from pytest_httpx import HTTPXMock
+
+ from tests.unit.sdk.conftest import BothClients
+
+
+FILE_CONTENT_BYTES = b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR..."
+NODE_ID = "test-node-123"
+
+
+async def test_prepare_upload_with_bytes() -> None:
+ """Test preparing upload with bytes content (async)."""
+ content = b"test file content"
+ prepared = await FileHandlerBase.prepare_upload(content=content, name="test.txt")
+
+ assert isinstance(prepared, PreparedFile)
+ assert prepared.file_object is not None
+ assert isinstance(prepared.file_object, BytesIO)
+ assert prepared.filename == "test.txt"
+ assert prepared.should_close is False
+ assert prepared.file_object.read() == content
+
+
+async def test_prepare_upload_with_bytes_default_name() -> None:
+ """Test preparing upload with bytes content and no name (async)."""
+ content = b"test file content"
+ prepared = await FileHandlerBase.prepare_upload(content=content)
+
+ assert prepared.file_object is not None
+ assert prepared.filename == "uploaded_file"
+ assert prepared.should_close is False
+
+
+async def test_prepare_upload_with_path() -> None:
+ """Test preparing upload with Path content (async, opens in thread pool)."""
+ with tempfile.NamedTemporaryFile(suffix=".txt") as tmp:
+ tmp.write(b"test content from file")
+ tmp.flush()
+ tmp_path = Path(tmp.name)
+
+ prepared = await FileHandlerBase.prepare_upload(content=tmp_path)
+
+ assert prepared.file_object is not None
+ assert prepared.filename == tmp_path.name
+ assert prepared.should_close is True
+ assert prepared.file_object.read() == b"test content from file"
+ prepared.file_object.close()
+
+
+async def test_prepare_upload_with_path_custom_name() -> None:
+ """Test preparing upload with Path content and custom name (async)."""
+ with tempfile.NamedTemporaryFile(suffix=".txt") as tmp:
+ tmp.write(b"test content")
+ tmp.flush()
+ tmp_path = Path(tmp.name)
+
+ prepared = await FileHandlerBase.prepare_upload(content=tmp_path, name="custom_name.txt")
+
+ assert prepared.filename == "custom_name.txt"
+ assert prepared.file_object
+ prepared.file_object.close()
+
+
+async def test_prepare_upload_with_binary_io() -> None:
+ """Test preparing upload with BinaryIO content (async)."""
+ content = BytesIO(b"binary io content")
+ prepared = await FileHandlerBase.prepare_upload(content=content, name="binary.bin")
+
+ assert prepared.file_object is content
+ assert prepared.filename == "binary.bin"
+ assert prepared.should_close is False
+
+
+async def test_prepare_upload_with_none() -> None:
+ """Test preparing upload with None content (async)."""
+ prepared = await FileHandlerBase.prepare_upload(content=None)
+
+ assert prepared.file_object is None
+ assert prepared.filename is None
+ assert prepared.should_close is False
+
+
+def test_prepare_upload_sync_with_bytes() -> None:
+ """Test preparing upload with bytes content (sync)."""
+ content = b"test file content"
+ prepared = FileHandlerBase.prepare_upload_sync(content=content, name="test.txt")
+
+ assert isinstance(prepared, PreparedFile)
+ assert prepared.file_object is not None
+ assert isinstance(prepared.file_object, BytesIO)
+ assert prepared.filename == "test.txt"
+ assert prepared.should_close is False
+ assert prepared.file_object.read() == content
+
+
+def test_prepare_upload_sync_with_bytes_default_name() -> None:
+ """Test preparing upload with bytes content and no name (sync)."""
+ content = b"test file content"
+ prepared = FileHandlerBase.prepare_upload_sync(content=content)
+
+ assert prepared.file_object is not None
+ assert prepared.filename == "uploaded_file"
+ assert prepared.should_close is False
+
+
+def test_prepare_upload_sync_with_path() -> None:
+ """Test preparing upload with Path content (sync)."""
+ with tempfile.NamedTemporaryFile(suffix=".txt") as tmp:
+ tmp.write(b"test content from file")
+ tmp.flush()
+ tmp_path = Path(tmp.name)
+
+ prepared = FileHandlerBase.prepare_upload_sync(content=tmp_path)
+
+ assert prepared.file_object is not None
+ assert prepared.filename == tmp_path.name
+ assert prepared.should_close is True
+ assert prepared.file_object.read() == b"test content from file"
+ prepared.file_object.close()
+
+
+def test_prepare_upload_sync_with_path_custom_name() -> None:
+ """Test preparing upload with Path content and custom name (sync)."""
+ with tempfile.NamedTemporaryFile(suffix=".txt") as tmp:
+ tmp.write(b"test content")
+ tmp.flush()
+ tmp_path = Path(tmp.name)
+
+ prepared = FileHandlerBase.prepare_upload_sync(content=tmp_path, name="custom_name.txt")
+
+ assert prepared.filename == "custom_name.txt"
+ assert prepared.file_object
+ prepared.file_object.close()
+
+
+def test_prepare_upload_sync_with_binary_io() -> None:
+ """Test preparing upload with BinaryIO content (sync)."""
+ content = BytesIO(b"binary io content")
+ prepared = FileHandlerBase.prepare_upload_sync(content=content, name="binary.bin")
+
+ assert prepared.file_object is content
+ assert prepared.filename == "binary.bin"
+ assert prepared.should_close is False
+
+
+def test_prepare_upload_sync_with_none() -> None:
+ """Test preparing upload with None content (sync)."""
+ prepared = FileHandlerBase.prepare_upload_sync(content=None)
+
+ assert prepared.file_object is None
+ assert prepared.filename is None
+ assert prepared.should_close is False
+
+
+def test_handle_error_response_401() -> None:
+ """Test handling 401 authentication error."""
+ response = httpx.Response(status_code=401, json={"errors": [{"message": "Invalid token"}]})
+ exc = httpx.HTTPStatusError(message="Unauthorized", request=httpx.Request("GET", "http://test"), response=response)
+
+ with pytest.raises(AuthenticationError) as excinfo:
+ FileHandlerBase.handle_error_response(exc=exc)
+
+ assert "Invalid token" in str(excinfo.value)
+
+
+def test_handle_error_response_403() -> None:
+ """Test handling 403 forbidden error."""
+ response = httpx.Response(status_code=403, json={"errors": [{"message": "Access denied"}]})
+ exc = httpx.HTTPStatusError(message="Forbidden", request=httpx.Request("GET", "http://test"), response=response)
+
+ with pytest.raises(AuthenticationError) as excinfo:
+ FileHandlerBase.handle_error_response(exc=exc)
+
+ assert "Access denied" in str(excinfo.value)
+
+
+def test_handle_error_response_404() -> None:
+ """Test handling 404 not found error."""
+ response = httpx.Response(status_code=404, json={"detail": "File not found with ID abc123"})
+ exc = httpx.HTTPStatusError(message="Not Found", request=httpx.Request("GET", "http://test"), response=response)
+
+ with pytest.raises(NodeNotFoundError) as excinfo:
+ FileHandlerBase.handle_error_response(exc=exc)
+
+ assert "File not found with ID abc123" in str(excinfo.value)
+
+
+def test_handle_error_response_500() -> None:
+ """Test handling 500 server error (re-raises)."""
+ response = httpx.Response(status_code=500, json={"error": "Internal server error"})
+ exc = httpx.HTTPStatusError(message="Server Error", request=httpx.Request("GET", "http://test"), response=response)
+
+ with pytest.raises(httpx.HTTPStatusError):
+ FileHandlerBase.handle_error_response(exc=exc)
+
+
+def test_handle_response_success() -> None:
+ """Test handling successful response."""
+ request = httpx.Request("GET", "http://test")
+ response = httpx.Response(status_code=200, content=FILE_CONTENT_BYTES, request=request)
+
+ result = FileHandlerBase.handle_response(resp=response)
+
+ assert result == FILE_CONTENT_BYTES
+
+
+@pytest.fixture
+def mock_download_success(httpx_mock: HTTPXMock) -> HTTPXMock:
+ """Mock successful file download."""
+ httpx_mock.add_response(
+ method="GET",
+ url="http://mock/api/storage/files/test-node-123?branch=main",
+ content=FILE_CONTENT_BYTES,
+ headers={"Content-Type": "application/octet-stream"},
+ )
+ return httpx_mock
+
+
+@pytest.fixture
+def mock_download_stream(httpx_mock: HTTPXMock) -> HTTPXMock:
+ """Mock successful streaming file download."""
+ httpx_mock.add_response(
+ method="GET",
+ url="http://mock/api/storage/files/stream-node?branch=main",
+ content=FILE_CONTENT_BYTES,
+ headers={"Content-Type": "application/octet-stream"},
+ )
+ return httpx_mock
+
+
+client_types = ["standard", "sync"]
+
+
+@pytest.mark.parametrize("client_type", client_types)
+async def test_file_handler_download_to_memory(
+ client_type: str, clients: BothClients, mock_download_success: HTTPXMock
+) -> None:
+ """Test downloading file to memory via FileHandler."""
+ client = getattr(clients, client_type)
+
+ if client_type == "standard":
+ handler = FileHandler(client=client)
+ content = await handler.download(node_id=NODE_ID, branch="main")
+ else:
+ handler = FileHandlerSync(client=client)
+ content = handler.download(node_id=NODE_ID, branch="main")
+
+ assert content == FILE_CONTENT_BYTES
+
+
+@pytest.mark.parametrize("client_type", client_types)
+async def test_file_handler_download_to_disk(
+ client_type: str, clients: BothClients, mock_download_stream: HTTPXMock
+) -> None:
+ """Test streaming file download to disk via FileHandler."""
+ client = getattr(clients, client_type)
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ dest_path = Path(tmpdir) / "downloaded.bin"
+
+ if client_type == "standard":
+ handler = FileHandler(client=client)
+ bytes_written = await handler.download(node_id="stream-node", branch="main", dest=dest_path)
+ else:
+ handler = FileHandlerSync(client=client)
+ bytes_written = handler.download(node_id="stream-node", branch="main", dest=dest_path)
+
+ assert bytes_written == len(FILE_CONTENT_BYTES)
+ assert await anyio.Path(dest_path).read_bytes() == FILE_CONTENT_BYTES
+
+
+@pytest.mark.parametrize("client_type", client_types)
+async def test_file_handler_build_url_with_branch(client_type: str, clients: BothClients) -> None:
+ """Test URL building with branch parameter."""
+ client = getattr(clients, client_type)
+
+ handler = FileHandler(client=client) if client_type == "standard" else FileHandlerSync(client=client)
+
+ url = handler._build_url(node_id="node-123", branch="feature-branch")
+ assert url == "http://mock/api/storage/files/node-123?branch=feature-branch"
+
+
+@pytest.mark.parametrize("client_type", client_types)
+async def test_file_handler_build_url_without_branch(client_type: str, clients: BothClients) -> None:
+ """Test URL building without branch parameter."""
+ client = getattr(clients, client_type)
+
+ handler = FileHandler(client=client) if client_type == "standard" else FileHandlerSync(client=client)
+
+ url = handler._build_url(node_id="node-456", branch=None)
+ assert url == "http://mock/api/storage/files/node-456"
diff --git a/tests/unit/sdk/test_file_object.py b/tests/unit/sdk/test_file_object.py
new file mode 100644
index 00000000..f6267003
--- /dev/null
+++ b/tests/unit/sdk/test_file_object.py
@@ -0,0 +1,295 @@
+import tempfile
+from pathlib import Path
+
+import anyio
+import pytest
+from pytest_httpx import HTTPXMock
+
+from infrahub_sdk.exceptions import FeatureNotSupportedError
+from infrahub_sdk.node import InfrahubNode, InfrahubNodeSync
+from infrahub_sdk.schema import NodeSchemaAPI
+from tests.unit.sdk.conftest import BothClients
+
+pytestmark = pytest.mark.httpx_mock(can_send_already_matched_responses=True)
+
+client_types = ["standard", "sync"]
+
+FILE_CONTENT = b"Test file content"
+FILE_NAME = "contract.pdf"
+FILE_MIME_TYPE = "application/pdf"
+
+
+@pytest.fixture
+def mock_node_create_with_file(httpx_mock: HTTPXMock) -> HTTPXMock:
+ """Mock the HTTP response for node create with file upload."""
+ httpx_mock.add_response(
+ method="POST",
+ json={
+ "data": {
+ "NetworkCircuitContractCreate": {
+ "ok": True,
+ "object": {
+ "id": "new-file-node-123",
+ "display_label": FILE_NAME,
+ "file_name": {"value": FILE_NAME},
+ "checksum": {"value": "abc123checksum"},
+ "file_size": {"value": len(FILE_CONTENT)},
+ "file_type": {"value": FILE_MIME_TYPE},
+ "storage_id": {"value": "storage-xyz-789"},
+ "contract_start": {"value": "2024-01-01T00:00:00Z"},
+ "contract_end": {"value": "2024-12-31T23:59:59Z"},
+ },
+ }
+ }
+ },
+ is_reusable=True,
+ )
+ return httpx_mock
+
+
+@pytest.fixture
+def mock_node_update_with_file(httpx_mock: HTTPXMock) -> HTTPXMock:
+ """Mock the HTTP response for node update with file upload."""
+ httpx_mock.add_response(
+ method="POST",
+ json={
+ "data": {
+ "NetworkCircuitContractUpdate": {
+ "ok": True,
+ "object": {
+ "id": "existing-file-node-456",
+ "display_label": FILE_NAME,
+ "file_name": {"value": FILE_NAME},
+ "checksum": {"value": "updated123checksum"},
+ "file_size": {"value": len(FILE_CONTENT)},
+ "file_type": {"value": FILE_MIME_TYPE},
+ "storage_id": {"value": "storage-updated-789"},
+ "contract_start": {"value": "2024-01-01T00:00:00Z"},
+ "contract_end": {"value": "2024-12-31T23:59:59Z"},
+ },
+ }
+ }
+ },
+ is_reusable=True,
+ )
+ return httpx_mock
+
+
+@pytest.mark.parametrize("client_type", client_types)
+async def test_node_create_with_file_uses_multipart(
+ client_type: str, clients: BothClients, file_object_schema: NodeSchemaAPI, mock_node_create_with_file: HTTPXMock
+) -> None:
+ """Test that node.save() for create with file content sends a multipart request."""
+ client = getattr(clients, client_type)
+
+ if client_type == "standard":
+ node = InfrahubNode(client=client, schema=file_object_schema, branch="main")
+ else:
+ node = InfrahubNodeSync(client=client, schema=file_object_schema, branch="main")
+
+ node.contract_start.value = "2024-01-01T00:00:00Z" # type: ignore[union-attr]
+ node.contract_end.value = "2024-12-31T23:59:59Z" # type: ignore[union-attr]
+ node.upload_from_bytes(content=FILE_CONTENT, name=FILE_NAME)
+
+ if isinstance(node, InfrahubNode):
+ await node.save()
+ else:
+ node.save()
+
+ requests = mock_node_create_with_file.get_requests()
+ assert len(requests) == 1
+ assert requests[0].headers.get("x-infrahub-tracker") == "mutation-networkcircuitcontract-create"
+ assert requests[0].headers.get("content-type").startswith("multipart/form-data;")
+ assert b"Content-Disposition: form-data" in requests[0].content
+ assert f'filename="{FILE_NAME}"'.encode() in requests[0].content
+
+
+@pytest.mark.parametrize("client_type", client_types)
+async def test_node_update_with_file_uses_multipart(
+ client_type: str, clients: BothClients, file_object_schema: NodeSchemaAPI, mock_node_update_with_file: HTTPXMock
+) -> None:
+ """Test that node.save() for update with file content sends a multipart request."""
+ client = getattr(clients, client_type)
+
+ if client_type == "standard":
+ node = InfrahubNode(client=client, schema=file_object_schema, branch="main")
+ else:
+ node = InfrahubNodeSync(client=client, schema=file_object_schema, branch="main")
+
+ # Simulate an existing node
+ node.id = "existing-file-node-456"
+ node._existing = True
+ node.contract_start.value = "2024-01-01T00:00:00Z" # type: ignore[union-attr]
+ node.contract_end.value = "2024-12-31T23:59:59Z" # type: ignore[union-attr]
+ node.upload_from_bytes(content=FILE_CONTENT, name=FILE_NAME)
+
+ if isinstance(node, InfrahubNode):
+ await node.save()
+ else:
+ node.save()
+
+ requests = mock_node_update_with_file.get_requests()
+ assert len(requests) == 1
+ assert requests[0].headers.get("x-infrahub-tracker") == "mutation-networkcircuitcontract-update"
+ assert requests[0].headers.get("content-type").startswith("multipart/form-data;")
+ assert b"Content-Disposition: form-data" in requests[0].content
+ assert f'filename="{FILE_NAME}"'.encode() in requests[0].content
+
+
+@pytest.mark.parametrize("client_type", client_types)
+async def test_node_create_file_object_without_file_raises(
+ client_type: str, clients: BothClients, file_object_schema: NodeSchemaAPI
+) -> None:
+ """Test that creating a FileObject node without file content raises an error."""
+ client = getattr(clients, client_type)
+
+ if client_type == "standard":
+ node = InfrahubNode(client=client, schema=file_object_schema, branch="main")
+ else:
+ node = InfrahubNodeSync(client=client, schema=file_object_schema, branch="main")
+
+ node.contract_start.value = "2024-01-01T00:00:00Z" # type: ignore[union-attr]
+ node.contract_end.value = "2024-12-31T23:59:59Z" # type: ignore[union-attr]
+
+ with pytest.raises(ValueError, match=r"Cannot create .* without file content"):
+ if isinstance(node, InfrahubNode):
+ await node.save()
+ else:
+ node.save()
+
+
+@pytest.mark.parametrize("client_type", client_types)
+async def test_node_save_clears_file_after_upload(
+ client_type: str, clients: BothClients, file_object_schema: NodeSchemaAPI, mock_node_create_with_file: HTTPXMock
+) -> None:
+ """Test that file content is cleared after successful save."""
+ client = getattr(clients, client_type)
+
+ if client_type == "standard":
+ node = InfrahubNode(client=client, schema=file_object_schema, branch="main")
+ else:
+ node = InfrahubNodeSync(client=client, schema=file_object_schema, branch="main")
+
+ node.contract_start.value = "2024-01-01T00:00:00Z" # type: ignore[union-attr]
+ node.contract_end.value = "2024-12-31T23:59:59Z" # type: ignore[union-attr]
+
+ node.upload_from_bytes(content=FILE_CONTENT, name=FILE_NAME)
+ assert node._file_content is not None
+ assert node._file_name is not None
+
+ if isinstance(node, InfrahubNode):
+ await node.save()
+ else:
+ node.save()
+
+ # File content should be cleared after save
+ assert node._file_content is None
+ assert node._file_name is None
+
+
+@pytest.fixture
+def mock_download_file(httpx_mock: HTTPXMock) -> HTTPXMock:
+ httpx_mock.add_response(
+ method="GET",
+ url="http://mock/api/storage/files/file-node-123?branch=main",
+ content=FILE_CONTENT,
+ headers={"Content-Type": FILE_MIME_TYPE, "Content-Disposition": f'attachment; filename="{FILE_NAME}"'},
+ )
+ return httpx_mock
+
+
+@pytest.mark.parametrize("client_type", client_types)
+async def test_node_download_file(
+ client_type: str, clients: BothClients, file_object_schema: NodeSchemaAPI, mock_download_file: HTTPXMock
+) -> None:
+ """Test downloading a file from a FileObject node."""
+ client = getattr(clients, client_type)
+
+ if client_type == "standard":
+ node = InfrahubNode(client=client, schema=file_object_schema, branch="main")
+ else:
+ node = InfrahubNodeSync(client=client, schema=file_object_schema, branch="main")
+
+ node.id = "file-node-123"
+ if isinstance(node, InfrahubNode):
+ content = await node.download_file()
+ else:
+ content = node.download_file()
+
+ assert content == FILE_CONTENT
+
+
+@pytest.fixture
+def mock_download_file_to_disk(httpx_mock: HTTPXMock) -> HTTPXMock:
+ httpx_mock.add_response(
+ method="GET",
+ url="http://mock/api/storage/files/file-node-stream?branch=main",
+ content=FILE_CONTENT,
+ headers={"Content-Type": FILE_MIME_TYPE, "Content-Disposition": f'attachment; filename="{FILE_NAME}"'},
+ )
+ return httpx_mock
+
+
+@pytest.mark.parametrize("client_type", client_types)
+async def test_node_download_file_to_disk(
+ client_type: str, clients: BothClients, file_object_schema: NodeSchemaAPI, mock_download_file_to_disk: HTTPXMock
+) -> None:
+ """Test downloading a file from a FileObject node directly to disk."""
+ client = getattr(clients, client_type)
+
+ if client_type == "standard":
+ node = InfrahubNode(client=client, schema=file_object_schema, branch="main")
+ else:
+ node = InfrahubNodeSync(client=client, schema=file_object_schema, branch="main")
+
+ node.id = "file-node-stream"
+ with tempfile.TemporaryDirectory() as tmpdir:
+ dest_path = Path(tmpdir) / "downloaded.bin"
+
+ if isinstance(node, InfrahubNode):
+ bytes_written = await node.download_file(dest=dest_path)
+ else:
+ bytes_written = node.download_file(dest=dest_path)
+
+ assert bytes_written == len(FILE_CONTENT)
+ assert await anyio.Path(dest_path).read_bytes() == FILE_CONTENT
+
+
+@pytest.mark.parametrize("client_type", client_types)
+async def test_node_download_file_not_file_object_raises(
+ client_type: str, clients: BothClients, non_file_object_schema: NodeSchemaAPI
+) -> None:
+ """Test that download_file raises error on non-FileObject nodes."""
+ client = getattr(clients, client_type)
+
+ if client_type == "standard":
+ node = InfrahubNode(client=client, schema=non_file_object_schema, branch="main")
+ with pytest.raises(
+ FeatureNotSupportedError,
+ match=r"calling download_file is only supported for nodes that inherit from CoreFileObject",
+ ):
+ await node.download_file()
+ else:
+ node = InfrahubNodeSync(client=client, schema=non_file_object_schema, branch="main")
+ with pytest.raises(
+ FeatureNotSupportedError,
+ match=r"calling download_file is only supported for nodes that inherit from CoreFileObject",
+ ):
+ node.download_file()
+
+
+@pytest.mark.parametrize("client_type", client_types)
+async def test_node_download_file_unsaved_node_raises(
+ client_type: str, clients: BothClients, file_object_schema: NodeSchemaAPI
+) -> None:
+ """Test that download_file raises error on unsaved nodes."""
+ client = getattr(clients, client_type)
+
+ if client_type == "standard":
+ node = InfrahubNode(client=client, schema=file_object_schema, branch="main")
+ with pytest.raises(ValueError, match=r"Cannot download file for a node that hasn't been saved yet"):
+ await node.download_file()
+ else:
+ node = InfrahubNodeSync(client=client, schema=file_object_schema, branch="main")
+ with pytest.raises(ValueError, match=r"Cannot download file for a node that hasn't been saved yet"):
+ node.download_file()
diff --git a/tests/unit/sdk/test_node.py b/tests/unit/sdk/test_node.py
index 3db48edf..ad3d77eb 100644
--- a/tests/unit/sdk/test_node.py
+++ b/tests/unit/sdk/test_node.py
@@ -2,11 +2,14 @@
import inspect
import ipaddress
+import tempfile
+from io import BytesIO
+from pathlib import Path
from typing import TYPE_CHECKING
import pytest
-from infrahub_sdk.exceptions import NodeNotFoundError
+from infrahub_sdk.exceptions import FeatureNotSupportedError, NodeNotFoundError
from infrahub_sdk.node import (
InfrahubNode,
InfrahubNodeBase,
@@ -129,7 +132,9 @@ async def test_validate_method_signature(
)
-@pytest.mark.parametrize("hfid,expected_kind,expected_hfid", [("BuiltinLocation__JFK1", "BuiltinLocation", ["JFK1"])])
+@pytest.mark.parametrize(
+ ("hfid", "expected_kind", "expected_hfid"), [("BuiltinLocation__JFK1", "BuiltinLocation", ["JFK1"])]
+)
def test_parse_human_friendly_id(hfid: str, expected_kind: str, expected_hfid: list[str]) -> None:
kind, hfid = parse_human_friendly_id(hfid)
assert kind == expected_kind
@@ -2988,3 +2993,253 @@ def test_relationship_manager_generate_query_data_without_include_metadata() ->
assert "count" in data
assert "edges" in data
assert "node" in data["edges"]
+
+
+@pytest.mark.parametrize("client_type", client_types)
+async def test_node_is_file_object_true(
+ client_type: str, clients: BothClients, file_object_schema: NodeSchemaAPI
+) -> None:
+ """Test that is_file_object returns True for nodes inheriting from CoreFileObject."""
+ if client_type == "standard":
+ node = InfrahubNode(client=clients.standard, schema=file_object_schema, branch="main")
+ else:
+ node = InfrahubNodeSync(client=clients.sync, schema=file_object_schema, branch="main")
+
+ assert node.is_file_object()
+
+
+@pytest.mark.parametrize("client_type", client_types)
+async def test_node_is_file_object_false(
+ client_type: str, clients: BothClients, non_file_object_schema: NodeSchemaAPI
+) -> None:
+ """Test that is_file_object returns False for regular nodes."""
+ if client_type == "standard":
+ node = InfrahubNode(client=clients.standard, schema=non_file_object_schema, branch="main")
+ else:
+ node = InfrahubNodeSync(client=clients.sync, schema=non_file_object_schema, branch="main")
+
+ assert not node.is_file_object()
+
+
+@pytest.mark.parametrize("client_type", client_types)
+async def test_node_upload_from_bytes_with_bytes(
+ client_type: str, clients: BothClients, file_object_schema: NodeSchemaAPI
+) -> None:
+ """Test that upload_from_bytes works with bytes on FileObject nodes."""
+ if client_type == "standard":
+ node = InfrahubNode(client=clients.standard, schema=file_object_schema, branch="main")
+ else:
+ node = InfrahubNodeSync(client=clients.sync, schema=file_object_schema, branch="main")
+
+ file_content = b"PDF content here"
+ node.upload_from_bytes(content=file_content, name="contract.pdf")
+
+ assert node._file_content == file_content
+ assert node._file_name == "contract.pdf"
+
+
+@pytest.mark.parametrize("client_type", client_types)
+async def test_node_upload_from_path(client_type: str, clients: BothClients, file_object_schema: NodeSchemaAPI) -> None:
+ """Test that upload_from_path works with a Path object."""
+ if client_type == "standard":
+ node = InfrahubNode(client=clients.standard, schema=file_object_schema, branch="main")
+ else:
+ node = InfrahubNodeSync(client=clients.sync, schema=file_object_schema, branch="main")
+
+ file_content = b"Content from file path"
+ with tempfile.NamedTemporaryFile(suffix=".pdf") as tmp:
+ tmp.write(file_content)
+ tmp.flush()
+ tmp_path = Path(tmp.name)
+
+ node.upload_from_path(path=tmp_path)
+ assert node._file_content == tmp_path
+ assert node._file_name == tmp_path.name
+
+
+@pytest.mark.parametrize("client_type", client_types)
+async def test_node_upload_from_bytes_with_binary_io(
+ client_type: str, clients: BothClients, file_object_schema: NodeSchemaAPI
+) -> None:
+ """Test that upload_from_bytes works with a BinaryIO object."""
+ if client_type == "standard":
+ node = InfrahubNode(client=clients.standard, schema=file_object_schema, branch="main")
+ else:
+ node = InfrahubNodeSync(client=clients.sync, schema=file_object_schema, branch="main")
+
+ file_content = b"Content from BinaryIO"
+ file_obj = BytesIO(file_content)
+
+ node.upload_from_bytes(content=file_obj, name="uploaded.pdf")
+
+ assert node._file_content == file_obj
+ assert node._file_name == "uploaded.pdf"
+
+
+@pytest.mark.parametrize("client_type", client_types)
+async def test_node_upload_from_bytes_on_non_file_object_raises(
+ client_type: str, clients: BothClients, non_file_object_schema: NodeSchemaAPI
+) -> None:
+ """Test that upload_from_bytes raises FeatureNotSupportedError on non-FileObject nodes."""
+ if client_type == "standard":
+ node = InfrahubNode(client=clients.standard, schema=non_file_object_schema, branch="main")
+ else:
+ node = InfrahubNodeSync(client=clients.sync, schema=non_file_object_schema, branch="main")
+
+ with pytest.raises(FeatureNotSupportedError, match=r"File upload is not supported"):
+ node.upload_from_bytes(content=b"some content", name="file.txt")
+
+
+@pytest.mark.parametrize("client_type", client_types)
+async def test_node_upload_from_path_on_non_file_object_raises(
+ client_type: str, clients: BothClients, non_file_object_schema: NodeSchemaAPI
+) -> None:
+ """Test that upload_from_path raises FeatureNotSupportedError on non-FileObject nodes."""
+ if client_type == "standard":
+ node = InfrahubNode(client=clients.standard, schema=non_file_object_schema, branch="main")
+ else:
+ node = InfrahubNodeSync(client=clients.sync, schema=non_file_object_schema, branch="main")
+
+ with pytest.raises(FeatureNotSupportedError, match=r"File upload is not supported"):
+ node.upload_from_path(path=Path("/some/file.txt"))
+
+
+@pytest.mark.parametrize("client_type", client_types)
+async def test_node_clear_file(client_type: str, clients: BothClients, file_object_schema: NodeSchemaAPI) -> None:
+ """Test that clear_file removes pending file content."""
+ if client_type == "standard":
+ node = InfrahubNode(client=clients.standard, schema=file_object_schema, branch="main")
+ else:
+ node = InfrahubNodeSync(client=clients.sync, schema=file_object_schema, branch="main")
+
+ file_content = b"Test content"
+ file_name = "file.txt"
+
+ node.upload_from_bytes(content=file_content, name=file_name)
+ assert node._file_content == file_content
+ assert node._file_name == file_name
+
+ node.clear_file()
+ assert node._file_content is None
+ assert node._file_name is None
+
+
+@pytest.mark.parametrize("client_type", client_types)
+async def test_node_get_file_for_upload_bytes(
+ client_type: str, clients: BothClients, file_object_schema: NodeSchemaAPI
+) -> None:
+ """Test _get_file_for_upload with bytes returns PreparedFile with BytesIO."""
+ if client_type == "standard":
+ node = InfrahubNode(client=clients.standard, schema=file_object_schema, branch="main")
+ else:
+ node = InfrahubNodeSync(client=clients.sync, schema=file_object_schema, branch="main")
+
+ file_content = b"Test content"
+ file_name = "test.txt"
+ node.upload_from_bytes(content=file_content, name=file_name)
+
+ if isinstance(node, InfrahubNode):
+ prepared = await node._get_file_for_upload()
+ else:
+ prepared = node._get_file_for_upload_sync()
+
+ assert prepared.file_object
+ assert prepared.filename == file_name
+ assert not prepared.should_close
+ assert prepared.file_object.read() == file_content
+
+
+@pytest.mark.parametrize("client_type", client_types)
+async def test_node_get_file_for_upload_path(
+ client_type: str, clients: BothClients, file_object_schema: NodeSchemaAPI
+) -> None:
+ """Test _get_file_for_upload with Path returns PreparedFile with opened file handle."""
+ if client_type == "standard":
+ node = InfrahubNode(client=clients.standard, schema=file_object_schema, branch="main")
+ else:
+ node = InfrahubNodeSync(client=clients.sync, schema=file_object_schema, branch="main")
+
+ file_content = b"Content from path"
+ with tempfile.NamedTemporaryFile(suffix=".pdf") as tmp:
+ tmp.write(file_content)
+ tmp.flush()
+ tmp_path = Path(tmp.name)
+
+ node.upload_from_path(path=tmp_path)
+
+ if isinstance(node, InfrahubNode):
+ prepared = await node._get_file_for_upload()
+ else:
+ prepared = node._get_file_for_upload_sync()
+
+ assert prepared.file_object
+ assert prepared.filename == tmp_path.name
+ assert prepared.should_close # Path files should be closed after upload
+ assert prepared.file_object.read() == file_content
+ prepared.file_object.close()
+
+
+@pytest.mark.parametrize("client_type", client_types)
+async def test_node_get_file_for_upload_binary_io(
+ client_type: str, clients: BothClients, file_object_schema: NodeSchemaAPI
+) -> None:
+ """Test _get_file_for_upload with BinaryIO returns PreparedFile with the same object."""
+ if client_type == "standard":
+ node = InfrahubNode(client=clients.standard, schema=file_object_schema, branch="main")
+ else:
+ node = InfrahubNodeSync(client=clients.sync, schema=file_object_schema, branch="main")
+
+ file_content = b"Content from BinaryIO"
+ file_name = "test.bin"
+ file_obj_input = BytesIO(file_content)
+ node.upload_from_bytes(content=file_obj_input, name=file_name)
+
+ if isinstance(node, InfrahubNode):
+ prepared = await node._get_file_for_upload()
+ else:
+ prepared = node._get_file_for_upload_sync()
+
+ assert prepared.file_object is file_obj_input # Should be the same object
+ assert prepared.filename == file_name
+ assert not prepared.should_close # BinaryIO provided by user shouldn't be closed
+
+
+@pytest.mark.parametrize("client_type", client_types)
+async def test_node_get_file_for_upload_none(
+ client_type: str, clients: BothClients, file_object_schema: NodeSchemaAPI
+) -> None:
+ """Test _get_file_for_upload with no file set returns PreparedFile with None values."""
+ if client_type == "standard":
+ node = InfrahubNode(client=clients.standard, schema=file_object_schema, branch="main")
+ else:
+ node = InfrahubNodeSync(client=clients.sync, schema=file_object_schema, branch="main")
+
+ if isinstance(node, InfrahubNode):
+ prepared = await node._get_file_for_upload()
+ else:
+ prepared = node._get_file_for_upload_sync()
+
+ assert prepared.file_object is None
+ assert prepared.filename is None
+ assert not prepared.should_close
+
+
+@pytest.mark.parametrize("client_type", client_types)
+async def test_node_generate_input_data_with_file(
+ client_type: str, clients: BothClients, file_object_schema: NodeSchemaAPI
+) -> None:
+ """Test _generate_input_data places file at mutation level, not inside data."""
+ if client_type == "standard":
+ node = InfrahubNode(client=clients.standard, schema=file_object_schema, branch="main")
+ else:
+ node = InfrahubNodeSync(client=clients.sync, schema=file_object_schema, branch="main")
+
+ node.upload_from_bytes(content=b"test content", name="test.txt")
+
+ input_data = node._generate_input_data()
+
+ assert "file" in input_data["data"], "file should be at mutation payload level"
+ assert input_data["data"]["file"] == "$file"
+ assert "file" not in input_data["data"]["data"], "file should not be inside nested data dict"
+ assert "file" in input_data["mutation_variables"]
+ assert input_data["mutation_variables"]["file"] is bytes
diff --git a/tests/unit/sdk/test_query_analyzer.py b/tests/unit/sdk/test_query_analyzer.py
index a4deefb2..ffa810ad 100644
--- a/tests/unit/sdk/test_query_analyzer.py
+++ b/tests/unit/sdk/test_query_analyzer.py
@@ -148,7 +148,7 @@ async def test_get_variables(query_01: str, query_04: str, query_05: str, query_
@pytest.mark.parametrize(
- "var_type,var_required",
+ ("var_type", "var_required"),
[("[ID]", False), ("[ID]!", True), ("[ID!]", False), ("[ID!]!", True)],
)
async def test_get_nested_variables(var_type: str, var_required: bool) -> None:
diff --git a/tests/unit/sdk/test_schema.py b/tests/unit/sdk/test_schema.py
index 05302b11..149b584c 100644
--- a/tests/unit/sdk/test_schema.py
+++ b/tests/unit/sdk/test_schema.py
@@ -243,10 +243,7 @@ async def test_schema_wait_happy_path(clients: BothClients, client_type: list[st
@pytest.mark.parametrize("client_type", client_types)
async def test_schema_set_cache_dict(clients: BothClients, client_type: list[str], schema_query_01_data: dict) -> None:
- if client_type == "standard":
- client = clients.standard
- else:
- client = clients.sync
+ client = clients.standard if client_type == "standard" else clients.sync
client.schema.set_cache(schema_query_01_data, branch="branch1")
assert "branch1" in client.schema.cache
@@ -257,10 +254,7 @@ async def test_schema_set_cache_dict(clients: BothClients, client_type: list[str
async def test_schema_set_cache_branch_schema(
clients: BothClients, client_type: list[str], schema_query_01_data: dict
) -> None:
- if client_type == "standard":
- client = clients.standard
- else:
- client = clients.sync
+ client = clients.standard if client_type == "standard" else clients.sync
schema = BranchSchema.from_api_response(schema_query_01_data)
diff --git a/tests/unit/sdk/test_schema_export.py b/tests/unit/sdk/test_schema_export.py
new file mode 100644
index 00000000..ed2814e4
--- /dev/null
+++ b/tests/unit/sdk/test_schema_export.py
@@ -0,0 +1,269 @@
+from __future__ import annotations
+
+import warnings
+from typing import TYPE_CHECKING, Any
+
+import pytest
+
+from infrahub_sdk.schema import (
+ GenericSchemaAPI,
+ InfrahubSchemaBase,
+ NodeSchemaAPI,
+ ProfileSchemaAPI,
+ SchemaExport,
+ TemplateSchemaAPI,
+)
+
+if TYPE_CHECKING:
+ from pytest_httpx import HTTPXMock
+
+ from tests.unit.sdk.conftest import BothClients
+
+client_types = ["standard", "sync"]
+
+# ---------------------------------------------------------------------------
+# Minimal schema API response builders (reused from ctl tests)
+# ---------------------------------------------------------------------------
+
+_BASE_NODE: dict[str, Any] = {
+ "id": None,
+ "state": "present",
+ "hash": None,
+ "hierarchy": None,
+ "label": None,
+ "description": None,
+ "include_in_menu": None,
+ "menu_placement": None,
+ "display_label": None,
+ "display_labels": None,
+ "human_friendly_id": None,
+ "icon": None,
+ "uniqueness_constraints": None,
+ "documentation": None,
+ "order_by": None,
+ "inherit_from": [],
+ "branch": "aware",
+ "default_filter": None,
+ "generate_profile": None,
+ "generate_template": None,
+ "parent": None,
+ "children": None,
+ "attributes": [],
+ "relationships": [],
+}
+
+_BASE_GENERIC: dict[str, Any] = {
+ "id": None,
+ "state": "present",
+ "hash": None,
+ "used_by": [],
+ "label": None,
+ "description": None,
+ "include_in_menu": None,
+ "menu_placement": None,
+ "display_label": None,
+ "display_labels": None,
+ "human_friendly_id": None,
+ "icon": None,
+ "uniqueness_constraints": None,
+ "documentation": None,
+ "order_by": None,
+ "attributes": [],
+ "relationships": [],
+}
+
+
+def _make_node_schema(namespace: str, name: str) -> NodeSchemaAPI:
+ return NodeSchemaAPI(**{**_BASE_NODE, "namespace": namespace, "name": name})
+
+
+def _make_generic_schema(namespace: str, name: str) -> GenericSchemaAPI:
+ return GenericSchemaAPI(**{**_BASE_GENERIC, "namespace": namespace, "name": name})
+
+
+def _make_profile_schema(namespace: str, name: str) -> ProfileSchemaAPI:
+ return ProfileSchemaAPI(
+ **{
+ **_BASE_NODE,
+ "namespace": namespace,
+ "name": name,
+ }
+ )
+
+
+def _make_template_schema(namespace: str, name: str) -> TemplateSchemaAPI:
+ return TemplateSchemaAPI(
+ **{
+ **_BASE_NODE,
+ "namespace": namespace,
+ "name": name,
+ }
+ )
+
+
+# ---------------------------------------------------------------------------
+# _build_export_schemas tests
+# ---------------------------------------------------------------------------
+
+
+class TestBuildExportSchemas:
+ def test_separates_nodes_and_generics(self) -> None:
+ schema_nodes = {
+ "InfraDevice": _make_node_schema("Infra", "Device"),
+ "InfraInterface": _make_generic_schema("Infra", "Interface"),
+ }
+ result = InfrahubSchemaBase._build_export_schemas(schema_nodes)
+ assert isinstance(result, SchemaExport)
+ assert "Infra" in result.namespaces
+ assert len(result.namespaces["Infra"].nodes) == 1
+ assert len(result.namespaces["Infra"].generics) == 1
+ assert result.namespaces["Infra"].nodes[0]["name"] == "Device"
+ assert result.namespaces["Infra"].generics[0]["name"] == "Interface"
+
+ def test_groups_by_namespace(self) -> None:
+ schema_nodes = {
+ "InfraDevice": _make_node_schema("Infra", "Device"),
+ "DcimRack": _make_node_schema("Dcim", "Rack"),
+ }
+ result = InfrahubSchemaBase._build_export_schemas(schema_nodes)
+ assert set(result.namespaces.keys()) == {"Infra", "Dcim"}
+
+ def test_filters_profiles_and_templates(self) -> None:
+ schema_nodes = {
+ "InfraDevice": _make_node_schema("Infra", "Device"),
+ "ProfileInfraDevice": _make_profile_schema("Profile", "InfraDevice"),
+ "TemplateInfraDevice": _make_template_schema("Template", "InfraDevice"),
+ }
+ result = InfrahubSchemaBase._build_export_schemas(schema_nodes)
+ assert "Infra" in result.namespaces
+ assert "Profile" not in result.namespaces
+ assert "Template" not in result.namespaces
+
+ def test_filters_restricted_namespaces(self) -> None:
+ schema_nodes = {
+ "CoreRepository": _make_node_schema("Core", "Repository"),
+ "BuiltinTag": _make_node_schema("Builtin", "Tag"),
+ "InfraDevice": _make_node_schema("Infra", "Device"),
+ }
+ result = InfrahubSchemaBase._build_export_schemas(schema_nodes)
+ assert "Core" not in result.namespaces
+ assert "Builtin" not in result.namespaces
+ assert "Infra" in result.namespaces
+
+ def test_namespace_filter(self) -> None:
+ schema_nodes = {
+ "InfraDevice": _make_node_schema("Infra", "Device"),
+ "DcimRack": _make_node_schema("Dcim", "Rack"),
+ }
+ result = InfrahubSchemaBase._build_export_schemas(schema_nodes, namespaces=["Infra"])
+ assert "Infra" in result.namespaces
+ assert "Dcim" not in result.namespaces
+
+ def test_empty_when_no_user_schemas(self) -> None:
+ schema_nodes = {
+ "CoreRepository": _make_node_schema("Core", "Repository"),
+ }
+ result = InfrahubSchemaBase._build_export_schemas(schema_nodes)
+ assert result.namespaces == {}
+
+ def test_warns_on_restricted_namespaces(self) -> None:
+ schema_nodes = {
+ "InfraDevice": _make_node_schema("Infra", "Device"),
+ }
+ with warnings.catch_warnings(record=True) as w:
+ warnings.simplefilter("always")
+ result = InfrahubSchemaBase._build_export_schemas(schema_nodes, namespaces=["Infra", "Core"])
+ assert len(w) == 1
+ assert "Core" in str(w[0].message)
+ assert "Infra" in result.namespaces
+
+ def test_to_dict(self) -> None:
+ schema_nodes = {
+ "InfraDevice": _make_node_schema("Infra", "Device"),
+ "InfraInterface": _make_generic_schema("Infra", "Interface"),
+ }
+ result = InfrahubSchemaBase._build_export_schemas(schema_nodes)
+ as_dict = result.to_dict()
+ assert isinstance(as_dict, dict)
+ assert "Infra" in as_dict
+ assert isinstance(as_dict["Infra"], dict)
+ assert len(as_dict["Infra"]["nodes"]) == 1
+ assert len(as_dict["Infra"]["generics"]) == 1
+
+
+# ---------------------------------------------------------------------------
+# Integration tests for export() method on client.schema
+# ---------------------------------------------------------------------------
+
+
+def _schema_response(
+ nodes: list[dict] | None = None,
+ generics: list[dict] | None = None,
+ profiles: list[dict] | None = None,
+ templates: list[dict] | None = None,
+) -> dict:
+ return {
+ "main": "aabbccdd",
+ "nodes": nodes or [],
+ "generics": generics or [],
+ "profiles": profiles or [],
+ "templates": templates or [],
+ }
+
+
+def _make_node_dict(namespace: str, name: str) -> dict[str, Any]:
+ return {**_BASE_NODE, "namespace": namespace, "name": name}
+
+
+def _make_generic_dict(namespace: str, name: str) -> dict[str, Any]:
+ return {**_BASE_GENERIC, "namespace": namespace, "name": name}
+
+
+@pytest.mark.parametrize("client_type", client_types)
+async def test_export_returns_user_schemas(httpx_mock: HTTPXMock, clients: BothClients, client_type: str) -> None:
+ response = _schema_response(
+ nodes=[_make_node_dict("Infra", "Device"), _make_node_dict("Dcim", "Rack")],
+ generics=[_make_generic_dict("Infra", "GenericInterface")],
+ )
+ httpx_mock.add_response(method="GET", url="http://mock/api/schema?branch=main", json=response)
+
+ if client_type == "standard":
+ result = await clients.standard.schema.export(branch="main")
+ else:
+ result = clients.sync.schema.export(branch="main")
+
+ assert isinstance(result, SchemaExport)
+ assert "Infra" in result.namespaces
+ assert "Dcim" in result.namespaces
+ assert len(result.namespaces["Infra"].nodes) == 1
+ assert len(result.namespaces["Infra"].generics) == 1
+ assert len(result.namespaces["Dcim"].nodes) == 1
+
+
+@pytest.mark.parametrize("client_type", client_types)
+async def test_export_with_namespace_filter(httpx_mock: HTTPXMock, clients: BothClients, client_type: str) -> None:
+ response = _schema_response(
+ nodes=[_make_node_dict("Infra", "Device")],
+ )
+ httpx_mock.add_response(method="GET", url="http://mock/api/schema?branch=main&namespaces=Infra", json=response)
+
+ if client_type == "standard":
+ result = await clients.standard.schema.export(branch="main", namespaces=["Infra"])
+ else:
+ result = clients.sync.schema.export(branch="main", namespaces=["Infra"])
+
+ assert "Infra" in result.namespaces
+ assert "Dcim" not in result.namespaces
+
+
+@pytest.mark.parametrize("client_type", client_types)
+async def test_export_empty_when_only_restricted(httpx_mock: HTTPXMock, clients: BothClients, client_type: str) -> None:
+ response = _schema_response(nodes=[_make_node_dict("Core", "Repository")])
+ httpx_mock.add_response(method="GET", url="http://mock/api/schema?branch=main", json=response)
+
+ if client_type == "standard":
+ result = await clients.standard.schema.export(branch="main")
+ else:
+ result = clients.sync.schema.export(branch="main")
+
+ assert result.namespaces == {}
diff --git a/tests/unit/sdk/test_timestamp.py b/tests/unit/sdk/test_timestamp.py
index 713f2f72..fd189f4b 100644
--- a/tests/unit/sdk/test_timestamp.py
+++ b/tests/unit/sdk/test_timestamp.py
@@ -43,7 +43,7 @@ def test_parse_string() -> None:
@pytest.mark.parametrize(
- "input_str,expected_datetime",
+ ("input_str", "expected_datetime"),
[
pytest.param(
"2022-01-01T10:01:01.123000Z", datetime(2022, 1, 1, 10, 1, 1, 123000, tzinfo=UTC), id="milliseconds"
@@ -69,7 +69,7 @@ def test_to_datetime(input_str: str, expected_datetime: datetime) -> None:
@pytest.mark.parametrize(
- "input_str,expected_str,expected_str_no_z",
+ ("input_str", "expected_str", "expected_str_no_z"),
[
pytest.param(
"2022-01-01T10:01:01.123000Z",
diff --git a/tests/unit/test_tasks.py b/tests/unit/test_tasks.py
index 3510f1a4..70a4064c 100644
--- a/tests/unit/test_tasks.py
+++ b/tests/unit/test_tasks.py
@@ -1,20 +1,15 @@
from __future__ import annotations
-from typing import TYPE_CHECKING
-
import pytest
from invoke import Exit
import tasks
-if TYPE_CHECKING:
- from pytest import MonkeyPatch
-
class TestRequireTool:
"""Verify that require_tool() raises with a helpful message when an external tool is missing."""
- def test_raises_when_tool_missing(self, monkeypatch: MonkeyPatch) -> None:
+ def test_raises_when_tool_missing(self, monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange
monkeypatch.setattr(tasks, "which", lambda _name: None)
@@ -22,14 +17,14 @@ def test_raises_when_tool_missing(self, monkeypatch: MonkeyPatch) -> None:
with pytest.raises(Exit, match="mytool is not installed"):
tasks.require_tool("mytool", "Install it with: pip install mytool")
- def test_passes_when_tool_installed(self, monkeypatch: MonkeyPatch) -> None:
+ def test_passes_when_tool_installed(self, monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange
monkeypatch.setattr(tasks, "which", lambda _name: "/usr/bin/mytool")
# Act / Assert — no exception means tool is found
tasks.require_tool("mytool", "Install it with: pip install mytool")
- def test_error_message_contains_install_hint(self, monkeypatch: MonkeyPatch) -> None:
+ def test_error_message_contains_install_hint(self, monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange
monkeypatch.setattr(tasks, "which", lambda _name: None)
diff --git a/uv.lock b/uv.lock
index 0821d3b6..b832ef39 100644
--- a/uv.lock
+++ b/uv.lock
@@ -686,7 +686,7 @@ wheels = [
[[package]]
name = "infrahub-sdk"
-version = "1.18.1"
+version = "1.19.0"
source = { editable = "." }
dependencies = [
{ name = "dulwich" },