Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 189 additions & 9 deletions tests/test_messaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from webhook.messaging import create_or_update_message
from webhook.messaging import get_or_create_message_refs
from webhook.messaging import update_all_messages_transactional
from webhook.messaging import update_message_with_fingerprint


@pytest.fixture
Expand Down Expand Up @@ -60,6 +61,8 @@ def sample_mri():
mri.merge_request_ref_id = 1
mri.merge_request_payload.object_attributes.state = "opened"
mri.merge_request_payload.object_attributes.title = "Test MR"
mri.merge_request_payload.object_attributes.target_project_id = 42
mri.merge_request_payload.object_attributes.iid = 123
mri.merge_request_payload.project.path_with_namespace = "test/project"
return mri

Expand Down Expand Up @@ -218,13 +221,24 @@ async def test_error_on_create_raises_exception(self, mock_http_client, sample_m

mock_http_client.request.side_effect = httpx.HTTPError("Connection failed")

with pytest.raises(httpx.HTTPError):
with (
patch("webhook.messaging.logger") as mock_logger,
pytest.raises(httpx.HTTPError),
):
await create_or_update_message(
mock_http_client,
sample_mr_mess_ref,
card={"body": []},
project_id=42,
mr_iid=123,
)

# Error log must carry project_id/mr_iid for triage
mock_logger.error.assert_called_once()
call_kwargs = mock_logger.error.call_args[1]
assert call_kwargs["project_id"] == 42
assert call_kwargs["mr_iid"] == 123


class TestUpdateAllMessagesTransactional:
async def test_updates_all_messages_successfully(self, mock_database, sample_mri):
Expand Down Expand Up @@ -423,18 +437,26 @@ async def side_effect(*args, **kwargs):
mock_client_ctx.__aexit__ = AsyncMock()
mock_client_class.return_value = mock_client_ctx

count = await update_all_messages_transactional(
sample_mri,
{"body": []},
"summary",
"fingerprint",
datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC),
"test-action",
)
with patch("webhook.messaging.logger") as mock_logger:
count = await update_all_messages_transactional(
sample_mri,
{"body": []},
"summary",
"fingerprint",
datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC),
"test-action",
)

assert count == 1
assert mock_client_instance.request.call_count == 2

# Per-message error log must include project_id/mr_iid pulled from
# the mri payload (not threaded through kwargs in this code path).
assert mock_logger.error.call_count == 1
call_kwargs = mock_logger.error.call_args[1]
assert call_kwargs["project_id"] == 42
assert call_kwargs["mr_iid"] == 123

async def test_update_raises_error_on_http_failure(
self, mock_http_client, mock_database, sample_mr_mess_ref
):
Expand Down Expand Up @@ -639,6 +661,164 @@ async def test_newer_event_processed_in_transactional(self, mock_database):
mock_client_instance.request.assert_called_once()


class TestUpdateMessageWithFingerprint:
async def test_successful_update_stores_fingerprint(
self, mock_http_client, mock_database, sample_mr_mess_ref
):
mock_db, mock_conn = mock_database
response_mock = MagicMock()
response_mock.status_code = 200
response_mock.raise_for_status = MagicMock()
mock_http_client.request.return_value = response_mock
mock_conn.execute.return_value = "UPDATE 1"

card = {"type": "AdaptiveCard", "body": []}
fingerprint = "abc123"
updated_at = datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC)

await update_message_with_fingerprint(
mock_http_client,
sample_mr_mess_ref,
card,
"Test summary",
fingerprint,
updated_at,
project_id=42,
mr_iid=123,
)

mock_http_client.request.assert_called_once()
call_args = mock_http_client.request.call_args
assert call_args[0][0] == "PATCH"
assert call_args[1]["json"]["message_id"] == str(sample_mr_mess_ref.message_id)

# Verify fingerprint and updated_at are passed to the conditional UPDATE
# (the function's whole reason for existing — name was lying without this).
mock_conn.execute.assert_called_once()
update_args = mock_conn.execute.call_args.args
assert update_args[1] == fingerprint
assert update_args[2] == updated_at
assert update_args[3] == sample_mr_mess_ref.merge_request_message_ref_id

async def test_skips_when_message_id_is_none(self, mock_http_client):
mrmsgref = MRMessRef(
merge_request_message_ref_id=1,
conversation_token=uuid.uuid4(),
message_id=None,
)

with patch("webhook.messaging.logger") as mock_logger:
await update_message_with_fingerprint(
mock_http_client,
mrmsgref,
{"body": []},
None,
"fingerprint",
datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC),
project_id=42,
mr_iid=123,
)

mock_http_client.request.assert_not_called()
mock_logger.warning.assert_called_once()
assert "NULL message_id" in str(mock_logger.warning.call_args)

async def test_race_condition_logs_warning_with_project_info(
self, mock_http_client, mock_database, sample_mr_mess_ref
):
mock_db, mock_conn = mock_database
response_mock = MagicMock()
response_mock.status_code = 200
response_mock.raise_for_status = MagicMock()
mock_http_client.request.return_value = response_mock
mock_conn.execute.return_value = "UPDATE 0"

with patch("webhook.messaging.logger") as mock_logger:
await update_message_with_fingerprint(
mock_http_client,
sample_mr_mess_ref,
{"body": []},
None,
"fingerprint",
datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC),
project_id=42,
mr_iid=123,
)

