Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 43 additions & 4 deletions backend/infrahub/webhook/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import hashlib
import hmac
import json
from typing import TYPE_CHECKING, Any
import logging
import os
from typing import TYPE_CHECKING, Any, Literal
from uuid import UUID, uuid4

from pydantic import BaseModel, ConfigDict, Field, computed_field
Expand All @@ -19,6 +21,8 @@
from infrahub.trigger.models import EventTrigger, ExecuteWorkflow, TriggerDefinition, TriggerType
from infrahub.workflows.catalogue import WEBHOOK_PROCESS

logger = logging.getLogger(__name__)

if TYPE_CHECKING:
from httpx import Response
from infrahub_sdk.client import InfrahubClient
Expand Down Expand Up @@ -113,12 +117,36 @@ def from_event(cls, event_id: str, event_type: str, event_occured_at: str, event
)


class WebhookHeaderResolutionError(Exception):
pass


class WebhookHeader(BaseModel):
key: str
value: str
kind: Literal["static", "environment"]

def resolve(self) -> str:
"""Resolve the header value based on its kind.

Raises WebhookHeaderResolutionError if the value cannot be resolved.
"""
if self.kind == "static":
return self.value

resolved = os.environ.get(self.value)
if resolved is None:
raise WebhookHeaderResolutionError(f"Environment variable '{self.value}' not found")
return resolved
Comment thread
polmichel marked this conversation as resolved.


class Webhook(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
name: str = Field(...)
url: str = Field(...)
event_type: str = Field(...)
validate_certificates: bool | None = Field(...)
custom_headers: list[WebhookHeader] = Field(default_factory=list)
_payload: Any = None
_headers: dict[str, Any] | None = None
shared_key: str | None = Field(default=None, description="Shared key for signing the webhook requests")
Expand All @@ -132,6 +160,12 @@ def _assign_headers(self, uuid: UUID | None = None, at: Timestamp | None = None)
"Content-Type": "application/json",
}

for header in self.custom_headers:
try:
self._headers[header.key] = header.resolve()
except WebhookHeaderResolutionError as exc:
logger.warning("Webhook '%s': %s, skipping header '%s'", self.name, exc, header.key)

if self.shared_key:
message_id = f"msg_{uuid.hex}" if uuid else f"msg_{uuid4().hex}"
timestamp = str(at.to_timestamp()) if at else str(Timestamp().to_timestamp())
Expand Down Expand Up @@ -184,25 +218,27 @@ class CustomWebhook(Webhook):
"""Custom webhook"""

@classmethod
def from_object(cls, obj: CoreCustomWebhook) -> Self:
def from_object(cls, obj: CoreCustomWebhook, custom_headers: list[WebhookHeader] | None = None) -> Self:
return cls(
name=obj.name.value,
url=obj.url.value,
event_type=obj.event_type.value,
validate_certificates=obj.validate_certificates.value or False,
shared_key=obj.shared_key.value,
custom_headers=custom_headers or [],
)


class StandardWebhook(Webhook):
@classmethod
def from_object(cls, obj: CoreStandardWebhook) -> Self:
def from_object(cls, obj: CoreStandardWebhook, custom_headers: list[WebhookHeader] | None = None) -> Self:
return cls(
name=obj.name.value,
url=obj.url.value,
event_type=obj.event_type.value,
validate_certificates=obj.validate_certificates.value or False,
shared_key=obj.shared_key.value,
custom_headers=custom_headers or [],
)


Expand Down Expand Up @@ -238,7 +274,9 @@ async def _prepare_payload(self, data: dict[str, Any], context: EventContext, cl
) # type: ignore[call-overload]

@classmethod
def from_object(cls, obj: CoreCustomWebhook, transform: CoreTransformPython) -> Self:
def from_object(
cls, obj: CoreCustomWebhook, transform: CoreTransformPython, custom_headers: list[WebhookHeader] | None = None
) -> Self:
return cls(
name=obj.name.value,
url=obj.url.value,
Expand All @@ -253,4 +291,5 @@ def from_object(cls, obj: CoreCustomWebhook, transform: CoreTransformPython) ->
transform_timeout=transform.timeout.value,
convert_query_response=transform.convert_query_response.value or False,
shared_key=obj.shared_key.value,
custom_headers=custom_headers or [],
)
34 changes: 28 additions & 6 deletions backend/infrahub/webhook/tasks/process.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Literal

import ujson
from infrahub_sdk import InfrahubClient # noqa: TC002 needed for prefect flow
Expand All @@ -13,7 +13,7 @@
from infrahub.workers.dependencies import get_cache, get_client, get_http
from infrahub.workflows.utils import add_tags

from ..models import CustomWebhook, EventContext, StandardWebhook, TransformWebhook, Webhook
from ..models import CustomWebhook, EventContext, StandardWebhook, TransformWebhook, Webhook, WebhookHeader

