diff --git a/.vale/styles/Infrahub/sentence-case.yml b/.vale/styles/Infrahub/sentence-case.yml index 126e18f6..c27cf7a1 100644 --- a/.vale/styles/Infrahub/sentence-case.yml +++ b/.vale/styles/Infrahub/sentence-case.yml @@ -52,6 +52,7 @@ exceptions: - Jinja - Jinja2 - JWT + - MDX - Namespace - NATS - Node diff --git a/.vale/styles/spelling-exceptions.txt b/.vale/styles/spelling-exceptions.txt index 0a4c0144..ecba179f 100644 --- a/.vale/styles/spelling-exceptions.txt +++ b/.vale/styles/spelling-exceptions.txt @@ -79,6 +79,7 @@ kbps Keycloak Loopbacks markdownlint +MDX max_count memgraph menu_placement diff --git a/CHANGELOG.md b/CHANGELOG.md index b6c418e0..a3200c39 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,24 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), This project uses [*towncrier*](https://towncrier.readthedocs.io/) and the changes for the upcoming release can be found in . + +## [1.19.0](https://github.com/opsmill/infrahub-sdk-python/tree/v1.19.0) - 2026-03-16 + +### Added + +- Added support for FileObject nodes with file upload and download capabilities. New methods `upload_from_path(path)` and `upload_from_bytes(content, name)` allow setting file content before saving, while `download_file(dest)` enables downloading files to memory or streaming to disk for large files. ([#ihs193](https://github.com/opsmill/infrahub-sdk-python/issues/ihs193)) +- Python SDK API documentation is now generated directly from the docstrings of the classes, functions, and methods contained in the code. ([#201](https://github.com/opsmill/infrahub-sdk-python/issues/201)) +- Added a 'py.typed' file to the project. This is to enable type checking when the Infrahub SDK is imported from other projects. The addition of this file could cause new typing issues in external projects until all typing issues have been resolved. Adding it to the project now to better highlight remaining issues. + +### Changed + +- Updated branch report command to use node metadata for proposed change creator information instead of the deprecated relationship-based approach. Requires Infrahub 1.7 or above. + +### Fixed + +- Allow SDK tracking feature to continue after encountering delete errors due to impacted nodes having already been deleted by cascade delete. ([#265](https://github.com/opsmill/infrahub-sdk-python/issues/265)) +- Fixed Python SDK query generation regarding from_pool generated attribute value ([#497](https://github.com/opsmill/infrahub-sdk-python/issues/497)) + ## [1.18.1](https://github.com/opsmill/infrahub-sdk-python/tree/v1.18.1) - 2026-01-08 ### Fixed diff --git a/changelog/151.added.md b/changelog/151.added.md new file mode 100644 index 00000000..c770976f --- /dev/null +++ b/changelog/151.added.md @@ -0,0 +1 @@ +Add `infrahubctl schema export` command to export schemas from Infrahub. diff --git a/changelog/201.added.md b/changelog/201.added.md deleted file mode 100644 index bb3fbf00..00000000 --- a/changelog/201.added.md +++ /dev/null @@ -1 +0,0 @@ -Python SDK API documentation is now generated directly from the docstrings of the classes, functions, and methods contained in the code. \ No newline at end of file diff --git a/changelog/265.fixed.md b/changelog/265.fixed.md deleted file mode 100644 index 4e3c43a9..00000000 --- a/changelog/265.fixed.md +++ /dev/null @@ -1 +0,0 @@ -Allow SDK tracking feature to continue after encountering delete errors due to impacted nodes having already been deleted by cascade delete. diff --git a/changelog/497.fixed.md b/changelog/497.fixed.md deleted file mode 100644 index b32323d1..00000000 --- a/changelog/497.fixed.md +++ /dev/null @@ -1 +0,0 @@ -Fixed Python SDK query generation regarding from_pool generated attribute value diff --git a/docs/AGENTS.md b/docs/AGENTS.md index 23ae5f68..43e104c6 100644 --- a/docs/AGENTS.md +++ b/docs/AGENTS.md @@ -1,4 +1,4 @@ -# docs/AGENTS.md +# Documentation agents Docusaurus documentation following Diataxis framework. @@ -34,12 +34,12 @@ Sidebar navigation is dynamic: `sidebars-*.ts` files read the filesystem at buil No manual sidebar update is needed when adding a new `.mdx` file. However, to control the display order of a new page, add its doc ID to the ordered list in the corresponding `sidebars-*.ts` file. -## Adding Documentation +## Adding documentation 1. Create MDX file in appropriate directory 2. Add frontmatter with `title` -## MDX Pattern +## MDX pattern Use Tabs for async/sync examples, callouts for notes: diff --git a/docs/docs/infrahubctl/infrahubctl-schema.mdx b/docs/docs/infrahubctl/infrahubctl-schema.mdx index 7d569cc4..1467eae8 100644 --- a/docs/docs/infrahubctl/infrahubctl-schema.mdx +++ b/docs/docs/infrahubctl/infrahubctl-schema.mdx @@ -17,6 +17,7 @@ $ infrahubctl schema [OPTIONS] COMMAND [ARGS]... **Commands**: * `check`: Check if schema files are valid and what... +* `export`: Export the schema from Infrahub as YAML... * `load`: Load one or multiple schema files into... ## `infrahubctl schema check` @@ -40,6 +41,25 @@ $ infrahubctl schema check [OPTIONS] SCHEMAS... * `--config-file TEXT`: [env var: INFRAHUBCTL_CONFIG; default: infrahubctl.toml] * `--help`: Show this message and exit. +## `infrahubctl schema export` + +Export the schema from Infrahub as YAML files, one per namespace. + +**Usage**: + +```console +$ infrahubctl schema export [OPTIONS] +``` + +**Options**: + +* `--directory PATH`: Directory path to store schema files [default: (dynamic)] +* `--branch TEXT`: Branch from which to export the schema +* `--namespaces TEXT`: Namespace(s) to export (default: all user-defined) +* `--debug / --no-debug`: [default: no-debug] +* `--config-file TEXT`: [env var: INFRAHUBCTL_CONFIG; default: infrahubctl.toml] +* `--help`: Show this message and exit. + ## `infrahubctl schema load` Load one or multiple schema files into Infrahub. diff --git a/docs/docs/python-sdk/guides/client.mdx b/docs/docs/python-sdk/guides/client.mdx index 460036a1..90872a0f 100644 --- a/docs/docs/python-sdk/guides/client.mdx +++ b/docs/docs/python-sdk/guides/client.mdx @@ -251,7 +251,7 @@ Your client is now configured to use the specified default branch instead of `ma ## Hello world example -Let's create a simple "Hello World" example to verify your client configuration works correctly. This example will connect to your Infrahub instance and query the available accounts. +Let's create a "Hello World" example to verify your client configuration works correctly. This example will connect to your Infrahub instance and query the available accounts. 1. Create a new file called `hello_world.py`: diff --git a/docs/docs/python-sdk/guides/python-typing.mdx b/docs/docs/python-sdk/guides/python-typing.mdx index 9bc2c323..77780177 100644 --- a/docs/docs/python-sdk/guides/python-typing.mdx +++ b/docs/docs/python-sdk/guides/python-typing.mdx @@ -131,7 +131,7 @@ infrahubctl graphql generate-return-types queries/get_tags.gql ### Example workflow -1. **Create your GraphQL queries** in `.gql` files preferably in a directory (e.g., `queries/`): +1. **Create your GraphQL queries** in `.gql` files preferably in a directory (for example, `queries/`): ```graphql # queries/get_tags.gql diff --git a/docs/docs/python-sdk/sdk_ref/infrahub_sdk/node/node.mdx b/docs/docs/python-sdk/sdk_ref/infrahub_sdk/node/node.mdx index ff1b817f..e23120dd 100644 --- a/docs/docs/python-sdk/sdk_ref/infrahub_sdk/node/node.mdx +++ b/docs/docs/python-sdk/sdk_ref/infrahub_sdk/node/node.mdx @@ -37,6 +37,44 @@ artifact_generate(self, name: str) -> None artifact_fetch(self, name: str) -> str | dict[str, Any] ``` +#### `download_file` + +```python +download_file(self, dest: Path | None = None) -> bytes | int +``` + +Download the file content from this FileObject node. + +This method is only available for nodes that inherit from CoreFileObject. +The node must have been saved (have an id) before calling this method. + +**Args:** + +- `dest`: Optional destination path. If provided, the file will be streamed + directly to this path (memory-efficient for large files) and the + number of bytes written will be returned. If not provided, the + file content will be returned as bytes. + +**Returns:** + +- If ``dest`` is None: The file content as bytes. +- If ``dest`` is provided: The number of bytes written to the file. + +**Raises:** + +- `FeatureNotSupportedError`: If this node doesn't inherit from CoreFileObject. +- `ValueError`: If the node hasn't been saved yet or file not found. +- `AuthenticationError`: If authentication fails. + +**Examples:** + +```python +>>> # Download to memory +>>> content = await contract.download_file() +>>> # Stream to file (memory-efficient for large files) +>>> bytes_written = await contract.download_file(dest=Path("/tmp/contract.pdf")) +``` + #### `delete` ```python @@ -180,6 +218,44 @@ artifact_generate(self, name: str) -> None artifact_fetch(self, name: str) -> str | dict[str, Any] ``` +#### `download_file` + +```python +download_file(self, dest: Path | None = None) -> bytes | int +``` + +Download the file content from this FileObject node. + +This method is only available for nodes that inherit from CoreFileObject. +The node must have been saved (have an id) before calling this method. + +**Args:** + +- `dest`: Optional destination path. If provided, the file will be streamed + directly to this path (memory-efficient for large files) and the + number of bytes written will be returned. If not provided, the + file content will be returned as bytes. + +**Returns:** + +- If ``dest`` is None: The file content as bytes. +- If ``dest`` is provided: The number of bytes written to the file. + +**Raises:** + +- `FeatureNotSupportedError`: If this node doesn't inherit from CoreFileObject. +- `ValueError`: If the node hasn't been saved yet or file not found. +- `AuthenticationError`: If authentication fails. + +**Examples:** + +```python +>>> # Download to memory +>>> content = contract.download_file() +>>> # Stream to file (memory-efficient for large files) +>>> bytes_written = contract.download_file(dest=Path("/tmp/contract.pdf")) +``` + #### `delete` ```python @@ -373,6 +449,70 @@ is_ip_address(self) -> bool is_resource_pool(self) -> bool ``` +#### `is_file_object` + +```python +is_file_object(self) -> bool +``` + +Check if this node inherits from CoreFileObject and supports file uploads. + +#### `upload_from_path` + +```python +upload_from_path(self, path: Path) -> None +``` + +Set a file from disk to be uploaded when saving this FileObject node. + +The file will be streamed during upload, avoiding loading the entire file into memory. + +**Args:** + +- `path`: Path to the file on disk. + +**Raises:** + +- `FeatureNotSupportedError`: If this node doesn't inherit from CoreFileObject. + +#### `upload_from_bytes` + +```python +upload_from_bytes(self, content: bytes | BinaryIO, name: str) -> None +``` + +Set content to be uploaded when saving this FileObject node. + +The content can be provided as bytes or a file-like object. +Using BinaryIO is recommended for large content to stream during upload. + +**Args:** + +- `content`: The file content as bytes or a file-like object. +- `name`: The filename to use for the uploaded file. + +**Raises:** + +- `FeatureNotSupportedError`: If this node doesn't inherit from CoreFileObject. + +**Examples:** + +```python +>>> # Using bytes (for small files) +>>> node.upload_from_bytes(content=b"file content", name="example.txt") +>>> # Using file-like object (for large files) +>>> with open("/path/to/file.bin", "rb") as f: +... node.upload_from_bytes(content=f, name="file.bin") +``` + +#### `clear_file` + +```python +clear_file(self) -> None +``` + +Clear any pending file content. + #### `get_raw_graphql_data` ```python diff --git a/docs/docs/python-sdk/topics/object_file.mdx b/docs/docs/python-sdk/topics/object_file.mdx index 751599f0..75744d89 100644 --- a/docs/docs/python-sdk/topics/object_file.mdx +++ b/docs/docs/python-sdk/topics/object_file.mdx @@ -68,13 +68,13 @@ spec: > Multiple documents in a single YAML file are also supported, each document will be loaded separately. Documents are separated by `---` -### Data Processing Parameters +### Data processing parameters The `parameters` field controls how the data in the object file is processed before loading into Infrahub: -| Parameter | Description | Default | -| -------------- | ------------------------------------------------------------------------------------------------------- | ------- | -| `expand_range` | When set to `true`, range patterns (e.g., `[1-5]`) in string fields are expanded into multiple objects. | `false` | +| Parameter | Description | Default | +| -------------- | -------------------------------------------------------------------------------------------------------------- | ------- | +| `expand_range` | When set to `true`, range patterns (for example, `[1-5]`) in string fields are expanded into multiple objects. | `false` | When `expand_range` is not specified, it defaults to `false`. @@ -208,9 +208,9 @@ Metadata support is planned for future releases. Currently, the Object file does 3. Validate object files before loading them into production environments. 4. Use comments in your YAML files to document complex relationships or dependencies. -## Range Expansion in Object Files +## Range expansion in object files -The Infrahub Python SDK supports **range expansion** for string fields in object files when the `parameters > expand_range` is set to `true`. This feature allows you to specify a range pattern (e.g., `[1-5]`) in any string value, and the SDK will automatically expand it into multiple objects during validation and processing. +The Infrahub Python SDK supports **range expansion** for string fields in object files when the `parameters > expand_range` is set to `true`. This feature allows you to specify a range pattern (for example, `[1-5]`) in any string value, and the SDK will automatically expand it into multiple objects during validation and processing. ```yaml --- @@ -225,7 +225,7 @@ spec: type: Country ``` -### How Range Expansion Works +### How range expansion works - Any string field containing a pattern like `[1-5]`, `[10-15]`, or `[1,3,5]` will be expanded into multiple objects. - If multiple fields in the same object use range expansion, **all expanded lists must have the same length**. If not, validation will fail. @@ -233,7 +233,7 @@ spec: ### Examples -#### Single Field Expansion +#### Single field expansion ```yaml spec: @@ -256,7 +256,7 @@ This will expand to: type: Country ``` -#### Multiple Field Expansion (Matching Lengths) +#### Multiple field expansion (matching lengths) ```yaml spec: @@ -283,7 +283,7 @@ This will expand to: type: Country ``` -#### Error: Mismatched Range Lengths +#### Error: mismatched range lengths If you use ranges of different lengths in multiple fields: diff --git a/docs/docs_generation/content_gen_methods/mdx/mdx_collapsed_overload_code_doc.py b/docs/docs_generation/content_gen_methods/mdx/mdx_collapsed_overload_code_doc.py index 95ea8f00..1f680d15 100644 --- a/docs/docs_generation/content_gen_methods/mdx/mdx_collapsed_overload_code_doc.py +++ b/docs/docs_generation/content_gen_methods/mdx/mdx_collapsed_overload_code_doc.py @@ -44,7 +44,7 @@ def _collapse_overloads(self, mdx_file: MdxFile) -> MdxFile: h3_parsed = _parse_sections(h2.content, heading_level=3) new_lines = h3_parsed.reassembled(processed_h3) processed_h2.append( - MdxSection(name=h2.name, heading_level=h2.heading_level, _lines=[h2.heading] + new_lines) + MdxSection(name=h2.name, heading_level=h2.heading_level, _lines=[h2.heading, *new_lines]) ) new_content = "\n".join(parsed_h2.reassembled(processed_h2)) @@ -67,7 +67,7 @@ def _process_class_sections(self, h2_content: list[str]) -> list[ASection] | Non h4_parsed = _parse_sections(h3.content, heading_level=4) new_lines = h4_parsed.reassembled(collapsed_methods) processed.append( - MdxSection(name=h3.name, heading_level=h3.heading_level, _lines=[h3.heading] + new_lines) + MdxSection(name=h3.name, heading_level=h3.heading_level, _lines=[h3.heading, *new_lines]) ) return processed if any_collapsed else None diff --git a/docs/docs_generation/content_gen_methods/mdx/mdx_section.py b/docs/docs_generation/content_gen_methods/mdx/mdx_section.py index 6187f67a..ed4e25fa 100644 --- a/docs/docs_generation/content_gen_methods/mdx/mdx_section.py +++ b/docs/docs_generation/content_gen_methods/mdx/mdx_section.py @@ -19,7 +19,7 @@ def content(self) -> list[str]: ... @property def lines(self) -> list[str]: - return [self.heading] + self.content + return [self.heading, *self.content] @dataclass diff --git a/infrahub_sdk/branch.py b/infrahub_sdk/branch.py index 2c32a481..4c89bd41 100644 --- a/infrahub_sdk/branch.py +++ b/infrahub_sdk/branch.py @@ -19,6 +19,7 @@ class BranchStatus(str, Enum): NEED_REBASE = "NEED_REBASE" NEED_UPGRADE_REBASE = "NEED_UPGRADE_REBASE" DELETING = "DELETING" + MERGED = "MERGED" class BranchData(BaseModel): diff --git a/infrahub_sdk/client.py b/infrahub_sdk/client.py index f2ebfab7..988d30bd 100644 --- a/infrahub_sdk/client.py +++ b/infrahub_sdk/client.py @@ -5,18 +5,12 @@ import logging import time import warnings -from collections.abc import Callable, Coroutine, Mapping, MutableMapping +from collections.abc import AsyncIterator, Callable, Coroutine, Iterator, Mapping, MutableMapping +from contextlib import asynccontextmanager, contextmanager from datetime import datetime from functools import wraps from time import sleep -from typing import ( - TYPE_CHECKING, - Any, - Literal, - TypedDict, - TypeVar, - overload, -) +from typing import TYPE_CHECKING, Any, BinaryIO, Literal, TypedDict, TypeVar, overload from urllib.parse import urlencode import httpx @@ -24,12 +18,7 @@ from typing_extensions import Self from .batch import InfrahubBatch, InfrahubBatchSync -from .branch import ( - MUTATION_QUERY_TASK, - BranchData, - InfrahubBranchManager, - InfrahubBranchManagerSync, -) +from .branch import MUTATION_QUERY_TASK, BranchData, InfrahubBranchManager, InfrahubBranchManagerSync from .config import Config from .constants import InfrahubClientMode from .convert_object_type import CONVERT_OBJECT_MUTATION, ConversionFieldInput @@ -44,11 +33,8 @@ ServerNotResponsiveError, URLNotFoundError, ) -from .graphql import Mutation, Query -from .node import ( - InfrahubNode, - InfrahubNodeSync, -) +from .graphql import MultipartBuilder, Mutation, Query +from .node import InfrahubNode, InfrahubNodeSync from .object_store import ObjectStore, ObjectStoreSync from .protocols_base import CoreNode, CoreNodeSync from .queries import QUERY_USER, get_commit_update_mutation @@ -994,7 +980,7 @@ async def execute_graphql( messages = [error.get("message") for error in errors] raise AuthenticationError(" | ".join(messages)) from exc if exc.response.status_code == 404: - raise URLNotFoundError(url=url) + raise URLNotFoundError(url=url) from exc if not resp: raise Error("Unexpected situation, resp hasn't been initialized.") @@ -1008,6 +994,128 @@ async def execute_graphql( # TODO add a special method to execute mutation that will check if the method returned OK + async def _execute_graphql_with_file( + self, + query: str, + variables: dict | None = None, + file_content: BinaryIO | None = None, + file_name: str | None = None, + branch_name: str | None = None, + timeout: int | None = None, + tracker: str | None = None, + ) -> dict: + """Execute a GraphQL mutation with a file upload using multipart/form-data. + + This method follows the GraphQL Multipart Request Spec for file uploads. + The file is attached to the 'file' variable in the mutation. + + Args: + query: GraphQL mutation query that includes a $file variable of type Upload! + variables: Variables to pass along with the GraphQL query. + file_content: The file content as a file-like object (BinaryIO). + file_name: The name of the file being uploaded. + branch_name: Name of the branch on which the mutation will be executed. + timeout: Timeout in seconds for the query. + tracker: Optional tracker for request tracing. + + Raises: + GraphQLError: When the GraphQL response contains errors. + + Returns: + dict: The GraphQL data payload (response["data"]). + """ + branch_name = branch_name or self.default_branch + url = self._graphql_url(branch_name=branch_name) + + # Prepare variables with file placeholder + variables = variables or {} + variables["file"] = None + + headers = copy.copy(self.headers or {}) + # Remove content-type header - httpx will set it for multipart + headers.pop("content-type", None) + if self.insert_tracker and tracker: + headers["X-Infrahub-Tracker"] = tracker + + self._echo(url=url, query=query, variables=variables) + + resp = await self._post_multipart( + url=url, + query=query, + variables=variables, + file_content=file_content, + file_name=file_name or "upload", + headers=headers, + timeout=timeout, + ) + + resp.raise_for_status() + response = decode_json(response=resp) + + if "errors" in response: + raise GraphQLError(errors=response["errors"], query=query, variables=variables) + + return response["data"] + + @handle_relogin + async def _post_multipart( + self, + url: str, + query: str, + variables: dict, + file_content: BinaryIO | None, + file_name: str, + headers: dict | None = None, + timeout: int | None = None, + ) -> httpx.Response: + """Execute a HTTP POST with multipart/form-data for GraphQL file uploads. + + The file_content is streamed directly from the file-like object, avoiding loading the entire file into memory for large files. + """ + await self.login() + + headers = headers or {} + base_headers = copy.copy(self.headers or {}) + # Remove content-type from base headers - httpx will set it for multipart + base_headers.pop("content-type", None) + headers.update(base_headers) + + # Build the multipart form data according to GraphQL Multipart Request Spec + files = MultipartBuilder.build_payload( + query=query, variables=variables, file_content=file_content, file_name=file_name + ) + + return await self._request_multipart( + url=url, headers=headers, timeout=timeout or self.default_timeout, files=files + ) + + def _build_proxy_config(self) -> ProxyConfig: + """Build proxy configuration for httpx AsyncClient.""" + proxy_config: ProxyConfig = {"proxy": None, "mounts": None} + if self.config.proxy: + proxy_config["proxy"] = self.config.proxy + elif self.config.proxy_mounts.is_set: + proxy_config["mounts"] = { + key: httpx.AsyncHTTPTransport(proxy=value) + for key, value in self.config.proxy_mounts.model_dump(by_alias=True).items() + } + return proxy_config + + async def _request_multipart( + self, url: str, headers: dict[str, Any], timeout: int, files: dict[str, Any] + ) -> httpx.Response: + """Execute a multipart HTTP POST request.""" + async with httpx.AsyncClient(**self._build_proxy_config(), verify=self.config.tls_context) as client: + try: + response = await client.post(url=url, headers=headers, timeout=timeout, files=files) + except httpx.NetworkError as exc: + raise ServerNotReachableError(address=self.address) from exc + except httpx.ReadTimeout as exc: + raise ServerNotResponsiveError(url=url, timeout=timeout) from exc + + self._record(response) + return response + @handle_relogin async def _post( self, @@ -1057,6 +1165,36 @@ async def _get(self, url: str, headers: dict | None = None, timeout: int | None timeout=timeout or self.default_timeout, ) + @asynccontextmanager + async def _get_streaming( + self, url: str, headers: dict | None = None, timeout: int | None = None + ) -> AsyncIterator[httpx.Response]: + """Execute a streaming HTTP GET with HTTPX. + + Returns an async context manager that yields the streaming response. + Use this for downloading large files without loading into memory. + + Raises: + ServerNotReachableError if we are not able to connect to the server + ServerNotResponsiveError if the server didn't respond before the timeout expired + """ + await self.login() + + headers = headers or {} + base_headers = copy.copy(self.headers or {}) + headers.update(base_headers) + + async with httpx.AsyncClient(**self._build_proxy_config(), verify=self.config.tls_context) as client: + try: + async with client.stream( + method="GET", url=url, headers=headers, timeout=timeout or self.default_timeout + ) as response: + yield response + except httpx.NetworkError as exc: + raise ServerNotReachableError(address=self.address) from exc + except httpx.ReadTimeout as exc: + raise ServerNotResponsiveError(url=url, timeout=timeout or self.default_timeout) from exc + async def _request( self, url: str, @@ -1081,19 +1219,7 @@ async def _default_request_method( if payload: params["json"] = payload - proxy_config: ProxyConfig = {"proxy": None, "mounts": None} - if self.config.proxy: - proxy_config["proxy"] = self.config.proxy - elif self.config.proxy_mounts.is_set: - proxy_config["mounts"] = { - key: httpx.AsyncHTTPTransport(proxy=value) - for key, value in self.config.proxy_mounts.model_dump(by_alias=True).items() - } - - async with httpx.AsyncClient( - **proxy_config, - verify=self.config.tls_context, - ) as client: + async with httpx.AsyncClient(**self._build_proxy_config(), verify=self.config.tls_context) as client: try: response = await client.request( method=method.value, @@ -1742,10 +1868,11 @@ async def convert_object_type( for more information. """ - if fields_mapping is None: - mapping_dict = {} - else: - mapping_dict = {field_name: model.model_dump(mode="json") for field_name, model in fields_mapping.items()} + mapping_dict = ( + {} + if fields_mapping is None + else {field_name: model.model_dump(mode="json") for field_name, model in fields_mapping.items()} + ) branch_name = branch or self.default_branch response = await self.execute_graphql( @@ -1910,7 +2037,7 @@ def execute_graphql( messages = [error.get("message") for error in errors] raise AuthenticationError(" | ".join(messages)) from exc if exc.response.status_code == 404: - raise URLNotFoundError(url=url) + raise URLNotFoundError(url=url) from exc if not resp: raise Error("Unexpected situation, resp hasn't been initialized.") @@ -1924,6 +2051,126 @@ def execute_graphql( # TODO add a special method to execute mutation that will check if the method returned OK + def _execute_graphql_with_file( + self, + query: str, + variables: dict | None = None, + file_content: BinaryIO | None = None, + file_name: str | None = None, + branch_name: str | None = None, + timeout: int | None = None, + tracker: str | None = None, + ) -> dict: + """Execute a GraphQL mutation with a file upload using multipart/form-data. + + This method follows the GraphQL Multipart Request Spec for file uploads. + The file is attached to the 'file' variable in the mutation. + + Args: + query: GraphQL mutation query that includes a $file variable of type Upload! + variables: Variables to pass along with the GraphQL query. + file_content: The file content as a file-like object (BinaryIO). + file_name: The name of the file being uploaded. + branch_name: Name of the branch on which the mutation will be executed. + timeout: Timeout in seconds for the query. + tracker: Optional tracker for request tracing. + + Raises: + GraphQLError: When the GraphQL response contains errors. + + Returns: + dict: The GraphQL data payload (response["data"]). + """ + branch_name = branch_name or self.default_branch + url = self._graphql_url(branch_name=branch_name) + + # Prepare variables with file placeholder + variables = variables or {} + variables["file"] = None + + headers = copy.copy(self.headers or {}) + # Remove content-type header - httpx will set it for multipart + headers.pop("content-type", None) + if self.insert_tracker and tracker: + headers["X-Infrahub-Tracker"] = tracker + + self._echo(url=url, query=query, variables=variables) + + resp = self._post_multipart( + url=url, + query=query, + variables=variables, + file_content=file_content, + file_name=file_name or "upload", + headers=headers, + timeout=timeout, + ) + + resp.raise_for_status() + response = decode_json(response=resp) + + if "errors" in response: + raise GraphQLError(errors=response["errors"], query=query, variables=variables) + + return response["data"] + + @handle_relogin_sync + def _post_multipart( + self, + url: str, + query: str, + variables: dict, + file_content: BinaryIO | None, + file_name: str, + headers: dict | None = None, + timeout: int | None = None, + ) -> httpx.Response: + """Execute a HTTP POST with multipart/form-data for GraphQL file uploads. + + The file_content is streamed directly from the file-like object, avoiding loading the entire file into memory for large files. + """ + self.login() + + headers = headers or {} + base_headers = copy.copy(self.headers or {}) + # Remove content-type from base headers - httpx will set it for multipart + base_headers.pop("content-type", None) + headers.update(base_headers) + + # Build the multipart form data according to GraphQL Multipart Request Spec + files = MultipartBuilder.build_payload( + query=query, variables=variables, file_content=file_content, file_name=file_name + ) + + return self._request_multipart(url=url, headers=headers, timeout=timeout or self.default_timeout, files=files) + + def _build_proxy_config(self) -> ProxyConfigSync: + """Build proxy configuration for httpx Client.""" + proxy_config: ProxyConfigSync = {"proxy": None, "mounts": None} + if self.config.proxy: + proxy_config["proxy"] = self.config.proxy + elif self.config.proxy_mounts.is_set: + proxy_config["mounts"] = { + key: httpx.HTTPTransport(proxy=value) + for key, value in self.config.proxy_mounts.model_dump(by_alias=True).items() + } + return proxy_config + + def _request_multipart( + self, url: str, headers: dict[str, Any], timeout: int, files: dict[str, Any] + ) -> httpx.Response: + """Execute a multipart HTTP POST request.""" + with httpx.Client(**self._build_proxy_config(), verify=self.config.tls_context) as client: + try: + response = client.post(url=url, headers=headers, timeout=timeout, files=files) + except httpx.NetworkError as exc: + raise ServerNotReachableError(address=self.address) from exc + except httpx.ReadTimeout as exc: + raise ServerNotResponsiveError(url=url, timeout=timeout) from exc + + self._record(response) + return response + def count( self, kind: str | type[SchemaType], @@ -2995,6 +3242,36 @@ def _get(self, url: str, headers: dict | None = None, timeout: int | None = None timeout=timeout or self.default_timeout, ) + @contextmanager + def _get_streaming( + self, url: str, headers: dict | None = None, timeout: int | None = None + ) -> Iterator[httpx.Response]: + """Execute a streaming HTTP GET with HTTPX. + + Returns a context manager that yields the streaming response. + Use this for downloading large files without loading into memory. + + Raises: + ServerNotReachableError if we are not able to connect to the server + ServerNotResponsiveError if the server didn't respond before the timeout expired + """ + self.login() + + headers = headers or {} + base_headers = copy.copy(self.headers or {}) + headers.update(base_headers) + + with httpx.Client(**self._build_proxy_config(), verify=self.config.tls_context) as client: + try: + with client.stream( + method="GET", url=url, headers=headers, timeout=timeout or self.default_timeout + ) as response: + yield response + except httpx.NetworkError as exc: + raise ServerNotReachableError(address=self.address) from exc + except httpx.ReadTimeout as exc: + raise ServerNotResponsiveError(url=url, timeout=timeout or self.default_timeout) from exc + @handle_relogin_sync def _post( self, @@ -3047,20 +3324,7 @@ def _default_request_method( if payload: params["json"] = payload - proxy_config: ProxyConfigSync = {"proxy": None, "mounts": None} - - if self.config.proxy: - proxy_config["proxy"] = self.config.proxy - elif self.config.proxy_mounts.is_set: - proxy_config["mounts"] = { - key: httpx.HTTPTransport(proxy=value) - for key, value in self.config.proxy_mounts.model_dump(by_alias=True).items() - } - - with httpx.Client( - **proxy_config, - verify=self.config.tls_context, - ) as client: + with httpx.Client(**self._build_proxy_config(), verify=self.config.tls_context) as client: try: response = client.request( method=method.value, @@ -3162,10 +3426,11 @@ def convert_object_type( for more information. """ - if fields_mapping is None: - mapping_dict = {} - else: - mapping_dict = {field_name: model.model_dump(mode="json") for field_name, model in fields_mapping.items()} + mapping_dict = ( + {} + if fields_mapping is None + else {field_name: model.model_dump(mode="json") for field_name, model in fields_mapping.items()} + ) branch_name = branch or self.default_branch response = self.execute_graphql( diff --git a/infrahub_sdk/ctl/branch.py b/infrahub_sdk/ctl/branch.py index 60d67e86..d309c1d5 100644 --- a/infrahub_sdk/ctl/branch.py +++ b/infrahub_sdk/ctl/branch.py @@ -110,6 +110,10 @@ def generate_proposed_change_tables(proposed_changes: list[CoreProposedChange]) proposed_change_tables: list[Table] = [] for pc in proposed_changes: + metadata = pc.get_node_metadata() + created_by = metadata.created_by.display_label if metadata and metadata.created_by else "-" + created_at = format_timestamp(metadata.created_at) if metadata and metadata.created_at else "-" + # Create proposal table proposed_change_table = Table(show_header=False, box=None) proposed_change_table.add_column(justify="left") @@ -119,8 +123,8 @@ def generate_proposed_change_tables(proposed_changes: list[CoreProposedChange]) proposed_change_table.add_row("Name", pc.name.value) proposed_change_table.add_row("State", str(pc.state.value)) proposed_change_table.add_row("Is draft", "Yes" if pc.is_draft.value else "No") - proposed_change_table.add_row("Created by", pc.created_by.peer.name.value) # type: ignore[union-attr] - proposed_change_table.add_row("Created at", format_timestamp(str(pc.created_by.updated_at))) + proposed_change_table.add_row("Created by", created_by) + proposed_change_table.add_row("Created at", created_at) proposed_change_table.add_row("Approvals", str(len(pc.approved_by.peers))) proposed_change_table.add_row("Rejections", str(len(pc.rejected_by.peers))) @@ -295,9 +299,9 @@ async def report( proposed_changes = await client.filters( kind=CoreProposedChange, # type: ignore[type-abstract] source_branch__value=branch_name, - include=["created_by"], prefetch_relationships=True, property=True, + include_metadata=True, ) branch_table = generate_branch_report_table(branch=branch, diff_tree=diff_tree, git_files_changed=git_files_changed) diff --git a/infrahub_sdk/ctl/cli_commands.py b/infrahub_sdk/ctl/cli_commands.py index 2b571723..d7a636ed 100644 --- a/infrahub_sdk/ctl/cli_commands.py +++ b/infrahub_sdk/ctl/cli_commands.py @@ -239,7 +239,7 @@ async def _run_transform( elif isinstance(error, str) and "Branch:" in error: console.print(f"[yellow] - {error}") console.print("[yellow] you can specify a different branch with --branch") - raise typer.Abort + raise typer.Abort from None if inspect.iscoroutinefunction(transform_func): output = await transform_func(response) @@ -350,10 +350,7 @@ def transform( # Run Transform result = asyncio.run(transform.run(data=data)) - if isinstance(result, str): - json_string = result - else: - json_string = ujson.dumps(result, indent=2, sort_keys=True) + json_string = result if isinstance(result, str) else ujson.dumps(result, indent=2, sort_keys=True) if out: write_to_file(Path(out), json_string) diff --git a/infrahub_sdk/ctl/config.py b/infrahub_sdk/ctl/config.py index b3d2a404..a5b522b2 100644 --- a/infrahub_sdk/ctl/config.py +++ b/infrahub_sdk/ctl/config.py @@ -90,7 +90,7 @@ def load_and_exit(self, config_file: str | Path = "infrahubctl.toml", config_dat for error in exc.errors(): loc_str = [str(item) for item in error["loc"]] print(f" {'/'.join(loc_str)} | {error['msg']} ({error['type']})") - raise typer.Abort + raise typer.Abort from None SETTINGS = ConfiguredSettings() diff --git a/infrahub_sdk/ctl/exporter.py b/infrahub_sdk/ctl/exporter.py index ae5e5d18..402b47ec 100644 --- a/infrahub_sdk/ctl/exporter.py +++ b/infrahub_sdk/ctl/exporter.py @@ -46,4 +46,4 @@ def dump( aiorun(exporter.export(export_directory=directory, namespaces=namespace, branch=branch, exclude=exclude)) except TransferError as exc: console.print(f"[red]{exc}") - raise typer.Exit(1) + raise typer.Exit(1) from None diff --git a/infrahub_sdk/ctl/importer.py b/infrahub_sdk/ctl/importer.py index 420c6d75..d3318eb5 100644 --- a/infrahub_sdk/ctl/importer.py +++ b/infrahub_sdk/ctl/importer.py @@ -50,4 +50,4 @@ def load( aiorun(importer.import_data(import_directory=directory, branch=branch)) except TransferError as exc: console.print(f"[red]{exc}") - raise typer.Exit(1) + raise typer.Exit(1) from None diff --git a/infrahub_sdk/ctl/schema.py b/infrahub_sdk/ctl/schema.py index 5a977b59..9532959e 100644 --- a/infrahub_sdk/ctl/schema.py +++ b/infrahub_sdk/ctl/schema.py @@ -2,6 +2,7 @@ import asyncio import time +from datetime import datetime, timezone from pathlib import Path from typing import TYPE_CHECKING, Any @@ -211,3 +212,49 @@ def _display_schema_warnings(console: Console, warnings: list[SchemaWarning]) -> console.print( f"[yellow] {warning.type.value}: {warning.message} [{', '.join([kind.display for kind in warning.kinds])}]" ) + + +def _default_export_directory() -> Path: + timestamp = datetime.now(timezone.utc).astimezone().strftime("%Y%m%d-%H%M%S") + return Path(f"infrahub-schema-export-{timestamp}") + + +@app.command() +@catch_exception(console=console) +async def export( + directory: Path = typer.Option(_default_export_directory, help="Directory path to store schema files"), + branch: str = typer.Option(None, help="Branch from which to export the schema"), + namespaces: list[str] = typer.Option([], help="Namespace(s) to export (default: all user-defined)"), + debug: bool = False, + _: str = CONFIG_PARAM, +) -> None: + """Export the schema from Infrahub as YAML files, one per namespace.""" + init_logging(debug=debug) + + client = initialize_client() + user_schemas = await client.schema.export( + branch=branch, + namespaces=namespaces or None, + ) + + if not user_schemas.namespaces: + console.print("[yellow]No user-defined schema found to export.") + return + + directory.mkdir(parents=True, exist_ok=True) + + for ns, data in sorted(user_schemas.namespaces.items()): + payload: dict[str, Any] = {"version": "1.0"} + if data.generics: + payload["generics"] = data.generics + if data.nodes: + payload["nodes"] = data.nodes + + output_file = directory / f"{ns.lower()}.yml" + output_file.write_text( + yaml.dump(payload, default_flow_style=False, sort_keys=False, allow_unicode=True), + encoding="utf-8", + ) + console.print(f"[green] Exported namespace '{ns}' to {output_file}") + + console.print(f"[green] Schema exported to {directory}") diff --git a/infrahub_sdk/ctl/validate.py b/infrahub_sdk/ctl/validate.py index 3ffbd85a..07256faf 100644 --- a/infrahub_sdk/ctl/validate.py +++ b/infrahub_sdk/ctl/validate.py @@ -48,7 +48,7 @@ async def validate_schema(schema: Path, _: str = CONFIG_PARAM) -> None: for error in exc.errors(): loc_str = [str(item) for item in error["loc"]] console.print(f" '{'/'.join(loc_str)}' | {error['msg']} ({error['type']})") - raise typer.Exit(1) + raise typer.Exit(1) from None console.print("[green]Schema is valid !!") diff --git a/infrahub_sdk/exceptions.py b/infrahub_sdk/exceptions.py index e1dba1d5..727239bf 100644 --- a/infrahub_sdk/exceptions.py +++ b/infrahub_sdk/exceptions.py @@ -136,7 +136,7 @@ def __init__(self, position: list[int | str], message: str) -> None: super().__init__(self.message) def __str__(self) -> str: - return f"{'.'.join(map(str, self.position))}: {self.message}" + return f"{'.'.join(str(p) for p in self.position)}: {self.message}" class AuthenticationError(Error): diff --git a/infrahub_sdk/file_handler.py b/infrahub_sdk/file_handler.py new file mode 100644 index 00000000..5d32441a --- /dev/null +++ b/infrahub_sdk/file_handler.py @@ -0,0 +1,348 @@ +from __future__ import annotations + +from dataclasses import dataclass +from io import BytesIO +from pathlib import Path +from typing import TYPE_CHECKING, BinaryIO, cast, overload + +import anyio +import httpx + +from .exceptions import AuthenticationError, NodeNotFoundError, ServerNotReachableError + +if TYPE_CHECKING: + from .client import InfrahubClient, InfrahubClientSync + + +@dataclass +class PreparedFile: + file_object: BinaryIO | None + filename: str | None + should_close: bool + + +class FileHandlerBase: + """Base class for file handling operations. + + Provides common functionality for both async and sync file handlers, including upload preparation and error handling. + """ + + @staticmethod + async def prepare_upload(content: bytes | Path | BinaryIO | None, name: str | None = None) -> PreparedFile: + """Prepare file content for upload (async version). + + Converts various content types to a consistent BinaryIO interface for streaming uploads. + For Path inputs, opens the file handle in a thread pool to avoid blocking the event loop. + The actual file reading is streamed by httpx during the HTTP request. + + Args: + content: The file content as bytes, a Path to a file, or a file-like object. + Can be None if no file is set. + name: Optional filename. If not provided and content is a Path, + the filename will be derived from the path. + + Returns: + A PreparedFile containing the file object, filename, and whether it should be closed. + """ + if content is None: + return PreparedFile(file_object=None, filename=None, should_close=False) + + if name is None and isinstance(content, Path): + name = content.name + + filename = name or "uploaded_file" + + if isinstance(content, bytes): + return PreparedFile(file_object=BytesIO(content), filename=filename, should_close=False) + if isinstance(content, Path): + # Open file in thread pool to avoid blocking the event loop + # Returns a sync file handle that httpx can stream from in chunks + file_obj = await anyio.to_thread.run_sync(content.open, "rb") + return PreparedFile(file_object=cast("BinaryIO", file_obj), filename=filename, should_close=True) + + # At this point, content must be a BinaryIO (file-like object) + return PreparedFile(file_object=cast("BinaryIO", content), filename=filename, should_close=False) + + @staticmethod + def prepare_upload_sync(content: bytes | Path | BinaryIO | None, name: str | None = None) -> PreparedFile: + """Prepare file content for upload (sync version). + + Converts various content types to a consistent BinaryIO interface for streaming uploads. + + Args: + content: The file content as bytes, a Path to a file, or a file-like object. + Can be None if no file is set. + name: Optional filename. If not provided and content is a Path, + the filename will be derived from the path. + + Returns: + A PreparedFile containing the file object, filename, and whether it should be closed. + """ + if content is None: + return PreparedFile(file_object=None, filename=None, should_close=False) + + if name is None and isinstance(content, Path): + name = content.name + + filename = name or "uploaded_file" + + if isinstance(content, bytes): + return PreparedFile(file_object=BytesIO(content), filename=filename, should_close=False) + if isinstance(content, Path): + return PreparedFile(file_object=content.open("rb"), filename=filename, should_close=True) + + # At this point, content must be a BinaryIO (file-like object) + return PreparedFile(file_object=cast("BinaryIO", content), filename=filename, should_close=False) + + @staticmethod + def handle_error_response(exc: httpx.HTTPStatusError) -> None: + """Handle HTTP error responses for file operations. + + Args: + exc: The HTTP status error from httpx. + + Raises: + AuthenticationError: If authentication fails (401/403). + NodeNotFoundError: If the file/node is not found (404). + httpx.HTTPStatusError: For other HTTP errors. + """ + if exc.response.status_code in {401, 403}: + response = exc.response.json() + errors = response.get("errors", []) + messages = [error.get("message") for error in errors] + raise AuthenticationError(" | ".join(messages)) from exc + if exc.response.status_code == 404: + response = exc.response.json() + detail = response.get("detail", "File not found") + raise NodeNotFoundError(node_type="FileObject", identifier=detail) from exc + raise exc + + @staticmethod + def handle_response(resp: httpx.Response) -> bytes: + """Handle the HTTP response and return file content as bytes. + + Args: + resp: The HTTP response from httpx. + + Returns: + The file content as bytes. + + Raises: + AuthenticationError: If authentication fails. + NodeNotFoundError: If the file is not found. + """ + try: + resp.raise_for_status() + except httpx.HTTPStatusError as exc: + FileHandlerBase.handle_error_response(exc=exc) + return resp.content + + +class FileHandler(FileHandlerBase): + """Async file handler for download operations. + + Handles file downloads with support for streaming to disk + for memory-efficient handling of large files. + """ + + def __init__(self, client: InfrahubClient) -> None: + """Initialize the async file handler. + + Args: + client: The async Infrahub client instance. + """ + self._client = client + + def _build_url(self, node_id: str, branch: str | None) -> str: + """Build the download URL for a file. + + Args: + node_id: The ID of the FileObject node. + branch: Optional branch name. + + Returns: + The complete URL for downloading the file. + """ + url = f"{self._client.address}/api/storage/files/{node_id}" + if branch: + url = f"{url}?branch={branch}" + return url + + @overload + async def download(self, node_id: str, branch: str | None) -> bytes: ... + + @overload + async def download(self, node_id: str, branch: str | None, dest: Path) -> int: ... + + @overload + async def download(self, node_id: str, branch: str | None, dest: None) -> bytes: ... + + async def download(self, node_id: str, branch: str | None, dest: Path | None = None) -> bytes | int: + """Download file content from a FileObject node. + + Args: + node_id: The ID of the FileObject node. + branch: Optional branch name. Uses client default if not provided. + dest: Optional destination path. If provided, streams to disk. + + Returns: + If dest is None: The file content as bytes. + If dest is provided: The number of bytes written. + + Raises: + ServerNotReachableError: If the server is not reachable. + AuthenticationError: If authentication fails. + NodeNotFoundError: If the node/file is not found. + """ + effective_branch = branch or self._client.default_branch + url = self._build_url(node_id=node_id, branch=effective_branch) + + if dest is not None: + return await self._stream_to_file(url=url, dest=dest) + + try: + resp = await self._client._get(url=url) + except ServerNotReachableError: + self._client.log.error(f"Unable to connect to {self._client.address}") + raise + + return self.handle_response(resp=resp) + + async def _stream_to_file(self, url: str, dest: Path) -> int: + """Stream download directly to a file without loading into memory. + + Args: + url: The URL to download from. + dest: The destination path to write to. + + Returns: + The number of bytes written to the file. + + Raises: + ServerNotReachableError: If the server is not reachable. + AuthenticationError: If authentication fails. + NodeNotFoundError: If the file is not found. + """ + try: + async with self._client._get_streaming(url=url) as resp: + try: + resp.raise_for_status() + except httpx.HTTPStatusError as exc: + # Need to read the response body for error details + await resp.aread() + self.handle_error_response(exc=exc) + + bytes_written = 0 + async with await anyio.Path(dest).open("wb") as f: + async for chunk in resp.aiter_bytes(chunk_size=65536): + await f.write(chunk) + bytes_written += len(chunk) + return bytes_written + except ServerNotReachableError: + self._client.log.error(f"Unable to connect to {self._client.address}") + raise + + +class FileHandlerSync(FileHandlerBase): + """Sync file handler for download operations. + + Handles file downloads with support for streaming to disk + for memory-efficient handling of large files. + """ + + def __init__(self, client: InfrahubClientSync) -> None: + """Initialize the sync file handler. + + Args: + client: The sync Infrahub client instance. + """ + self._client = client + + def _build_url(self, node_id: str, branch: str | None) -> str: + """Build the download URL for a file. + + Args: + node_id: The ID of the FileObject node. + branch: Optional branch name. + + Returns: + The complete URL for downloading the file. + """ + url = f"{self._client.address}/api/storage/files/{node_id}" + if branch: + url = f"{url}?branch={branch}" + return url + + @overload + def download(self, node_id: str, branch: str | None) -> bytes: ... + + @overload + def download(self, node_id: str, branch: str | None, dest: Path) -> int: ... + + @overload + def download(self, node_id: str, branch: str | None, dest: None) -> bytes: ... + + def download(self, node_id: str, branch: str | None, dest: Path | None = None) -> bytes | int: + """Download file content from a FileObject node. + + Args: + node_id: The ID of the FileObject node. + branch: Optional branch name. Uses client default if not provided. + dest: Optional destination path. If provided, streams to disk. + + Returns: + If dest is None: The file content as bytes. + If dest is provided: The number of bytes written. + + Raises: + ServerNotReachableError: If the server is not reachable. + AuthenticationError: If authentication fails. + NodeNotFoundError: If the node/file is not found. + """ + effective_branch = branch or self._client.default_branch + url = self._build_url(node_id=node_id, branch=effective_branch) + + if dest is not None: + return self._stream_to_file(url=url, dest=dest) + + try: + resp = self._client._get(url=url) + except ServerNotReachableError: + self._client.log.error(f"Unable to connect to {self._client.address}") + raise + + return self.handle_response(resp=resp) + + def _stream_to_file(self, url: str, dest: Path) -> int: + """Stream download directly to a file without loading into memory. + + Args: + url: The URL to download from. + dest: The destination path to write to. + + Returns: + The number of bytes written to the file. + + Raises: + ServerNotReachableError: If the server is not reachable. + AuthenticationError: If authentication fails. + NodeNotFoundError: If the file is not found. + """ + try: + with self._client._get_streaming(url=url) as resp: + try: + resp.raise_for_status() + except httpx.HTTPStatusError as exc: + # Need to read the response body for error details + resp.read() + self.handle_error_response(exc=exc) + + bytes_written = 0 + with dest.open("wb") as f: + for chunk in resp.iter_bytes(chunk_size=65536): + f.write(chunk) + bytes_written += len(chunk) + return bytes_written + except ServerNotReachableError: + self._client.log.error(f"Unable to connect to {self._client.address}") + raise diff --git a/infrahub_sdk/graphql/__init__.py b/infrahub_sdk/graphql/__init__.py index 33438e35..743919b6 100644 --- a/infrahub_sdk/graphql/__init__.py +++ b/infrahub_sdk/graphql/__init__.py @@ -1,9 +1,11 @@ from .constants import VARIABLE_TYPE_MAPPING +from .multipart import MultipartBuilder from .query import Mutation, Query from .renderers import render_input_block, render_query_block, render_variables_to_string __all__ = [ "VARIABLE_TYPE_MAPPING", + "MultipartBuilder", "Mutation", "Query", "render_input_block", diff --git a/infrahub_sdk/graphql/constants.py b/infrahub_sdk/graphql/constants.py index e2033155..0fed5c57 100644 --- a/infrahub_sdk/graphql/constants.py +++ b/infrahub_sdk/graphql/constants.py @@ -1,4 +1,6 @@ from datetime import datetime +from pathlib import Path +from typing import BinaryIO VARIABLE_TYPE_MAPPING = ( (str, "String!"), @@ -11,4 +13,7 @@ (bool | None, "Boolean"), (datetime, "DateTime!"), (datetime | None, "DateTime"), + (bytes, "Upload!"), + (Path, "Upload!"), + (BinaryIO, "Upload!"), ) diff --git a/infrahub_sdk/graphql/multipart.py b/infrahub_sdk/graphql/multipart.py new file mode 100644 index 00000000..bdb1f84e --- /dev/null +++ b/infrahub_sdk/graphql/multipart.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import ujson + +if TYPE_CHECKING: + from typing import BinaryIO + + +class MultipartBuilder: + """Builds multipart form data payloads for GraphQL file uploads. + + This class implements the GraphQL Multipart Request Spec for uploading files via GraphQL mutations. The spec defines a standard way to send files + alongside GraphQL operations using multipart/form-data. + + The payload structure follows the spec: + - operations: JSON containing the GraphQL query and variables + - map: JSON mapping file keys to variable paths + - 0, 1, ...: The actual file contents + + Example payload: + { + "operations": '{"query": "mutation($file: Upload!) {...}", "variables": {"file": null}}', + "map": '{"0": ["variables.file"]}', + "0": (filename, file_content) + } + """ + + @staticmethod + def build_operations(query: str, variables: dict[str, Any]) -> str: + """Build the operations JSON string. + + Args: + query: The GraphQL query string. + variables: The variables dict (file variable should be null). + + Returns: + JSON string containing the query and variables. + """ + return ujson.dumps({"query": query, "variables": variables}) + + @staticmethod + def build_file_map(file_key: str = "0", variable_path: str = "variables.file") -> str: + """Build the file map JSON string. + + Args: + file_key: The key used for the file in the multipart payload. + variable_path: The path to the file variable in the GraphQL variables. + + Returns: + JSON string mapping the file key to the variable path. + """ + return ujson.dumps({file_key: [variable_path]}) + + @staticmethod + def build_payload( + query: str, + variables: dict[str, Any], + file_content: BinaryIO | None = None, + file_name: str = "upload", + ) -> dict[str, Any]: + """Build the complete multipart form data payload. + + Constructs the payload according to the GraphQL Multipart Request Spec. The returned dict can be passed directly to httpx as the `files` + parameter. + + Args: + query: The GraphQL query string containing $file: Upload! variable. + variables: The variables dict. The 'file' key will be set to null. + file_content: The file content as a file-like object (BinaryIO). + If None, only the operations and map will be included. + file_name: The filename to use for the upload. + + Returns: + A dict suitable for httpx's `files` parameter in a POST request. + + Example: + >>> builder = MultipartBuilder() + >>> payload = builder.build_payload( + ... query="mutation($file: Upload!) { upload(file: $file) { id } }", + ... variables={"other": "value"}, + ... file_content=open("file.pdf", "rb"), + ... file_name="document.pdf", + ... ) + >>> # payload can be passed to httpx.post(..., files=payload) + """ + # Ensure file variable is null (spec requirement) + variables = {**variables, "file": None} + + operations = MultipartBuilder.build_operations(query=query, variables=variables) + file_map = MultipartBuilder.build_file_map() + + files: dict[str, Any] = {"operations": (None, operations), "map": (None, file_map)} + + if file_content is not None: + # httpx streams from file-like objects automatically + files["0"] = (file_name, file_content) + + return files diff --git a/infrahub_sdk/graphql/renderers.py b/infrahub_sdk/graphql/renderers.py index b0d2ab28..5b6c2c0f 100644 --- a/infrahub_sdk/graphql/renderers.py +++ b/infrahub_sdk/graphql/renderers.py @@ -3,7 +3,8 @@ import json from datetime import datetime from enum import Enum -from typing import Any +from pathlib import Path +from typing import Any, BinaryIO from pydantic import BaseModel @@ -88,7 +89,7 @@ def convert_to_graphql_as_string(value: Any, convert_enum: bool = False) -> str: return str(value) -GRAPHQL_VARIABLE_TYPES = type[str | int | float | bool | datetime | None] +GRAPHQL_VARIABLE_TYPES = type[str | int | float | bool | datetime | bytes | Path | BinaryIO | None] def render_variables_to_string(data: dict[str, GRAPHQL_VARIABLE_TYPES]) -> str: @@ -148,10 +149,7 @@ def render_query_block(data: dict, offset: int = 4, indentation: int = 4, conver elif isinstance(value, dict) and len(value) == 1 and alias_key in value and value[alias_key]: lines.append(f"{offset_str}{value[alias_key]}: {key}") elif isinstance(value, dict): - if value.get(alias_key): - key_str = f"{value[alias_key]}: {key}" - else: - key_str = key + key_str = f"{value[alias_key]}: {key}" if value.get(alias_key) else key if value.get(filters_key): filters_str = ", ".join( diff --git a/infrahub_sdk/node/constants.py b/infrahub_sdk/node/constants.py index 8d301115..7a0bc6fd 100644 --- a/infrahub_sdk/node/constants.py +++ b/infrahub_sdk/node/constants.py @@ -27,6 +27,9 @@ ARTIFACT_DEFINITION_GENERATE_FEATURE_NOT_SUPPORTED_MESSAGE = ( "calling generate is only supported for CoreArtifactDefinition nodes" ) +FILE_DOWNLOAD_FEATURE_NOT_SUPPORTED_MESSAGE = ( + "calling download_file is only supported for nodes that inherit from CoreFileObject" +) HIERARCHY_FETCH_FEATURE_NOT_SUPPORTED_MESSAGE = "Hierarchical fields are not supported for this node." diff --git a/infrahub_sdk/node/node.py b/infrahub_sdk/node/node.py index 9d024cbb..a47209dc 100644 --- a/infrahub_sdk/node/node.py +++ b/infrahub_sdk/node/node.py @@ -2,10 +2,12 @@ from collections.abc import Iterable from copy import copy, deepcopy -from typing import TYPE_CHECKING, Any +from pathlib import Path +from typing import TYPE_CHECKING, Any, BinaryIO from ..constants import InfrahubClientMode from ..exceptions import FeatureNotSupportedError, NodeNotFoundError, ResourceNotDefinedError, SchemaNotFoundError +from ..file_handler import FileHandler, FileHandlerBase, FileHandlerSync, PreparedFile from ..graphql import Mutation, Query from ..schema import ( GenericSchemaAPI, @@ -21,6 +23,7 @@ ARTIFACT_DEFINITION_GENERATE_FEATURE_NOT_SUPPORTED_MESSAGE, ARTIFACT_FETCH_FEATURE_NOT_SUPPORTED_MESSAGE, ARTIFACT_GENERATE_FEATURE_NOT_SUPPORTED_MESSAGE, + FILE_DOWNLOAD_FEATURE_NOT_SUPPORTED_MESSAGE, PROPERTIES_OBJECT, ) from .metadata import NodeMetadata @@ -65,14 +68,15 @@ def __init__(self, schema: MainSchemaTypesAPI, branch: str, data: dict | None = self._attributes = [item.name for item in self._schema.attributes] self._relationships = [item.name for item in self._schema.relationships] - # GenericSchemaAPI doesn't have inherit_from, so we need to check the type first - if isinstance(schema, GenericSchemaAPI): - self._artifact_support = False - else: - inherit_from = getattr(schema, "inherit_from", None) or [] - self._artifact_support = "CoreArtifactTarget" in inherit_from + # GenericSchemaAPI doesn't have inherit_from + inherit_from: list[str] = getattr(schema, "inherit_from", None) or [] + self._artifact_support = "CoreArtifactTarget" in inherit_from + self._file_object_support = "CoreFileObject" in inherit_from self._artifact_definition_support = schema.kind == "CoreArtifactDefinition" + self._file_content: bytes | Path | BinaryIO | None = None + self._file_name: str | None = None + # Check if this node is hierarchical (supports parent/children and ancestors/descendants) if not isinstance(schema, (ProfileSchemaAPI, GenericSchemaAPI, TemplateSchemaAPI)): self._hierarchy_support = getattr(schema, "hierarchy", None) is not None @@ -143,7 +147,7 @@ def get_human_friendly_id_as_string(self, include_kind: bool = False) -> str | N if not hfid: return None if include_kind: - hfid = [self.get_kind()] + hfid + hfid = [self.get_kind(), *hfid] return "__".join(hfid) @property @@ -199,7 +203,7 @@ def get_kind(self) -> str: def get_all_kinds(self) -> list[str]: if inherit_from := getattr(self._schema, "inherit_from", None): - return [self._schema.kind] + inherit_from + return [self._schema.kind, *inherit_from] return [self._schema.kind] def is_ip_prefix(self) -> bool: @@ -213,6 +217,72 @@ def is_ip_address(self) -> bool: def is_resource_pool(self) -> bool: return hasattr(self._schema, "inherit_from") and "CoreResourcePool" in self._schema.inherit_from # type: ignore[union-attr] + def is_file_object(self) -> bool: + """Check if this node inherits from CoreFileObject and supports file uploads.""" + return self._file_object_support + + def upload_from_path(self, path: Path) -> None: + """Set a file from disk to be uploaded when saving this FileObject node. + + The file will be streamed during upload, avoiding loading the entire file into memory. + + Args: + path: Path to the file on disk. + + Raises: + FeatureNotSupportedError: If this node doesn't inherit from CoreFileObject. + + Example: + node.upload_from_path(path=Path("/path/to/large_file.pdf")) + """ + if not self._file_object_support: + raise FeatureNotSupportedError( + f"File upload is not supported for {self._schema.kind}. Only nodes inheriting from CoreFileObject support file uploads." + ) + self._file_content = path + self._file_name = path.name + + def upload_from_bytes(self, content: bytes | BinaryIO, name: str) -> None: + """Set content to be uploaded when saving this FileObject node. + + The content can be provided as bytes or a file-like object. + Using BinaryIO is recommended for large content to stream during upload. + + Args: + content: The file content as bytes or a file-like object. + name: The filename to use for the uploaded file. + + Raises: + FeatureNotSupportedError: If this node doesn't inherit from CoreFileObject. + + Examples: + >>> # Using bytes (for small files) + >>> node.upload_from_bytes(content=b"file content", name="example.txt") + + >>> # Using file-like object (for large files) + >>> with open("/path/to/file.bin", "rb") as f: + ... node.upload_from_bytes(content=f, name="file.bin") + """ + if not self._file_object_support: + raise FeatureNotSupportedError( + f"File upload is not supported for {self._schema.kind}. Only nodes inheriting from CoreFileObject support file uploads." + ) + self._file_content = content + self._file_name = name + + def clear_file(self) -> None: + """Clear any pending file content.""" + self._file_content = None + self._file_name = None + + async def _get_file_for_upload(self) -> PreparedFile: + """Get the file content as a file-like object for upload (async version).""" + return await FileHandlerBase.prepare_upload(content=self._file_content, name=self._file_name) + + def _get_file_for_upload_sync(self) -> PreparedFile: + """Get the file content as a file-like object for upload (sync version).""" + return FileHandlerBase.prepare_upload_sync(content=self._file_content, name=self._file_name) + def get_raw_graphql_data(self) -> dict | None: return self._data @@ -288,10 +358,16 @@ def _generate_input_data( # noqa: C901 elif self.hfid is not None and not exclude_hfid: data["hfid"] = self.hfid - mutation_payload = {"data": data} + mutation_payload: dict[str, Any] = {"data": data} if context_data := self._get_request_context(request_context=request_context): mutation_payload["context"] = context_data + # Add file variable for FileObject nodes with pending file content + # file is a mutation argument at the same level as data, not inside data + if self._file_object_support and self._file_content is not None: + mutation_payload["file"] = "$file" + mutation_variables["file"] = bytes + return { "data": mutation_payload, "variables": variables, @@ -417,6 +493,10 @@ def _validate_artifact_definition_support(self, message: str) -> None: if not self._artifact_definition_support: raise FeatureNotSupportedError(message) + def _validate_file_object_support(self, message: str) -> None: + if not self._file_object_support: + raise FeatureNotSupportedError(message) + def generate_query_data_init( self, filters: dict[str, Any] | None = None, @@ -506,6 +586,7 @@ def __init__( data: Optional data to initialize the node. """ self._client = client + self._file_handler = FileHandler(client=client) # Extract node_metadata before extracting node data (node_metadata is sibling to node in edges) node_metadata_data: dict | None = None @@ -558,10 +639,7 @@ def _init_relationships(self, data: dict | RelatedNode | None = None) -> None: ) if value is not None } - if peer_id_data: - rel_data = peer_id_data - else: - rel_data = None + rel_data = peer_id_data or None self._relationship_cardinality_one_data[rel_schema.name] = RelatedNode( name=rel_schema.name, branch=self._branch, client=self._client, schema=rel_schema, data=rel_data ) @@ -694,6 +772,41 @@ async def artifact_fetch(self, name: str) -> str | dict[str, Any]: artifact = await self._client.get(kind="CoreArtifact", name__value=name, object__ids=[self.id]) return await self._client.object_store.get(identifier=artifact._get_attribute(name="storage_id").value) + async def download_file(self, dest: Path | None = None) -> bytes | int: + """Download the file content from this FileObject node. + + This method is only available for nodes that inherit from CoreFileObject. + The node must have been saved (have an id) before calling this method. + + Args: + dest: Optional destination path. If provided, the file will be streamed + directly to this path (memory-efficient for large files) and the + number of bytes written will be returned. If not provided, the + file content will be returned as bytes. + + Returns: + If ``dest`` is None: The file content as bytes. + If ``dest`` is provided: The number of bytes written to the file. + + Raises: + FeatureNotSupportedError: If this node doesn't inherit from CoreFileObject. + ValueError: If the node hasn't been saved yet or file not found. + AuthenticationError: If authentication fails. + + Examples: + >>> # Download to memory + >>> content = await contract.download_file() + + >>> # Stream to file (memory-efficient for large files) + >>> bytes_written = await contract.download_file(dest=Path("/tmp/contract.pdf")) + """ + self._validate_file_object_support(message=FILE_DOWNLOAD_FEATURE_NOT_SUPPORTED_MESSAGE) + + if not self.id: + raise ValueError("Cannot download file for a node that hasn't been saved yet.") + + return await self._file_handler.download(node_id=self.id, branch=self._branch, dest=dest) + async def delete(self, timeout: int | None = None, request_context: RequestContext | None = None) -> None: input_data = {"data": {"id": self.id}} if context_data := self._get_request_context(request_context=request_context): @@ -1024,6 +1137,12 @@ async def _process_mutation_result( async def create( self, allow_upsert: bool = False, timeout: int | None = None, request_context: RequestContext | None = None ) -> None: + if self._file_object_support and self._file_content is None: + raise ValueError( + f"Cannot create {self._schema.kind} without file content. Use upload_from_path() or upload_from_bytes() to provide " + "file content before saving." + ) + mutation_query = self._generate_mutation_query() # Upserting means we may want to create, meaning payload contains all mandatory fields required for a creation, @@ -1036,19 +1155,39 @@ async def create( input_data = self._generate_input_data(exclude_hfid=True, request_context=request_context) mutation_name = f"{self._schema.kind}Create" tracker = f"mutation-{str(self._schema.kind).lower()}-create" + query = Mutation( mutation=mutation_name, input_data=input_data["data"], query=mutation_query, variables=input_data["mutation_variables"], ) - response = await self._client.execute_graphql( - query=query.render(), - branch_name=self._branch, - tracker=tracker, - variables=input_data["variables"], - timeout=timeout, - ) + + if "file" in input_data["mutation_variables"]: + prepared = await self._get_file_for_upload() + try: + response = await self._client._execute_graphql_with_file( + query=query.render(), + variables=input_data["variables"], + file_content=prepared.file_object, + file_name=prepared.filename, + branch_name=self._branch, + tracker=tracker, + timeout=timeout, + ) + finally: + if prepared.should_close and prepared.file_object: + prepared.file_object.close() + # Clear the file content after successful upload + self.clear_file() + else: + response = await self._client.execute_graphql( + query=query.render(), + branch_name=self._branch, + tracker=tracker, + variables=input_data["variables"], + timeout=timeout, + ) await self._process_mutation_result(mutation_name=mutation_name, response=response, timeout=timeout) async def update( @@ -1057,6 +1196,7 @@ async def update( input_data = self._generate_input_data(exclude_unmodified=not do_full_update, request_context=request_context) mutation_query = self._generate_mutation_query() mutation_name = f"{self._schema.kind}Update" + tracker = f"mutation-{str(self._schema.kind).lower()}-update" query = Mutation( mutation=mutation_name, @@ -1064,13 +1204,32 @@ async def update( query=mutation_query, variables=input_data["mutation_variables"], ) - response = await self._client.execute_graphql( - query=query.render(), - branch_name=self._branch, - timeout=timeout, - tracker=f"mutation-{str(self._schema.kind).lower()}-update", - variables=input_data["variables"], - ) + + if "file" in input_data["mutation_variables"]: + prepared = await self._get_file_for_upload() + try: + response = await self._client._execute_graphql_with_file( + query=query.render(), + variables=input_data["variables"], + file_content=prepared.file_object, + file_name=prepared.filename, + branch_name=self._branch, + tracker=tracker, + timeout=timeout, + ) + finally: + if prepared.should_close and prepared.file_object: + prepared.file_object.close() + # Clear the file content after successful upload + self.clear_file() + else: + response = await self._client.execute_graphql( + query=query.render(), + branch_name=self._branch, + timeout=timeout, + tracker=tracker, + variables=input_data["variables"], + ) await self._process_mutation_result(mutation_name=mutation_name, response=response, timeout=timeout) async def _process_relationships( @@ -1310,6 +1469,7 @@ def __init__( data (Optional[dict]): Optional data to initialize the node. """ self._client = client + self._file_handler = FileHandlerSync(client=client) # Extract node_metadata before extracting node data (node_metadata is sibling to node in edges) node_metadata_data: dict | None = None @@ -1362,10 +1522,7 @@ def _init_relationships(self, data: dict | None = None) -> None: ) if value is not None } - if peer_id_data: - rel_data = peer_id_data - else: - rel_data = None + rel_data = peer_id_data or None self._relationship_cardinality_one_data[rel_schema.name] = RelatedNodeSync( name=rel_schema.name, branch=self._branch, client=self._client, schema=rel_schema, data=rel_data ) @@ -1499,6 +1656,41 @@ def artifact_fetch(self, name: str) -> str | dict[str, Any]: artifact = self._client.get(kind="CoreArtifact", name__value=name, object__ids=[self.id]) return self._client.object_store.get(identifier=artifact._get_attribute(name="storage_id").value) + def download_file(self, dest: Path | None = None) -> bytes | int: + """Download the file content from this FileObject node. + + This method is only available for nodes that inherit from CoreFileObject. + The node must have been saved (have an id) before calling this method. + + Args: + dest: Optional destination path. If provided, the file will be streamed + directly to this path (memory-efficient for large files) and the + number of bytes written will be returned. If not provided, the + file content will be returned as bytes. + + Returns: + If ``dest`` is None: The file content as bytes. + If ``dest`` is provided: The number of bytes written to the file. + + Raises: + FeatureNotSupportedError: If this node doesn't inherit from CoreFileObject. + ValueError: If the node hasn't been saved yet or file not found. + AuthenticationError: If authentication fails. + + Examples: + >>> # Download to memory + >>> content = contract.download_file() + + >>> # Stream to file (memory-efficient for large files) + >>> bytes_written = contract.download_file(dest=Path("/tmp/contract.pdf")) + """ + self._validate_file_object_support(message=FILE_DOWNLOAD_FEATURE_NOT_SUPPORTED_MESSAGE) + + if not self.id: + raise ValueError("Cannot download file for a node that hasn't been saved yet.") + + return self._file_handler.download(node_id=self.id, branch=self._branch, dest=dest) + def delete(self, timeout: int | None = None, request_context: RequestContext | None = None) -> None: input_data = {"data": {"id": self.id}} if context_data := self._get_request_context(request_context=request_context): @@ -1828,6 +2020,12 @@ def _process_mutation_result( def create( self, allow_upsert: bool = False, timeout: int | None = None, request_context: RequestContext | None = None ) -> None: + if self._file_object_support and self._file_content is None: + raise ValueError( + f"Cannot create {self._schema.kind} without file content. Use upload_from_path() or upload_from_bytes() to provide " + "file content before saving." + ) + mutation_query = self._generate_mutation_query() if allow_upsert: @@ -1838,6 +2036,7 @@ def create( input_data = self._generate_input_data(exclude_hfid=True, request_context=request_context) mutation_name = f"{self._schema.kind}Create" tracker = f"mutation-{str(self._schema.kind).lower()}-create" + query = Mutation( mutation=mutation_name, input_data=input_data["data"], @@ -1845,13 +2044,31 @@ def create( variables=input_data["mutation_variables"], ) - response = self._client.execute_graphql( - query=query.render(), - branch_name=self._branch, - tracker=tracker, - variables=input_data["variables"], - timeout=timeout, - ) + if "file" in input_data["mutation_variables"]: + prepared = self._get_file_for_upload_sync() + try: + response = self._client._execute_graphql_with_file( + query=query.render(), + variables=input_data["variables"], + file_content=prepared.file_object, + file_name=prepared.filename, + branch_name=self._branch, + tracker=tracker, + timeout=timeout, + ) + finally: + if prepared.should_close and prepared.file_object: + prepared.file_object.close() + # Clear the file content after successful upload + self.clear_file() + else: + response = self._client.execute_graphql( + query=query.render(), + branch_name=self._branch, + tracker=tracker, + variables=input_data["variables"], + timeout=timeout, + ) self._process_mutation_result(mutation_name=mutation_name, response=response, timeout=timeout) def update( @@ -1860,6 +2077,7 @@ def update( input_data = self._generate_input_data(exclude_unmodified=not do_full_update, request_context=request_context) mutation_query = self._generate_mutation_query() mutation_name = f"{self._schema.kind}Update" + tracker = f"mutation-{str(self._schema.kind).lower()}-update" query = Mutation( mutation=mutation_name, @@ -1868,13 +2086,31 @@ def update( variables=input_data["mutation_variables"], ) - response = self._client.execute_graphql( - query=query.render(), - branch_name=self._branch, - tracker=f"mutation-{str(self._schema.kind).lower()}-update", - variables=input_data["variables"], - timeout=timeout, - ) + if "file" in input_data["mutation_variables"]: + prepared = self._get_file_for_upload_sync() + try: + response = self._client._execute_graphql_with_file( + query=query.render(), + variables=input_data["variables"], + file_content=prepared.file_object, + file_name=prepared.filename, + branch_name=self._branch, + tracker=tracker, + timeout=timeout, + ) + finally: + if prepared.should_close and prepared.file_object: + prepared.file_object.close() + # Clear the file content after successful upload + self.clear_file() + else: + response = self._client.execute_graphql( + query=query.render(), + branch_name=self._branch, + tracker=tracker, + variables=input_data["variables"], + timeout=timeout, + ) self._process_mutation_result(mutation_name=mutation_name, response=response, timeout=timeout) def _process_relationships( diff --git a/infrahub_sdk/protocols.py b/infrahub_sdk/protocols.py index b3752bed..c359ad5c 100644 --- a/infrahub_sdk/protocols.py +++ b/infrahub_sdk/protocols.py @@ -29,7 +29,6 @@ StringOptional, ) -# pylint: disable=too-many-ancestors # --------------------------------------------- # ASYNC @@ -108,6 +107,14 @@ class CoreCredential(CoreNode): description: StringOptional +class CoreFileObject(CoreNode): + file_name: String + checksum: String + file_size: Integer + file_type: String + storage_id: String + + class CoreGenericAccount(CoreNode): name: String password: HashedPassword @@ -227,6 +234,7 @@ class CoreValidator(CoreNode): class CoreWebhook(CoreNode): name: String event_type: Enum + active: Boolean branch_scope: Dropdown node_kind: StringOptional description: StringOptional @@ -499,7 +507,6 @@ class CoreProposedChange(CoreTaskTarget): approved_by: RelationshipManager rejected_by: RelationshipManager reviewers: RelationshipManager - created_by: RelatedNode comments: RelationshipManager threads: RelationshipManager validations: RelationshipManager @@ -665,6 +672,14 @@ class CoreCredentialSync(CoreNodeSync): description: StringOptional +class CoreFileObjectSync(CoreNodeSync): + file_name: String + checksum: String + file_size: Integer + file_type: String + storage_id: String + + class CoreGenericAccountSync(CoreNodeSync): name: String password: HashedPassword @@ -784,6 +799,7 @@ class CoreValidatorSync(CoreNodeSync): class CoreWebhookSync(CoreNodeSync): name: String event_type: Enum + active: Boolean branch_scope: Dropdown node_kind: StringOptional description: StringOptional @@ -1056,7 +1072,6 @@ class CoreProposedChangeSync(CoreTaskTargetSync): approved_by: RelationshipManagerSync rejected_by: RelationshipManagerSync reviewers: RelationshipManagerSync - created_by: RelatedNodeSync comments: RelationshipManagerSync threads: RelationshipManagerSync validations: RelationshipManagerSync diff --git a/infrahub_sdk/protocols_base.py b/infrahub_sdk/protocols_base.py index 8a841b5b..7f6569ae 100644 --- a/infrahub_sdk/protocols_base.py +++ b/infrahub_sdk/protocols_base.py @@ -6,6 +6,7 @@ import ipaddress from .context import RequestContext + from .node.metadata import NodeMetadata from .schema import MainSchemaTypes @@ -203,6 +204,8 @@ def is_resource_pool(self) -> bool: ... def get_raw_graphql_data(self) -> dict | None: ... + def get_node_metadata(self) -> NodeMetadata | None: ... + @runtime_checkable class CoreNode(CoreNodeBase, Protocol): diff --git a/infrahub_sdk/protocols_generator/generator.py b/infrahub_sdk/protocols_generator/generator.py index e70e221c..38bc968f 100644 --- a/infrahub_sdk/protocols_generator/generator.py +++ b/infrahub_sdk/protocols_generator/generator.py @@ -59,13 +59,13 @@ def __init__(self, schema: dict[str, MainSchemaTypesAll]) -> None: not in {"TYPE_CHECKING", "CoreNode", "Optional", "Protocol", "Union", "annotations", "runtime_checkable"} ] - self.sorted_generics = self._sort_and_filter_models(self.generics, filters=["CoreNode"] + self.base_protocols) - self.sorted_nodes = self._sort_and_filter_models(self.nodes, filters=["CoreNode"] + self.base_protocols) + self.sorted_generics = self._sort_and_filter_models(self.generics, filters=["CoreNode", *self.base_protocols]) + self.sorted_nodes = self._sort_and_filter_models(self.nodes, filters=["CoreNode", *self.base_protocols]) self.sorted_profiles = self._sort_and_filter_models( - self.profiles, filters=["CoreProfile"] + self.base_protocols + self.profiles, filters=["CoreProfile", *self.base_protocols] ) self.sorted_templates = self._sort_and_filter_models( - self.templates, filters=["CoreObjectTemplate"] + self.base_protocols + self.templates, filters=["CoreObjectTemplate", *self.base_protocols] ) def render(self, sync: bool = True) -> str: diff --git a/infrahub_sdk/py.typed b/infrahub_sdk/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/infrahub_sdk/pytest_plugin/items/check.py b/infrahub_sdk/pytest_plugin/items/check.py index f42f4808..8f68b271 100644 --- a/infrahub_sdk/pytest_plugin/items/check.py +++ b/infrahub_sdk/pytest_plugin/items/check.py @@ -12,7 +12,7 @@ from .base import InfrahubItem if TYPE_CHECKING: - from pytest import ExceptionInfo + import pytest from ...checks import InfrahubCheck from ...schema.repository import InfrahubRepositoryConfigElement @@ -46,7 +46,7 @@ def run_check(self, variables: dict[str, Any]) -> Any: self.instantiate_check() return asyncio.run(self.check_instance.run(data=variables)) - def repr_failure(self, excinfo: ExceptionInfo, style: str | None = None) -> str: + def repr_failure(self, excinfo: pytest.ExceptionInfo, style: str | None = None) -> str: if isinstance(excinfo.value, HTTPStatusError): try: response_content = ujson.dumps(excinfo.value.response.json(), indent=4) diff --git a/infrahub_sdk/pytest_plugin/items/graphql_query.py b/infrahub_sdk/pytest_plugin/items/graphql_query.py index defb9fb9..bced5542 100644 --- a/infrahub_sdk/pytest_plugin/items/graphql_query.py +++ b/infrahub_sdk/pytest_plugin/items/graphql_query.py @@ -11,7 +11,7 @@ from .base import InfrahubItem if TYPE_CHECKING: - from pytest import ExceptionInfo + import pytest class InfrahubGraphQLQueryItem(InfrahubItem): @@ -25,7 +25,7 @@ def execute_query(self) -> Any: variables=self.test.spec.get_variables_data(), # type: ignore[union-attr] ) - def repr_failure(self, excinfo: ExceptionInfo, style: str | None = None) -> str: + def repr_failure(self, excinfo: pytest.ExceptionInfo, style: str | None = None) -> str: if isinstance(excinfo.value, HTTPStatusError): try: response_content = ujson.dumps(excinfo.value.response.json(), indent=4) diff --git a/infrahub_sdk/pytest_plugin/items/jinja2_transform.py b/infrahub_sdk/pytest_plugin/items/jinja2_transform.py index 433309a4..fe54fd71 100644 --- a/infrahub_sdk/pytest_plugin/items/jinja2_transform.py +++ b/infrahub_sdk/pytest_plugin/items/jinja2_transform.py @@ -16,7 +16,7 @@ from .base import InfrahubItem if TYPE_CHECKING: - from pytest import ExceptionInfo + import pytest class InfrahubJinja2Item(InfrahubItem): @@ -57,7 +57,7 @@ def get_result_differences(self, computed: Any) -> str | None: ) return "\n".join(differences) - def repr_failure(self, excinfo: ExceptionInfo, style: str | None = None) -> str: + def repr_failure(self, excinfo: pytest.ExceptionInfo, style: str | None = None) -> str: if isinstance(excinfo.value, HTTPStatusError): try: response_content = ujson.dumps(excinfo.value.response.json(), indent=4, sort_keys=True) @@ -94,7 +94,7 @@ def runtest(self) -> None: if computed is not None and differences and self.test.expect == InfrahubTestExpectedResult.PASS: raise OutputMatchError(name=self.name, differences=differences) - def repr_failure(self, excinfo: ExceptionInfo, style: str | None = None) -> str: + def repr_failure(self, excinfo: pytest.ExceptionInfo, style: str | None = None) -> str: if isinstance(excinfo.value, (JinjaTemplateError)): return str(excinfo.value.message) diff --git a/infrahub_sdk/pytest_plugin/items/python_transform.py b/infrahub_sdk/pytest_plugin/items/python_transform.py index 0ec42052..f895e971 100644 --- a/infrahub_sdk/pytest_plugin/items/python_transform.py +++ b/infrahub_sdk/pytest_plugin/items/python_transform.py @@ -13,7 +13,7 @@ from .base import InfrahubItem if TYPE_CHECKING: - from pytest import ExceptionInfo + import pytest from ...schema.repository import InfrahubRepositoryConfigElement from ...transforms import InfrahubTransform @@ -48,7 +48,7 @@ def run_transform(self, variables: dict[str, Any]) -> Any: self.instantiate_transform() return asyncio.run(self.transform_instance.run(data=variables)) - def repr_failure(self, excinfo: ExceptionInfo, style: str | None = None) -> str: + def repr_failure(self, excinfo: pytest.ExceptionInfo, style: str | None = None) -> str: if isinstance(excinfo.value, HTTPStatusError): try: response_content = ujson.dumps(excinfo.value.response.json(), indent=4) diff --git a/infrahub_sdk/pytest_plugin/loader.py b/infrahub_sdk/pytest_plugin/loader.py index 5912a0c3..c26b09d2 100644 --- a/infrahub_sdk/pytest_plugin/loader.py +++ b/infrahub_sdk/pytest_plugin/loader.py @@ -6,7 +6,6 @@ import pytest import yaml -from pytest import Item from .exceptions import InvalidResourceConfigError from .items import ( @@ -66,7 +65,7 @@ def get_resource_config(self, group: InfrahubTestGroup) -> Any | None: return resource_config - def collect_group(self, group: InfrahubTestGroup) -> Iterable[Item]: + def collect_group(self, group: InfrahubTestGroup) -> Iterable[pytest.Item]: """Collect all items for a group.""" marker = MARKER_MAPPING[group.resource] resource_config = self.get_resource_config(group) @@ -98,7 +97,7 @@ def collect_group(self, group: InfrahubTestGroup) -> Iterable[Item]: yield item - def collect(self) -> Iterable[Item]: + def collect(self) -> Iterable[pytest.Item]: raw = yaml.safe_load(self.path.open(encoding="utf-8")) if not raw or "infrahub_tests" not in raw: diff --git a/infrahub_sdk/pytest_plugin/plugin.py b/infrahub_sdk/pytest_plugin/plugin.py index 74148a05..258e7f9c 100644 --- a/infrahub_sdk/pytest_plugin/plugin.py +++ b/infrahub_sdk/pytest_plugin/plugin.py @@ -3,8 +3,7 @@ import os from pathlib import Path -from pytest import Collector, Config, Item, Parser, Session -from pytest import exit as exit_test +import pytest from .. import InfrahubClientSync from ..utils import is_valid_url @@ -12,7 +11,7 @@ from .utils import find_repository_config_file, load_repository_config -def pytest_addoption(parser: Parser) -> None: +def pytest_addoption(parser: pytest.Parser) -> None: group = parser.getgroup("pytest-infrahub") group.addoption( "--infrahub-repo-config", @@ -62,7 +61,7 @@ def pytest_addoption(parser: Parser) -> None: ) -def pytest_sessionstart(session: Session) -> None: +def pytest_sessionstart(session: pytest.Session) -> None: if session.config.option.infrahub_repo_config: session.infrahub_config_path = Path(session.config.option.infrahub_repo_config) # type: ignore[attr-defined] else: @@ -72,7 +71,7 @@ def pytest_sessionstart(session: Session) -> None: session.infrahub_repo_config = load_repository_config(repo_config_file=session.infrahub_config_path) # type: ignore[attr-defined] if not is_valid_url(session.config.option.infrahub_address): - exit_test("Infrahub test instance address is not a valid URL", returncode=1) + pytest.exit("Infrahub test instance address is not a valid URL", returncode=1) client_config = { "address": session.config.option.infrahub_address, @@ -89,13 +88,13 @@ def pytest_sessionstart(session: Session) -> None: session.infrahub_client = infrahub_client # type: ignore[attr-defined] -def pytest_collect_file(parent: Collector | Item, file_path: Path) -> InfrahubYamlFile | None: +def pytest_collect_file(parent: pytest.Collector | pytest.Item, file_path: Path) -> InfrahubYamlFile | None: if file_path.suffix in {".yml", ".yaml"} and file_path.name.startswith("test_"): return InfrahubYamlFile.from_parent(parent, path=file_path) return None -def pytest_configure(config: Config) -> None: +def pytest_configure(config: pytest.Config) -> None: config.addinivalue_line("markers", "infrahub: Infrahub test") config.addinivalue_line("markers", "infrahub_smoke: Smoke test for an Infrahub resource") config.addinivalue_line("markers", "infrahub_unit: Unit test for an Infrahub resource, works without dependencies") diff --git a/infrahub_sdk/schema/__init__.py b/infrahub_sdk/schema/__init__.py index 3e61ad2a..557e76f3 100644 --- a/infrahub_sdk/schema/__init__.py +++ b/infrahub_sdk/schema/__init__.py @@ -22,6 +22,7 @@ ) from ..graphql import Mutation from ..queries import SCHEMA_HASH_SYNC_STATUS +from .export import RESTRICTED_NAMESPACES, NamespaceExport, SchemaExport, schema_to_export_dict from .main import ( AttributeSchema, AttributeSchemaAPI, @@ -54,6 +55,7 @@ "BranchSupportType", "GenericSchema", "GenericSchemaAPI", + "NamespaceExport", "NodeSchema", "NodeSchemaAPI", "ProfileSchemaAPI", @@ -61,9 +63,11 @@ "RelationshipKind", "RelationshipSchema", "RelationshipSchemaAPI", + "SchemaExport", "SchemaRoot", "SchemaRootAPI", "TemplateSchemaAPI", + "schema_to_export_dict", ] @@ -118,6 +122,47 @@ def __init__(self, client: InfrahubClient | InfrahubClientSync) -> None: self.client = client self.cache = {} + @staticmethod + def _build_export_schemas( + schema_nodes: MutableMapping[str, MainSchemaTypesAPI], + namespaces: list[str] | None = None, + ) -> SchemaExport: + """Organize fetched schemas into a per-namespace export structure. + + Filters out system types (Profile/Template) and restricted namespaces + (see :data:`RESTRICTED_NAMESPACES`), and optionally limits to specific + namespaces. If the caller requests restricted namespaces they are + silently excluded and a :func:`warnings.warn` is emitted. + + Returns: + A :class:`SchemaExport` containing user-defined schemas by namespace. + """ + if namespaces: + restricted = set(namespaces) & set(RESTRICTED_NAMESPACES) + if restricted: + warnings.warn( + f"Restricted namespace(s) {sorted(restricted)} requested but will be excluded from export", + stacklevel=3, + ) + + ns_map: dict[str, NamespaceExport] = {} + for schema in schema_nodes.values(): + if isinstance(schema, (ProfileSchemaAPI, TemplateSchemaAPI)): + continue + if schema.namespace in RESTRICTED_NAMESPACES: + continue + if namespaces and schema.namespace not in namespaces: + continue + ns = schema.namespace + if ns not in ns_map: + ns_map[ns] = NamespaceExport() + schema_dict = schema_to_export_dict(schema) + if isinstance(schema, GenericSchemaAPI): + ns_map[ns].generics.append(schema_dict) + else: + ns_map[ns].nodes.append(schema_dict) + return SchemaExport(namespaces=ns_map) + def validate(self, data: dict[str, Any]) -> None: SchemaRoot(**data) @@ -497,6 +542,32 @@ async def fetch( return branch_schema.nodes + async def export( + self, + branch: str | None = None, + namespaces: list[str] | None = None, + ) -> SchemaExport: + """Export user-defined schemas organized by namespace. + + Fetches schemas from the server, filters out system types and + restricted namespaces (see :data:`RESTRICTED_NAMESPACES`), and returns + a :class:`SchemaExport` object with per-namespace data. Restricted + namespaces such as ``Core`` and ``Builtin`` are always excluded even if + explicitly listed in *namespaces*; a warning is emitted when this + happens. + + Args: + branch: Branch to export from. Defaults to default_branch. + namespaces: Optional list of namespaces to include. If empty/None, + all user-defined namespaces are exported. + + Returns: + A :class:`SchemaExport` containing user-defined schemas by namespace. + """ + branch = branch or self.client.default_branch + schema_nodes = await self.fetch(branch=branch, namespaces=namespaces, populate_cache=False) + return self._build_export_schemas(schema_nodes=schema_nodes, namespaces=namespaces) + async def get_graphql_schema(self, branch: str | None = None) -> str: """Get the GraphQL schema as a string. @@ -739,6 +810,32 @@ def fetch( return branch_schema.nodes + def export( + self, + branch: str | None = None, + namespaces: list[str] | None = None, + ) -> SchemaExport: + """Export user-defined schemas organized by namespace. + + Fetches schemas from the server, filters out system types and + restricted namespaces (see :data:`RESTRICTED_NAMESPACES`), and returns + a :class:`SchemaExport` object with per-namespace data. Restricted + namespaces such as ``Core`` and ``Builtin`` are always excluded even if + explicitly listed in *namespaces*; a warning is emitted when this + happens. + + Args: + branch: Branch to export from. Defaults to default_branch. + namespaces: Optional list of namespaces to include. If empty/None, + all user-defined namespaces are exported. + + Returns: + A :class:`SchemaExport` containing user-defined schemas by namespace. + """ + branch = branch or self.client.default_branch + schema_nodes = self.fetch(branch=branch, namespaces=namespaces, populate_cache=False) + return self._build_export_schemas(schema_nodes=schema_nodes, namespaces=namespaces) + def get_graphql_schema(self, branch: str | None = None) -> str: """Get the GraphQL schema as a string. diff --git a/infrahub_sdk/schema/export.py b/infrahub_sdk/schema/export.py new file mode 100644 index 00000000..d5a09c77 --- /dev/null +++ b/infrahub_sdk/schema/export.py @@ -0,0 +1,120 @@ +from __future__ import annotations + +from typing import Any + +from pydantic import BaseModel, Field + +from .main import GenericSchemaAPI, NodeSchemaAPI + + +class NamespaceExport(BaseModel): + """Export data for a single namespace.""" + + nodes: list[dict[str, Any]] = Field(default_factory=list) + generics: list[dict[str, Any]] = Field(default_factory=list) + + +class SchemaExport(BaseModel): + """Result of a schema export, organized by namespace.""" + + namespaces: dict[str, NamespaceExport] = Field(default_factory=dict) + + def to_dict(self) -> dict[str, dict[str, list[dict[str, Any]]]]: + """Convert to plain dict for YAML serialization.""" + return {ns: data.model_dump(exclude_defaults=True) for ns, data in self.namespaces.items()} + + +# Namespaces reserved by the Infrahub server — mirrored from +# backend/infrahub/core/constants/__init__.py in the opsmill/infrahub repo. +RESTRICTED_NAMESPACES: list[str] = [ + "Account", + "Branch", + "Builtin", + "Core", + "Deprecated", + "Diff", + "Infrahub", + "Internal", + "Lineage", + "Schema", + "Profile", + "Template", +] + +_SCHEMA_EXPORT_EXCLUDE: set[str] = {"hash", "hierarchy", "used_by", "id", "state"} +# branch is inherited from the node and need not be repeated on each field +_FIELD_EXPORT_EXCLUDE: set[str] = {"inherited", "allow_override", "hierarchical", "id", "state", "branch"} + +# Attribute field values that match schema loading defaults — omitted for cleaner output +_ATTR_EXPORT_DEFAULTS: dict[str, Any] = { + "read_only": False, + "optional": False, +} + +# Relationship field values that match schema loading defaults — omitted for cleaner output +_REL_EXPORT_DEFAULTS: dict[str, Any] = { + "direction": "bidirectional", + "on_delete": "no-action", + "cardinality": "many", + "optional": True, + "min_count": 0, + "max_count": 0, + "read_only": False, +} + +# Relationship kinds that Infrahub generates automatically — never user-defined +_AUTO_GENERATED_REL_KINDS: frozenset[str] = frozenset({"Group", "Profile", "Hierarchy"}) + + +def schema_to_export_dict(schema: NodeSchemaAPI | GenericSchemaAPI) -> dict[str, Any]: + """Convert an API schema object to an export-ready dict (omits API-internal fields).""" + data = schema.model_dump(exclude=_SCHEMA_EXPORT_EXCLUDE, exclude_none=True) + + # Pop attrs/rels so they can be re-inserted last for better readability + data.pop("attributes", None) + data.pop("relationships", None) + + # Generics with Hierarchy relationships were defined with `hierarchical: true`. + # Restore that flag and drop the auto-generated rels so the schema round-trips cleanly. + if isinstance(schema, GenericSchemaAPI) and any( + rel.kind == "Hierarchy" for rel in schema.relationships if not rel.inherited + ): + data["hierarchical"] = True + + # Strip uniqueness_constraints that are auto-generated from `unique: true` attributes + # (single-field entries of the form ["__value"]). User-defined multi-field + # constraints are preserved. + unique_attr_suffixes = {f"{attr.name}__value" for attr in schema.attributes if attr.unique} + user_constraints = [ + c + for c in (data.pop("uniqueness_constraints", None) or []) + if not (len(c) == 1 and c[0] in unique_attr_suffixes) + ] + if user_constraints: + data["uniqueness_constraints"] = user_constraints + + attributes = [ + { + k: v + for k, v in attr.model_dump(exclude=_FIELD_EXPORT_EXCLUDE, exclude_none=True).items() + if k not in _ATTR_EXPORT_DEFAULTS or v != _ATTR_EXPORT_DEFAULTS[k] + } + for attr in schema.attributes + if not attr.inherited + ] + if attributes: + data["attributes"] = attributes + + relationships = [ + { + k: v + for k, v in rel.model_dump(exclude=_FIELD_EXPORT_EXCLUDE, exclude_none=True).items() + if k not in _REL_EXPORT_DEFAULTS or v != _REL_EXPORT_DEFAULTS[k] + } + for rel in schema.relationships + if not rel.inherited and rel.kind not in _AUTO_GENERATED_REL_KINDS + ] + if relationships: + data["relationships"] = relationships + + return data diff --git a/infrahub_sdk/spec/object.py b/infrahub_sdk/spec/object.py index cf7a6fc3..30d9f93d 100644 --- a/infrahub_sdk/spec/object.py +++ b/infrahub_sdk/spec/object.py @@ -265,7 +265,7 @@ async def validate_object( # First validate if all mandatory fields are present errors.extend( - ObjectValidationError(position=position + [element], message=f"{element} is mandatory") + ObjectValidationError(position=[*position, element], message=f"{element} is mandatory") for element in schema.mandatory_input_names if not any([element in data, element in context]) ) @@ -275,7 +275,7 @@ async def validate_object( if key not in schema.attribute_names and key not in schema.relationship_names: errors.append( ObjectValidationError( - position=position + [key], + position=[*position, key], message=f"{key} is not a valid attribute or relationship for {schema.kind}", ) ) @@ -283,7 +283,7 @@ async def validate_object( if key in schema.attribute_names and not isinstance(value, (str, int, float, bool, list, dict)): errors.append( ObjectValidationError( - position=position + [key], + position=[*position, key], message=f"{key} must be a string, int, float, bool, list, or dict", ) ) @@ -295,7 +295,7 @@ async def validate_object( if not rel_info.is_valid: errors.append( ObjectValidationError( - position=position + [key], + position=[*position, key], message=rel_info.reason_relationship_not_valid or "Invalid relationship", ) ) @@ -303,7 +303,7 @@ async def validate_object( errors.extend( await cls.validate_related_nodes( client=client, - position=position + [key], + position=[*position, key], rel_info=rel_info, data=value, context=context, @@ -378,7 +378,7 @@ async def validate_related_nodes( errors.extend( await cls.validate_object( client=client, - position=position + [idx + 1], + position=[*position, idx + 1], schema=peer_schema, data=peer_data, context=context, @@ -403,7 +403,7 @@ async def validate_related_nodes( errors.extend( await cls.validate_object( client=client, - position=position + [idx + 1], + position=[*position, idx + 1], schema=peer_schema, data=item["data"], context=context, @@ -613,7 +613,7 @@ async def create_related_nodes( node = await cls.create_node( client=client, schema=peer_schema, - position=position + [rel_info.name, idx + 1], + position=[*position, rel_info.name, idx + 1], data=peer_data, context=context, branch=branch, @@ -639,7 +639,7 @@ async def create_related_nodes( node = await cls.create_node( client=client, schema=peer_schema, - position=position + [rel_info.name, idx + 1], + position=[*position, rel_info.name, idx + 1], data=item["data"], context=context, branch=branch, @@ -681,7 +681,7 @@ def spec(self) -> InfrahubObjectFileData: try: self._spec = InfrahubObjectFileData(**self.data.spec) except Exception as exc: - raise ValidationError(identifier=str(self.location), message=str(exc)) + raise ValidationError(identifier=str(self.location), message=str(exc)) from exc return self._spec def validate_content(self) -> None: @@ -691,7 +691,7 @@ def validate_content(self) -> None: try: self._spec = InfrahubObjectFileData(**self.data.spec) except Exception as exc: - raise ValidationError(identifier=str(self.location), message=str(exc)) + raise ValidationError(identifier=str(self.location), message=str(exc)) from exc async def validate_format(self, client: InfrahubClient, branch: str | None = None) -> None: self.validate_content() diff --git a/infrahub_sdk/template/__init__.py b/infrahub_sdk/template/__init__.py index ff866ecd..6a7f7fe2 100644 --- a/infrahub_sdk/template/__init__.py +++ b/infrahub_sdk/template/__init__.py @@ -64,14 +64,11 @@ def get_template(self) -> jinja2.Template: return self._template_definition try: - if self.is_string_based: - template = self._get_string_based_template() - else: - template = self._get_file_based_template() + template = self._get_string_based_template() if self.is_string_based else self._get_file_based_template() except jinja2.TemplateSyntaxError as exc: self._raise_template_syntax_error(error=exc) except jinja2.TemplateNotFound as exc: - raise JinjaTemplateNotFoundError(message=exc.message, filename=str(exc.name)) + raise JinjaTemplateNotFoundError(message=exc.message, filename=str(exc.name)) from exc return template @@ -119,19 +116,18 @@ async def render(self, variables: dict[str, Any]) -> str: try: output = await template.render_async(variables) except jinja2.exceptions.TemplateNotFound as exc: - raise JinjaTemplateNotFoundError(message=exc.message, filename=str(exc.name), base_template=template.name) + raise JinjaTemplateNotFoundError( + message=exc.message, filename=str(exc.name), base_template=template.name + ) from exc except jinja2.TemplateSyntaxError as exc: self._raise_template_syntax_error(error=exc) except jinja2.UndefinedError as exc: traceback = Traceback(show_locals=False) errors = _identify_faulty_jinja_code(traceback=traceback) - raise JinjaTemplateUndefinedError(message=exc.message, errors=errors) + raise JinjaTemplateUndefinedError(message=exc.message, errors=errors) from exc except Exception as exc: - if error_message := getattr(exc, "message", None): - message = error_message - else: - message = str(exc) - raise JinjaTemplateError(message=message or "Unknown template error") + message = error_message if (error_message := getattr(exc, "message", None)) else str(exc) + raise JinjaTemplateError(message=message or "Unknown template error") from exc return output @@ -195,10 +191,7 @@ def _identify_faulty_jinja_code(traceback: Traceback, nbr_context_lines: int = 3 # Extract only the Jinja related exception for frame in [frame for frame in traceback.trace.stacks[0].frames if not frame.filename.endswith(".py")]: code = "".join(linecache.getlines(frame.filename)) - if frame.filename == "