-
Notifications
You must be signed in to change notification settings - Fork 47
Expand file tree
/
Copy pathsymphony.py
More file actions
2257 lines (1964 loc) · 99.1 KB
/
symphony.py
File metadata and controls
2257 lines (1964 loc) · 99.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
Symphony 2.0 core execution engine (Path-1: centralized orchestrator).
✅ Symphony 2.0 features:
- Two-stage agent selection:
Stage-1: Top-L by static capability match_score (Symphony 1.0 compatible)
Stage-2: Global LinUCB selects within Top-L using dynamic_state features
- Multi-CoT execution and voting per subtask
- Online update after voting (winner bonus + latency penalty)
- Optional Symphony 1.0-style planning decomposition (multiple planners produce chains)
✅ This patched version additionally:
- Planner branch also performs online updates (closes the loop in planner mode)
- Planner weighted vote keys on extracted Final (not whole raw text)
- Never uses x[1] as match_score (build_x may normalize -> x[1] != raw match)
- Safer fallback: do not pick unavailable agent when pool is empty
- Optional correctness reward if gold label is provided in task/subtask context
Return modes:
- "aggregate": returns a multi-subtask report (string)
- "final": returns final answer only (string, BBH-friendly)
- "trace": returns dict with per-run traces (for debugging / saving)
"""
from __future__ import annotations
import json
import re
import threading
import time
import uuid
from collections import Counter, defaultdict
from typing import Any, Dict, List, Optional, Tuple, Union
# ---------------------- imports (package / local) ----------------------
try:
# Package mode
from symphony.protocol.task_contract import Task # type: ignore
from symphony.agents.agent import Agent # type: ignore
from symphony.core.linucb_selector import GlobalLinUCB, build_x # type: ignore
except Exception: # pragma: no cover
# Local mode
from protocol.task_contract import Task # type: ignore
from agents.agent import Agent # type: ignore
from core.linucb_selector import GlobalLinUCB, build_x # type: ignore
# Risk guard is optional
try:
from core.risk_guard import RiskAwareGuard, RiskGuardConfig # type: ignore
except Exception: # pragma: no cover
RiskAwareGuard = None # type: ignore
RiskGuardConfig = None # type: ignore
# ---------------------- orchestrator ----------------------
class SymphonyOrchestrator:
"""Main orchestrator for multi-agent task execution (Path-1)."""
def __init__(
self,
verbose: bool = False,
# ---- Dynamic Beacon Selection knobs (2.0) ----
use_dynamic: bool = True,
topL: int = 3,
linucb_alpha: float = 1.0,
linucb_l2: float = 1.0,
latency_scale_ms: float = 2000.0,
latency_penalty: float = 0.2,
win_bonus: float = 0.5,
# ---- Optional correctness reward (for better scores) ----
correctness_bonus: float = 0.0, # add when voted final matches gold
incorrect_penalty: float = 0.0, # subtract when voted final mismatches gold
# ---- Optional Symphony 1.0 planner ----
plan_k: int = 1,
# ---- Use planner to decompose even when plan_k == 1 ----
use_planner_decompose: bool = False,
# ---- Risk Guard ----
enable_risk_guard: bool = False,
# ---- Shared Blackboard / by-ref dispatch (Path-2 style, optional) ----
dispatch_mode: str = "local", # "local" | "shared_bb"
shared_timeout_s: float = 30.0, # wait result timeout
shared_poll_interval: float = 0.01,
requester_id: Optional[str] = None, # pick requester agent by id if provided
# ---- ✅ P0-1: Cold-start priors injection ----
priors: Optional[Dict[str, Dict[str, float]]] = None, # agent_id -> bucket -> prior
priors_path: Optional[str] = None, # path to priors JSON file
# ---- ✅ P0-3: Strict routing mode (experiment mode) ----
strict_routing: bool = False, # If True, routing failure raises instead of fallback
# ---- ✅ Eval-only (no selector updates; avoid test-phase leakage) ----
eval_only: bool = False,
) -> None:
self.lock = threading.Lock()
self.verbose = bool(verbose)
self.eval_only = bool(eval_only)
# agent registry
self.agents: List[Agent] = []
# dynamic selection knobs
self.use_dynamic = bool(use_dynamic)
self.topL = max(1, int(topL))
self.latency_scale_ms = float(latency_scale_ms)
self.latency_penalty = float(latency_penalty)
self.win_bonus = float(win_bonus)
# correctness reward knobs
self.correctness_bonus = float(correctness_bonus)
self.incorrect_penalty = float(incorrect_penalty)
# planner knobs (Symphony 1.0 style)
self.plan_k = max(1, int(plan_k))
self.use_planner_decompose = bool(use_planner_decompose)
# optional risk guard
self.enable_risk_guard = bool(enable_risk_guard)
self.risk_guard = None
if self.enable_risk_guard and RiskAwareGuard is not None and RiskGuardConfig is not None:
self.risk_guard = RiskAwareGuard(RiskGuardConfig())
# ✅ Global LinUCB (single global A,b)
# build_x returns a vector (commonly 6-dim): [1, match, load, lat_norm, rep, available]
self.selector: Optional[GlobalLinUCB] = None
if self.use_dynamic:
self.selector = GlobalLinUCB(d=6, l2=float(linucb_l2), alpha=float(linucb_alpha))
# ---- shared dispatch knobs ----
self.dispatch_mode = str(dispatch_mode or "local")
self.shared_timeout_s = float(shared_timeout_s)
self.shared_poll_interval = float(shared_poll_interval)
self.requester_id = requester_id
self.strict_routing = bool(strict_routing) # ✅ P0-3: Strict routing mode
# ✅ P0-1: Load and inject learned priors
_priors: Dict[str, Dict[str, float]] = {}
if priors is not None:
_priors = priors
elif priors_path:
try:
from core.cold_start import load_priors
_priors = load_priors(priors_path)
if self.verbose:
self._log(f"[ORCHESTRATOR] Loaded priors from {priors_path} ({len(_priors)} agents)")
except Exception as e:
if self.verbose:
self._log(f"[WARN] Failed to load priors from {priors_path}: {e}")
# Store priors for later injection (will be injected when agents are registered)
self._learned_priors = _priors
# ---------------------- logging ----------------------
def _log(self, msg: str) -> None:
if self.verbose:
print(msg, flush=True)
# ---------------------- lifecycle ----------------------
def register_agent(self, agent: Agent) -> None:
with self.lock:
if agent not in self.agents:
self.agents.append(agent)
# ✅ P0-1: Use unified agent key resolution
agent_key = self._resolve_agent_key(agent)
# ✅ P0-1: Inject learned priors into agent (with logging)
if agent_key and agent_key in self._learned_priors:
agent.learned_priors = self._learned_priors[agent_key]
bucket_count = len(self._learned_priors[agent_key])
self._log(f"[PRIORS] ✅ Injected agent={agent_key} buckets={bucket_count}")
elif self._learned_priors and agent_key:
# ✅ P0-1: Log miss (only first few to avoid spam)
available_keys = list(self._learned_priors.keys())[:5]
if not hasattr(self, "_priors_miss_logged"):
self._priors_miss_logged = set()
if agent_key not in self._priors_miss_logged and len(self._priors_miss_logged) < 5:
self._priors_miss_logged.add(agent_key)
self._log(f"[PRIORS] ❌ Miss agent={agent_key} available_keys={available_keys}")
self._log(f"[ORCHESTRATOR] Registered agent: {agent_key}")
def get_registered_agents(self) -> List[Agent]:
return list(self.agents)
# ---------------------- P0-1: Unified agent key resolution ----------------------
@staticmethod
def _resolve_agent_key(agent: Agent) -> str:
"""
✅ P0-1: Resolve agent key for priors lookup.
Priority chain: agent_id -> node_id -> name -> id
Args:
agent: Agent object
Returns:
Agent key string (empty if none found)
"""
aid = (
str(getattr(agent, "agent_id", "")) or
str(getattr(agent, "node_id", "")) or
str(getattr(agent, "name", "")) or
str(getattr(agent, "id", "")) or
""
).strip()
return aid
# ---------------------- helpers: requirement normalization ----------------------
@staticmethod
def _normalize_requirement(req: str) -> str:
"""
Make matching slightly more robust:
- lowercase
- spaces/hyphens -> underscores
"""
r = (req or "").strip().lower()
r = re.sub(r"[\s\-]+", "_", r)
return r
# ---------------------- helpers: gold/correctness ----------------------
def _get_gold_from_context(self, ctx: Dict[str, Any]) -> Union[None, str, List[str]]:
"""
Look for gold label in context. Supported:
ctx["gold"] = "yes" or "A" or "42"
ctx["gold"] = ["A", "B"] (multiple acceptable)
"""
if not isinstance(ctx, dict):
return None
gold = ctx.get("gold", None)
if gold is None:
return None
if isinstance(gold, str):
g = gold.strip()
return g if g else None
if isinstance(gold, (list, tuple, set)):
out: List[str] = []
for x in gold:
if isinstance(x, str) and x.strip():
out.append(x.strip())
return out if out else None
return None
def _canon_answer(self, s: str) -> str:
"""
Canonicalize an answer for comparison:
- try extract_final first
- strip, collapse spaces, lowercase for yes/no/valid/invalid, keep letter uppercase
"""
t = (s or "").strip()
fin = self._extract_final_from_text(t)
t = (fin or t).strip()
t = re.sub(r"\s+", " ", t)
# normalize common BBH labels
low = t.lower()
if low in ("yes", "no", "valid", "invalid", "true", "false"):
return low
if re.fullmatch(r"[A-Za-z]", t):
return t.upper()
return t
def _is_correct(self, pred_text: str, gold: Union[str, List[str], None]) -> Optional[bool]:
if gold is None:
return None
pred = self._canon_answer(pred_text)
if isinstance(gold, str):
return pred == self._canon_answer(gold)
if isinstance(gold, list):
gold_set = {self._canon_answer(x) for x in gold if isinstance(x, str)}
return pred in gold_set if gold_set else None
return None
# ---------------------- P0-0: Unified feature building helper ----------------------
def _build_x_from_candidate_or_fallback(
self,
candidate: Dict[str, Any],
agent: Agent,
dynamic_state: Optional[Dict[str, Any]] = None,
) -> List[float]:
"""
✅ P0-0: Unified helper to build x from candidate (ensures ms=sim_emb, rep=prior_success).
This ensures all branches (dynamic, non-dynamic, risk rerun, planner) use consistent
feature definitions: ms = sim_emb, rep = prior_success.
Args:
candidate: Candidate dict from TopL with {"agent", "sim_emb", "prior_success", ...}
agent: Agent object (may differ from candidate["agent"] if using fallback)
dynamic_state: Optional dynamic state dict (if None, will fetch from agent)
Returns:
Feature vector x [1, ms, load, lat, rep, av]
where ms = sim_emb, rep = prior_success
"""
# Try routing.build_x_from_candidate() first (preferred path)
try:
from core.routing import build_x_from_candidate
x, _, _ = build_x_from_candidate(
candidate=candidate,
agent=agent,
latency_scale_ms=self.latency_scale_ms,
)
return x
except Exception:
# Fallback: build x manually with consistent feature definitions
pass
# ✅ P0-C: Extract sim_emb and prior_success from candidate (required)
# ✅ P0-C: sim_emb fallback must be neutral (0.5), NOT match_score (composite)
sim_emb_val = float(candidate.get("sim_emb", 0.5)) # Do NOT use match_score as fallback
prior_success_val = float(candidate.get("prior_success", 0.5))
# Get dynamic state if not provided
if dynamic_state is None:
dynamic_state = self._agent_state(agent)
# ✅ P0-0: Build x with consistent features: ms = sim_emb, rep = prior_success
x = build_x(
match_score=sim_emb_val, # ✅ ms = sim_emb (not composite match_score)
dynamic_state={
"load": float(dynamic_state.get("load", 0.0)),
"latency_ms": float(dynamic_state.get("latency_ms", 500.0)),
"reputation": prior_success_val, # ✅ rep = prior_success (not stt.reputation)
},
available=bool(dynamic_state.get("available", True)),
latency_scale_ms=self.latency_scale_ms,
)
return x
# ---------------------- public entry ----------------------
def execute_task(
self,
task: Task,
cot_count: int = 3,
return_mode: str = "aggregate", # "aggregate" | "final" | "trace"
) -> Any:
"""
Main execution:
- If plan_k > 1: planning => execute each plan chain => weighted vote.
- Else: decompose by task.requirements => execute each subtask with multi-CoT voting.
"""
task_text = getattr(task, "description", "") or ""
ctx = getattr(task, "context", {}) or {}
# ✅ stable run_id for this top-level task (used in shared_bb mode)
if isinstance(ctx, dict) and "_run_id" not in ctx:
ctx["_run_id"] = f"run:{uuid.uuid4().hex[:8]}"
try:
task.context = ctx # type: ignore[attr-defined]
except Exception:
pass
if not self.agents:
if return_mode == "trace":
return {"error": "No agents registered", "results": {}, "traces": {}}
return "[ERROR] No agents registered"
# ---------- (A) Optional planner mode (Symphony 1.0-style, patched to 2.0 loop) ----------
if self.plan_k > 1 or self.use_planner_decompose:
m = self.plan_k if self.plan_k > 1 else 1
plans = self._plan_chains_v1(task_text=task_text, ctx=ctx, m=m)
plan_answers: List[str] = []
plan_weights: List[float] = []
plan_traces: List[Dict[str, Any]] = []
for p in plans:
ans, w, tr = self._run_plan_chain_v1(base_task=task_text, chain=p["chain"], base_ctx=ctx)
plan_answers.append(ans)
plan_weights.append(w)
plan_traces.append({"planner": p.get("planner", ""), "w": w, "trace": tr, "chain": p["chain"]})
# ✅ vote on extracted final keys (not whole raw text)
plan_keys = [(self._extract_final_from_text(a) or str(a).strip()) for a in plan_answers]
win_key = self._weighted_vote(plan_keys, plan_weights)
# choose representative full text with max weight among those sharing win_key
best_i = 0
best_w = -1e18
for i, (k, w) in enumerate(zip(plan_keys, plan_weights)):
if k == win_key and float(w) > best_w:
best_w = float(w)
best_i = i
final_text = plan_answers[best_i] if plan_answers else ""
# ✅ optional correctness reward at planner-level
gold = self._get_gold_from_context(ctx)
correct = self._is_correct(win_key, gold)
# ✅ winner-bonus updates for all steps in winning trajectory(ies)
# Skip updates in eval_only (e.g. test phase) to avoid label/data leakage.
if self.use_dynamic and self.selector is not None and not self.eval_only:
for i, k in enumerate(plan_keys):
if k != win_key:
continue
tr = plan_traces[i].get("trace", {}) if isinstance(plan_traces[i], dict) else {}
recs = tr.get("records", [])
for rec in recs:
x = rec.get("x")
if isinstance(x, list):
bonus = float(self.win_bonus)
if correct is True:
bonus += float(self.correctness_bonus)
elif correct is False:
bonus -= float(self.incorrect_penalty)
self.selector.update(x, bonus)
if return_mode == "trace":
# Build traces dict compatible with non-planner mode for pretrain.py
# Extract runs from winning plan's records to build a compatible trace structure
traces_dict: Dict[str, Any] = {}
all_runs: List[Dict[str, Any]] = []
# Extract runs from winning plan's trace (records are the run_records)
if plan_traces and best_i < len(plan_traces):
winning_trace = plan_traces[best_i].get("trace", {})
if isinstance(winning_trace, dict):
# _run_plan_chain_v1 returns {"steps": [...], "records": [...]}
records = winning_trace.get("records", [])
if isinstance(records, list):
# records are the run_records from plan chain execution
all_runs = records
# Build a single trace entry compatible with pretrain.py (expects {"traces": {"sub_1": {...}}})
if all_runs or final_text:
traces_dict["sub_1"] = {
"requirement": "planning",
"context": ctx,
"gold": gold,
"vote_count": dict(Counter(plan_keys)),
"vote_weight_by_match_score": {k: w for k, w in zip(plan_keys, plan_weights)},
"correct": correct,
"runs": all_runs,
"voted": final_text,
"voted_final": win_key,
}
return {
"results": {"sub_1": final_text} if final_text else {},
"traces": traces_dict,
"final": win_key,
"final_text": final_text,
"answers": plan_answers,
"keys": plan_keys,
"weights": plan_weights,
"plans": plan_traces,
"gold": gold,
"correct": correct,
}
if return_mode == "final":
return (win_key or (self._extract_final_from_text(final_text) or final_text)).strip()
# aggregate
rep = "## Symphony Planner Result\n\n"
rep += f"**Original Task**: {task_text}\n\n"
for i, (a, w) in enumerate(zip(plan_answers, plan_weights), 1):
rep += f"{i}. (w={w:.3f}) {a}\n\n"
rep += f"\n**Final answer**: {win_key}\n"
return rep.strip()
# ---------- (B) Default non-planner mode ----------
subtasks = self._decompose_task(task)
if not subtasks:
subtasks = [self._mk_subtask(task_text, ctx, i=1, requirement="general-reasoning")]
# normalize
for i, st in enumerate(subtasks):
if not isinstance(st, dict):
st = {"input": str(st)}
subtasks[i] = st
st.setdefault("id", f"sub_{i + 1}")
st.setdefault("requirement", "general-reasoning")
st.setdefault("context", ctx)
st.setdefault("original_task", task_text)
st.setdefault("description", st.get("input") or st.get("description") or task_text)
if not st.get("input"):
st["input"] = st.get("description") or task_text
# normalize req for matching
st["requirement"] = self._normalize_requirement(str(st.get("requirement", "general-reasoning")))
agent_assignments = self._find_suitable_agents(subtasks)
out = self._execute_with_cot_voting(
subtasks=subtasks,
agent_assignments=agent_assignments,
cot_count=cot_count,
return_mode=return_mode,
)
if return_mode == "trace":
if isinstance(out, dict) and "results" in out and isinstance(out["results"], dict):
if len(out["results"]) == 1:
one = next(iter(out["results"].values()))
out["final"] = self._extract_final_from_text(str(one)) or str(one)
return out
if return_mode == "final":
if isinstance(out, dict) and len(out) == 1:
s = str(next(iter(out.values()))).strip()
s = re.sub(r"(?is)</?\s*answer\s*>", "", s).strip()
fin = self._extract_final_from_text(s)
return (fin or s).strip()
aggregated = self._aggregate_results(out, task) # type: ignore[arg-type]
fin = self._extract_final_from_text(aggregated)
return fin.strip() if fin else aggregated.strip()
return self._aggregate_results(out, task) # type: ignore[arg-type]
# ---------------------- decomposition (simple baseline) ----------------------
def _mk_subtask(self, task_text: str, ctx: Dict[str, Any], i: int, requirement: str) -> Dict[str, Any]:
"""
✅ P0-2: Create subtask dict with benchmark/difficulty_bin for priors lookup.
These fields are required for routing.get_prior_success() to work correctly.
"""
st = {
"id": f"{uuid.uuid4().hex}_sub_{i}",
"requirement": self._normalize_requirement(requirement),
"input": task_text,
"description": task_text,
"context": ctx or {},
"original_task": task_text,
}
# ✅ P0-2: Inject benchmark and difficulty_bin for priors lookup
if isinstance(ctx, dict):
st["benchmark"] = str(ctx.get("benchmark", "")).strip()
st["difficulty_bin"] = str(ctx.get("difficulty_bin", ctx.get("difficulty", "unknown"))).strip() or "unknown"
else:
st["benchmark"] = ""
st["difficulty_bin"] = "unknown"
return st
def _decompose_task(self, task: Task) -> List[Dict[str, Any]]:
"""
Baseline decomposition: one subtask per requirement.
"""
reqs = list(getattr(task, "requirements", []) or [])
if not reqs:
reqs = ["general-reasoning"]
out: List[Dict[str, Any]] = []
for i, r in enumerate(reqs, 1):
out.append(
self._mk_subtask(
task_text=getattr(task, "description", "") or "",
ctx=getattr(task, "context", {}) or {},
i=i,
requirement=str(r),
)
)
return out
# ---------------------- agent matching ----------------------
def _agent_state(self, agent: Agent) -> Dict[str, Any]:
"""Symphony 2.0: read dynamic state if provided; else fallback defaults."""
if hasattr(agent, "get_dynamic_state"):
try:
st = agent.get_dynamic_state() # type: ignore[attr-defined]
if isinstance(st, dict):
st.setdefault("available", True)
st.setdefault("load", 0.0)
st.setdefault("latency_ms", 500.0)
st.setdefault("reputation", 0.5)
return st
except Exception:
pass
return {"available": True, "load": 0.0, "latency_ms": 500.0, "reputation": 0.5}
def _find_suitable_agents(self, subtasks: List[Dict[str, Any]]) -> Dict[str, List[Dict[str, Any]]]:
"""
✅ Return candidates per subtask with composite TopL score.
Delegates to routing.select_topL() for clean separation of concerns.
Returns:
[{"agent": Agent, "match_score": float, "sim_emb": float, "prior_success": float}, ...]
"""
try:
from core.routing import select_topL
use_routing = True
except ImportError:
use_routing = False
assignments: Dict[str, List[Dict[str, Any]]] = {}
for st in subtasks:
sid = st["id"]
# ✅ P0-3: Use routing module for clean TopL selection (must use routing)
if use_routing:
try:
candidates = select_topL(
agents=self.agents,
subtask=st,
topL=self.topL,
latency_scale_ms=float(self.latency_scale_ms),
use_embedding=True, # ✅ P0-3: Must use embedding
)
# ✅ P0-7: Validate candidate schema
for cand in candidates:
required_keys = {"agent", "match_score", "sim_emb", "prior_success"}
missing = required_keys - set(cand.keys())
if missing:
raise ValueError(
f"P0-7: Candidate missing required keys: {missing}. "
f"Candidate keys: {list(cand.keys())}"
)
assignments[sid] = candidates
continue
except Exception as e:
# ✅ P0-3: Strict routing mode: raise if routing fails (experiment mode)
if self.strict_routing:
raise RuntimeError(
f"P0-3: Routing.select_topL failed for subtask {sid} and strict_routing=True. "
f"Error: {e}"
)
# ✅ P0-3: Log routing failure but fallback to legacy (backward compatibility)
if self.verbose:
self._log(f"[WARN] Routing.select_topL failed for subtask {sid}: {e}. Falling back to legacy.")
# Fallback to legacy (for backward compatibility)
pass
# Legacy fallback: capability_manager.match() only (should not happen in normal flow)
req = self._normalize_requirement(str(st.get("requirement", "general-reasoning")))
cand: List[Dict[str, Any]] = []
for ag in self.agents:
ms = 0.5
if hasattr(ag, "capability_manager"):
try:
ms = float(ag.capability_manager.match(req)) # type: ignore[attr-defined]
except Exception:
ms = 0.5
cand.append({
"agent": ag,
"match_score": ms,
"sim_emb": ms, # Fallback: use match_score as sim_emb
"prior_success": 0.5, # Fallback: default
})
cand.sort(key=lambda x: float(x.get("match_score", 0.0)), reverse=True)
assignments[sid] = cand
return assignments
def _select_agent_dynamic(
self,
candidates: List[Dict[str, Any]],
used_ids: set,
) -> Tuple[Agent, List[float], Dict[str, Any], float]:
"""
Stage-1: Top-L by match_score (already sorted)
Stage-2: Global LinUCB selects within Top-L using build_x(match, dynamic_state, available)
Return: (agent, x, state, raw_match_score)
"""
topL = candidates[: self.topL]
pool: List[Tuple[str, List[float], Agent, Dict[str, Any], float]] = []
for c in topL:
agent = c["agent"]
# ✅ P0-1: Use unified agent key resolution (not just agent_id)
aid = self._resolve_agent_key(agent)
if not aid:
# Fallback: use object id if all keys are empty (should not happen)
aid = f"agent_{id(agent)}"
if aid in used_ids and len(used_ids) < len(topL):
continue
st = self._agent_state(agent)
if not bool(st.get("available", True)):
continue
# ✅ P1: UCB stage: use build_x_from_candidate() to reuse TopL results (avoid recompute)
try:
from core.routing import build_x_from_candidate
x, sim_emb_val, prior_success_val = build_x_from_candidate(
candidate=c,
agent=agent,
latency_scale_ms=float(self.latency_scale_ms),
)
raw_ms = float(c.get("match_score", 0.0)) # TopL composite score (for logging)
except Exception:
# ✅ P0-4: Fallback: use unified helper (ensures ms=sim_emb, rep=prior_success)
x = self._build_x_from_candidate_or_fallback(
candidate=c,
agent=agent,
dynamic_state=st,
)
raw_ms = float(c.get("match_score", 0.0))
pool.append((aid, x, agent, st, raw_ms))
# ✅ safer fallback: pick first AVAILABLE in topL; else pick topL[0] but mark available=False in x
if not pool:
for c in topL:
ag = c["agent"]
st0 = self._agent_state(ag)
if bool(st0.get("available", True)):
# ✅ P0-4: Use unified helper (ensures ms=sim_emb, rep=prior_success)
x0 = self._build_x_from_candidate_or_fallback(
candidate=c,
agent=ag,
dynamic_state=st0,
)
raw_ms0 = float(c.get("match_score", 0.0)) # For logging
return ag, x0, st0, raw_ms0
c0 = topL[0]
ag0 = c0["agent"]
st0 = self._agent_state(ag0)
# ✅ P0-4: Use unified helper (ensures ms=sim_emb, rep=prior_success)
x0 = self._build_x_from_candidate_or_fallback(
candidate=c0,
agent=ag0,
dynamic_state=st0,
)
raw_ms0 = float(c0.get("match_score", 0.0)) # For logging
return ag0, x0, st0, raw_ms0
assert self.selector is not None
chosen_id = self.selector.select([(aid, x) for (aid, x, _, _, _) in pool])
for (aid, x, agent, st, raw_ms) in pool:
if aid == chosen_id:
return agent, x, st, raw_ms
return pool[0][2], pool[0][1], pool[0][3], pool[0][4]
# ---------------------- core: multi-CoT + voting (+ trace) ----------------------
def _execute_with_cot_voting(
self,
subtasks: List[Dict[str, Any]],
agent_assignments: Dict[str, List[Dict[str, Any]]],
cot_count: int,
return_mode: str = "final", # "final" | "aggregate" | "trace"
) -> Any:
"""
For each subtask:
- run up to `cot_count` times (bounded by |TopL|)
- vote among outputs (keys on extracted Final)
- update LinUCB online after vote (winner bonus + latency penalty + optional correctness reward)
"""
results: Dict[str, str] = {}
traces_by_subtask: Dict[str, Any] = {}
for st in subtasks:
sid = st["id"]
# 改动4: 默认requirement改为math_reasoning(如果context中有benchmark=gsm8k)
ctx_check = st.get("context", {}) or {}
bench_check = str(ctx_check.get("benchmark", "")).strip().lower()
default_req = "math_reasoning" if bench_check in {"gsm8k", "gsm"} else "general-reasoning"
req = str(st.get("requirement", default_req))
candidates = agent_assignments.get(sid, [])
if not candidates:
err = f"[ERROR] No agents available for subtask: {req}"
results[sid] = err
if return_mode == "trace":
traces_by_subtask[sid] = {"error": err, "runs": [], "voted": err, "requirement": req}
continue
# ✅ Cold_start round-robin mode: check if context has _cold_start_task_index
ctx = st.get("context", {}) or {}
cold_start_task_index = ctx.get("_cold_start_task_index")
cold_start_agent_keys = ctx.get("_cold_start_agents", [])
if cold_start_task_index is not None and cold_start_agent_keys:
# ✅ Cold_start round-robin: each task -> exactly one agent (round-robin)
# Select agent by round-robin: task_index % len(agents)
agent_key_idx = int(cold_start_task_index) % len(cold_start_agent_keys) if cold_start_agent_keys else 0
target_agent_key = cold_start_agent_keys[agent_key_idx]
# Find agent by key
selected_agent = None
selected_candidate = None
for c in candidates:
ag = c["agent"]
aid = self._resolve_agent_key(ag)
if aid == target_agent_key:
selected_agent = ag
selected_candidate = c
break
if selected_agent is None:
# ✅ Fallback: still use round-robin from available candidates
# Match agents in candidates by key, then pick by round-robin index
available_candidates = []
candidate_keys = []
for c in candidates:
ag = c["agent"]
stt = self._agent_state(ag)
if bool(stt.get("available", True)):
aid = self._resolve_agent_key(ag)
available_candidates.append((aid, ag, c))
candidate_keys.append(aid)
if available_candidates:
# Find target_agent_key's position in sorted agent_keys list
# Use that position to select from available candidates
try:
target_idx_in_all = cold_start_agent_keys.index(target_agent_key)
# Find the first available candidate whose key matches any agent in sorted list at same relative position
# Simplified: use round-robin index directly on available candidates
fallback_idx = int(cold_start_task_index) % len(available_candidates) if available_candidates else 0
selected_agent = available_candidates[fallback_idx][1]
selected_candidate = available_candidates[fallback_idx][2]
except (ValueError, IndexError):
# If target not found or index error, use first available
selected_agent = available_candidates[0][1]
selected_candidate = available_candidates[0][2]
if selected_agent is None:
err = f"[ERROR] No agent found for cold_start round-robin (key={target_agent_key})"
results[sid] = err
if return_mode == "trace":
traces_by_subtask[sid] = {"error": err, "runs": [], "voted": err, "requirement": req}
continue
# Execute once with selected agent
aid = self._resolve_agent_key(selected_agent)
if not aid:
aid = f"agent_{id(selected_agent)}"
stt = self._agent_state(selected_agent)
match_score = float(selected_candidate.get("match_score", 0.0)) if selected_candidate else 0.0
x = self._build_x_from_candidate_or_fallback(
candidate=selected_candidate or {"agent": selected_agent, "sim_emb": 0.5, "prior_success": 0.5},
agent=selected_agent,
dynamic_state=stt,
)
t0 = time.time()
try:
if self.dispatch_mode == "shared_bb":
requester = self._get_requester_agent()
if requester is None:
text = "[ERROR] dispatch_mode=shared_bb but no requester"
else:
run_tag = str(ctx.get("_run_id", "run")) + f":{sid}:cold_start"
text = self._execute_subtask_via_shared_bb(requester, st, run_tag=run_tag)
else:
text = self._execute_subtask_on_agent(selected_agent, st)
except Exception as e:
text = f"[AGENT_ERROR] {str(e)}"
dt_ms = (time.time() - t0) * 1000.0
final_result = text
run_records = [{
"agent_id": aid,
"match_score": float(match_score),
"sim_emb": float(selected_candidate.get("sim_emb", 0.5)) if selected_candidate else 0.5,
"prior_success": float(selected_candidate.get("prior_success", 0.5)) if selected_candidate else 0.5,
"x": x,
"latency_ms": float(dt_ms),
"text": text,
"final": self._extract_final_from_text(text) or "",
}]
results[sid] = final_result
if return_mode == "trace":
traces_by_subtask[sid] = {
"requirement": req,
"context": ctx,
"runs": run_records,
"voted": final_result,
"voted_final": self._extract_final_from_text(final_result) or "",
}
continue
# Normal mode: filter available and use Top-L + Multi-CoT
# filter available
candidates_avail: List[Dict[str, Any]] = []
for c in candidates:
ag = c["agent"]
stt = self._agent_state(ag)
if bool(stt.get("available", True)):
candidates_avail.append(c)
if not candidates_avail:
err = f"[ERROR] No AVAILABLE agents for subtask: {req}"
results[sid] = err
if return_mode == "trace":
traces_by_subtask[sid] = {"error": err, "runs": [], "voted": err, "requirement": req}
continue
# ✅ Exploration constraint 1: Top-L must be unique (deduplicate by agent_id)
# Ensure no duplicate agents in topL (critical for exploration)
seen_agent_ids = set()
topL_unique: List[Dict[str, Any]] = []
for c in candidates_avail:
ag = c["agent"]
aid = self._resolve_agent_key(ag)
if not aid:
aid = f"agent_{id(ag)}"
if aid not in seen_agent_ids:
seen_agent_ids.add(aid)
topL_unique.append(c)
if len(topL_unique) >= self.topL:
break
# If not enough unique agents, use what we have (at least 1)
topL = topL_unique[: max(1, self.topL)] if topL_unique else candidates_avail[:1]
# ✅ Exploration constraint 2: Fixed exploration budget K (don't vary with topL length)
# 改动5: 对于GSM8K,实现self-consistency(同agent多次采样)
ctx_check = st.get("context", {}) or {}
bench_check = str(ctx_check.get("benchmark", "")).strip().lower()
is_gsm8k_self_consistency = bench_check in {"gsm8k", "gsm"} and cot_count >= 3
runs = int(cot_count) if cot_count > 0 else 0
if runs <= 0:
err = f"[ERROR] All agents filtered out for subtask: {req}"
results[sid] = err
if return_mode == "trace":
traces_by_subtask[sid] = {"error": err, "runs": [], "voted": err, "requirement": req}
continue
used_ids = set()
run_records: List[Dict[str, Any]] = []
cot_results: List[str] = []
# 改动5: self-consistency模式:选择第一个agent,然后多次采样
selected_agent = None
selected_candidate = None
selected_x = None
selected_stt = None
if is_gsm8k_self_consistency:
# 选择第一个agent(通过UCB或round-robin)
if self.use_dynamic and self.selector is not None:
selected_agent, selected_x, selected_stt, _ = self._select_agent_dynamic(topL, used_ids)
# 找到对应的candidate
for c in topL:
if c["agent"] == selected_agent:
selected_candidate = c
break
else:
if topL:
selected_candidate = topL[0]
selected_agent = selected_candidate["agent"]
selected_stt = self._agent_state(selected_agent)
selected_x = self._build_x_from_candidate_or_fallback(
candidate=selected_candidate,
agent=selected_agent,
dynamic_state=selected_stt,
)
for j in range(runs):
if is_gsm8k_self_consistency and selected_agent:
# Self-consistency: 使用同一个agent多次采样
agent = selected_agent
x = selected_x
stt = selected_stt
match_score = float(selected_candidate.get("match_score", 0.0)) if selected_candidate else 0.0
elif self.use_dynamic and self.selector is not None:
agent, x, _st, match_score = self._select_agent_dynamic(topL, used_ids)
else:
# ✅ cold_start: static Top-L but round-robin across agents (no repeat)
# Use j % len(topL) to cycle through topL candidates
candidate_idx = j % len(topL) if topL else 0
candidate = topL[candidate_idx]
agent = candidate["agent"]
stt = self._agent_state(agent)
# ✅ P0-4: Use unified helper (ensures ms=sim_emb, rep=prior_success)
match_score = float(candidate.get("match_score", 0.0)) # For logging
x = self._build_x_from_candidate_or_fallback(
candidate=candidate,
agent=agent,
dynamic_state=stt,
)
# ✅ P0-1: Use unified agent key resolution
aid = self._resolve_agent_key(agent)
if not aid:
aid = f"agent_{id(agent)}"
if not is_gsm8k_self_consistency:
used_ids.add(aid) # self-consistency模式下允许重复使用同一个agent
# 改动4: 零容忍解析 + invalid时自动重试一次
t0 = time.time()
text = ""
retry_count = 0
max_retries = 1
while retry_count <= max_retries:
try:
if self.dispatch_mode == "shared_bb":
requester = self._get_requester_agent()
if requester is None:
text = "[ERROR] dispatch_mode=shared_bb but no requester (agent with isep_client) registered"
else:
run_tag = str((st.get("context", {}) or {}).get("_run_id", "run")) + f":{st.get('id','sub')}:cot{len(run_records)+1}"
text = self._execute_subtask_via_shared_bb(requester, st, run_tag=run_tag)
else:
text = self._execute_subtask_on_agent(agent, st)
except Exception as e:
text = f"[AGENT_ERROR] {str(e)}"
# 对于GSM8K,检查输出是否valid
ctx_check = st.get("context", {}) or {}
bench_check = str(ctx_check.get("benchmark", "")).strip().lower()
if bench_check in {"gsm8k", "gsm"} and text and not text.startswith("[ERROR]") and not text.startswith("[AGENT_ERROR]"):
parsed, is_valid, err = self._parse_strict_json(text, benchmark=bench_check)
if parsed is None or not is_valid:
# invalid输出,重试一次(低温度)
if retry_count < max_retries: