Skip to content
67 changes: 51 additions & 16 deletions webhook_server/libs/handlers/pull_request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from github import GithubException
from github.PullRequest import PullRequest
from github.Repository import Repository
from timeout_sampler import TimeoutExpiredError, TimeoutSampler

from webhook_server.libs.handlers.check_run_handler import CheckRunHandler, CheckRunOutput
from webhook_server.libs.handlers.labels_handler import LabelsHandler
Expand Down Expand Up @@ -982,8 +983,9 @@ async def label_pull_request_by_merge_state(self, pull_request: PullRequest) ->

Simple flow:
1. Check pull_request.mergeable for conflicts
2. If has conflicts → add has-conflicts, exit
3. Else → remove has-conflicts, check Compare API for rebase status
2. If has conflicts → add has-conflicts label, exit
3. If mergeable unknown → skip has-conflicts update
4. If no conflicts → remove has-conflicts, check Compare API for rebase status

Uses both GitHub APIs for accurate labeling:
- has-conflicts: pull_request.mergeable == False (true merge conflict detection)
Expand All @@ -1004,25 +1006,58 @@ async def label_pull_request_by_merge_state(self, pull_request: PullRequest) ->
needs_rebase_label_exists = NEEDS_REBASE_LABEL_STR in current_labels

# Step 1: Check for conflicts first
# GitHub may return mergeable=None while computing - poll until definitive
mergeable = await asyncio.to_thread(lambda: pull_request.mergeable)
has_conflicts = mergeable is False

if has_conflicts:
# Has conflicts - add has-conflicts label and exit
self.logger.debug(f"{self.log_prefix} PR has conflicts. {mergeable=}")
if mergeable is None:
self.logger.debug(
f"{self.log_prefix} PR mergeable status is None, polling until GitHub computes status"
)
pr_number = pull_request.number
repository = self.github_webhook.repository

def _poll_mergeable() -> bool | None:
for sample in TimeoutSampler(
wait_timeout=30,
sleep=5,
func=lambda: repository.get_pull(pr_number).mergeable,
):
if sample is not None:
return sample
return None # pragma: no cover

try:
mergeable = await asyncio.to_thread(_poll_mergeable)
except asyncio.CancelledError:
raise
except TimeoutExpiredError:
self.logger.warning(
f"{self.log_prefix} PR mergeable status still None after retries, skipping label update"
)
if self.ctx:
self.ctx.complete_step("label_merge_state", mergeable_unknown=True)

if not has_conflicts_label_exists:
self.logger.debug(f"{self.log_prefix} Adding {HAS_CONFLICTS_LABEL_STR} label")
await self.labels_handler._add_label(pull_request=pull_request, label=HAS_CONFLICTS_LABEL_STR)
if mergeable is not None:
has_conflicts = mergeable is False

if self.ctx:
self.ctx.complete_step("label_merge_state", has_conflicts=True, needs_rebase=False)
return # Exit early - conflicts take precedence
if has_conflicts:
# Has conflicts - add has-conflicts label and exit
self.logger.debug(f"{self.log_prefix} PR has conflicts. {mergeable=}")

# Step 2: No conflicts - remove has-conflicts label if present
if has_conflicts_label_exists:
self.logger.debug(f"{self.log_prefix} Removing {HAS_CONFLICTS_LABEL_STR} label")
await self.labels_handler._remove_label(pull_request=pull_request, label=HAS_CONFLICTS_LABEL_STR)
if not has_conflicts_label_exists:
self.logger.debug(f"{self.log_prefix} Adding {HAS_CONFLICTS_LABEL_STR} label")
await self.labels_handler._add_label(pull_request=pull_request, label=HAS_CONFLICTS_LABEL_STR)

if self.ctx:
self.ctx.complete_step("label_merge_state", has_conflicts=True)
return # Exit early - conflicts take precedence

# Step 2: No conflicts - remove has-conflicts label if present
if has_conflicts_label_exists:
self.logger.debug(f"{self.log_prefix} Removing {HAS_CONFLICTS_LABEL_STR} label")
await self.labels_handler._remove_label(pull_request=pull_request, label=HAS_CONFLICTS_LABEL_STR)
else:
self.logger.debug(f"{self.log_prefix} Mergeable status unknown, skipping has-conflicts label update")