if TYPE_CHECKING:
from httpx import Response
Expand All @@ -36,15 +36,37 @@ async def webhook_send(webhook: Webhook, context: EventContext, event_data: dict
return response


KIND_MAP: dict[str, Literal["static", "environment"]] = {
"CoreStaticKeyValue": "static",
"CoreEnvKeyValue": "environment",
}


def _extract_custom_headers(webhook_node: CoreWebhook) -> list[WebhookHeader]:
"""Extract WebhookHeader list from a webhook node's headers relationship."""
if not hasattr(webhook_node, "headers"):
return []
headers: list[WebhookHeader] = []
for related in webhook_node.headers.peers:
peer = related.peer
kind = KIND_MAP.get(peer.get_kind())
if kind is None:
continue
Comment thread
polmichel marked this conversation as resolved.
headers.append(WebhookHeader(key=peer.key.value, value=peer.value.value, kind=kind))
return headers


@task(name="webhook-convert-node", task_run_name="Convert node to webhook", cache_policy=NONE)
async def convert_node_to_webhook(webhook_node: CoreWebhook, client: InfrahubClient) -> Webhook:
webhook_kind = webhook_node.get_kind()

if webhook_kind not in ["CoreStandardWebhook", "CoreCustomWebhook"]:
raise ValueError(f"Unsupported webhook kind: {webhook_kind}")
Comment thread
polmichel marked this conversation as resolved.

custom_headers = _extract_custom_headers(webhook_node)

if webhook_kind == "CoreStandardWebhook":
return StandardWebhook.from_object(obj=webhook_node)
return StandardWebhook.from_object(obj=webhook_node, custom_headers=custom_headers)

# Processing Custom Webhook
if webhook_node.transformation.id:
Expand All @@ -54,9 +76,9 @@ async def convert_node_to_webhook(webhook_node: CoreWebhook, client: InfrahubCli
prefetch_relationships=True,
include=["name", "class_name", "file_path", "repository"],
)
return TransformWebhook.from_object(obj=webhook_node, transform=transform)
return TransformWebhook.from_object(obj=webhook_node, transform=transform, custom_headers=custom_headers)

return CustomWebhook.from_object(obj=webhook_node)
return CustomWebhook.from_object(obj=webhook_node, custom_headers=custom_headers)


@flow(name="webhook-process", flow_run_name="Send webhook for {webhook_name}")
Expand All @@ -81,7 +103,7 @@ async def webhook_process(
webhook_data_str = await cache.get(key=f"webhook:{webhook_id}")
if not webhook_data_str:
log.info(f"Webhook {webhook_id} not found in cache")
webhook_node = await client.get(kind=webhook_kind, id=webhook_id)
webhook_node = await client.get(kind=webhook_kind, id=webhook_id, prefetch_relationships=True)
webhook = await convert_node_to_webhook(webhook_node=webhook_node, client=client)
webhook_data = webhook.to_cache()
await cache.set(key=f"webhook:{webhook_id}", value=ujson.dumps(webhook_data), expires=KVTTL.TWO_HOURS)
Expand Down
25 changes: 25 additions & 0 deletions backend/tests/functional/webhook/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,31 @@ async def webhook4(db: InfrahubDatabase, initial_dataset: None, client: Infrahub
return webhook


@pytest.fixture(scope="class")
async def webhook_with_headers(db: InfrahubDatabase, initial_dataset: None, client: InfrahubClient) -> Node:
static_header = await Node.init(schema=InfrahubKind.STATICKEYVALUE, db=db)
await static_header.new(db=db, name="x-custom-token", key="X-Custom-Token", value="secret123")
await static_header.save(db=db)

env_header = await Node.init(schema=InfrahubKind.ENVKEYVALUE, db=db)
await env_header.new(db=db, name="x-env-key", key="X-Env-Key", value="MY_ENV_VAR")
await env_header.save(db=db)

webhook = await Node.init(schema=InfrahubKind.STANDARDWEBHOOK, db=db)
await webhook.new(
db=db,
name="WebhookWithHeaders",
url="https://url.mock",
shared_key="1234567890",
validate_certificates=False,
event_type="infrahub.branch.created",
branch_scope="all_branches",
headers=[static_header, env_header],
)
await webhook.save(db=db)
return webhook


@pytest.fixture(scope="class")
async def inactive_webhook(db: InfrahubDatabase, initial_dataset: None, client: InfrahubClient) -> Node:
webhook = await Node.init(schema=InfrahubKind.STANDARDWEBHOOK, db=db)
Expand Down
26 changes: 26 additions & 0 deletions backend/tests/functional/webhook/test_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,31 @@ async def test_convert_node_to_webhook_standard(
"event_type": "infrahub.branch.created",
"validate_certificates": False,
"shared_key": "1234567890",
"custom_headers": [],
"webhook_type": "StandardWebhook",
}

async def test_convert_node_to_webhook_with_headers(
self,
db: InfrahubDatabase,
webhook_with_headers: Node,
client: InfrahubClient,
) -> None:
webhook = await client.get(
kind=InfrahubKind.STANDARDWEBHOOK, id=webhook_with_headers.id, prefetch_relationships=True
)
converted_webhook = await convert_node_to_webhook(webhook_node=webhook, client=client)

assert converted_webhook.model_dump() == {
"name": "WebhookWithHeaders",
"url": "https://url.mock",
"event_type": "infrahub.branch.created",
"validate_certificates": False,
"shared_key": "1234567890",
"custom_headers": [
{"key": "X-Custom-Token", "value": "secret123", "kind": "static"},
{"key": "X-Env-Key", "value": "MY_ENV_VAR", "kind": "environment"},
],
"webhook_type": "StandardWebhook",
}

Expand All @@ -63,6 +88,7 @@ async def test_convert_node_to_webhook_transform(
"transform_timeout": 5,
"url": "https://url.mock",
"validate_certificates": False,
"custom_headers": [],
"webhook_type": "TransformWebhook",
}

Expand Down
105 changes: 104 additions & 1 deletion backend/tests/unit/webhook/test_models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import logging
from unittest.mock import patch
from uuid import UUID

import pytest

from infrahub.core.timestamp import Timestamp
from infrahub.webhook.models import StandardWebhook
from infrahub.webhook.models import CustomWebhook, StandardWebhook, WebhookHeader


def test_standard_webhook() -> None:
Expand All @@ -15,6 +19,7 @@ def test_standard_webhook() -> None:
"url": "http://test.com",
"event_type": "test",
"validate_certificates": True,
"custom_headers": [],
"shared_key": "test",
"webhook_type": "StandardWebhook",
}
Expand All @@ -40,6 +45,104 @@ def test_standard_webhook_header() -> None:
}


def test_assign_headers_with_static_custom_header() -> None:
"""_assign_headers() with static custom header."""
webhook = CustomWebhook(
name="test",
url="http://test.com",
event_type="test",
validate_certificates=True,
custom_headers=[WebhookHeader(key="Authorization", value="Bearer token", kind="static")],
)
webhook._assign_headers()

assert webhook._headers is not None
assert webhook._headers["Authorization"] == "Bearer token"
assert webhook._headers["Accept"] == "application/json"
assert webhook._headers["Content-Type"] == "application/json"


def test_custom_header_overrides_default() -> None:
"""Custom header overrides default header."""
webhook = CustomWebhook(
name="test",
url="http://test.com",
event_type="test",
validate_certificates=True,
custom_headers=[WebhookHeader(key="Content-Type", value="text/plain", kind="static")],
)
webhook._assign_headers()

assert webhook._headers is not None
assert webhook._headers["Content-Type"] == "text/plain"


def test_cache_roundtrip_preserves_custom_headers() -> None:
"""to_cache()/from_cache() roundtrip preserves custom_headers."""
headers = [
WebhookHeader(key="X-Source", value="infrahub", kind="static"),
WebhookHeader(key="Y-Source", value="opsmill", kind="static"),
]
webhook = StandardWebhook(
name="test",
url="http://test.com",
event_type="test",
validate_certificates=True,
shared_key="key123",
custom_headers=headers,
)

cache_data = webhook.to_cache()
restored = StandardWebhook.from_cache(cache_data)

assert restored.custom_headers == headers
assert len(restored.custom_headers) == 2
assert restored.custom_headers[0].key == "X-Source"
assert restored.custom_headers[0].kind == "static"


def test_assign_headers_resolves_environment_variable() -> None:
"""Environment variable header resolves from os.environ at send time."""
webhook = CustomWebhook(
name="test",
url="http://test.com",
event_type="test",
validate_certificates=True,
custom_headers=[WebhookHeader(key="X-API-Key", value="MY_API_KEY", kind="environment")],
)

with patch.dict("os.environ", {"MY_API_KEY": "secret123"}):
webhook._assign_headers()

assert webhook._headers is not None
assert webhook._headers["X-API-Key"] == "secret123"
assert webhook._headers["Accept"] == "application/json"


def test_assign_headers_skips_missing_environment_variable(caplog: pytest.LogCaptureFixture) -> None:
"""Missing environment variable is skipped with a warning, no exception raised."""
webhook = CustomWebhook(
name="test",
url="http://test.com",
event_type="test",
validate_certificates=True,
custom_headers=[
WebhookHeader(key="X-API-Key", value="MISSING_VAR", kind="environment"),
WebhookHeader(key="X-Source", value="infrahub", kind="static"),
],
)

with patch.dict("os.environ", {}, clear=True), caplog.at_level(logging.WARNING, logger="infrahub.webhook.models"):
webhook._assign_headers()

assert webhook._headers is not None
assert "X-API-Key" not in webhook._headers
assert webhook._headers["X-Source"] == "infrahub"
assert "MISSING_VAR" in caplog.text
assert "X-API-Key" in caplog.text
assert "test" in caplog.text # webhook name included in warning


def test_webhook_signature_with_payload() -> None:
"""Signature is computed on compact JSON of the payload, not str(dict) or spaced JSON.

Expand Down
Loading