diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..48a918c --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +__pycache__/ +.pytest_cache/ +.venv/ +*.egg-info/ +out/ diff --git a/README.md b/README.md index 5a82102..63a3868 100644 --- a/README.md +++ b/README.md @@ -89,4 +89,75 @@ The pipeline should treat **export formats** as first-class requirements so the 5. Specify **one SFT JSONL schema** and **one DPO JSONL schema** (and chat template) validated end-to-end with a **LoRA** dry run and a **small DPO** dry run on toy data. 6. Define **student model** constraints (context length, tool set) and a **filter + eval** plan for teacher-to-student parity before production swap. +--- +## Implementation approach + +This repository now includes a complete first implementation slice for the log-to-training pipeline: + +- file or directory ingestion for JSON/JSONL logs +- canonical event schema for Q&A, assistant turns, tool calls, tool results, errors, and feedback +- deterministic PII/secrets redaction with stable placeholders +- redaction report with audit samples that never include raw sensitive values +- session segmentation into Q&A units and agent trajectories +- trajectory complexity tags for staged training schedules +- deterministic train/validation/test splits with leakage buckets +- optional tool registry validation for declared tool names and required arguments +- SFT JSONL export for LoRA-compatible chat training +- DPO candidate export for human-approved preference data +- redacted canonical unit export for review and downstream transforms + +The implementation stays dependency-light so it can run in controlled environments, but it is structured as installable Python package code rather than ad hoc scripts. + +### Quickstart + +```bash +python -m pip install -e ".[dev]" +training-setup-logs examples/sample_agent_logs.jsonl \ + --tool-schema examples/tool_schema.json \ + --out-dir out +``` + +Outputs: + +- `out/redacted_units.jsonl`: canonical redacted units for audit and downstream transforms +- `out/sft.jsonl`: LoRA-ready supervised chat rows +- `out/dpo_candidates.jsonl`: preference-pair candidates that require human approval +- `out/redaction_report.json`: PII finding kinds and placeholders, without raw sensitive values +- `out/manifest.json`: PII counts, validation counts, and complexity distribution + +Run tests: + +```bash +python -m pytest +``` + +## Repository structure + +```text +src/training_setup_logs/ + audit.py audit samples for redacted data review + cli.py command line entrypoint + export.py SFT, DPO candidate, and redacted-unit exporters + ingest.py JSON/JSONL file and directory ingestion + pii.py deterministic PII and secret redaction + schemas.py canonical dataclasses + segment.py session-to-training-unit segmentation + split.py deterministic split and leakage-bucket assignment + tagging.py trajectory complexity and scheduling tags + tool_schema.py optional tool registry loading + validate.py trajectory and tool-use validation +examples/ + sample_agent_logs.jsonl + tool_schema.json +tests/ + test_pipeline.py +``` + +## Canonical outputs + +Each `redacted_units.jsonl` row contains `unit_id`, `session_id`, `unit_type`, redacted events, split metadata, complexity tags, and validation issues. `sft.jsonl` converts the same units into chat-style `messages`. `dpo_candidates.jsonl` only emits governed preference candidates, such as failed/error traces followed by a later recovery, and marks them as requiring human approval. + +## Privacy assumptions + +The default redactor covers common emails, phone numbers, Aadhaar-like IDs, IP addresses, bearer tokens, API-key-shaped secrets, and URL secret query parameters. Production use should add organization-specific dictionaries, policy approval, and human audit sampling before training artifacts are shipped. diff --git a/examples/sample_agent_logs.jsonl b/examples/sample_agent_logs.jsonl new file mode 100644 index 0000000..cb72bc2 --- /dev/null +++ b/examples/sample_agent_logs.jsonl @@ -0,0 +1,9 @@ +{"session_id":"s1","timestamp":"2026-04-28T08:00:00Z","role":"user","content":"My email is farmer.ravi@example.com. Please set up irrigation advice for plot 42."} +{"session_id":"s1","timestamp":"2026-04-28T08:00:02Z","role":"assistant","content":"I will inspect the available weather and crop profile before suggesting irrigation."} +{"session_id":"s1","timestamp":"2026-04-28T08:00:03Z","type":"tool_call","tool_name":"get_weather","tool_args":{"district":"Raipur","api_key":"sk_TEST1234567890abcdef"}} +{"session_id":"s1","timestamp":"2026-04-28T08:00:04Z","type":"tool_result","tool_result":{"rain_forecast_mm":0,"temperature_c":35}} +{"session_id":"s1","timestamp":"2026-04-28T08:00:05Z","type":"tool_call","tool_name":"get_crop_profile","tool_args":{"phone":"+91 98765 43210","crop":"paddy"}} +{"session_id":"s1","timestamp":"2026-04-28T08:00:06Z","type":"error","content":"crop profile lookup timed out"} +{"session_id":"s1","timestamp":"2026-04-28T08:00:08Z","role":"assistant","content":"The crop profile lookup failed, so I will provide a conservative recommendation and flag this for retry."} +{"session_id":"s2","timestamp":"2026-04-28T09:00:00Z","role":"user","content":"What is LoRA fine-tuning?"} +{"session_id":"s2","timestamp":"2026-04-28T09:00:02Z","role":"assistant","content":"LoRA fine-tuning updates small low-rank adapter matrices instead of all model weights, making supervised fine-tuning cheaper and easier to deploy."} diff --git a/examples/tool_schema.json b/examples/tool_schema.json new file mode 100644 index 0000000..6b677b0 --- /dev/null +++ b/examples/tool_schema.json @@ -0,0 +1,10 @@ +{ + "tools": { + "get_weather": { + "required_args": ["district"] + }, + "get_crop_profile": { + "required_args": ["crop"] + } + } +} diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..9309866 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,25 @@ +[build-system] +requires = ["setuptools>=68"] +build-backend = "setuptools.build_meta" + +[project] +name = "training-setup-logs" +version = "0.1.0" +description = "Privacy-safe log-to-training-data pipeline for Q&A and agentic traces." +readme = "README.md" +requires-python = ">=3.10" +authors = [{ name = "OpenAgriNet contributors" }] +dependencies = [] + +[project.optional-dependencies] +dev = ["pytest>=8"] + +[project.scripts] +training-setup-logs = "training_setup_logs.cli:main" + +[tool.setuptools.packages.find] +where = ["src"] + +[tool.pytest.ini_options] +pythonpath = ["src"] +testpaths = ["tests"] diff --git a/src/training_setup_logs/__init__.py b/src/training_setup_logs/__init__.py new file mode 100644 index 0000000..0e07f94 --- /dev/null +++ b/src/training_setup_logs/__init__.py @@ -0,0 +1,5 @@ +"""Privacy-safe log-to-training-data pipeline.""" + +from training_setup_logs.schemas import LogEvent, TrainingUnit + +__all__ = ["LogEvent", "TrainingUnit"] diff --git a/src/training_setup_logs/audit.py b/src/training_setup_logs/audit.py new file mode 100644 index 0000000..af775f8 --- /dev/null +++ b/src/training_setup_logs/audit.py @@ -0,0 +1,42 @@ +"""Audit helpers for privacy review and data quality checks.""" + +from __future__ import annotations + +from training_setup_logs.schemas import TrainingUnit +from training_setup_logs.validate import validate_unit + + +def build_audit_sample(units: list[TrainingUnit], limit: int = 10) -> list[dict[str, object]]: + """Return a compact redacted sample for human audit workflows.""" + + sample: list[dict[str, object]] = [] + for unit in units[:limit]: + sample.append( + { + "unit_id": unit.unit_id, + "session_id": unit.session_id, + "unit_type": unit.unit_type, + "event_count": len(unit.events), + "preview": _preview(unit), + "validation_issues": [issue.to_dict() for issue in validate_unit(unit)], + } + ) + return sample + + +def _preview(unit: TrainingUnit) -> list[dict[str, str]]: + preview: list[dict[str, str]] = [] + for event in unit.events[:6]: + text = event.content + if text is None and event.tool_name: + text = f"{event.type}: {event.tool_name}" + if text is None and event.tool_result is not None: + text = f"{event.type}: {type(event.tool_result).__name__}" + preview.append( + { + "event_id": event.event_id, + "type": event.type, + "text": (text or "")[:180], + } + ) + return preview diff --git a/src/training_setup_logs/cli.py b/src/training_setup_logs/cli.py new file mode 100644 index 0000000..57fb161 --- /dev/null +++ b/src/training_setup_logs/cli.py @@ -0,0 +1,105 @@ +"""Command line interface for the log-to-training pipeline.""" + +from __future__ import annotations + +import argparse +import json +from pathlib import Path + +from training_setup_logs.audit import build_audit_sample +from training_setup_logs.export import ( + to_dpo_candidate_row, + to_redacted_unit_row, + to_sft_row, + write_jsonl, +) +from training_setup_logs.ingest import load_events_from_path +from training_setup_logs.pii import PiiRedactor +from training_setup_logs.segment import segment_events +from training_setup_logs.schemas import TrainingUnit +from training_setup_logs.tagging import tag_unit +from training_setup_logs.tool_schema import ToolRegistry +from training_setup_logs.validate import validate_unit + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Build privacy-safe SFT and DPO JSONL from logs.") + parser.add_argument("input", type=Path, help="Input JSON/JSONL log file or directory.") + parser.add_argument("--out-dir", type=Path, default=Path("out"), help="Output directory.") + parser.add_argument("--tool-schema", type=Path, help="Optional JSON registry for tool validation.") + parser.add_argument("--audit-sample-size", type=int, default=10) + return parser + + +def run( + input_path: Path, + out_dir: Path, + tool_schema: Path | None = None, + audit_sample_size: int = 10, +) -> dict[str, object]: + events = load_events_from_path(input_path) + tool_registry = ToolRegistry.from_path(tool_schema) + redactor = PiiRedactor() + redacted_events = [redactor.redact_event(event) for event in events] + units = segment_events(redacted_events) + + sft_rows = [to_sft_row(unit) for unit in units] + dpo_rows = [row for unit in units if (row := to_dpo_candidate_row(unit)) is not None] + redacted_unit_rows = [to_redacted_unit_row(unit) for unit in units] + manifest = { + "input": str(input_path), + "unit_count": len(units), + "sft_rows": len(sft_rows), + "dpo_candidate_rows": len(dpo_rows), + "pii_counts": redactor.report.counts_by_kind(), + "validation_issue_count": sum(len(validate_unit(unit, tool_registry)) for unit in units), + "tool_registry_count": len(tool_registry.tools), + "split_summary": _split_summary(sft_rows), + "tag_summary": _tag_summary(units), + } + + out_dir.mkdir(parents=True, exist_ok=True) + write_jsonl(out_dir / "redacted_units.jsonl", redacted_unit_rows) + write_jsonl(out_dir / "sft.jsonl", sft_rows) + write_jsonl(out_dir / "dpo_candidates.jsonl", dpo_rows) + (out_dir / "redaction_report.json").write_text( + json.dumps( + { + "counts_by_kind": redactor.report.counts_by_kind(), + "findings": [finding.to_dict() for finding in redactor.report.findings], + "audit_sample": build_audit_sample(units, limit=audit_sample_size), + }, + indent=2, + sort_keys=True, + ), + encoding="utf-8", + ) + (out_dir / "manifest.json").write_text(json.dumps(manifest, indent=2, sort_keys=True), encoding="utf-8") + return manifest + + +def main() -> None: + args = build_parser().parse_args() + manifest = run(args.input, args.out_dir, args.tool_schema, args.audit_sample_size) + print(json.dumps(manifest, indent=2, sort_keys=True)) + + +def _tag_summary(units: list[TrainingUnit]) -> dict[str, int]: + summary: dict[str, int] = {} + for unit in units: + complexity = str(tag_unit(unit)["complexity"]) + summary[complexity] = summary.get(complexity, 0) + 1 + return summary + + +def _split_summary(rows: list[dict[str, object]]) -> dict[str, int]: + summary: dict[str, int] = {} + for row in rows: + metadata = row.get("metadata", {}) + split = str(metadata.get("split", "unknown")) if isinstance(metadata, dict) else "unknown" + summary[split] = summary.get(split, 0) + 1 + return summary + + +if __name__ == "__main__": + main() diff --git a/src/training_setup_logs/export.py b/src/training_setup_logs/export.py new file mode 100644 index 0000000..f3c01fb --- /dev/null +++ b/src/training_setup_logs/export.py @@ -0,0 +1,134 @@ +"""Export redacted training units into trainer-friendly JSONL views.""" + +from __future__ import annotations + +import json +from pathlib import Path + +from training_setup_logs.pii import redact_jsonish +from training_setup_logs.schemas import LogEvent, TrainingUnit +from training_setup_logs.split import split_metadata +from training_setup_logs.tagging import tag_unit +from training_setup_logs.validate import validate_unit + + +def write_jsonl(path: Path, rows: list[dict[str, object]]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w", encoding="utf-8") as handle: + for row in rows: + handle.write(json.dumps(row, ensure_ascii=False, sort_keys=True) + "\n") + + +def to_sft_row(unit: TrainingUnit) -> dict[str, object]: + """Convert a training unit to a LoRA-ready chat SFT row.""" + + messages = _events_to_messages(unit.events) + return { + "id": unit.unit_id, + "messages": messages, + "metadata": { + **unit.metadata, + **split_metadata(unit), + **tag_unit(unit), + "validation_issues": [issue.to_dict() for issue in validate_unit(unit)], + }, + } + + +def to_dpo_candidate_row(unit: TrainingUnit) -> dict[str, object] | None: + """Create a DPO candidate row when feedback or error recovery gives a natural pair. + + This does not invent preference labels. It emits candidate pairs only when the + session contains a failed assistant/error suffix followed by a later assistant + recovery. Mentors can later approve, discard, or replace these pairs. + """ + + last_user_index = _last_index(unit.events, "user") + if last_user_index is None: + return None + + rejected = _first_event_after(unit.events, last_user_index, {"error"}) + chosen = _last_event_of_type(unit.events, "assistant") + if rejected is None or chosen is None or chosen.event_id == rejected.event_id: + return None + + prompt_events = unit.events[: last_user_index + 1] + return { + "id": f"{unit.unit_id}_dpo_candidate", + "prompt": _events_to_messages(prompt_events), + "chosen": [_event_to_message(chosen)], + "rejected": [_event_to_message(rejected)], + "metadata": { + **split_metadata(unit), + **tag_unit(unit), + "source_unit_id": unit.unit_id, + "requires_human_approval": True, + }, + } + + +def to_redacted_unit_row(unit: TrainingUnit) -> dict[str, object]: + """Return the canonical redacted unit for audit and downstream transforms.""" + + return { + **unit.to_dict(), + "metadata": { + **unit.metadata, + **split_metadata(unit), + **tag_unit(unit), + "validation_issues": [issue.to_dict() for issue in validate_unit(unit)], + }, + } + + +def _events_to_messages(events: list[LogEvent]) -> list[dict[str, object]]: + messages: list[dict[str, object]] = [] + for event in events: + message = _event_to_message(event) + if message is not None: + messages.append(message) + return messages + + +def _event_to_message(event: LogEvent) -> dict[str, object] | None: + if event.type in {"system", "user", "assistant"}: + return {"role": event.type, "content": event.content or ""} + if event.type == "tool_call": + return { + "role": "assistant", + "content": json.dumps( + { + "tool_call": { + "name": event.tool_name, + "arguments": event.tool_args or {}, + } + }, + sort_keys=True, + ), + } + if event.type == "tool_result": + return {"role": "tool", "name": event.tool_name or "unknown", "content": redact_jsonish(event.tool_result) or ""} + if event.type == "error": + return {"role": "assistant", "content": f"[ERROR] {event.content or redact_jsonish(event.tool_result) or ''}"} + return None + + +def _last_index(events: list[LogEvent], event_type: str) -> int | None: + for index in range(len(events) - 1, -1, -1): + if events[index].type == event_type: + return index + return None + + +def _first_event_after(events: list[LogEvent], start_index: int, event_types: set[str]) -> LogEvent | None: + for event in events[start_index + 1 :]: + if event.type in event_types: + return event + return None + + +def _last_event_of_type(events: list[LogEvent], event_type: str) -> LogEvent | None: + for event in reversed(events): + if event.type == event_type: + return event + return None diff --git a/src/training_setup_logs/ingest.py b/src/training_setup_logs/ingest.py new file mode 100644 index 0000000..1fb6208 --- /dev/null +++ b/src/training_setup_logs/ingest.py @@ -0,0 +1,157 @@ +"""Input parsers for heterogeneous JSON/JSONL logs.""" + +from __future__ import annotations + +import hashlib +import json +from collections.abc import Iterable +from pathlib import Path +from typing import Any + +from training_setup_logs.schemas import EventType, LogEvent + + +TYPE_ALIASES: dict[str, EventType] = { + "human": "user", + "user": "user", + "ai": "assistant", + "assistant": "assistant", + "bot": "assistant", + "system": "system", + "tool": "tool_call", + "tool_call": "tool_call", + "function_call": "tool_call", + "tool_result": "tool_result", + "observation": "tool_result", + "error": "error", + "exception": "error", + "feedback": "feedback", +} + + +def read_json_records(path: Path) -> list[dict[str, Any]]: + """Read JSON, JSONL, or a JSON object containing an events/logs array.""" + + text = path.read_text(encoding="utf-8").strip() + if not text: + return [] + + if path.suffix == ".jsonl": + return [json.loads(line) for line in text.splitlines() if line.strip()] + + payload = json.loads(text) + if isinstance(payload, list): + return payload + if isinstance(payload, dict): + for key in ("events", "logs", "records"): + value = payload.get(key) + if isinstance(value, list): + return value + return [payload] + raise ValueError(f"Unsupported JSON payload in {path}") + + +def normalize_records(records: Iterable[dict[str, Any]]) -> list[LogEvent]: + """Normalize raw records into the canonical LogEvent schema.""" + + events: list[LogEvent] = [] + for index, record in enumerate(records): + session_id = str( + record.get("session_id") + or record.get("conversation_id") + or record.get("thread_id") + or "default" + ) + raw_type = str(record.get("type") or record.get("role") or record.get("event") or "assistant") + event_type = TYPE_ALIASES.get(raw_type.lower()) + if event_type is None: + event_type = "assistant" + + tool_name = record.get("tool_name") or record.get("name") + if event_type == "tool_call" and not tool_name and isinstance(record.get("tool"), str): + tool_name = record["tool"] + + event_id = str(record.get("event_id") or _stable_event_id(record, index)) + events.append( + LogEvent( + event_id=event_id, + session_id=session_id, + timestamp=record.get("timestamp") or record.get("time") or record.get("created_at"), + type=event_type, + content=_string_or_none(record.get("content") or record.get("message") or record.get("text")), + tool_name=_string_or_none(tool_name), + tool_args=_dict_or_none(record.get("tool_args") or record.get("arguments") or record.get("args")), + tool_result=record.get("tool_result") or record.get("result") or record.get("observation"), + metadata=_metadata_without_known_fields(record), + ) + ) + return sorted(events, key=lambda event: (event.session_id, event.timestamp or "", event.event_id)) + + +def load_events(path: Path) -> list[LogEvent]: + """Load and normalize records from one JSON or JSONL file.""" + + return normalize_records(read_json_records(path)) + + +def load_events_from_path(path: Path) -> list[LogEvent]: + """Load events from a file or directory of JSON/JSONL logs.""" + + if path.is_file(): + return load_events(path) + if not path.is_dir(): + raise FileNotFoundError(path) + + records: list[dict[str, Any]] = [] + for candidate in sorted(path.rglob("*")): + if candidate.suffix in {".json", ".jsonl"}: + records.extend(read_json_records(candidate)) + return normalize_records(records) + + +def _stable_event_id(record: dict[str, Any], index: int) -> str: + payload = json.dumps(record, sort_keys=True, default=str) + digest = hashlib.sha1(f"{index}:{payload}".encode("utf-8")).hexdigest()[:12] + return f"evt_{digest}" + + +def _string_or_none(value: Any) -> str | None: + if value is None: + return None + if isinstance(value, str): + return value + return json.dumps(value, sort_keys=True, default=str) + + +def _dict_or_none(value: Any) -> dict[str, Any] | None: + if isinstance(value, dict): + return value + return None + + +def _metadata_without_known_fields(record: dict[str, Any]) -> dict[str, Any]: + known = { + "session_id", + "conversation_id", + "thread_id", + "type", + "role", + "event", + "timestamp", + "time", + "created_at", + "content", + "message", + "text", + "tool_name", + "tool", + "name", + "tool_args", + "arguments", + "args", + "tool_result", + "result", + "observation", + "event_id", + } + return {key: value for key, value in record.items() if key not in known} diff --git a/src/training_setup_logs/pii.py b/src/training_setup_logs/pii.py new file mode 100644 index 0000000..e6a425d --- /dev/null +++ b/src/training_setup_logs/pii.py @@ -0,0 +1,100 @@ +"""Deterministic rule-based PII redaction.""" + +from __future__ import annotations + +import json +import re +from dataclasses import dataclass, field +from typing import Any + +from training_setup_logs.schemas import LogEvent, PiiFinding + + +PII_PATTERNS: list[tuple[str, re.Pattern[str]]] = [ + ("bearer_token", re.compile(r"\bBearer\s+[A-Za-z0-9._~+/=-]{16,}\b", re.IGNORECASE)), + ("api_key", re.compile(r"\b(?:sk|pk|api|key|token|secret)[-_]?[A-Za-z0-9]{16,}\b", re.IGNORECASE)), + ("url_token", re.compile(r"([?&](?:token|key|secret|signature|auth)=)[^&\s]+", re.IGNORECASE)), + ("email", re.compile(r"\b[A-Z0-9._%+-]+@[A-Z0-9.-]+\.[A-Z]{2,}\b", re.IGNORECASE)), + ("aadhaar", re.compile(r"(? None: + self.findings.append(PiiFinding(kind=kind, placeholder=placeholder)) + + def counts_by_kind(self) -> dict[str, int]: + counts: dict[str, int] = {} + for finding in self.findings: + counts[finding.kind] = counts.get(finding.kind, 0) + 1 + return counts + + +class PiiRedactor: + """Replace sensitive spans with stable placeholders within one run.""" + + def __init__(self) -> None: + self._seen: dict[tuple[str, str], str] = {} + self.report = RedactionReport() + + def redact_text(self, text: str | None) -> str | None: + if text is None: + return None + redacted = text + for kind, pattern in PII_PATTERNS: + redacted = pattern.sub(lambda match: self._placeholder(kind, match.group(0)), redacted) + return redacted + + def redact_value(self, value: Any) -> Any: + if isinstance(value, str): + return self.redact_text(value) + if isinstance(value, dict): + return {key: self.redact_value(item) for key, item in value.items()} + if isinstance(value, list): + return [self.redact_value(item) for item in value] + return value + + def redact_event(self, event: LogEvent) -> LogEvent: + metadata = dict(event.metadata) + metadata["pii_redacted"] = True + return LogEvent( + event_id=event.event_id, + session_id=event.session_id, + timestamp=event.timestamp, + type=event.type, + content=self.redact_text(event.content), + tool_name=event.tool_name, + tool_args=self.redact_value(event.tool_args), + tool_result=self.redact_value(event.tool_result), + metadata=metadata, + ) + + def _placeholder(self, kind: str, raw: str) -> str: + if kind == "url_token": + prefix = raw.split("=", 1)[0] + "=" + secret = raw.split("=", 1)[1] if "=" in raw else raw + return prefix + self._placeholder("url_secret", secret) + + key = (kind, raw) + if key not in self._seen: + placeholder = f"<{kind.upper()}_{len([k for k in self._seen if k[0] == kind]) + 1}>" + self._seen[key] = placeholder + self.report.add(kind, placeholder) + return self._seen[key] + + +def redact_jsonish(value: Any) -> str | None: + """Return a stable string representation for redacted tool payloads.""" + + if value is None: + return None + if isinstance(value, str): + return value + return json.dumps(value, sort_keys=True, default=str) diff --git a/src/training_setup_logs/schemas.py b/src/training_setup_logs/schemas.py new file mode 100644 index 0000000..4e23d70 --- /dev/null +++ b/src/training_setup_logs/schemas.py @@ -0,0 +1,89 @@ +"""Canonical schemas for normalized logs and training units.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Literal + +EventType = Literal[ + "system", + "user", + "assistant", + "tool_call", + "tool_result", + "error", + "feedback", +] + +UnitType = Literal["qa", "agent_trajectory"] + + +@dataclass(frozen=True) +class PiiFinding: + """A single PII finding after redaction.""" + + kind: str + placeholder: str + + def to_dict(self) -> dict[str, str]: + return {"kind": self.kind, "placeholder": self.placeholder} + + +@dataclass(frozen=True) +class LogEvent: + """A normalized event from application, chat, or agent logs.""" + + event_id: str + session_id: str + timestamp: str | None + type: EventType + content: str | None = None + tool_name: str | None = None + tool_args: dict[str, Any] | None = None + tool_result: Any | None = None + metadata: dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> dict[str, Any]: + return { + "event_id": self.event_id, + "session_id": self.session_id, + "timestamp": self.timestamp, + "type": self.type, + "content": self.content, + "tool_name": self.tool_name, + "tool_args": self.tool_args, + "tool_result": self.tool_result, + "metadata": self.metadata, + } + + +@dataclass(frozen=True) +class TrainingUnit: + """A redacted session segment ready for tagging and export.""" + + unit_id: str + session_id: str + unit_type: UnitType + events: list[LogEvent] + metadata: dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> dict[str, Any]: + return { + "unit_id": self.unit_id, + "session_id": self.session_id, + "unit_type": self.unit_type, + "events": [event.to_dict() for event in self.events], + "metadata": self.metadata, + } + + +@dataclass(frozen=True) +class ValidationIssue: + """Validation issue attached to a training unit.""" + + code: str + message: str + event_id: str | None = None + + def to_dict(self) -> dict[str, str | None]: + return {"code": self.code, "message": self.message, "event_id": self.event_id} diff --git a/src/training_setup_logs/segment.py b/src/training_setup_logs/segment.py new file mode 100644 index 0000000..1ca59de --- /dev/null +++ b/src/training_setup_logs/segment.py @@ -0,0 +1,39 @@ +"""Session segmentation into Q&A and agent trajectory training units.""" + +from __future__ import annotations + +import hashlib +from collections import defaultdict + +from training_setup_logs.schemas import LogEvent, TrainingUnit + + +def segment_events(events: list[LogEvent]) -> list[TrainingUnit]: + """Group normalized events into redacted training units by session.""" + + sessions: dict[str, list[LogEvent]] = defaultdict(list) + for event in events: + sessions[event.session_id].append(event) + + units: list[TrainingUnit] = [] + for session_id, session_events in sorted(sessions.items()): + unit_type = "agent_trajectory" if any( + event.type in {"tool_call", "tool_result", "error"} for event in session_events + ) else "qa" + unit_id = _unit_id(session_id, session_events) + units.append( + TrainingUnit( + unit_id=unit_id, + session_id=session_id, + unit_type=unit_type, + events=session_events, + metadata={"source_event_count": len(session_events)}, + ) + ) + return units + + +def _unit_id(session_id: str, events: list[LogEvent]) -> str: + joined = "|".join(event.event_id for event in events) + digest = hashlib.sha1(f"{session_id}:{joined}".encode("utf-8")).hexdigest()[:12] + return f"unit_{digest}" diff --git a/src/training_setup_logs/split.py b/src/training_setup_logs/split.py new file mode 100644 index 0000000..3a62798 --- /dev/null +++ b/src/training_setup_logs/split.py @@ -0,0 +1,51 @@ +"""Deterministic split assignment with simple leakage guards.""" + +from __future__ import annotations + +import hashlib +import re + +from training_setup_logs.schemas import TrainingUnit + + +SPACE_RE = re.compile(r"\s+") + + +def split_metadata(unit: TrainingUnit) -> dict[str, str]: + """Return stable split metadata for a unit. + + The leakage bucket is based on normalized user text instead of unit id, so + repeated or near-identical prompts are kept in the same split. + """ + + leakage_key = _leakage_key(unit) + split = _split_from_key(leakage_key) + return { + "split": split, + "leakage_bucket": hashlib.sha1(leakage_key.encode("utf-8")).hexdigest()[:16], + } + + +def _leakage_key(unit: TrainingUnit) -> str: + user_text = "\n".join( + _normalize_text(event.content or "") + for event in unit.events + if event.type == "user" + ) + tool_names = ",".join( + sorted({event.tool_name or "unknown" for event in unit.events if event.type == "tool_call"}) + ) + return f"{unit.unit_type}|{tool_names}|{user_text}" + + +def _normalize_text(text: str) -> str: + return SPACE_RE.sub(" ", text.casefold()).strip() + + +def _split_from_key(key: str) -> str: + value = int(hashlib.sha1(key.encode("utf-8")).hexdigest()[:8], 16) % 100 + if value < 80: + return "train" + if value < 90: + return "validation" + return "test" diff --git a/src/training_setup_logs/tagging.py b/src/training_setup_logs/tagging.py new file mode 100644 index 0000000..68f878b --- /dev/null +++ b/src/training_setup_logs/tagging.py @@ -0,0 +1,75 @@ +"""Trajectory complexity and scheduling metadata.""" + +from __future__ import annotations + +from training_setup_logs.schemas import TrainingUnit + + +AMBIGUITY_MARKERS = ("maybe", "not sure", "unclear", "ambiguous", "try again", "fallback") + + +def tag_unit(unit: TrainingUnit) -> dict[str, object]: + """Compute training-time scheduling tags for a unit.""" + + tool_calls = [event for event in unit.events if event.type == "tool_call"] + errors = [event for event in unit.events if event.type == "error"] + assistant_turns = [event for event in unit.events if event.type == "assistant"] + user_turns = [event for event in unit.events if event.type == "user"] + unique_tools = sorted({event.tool_name for event in tool_calls if event.tool_name}) + has_recovery = _has_recovery_after_error(unit) + has_ambiguity = any( + marker in (event.content or "").lower() + for event in unit.events + for marker in AMBIGUITY_MARKERS + ) + + complexity = _complexity( + event_count=len(unit.events), + tool_count=len(tool_calls), + error_count=len(errors), + has_recovery=has_recovery, + ) + + return { + "unit_type": unit.unit_type, + "complexity": complexity, + "event_count": len(unit.events), + "user_turns": len(user_turns), + "assistant_turns": len(assistant_turns), + "tool_call_count": len(tool_calls), + "unique_tools": unique_tools, + "error_count": len(errors), + "has_recovery": has_recovery, + "has_ambiguity": has_ambiguity, + "recommended_schedule_bucket": _schedule_bucket(complexity), + } + + +def _has_recovery_after_error(unit: TrainingUnit) -> bool: + seen_error = False + for event in unit.events: + if event.type == "error": + seen_error = True + elif seen_error and event.type in {"assistant", "tool_call"}: + return True + return False + + +def _complexity(event_count: int, tool_count: int, error_count: int, has_recovery: bool) -> str: + if has_recovery or error_count: + return "recovery" + if tool_count >= 2 or event_count >= 8: + return "multi_tool" + if tool_count == 1: + return "single_tool" + if event_count > 2: + return "multi_turn_qa" + return "single_turn_qa" + + +def _schedule_bucket(complexity: str) -> str: + if complexity in {"single_turn_qa", "multi_turn_qa"}: + return "phase_1_foundation" + if complexity == "single_tool": + return "phase_2_tool_use" + return "phase_3_complex_trajectories" diff --git a/src/training_setup_logs/tool_schema.py b/src/training_setup_logs/tool_schema.py new file mode 100644 index 0000000..5f9aff5 --- /dev/null +++ b/src/training_setup_logs/tool_schema.py @@ -0,0 +1,38 @@ +"""Optional tool schema validation for agent trajectories.""" + +from __future__ import annotations + +import json +from dataclasses import dataclass +from pathlib import Path + + +@dataclass(frozen=True) +class ToolSpec: + name: str + required_args: set[str] + + +@dataclass(frozen=True) +class ToolRegistry: + tools: dict[str, ToolSpec] + + @classmethod + def from_path(cls, path: Path | None) -> "ToolRegistry": + if path is None: + return cls({}) + payload = json.loads(path.read_text(encoding="utf-8")) + tools: dict[str, ToolSpec] = {} + for name, spec in payload.get("tools", {}).items(): + tools[name] = ToolSpec(name=name, required_args=set(spec.get("required_args", []))) + return cls(tools) + + def has_tool(self, name: str) -> bool: + return name in self.tools + + def missing_required_args(self, name: str, args: dict[str, object] | None) -> set[str]: + spec = self.tools.get(name) + if spec is None: + return set() + actual = set((args or {}).keys()) + return spec.required_args - actual diff --git a/src/training_setup_logs/validate.py b/src/training_setup_logs/validate.py new file mode 100644 index 0000000..7961751 --- /dev/null +++ b/src/training_setup_logs/validate.py @@ -0,0 +1,87 @@ +"""Schema and trajectory validation.""" + +from __future__ import annotations + +from training_setup_logs.schemas import TrainingUnit, ValidationIssue +from training_setup_logs.tool_schema import ToolRegistry + + +def validate_unit(unit: TrainingUnit, tool_registry: ToolRegistry | None = None) -> list[ValidationIssue]: + """Validate basic event and tool trajectory consistency.""" + + issues: list[ValidationIssue] = [] + pending_tools: list[tuple[str, str]] = [] + registry = tool_registry or ToolRegistry({}) + + for event in unit.events: + if event.type == "tool_call": + if not event.tool_name: + issues.append( + ValidationIssue( + code="TOOL_NAME_MISSING", + message="Tool call event is missing tool_name.", + event_id=event.event_id, + ) + ) + else: + pending_tools.append((event.tool_name, event.event_id)) + if registry.tools and not registry.has_tool(event.tool_name): + issues.append( + ValidationIssue( + code="UNKNOWN_TOOL", + message=f"Tool {event.tool_name} is not declared in the registry.", + event_id=event.event_id, + ) + ) + missing_args = registry.missing_required_args(event.tool_name, event.tool_args) + if missing_args: + issues.append( + ValidationIssue( + code="MISSING_TOOL_ARGS", + message=f"Tool {event.tool_name} is missing required args: {sorted(missing_args)}.", + event_id=event.event_id, + ) + ) + + if event.type == "tool_result": + if not pending_tools: + issues.append( + ValidationIssue( + code="ORPHAN_TOOL_RESULT", + message="Tool result appears before any matching tool call.", + event_id=event.event_id, + ) + ) + else: + expected_tool, _ = pending_tools.pop(0) + if event.tool_name and event.tool_name != expected_tool: + issues.append( + ValidationIssue( + code="TOOL_RESULT_NAME_MISMATCH", + message=f"Tool result is for {event.tool_name}, expected {expected_tool}.", + event_id=event.event_id, + ) + ) + + if event.type == "error" and pending_tools: + pending_tools.pop(0) + + if event.type in {"user", "assistant"} and not event.content: + issues.append( + ValidationIssue( + code="EMPTY_TEXT_TURN", + message=f"{event.type} event has no content.", + event_id=event.event_id, + ) + ) + + for tool_name, event_id in pending_tools: + issues.append( + ValidationIssue( + code="MISSING_TOOL_OBSERVATION", + message=f"Tool call {tool_name} has no following result or error event.", + event_id=event_id, + ) + ) + + return issues diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py new file mode 100644 index 0000000..ee3f236 --- /dev/null +++ b/tests/test_pipeline.py @@ -0,0 +1,145 @@ +from pathlib import Path + +from training_setup_logs.cli import run +from training_setup_logs.ingest import load_events +from training_setup_logs.pii import PiiRedactor +from training_setup_logs.schemas import LogEvent, TrainingUnit +from training_setup_logs.segment import segment_events +from training_setup_logs.split import split_metadata +from training_setup_logs.tagging import tag_unit +from training_setup_logs.tool_schema import ToolRegistry +from training_setup_logs.validate import validate_unit + + +ROOT = Path(__file__).resolve().parents[1] + + +def test_sample_pipeline_exports_sft_and_dpo_candidates(tmp_path): + manifest = run(ROOT / "examples" / "sample_agent_logs.jsonl", tmp_path) + + assert manifest["unit_count"] == 2 + assert manifest["sft_rows"] == 2 + assert manifest["dpo_candidate_rows"] == 1 + assert manifest["pii_counts"]["email"] == 1 + assert manifest["pii_counts"]["phone"] == 1 + assert manifest["validation_issue_count"] == 0 + assert (tmp_path / "sft.jsonl").exists() + assert (tmp_path / "dpo_candidates.jsonl").exists() + assert (tmp_path / "redacted_units.jsonl").exists() + assert (tmp_path / "redaction_report.json").exists() + assert (tmp_path / "manifest.json").exists() + + +def test_redaction_is_deterministic_within_run(): + events = load_events(ROOT / "examples" / "sample_agent_logs.jsonl") + redactor = PiiRedactor() + redacted = [redactor.redact_event(event) for event in events] + + assert "" in (redacted[0].content or "") + assert redacted[4].tool_args["phone"] == "" + assert redacted[2].tool_args["api_key"] == "" + + +def test_agent_trace_receives_recovery_tag_and_validates(): + events = load_events(ROOT / "examples" / "sample_agent_logs.jsonl") + redactor = PiiRedactor() + units = segment_events([redactor.redact_event(event) for event in events]) + agent_unit = next(unit for unit in units if unit.unit_type == "agent_trajectory") + + tags = tag_unit(agent_unit) + + assert tags["complexity"] == "recovery" + assert tags["recommended_schedule_bucket"] == "phase_3_complex_trajectories" + assert validate_unit(agent_unit) == [] + + +def test_split_metadata_keeps_similar_prompts_together(): + first = TrainingUnit( + unit_id="u1", + session_id="s1", + unit_type="qa", + events=[ + LogEvent(event_id="e1", session_id="s1", timestamp=None, type="user", content=" What is LoRA? "), + LogEvent(event_id="e2", session_id="s1", timestamp=None, type="assistant", content="Answer"), + ], + ) + second = TrainingUnit( + unit_id="u2", + session_id="s2", + unit_type="qa", + events=[ + LogEvent(event_id="e3", session_id="s2", timestamp=None, type="user", content="what is lora?"), + LogEvent(event_id="e4", session_id="s2", timestamp=None, type="assistant", content="Answer"), + ], + ) + + assert split_metadata(first) == split_metadata(second) + + +def test_validation_flags_tool_mismatch_and_missing_observation(): + mismatch_unit = TrainingUnit( + unit_id="u3", + session_id="s3", + unit_type="agent_trajectory", + events=[ + LogEvent(event_id="e1", session_id="s3", timestamp=None, type="tool_call", tool_name="weather"), + LogEvent(event_id="e2", session_id="s3", timestamp=None, type="tool_result", tool_name="crop"), + ], + ) + missing_result_unit = TrainingUnit( + unit_id="u4", + session_id="s4", + unit_type="agent_trajectory", + events=[ + LogEvent(event_id="e3", session_id="s4", timestamp=None, type="tool_call", tool_name="weather"), + ], + ) + + assert [issue.code for issue in validate_unit(mismatch_unit)] == ["TOOL_RESULT_NAME_MISMATCH"] + assert [issue.code for issue in validate_unit(missing_result_unit)] == ["MISSING_TOOL_OBSERVATION"] + + +def test_tool_schema_validation_flags_unknown_tools_and_missing_args(): + registry = ToolRegistry.from_path(ROOT / "examples" / "tool_schema.json") + unit = TrainingUnit( + unit_id="u5", + session_id="s5", + unit_type="agent_trajectory", + events=[ + LogEvent( + event_id="e1", + session_id="s5", + timestamp=None, + type="tool_call", + tool_name="get_weather", + tool_args={}, + ), + LogEvent( + event_id="e2", + session_id="s5", + timestamp=None, + type="tool_result", + tool_name="get_weather", + tool_result={}, + ), + LogEvent( + event_id="e3", + session_id="s5", + timestamp=None, + type="tool_call", + tool_name="unknown_tool", + ), + LogEvent( + event_id="e4", + session_id="s5", + timestamp=None, + type="error", + content="failed", + ), + ], + ) + + codes = [issue.code for issue in validate_unit(unit, registry)] + + assert "MISSING_TOOL_ARGS" in codes + assert "UNKNOWN_TOOL" in codes