-
Notifications
You must be signed in to change notification settings - Fork 10
Expand file tree
/
Copy pathmain.py
More file actions
265 lines (212 loc) · 8.6 KB
/
main.py
File metadata and controls
265 lines (212 loc) · 8.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
import argparse
import asyncio
import importlib
import importlib.util
import sys
from pathlib import Path
from tomllib import load as load_toml
from pydantic import ValidationError
import services.error # noqa: F401
import services.logger as log
import services.util as u
from services import config_io
from services.bridge import bridge
from services.config_schema import GlobalConfig
from services.db import db_target_version, init_db
from services.http_server import HttpServerManager
logger = log.get_logger()
def _load_project_version() -> str:
"""Load project version from pyproject.toml.
Returns:
The version string from [project].version.
Raises:
RuntimeError: If pyproject.toml cannot be read or version is missing.
"""
try:
with open("pyproject.toml", "rb") as f:
version = str(load_toml(f).get("project", {}).get("version", "")).strip()
except Exception as exc:
raise RuntimeError("Read version info failed") from exc
if not version:
raise RuntimeError("Missing [project].version in pyproject.toml")
return version
def _load_all_drivers(enabled_platforms: list[str]) -> None:
"""Import every module in the ``drivers/`` package.
Each driver module calls ``drivers.registry.register()`` at import time,
so this one pass is enough to populate the registry. The ``registry``
module itself is skipped to avoid a circular bootstrap.
"""
for platform in enabled_platforms:
module_name = f"drivers.{platform}"
if importlib.util.find_spec(module_name) is None:
logger.warning(
f"Driver module for platform '{platform}' not found, skipping."
)
continue
importlib.import_module(module_name)
def cmd_convert(src: str, dst: str) -> None:
src_path = Path(src)
dst_path = Path(dst)
if not src_path.is_file():
logger.error(f"Source file not found: {src_path}")
sys.exit(1)
try:
data = config_io.load_config(src_path)
except Exception:
logger.opt(exception=True).critical(f"Error reading {src_path}")
sys.exit(1)
try:
config_io.save_config(data, dst_path)
except Exception:
logger.opt(exception=True).critical(f"Error reading {dst_path}")
sys.exit(1)
print(f"Converted {src_path} → {dst_path}")
async def main():
try:
version = _load_project_version()
except Exception:
logger.opt(exception=True).critical("Startup aborted: failed to load version")
return
config_path = config_io.find_config(Path(u.get_data_path()))
if config_path is None:
logger.critical(
f"No config file found in: {u.get_data_path()} (tried config.json / .yaml / .toml)"
)
return
bridge.load_rules()
logger.info(f"Loading config from: {config_path}")
raw: dict = config_io.load_config(config_path)
bridge.load_sensitive_values(raw)
enabled_platforms = [key for key in raw if key != "global"]
_load_all_drivers(enabled_platforms)
from drivers.registry import all_drivers
logger.info("NextBridge starting...")
# Load global configuration
global_config = raw.get("global", {})
bridge.strict_echo_match = global_config.get("strict_echo_match", False)
# Validate database configuration
try:
validated_global = GlobalConfig.model_validate(global_config)
except ValidationError as exc:
logger.opt(exception=exc).critical("Global configuration error")
return
# Logging configuration
log.set_console_level(validated_global.log.level)
log.set_log_dir(validated_global.log.dir)
log.set_log_rotation(
rotation_size=validated_global.log.rotation_size,
retention_days=validated_global.log.retention_days,
compression=validated_global.log.compression,
file_level=validated_global.log.file_level,
)
bridge.command_prefix = validated_global.command_prefix
try:
init_db()
logger.info(
f"Database initialized at startup with db_version target {db_target_version()}"
)
except Exception:
logger.opt(exception=True).critical(
"Startup aborted: database initialization failed"
)
return
http_server = HttpServerManager(
host=validated_global.http.host,
port=validated_global.http.port,
root_path=validated_global.http.root_path,
log_level=validated_global.http.log_level,
start_without_mounts=validated_global.http.enable == "true",
version=version,
)
# Validate each driver's per-instance configs via its registered model.
registry = all_drivers()
validated: dict[str, dict[str, object]] = {}
config_ok = True
for platform, (config_cls, _) in registry.items():
for inst_id, inst_raw in raw.get(platform, {}).items():
try:
validated.setdefault(platform, {})[inst_id] = config_cls.model_validate(
inst_raw
)
except ValidationError as exc:
logger.opt(exception=exc).critical(
f"Config error in {platform}.{inst_id}"
)
config_ok = False
if not config_ok:
return
def _on_task_done(task: asyncio.Task) -> None:
if task.cancelled():
return
exc = task.exception()
if exc is not None:
logger.opt(exception=exc).error(f"Driver '{task.get_name()}' crashed")
logger.info(f"========== NextBridge v{version} Starting ==========")
driver_tasks: list[asyncio.Task] = []
for platform, (_, driver_cls) in registry.items():
for inst_id, cfg in validated.get(platform, {}).items():
drv = driver_cls(inst_id, cfg, bridge)
drv.attach_http_server(http_server)
task = asyncio.create_task(drv.start(), name=f"{platform}/{inst_id}")
task.add_done_callback(_on_task_done)
driver_tasks.append(task)
logger.info(f"Registered driver: {platform}/{inst_id}")
if not driver_tasks and validated_global.http.enable != "true":
logger.error("No drivers configured — nothing to do, exiting.")
return
if not driver_tasks and validated_global.http.enable == "true":
logger.warning(
"No drivers configured — starting HTTP server due to http.enable=true"
)
# Let drivers perform startup and register webhook sub-apps.
await asyncio.sleep(0)
all_tasks = list(driver_tasks)
http_enable = validated_global.http.enable
if http_enable == "false":
if http_server.has_mounts():
logger.warning(
"HTTP server is disabled by http.enable=false while drivers mounted "
"webhook sub-apps; inbound webhook features are unavailable"
)
logger.info("Shared HTTP server disabled by configuration (http.enable=false)")
elif http_server.should_start():
http_task = asyncio.create_task(http_server.run(), name="http/shared")
http_task.add_done_callback(_on_task_done)
all_tasks.append(http_task)
else:
logger.info("No HTTP sub-app mounted; shared HTTP server disabled")
try:
results = await asyncio.gather(*all_tasks, return_exceptions=True)
for task, result in zip(all_tasks, results):
if isinstance(result, Exception):
logger.error(f"Driver '{task.get_name()}' exited with error: {result}")
except asyncio.CancelledError:
logger.info("NextBridge shutting down...")
# stop all tasks explicitly
for task in all_tasks:
if not task.done():
task.cancel()
# wait for all drivers to clean up
await asyncio.gather(*all_tasks, return_exceptions=True)
# close all sessions to avoid connection leaks
from services.media import close_all_sessions
await close_all_sessions()
logger.info("NextBridge stopped.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog="nextbridge", description="NextBridge chat bridge"
)
subparsers = parser.add_subparsers(dest="command")
conv = subparsers.add_parser(
"convert", help="Convert a config file between formats (json/yaml/toml)"
)
conv.add_argument("src", help="Source config file (e.g. config.json)")
conv.add_argument("dst", help="Destination config file (e.g. config.yaml)")
args = parser.parse_args()
if args.command == "convert":
cmd_convert(args.src, args.dst)
sys.exit(0)
try:
asyncio.run(main())
except KeyboardInterrupt:
pass