Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/kimi_cli/soul/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from kimi_cli.session import Session
from kimi_cli.soul.approval import Approval
from kimi_cli.soul.denwarenji import DenwaRenji
from kimi_cli.soul.preview import Preview
from kimi_cli.soul.runtime import BuiltinSystemPromptArgs, Runtime
from kimi_cli.soul.toolset import CustomToolset
from kimi_cli.tools import SkipThisTool
Expand Down Expand Up @@ -55,6 +56,7 @@ async def load_agent(
Session: runtime.session,
DenwaRenji: runtime.denwa_renji,
Approval: runtime.approval,
Preview: runtime.preview,
}
tools = agent_spec.tools
if agent_spec.exclude_tools:
Expand Down
8 changes: 8 additions & 0 deletions src/kimi_cli/soul/kimisoul.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(
self._runtime = runtime
self._denwa_renji = runtime.denwa_renji
self._approval = runtime.approval
self._preview = runtime.preview
self._context = context
self._loop_control = runtime.config.loop_control
self._compaction = SimpleCompaction() # TODO: maybe configurable and composable
Expand Down Expand Up @@ -160,10 +161,16 @@ async def _pipe_approval_to_wire():
request = await self._approval.fetch_request()
wire_send(request)

async def _pipe_preview_to_wire():
while True:
request = await self._preview.fetch_request()
wire_send(request)

step_no = 1
while True:
wire_send(StepBegin(step_no))
approval_task = asyncio.create_task(_pipe_approval_to_wire())
preview_task = asyncio.create_task(_pipe_preview_to_wire())
# FIXME: It's possible that a subagent's approval task steals approval request
# from the main agent. We must ensure that the Task tool will redirect them
# to the main wire. See `_SubWire` for more details. Later we need to figure
Expand Down Expand Up @@ -194,6 +201,7 @@ async def _pipe_approval_to_wire():
raise
finally:
approval_task.cancel() # stop piping approval requests to the wire
preview_task.cancel() # stop piping preview requests to the wire

if finished:
return
Expand Down
50 changes: 50 additions & 0 deletions src/kimi_cli/soul/preview.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import asyncio
import difflib

from pygments.lexers import get_lexer_for_filename

from kimi_cli.wire.message import PreviewChange


class Preview:
def __init__(self, yolo: bool = False):
self._preview_queue = asyncio.Queue[PreviewChange]()
self._yolo = yolo

async def get_lexer(self, file_path: str):
try:
lexer = get_lexer_for_filename(file_path)
return lexer.name.lower()
except Exception:
return "text"

async def preview_text(
self, file_path: str, content: str, content_type: str = "", style: str = ""
):
if self._yolo:
return

title = file_path
if not content_type:
content_type = await self.get_lexer(file_path)

msg = PreviewChange(title, content, content_type, style)
self._preview_queue.put_nowait(msg)
await msg.wait()

async def preview_diff(self, file_path: str, before: str, after: str):
if self._yolo:
return

diff = difflib.unified_diff(
before.splitlines(keepends=True), after.splitlines(keepends=True)
)

content = "".join(diff)
content_type = await self.get_lexer(file_path)
msg = PreviewChange(file_path, content, content_type, "diff")
self._preview_queue.put_nowait(msg)
await msg.wait()

async def fetch_request(self) -> PreviewChange:
return await self._preview_queue.get()
3 changes: 3 additions & 0 deletions src/kimi_cli/soul/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from kimi_cli.session import Session
from kimi_cli.soul.approval import Approval
from kimi_cli.soul.denwarenji import DenwaRenji
from kimi_cli.soul.preview import Preview
from kimi_cli.utils.logging import logger


Expand Down Expand Up @@ -68,6 +69,7 @@ class Runtime(NamedTuple):
builtin_args: BuiltinSystemPromptArgs
denwa_renji: DenwaRenji
approval: Approval
preview: Preview

@staticmethod
async def create(
Expand All @@ -93,4 +95,5 @@ async def create(
),
denwa_renji=DenwaRenji(),
approval=Approval(yolo=yolo),
preview=Preview(),
)
17 changes: 16 additions & 1 deletion src/kimi_cli/tools/file/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pydantic import BaseModel, Field

from kimi_cli.soul.approval import Approval
from kimi_cli.soul.preview import Preview
from kimi_cli.soul.runtime import BuiltinSystemPromptArgs
from kimi_cli.tools.file import FileActions
from kimi_cli.tools.utils import ToolRejectedError, load_desc
Expand Down Expand Up @@ -52,10 +53,17 @@ class PatchFile(CallableTool2[Params]):
description: str = load_desc(Path(__file__).parent / "patch.md")
params: type[Params] = Params

def __init__(self, builtin_args: BuiltinSystemPromptArgs, approval: Approval, **kwargs: Any):
def __init__(
self,
builtin_args: BuiltinSystemPromptArgs,
approval: Approval,
preview: Preview,
**kwargs: Any,
):
super().__init__(**kwargs)
self._work_dir = builtin_args.KIMI_WORK_DIR
self._approval = approval
self._preview = preview

def _validate_path(self, path: Path) -> ToolError | None:
"""Validate that the path is safe to patch."""
Expand Down Expand Up @@ -104,6 +112,13 @@ async def __call__(self, params: Params) -> ToolReturnType:
brief="Invalid path",
)

