Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 117 additions & 0 deletions SCHEMA.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Schema & Pipeline Documentation

## Input Log Schema (`LogSession`)

| Field | Type | Description |
|---|---|---|
| `user_question` | string | Raw user query (PII-redacted before export) |
| `bot_response` | string | Final model response (PII-redacted before export) |
| `agent_turns` | list | Ordered tool-call / tool-return turns between question and response |
| `session_id` | string? | Optional session identifier (not exported to training artifacts) |
| `language` | string? | Optional language tag (e.g. `en`, `hi`) |
| `domain` | string? | Optional domain tag (e.g. `agri`, `weather`) |

### `AgentTurn.parts[].part_kind` values

| Value | Meaning |
|---|---|
| `tool-call` | Model requested a tool; has `tool_name`, `args`, `tool_call_id` |
| `tool-return` | Tool responded; has `tool_name`, `content`, `tool_call_id` |

---

## SFT JSONL Schema (`sft_train.jsonl`, `sft_eval.jsonl`)

Each line is a JSON object:

```json
{
"messages": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": null, "tool_calls": [{"type": "function", "function": {"name": "...", "arguments": "{}"}, "id": "call_1"}]},
{"role": "tool", "tool_call_id": "call_1", "content": "..."},
{"role": "assistant", "content": "final response"}
],
"metadata": {
"tool_count": 2,
"unique_tools": ["weather_forecast", "fetch_agristack_data"],
"has_recovery": false,
"complexity_tier": "moderate",
"is_agentic": true
}
}
```

- Compatible with TRL `SFTTrainer` and Hugging Face `datasets` directly.
- Tool-call format matches OpenAI function-calling chat template used by Gemma, Llama, Qwen.
- All `content` fields are PII-redacted before writing.

---

## DPO JSONL Schema (`dpo_train.jsonl`, `dpo_eval.jsonl`)

Each line is a JSON object:

```json
{
"prompt": [{"role": "user", "content": "..."}],
"chosen": "correct model response",
"rejected": "I cannot help with that. correct model response",
"metadata": { "...same as SFT metadata..." },
"synthetic": true
}
```

- `prompt`: conversation prefix up to and including the last user turn.
- `chosen`: PII-redacted actual bot response.
- `rejected`: synthetic negative (prefixed refusal). `"synthetic": true` marks these for filtering once real preference pairs (thumbs, edits, failure/success pairs) are available.
- Compatible with TRL `DPOTrainer` (`prompt`, `chosen`, `rejected` are the required fields).

---

## Complexity Tags & Recommended Training Schedule

| `complexity_tier` | Condition | Recommended use |
|---|---|---|
| `simple` | 0 tool calls | Warm-up / instruction-following phase |
| `moderate` | 1–2 tool calls, no recovery | Main SFT bulk training |
| `complex` | 3+ tool calls **or** `has_recovery=true` | Later-stage fine-tuning; up-weight in DPO |

Trainers can filter or sample by `metadata.complexity_tier` for staged curriculum or flat mixture runs without any pipeline changes.

---

## Split Strategy

Sessions are assigned to `train` or `eval` using a **hash of `user_question`** (MD5 mod 100), not a random index. This ensures:
- Near-duplicate or repeated queries always land in the **same split**.
- No prompt leakage between `sft_train` and `dpo_train` preference sets.
- Splits are **deterministic and reproducible** across pipeline runs.

Default ratio: 80% train / 20% eval (configurable via `--split-ratio`).

---

## PII Handling

Detected entity types: `PHONE_NUMBER`, `EMAIL_ADDRESS`, `URL`, `IN_AADHAAR`

- Replaced with indexed placeholders: `<PHONE_NUMBER_1>`, `<EMAIL_ADDRESS_1>`, `<IN_AADHAAR_1>`, etc.
- Placeholders are **consistent within a session** (same entity → same placeholder across SFT and DPO exports for that session).
- Raw PII is logged to an in-memory audit list accessible via `audit_sample()` in `pii_redactor.py` for human review of false negatives.
- `PERSON` entity detection is intentionally disabled to avoid false positives on crop names and place names in agricultural context.

**Residual risk:** Indirect identification via rare location + crop combinations is not eliminated by redaction alone. Human audit sampling is required before production export.

