diff --git a/SCHEMA.md b/SCHEMA.md new file mode 100644 index 0000000..f78187c --- /dev/null +++ b/SCHEMA.md @@ -0,0 +1,117 @@ +# Schema & Pipeline Documentation + +## Input Log Schema (`LogSession`) + +| Field | Type | Description | +|---|---|---| +| `user_question` | string | Raw user query (PII-redacted before export) | +| `bot_response` | string | Final model response (PII-redacted before export) | +| `agent_turns` | list | Ordered tool-call / tool-return turns between question and response | +| `session_id` | string? | Optional session identifier (not exported to training artifacts) | +| `language` | string? | Optional language tag (e.g. `en`, `hi`) | +| `domain` | string? | Optional domain tag (e.g. `agri`, `weather`) | + +### `AgentTurn.parts[].part_kind` values + +| Value | Meaning | +|---|---| +| `tool-call` | Model requested a tool; has `tool_name`, `args`, `tool_call_id` | +| `tool-return` | Tool responded; has `tool_name`, `content`, `tool_call_id` | + +--- + +## SFT JSONL Schema (`sft_train.jsonl`, `sft_eval.jsonl`) + +Each line is a JSON object: + +```json +{ + "messages": [ + {"role": "user", "content": "..."}, + {"role": "assistant", "content": null, "tool_calls": [{"type": "function", "function": {"name": "...", "arguments": "{}"}, "id": "call_1"}]}, + {"role": "tool", "tool_call_id": "call_1", "content": "..."}, + {"role": "assistant", "content": "final response"} + ], + "metadata": { + "tool_count": 2, + "unique_tools": ["weather_forecast", "fetch_agristack_data"], + "has_recovery": false, + "complexity_tier": "moderate", + "is_agentic": true + } +} +``` + +- Compatible with TRL `SFTTrainer` and Hugging Face `datasets` directly. +- Tool-call format matches OpenAI function-calling chat template used by Gemma, Llama, Qwen. +- All `content` fields are PII-redacted before writing. + +--- + +## DPO JSONL Schema (`dpo_train.jsonl`, `dpo_eval.jsonl`) + +Each line is a JSON object: + +```json +{ + "prompt": [{"role": "user", "content": "..."}], + "chosen": "correct model response", + "rejected": "I cannot help with that. correct model response", + "metadata": { "...same as SFT metadata..." }, + "synthetic": true +} +``` + +- `prompt`: conversation prefix up to and including the last user turn. +- `chosen`: PII-redacted actual bot response. +- `rejected`: synthetic negative (prefixed refusal). `"synthetic": true` marks these for filtering once real preference pairs (thumbs, edits, failure/success pairs) are available. +- Compatible with TRL `DPOTrainer` (`prompt`, `chosen`, `rejected` are the required fields). + +--- + +## Complexity Tags & Recommended Training Schedule + +| `complexity_tier` | Condition | Recommended use | +|---|---|---| +| `simple` | 0 tool calls | Warm-up / instruction-following phase | +| `moderate` | 1–2 tool calls, no recovery | Main SFT bulk training | +| `complex` | 3+ tool calls **or** `has_recovery=true` | Later-stage fine-tuning; up-weight in DPO | + +Trainers can filter or sample by `metadata.complexity_tier` for staged curriculum or flat mixture runs without any pipeline changes. + +--- + +## Split Strategy + +Sessions are assigned to `train` or `eval` using a **hash of `user_question`** (MD5 mod 100), not a random index. This ensures: +- Near-duplicate or repeated queries always land in the **same split**. +- No prompt leakage between `sft_train` and `dpo_train` preference sets. +- Splits are **deterministic and reproducible** across pipeline runs. + +Default ratio: 80% train / 20% eval (configurable via `--split-ratio`). + +--- + +## PII Handling + +Detected entity types: `PHONE_NUMBER`, `EMAIL_ADDRESS`, `URL`, `IN_AADHAAR` + +- Replaced with indexed placeholders: ``, ``, ``, etc. +- Placeholders are **consistent within a session** (same entity → same placeholder across SFT and DPO exports for that session). +- Raw PII is logged to an in-memory audit list accessible via `audit_sample()` in `pii_redactor.py` for human review of false negatives. +- `PERSON` entity detection is intentionally disabled to avoid false positives on crop names and place names in agricultural context. + +**Residual risk:** Indirect identification via rare location + crop combinations is not eliminated by redaction alone. Human audit sampling is required before production export. + +--- + +## Student Model Criteria + +Target: smallest dense model ≤ 32B parameters that achieves parity with the teacher on the held-out `sft_eval.jsonl` set. + +Acceptance threshold (to be confirmed with mentor): +- Tool-call accuracy (correct tool name + valid args) ≥ teacher baseline on `sft_eval` +- Response quality score (persona adherence, checked via rule layer for verifiable rules) within 5% of teacher +- End-to-end latency: 4–5 tool calls + final response within ~4 seconds on 8×H100 + +Filter applied to student training data: exclude sessions where `tool_count > 4` or total token length exceeds the student model's context window. diff --git a/exporter.py b/exporter.py new file mode 100644 index 0000000..44e3dbc --- /dev/null +++ b/exporter.py @@ -0,0 +1,141 @@ +import json +import os +import hashlib +import argparse +import collections +from pii_redactor import redact_pii +from schemas import LogSession, DPORecord +from segmenter import validate_trajectory, tag_complexity + +def split_sessions(sessions, train_ratio=0.8, seed=42): + train_sessions = [] + eval_sessions = [] + for s in sessions: + h = int(hashlib.md5(s.user_question.encode('utf-8')).hexdigest(), 16) + if (h % 100) < (train_ratio * 100): + train_sessions.append(s) + else: + eval_sessions.append(s) + return train_sessions, eval_sessions + +def _redact_session(session: LogSession): + """Redact all text fields in a session using a single shared counter so placeholders are consistent.""" + counts = collections.defaultdict(int) + all_audit_entries = [] + + def r(text): + if not text: + return text + redacted, audits = redact_pii(text, language=session.language or "en", _counts=counts) + all_audit_entries.extend(audits) + return redacted + + messages = [{"role": "user", "content": r(session.user_question)}] + for turn in session.agent_turns: + for part in turn.parts: + if part.part_kind == "tool-call": + messages.append({ + "role": "assistant", + "content": None, + "tool_calls": [{ + "type": "function", + "function": { + "name": part.tool_name or "", + "arguments": json.dumps({k: r(str(v)) for k, v in part.args.items()}) if part.args else "{}" + }, + "id": part.tool_call_id or "" + }] + }) + elif part.part_kind == "tool-return": + messages.append({ + "role": "tool", + "tool_call_id": part.tool_call_id or "", + "content": r(part.content) + }) + messages.append({"role": "assistant", "content": r(session.bot_response)}) + return messages, all_audit_entries + + +def get_trace_id(session: LogSession) -> str: + if session.session_id: + return hashlib.md5(session.session_id.encode('utf-8')).hexdigest()[:12] + return hashlib.md5(session.user_question.encode('utf-8')).hexdigest()[:12] + +def create_sft_export(sessions, output_file, audit_file_path): + with open(output_file, "w") as f, open(audit_file_path, "a") as af: + for i, session in enumerate(sessions): + if not validate_trajectory(session): + continue + messages, audit_entries = _redact_session(session) + + trace_id = get_trace_id(session) + meta = tag_complexity(session) + meta.trace_id = trace_id + + if audit_entries: + af.write(json.dumps({"trace_id": trace_id, "findings": audit_entries}) + "\n") + + f.write(json.dumps({"messages": messages, "metadata": meta.model_dump()}) + "\n") + +def generate_synthetic_rejection(messages): + chosen = messages[-1]["content"] + # Rejection type: persona_violation + rejected = "Yes, here is the answer. " + chosen + return rejected, "persona_violation" + +def create_dpo_export(sessions, output_file, audit_file_path): + with open(output_file, "w") as f, open(audit_file_path, "a") as af: + for i, session in enumerate(sessions): + if not session.agent_turns or not validate_trajectory(session): + continue + messages, audit_entries = _redact_session(session) + trace_id = get_trace_id(session) + + if audit_entries: + af.write(json.dumps({"trace_id": trace_id, "findings": audit_entries}) + "\n") + + prompt = [messages[0]] + chosen = messages[-1]["content"] + rejected, rejection_type = generate_synthetic_rejection(messages) + + meta = tag_complexity(session) + meta.trace_id = trace_id + meta.rejection_type = rejection_type + + dpo_record = DPORecord( + prompt=prompt, chosen=chosen, rejected=rejected, + metadata=meta, synthetic=True + ) + f.write(json.dumps(dpo_record.model_dump()) + "\n") + +def process_logs(input_file: str, output_dir: str, split_ratio: float): + os.makedirs(output_dir, exist_ok=True) + audit_file_path = os.path.join(output_dir, "audit_log.jsonl") + # Clear audit log if it exists + if os.path.exists(audit_file_path): + os.remove(audit_file_path) + + with open(input_file, 'r') as f: + data = json.load(f) + + sessions = [LogSession(**item) for item in data] + + train_sessions, eval_sessions = split_sessions(sessions, train_ratio=split_ratio) + + create_sft_export(train_sessions, os.path.join(output_dir, "sft_train.jsonl"), audit_file_path) + create_sft_export(eval_sessions, os.path.join(output_dir, "sft_eval.jsonl"), audit_file_path) + + create_dpo_export(train_sessions, os.path.join(output_dir, "dpo_train.jsonl"), audit_file_path) + create_dpo_export(eval_sessions, os.path.join(output_dir, "dpo_eval.jsonl"), audit_file_path) + + print(f"Processed {len(sessions)} sessions, exported to {output_dir}") + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--input", default="sample_logs.json") + parser.add_argument("--output", default="output/") + parser.add_argument("--split-ratio", type=float, default=0.8) + args = parser.parse_args() + + if os.path.exists(args.input): + process_logs(args.input, args.output, args.split_ratio) diff --git a/pii_redactor.py b/pii_redactor.py new file mode 100644 index 0000000..2f4b894 --- /dev/null +++ b/pii_redactor.py @@ -0,0 +1,64 @@ +import re +import collections +import spacy + +try: + spacy.util.get_package_path('en_core_web_lg') +except Exception: + raise ImportError('Missing required spacy model. Please run: setup.sh or python -m spacy download en_core_web_lg') + +from presidio_analyzer import AnalyzerEngine, PatternRecognizer, Pattern + +analyzer = AnalyzerEngine() + +# Indian Aadhaar: 12 digits optionally grouped as XXXX-XXXX-XXXX +_AADHAAR_PATTERN = Pattern(name="aadhaar", regex=r"\b\d{4}[- ]?\d{4}[- ]?\d{4}\b", score=0.85) +analyzer.registry.add_recognizer( + PatternRecognizer(supported_entity="IN_AADHAAR", patterns=[_AADHAAR_PATTERN]) +) + +# Entities to detect — excludes PERSON to avoid crop/place false positives +_ENTITIES = ["PHONE_NUMBER", "EMAIL_ADDRESS", "URL", "IN_AADHAAR"] + +def redact_pii(text: str, language: str = "en", _counts: dict = None): + """Redact PII in text. Pass a shared _counts dict per session for consistent placeholders.""" + if not text: + return text, [] + + if _counts is None: + _counts = collections.defaultdict(int) + + audit_entries = [] + + if language != "en": + import re + results = [] + for match in re.finditer(r'\b\d{10}\b', text): + results.append(("PHONE_NUMBER", match.start(), match.end())) + for match in re.finditer(r'\b\d{4}[- ]?\d{4}[- ]?\d{4}\b', text): + results.append(("IN_AADHAAR", match.start(), match.end())) + for match in re.finditer(r'[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}', text): + results.append(("EMAIL_ADDRESS", match.start(), match.end())) + results = sorted(results, key=lambda x: x[1], reverse=True) + + redacted_text = text + for r_type, start, end in results: + entity_value = text[start:end] + _counts[r_type] += 1 + placeholder = f"<{r_type}_{_counts[r_type]}>" + audit_entries.append({"original_span": entity_value, "placeholder": placeholder, "entity_type": r_type}) + redacted_text = redacted_text[:start] + placeholder + redacted_text[end:] + return redacted_text, audit_entries + + results = analyzer.analyze(text=text, entities=_ENTITIES, language="en") + results = sorted(results, key=lambda x: x.start, reverse=True) + redacted_text = text + + for r in results: + entity_value = text[r.start:r.end] + _counts[r.entity_type] += 1 + placeholder = f"<{r.entity_type}_{_counts[r.entity_type]}>" + audit_entries.append({"original_span": entity_value, "placeholder": placeholder, "entity_type": r.entity_type}) + redacted_text = redacted_text[:r.start] + placeholder + redacted_text[r.end:] + + return redacted_text, audit_entries diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..3b86d68 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,6 @@ +# Note: the spacy model en_core_web_lg must be downloaded separately. +# Run setup.sh to automatically install these requirements and download the model. +presidio-analyzer +presidio-anonymizer +pydantic>=2 +pytest \ No newline at end of file diff --git a/sample_logs.json b/sample_logs.json new file mode 100644 index 0000000..2befe17 --- /dev/null +++ b/sample_logs.json @@ -0,0 +1,130 @@ +[ + { + "user_question": "What is the weather like at 19.066, 77.174 ? My number is 9876543210", + "bot_response": "The weather is currently sunny with a high of 32°C.", + "agent_turns": [ + { + "parts": [ + { + "tool_name": "weather_forecast", + "args": {"latitude": 19.066, "longitude": 77.174, "days": 5}, + "tool_call_id": "call_1", + "part_kind": "tool-call" + } + ], + "timestamp": "2026-01-20T05:54:15.673313Z" + }, + { + "parts": [ + { + "tool_name": "weather_forecast", + "content": "Sunny, 32°C, humidity 45%, wind 12 km/h", + "tool_call_id": "call_1", + "part_kind": "tool-return" + } + ], + "timestamp": "2026-01-20T05:54:16.270192Z" + } + ] + }, + { + "user_question": "What crops should I grow on my farm? Contact me at farmer@example.com", + "bot_response": "Based on your soil and location, I recommend growing soybean and cotton this season.", + "agent_turns": [ + { + "parts": [ + { + "tool_name": "fetch_agristack_data", + "args": {}, + "tool_call_id": "call_2", + "part_kind": "tool-call" + } + ], + "timestamp": "2026-01-20T06:10:00.000000Z" + }, + { + "parts": [ + { + "tool_name": "fetch_agristack_data", + "content": "Farmer Details: District: Loha, Maharashtra. Total Plot Area: 2.5 hectares. Soil Type: Black Cotton Soil.", + "tool_call_id": "call_2", + "part_kind": "tool-return" + } + ], + "timestamp": "2026-01-20T06:10:01.500000Z" + }, + { + "parts": [ + { + "tool_name": "crop_recommendation", + "args": {"soil_type": "black_cotton", "district": "Loha", "season": "kharif"}, + "tool_call_id": "call_3", + "part_kind": "tool-call" + } + ], + "timestamp": "2026-01-20T06:10:02.000000Z" + }, + { + "parts": [ + { + "tool_name": "crop_recommendation", + "content": "Recommended crops: Soybean, Cotton, Tur Dal", + "tool_call_id": "call_3", + "part_kind": "tool-return" + } + ], + "timestamp": "2026-01-20T06:10:03.200000Z" + } + ] + }, + { + "user_question": "How much subsidy can I get for drip irrigation? My Aadhaar is 1234-5678-9012", + "bot_response": "You are eligible for up to 55% subsidy on drip irrigation under the PMKSY scheme.", + "agent_turns": [ + { + "parts": [ + { + "tool_name": "fetch_agristack_data", + "args": {}, + "tool_call_id": "call_4", + "part_kind": "tool-call" + } + ], + "timestamp": "2026-01-20T07:00:00.000000Z" + }, + { + "parts": [ + { + "tool_name": "fetch_agristack_data", + "content": "error: timeout fetching farmer record", + "tool_call_id": "call_4", + "part_kind": "tool-return" + } + ], + "timestamp": "2026-01-20T07:00:01.800000Z" + }, + { + "parts": [ + { + "tool_name": "subsidy_lookup", + "args": {"scheme": "PMKSY", "category": "drip_irrigation", "state": "Maharashtra"}, + "tool_call_id": "call_5", + "part_kind": "tool-call" + } + ], + "timestamp": "2026-01-20T07:00:02.500000Z" + }, + { + "parts": [ + { + "tool_name": "subsidy_lookup", + "content": "PMKSY drip irrigation subsidy: General category 55%, SC/ST 65%", + "tool_call_id": "call_5", + "part_kind": "tool-return" + } + ], + "timestamp": "2026-01-20T07:00:03.900000Z" + } + ] + } +] diff --git a/schemas.py b/schemas.py new file mode 100644 index 0000000..8b8efc2 --- /dev/null +++ b/schemas.py @@ -0,0 +1,45 @@ +from pydantic import BaseModel, Field +from typing import List, Optional, Dict, Any + +class Part(BaseModel): + part_kind: str + tool_name: Optional[str] = None + args: Optional[Dict[str, Any]] = None + content: Optional[str] = None + tool_call_id: Optional[str] = None + +class AgentTurn(BaseModel): + parts: List[Part] + timestamp: str + tool_call_id: Optional[str] = None + usage: Optional[Dict[str, Any]] = None + model_name: Optional[str] = None + finish_reason: Optional[str] = None + run_id: Optional[str] = None + provider_name: Optional[str] = None + +class LogSession(BaseModel): + session_id: Optional[str] = None + language: Optional[str] = None + domain: Optional[str] = None + user_question: str + bot_response: str + agent_turns: List[AgentTurn] = Field(default_factory=list) + +class TrainingMetadata(BaseModel): + tool_count: int + unique_tools: List[str] + has_recovery: bool + complexity_tier: str + is_agentic: bool + total_tokens: int = 0 + student_eligible: bool = True + trace_id: Optional[str] = None + rejection_type: Optional[str] = None + +class DPORecord(BaseModel): + prompt: List[Dict[str, Any]] + chosen: str + rejected: str + metadata: TrainingMetadata + synthetic: Optional[bool] = None diff --git a/segmenter.py b/segmenter.py new file mode 100644 index 0000000..55b542c --- /dev/null +++ b/segmenter.py @@ -0,0 +1,64 @@ +from schemas import LogSession, TrainingMetadata + +def validate_trajectory(session: LogSession) -> bool: + """Basic validation for agent trajectories.""" + if not session.agent_turns: + return True # Can be just Q&A + + tool_calls_seen = set() + + for turn in session.agent_turns: + for part in turn.parts: + if part.part_kind == "tool-call": + if not part.tool_name: + return False + if part.tool_call_id: + tool_calls_seen.add(part.tool_call_id) + elif part.part_kind == "tool-return": + if not part.content: + return False + # Tool return must occur after tool call + if part.tool_call_id and part.tool_call_id not in tool_calls_seen: + return False + + return True + +def tag_complexity(session: LogSession, max_tokens: int = 8192) -> TrainingMetadata: + """Tag trajectory complexity.""" + num_tools = 0 + unique_tools = set() + has_recovery = False + total_tokens = 0 + + for turn in session.agent_turns: + if turn.usage: + total_tokens += turn.usage.get("input_tokens", 0) + turn.usage.get("output_tokens", 0) + + for p in turn.parts: + if p.part_kind == "tool-call": + num_tools += 1 + if p.tool_name: + unique_tools.add(p.tool_name) + if p.part_kind == "tool-return" and p.content: + content_lower = str(p.content).lower() + if any(err in content_lower for err in ["error", "timeout", "failed", "no results"]): + has_recovery = True + + if num_tools == 0: + complexity = "simple" + elif (num_tools == 1 or num_tools == 2) and not has_recovery: + complexity = "moderate" + else: + complexity = "complex" + + student_eligible = not (num_tools > 4 or total_tokens > max_tokens) + + return TrainingMetadata( + tool_count=num_tools, + unique_tools=list(unique_tools), + has_recovery=has_recovery, + complexity_tier=complexity, + is_agentic=num_tools > 0, + total_tokens=total_tokens, + student_eligible=student_eligible + ) diff --git a/setup.sh b/setup.sh new file mode 100644 index 0000000..8d96e49 --- /dev/null +++ b/setup.sh @@ -0,0 +1,3 @@ +#!/bin/bash +pip install -r requirements.txt +python -m spacy download en_core_web_lg diff --git a/test_pipeline.py b/test_pipeline.py new file mode 100644 index 0000000..e794ad5 --- /dev/null +++ b/test_pipeline.py @@ -0,0 +1,69 @@ +import pytest +import json +import tempfile +import os +from pii_redactor import redact_pii +from segmenter import tag_complexity, validate_trajectory +from schemas import LogSession +from exporter import create_sft_export, create_dpo_export + +@pytest.fixture +def mock_session(): + return LogSession( + user_question='My email is test@example.com', + bot_response='Thanks for contacting us.', + agent_turns=[ + {'timestamp': '1', 'parts': [{'part_kind': 'tool-call', 'tool_name': 't1', 'tool_call_id': 'c1'}]}, + {'timestamp': '2', 'parts': [{'part_kind': 'tool-return', 'tool_name': 't1', 'tool_call_id': 'c1', 'content': 'result1'}]}, + {'timestamp': '3', 'parts': [{'part_kind': 'tool-call', 'tool_name': 't2', 'tool_call_id': 'c2'}]}, + {'timestamp': '4', 'parts': [{'part_kind': 'tool-return', 'tool_name': 't2', 'tool_call_id': 'c2', 'content': 'error: timeout'}]}, + {'timestamp': '5', 'parts': [{'part_kind': 'tool-call', 'tool_name': 't3', 'tool_call_id': 'c3'}]}, + {'timestamp': '6', 'parts': [{'part_kind': 'tool-return', 'tool_name': 't3', 'tool_call_id': 'c3', 'content': 'result3'}]}, + ] + ) + +def test_redact_pii(): + text = 'Call 9876543210 or email test@example.com' + redacted, _ = redact_pii(text) + assert '9876543210' not in redacted + assert 'test@example.com' not in redacted + assert '' in redacted + assert '' in redacted + +def test_tag_complexity(mock_session): + meta = tag_complexity(mock_session) + assert meta.complexity_tier == 'complex' + +def test_validate_trajectory(): + session = LogSession( + user_question='x', bot_response='y', + agent_turns=[{'timestamp':'1', 'parts': [{'part_kind': 'tool-call'}]}] + ) + assert not validate_trajectory(session) + +def test_sft_export_format(mock_session): + with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as tmp: + tmp_name = tmp.name + with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as tmp_audit: + tmp_audit_name = tmp_audit.name + create_sft_export([mock_session], tmp_name, tmp_audit_name) + with open(tmp_name) as f: + data = json.loads(f.read().strip()) + assert 'messages' in data + assert len(data['messages']) > 0 + os.unlink(tmp_name) + os.unlink(tmp_audit_name) + +def test_dpo_export_format(mock_session): + with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as tmp: + tmp_name = tmp.name + with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as tmp_audit: + tmp_audit_name = tmp_audit.name + create_dpo_export([mock_session], tmp_name, tmp_audit_name) + with open(tmp_name) as f: + data = json.loads(f.read().strip()) + assert 'prompt' in data + assert 'chosen' in data + assert 'rejected' in data + os.unlink(tmp_name) + os.unlink(tmp_audit_name) \ No newline at end of file