Skip to content

Commit 181a32a

Browse files
committed
Add unit tests for various components and screens
- Implement tests for GitHelperMethods and issue commands in test_core.py. - Create tests for NL2Shell command translation and context management in test_nl2shell.py. - Enhance suggestion engine tests in test_suggestions.py to support async context retrieval. - Add tests for AgentScreen, ToolApprovalScreen, and other UI components in their respective test files. - Introduce tests for confirmation, disclaimer, help, save dialog, selection, and tools screens. - Ensure all new tests cover key functionalities and edge cases for better code reliability.
1 parent 6f64fa6 commit 181a32a

26 files changed

Lines changed: 2734 additions & 1713 deletions

ai-todo.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# AI Todo List - Documentation Polish
2+
3+
## 현재 진행 중인 작업
4+
- [x] Rewrite and polish `docs/user/planning.md`
5+
6+
## 할 일 목록
7+
- [x] Read `docs/user/planning.md` to understand current content
8+
- [x] Research `/plan` implementation to ensure accuracy (if needed)
9+
- [x] Rewrite `docs/user/planning.md` with MkDocs features (tabs, admonitions)
10+
- [x] Verify the documentation (links, consistency)

ai/oauth.py

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
import base64
5+
import hashlib
6+
import json
7+
import os
8+
import secrets
9+
import webbrowser
10+
from dataclasses import dataclass
11+
from http.server import BaseHTTPRequestHandler, HTTPServer
12+
from pathlib import Path
13+
from threading import Thread
14+
from typing import Any
15+
from urllib.parse import parse_qs, urlparse
16+
17+
import httpx
18+
19+
CONFIG_DIR = Path.home() / ".null"
20+
OAUTH_TOKENS_FILE = CONFIG_DIR / "oauth_tokens.json"
21+
22+
23+
@dataclass
24+
class OAuthTokens:
25+
access_token: str
26+
refresh_token: str
27+
expires_at: float
28+
provider: str
29+
30+
def is_expired(self) -> bool:
31+
import time
32+
33+
return time.time() >= self.expires_at - 60
34+
35+
def to_dict(self) -> dict[str, Any]:
36+
return {
37+
"access_token": self.access_token,
38+
"refresh_token": self.refresh_token,
39+
"expires_at": self.expires_at,
40+
"provider": self.provider,
41+
}
42+
43+
@classmethod
44+
def from_dict(cls, data: dict[str, Any]) -> OAuthTokens:
45+
return cls(
46+
access_token=data["access_token"],
47+
refresh_token=data["refresh_token"],
48+
expires_at=data["expires_at"],
49+
provider=data["provider"],
50+
)
51+
52+
53+
class PKCEHelper:
54+
@staticmethod
55+
def generate() -> tuple[str, str]:
56+
verifier = secrets.token_urlsafe(32)
57+
challenge = (
58+
base64.urlsafe_b64encode(hashlib.sha256(verifier.encode()).digest())
59+
.rstrip(b"=")
60+
.decode()
61+
)
62+
return verifier, challenge
63+
64+
65+
class OAuthCallbackHandler(BaseHTTPRequestHandler):
66+
auth_code: str | None = None
67+
state: str | None = None
68+
error: str | None = None
69+
70+
def log_message(self, format: str, *args: Any) -> None:
71+
pass
72+
73+
def do_GET(self) -> None:
74+
parsed = urlparse(self.path)
75+
params = parse_qs(parsed.query)
76+
77+
if "error" in params:
78+
OAuthCallbackHandler.error = params["error"][0]
79+
elif "code" in params:
80+
OAuthCallbackHandler.auth_code = params["code"][0]
81+
OAuthCallbackHandler.state = params.get("state", [None])[0]
82+
83+
self.send_response(200)
84+
self.send_header("Content-type", "text/html")
85+
self.end_headers()
86+
87+
if OAuthCallbackHandler.error:
88+
html = "<html><body><h1>Authentication Failed</h1><p>You can close this window.</p></body></html>"
89+
else:
90+
html = "<html><body><h1>Authentication Successful</h1><p>You can close this window and return to Null Terminal.</p></body></html>"
91+
self.wfile.write(html.encode())
92+
93+
94+
class OAuthTokenStore:
95+
@staticmethod
96+
def save(tokens: OAuthTokens) -> None:
97+
CONFIG_DIR.mkdir(parents=True, exist_ok=True)
98+
99+
all_tokens: dict[str, Any] = {}
100+
if OAUTH_TOKENS_FILE.exists():
101+
try:
102+
all_tokens = json.loads(OAUTH_TOKENS_FILE.read_text())
103+
except (json.JSONDecodeError, OSError):
104+
pass
105+
106+
all_tokens[tokens.provider] = tokens.to_dict()
107+
OAUTH_TOKENS_FILE.write_text(json.dumps(all_tokens, indent=2))
108+
109+
@staticmethod
110+
def load(provider: str) -> OAuthTokens | None:
111+
if not OAUTH_TOKENS_FILE.exists():
112+
return None
113+
114+
try:
115+
all_tokens = json.loads(OAUTH_TOKENS_FILE.read_text())
116+
if provider in all_tokens:
117+
return OAuthTokens.from_dict(all_tokens[provider])
118+
except (json.JSONDecodeError, OSError, KeyError):
119+
pass
120+
121+
return None
122+
123+
@staticmethod
124+
def delete(provider: str) -> None:
125+
if not OAUTH_TOKENS_FILE.exists():
126+
return
127+
128+
try:
129+
all_tokens = json.loads(OAUTH_TOKENS_FILE.read_text())
130+
if provider in all_tokens:
131+
del all_tokens[provider]
132+
OAUTH_TOKENS_FILE.write_text(json.dumps(all_tokens, indent=2))
133+
except (json.JSONDecodeError, OSError):
134+
pass
135+
136+
137+
class BaseOAuthFlow:
138+
def __init__(self, callback_port: int = 51121):
139+
self.callback_port = callback_port
140+
self.callback_url = f"http://localhost:{callback_port}/oauth-callback"
141+
self._server: HTTPServer | None = None
142+
self._server_thread: Thread | None = None
143+
144+
def _start_callback_server(self) -> None:
145+
OAuthCallbackHandler.auth_code = None
146+
OAuthCallbackHandler.state = None
147+
OAuthCallbackHandler.error = None
148+
149+
self._server = HTTPServer(
150+
("localhost", self.callback_port), OAuthCallbackHandler
151+
)
152+
self._server_thread = Thread(target=self._server.handle_request, daemon=True)
153+
self._server_thread.start()
154+
155+
def _wait_for_callback(
156+
self, timeout: float = 120
157+
) -> tuple[str | None, str | None, str | None]:
158+
if self._server_thread:
159+
self._server_thread.join(timeout=timeout)
160+
161+
if self._server:
162+
self._server.server_close()
163+
164+
return (
165+
OAuthCallbackHandler.auth_code,
166+
OAuthCallbackHandler.state,
167+
OAuthCallbackHandler.error,
168+
)
169+
170+
def _open_browser(self, url: str) -> None:
171+
webbrowser.open(url)

0 commit comments

Comments
 (0)