---

## Student Model Criteria

Target: smallest dense model ≤ 32B parameters that achieves parity with the teacher on the held-out `sft_eval.jsonl` set.

Acceptance threshold (to be confirmed with mentor):
- Tool-call accuracy (correct tool name + valid args) ≥ teacher baseline on `sft_eval`
- Response quality score (persona adherence, checked via rule layer for verifiable rules) within 5% of teacher
- End-to-end latency: 4–5 tool calls + final response within ~4 seconds on 8×H100

Filter applied to student training data: exclude sessions where `tool_count > 4` or total token length exceeds the student model's context window.
141 changes: 141 additions & 0 deletions exporter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import json
import os
import hashlib
import argparse
import collections
from pii_redactor import redact_pii
from schemas import LogSession, DPORecord
from segmenter import validate_trajectory, tag_complexity

def split_sessions(sessions, train_ratio=0.8, seed=42):
train_sessions = []
eval_sessions = []
for s in sessions:
h = int(hashlib.md5(s.user_question.encode('utf-8')).hexdigest(), 16)
if (h % 100) < (train_ratio * 100):
train_sessions.append(s)
else:
eval_sessions.append(s)
return train_sessions, eval_sessions

def _redact_session(session: LogSession):
"""Redact all text fields in a session using a single shared counter so placeholders are consistent."""
counts = collections.defaultdict(int)
all_audit_entries = []

def r(text):
if not text:
return text
redacted, audits = redact_pii(text, language=session.language or "en", _counts=counts)
all_audit_entries.extend(audits)
return redacted

messages = [{"role": "user", "content": r(session.user_question)}]
for turn in session.agent_turns:
for part in turn.parts:
if part.part_kind == "tool-call":
messages.append({
"role": "assistant",
"content": None,
"tool_calls": [{
"type": "function",
"function": {
"name": part.tool_name or "",
"arguments": json.dumps({k: r(str(v)) for k, v in part.args.items()}) if part.args else "{}"
},
"id": part.tool_call_id or ""
}]
})
elif part.part_kind == "tool-return":
messages.append({
"role": "tool",
"tool_call_id": part.tool_call_id or "",
"content": r(part.content)
})
messages.append({"role": "assistant", "content": r(session.bot_response)})
return messages, all_audit_entries


def get_trace_id(session: LogSession) -> str:
if session.session_id:
return hashlib.md5(session.session_id.encode('utf-8')).hexdigest()[:12]
return hashlib.md5(session.user_question.encode('utf-8')).hexdigest()[:12]

def create_sft_export(sessions, output_file, audit_file_path):
with open(output_file, "w") as f, open(audit_file_path, "a") as af:
for i, session in enumerate(sessions):
if not validate_trajectory(session):
continue
messages, audit_entries = _redact_session(session)

trace_id = get_trace_id(session)
meta = tag_complexity(session)
meta.trace_id = trace_id

if audit_entries:
af.write(json.dumps({"trace_id": trace_id, "findings": audit_entries}) + "\n")

f.write(json.dumps({"messages": messages, "metadata": meta.model_dump()}) + "\n")

def generate_synthetic_rejection(messages):
chosen = messages[-1]["content"]
# Rejection type: persona_violation
rejected = "Yes, here is the answer. " + chosen
return rejected, "persona_violation"

def create_dpo_export(sessions, output_file, audit_file_path):
with open(output_file, "w") as f, open(audit_file_path, "a") as af:
for i, session in enumerate(sessions):
if not session.agent_turns or not validate_trajectory(session):
continue
messages, audit_entries = _redact_session(session)
trace_id = get_trace_id(session)

if audit_entries:
af.write(json.dumps({"trace_id": trace_id, "findings": audit_entries}) + "\n")

prompt = [messages[0]]
chosen = messages[-1]["content"]
rejected, rejection_type = generate_synthetic_rejection(messages)

meta = tag_complexity(session)
meta.trace_id = trace_id
meta.rejection_type = rejection_type

dpo_record = DPORecord(
prompt=prompt, chosen=chosen, rejected=rejected,
metadata=meta, synthetic=True
)
f.write(json.dumps(dpo_record.model_dump()) + "\n")

