diff --git a/src/kimi_cli/soul/agent.py b/src/kimi_cli/soul/agent.py index 76985eed..678a72f4 100644 --- a/src/kimi_cli/soul/agent.py +++ b/src/kimi_cli/soul/agent.py @@ -11,6 +11,7 @@ from kimi_cli.session import Session from kimi_cli.soul.approval import Approval from kimi_cli.soul.denwarenji import DenwaRenji +from kimi_cli.soul.preview import Preview from kimi_cli.soul.runtime import BuiltinSystemPromptArgs, Runtime from kimi_cli.soul.toolset import CustomToolset from kimi_cli.tools import SkipThisTool @@ -55,6 +56,7 @@ async def load_agent( Session: runtime.session, DenwaRenji: runtime.denwa_renji, Approval: runtime.approval, + Preview: runtime.preview, } tools = agent_spec.tools if agent_spec.exclude_tools: diff --git a/src/kimi_cli/soul/kimisoul.py b/src/kimi_cli/soul/kimisoul.py index 69daff80..494b06b0 100644 --- a/src/kimi_cli/soul/kimisoul.py +++ b/src/kimi_cli/soul/kimisoul.py @@ -65,6 +65,7 @@ def __init__( self._runtime = runtime self._denwa_renji = runtime.denwa_renji self._approval = runtime.approval + self._preview = runtime.preview self._context = context self._loop_control = runtime.config.loop_control self._compaction = SimpleCompaction() # TODO: maybe configurable and composable @@ -160,10 +161,16 @@ async def _pipe_approval_to_wire(): request = await self._approval.fetch_request() wire_send(request) + async def _pipe_preview_to_wire(): + while True: + request = await self._preview.fetch_request() + wire_send(request) + step_no = 1 while True: wire_send(StepBegin(step_no)) approval_task = asyncio.create_task(_pipe_approval_to_wire()) + preview_task = asyncio.create_task(_pipe_preview_to_wire()) # FIXME: It's possible that a subagent's approval task steals approval request # from the main agent. We must ensure that the Task tool will redirect them # to the main wire. See `_SubWire` for more details. Later we need to figure @@ -194,6 +201,7 @@ async def _pipe_approval_to_wire(): raise finally: approval_task.cancel() # stop piping approval requests to the wire + preview_task.cancel() # stop piping preview requests to the wire if finished: return diff --git a/src/kimi_cli/soul/preview.py b/src/kimi_cli/soul/preview.py new file mode 100644 index 00000000..6ae7909f --- /dev/null +++ b/src/kimi_cli/soul/preview.py @@ -0,0 +1,50 @@ +import asyncio +import difflib + +from pygments.lexers import get_lexer_for_filename + +from kimi_cli.wire.message import PreviewChange + + +class Preview: + def __init__(self, yolo: bool = False): + self._preview_queue = asyncio.Queue[PreviewChange]() + self._yolo = yolo + + async def get_lexer(self, file_path: str): + try: + lexer = get_lexer_for_filename(file_path) + return lexer.name.lower() + except Exception: + return "text" + + async def preview_text( + self, file_path: str, content: str, content_type: str = "", style: str = "" + ): + if self._yolo: + return + + title = file_path + if not content_type: + content_type = await self.get_lexer(file_path) + + msg = PreviewChange(title, content, content_type, style) + self._preview_queue.put_nowait(msg) + await msg.wait() + + async def preview_diff(self, file_path: str, before: str, after: str): + if self._yolo: + return + + diff = difflib.unified_diff( + before.splitlines(keepends=True), after.splitlines(keepends=True) + ) + + content = "".join(diff) + content_type = await self.get_lexer(file_path) + msg = PreviewChange(file_path, content, content_type, "diff") + self._preview_queue.put_nowait(msg) + await msg.wait() + + async def fetch_request(self) -> PreviewChange: + return await self._preview_queue.get() diff --git a/src/kimi_cli/soul/runtime.py b/src/kimi_cli/soul/runtime.py index d26df051..4db364aa 100644 --- a/src/kimi_cli/soul/runtime.py +++ b/src/kimi_cli/soul/runtime.py @@ -10,6 +10,7 @@ from kimi_cli.session import Session from kimi_cli.soul.approval import Approval from kimi_cli.soul.denwarenji import DenwaRenji +from kimi_cli.soul.preview import Preview from kimi_cli.utils.logging import logger @@ -68,6 +69,7 @@ class Runtime(NamedTuple): builtin_args: BuiltinSystemPromptArgs denwa_renji: DenwaRenji approval: Approval + preview: Preview @staticmethod async def create( @@ -93,4 +95,5 @@ async def create( ), denwa_renji=DenwaRenji(), approval=Approval(yolo=yolo), + preview=Preview(), ) diff --git a/src/kimi_cli/tools/file/patch.py b/src/kimi_cli/tools/file/patch.py index 330e22a9..2187731d 100644 --- a/src/kimi_cli/tools/file/patch.py +++ b/src/kimi_cli/tools/file/patch.py @@ -7,6 +7,7 @@ from pydantic import BaseModel, Field from kimi_cli.soul.approval import Approval +from kimi_cli.soul.preview import Preview from kimi_cli.soul.runtime import BuiltinSystemPromptArgs from kimi_cli.tools.file import FileActions from kimi_cli.tools.utils import ToolRejectedError, load_desc @@ -52,10 +53,17 @@ class PatchFile(CallableTool2[Params]): description: str = load_desc(Path(__file__).parent / "patch.md") params: type[Params] = Params - def __init__(self, builtin_args: BuiltinSystemPromptArgs, approval: Approval, **kwargs: Any): + def __init__( + self, + builtin_args: BuiltinSystemPromptArgs, + approval: Approval, + preview: Preview, + **kwargs: Any, + ): super().__init__(**kwargs) self._work_dir = builtin_args.KIMI_WORK_DIR self._approval = approval + self._preview = preview def _validate_path(self, path: Path) -> ToolError | None: """Validate that the path is safe to patch.""" @@ -104,6 +112,13 @@ async def __call__(self, params: Params) -> ToolReturnType: brief="Invalid path", ) + await self._preview.preview_text( + f"Patch file `{params.path}`", + params.diff, + "", + "diff", + ) + # Request approval if not await self._approval.request( self.name, diff --git a/src/kimi_cli/tools/file/replace.py b/src/kimi_cli/tools/file/replace.py index 4098202a..bf84b18f 100644 --- a/src/kimi_cli/tools/file/replace.py +++ b/src/kimi_cli/tools/file/replace.py @@ -6,6 +6,7 @@ from pydantic import BaseModel, Field from kimi_cli.soul.approval import Approval +from kimi_cli.soul.preview import Preview from kimi_cli.soul.runtime import BuiltinSystemPromptArgs from kimi_cli.tools.file import FileActions from kimi_cli.tools.utils import ToolRejectedError, load_desc @@ -32,10 +33,17 @@ class StrReplaceFile(CallableTool2[Params]): description: str = load_desc(Path(__file__).parent / "replace.md") params: type[Params] = Params - def __init__(self, builtin_args: BuiltinSystemPromptArgs, approval: Approval, **kwargs: Any): + def __init__( + self, + builtin_args: BuiltinSystemPromptArgs, + approval: Approval, + preview: Preview, + **kwargs: Any, + ): super().__init__(**kwargs) self._work_dir = builtin_args.KIMI_WORK_DIR self._approval = approval + self._preview = preview def _validate_path(self, path: Path) -> ToolError | None: """Validate that the path is safe to edit.""" @@ -91,14 +99,6 @@ async def __call__(self, params: Params) -> ToolReturnType: brief="Invalid path", ) - # Request approval - if not await self._approval.request( - self.name, - FileActions.EDIT, - f"Edit file `{params.path}`", - ): - return ToolRejectedError() - # Read the file content async with aiofiles.open(p, encoding="utf-8", errors="replace") as f: content = await f.read() @@ -117,6 +117,16 @@ async def __call__(self, params: Params) -> ToolReturnType: brief="No replacements made", ) + await self._preview.preview_diff(params.path, original_content, content) + + # Request approval + if not await self._approval.request( + self.name, + FileActions.EDIT, + f"Edit file `{params.path}`", + ): + return ToolRejectedError() + # Write the modified content back to the file async with aiofiles.open(p, mode="w", encoding="utf-8") as f: await f.write(content) diff --git a/src/kimi_cli/tools/file/write.py b/src/kimi_cli/tools/file/write.py index a75140ee..a6800ace 100644 --- a/src/kimi_cli/tools/file/write.py +++ b/src/kimi_cli/tools/file/write.py @@ -6,6 +6,7 @@ from pydantic import BaseModel, Field from kimi_cli.soul.approval import Approval +from kimi_cli.soul.preview import Preview from kimi_cli.soul.runtime import BuiltinSystemPromptArgs from kimi_cli.tools.file import FileActions from kimi_cli.tools.utils import ToolRejectedError, load_desc @@ -29,10 +30,17 @@ class WriteFile(CallableTool2[Params]): description: str = load_desc(Path(__file__).parent / "write.md") params: type[Params] = Params - def __init__(self, builtin_args: BuiltinSystemPromptArgs, approval: Approval, **kwargs: Any): + def __init__( + self, + builtin_args: BuiltinSystemPromptArgs, + approval: Approval, + preview: Preview, + **kwargs: Any, + ): super().__init__(**kwargs) self._work_dir = builtin_args.KIMI_WORK_DIR self._approval = approval + self._preview = preview def _validate_path(self, path: Path) -> ToolError | None: """Validate that the path is safe to write.""" @@ -89,6 +97,8 @@ async def __call__(self, params: Params) -> ToolReturnType: brief="Invalid write mode", ) + await self._preview.preview_text(params.path, params.content) + # Request approval if not await self._approval.request( self.name, diff --git a/src/kimi_cli/ui/shell/visualize.py b/src/kimi_cli/ui/shell/visualize.py index 6940bbc4..ab5f3545 100644 --- a/src/kimi_cli/ui/shell/visualize.py +++ b/src/kimi_cli/ui/shell/visualize.py @@ -1,4 +1,5 @@ import asyncio +import re from collections import deque from collections.abc import Callable from contextlib import asynccontextmanager, suppress @@ -11,7 +12,9 @@ from rich.live import Live from rich.markup import escape from rich.panel import Panel +from rich.rule import Rule from rich.spinner import Spinner +from rich.syntax import Syntax from rich.table import Table from rich.text import Text @@ -26,6 +29,7 @@ ApprovalResponse, CompactionBegin, CompactionEnd, + PreviewChange, StatusUpdate, StepBegin, StepInterrupted, @@ -390,6 +394,8 @@ def dispatch_wire_message(self, msg: WireMessage) -> None: self.request_approval(msg) case SubagentEvent(): self.handle_subagent_event(msg) + case PreviewChange(): + self.append_preview(msg) def dispatch_keyboard_event(self, event: KeyEvent) -> None: # handle ESC key to cancel the run @@ -559,6 +565,36 @@ def handle_subagent_event(self, event: SubagentEvent) -> None: # TODO: may need to handle multi-level nested subagents pass + def append_preview(self, msg: PreviewChange): + content_type = msg.content_type + if content_type == "markdown": + body = Markdown( + msg.content, + justify="left", + ) + elif content_type in {"text", "text only"}: + body = Text(msg.content) + elif msg.style == "diff": + body = _DiffView.get_view(msg.file_path, msg.content, msg.content_type) + else: + body = Syntax( + msg.content, + content_type, + theme="monokai", + line_numbers=True, + background_color="default", + padding=(0, 0), + ) + + width = int(console.width * 0.8) + panel = Panel( + body, + border_style="wheat4", + width=width, + ) + console.print(panel) + msg.resolve() + def _with_bullet( renderable: RenderableType, @@ -574,3 +610,71 @@ def _with_bullet( bullet = Text("•") table.add_row(bullet, renderable) return table + + +class _DiffView: + RED = "#3A0003" + GREEN = "#242F12" + GRAY = "grey50" + + @staticmethod + def parse_diff_header(diff_line: str) -> tuple[int, int, int, int]: + pattern = r"@@ -(\d+),?(\d*) \+(\d+),?(\d*) @@" + match = re.match(pattern, diff_line) + + if not match: + return (0, 0, 0, 0) + + old_start, old_lines, new_start, new_lines = match.groups() + + old_lines = int(old_lines) if old_lines else 1 + new_lines = int(new_lines) if new_lines else 1 + + return (int(old_start), old_lines, int(new_start), new_lines) + + @staticmethod + def get_view(file: str, code: str, lexer: str) -> RenderableType: + if not code: + return Group( + Text(f"Edit {file}\n", style="grey50"), Text("nothing changed", style="grey50") + ) + + table = Table.grid(padding=(0, 1)) + table.add_column(style=_DiffView.GRAY, no_wrap=True) # line number + table.add_column(style=_DiffView.GRAY, no_wrap=True) # Diff marker + table.add_column() # code + + ln = oln = nln = 0 + for line in code.splitlines(): + if line.startswith("---") or line.startswith("+++"): + continue + + if line.startswith("@@"): + ln = oln = nln = _DiffView.parse_diff_header(line)[0] - 1 + if len(table.rows) > 0: + table.add_row(Rule(style=_DiffView.GRAY)) + continue + + if line.startswith("+"): + marker = Text("+", style="green") + syntax = Syntax(line[1:], lexer, theme="monokai", background_color=_DiffView.GREEN) + nln += 1 + ln = nln + elif line.startswith("-"): + syntax = Syntax(line[1:], lexer, theme="monokai", background_color=_DiffView.RED) + marker = Text("-", style="red") + oln += 1 + ln = oln + else: + marker = " " + syntax = Syntax( + line[1:], lexer, theme="monokai", word_wrap=True, background_color="default" + ) + oln += 1 + nln += 1 + ln += 1 + + table.add_row(Text(str(ln), style=_DiffView.GRAY), marker, syntax) + + text = Text(f"Edit {file}\n", style=_DiffView.GRAY) + return Group(text, table) diff --git a/src/kimi_cli/wire/message.py b/src/kimi_cli/wire/message.py index cde8a11e..5332e06b 100644 --- a/src/kimi_cli/wire/message.py +++ b/src/kimi_cli/wire/message.py @@ -48,7 +48,15 @@ class SubagentEvent(NamedTuple): type ControlFlowEvent = StepBegin | StepInterrupted | CompactionBegin | CompactionEnd | StatusUpdate -type Event = ControlFlowEvent | ContentPart | ToolCall | ToolCallPart | ToolResult | SubagentEvent +type Event = ( + ControlFlowEvent + | ContentPart + | ToolCall + | ToolCallPart + | ToolResult + | SubagentEvent + | PreviewChange +) class ApprovalResponse(Enum): @@ -94,7 +102,25 @@ def resolved(self) -> bool: return self._future.done() -type WireMessage = Event | ApprovalRequest +class PreviewChange: + def __init__( + self, file_path: str, content: str, content_type: str = "markdown", style: str = "auto" + ): + self.id = str(uuid.uuid4()) + self.file_path = file_path + self.content = content + self.content_type = content_type + self.style = style + self._future = asyncio.Future[bool]() + + async def wait(self) -> bool: + return await self._future + + def resolve(self) -> None: + self._future.set_result(True) + + +type WireMessage = Event | ApprovalRequest | PreviewChange def serialize_event(event: Event) -> dict[str, Any]: @@ -144,6 +170,12 @@ def serialize_event(event: Event) -> dict[str, Any]: }, } + case PreviewChange(): + return { + "type": "preview_request", + "payload": serialize_preview_request(event), + } + def serialize_approval_request(request: ApprovalRequest) -> dict[str, Any]: """ @@ -189,3 +221,13 @@ def _serialize_tool_output( return output.model_dump(mode="json", exclude_none=True) else: # Sequence[ContentPart] return [part.model_dump(mode="json", exclude_none=True) for part in output] + + +def serialize_preview_request(request: PreviewChange) -> dict[str, Any]: + return { + "id": request.id, + "file_path": request.file_path, + "content": request.content, + "content_type": request.content_type, + "style": request.style, + } diff --git a/tests/conftest.py b/tests/conftest.py index 9ef778db..89cf96c3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,6 +15,7 @@ from kimi_cli.session import Session from kimi_cli.soul.approval import Approval from kimi_cli.soul.denwarenji import DenwaRenji +from kimi_cli.soul.preview import Preview from kimi_cli.soul.runtime import BuiltinSystemPromptArgs, Runtime from kimi_cli.tools.bash import Bash from kimi_cli.tools.dmail import SendDMail @@ -99,23 +100,31 @@ def approval() -> Approval: return Approval(yolo=True) +@pytest.fixture +def preview() -> Preview: + """Create a Preview instance.""" + return Preview(yolo=True) + + @pytest.fixture def runtime( config: Config, llm: LLM, + session: Session, builtin_args: BuiltinSystemPromptArgs, denwa_renji: DenwaRenji, - session: Session, approval: Approval, + preview: Preview, ) -> Runtime: """Create a Runtime instance.""" return Runtime( config=config, llm=llm, + session=session, builtin_args=builtin_args, denwa_renji=denwa_renji, - session=session, approval=approval, + preview=preview, ) @@ -192,29 +201,35 @@ def grep_tool() -> Grep: @pytest.fixture def write_file_tool( - builtin_args: BuiltinSystemPromptArgs, approval: Approval + builtin_args: BuiltinSystemPromptArgs, + approval: Approval, + preview: Preview, ) -> Generator[WriteFile]: """Create a WriteFile tool instance.""" with tool_call_context("WriteFile"): - yield WriteFile(builtin_args, approval) + yield WriteFile(builtin_args, approval, preview) @pytest.fixture def str_replace_file_tool( - builtin_args: BuiltinSystemPromptArgs, approval: Approval + builtin_args: BuiltinSystemPromptArgs, + approval: Approval, + preview: Preview, ) -> Generator[StrReplaceFile]: """Create a StrReplaceFile tool instance.""" with tool_call_context("StrReplaceFile"): - yield StrReplaceFile(builtin_args, approval) + yield StrReplaceFile(builtin_args, approval, preview) @pytest.fixture def patch_file_tool( - builtin_args: BuiltinSystemPromptArgs, approval: Approval + builtin_args: BuiltinSystemPromptArgs, + approval: Approval, + preview: Preview, ) -> Generator[PatchFile]: """Create a PatchFile tool instance.""" with tool_call_context("PatchFile"): - yield PatchFile(builtin_args, approval) + yield PatchFile(builtin_args, approval, preview) @pytest.fixture