Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@

.venv/
venv/
env/
ENV/

__pycache__/
*.pyc
*.pyo
*.pyd
.pytest_cache/
.coverage
htmlcov/
.mypy_cache/
.ruff_cache/

training_setup_logs/
output/
*.log

.DS_Store
.vscode/
.idea/
1 change: 1 addition & 0 deletions .python-version
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
3.14
117 changes: 117 additions & 0 deletions DATA_SCHEMA.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Schema & Pipeline Documentation

## Input Log Schema (`LogSession`)

| Field | Type | Description |
|---|---|---|
| `user_question` | string | Raw user query (PII-redacted before export) |
| `bot_response` | string | Final model response (PII-redacted before export) |
| `agent_turns` | list | Ordered tool-call / tool-return turns between question and response |
| `session_id` | string? | Optional session identifier (not exported to training artifacts) |
| `language` | string? | Optional language tag (e.g. `en`, `hi`) |
| `domain` | string? | Optional domain tag (e.g. `agri`, `weather`) |

### `AgentTurn.parts[].part_kind` values

| Value | Meaning |
|---|---|
| `tool-call` | Model requested a tool; has `tool_name`, `args`, `tool_call_id` |
| `tool-return` | Tool responded; has `tool_name`, `content`, `tool_call_id` |

---

## SFT JSONL Schema (`sft_train.jsonl`, `sft_eval.jsonl`)

Each line is a JSON object:

```json
{
"messages": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": null, "tool_calls": [{"type": "function", "function": {"name": "...", "arguments": "{}"}, "id": "call_1"}]},
{"role": "tool", "tool_call_id": "call_1", "content": "..."},
{"role": "assistant", "content": "final response"}
],
"metadata": {
"tool_count": 2,
"unique_tools": ["weather_forecast", "fetch_agristack_data"],
"has_recovery": false,
"complexity_tier": "moderate",
"is_agentic": true
}
}
```

- Compatible with TRL `SFTTrainer` and Hugging Face `datasets` directly.
- Tool-call format matches OpenAI function-calling chat template used by Gemma, Llama, Qwen.
- All `content` fields are PII-redacted before writing.

---

## DPO JSONL Schema (`dpo_train.jsonl`, `dpo_eval.jsonl`)

Each line is a JSON object:

```json
{
"prompt": [{"role": "user", "content": "..."}],
"chosen": [{"role": "assistant", "content": "correct model response"}],
"rejected": [{"role": "assistant", "content": "I cannot help with that. correct model response"}],
"metadata": { "...same as SFT metadata..." },
"synthetic": true
}
```

- `prompt`: conversation prefix up to and including the last user turn.
- `chosen`: PII-redacted actual bot response.
- `rejected`: synthetic negative (prefixed refusal). `"synthetic": true` marks these for filtering once real preference pairs (thumbs, edits, failure/success pairs) are available.
- Compatible with TRL `DPOTrainer` (`prompt`, `chosen`, `rejected` are the required fields).

---

## Complexity Tags & Recommended Training Schedule

| `complexity_tier` | Condition | Recommended use |
|---|---|---|
| `simple` | 0 tool calls | Warm-up / instruction-following phase |
| `moderate` | 1–2 tool calls, no recovery | Main SFT bulk training |
| `complex` | 3+ tool calls **or** `has_recovery=true` | Later-stage fine-tuning; up-weight in DPO |

Trainers can filter or sample by `metadata.complexity_tier` for staged curriculum or flat mixture runs without any pipeline changes.

---

## Split Strategy

Sessions are assigned to `train` or `eval` using a **hash of `user_question`** (MD5 mod 100), not a random index. This ensures:
- Near-duplicate or repeated queries always land in the **same split**.
- No prompt leakage between `sft_train` and `dpo_train` preference sets.
- Splits are **deterministic and reproducible** across pipeline runs.

Default ratio: 80% train / 20% eval (configurable via `--split-ratio`).

---

## PII Handling

Detected entity types: `PHONE_NUMBER`, `EMAIL_ADDRESS`, `URL`, `IN_AADHAAR`

- Replaced with indexed placeholders: `<PHONE_NUMBER_1>`, `<EMAIL_ADDRESS_1>`, `<IN_AADHAAR_1>`, etc.
- Placeholders are **consistent within a session** (same entity → same placeholder across SFT and DPO exports for that session).
- Raw PII is logged to an in-memory audit list accessible via `audit_sample()` in `anonymizer.py` for human review of false negatives.
- `PERSON` entity detection is intentionally disabled to avoid false positives on crop names and place names in agricultural context.

**Residual risk:** Indirect identification via rare location + crop combinations is not eliminated by redaction alone. Human audit sampling is required before production export.

---

## Student Model Criteria

Target: smallest dense model ≤ 32B parameters that achieves parity with the teacher on the held-out `sft_eval.jsonl` set.