await self._preview.preview_text(
f"Patch file `{params.path}`",
params.diff,
"",
"diff",
)

# Request approval
if not await self._approval.request(
self.name,
Expand Down
28 changes: 19 additions & 9 deletions src/kimi_cli/tools/file/replace.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pydantic import BaseModel, Field

from kimi_cli.soul.approval import Approval
from kimi_cli.soul.preview import Preview
from kimi_cli.soul.runtime import BuiltinSystemPromptArgs
from kimi_cli.tools.file import FileActions
from kimi_cli.tools.utils import ToolRejectedError, load_desc
Expand All @@ -32,10 +33,17 @@ class StrReplaceFile(CallableTool2[Params]):
description: str = load_desc(Path(__file__).parent / "replace.md")
params: type[Params] = Params

def __init__(self, builtin_args: BuiltinSystemPromptArgs, approval: Approval, **kwargs: Any):
def __init__(
self,
builtin_args: BuiltinSystemPromptArgs,
approval: Approval,
preview: Preview,
**kwargs: Any,
):
super().__init__(**kwargs)
self._work_dir = builtin_args.KIMI_WORK_DIR
self._approval = approval
self._preview = preview

def _validate_path(self, path: Path) -> ToolError | None:
"""Validate that the path is safe to edit."""
Expand Down Expand Up @@ -91,14 +99,6 @@ async def __call__(self, params: Params) -> ToolReturnType:
brief="Invalid path",
)

# Request approval
if not await self._approval.request(
self.name,
FileActions.EDIT,
f"Edit file `{params.path}`",
):
return ToolRejectedError()

# Read the file content
async with aiofiles.open(p, encoding="utf-8", errors="replace") as f:
content = await f.read()
Expand All @@ -117,6 +117,16 @@ async def __call__(self, params: Params) -> ToolReturnType:
brief="No replacements made",
)

await self._preview.preview_diff(params.path, original_content, content)

# Request approval
if not await self._approval.request(
self.name,
FileActions.EDIT,
f"Edit file `{params.path}`",
):
return ToolRejectedError()

# Write the modified content back to the file
async with aiofiles.open(p, mode="w", encoding="utf-8") as f:
await f.write(content)
Expand Down
12 changes: 11 additions & 1 deletion src/kimi_cli/tools/file/write.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pydantic import BaseModel, Field

from kimi_cli.soul.approval import Approval
from kimi_cli.soul.preview import Preview
from kimi_cli.soul.runtime import BuiltinSystemPromptArgs
from kimi_cli.tools.file import FileActions
from kimi_cli.tools.utils import ToolRejectedError, load_desc
Expand All @@ -29,10 +30,17 @@ class WriteFile(CallableTool2[Params]):
description: str = load_desc(Path(__file__).parent / "write.md")
params: type[Params] = Params

def __init__(self, builtin_args: BuiltinSystemPromptArgs, approval: Approval, **kwargs: Any):
def __init__(
self,
builtin_args: BuiltinSystemPromptArgs,
approval: Approval,
preview: Preview,
**kwargs: Any,
):
super().__init__(**kwargs)
self._work_dir = builtin_args.KIMI_WORK_DIR
self._approval = approval
self._preview = preview

def _validate_path(self, path: Path) -> ToolError | None:
"""Validate that the path is safe to write."""
Expand Down Expand Up @@ -89,6 +97,8 @@ async def __call__(self, params: Params) -> ToolReturnType:
brief="Invalid write mode",
)

