Skip to content

Commit 46c3eb2

Browse files
authored
feat(explorer): allow client to control and inspect coding state (#104502)
The explorer client can now enable/disable coding tools, view returned file patches, and make and view PRs Part of AIML-1698 Requires getsentry/seer#4200
1 parent df95336 commit 46c3eb2

File tree

5 files changed

+419
-4
lines changed

5 files changed

+419
-4
lines changed

src/sentry/seer/endpoints/organization_seer_explorer_chat.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,12 @@ def post(
112112
on_page_context = validated_data.get("on_page_context")
113113

114114
try:
115-
client = SeerExplorerClient(organization, request.user, is_interactive=True)
115+
client = SeerExplorerClient(
116+
organization,
117+
request.user,
118+
is_interactive=True,
119+
enable_coding=True,
120+
)
116121
if run_id:
117122
# Continue existing conversation
118123
result_run_id = client.continue_run(

src/sentry/seer/explorer/client.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import logging
4+
import time
45
from typing import Any, Literal
56

67
import orjson
@@ -117,6 +118,28 @@ def execute(cls, organization: Organization, run_id: int) -> None:
117118
on_completion=NotifyOnComplete
118119
)
119120
run_id = client.start_run("Analyze this issue")
121+
122+
# WITH CODE EDITING AND PR CREATION
123+
client = SeerExplorerClient(
124+
organization,
125+
user,
126+
enable_coding=True, # Enable code editing tools
127+
)
128+
129+
run_id = client.start_run("Fix the null pointer exception in auth.py")
130+
state = client.get_run(run_id, blocking=True)
131+
132+
# Check if agent made code changes and if they need to be pushed
133+
has_changes, is_synced = state.has_code_changes()
134+
if has_changes and not is_synced:
135+
# Push changes to PR (creates new PR or updates existing)
136+
state = client.push_changes(run_id)
137+
138+
# Get PR info for each repo
139+
for repo_name in state.get_file_patches_by_repo().keys():
140+
pr_state = state.get_pr_state(repo_name)
141+
if pr_state and pr_state.pr_url:
142+
print(f"PR created: {pr_state.pr_url}")
120143
```
121144
122145
Args:
@@ -128,6 +151,7 @@ def execute(cls, organization: Organization, run_id: int) -> None:
128151
on_completion_hook: Optional `ExplorerOnCompletionHook` class to call when the agent completes. The hook's execute() method receives the organization and run ID. This is called whether or not the agent was successful. Hook classes must be module-level (not nested classes).
129152
intelligence_level: Optionally set the intelligence level of the agent. Higher intelligence gives better result quality at the cost of significantly higher latency and cost.
130153
is_interactive: Enable full interactive, human-like features of the agent. Only enable if you support *all* available interactions in Seer. An example use of this is the explorer chat in Sentry UI.
154+
enable_coding: Enable code editing tools. When disabled, the agent cannot make code changes. Default is False.
131155
"""
132156

133157
def __init__(
@@ -140,6 +164,7 @@ def __init__(
140164
on_completion_hook: type[ExplorerOnCompletionHook] | None = None,
141165
intelligence_level: Literal["low", "medium", "high"] = "medium",
142166
is_interactive: bool = False,
167+
enable_coding: bool = False,
143168
):
144169
self.organization = organization
145170
self.user = user
@@ -149,6 +174,7 @@ def __init__(
149174
self.category_key = category_key
150175
self.category_value = category_value
151176
self.is_interactive = is_interactive
177+
self.enable_coding = enable_coding
152178

153179
# Validate that category_key and category_value are provided together
154180
if category_key == "" or category_value == "":
@@ -198,6 +224,7 @@ def start_run(
198224
"user_org_context": collect_user_org_context(self.user, self.organization),
199225
"intelligence_level": self.intelligence_level,
200226
"is_interactive": self.is_interactive,
227+
"enable_coding": self.enable_coding,
201228
}
202229

203230
# Add artifact key and schema if provided
@@ -273,6 +300,7 @@ def continue_run(
273300
"insert_index": insert_index,
274301
"on_page_context": on_page_context,
275302
"is_interactive": self.is_interactive,
303+
"enable_coding": self.enable_coding,
276304
}
277305

278306
# Add artifact key and schema if provided
@@ -384,3 +412,68 @@ def get_runs(
384412

385413
runs = [ExplorerRun(**run) for run in result.get("data", [])]
386414
return runs
415+
416+
def push_changes(
417+
self,
418+
run_id: int,
419+
repo_name: str | None = None,
420+
poll_interval: float = 2.0,
421+
poll_timeout: float = 120.0,
422+
) -> SeerRunState:
423+
"""
424+
Push code changes to PR(s) and wait for completion.
425+
426+
Creates new PRs or updates existing ones with current file patches.
427+
Polls until all PR operations complete.
428+
429+
Args:
430+
run_id: The run ID
431+
repo_name: Specific repo to push, or None for all repos with changes
432+
poll_interval: Seconds between polls
433+
poll_timeout: Maximum seconds to wait
434+
435+
Returns:
436+
SeerRunState: Final state with PR info
437+
438+
Raises:
439+
TimeoutError: If polling exceeds timeout
440+
requests.HTTPError: If the Seer API request fails
441+
"""
442+
# Trigger PR creation
443+
path = "/v1/automation/explorer/update"
444+
payload = {
445+
"run_id": run_id,
446+
"payload": {
447+
"type": "create_pr",
448+
"repo_name": repo_name,
449+
},
450+
}
451+
body = orjson.dumps(payload, option=orjson.OPT_NON_STR_KEYS)
452+
response = requests.post(
453+
f"{settings.SEER_AUTOFIX_URL}{path}",
454+
data=body,
455+
headers={
456+
"content-type": "application/json;charset=utf-8",
457+
**sign_with_seer_secret(body),
458+
},
459+
)
460+
response.raise_for_status()
461+
462+
# Poll until PR creation completes
463+
start_time = time.time()
464+
465+
while True:
466+
state = fetch_run_status(run_id, self.organization)
467+
468+
# Check if any PRs are still being created
469+
any_creating = any(
470+
pr.pr_creation_status == "creating" for pr in state.repo_pr_states.values()
471+
)
472+
473+
if not any_creating:
474+
return state
475+
476+
if time.time() - start_time > poll_timeout:
477+
raise TimeoutError(f"PR creation timed out after {poll_timeout}s")
478+
479+
time.sleep(poll_interval)

src/sentry/seer/explorer/client_models.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,46 @@ class Config:
4444
extra = "allow"
4545

4646

47+
class FilePatch(BaseModel):
48+
"""A file patch from code editing."""
49+
50+
path: str
51+
type: Literal["A", "M", "D"] # A=add, M=modify, D=delete
52+
added: int
53+
removed: int
54+
55+
class Config:
56+
extra = "allow"
57+
58+
59+
class ExplorerFilePatch(BaseModel):
60+
"""A file patch associated with a repository."""
61+
62+
repo_name: str
63+
patch: FilePatch
64+
65+
class Config:
66+
extra = "allow"
67+
68+
69+
class RepoPRState(BaseModel):
70+
"""PR state for a single repository."""
71+
72+
repo_name: str
73+
branch_name: str | None = None
74+
pr_number: int | None = None
75+
pr_url: str | None = None
76+
pr_id: int | None = None
77+
commit_sha: str | None = None
78+
pr_creation_status: Literal["creating", "completed", "error"] | None = None
79+
pr_creation_error: str | None = None
80+
title: str | None = None
81+
description: str | None = None
82+
83+
class Config:
84+
extra = "allow"
85+
86+
4787
class MemoryBlock(BaseModel):
4888
"""A block in the Explorer agent's conversation/memory."""
4989

@@ -52,6 +92,10 @@ class MemoryBlock(BaseModel):
5292
timestamp: str
5393
loading: bool = False
5494
artifacts: list[Artifact] = []
95+
file_patches: list[ExplorerFilePatch] = []
96+
pr_commit_shas: dict[str, str] | None = (
97+
None # repository name -> commit SHA. Used to track which commit was associated with each repo's PR at the time this block was created.
98+
)
5599

56100
class Config:
57101
extra = "allow"
@@ -76,6 +120,7 @@ class SeerRunState(BaseModel):
76120
status: Literal["processing", "completed", "error", "awaiting_user_input"]
77121
updated_at: str
78122
pending_user_input: PendingUserInput | None = None
123+
repo_pr_states: dict[str, RepoPRState] = {}
79124

80125
class Config:
81126
extra = "allow"
@@ -110,6 +155,52 @@ def get_artifact(self, key: str, schema: type[T]) -> T | None:
110155
return None
111156
return schema.parse_obj(artifact.data)
112157

158+
def get_file_patches_by_repo(self) -> dict[str, list[ExplorerFilePatch]]:
159+
"""Get file patches grouped by repository."""
160+
by_repo: dict[str, list[ExplorerFilePatch]] = {}
161+
for block in self.blocks:
162+
for fp in block.file_patches:
163+
if fp.repo_name not in by_repo:
164+
by_repo[fp.repo_name] = []
165+
by_repo[fp.repo_name].append(fp)
166+
return by_repo
167+
168+
def get_pr_state(self, repo_name: str) -> RepoPRState | None:
169+
"""Get PR state for a specific repository."""
170+
return self.repo_pr_states.get(repo_name)
171+
172+
def _is_repo_synced(self, repo_name: str) -> bool:
173+
"""Check if PR for a repo is in sync with latest changes."""
174+
pr_state = self.repo_pr_states.get(repo_name)
175+
if not pr_state or not pr_state.commit_sha:
176+
return False # No PR yet = not synced
177+
178+
# Find last block with patches for this repo
179+
for block in reversed(self.blocks):
180+
if any(fp.repo_name == repo_name for fp in block.file_patches):
181+
block_sha = (block.pr_commit_shas or {}).get(repo_name)
182+
return block_sha == pr_state.commit_sha
183+
return True # No patches found = synced
184+
185+
def has_code_changes(self) -> tuple[bool, bool]:
186+
"""
187+
Check if there are code changes and if all have been pushed to PRs.
188+
189+
Returns:
190+
(has_changes, all_changes_pushed):
191+
- has_changes: True if any file patches exist
192+
- all_changes_pushed: True if the current state of changes across all repos have all been pushed to PRs.
193+
"""
194+
patches_by_repo = self.get_file_patches_by_repo()
195+
has_changes = len(patches_by_repo) > 0
196+
197+
if not has_changes:
198+
return (False, True)
199+
200+
# Check if all repos with changes are synced
201+
all_changes_pushed = all(self._is_repo_synced(repo) for repo in patches_by_repo.keys())
202+
return (has_changes, all_changes_pushed)
203+
113204

114205
class CustomToolDefinition(BaseModel):
115206
"""Definition of a custom tool to be sent to Seer."""

tests/sentry/seer/endpoints/test_organization_seer_explorer_chat.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,9 @@ def test_post_new_conversation_calls_client(self, mock_client_class: MagicMock):
6969
assert response.data == {"run_id": 456}
7070

7171
# Verify client was called correctly
72-
mock_client_class.assert_called_once_with(self.organization, ANY, is_interactive=True)
72+
mock_client_class.assert_called_once_with(
73+
self.organization, ANY, is_interactive=True, enable_coding=True
74+
)
7375
mock_client.start_run.assert_called_once_with(
7476
prompt="What is this error about?", on_page_context=None
7577
)
@@ -90,7 +92,9 @@ def test_post_continue_conversation_calls_client(self, mock_client_class: MagicM
9092
assert response.data == {"run_id": 789}
9193

9294
# Verify client was called correctly
93-
mock_client_class.assert_called_once_with(self.organization, ANY, is_interactive=True)
95+
mock_client_class.assert_called_once_with(
96+
self.organization, ANY, is_interactive=True, enable_coding=True
97+
)
9498
mock_client.continue_run.assert_called_once_with(
9599
run_id=789, prompt="Follow up question", insert_index=2, on_page_context=None
96100
)

0 commit comments

Comments
 (0)