Acceptance threshold (to be confirmed with mentor):
- Tool-call accuracy (correct tool name + valid args) ≥ teacher baseline on `sft_eval`
- Response quality score (persona adherence, checked via rule layer for verifiable rules) within 5% of teacher
- End-to-end latency: 4–5 tool calls + final response within ~4 seconds on 8×H100

Filter applied to student training data: exclude sessions where `tool_count > 4` or total token length exceeds the student model's context window.
148 changes: 148 additions & 0 deletions analyzer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
"""Trajectory validation and complexity tagging for agent sessions."""
from __future__ import annotations

from types_def import LogSession, TrainingMetadata

# ── Tool name allowlist ───────────────────────────────────────────────────────
# Set to None to skip allowlist checking.
ALLOWED_TOOL_NAMES: set[str] | None = {
"weather_forecast",
"fetch_agristack_data",
"crop_recommendation",
"subsidy_lookup",
"market_price",
"pest_advisory",
"soil_report",
"irrigation_schedule",
#can keep adding more tool names here
}

# Words that signal an ambiguous or under-specified question
_AMBIGUITY_TOKENS: set[str] = {
# English
"what", "how", "when", "why", "which", "where", "any", "some",
"maybe", "possibly", "sometimes", "generally",
# Hindi / transliterated
"kya", "kab", "kyun", "kaise", "kaun", "kuch", "koi",
# Marathi
"kay", "keva", "kasa",
}


def validate_trajectory(session: LogSession) -> bool:
"""Return True if the agent trajectory is structurally valid.

Rules
-----
- tool-call parts must have a non-empty ``tool_name``.
- If ``ALLOWED_TOOL_NAMES`` is set, ``tool_name`` must be in the allowlist.
- tool-return parts must have a matching tool-call (by ``tool_call_id``)
appearing earlier in the sequence.
- tool-return with ``content = None`` is rejected; ``content = ""`` is
accepted (explicit empty return is a valid signal).
"""
if not session.agent_turns:
return False # Reject completely empty or malformed sessions (no meaningful trajectory)

tool_calls_seen: set[str] = set()

for turn in session.agent_turns:
for part in turn.parts:
if part.part_kind == "tool-call":
if not part.tool_name:
return False
if (
ALLOWED_TOOL_NAMES is not None
and part.tool_name not in ALLOWED_TOOL_NAMES
):
return False
if part.tool_call_id:
tool_calls_seen.add(part.tool_call_id)

elif part.part_kind == "tool-return":
# content=None means the field was never set — reject
# content="" means the tool explicitly returned nothing — accept
if part.content is None:
return False
if part.tool_call_id and part.tool_call_id not in tool_calls_seen:
return False

return True


def _ambiguity_score(text: str) -> float:
"""Heuristic ambiguity score in [0, 1].

Counts the fraction of tokens that are in :data:`_AMBIGUITY_TOKENS`.
"""
if not text:
return 0.0
tokens = text.lower().split()
if not tokens:
return 0.0
hits = sum(1 for t in tokens if t.strip("?.,!") in _AMBIGUITY_TOKENS)
return round(min(hits / len(tokens), 1.0), 3)


def tag_complexity(session: LogSession, max_tokens: int = 8192) -> TrainingMetadata:
"""Compute and return :class:`TrainingMetadata` for *session*.

Complexity tiers
----------------
``simple``
0 tool calls — instruction-following warmup data.
``moderate``
1–3 tool calls with no error recovery — main SFT bulk.
``complex``
4+ tool calls **or** any recovery step — late-stage / DPO up-weight.
"""
num_tools = 0
unique_tools: set[str] = set()
has_recovery = False
total_tokens = 0

for turn in session.agent_turns:
if turn.usage:
total_tokens += (
turn.usage.get("input_tokens", 0) + turn.usage.get("output_tokens", 0)
)

for p in turn.parts:
if p.part_kind == "tool-call":
num_tools += 1
if p.tool_name:
unique_tools.add(p.tool_name)
if p.part_kind == "tool-return" and p.content:
content_lower = str(p.content).lower()
if any(
err in content_lower
for err in ["error", "timeout", "failed", "no results"]
):
has_recovery = True

# Revised tier thresholds
if num_tools == 0:
complexity = "simple"
elif num_tools <= 3 and not has_recovery:
complexity = "moderate"
else:
complexity = "complex"

student_eligible = not (num_tools > 4 or total_tokens > max_tokens)

# multi_turn_depth: number of distinct user turns (always ≥1)
# In the current schema every session is a single user question; this
# counter is a placeholder for future multi-turn schema support.
multi_turn_depth = 1

return TrainingMetadata(
tool_count=num_tools,
unique_tools=list(unique_tools),
has_recovery=has_recovery,
complexity_tier=complexity,
is_agentic=num_tools > 0,
total_tokens=total_tokens,
student_eligible=student_eligible,
multi_turn_depth=multi_turn_depth,
ambiguity_score=_ambiguity_score(session.user_question),
)
Loading