def process_logs(input_file: str, output_dir: str, split_ratio: float):
os.makedirs(output_dir, exist_ok=True)
audit_file_path = os.path.join(output_dir, "audit_log.jsonl")
# Clear audit log if it exists
if os.path.exists(audit_file_path):
os.remove(audit_file_path)

with open(input_file, 'r') as f:
data = json.load(f)

sessions = [LogSession(**item) for item in data]

train_sessions, eval_sessions = split_sessions(sessions, train_ratio=split_ratio)

create_sft_export(train_sessions, os.path.join(output_dir, "sft_train.jsonl"), audit_file_path)
create_sft_export(eval_sessions, os.path.join(output_dir, "sft_eval.jsonl"), audit_file_path)

create_dpo_export(train_sessions, os.path.join(output_dir, "dpo_train.jsonl"), audit_file_path)
create_dpo_export(eval_sessions, os.path.join(output_dir, "dpo_eval.jsonl"), audit_file_path)

print(f"Processed {len(sessions)} sessions, exported to {output_dir}")

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--input", default="sample_logs.json")
parser.add_argument("--output", default="output/")
parser.add_argument("--split-ratio", type=float, default=0.8)
args = parser.parse_args()

if os.path.exists(args.input):
process_logs(args.input, args.output, args.split_ratio)
64 changes: 64 additions & 0 deletions pii_redactor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import re
import collections
import spacy

try:
spacy.util.get_package_path('en_core_web_lg')
except Exception:
raise ImportError('Missing required spacy model. Please run: setup.sh or python -m spacy download en_core_web_lg')

from presidio_analyzer import AnalyzerEngine, PatternRecognizer, Pattern

analyzer = AnalyzerEngine()

# Indian Aadhaar: 12 digits optionally grouped as XXXX-XXXX-XXXX
_AADHAAR_PATTERN = Pattern(name="aadhaar", regex=r"\b\d{4}[- ]?\d{4}[- ]?\d{4}\b", score=0.85)
analyzer.registry.add_recognizer(
PatternRecognizer(supported_entity="IN_AADHAAR", patterns=[_AADHAAR_PATTERN])
)

# Entities to detect — excludes PERSON to avoid crop/place false positives
_ENTITIES = ["PHONE_NUMBER", "EMAIL_ADDRESS", "URL", "IN_AADHAAR"]

def redact_pii(text: str, language: str = "en", _counts: dict = None):
"""Redact PII in text. Pass a shared _counts dict per session for consistent placeholders."""
if not text:
return text, []

if _counts is None:
_counts = collections.defaultdict(int)

audit_entries = []

if language != "en":
import re
results = []
for match in re.finditer(r'\b\d{10}\b', text):
results.append(("PHONE_NUMBER", match.start(), match.end()))
for match in re.finditer(r'\b\d{4}[- ]?\d{4}[- ]?\d{4}\b', text):
results.append(("IN_AADHAAR", match.start(), match.end()))
for match in re.finditer(r'[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}', text):
results.append(("EMAIL_ADDRESS", match.start(), match.end()))
results = sorted(results, key=lambda x: x[1], reverse=True)

redacted_text = text
for r_type, start, end in results:
entity_value = text[start:end]
_counts[r_type] += 1
placeholder = f"<{r_type}_{_counts[r_type]}>"
audit_entries.append({"original_span": entity_value, "placeholder": placeholder, "entity_type": r_type})
redacted_text = redacted_text[:start] + placeholder + redacted_text[end:]
return redacted_text, audit_entries

results = analyzer.analyze(text=text, entities=_ENTITIES, language="en")
results = sorted(results, key=lambda x: x.start, reverse=True)
redacted_text = text

for r in results:
entity_value = text[r.start:r.end]
_counts[r.entity_type] += 1
placeholder = f"<{r.entity_type}_{_counts[r.entity_type]}>"
audit_entries.append({"original_span": entity_value, "placeholder": placeholder, "entity_type": r.entity_type})
redacted_text = redacted_text[:r.start] + placeholder + redacted_text[r.end:]

return redacted_text, audit_entries
6 changes: 6 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Note: the spacy model en_core_web_lg must be downloaded separately.
# Run setup.sh to automatically install these requirements and download the model.
presidio-analyzer
presidio-anonymizer
pydantic>=2
pytest
Loading