mock_logger.warning.assert_called_once()
call_kwargs = mock_logger.warning.call_args[1]
assert call_kwargs["project_id"] == 42
assert call_kwargs["mr_iid"] == 123
assert "race" in mock_logger.warning.call_args[0][0]

async def test_http_error_logs_with_project_info(self, mock_http_client, sample_mr_mess_ref):
mock_http_client.request.side_effect = httpx.HTTPError("Connection failed")

with (
patch("webhook.messaging.logger") as mock_logger,
pytest.raises(httpx.HTTPError),
):
await update_message_with_fingerprint(
mock_http_client,
sample_mr_mess_ref,
{"body": []},
None,
"fingerprint",
datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC),
project_id=42,
mr_iid=123,
)

mock_logger.error.assert_called_once()
call_kwargs = mock_logger.error.call_args[1]
assert call_kwargs["project_id"] == 42
assert call_kwargs["mr_iid"] == 123

async def test_includes_summary_in_payload_when_provided(
self, mock_http_client, mock_database, sample_mr_mess_ref
):
mock_db, mock_conn = mock_database
response_mock = MagicMock()
response_mock.status_code = 200
response_mock.raise_for_status = MagicMock()
mock_http_client.request.return_value = response_mock
mock_conn.execute.return_value = "UPDATE 1"

await update_message_with_fingerprint(
mock_http_client,
sample_mr_mess_ref,
{"body": []},
"Test summary",
"fingerprint",
datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC),
)

call_args = mock_http_client.request.call_args
assert call_args[1]["json"]["summary"] == "Test summary"

async def test_omits_summary_from_payload_when_none(
self, mock_http_client, mock_database, sample_mr_mess_ref
):
mock_db, mock_conn = mock_database
response_mock = MagicMock()
response_mock.status_code = 200
response_mock.raise_for_status = MagicMock()
mock_http_client.request.return_value = response_mock
mock_conn.execute.return_value = "UPDATE 1"

await update_message_with_fingerprint(
mock_http_client,
sample_mr_mess_ref,
{"body": []},
None,
"fingerprint",
datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC),
)

call_args = mock_http_client.request.call_args
assert "summary" not in call_args[1]["json"]


class TestMRMessRef:
def test_creates_with_all_fields(self):
ref_id = 123
Expand Down
6 changes: 6 additions & 0 deletions webhook/merge_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,8 @@ async def merge_request(
)
or not participant_found
),
project_id=mr.object_attributes.target_project_id,
mr_iid=mr.object_attributes.iid,
)

if mrmsgref.message_id is not None:
Expand Down Expand Up @@ -308,6 +310,8 @@ async def merge_request(
summary,
datasource_fingerprint,
payload_updated_at,
project_id=mr.object_attributes.target_project_id,
mr_iid=mr.object_attributes.iid,
)
else:
await update_message_with_fingerprint(
Expand All @@ -317,6 +321,8 @@ async def merge_request(
summary,
datasource_fingerprint,
payload_updated_at,
project_id=mr.object_attributes.target_project_id,
mr_iid=mr.object_attributes.iid,
)
messages_processed += 1
except Exception:
Expand Down
15 changes: 15 additions & 0 deletions webhook/messaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ async def create_or_update_message(
card: dict[str, Any] | None = None,
summary: str | None = None,
update_only: bool = False,
project_id: int | None = None,
mr_iid: int | None = None,
) -> uuid.UUID | None:
payload: dict[str, Any]
if message_text:
Expand Down Expand Up @@ -157,6 +159,8 @@ async def create_or_update_message(
url=config.ACTIVITY_API + "api/v1/message",
conversation_token=str(mrmsgref.conversation_token),
status_code=res.status_code if "res" in locals() else None,
project_id=project_id,
mr_iid=mr_iid,
exc_info=True,
)
raise
Expand Down Expand Up @@ -207,6 +211,8 @@ async def create_or_update_message(
url=config.ACTIVITY_API + "api/v1/message",
message_id=str(mrmsgref.message_id),
status_code=res.status_code if "res" in locals() else None,
project_id=project_id,
mr_iid=mr_iid,
exc_info=True,
)
raise
Expand All @@ -220,6 +226,9 @@ async def update_message_with_fingerprint(
summary: str | None,
payload_fingerprint: str,
payload_updated_at: datetime.datetime,
*,
project_id: int | None = None,
mr_iid: int | None = None,
) -> None:
"""
Update an existing message via Teams API and store the fingerprint.
Expand Down Expand Up @@ -266,6 +275,8 @@ async def update_message_with_fingerprint(
"message update skipped - newer timestamp already stored (race)",
message_id=str(mrmsgref.message_id),
payload_updated_at=payload_updated_at.isoformat(),
project_id=project_id,
mr_iid=mr_iid,
)
else:
logger.debug(
Expand All @@ -281,6 +292,8 @@ async def update_message_with_fingerprint(
url=config.ACTIVITY_API + "api/v1/message",
message_id=str(mrmsgref.message_id),
status_code=res.status_code if "res" in locals() else None,
project_id=project_id,
mr_iid=mr_iid,
exc_info=True,
)
raise
Expand Down Expand Up @@ -400,6 +413,8 @@ async def update_all_messages_transactional(
message_id=str(message_id),
mr_ref_id=mri.merge_request_ref_id,
status_code=res.status_code if "res" in locals() else None,
project_id=mri.merge_request_payload.object_attributes.target_project_id,
mr_iid=mri.merge_request_payload.object_attributes.iid,
exc_info=True,
)

Expand Down
Loading