# Step 3: Check if needs rebase via Compare API
base_ref, head_user_login, head_ref = await asyncio.gather(
Expand Down
153 changes: 123 additions & 30 deletions webhook_server/tests/test_pull_request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pytest
from github import GithubException
from github.PullRequest import PullRequest
from timeout_sampler import TimeoutExpiredError

from webhook_server.libs.github_api import GithubWebhook
from webhook_server.libs.handlers.owners_files_handler import OwnersFileHandler
Expand Down Expand Up @@ -750,20 +751,14 @@ async def test_label_pull_request_by_merge_state_has_conflicts(
"""Test labeling pull request by merge state when has conflicts.

Uses pull_request.mergeable == False to detect conflicts.
When mergeable is False, ONLY has-conflicts label is set (conflicts take precedence over needs-rebase).
When mergeable is False, has-conflicts label is set and method returns early
without checking Compare API for needs-rebase.
"""
mock_pull_request.mergeable = False # Conflict detected via mergeable
mock_pull_request.base.ref = "main"
mock_pull_request.head.user.login = "test-user"
mock_pull_request.head.ref = "feature-branch"

# Mock existing labels - PR has no labels currently
mock_pull_request.labels = []

# Mock Compare API response - clean (no rebase needed)
mock_compare_data = {"behind_by": 0, "status": "ahead"}
pull_request_handler.repository._requester.requestJsonAndCheck = Mock(return_value=({}, mock_compare_data))

with (
patch.object(
pull_request_handler.labels_handler,
Expand All @@ -773,7 +768,7 @@ async def test_label_pull_request_by_merge_state_has_conflicts(
patch.object(pull_request_handler.labels_handler, "_add_label", new_callable=AsyncMock) as mock_add_label,
):
await pull_request_handler.label_pull_request_by_merge_state(pull_request=mock_pull_request)
# When mergeable is False, only has-conflicts label is set (conflicts take precedence)
# When mergeable is False, has-conflicts label is set and method exits early
mock_add_label.assert_called_once_with(pull_request=mock_pull_request, label=HAS_CONFLICTS_LABEL_STR)

@pytest.mark.asyncio
Expand Down Expand Up @@ -2443,21 +2438,20 @@ async def test_set_pull_request_automerge_exception(
async def test_label_pull_request_by_merge_state_unknown(
self, pull_request_handler: PullRequestHandler, mock_pull_request: Mock
) -> None:
"""Test label_pull_request_by_merge_state when mergeable=None.
"""Test label_pull_request_by_merge_state when mergeable=None after retries.

When mergeable=None (not yet computed), has_conflicts is False.
If Compare API shows behind_by > 0, needs-rebase label should be added.
When mergeable=None (not yet computed by GitHub) and TimeoutSampler
times out, the has-conflicts label is left unchanged but the
needs-rebase check still runs via Compare API.
"""
mock_pull_request.mergeable = None # Not yet computed by GitHub
mock_pull_request.number = 123
mock_pull_request.base.ref = "main"
mock_pull_request.head.user.login = "test-user"
mock_pull_request.head.ref = "feature-branch"

# Mock existing labels - PR has no labels currently
mock_pull_request.labels = []

# Mock Compare API response - behind by 5 commits
mock_compare_data = {"behind_by": 5, "status": "behind"}
# Mock Compare API response - up-to-date (no rebase needed)
mock_compare_data = {"behind_by": 0, "status": "ahead"}
pull_request_handler.repository._requester.requestJsonAndCheck = Mock(return_value=({}, mock_compare_data))

with (
Expand All @@ -2466,11 +2460,117 @@ async def test_label_pull_request_by_merge_state_unknown(
"pull_request_labels_names",
new=AsyncMock(return_value=[]),
),
patch.object(pull_request_handler.labels_handler, "_add_label", new_callable=AsyncMock) as mock_add_label,
patch.object(pull_request_handler.labels_handler, "_add_label", new=AsyncMock()) as mock_add_label,
patch.object(pull_request_handler.labels_handler, "_remove_label", new=AsyncMock()) as mock_remove_label,
patch(
"webhook_server.libs.handlers.pull_request_handler.TimeoutSampler",
side_effect=TimeoutExpiredError("Timed out"),
),
):
await pull_request_handler.label_pull_request_by_merge_state(mock_pull_request)
# Should add needs-rebase label since behind_by > 0 and no conflicts (mergeable=None means no conflict)
mock_add_label.assert_called_once_with(pull_request=mock_pull_request, label=NEEDS_REBASE_LABEL_STR)
# has-conflicts label unchanged (mergeable is None), no rebase needed
mock_add_label.assert_not_awaited()
mock_remove_label.assert_not_awaited()

@pytest.mark.asyncio
async def test_label_pull_request_by_merge_state_mergeable_none_with_existing_conflicts_label(
self, pull_request_handler: PullRequestHandler, mock_pull_request: Mock
) -> None:
"""Test that has-conflicts label is NOT removed when mergeable is None after retries.

When mergeable=None (GitHub still computing) and has-conflicts label
already exists, the label must be preserved (not incorrectly removed)
even after TimeoutSampler times out. The needs-rebase check still runs.
"""
mock_pull_request.mergeable = None # Not yet computed by GitHub
mock_pull_request.number = 123
mock_pull_request.base.ref = "main"
mock_pull_request.head.user.login = "test-user"
mock_pull_request.head.ref = "feature-branch"

# Mock Compare API response - up-to-date (no rebase needed)
mock_compare_data = {"behind_by": 0, "status": "ahead"}
pull_request_handler.repository._requester.requestJsonAndCheck = Mock(return_value=({}, mock_compare_data))

with (
patch.object(
pull_request_handler.labels_handler,
"pull_request_labels_names",
new=AsyncMock(return_value=[HAS_CONFLICTS_LABEL_STR]),
),
patch.object(pull_request_handler.labels_handler, "_add_label", new=AsyncMock()) as mock_add_label,
patch.object(pull_request_handler.labels_handler, "_remove_label", new=AsyncMock()) as mock_remove_label,
patch(
"webhook_server.libs.handlers.pull_request_handler.TimeoutSampler",
side_effect=TimeoutExpiredError("Timed out"),
),
):
await pull_request_handler.label_pull_request_by_merge_state(pull_request=mock_pull_request)
# has-conflicts label preserved (mergeable is None), no rebase needed
mock_add_label.assert_not_awaited()
mock_remove_label.assert_not_awaited()

@pytest.mark.asyncio
async def test_label_pull_request_by_merge_state_polling_resolves_to_conflicts(
self, pull_request_handler: PullRequestHandler, mock_pull_request: Mock
) -> None:
"""Test that polling resolves mergeable=False correctly adds has-conflicts label.

After adding has-conflicts, the method returns early without checking Compare API.
"""
mock_pull_request.mergeable = None # Triggers polling
mock_pull_request.number = 123

with (
patch.object(
pull_request_handler.labels_handler,
"pull_request_labels_names",
new=AsyncMock(return_value=[]),
),
patch.object(pull_request_handler.labels_handler, "_add_label", new=AsyncMock()) as mock_add,
patch(
"webhook_server.libs.handlers.pull_request_handler.TimeoutSampler",
return_value=iter([False]),
),
):
await pull_request_handler.label_pull_request_by_merge_state(pull_request=mock_pull_request)
# has-conflicts should be added (mergeable=False means conflicts), then early return
mock_add.assert_awaited_once_with(pull_request=mock_pull_request, label=HAS_CONFLICTS_LABEL_STR)

@pytest.mark.asyncio
async def test_label_pull_request_by_merge_state_polling_resolves_to_mergeable(
self, pull_request_handler: PullRequestHandler, mock_pull_request: Mock
) -> None:
"""Test that polling resolves mergeable=True correctly removes has-conflicts label."""
mock_pull_request.mergeable = None # Triggers polling
mock_pull_request.number = 123
mock_pull_request.base.ref = "main"
mock_pull_request.head.user.login = "test-user"
mock_pull_request.head.ref = "feature-branch"

with (
patch.object(
pull_request_handler.labels_handler,
"pull_request_labels_names",
new=AsyncMock(return_value=[HAS_CONFLICTS_LABEL_STR]),
),
patch.object(pull_request_handler.labels_handler, "_add_label", new=AsyncMock()) as mock_add,
patch.object(pull_request_handler.labels_handler, "_remove_label", new=AsyncMock()) as mock_remove,
patch(
"webhook_server.libs.handlers.pull_request_handler.TimeoutSampler",
return_value=iter([True]),
),
patch.object(
pull_request_handler,
"_compare_branches",
new=AsyncMock(return_value={"behind_by": 0, "status": "identical"}),
),
):
await pull_request_handler.label_pull_request_by_merge_state(pull_request=mock_pull_request)
# has-conflicts should be removed (mergeable=True means no conflicts)
mock_remove.assert_awaited_once_with(pull_request=mock_pull_request, label=HAS_CONFLICTS_LABEL_STR)
# No labels should be added (no conflicts, no rebase needed)
mock_add.assert_not_awaited()

@pytest.mark.asyncio
async def test_label_pull_request_by_merge_state_diverged(
Expand Down Expand Up @@ -2545,21 +2645,14 @@ async def test_label_pull_request_by_merge_state_behind_and_conflicts(
"""Test labeling pull request when behind and has conflicts.

Uses pull_request.mergeable == False to detect conflicts.
Uses Compare API status='diverged' to detect needs-rebase.
When both exist, ONLY has-conflicts label is set (conflicts take precedence over needs-rebase).
When conflicts are detected, has-conflicts label is set and method returns early
without checking Compare API for needs-rebase.
"""
mock_pull_request.mergeable = False # Conflict detected via mergeable
mock_pull_request.base.ref = "main"
mock_pull_request.head.user.login = "test-user"
mock_pull_request.head.ref = "feature-branch"

# Mock existing labels - PR has no labels currently
mock_pull_request.labels = []

# Mock Compare API response - diverged (needs rebase) + mergeable=False (conflicts)
mock_compare_data = {"behind_by": 2, "status": "diverged"}
pull_request_handler.repository._requester.requestJsonAndCheck = Mock(return_value=({}, mock_compare_data))

with (
patch.object(
pull_request_handler.labels_handler,
Expand All @@ -2569,7 +2662,7 @@ async def test_label_pull_request_by_merge_state_behind_and_conflicts(
patch.object(pull_request_handler.labels_handler, "_add_label", new_callable=AsyncMock) as mock_add_label,
):
await pull_request_handler.label_pull_request_by_merge_state(pull_request=mock_pull_request)
# When mergeable is False (conflicts), only has-conflicts label is set (conflicts take precedence)
# Only has-conflicts label is set; method returns early without checking Compare API
mock_add_label.assert_called_once_with(pull_request=mock_pull_request, label=HAS_CONFLICTS_LABEL_STR)

@pytest.mark.asyncio
Expand Down