await self._preview.preview_text(params.path, params.content)

# Request approval
if not await self._approval.request(
self.name,
Expand Down
104 changes: 104 additions & 0 deletions src/kimi_cli/ui/shell/visualize.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import re
from collections import deque
from collections.abc import Callable
from contextlib import asynccontextmanager, suppress
Expand All @@ -11,7 +12,9 @@
from rich.live import Live
from rich.markup import escape
from rich.panel import Panel
from rich.rule import Rule
from rich.spinner import Spinner
from rich.syntax import Syntax
from rich.table import Table
from rich.text import Text

Expand All @@ -26,6 +29,7 @@
ApprovalResponse,
CompactionBegin,
CompactionEnd,
PreviewChange,
StatusUpdate,
StepBegin,
StepInterrupted,
Expand Down Expand Up @@ -390,6 +394,8 @@ def dispatch_wire_message(self, msg: WireMessage) -> None:
self.request_approval(msg)
case SubagentEvent():
self.handle_subagent_event(msg)
case PreviewChange():
self.append_preview(msg)

def dispatch_keyboard_event(self, event: KeyEvent) -> None:
# handle ESC key to cancel the run
Expand Down Expand Up @@ -559,6 +565,36 @@ def handle_subagent_event(self, event: SubagentEvent) -> None:
# TODO: may need to handle multi-level nested subagents
pass

def append_preview(self, msg: PreviewChange):
content_type = msg.content_type
if content_type == "markdown":
body = Markdown(
msg.content,
justify="left",
)
elif content_type in {"text", "text only"}:
body = Text(msg.content)
elif msg.style == "diff":
body = _DiffView.get_view(msg.file_path, msg.content, msg.content_type)
else:
body = Syntax(
msg.content,
content_type,
theme="monokai",
line_numbers=True,
background_color="default",
padding=(0, 0),
)

width = int(console.width * 0.8)
panel = Panel(
body,
border_style="wheat4",
width=width,
)
console.print(panel)
msg.resolve()


def _with_bullet(
renderable: RenderableType,
Expand All @@ -574,3 +610,71 @@ def _with_bullet(
bullet = Text("•")
table.add_row(bullet, renderable)
return table


class _DiffView:
RED = "#3A0003"
GREEN = "#242F12"
GRAY = "grey50"

@staticmethod
def parse_diff_header(diff_line: str) -> tuple[int, int, int, int]:
pattern = r"@@ -(\d+),?(\d*) \+(\d+),?(\d*) @@"
match = re.match(pattern, diff_line)

if not match:
return (0, 0, 0, 0)

old_start, old_lines, new_start, new_lines = match.groups()

old_lines = int(old_lines) if old_lines else 1
new_lines = int(new_lines) if new_lines else 1

return (int(old_start), old_lines, int(new_start), new_lines)

@staticmethod
def get_view(file: str, code: str, lexer: str) -> RenderableType:
if not code:
return Group(
Text(f"Edit {file}\n", style="grey50"), Text("nothing changed", style="grey50")
)

table = Table.grid(padding=(0, 1))
table.add_column(style=_DiffView.GRAY, no_wrap=True) # line number
table.add_column(style=_DiffView.GRAY, no_wrap=True) # Diff marker
table.add_column() # code

ln = oln = nln = 0
for line in code.splitlines():
if line.startswith("---") or line.startswith("+++"):
continue

if line.startswith("@@"):
ln = oln = nln = _DiffView.parse_diff_header(line)[0] - 1
if len(table.rows) > 0:
table.add_row(Rule(style=_DiffView.GRAY))
continue

if line.startswith("+"):
marker = Text("+", style="green")
syntax = Syntax(line[1:], lexer, theme="monokai", background_color=_DiffView.GREEN)
nln += 1
ln = nln
elif line.startswith("-"):
syntax = Syntax(line[1:], lexer, theme="monokai", background_color=_DiffView.RED)
marker = Text("-", style="red")
oln += 1
ln = oln
else:
marker = " "
syntax = Syntax(
line[1:], lexer, theme="monokai", word_wrap=True, background_color="default"
)
oln += 1
nln += 1
ln += 1

table.add_row(Text(str(ln), style=_DiffView.GRAY), marker, syntax)

text = Text(f"Edit {file}\n", style=_DiffView.GRAY)
return Group(text, table)
Loading