Skip to content

Commit 939c260

Browse files
author
Anders Brams
committed
feat: add ide completion tests
1 parent 37cd9a2 commit 939c260

2 files changed

Lines changed: 550 additions & 0 deletions

File tree

tests/ide/conftest.py

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
from __future__ import annotations
2+
3+
import json
4+
import subprocess
5+
from collections.abc import Iterator
6+
from pathlib import Path
7+
from typing import Any
8+
9+
import pytest
10+
11+
CURSOR = "<CURSOR>"
12+
13+
14+
class TyServer:
15+
def __init__(self, workspace: Path) -> None:
16+
self.workspace = workspace
17+
self.process = subprocess.Popen(
18+
["ty", "server"],
19+
cwd=self.workspace,
20+
stdin=subprocess.PIPE,
21+
stdout=subprocess.PIPE,
22+
stderr=subprocess.DEVNULL,
23+
)
24+
stdin = self.process.stdin
25+
stdout = self.process.stdout
26+
if stdin is None or stdout is None:
27+
raise RuntimeError("failed to start ty server")
28+
29+
self.stdin = stdin
30+
self.stdout = stdout
31+
self._next_id = 1
32+
self._request(
33+
"initialize",
34+
{
35+
"processId": None,
36+
"rootUri": self.workspace.as_uri(),
37+
"capabilities": {},
38+
},
39+
)
40+
self._notify("initialized", {})
41+
42+
def close(self) -> None:
43+
if self.process.poll() is not None:
44+
return
45+
try:
46+
self._request("shutdown", None)
47+
self._notify("exit", {})
48+
self.process.wait(timeout=5)
49+
finally:
50+
if self.process.poll() is None:
51+
self.process.kill()
52+
self.process.wait(timeout=5)
53+
54+
def completion_labels(self, source: str) -> set[str]:
55+
if source.count(CURSOR) != 1:
56+
raise ValueError(f"source must contain exactly one {CURSOR!r} marker")
57+
58+
before, after = source.split(CURSOR)
59+
text = before + after
60+
uri = (self.workspace / "__ty_completion_probe.py").as_uri()
61+
62+
self._notify(
63+
"textDocument/didOpen",
64+
{
65+
"textDocument": {
66+
"uri": uri,
67+
"languageId": "python",
68+
"version": 1,
69+
"text": text,
70+
}
71+
},
72+
)
73+
try:
74+
response = self._request(
75+
"textDocument/completion",
76+
{
77+
"textDocument": {"uri": uri},
78+
"position": _position(before),
79+
"context": {"triggerKind": 1},
80+
},
81+
)
82+
result = response.get("result")
83+
if result is None:
84+
return set()
85+
items = result if isinstance(result, list) else result.get("items", [])
86+
return {
87+
str(item["label"])
88+
for item in items
89+
if isinstance(item, dict) and "label" in item
90+
}
91+
finally:
92+
self._notify("textDocument/didClose", {"textDocument": {"uri": uri}})
93+
94+
def _request(self, method: str, params: object) -> dict[str, Any]:
95+
request_id = self._next_id
96+
self._next_id += 1
97+
self._send(
98+
{"jsonrpc": "2.0", "id": request_id, "method": method, "params": params}
99+
)
100+
while True:
101+
message = self._read_message()
102+
if message.get("id") != request_id:
103+
continue
104+
if "error" in message:
105+
raise AssertionError(
106+
f"ty server returned an error for {method}: {message['error']}"
107+
)
108+
return message
109+
110+
def _notify(self, method: str, params: object) -> None:
111+
self._send({"jsonrpc": "2.0", "method": method, "params": params})
112+
113+
def _send(self, message: dict[str, object]) -> None:
114+
payload = json.dumps(message).encode()
115+
self.stdin.write(f"Content-Length: {len(payload)}\r\n\r\n".encode() + payload)
116+
self.stdin.flush()
117+
118+
def _read_message(self) -> dict[str, Any]:
119+
header = bytearray()
120+
while b"\r\n\r\n" not in header:
121+
header.extend(self._read_exact(1))
122+
123+
content_length = None
124+
for line in bytes(header).split(b"\r\n"):
125+
if line.lower().startswith(b"content-length:"):
126+
content_length = int(line.split(b":", 1)[1].strip())
127+
break
128+
if content_length is None:
129+
raise AssertionError(f"ty server sent invalid LSP headers: {header!r}")
130+
131+
payload = self._read_exact(content_length)
132+
return json.loads(payload)
133+
134+
def _read_exact(self, length: int) -> bytes:
135+
chunks = bytearray()
136+
while len(chunks) < length:
137+
chunk = self.stdout.read(length - len(chunks))
138+
if not chunk:
139+
raise RuntimeError("ty server exited unexpectedly")
140+
chunks.extend(chunk)
141+
return bytes(chunks)
142+
143+
144+
def _position(text_before_cursor: str) -> dict[str, int]:
145+
current_line = text_before_cursor.rsplit("\n", 1)[-1]
146+
return {
147+
"line": text_before_cursor.count("\n"),
148+
"character": len(current_line.encode("utf-16-le")) // 2,
149+
}
150+
151+
152+
@pytest.fixture(scope="module")
153+
def ty_server(ide_workspace: Path) -> Iterator[TyServer]:
154+
server = TyServer(ide_workspace)
155+
try:
156+
yield server
157+
finally:
158+
server.close()

0 commit comments

Comments
 (0)