diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..a882411 --- /dev/null +++ b/.gitignore @@ -0,0 +1,38 @@ +# Virtual environments +.venv/ +venv/ +env/ + +# Python cache +__pycache__/ +*.py[cod] + +# Pytest +.pytest_cache/ + +# Coverage +.coverage +htmlcov/ + +# Jupyter +.ipynb_checkpoints/ + +# OS files +.DS_Store +Thumbs.db + +# IDE/editor +.vscode/ +.idea/ + +# Build artifacts +build/ +dist/ +*.egg-info/ + +# Local test outputs +test_results.txt + +# Temporary files +*.tmp +*.log \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/fixtures.py b/tests/fixtures.py new file mode 100644 index 0000000..3e6205c --- /dev/null +++ b/tests/fixtures.py @@ -0,0 +1,235 @@ +""" +fixtures.py +----------- +Realistic test fixtures for multilingual and trajectory tests. +""" + +from __future__ import annotations + +from typing import List + +from training_setup_logs.trajectory.models import ( + MessageRole, + ToolCall, + ToolCallStatus, + Trajectory, + TurnMessage, +) + + +# --------------------------------------------------------------------------- +# Multilingual query fixtures +# --------------------------------------------------------------------------- + + +WEATHER_QUERIES_MULTILINGUAL: List[str] = [ + # Devanagari + "कल मौसम कैसा रहेगा", + # Transliterated variants + "kal mausam kaisa rahega", + "kal mosam kaisa rahega", + "kal mousam kaisa hoga", + # Code-switched + "kal weather kaisa hai", + "मौसम tomorrow kaisa hoga", + # English + "what will the weather be like tomorrow", + # Unrelated (should NOT cluster with weather queries) + "मुझे एक अच्छा रेस्टोरेंट बताओ", + "recommend a restaurant near me", + "find me a good book to read", +] + +HINDI_QUERIES: List[str] = [ + "आज का तापमान क्या है", + "कल बारिश होगी क्या", + "मुझे दिल्ली का मौसम बताओ", +] + +TRANSLITERATED_QUERIES: List[str] = [ + "aaj ka tapmaan kya hai", + "kal barish hogi kya", + "mujhe delhi ka mausam batao", +] + +CODE_SWITCHED_QUERIES: List[str] = [ + "aaj temperature kitna hai", + "kal rain hogi kya Delhi mein", + "मुझे weather update do", +] + +ENGLISH_QUERIES: List[str] = [ + "what is the temperature today", + "will it rain tomorrow", + "give me the weather update for Delhi", +] + + +# --------------------------------------------------------------------------- +# Trajectory fixtures +# --------------------------------------------------------------------------- + + +def make_clean_trajectory(tid: str = "traj_clean_001") -> Trajectory: + """A clean, efficient, successful trajectory.""" + return Trajectory( + trajectory_id=tid, + turns=[ + TurnMessage(role=MessageRole.USER, content="What is the weather in Delhi tomorrow?", language="en"), + TurnMessage(role=MessageRole.ASSISTANT, content="Let me check the weather for you.", language="en"), + TurnMessage(role=MessageRole.ASSISTANT, content="Tomorrow in Delhi: 32°C, partly cloudy.", language="en"), + ], + tool_calls=[ + ToolCall( + tool_name="weather_api", + arguments={"city": "Delhi", "date": "tomorrow"}, + return_value={"temp": 32, "condition": "partly cloudy"}, + status=ToolCallStatus.SUCCESS, + latency_ms=230.0, + ) + ], + ) + + +def make_retry_trajectory(tid: str = "traj_retry_001") -> Trajectory: + """Trajectory with a retry that eventually succeeds.""" + return Trajectory( + trajectory_id=tid, + turns=[ + TurnMessage(role=MessageRole.USER, content="kal mausam kaisa rahega", language="hi-latn"), + TurnMessage(role=MessageRole.ASSISTANT, content="Let me check. One moment.", language="en"), + TurnMessage(role=MessageRole.ASSISTANT, content="kal Delhi mein 30°C hoga.", language="hi-latn"), + ], + tool_calls=[ + ToolCall( + tool_name="weather_api", + arguments={"city": "Delhi", "date": "tomorrow"}, + return_value=None, + status=ToolCallStatus.FAILURE, + latency_ms=40.0, + ), + ToolCall( + tool_name="weather_api", + arguments={"city": "Delhi", "date": "tomorrow"}, + return_value={"temp": 30, "condition": "sunny"}, + status=ToolCallStatus.SUCCESS, + latency_ms=280.0, + retry_of=0, + ), + ], + ) + + +def make_redundant_trajectory(tid: str = "traj_redundant_001") -> Trajectory: + """Trajectory with redundant tool calls.""" + return Trajectory( + trajectory_id=tid, + turns=[ + TurnMessage(role=MessageRole.USER, content="Weather in Mumbai?", language="en"), + TurnMessage(role=MessageRole.ASSISTANT, content="Mumbai: 28°C, humid.", language="en"), + ], + tool_calls=[ + ToolCall( + tool_name="weather_api", + arguments={"city": "Mumbai"}, + return_value={"temp": 28}, + status=ToolCallStatus.SUCCESS, + latency_ms=200.0, + ), + # Identical call — redundant + ToolCall( + tool_name="weather_api", + arguments={"city": "Mumbai"}, + return_value={"temp": 28}, + status=ToolCallStatus.SUCCESS, + latency_ms=190.0, + ), + ], + ) + + +def make_incomplete_trajectory(tid: str = "traj_incomplete_001") -> Trajectory: + """Trajectory ending in a user turn (no assistant response).""" + return Trajectory( + trajectory_id=tid, + turns=[ + TurnMessage(role=MessageRole.USER, content="कल बारिश होगी क्या?", language="hi"), + TurnMessage(role=MessageRole.ASSISTANT, content="Let me check...", language="en"), + TurnMessage(role=MessageRole.USER, content="जल्दी बताओ", language="hi"), # last = user + ], + tool_calls=[ + ToolCall( + tool_name="weather_api", + arguments={"city": "unknown"}, + return_value=None, + status=ToolCallStatus.MISSING_RETURN, + latency_ms=5000.0, + ) + ], + ) + + +def make_multilingual_recovery_trajectory(tid: str = "traj_ml_recovery_001") -> Trajectory: + """ + Hard multilingual trajectory with: + - code-switched user query + - tool failure + - fallback tool success + - clarification loop + """ + return Trajectory( + trajectory_id=tid, + turns=[ + TurnMessage(role=MessageRole.USER, content="kal rain hogi kya Delhi mein?", language="hi-en-mixed"), + TurnMessage(role=MessageRole.ASSISTANT, content="Which area of Delhi?", language="en"), + TurnMessage(role=MessageRole.USER, content="South Delhi", language="en"), + TurnMessage(role=MessageRole.ASSISTANT, content="South Delhi: moderate rain expected.", language="en"), + ], + tool_calls=[ + ToolCall( + tool_name="rainfall_api", + arguments={"city": "Delhi"}, + return_value=None, + status=ToolCallStatus.FAILURE, + latency_ms=45.0, + ), + ToolCall( + tool_name="weather_api", + arguments={"city": "South Delhi", "date": "tomorrow"}, + return_value={"rain_prob": 0.75}, + status=ToolCallStatus.SUCCESS, + latency_ms=310.0, + is_fallback=True, + ), + ], + ) + + +def make_hallucinated_args_trajectory(tid: str = "traj_halluc_001") -> Trajectory: + """Trajectory where tool args contain placeholder/hallucinated values.""" + return Trajectory( + trajectory_id=tid, + turns=[ + TurnMessage(role=MessageRole.USER, content="Book me a flight to Goa", language="en"), + TurnMessage(role=MessageRole.ASSISTANT, content="I've booked a flight for you.", language="en"), + ], + tool_calls=[ + ToolCall( + tool_name="flight_booking", + arguments={"destination": "Goa", "departure": "", "passenger": "TODO"}, + return_value=None, + status=ToolCallStatus.HALLUCINATED, + latency_ms=120.0, + ) + ], + ) + + +ALL_TRAJECTORIES = [ + make_clean_trajectory(), + make_retry_trajectory(), + make_redundant_trajectory(), + make_incomplete_trajectory(), + make_multilingual_recovery_trajectory(), + make_hallucinated_args_trajectory(), +] \ No newline at end of file diff --git a/tests/test_leakage.py b/tests/test_leakage.py new file mode 100644 index 0000000..23969bb --- /dev/null +++ b/tests/test_leakage.py @@ -0,0 +1,77 @@ +""" +test_leakage.py +--------------- +Tests for train/eval split leakage detection. +""" + +import pytest + +from training_setup_logs.multilingual.leakage_detector import ( + LeakageReport, + detect_leakage, +) + + +class TestLeakageDetector: + def test_no_leakage_on_distinct_splits(self): + train = [ + "how do I apply for a passport", + "best restaurants in Kolkata", + "Python list comprehension tutorial", + ] + eval_ = [ + "how to file income tax return", + "train schedule from Delhi to Mumbai", + "machine learning overfitting explained", + ] + report = detect_leakage(train, eval_, skip_semantic=True) + assert report.total_leaks == 0 + assert report.leak_rate == 0.0 + + def test_exact_leak_detected(self): + train = ["कल मौसम कैसा रहेगा", "best hotel in Goa"] + eval_ = ["कल मौसम कैसा रहेगा", "completely different query"] + report = detect_leakage(train, eval_, skip_semantic=True) + assert len(report.exact_leaks) >= 1 + + def test_transliteration_leak_detected(self): + train = ["kal mausam kaisa rahega"] + eval_ = ["kal mosam kaisa rahega"] # spelling variant + report = detect_leakage(train, eval_, skip_semantic=True) + # After canonicalization these should match + assert report.total_leaks >= 1 + + def test_empty_splits(self): + report = detect_leakage([], []) + assert report.total_leaks == 0 + + def test_empty_train(self): + report = detect_leakage([], ["some query"]) + assert report.total_leaks == 0 + + def test_report_structure(self): + train = ["hello world"] + eval_ = ["hello world"] + report = detect_leakage(train, eval_, skip_semantic=True) + d = report.to_dict() + assert "train_size" in d + assert "eval_size" in d + assert "leak_rate" in d + assert "cross_split_leaks" in d + assert isinstance(d["cross_split_leaks"], list) + + def test_leak_rate_bounded(self): + train = ["q1", "q2", "q3"] + eval_ = ["q1", "q2", "q4"] + report = detect_leakage(train, eval_, skip_semantic=True) + assert 0.0 <= report.leak_rate <= 1.0 + + def test_full_leak_detected_semantic(self): + """ + Test the semantic (embedding-based) detection path. + Uses near-identical text that TF-IDF char-ngram similarity will also catch. + """ + train = ["delhi weather tomorrow forecast rain"] + eval_ = ["delhi weather tomorrow forecast rain sunny"] # superset + report = detect_leakage(train, eval_, semantic_threshold=0.50, skip_semantic=False) + assert report.total_leaks >= 1 \ No newline at end of file diff --git a/tests/test_multilingual.py b/tests/test_multilingual.py new file mode 100644 index 0000000..a8e5ba8 --- /dev/null +++ b/tests/test_multilingual.py @@ -0,0 +1,173 @@ +""" +test_multilingual.py +-------------------- +Tests for indic_normalizer, transliteration, and code-switch handling. +""" + +import pytest + +from training_setup_logs.multilingual.indic_normalizer import ( + is_devanagari, + normalize_devanagari, + normalize_devanagari_numerals, + normalize_indic_text, + normalize_matras, + normalize_unicode, + normalize_whitespace, + _detect_script, +) +from training_setup_logs.multilingual.transliteration import ( + canonicalize_transliteration, + detect_transliterated_hindi, + transliteration_key, +) + + +# --------------------------------------------------------------------------- +# indic_normalizer tests +# --------------------------------------------------------------------------- + + +class TestUnicodeNormalization: + def test_nfc_normalization(self): + # Composed vs decomposed 'का' + decomposed = "क" + "\u093e" # ka + aa-matra (decomposed) + composed = "का" + assert normalize_unicode(decomposed) == normalize_unicode(composed) + + def test_returns_string(self): + assert isinstance(normalize_unicode("hello"), str) + + def test_empty_string(self): + assert normalize_unicode("") == "" + + +class TestNumeralNormalization: + def test_devanagari_to_ascii(self): + assert normalize_devanagari_numerals("१२३") == "123" + + def test_mixed_numerals(self): + assert normalize_devanagari_numerals("तापमान ३२°C") == "तापमान 32°C" + + def test_pure_ascii_unchanged(self): + assert normalize_devanagari_numerals("123") == "123" + + +class TestWhitespaceNormalization: + def test_collapses_spaces(self): + assert normalize_whitespace("hello world") == "hello world" + + def test_strips_edges(self): + assert normalize_whitespace(" hi ") == "hi" + + def test_zero_width_chars_removed(self): + text = "hello\u200bworld" + assert normalize_whitespace(text) == "helloworld" + + def test_newlines_collapsed(self): + assert normalize_whitespace("hello\n\nworld") == "hello world" + + +class TestScriptDetection: + def test_devanagari_detected(self): + assert _detect_script("कल मौसम कैसा रहेगा") == "devanagari" + + def test_latin_detected(self): + assert _detect_script("kal mausam kaisa") == "latin" + + def test_mixed_detected(self): + assert _detect_script("मौसम tomorrow") == "mixed" + + +class TestIsDevanagari: + def test_pure_devanagari(self): + assert is_devanagari("नमस्ते") + + def test_pure_latin_false(self): + assert not is_devanagari("hello") + + def test_mixed_true(self): + assert is_devanagari("hello नमस्ते") + + +class TestNormalizeDevanagari: + def test_lowercases_result(self): + result = normalize_devanagari("कल मौसम") + assert result == result.lower() + + def test_numerals_converted(self): + result = normalize_devanagari("तापमान ३५") + assert "35" in result + + def test_whitespace_cleaned(self): + result = normalize_devanagari("कल मौसम") + assert " " not in result + + +class TestNormalizeIndicText: + def test_devanagari_input(self): + text = "कल मौसम कैसा रहेगा" + result = normalize_indic_text(text) + assert isinstance(result, str) + assert len(result) > 0 + + def test_latin_input_lowercased(self): + result = normalize_indic_text("Hello World") + assert result == "hello world" + + def test_mixed_input(self): + result = normalize_indic_text("मौसम tomorrow") + assert isinstance(result, str) + + +# --------------------------------------------------------------------------- +# transliteration tests +# --------------------------------------------------------------------------- + + +class TestDetectTransliteratedHindi: + def test_known_token_detected(self): + assert detect_transliterated_hindi("kal mausam kaisa rahega") + + def test_english_not_detected(self): + assert not detect_transliterated_hindi("the weather is nice today") + + def test_empty_not_detected(self): + assert not detect_transliterated_hindi("") + + def test_code_switch_detected(self): + assert detect_transliterated_hindi("kal weather kaisa hai") + + +class TestCanonicalizeTransliteration: + def test_mausam_variants_normalize(self): + assert canonicalize_transliteration("mosam") == canonicalize_transliteration("mausam") + assert canonicalize_transliteration("mousam") == canonicalize_transliteration("mausam") + + def test_negation_normalizes(self): + assert canonicalize_transliteration("nahin") == canonicalize_transliteration("nahi") + + def test_lowercase_applied(self): + result = canonicalize_transliteration("MAUSAM") + assert result == result.lower() + + def test_repeated_chars_collapsed(self): + result = canonicalize_transliteration("achhhha") + assert "hhh" not in result + + def test_idempotent(self): + text = "kal mausam kaisa hai" + assert canonicalize_transliteration(text) == canonicalize_transliteration( + canonicalize_transliteration(text) + ) + + +class TestTransliterationKey: + def test_variants_produce_same_key(self): + key1 = transliteration_key("mosam") + key2 = transliteration_key("mausam") + assert key1 == key2 + + def test_devanagari_hint(self): + key = transliteration_key("कल मौसम", script_hint="devanagari") + assert isinstance(key, str) \ No newline at end of file diff --git a/tests/test_semantic_dedup.py b/tests/test_semantic_dedup.py new file mode 100644 index 0000000..6e99069 --- /dev/null +++ b/tests/test_semantic_dedup.py @@ -0,0 +1,97 @@ +""" +test_semantic_dedup.py +---------------------- +Tests for semantic deduplication (with TF-IDF fallback when +sentence-transformers is unavailable). +""" + +import pytest + +from training_setup_logs.multilingual.semantic_dedup import ( + DeduplicationResult, + deduplicate, + query_fingerprint, +) +from tests.fixtures import WEATHER_QUERIES_MULTILINGUAL + + +class TestQueryFingerprint: + def test_returns_string(self): + fp = query_fingerprint("कल मौसम कैसा रहेगा") + assert isinstance(fp, str) + assert len(fp) == 16 + + def test_same_text_same_fingerprint(self): + assert query_fingerprint("hello world") == query_fingerprint("hello world") + + def test_different_texts_different_fingerprint(self): + fp1 = query_fingerprint("कल मौसम कैसा रहेगा") + fp2 = query_fingerprint("मुझे एक रेस्टोरेंट बताओ") + assert fp1 != fp2 + + def test_transliteration_variants_similar_fingerprint(self): + # After normalization, "mosam" and "mausam" should have same fingerprint + fp1 = query_fingerprint("kal mausam kaisa rahega") + fp2 = query_fingerprint("kal mosam kaisa rahega") + # They should be equal after canonicalization + assert fp1 == fp2 + + +class TestDeduplicate: + def test_empty_list(self): + result = deduplicate([]) + assert result.total_inputs == 0 + assert result.unique_indices == [] + + def test_single_item(self): + result = deduplicate(["hello"]) + assert result.total_inputs == 1 + assert result.unique_indices == [0] + assert result.dedup_ratio == 0.0 + + def test_exact_duplicates_removed(self): + queries = ["hello world", "hello world", "hello world"] + result = deduplicate(queries, threshold=0.99) + assert len(result.unique_indices) == 1 + assert result.dedup_ratio > 0 + + def test_distinct_queries_all_kept(self): + queries = [ + "what is the weather today", + "मुझे एक रेस्टोरेंट बताओ", + "how do I book a flight", + ] + result = deduplicate(queries, threshold=0.90) + # These are semantically very different — should all survive + assert len(result.unique_indices) >= 2 # at least most kept + + def test_result_structure(self): + queries = ["hello", "world", "hello"] + result = deduplicate(queries, threshold=0.99) + assert isinstance(result, DeduplicationResult) + assert isinstance(result.unique_indices, list) + assert isinstance(result.cluster_map, dict) + assert isinstance(result.similarity_pairs, list) + assert 0.0 <= result.dedup_ratio <= 1.0 + + def test_cluster_map_coverage(self): + queries = ["foo", "bar", "baz"] + result = deduplicate(queries) + # Every index should be in cluster_map + for i in range(len(queries)): + assert i in result.cluster_map + + def test_weather_queries_multilingual(self): + """ + The four weather query variants (Devanagari, transliterated, code-switched, + English) should cluster with at least some deduplication happening. + """ + result = deduplicate(WEATHER_QUERIES_MULTILINGUAL, threshold=0.80) + assert result.total_inputs == len(WEATHER_QUERIES_MULTILINGUAL) + # Should find SOME duplicates among the weather variants + assert result.num_clusters < result.total_inputs + + def test_num_clusters_property(self): + result = deduplicate(["a", "b", "c"]) + assert isinstance(result.num_clusters, int) + assert result.num_clusters >= 1 \ No newline at end of file diff --git a/tests/test_trajectory.py b/tests/test_trajectory.py new file mode 100644 index 0000000..8a346e3 --- /dev/null +++ b/tests/test_trajectory.py @@ -0,0 +1,265 @@ +""" +test_trajectory.py +------------------ +Tests for failure classification, recovery mining, difficulty scoring, +hard-example tagging, and the full analyzer pipeline. +""" + +import pytest + +from training_setup_logs.trajectory.analyzer import analyze_trajectory, batch_summary +from training_setup_logs.trajectory.difficulty import classify_difficulty, compute_difficulty_score +from training_setup_logs.trajectory.failure_classifier import classify_failures +from training_setup_logs.trajectory.hard_example_miner import is_hard_example, tag_hard_examples +from training_setup_logs.trajectory.models import ( + DifficultyLevel, + FailureType, + HardExampleTag, + ToolCallStatus, +) +from training_setup_logs.trajectory.recovery_patterns import ( + detect_clarification_loop, + detect_retry_patterns, + detect_fallback_patterns, + detect_self_correction, + mine_recovery_patterns, +) +from training_setup_logs.trajectory.metadata_enrichment import ( + compute_efficiency_score, + detect_trajectory_language, +) + +from tests.fixtures import ( + make_clean_trajectory, + make_retry_trajectory, + make_redundant_trajectory, + make_incomplete_trajectory, + make_multilingual_recovery_trajectory, + make_hallucinated_args_trajectory, + ALL_TRAJECTORIES, +) + + +# --------------------------------------------------------------------------- +# Failure classification +# --------------------------------------------------------------------------- + +class TestFailureClassifier: + def test_clean_trajectory_no_failures(self): + traj = make_clean_trajectory() + failures, severity, _ = classify_failures(traj) + assert len(failures) == 0 + assert severity == "none" + + def test_incomplete_trajectory_detected(self): + traj = make_incomplete_trajectory() + failures, severity, _ = classify_failures(traj) + assert FailureType.INCOMPLETE_TRAJECTORY in failures + + def test_redundant_tool_detected(self): + traj = make_redundant_trajectory() + failures, severity, _ = classify_failures(traj) + assert FailureType.REDUNDANT_TOOL in failures + + def test_hallucinated_args_detected(self): + traj = make_hallucinated_args_trajectory() + failures, severity, _ = classify_failures(traj) + assert FailureType.HALLUCINATED_ARGS in failures + + def test_repair_candidate_for_fixable_failure(self): + traj = make_redundant_trajectory() + failures, severity, repair = classify_failures(traj) + assert repair is True + + def test_no_repair_for_incomplete(self): + traj = make_incomplete_trajectory() + failures, severity, repair = classify_failures(traj) + assert repair is False # incomplete = not repair worthy + + def test_severity_high_for_multiple_failures(self): + traj = make_hallucinated_args_trajectory() + # Inject extra failures by making it also incomplete + traj.turns = [] # no turns → incomplete too + failures, severity, _ = classify_failures(traj) + assert severity in ("medium", "high") + + +# --------------------------------------------------------------------------- +# Recovery pattern mining +# --------------------------------------------------------------------------- + +class TestRecoveryPatterns: + def test_retry_success_detected(self): + traj = make_retry_trajectory() + patterns, success = mine_recovery_patterns(traj) + assert "retry_success" in patterns + assert success is True + + def test_fallback_success_detected(self): + traj = make_multilingual_recovery_trajectory() + patterns, success = mine_recovery_patterns(traj) + assert "fallback_tool_success" in patterns + assert success is True + + def test_clarification_loop_detected(self): + traj = make_multilingual_recovery_trajectory() + assert detect_clarification_loop(traj) is True + + def test_no_recovery_in_clean_trajectory(self): + traj = make_clean_trajectory() + patterns, success = mine_recovery_patterns(traj) + assert len(patterns) == 0 + assert success is False + + def test_retry_failure_pattern(self): + from training_setup_logs.trajectory.models import ToolCall, ToolCallStatus + tool_calls = [ + ToolCall(tool_name="foo", arguments={}, status=ToolCallStatus.FAILURE), + ToolCall(tool_name="foo", arguments={}, status=ToolCallStatus.FAILURE, retry_of=0), + ] + patterns = detect_retry_patterns(tool_calls) + assert "retry_failure" in patterns + + +# --------------------------------------------------------------------------- +# Difficulty classification +# --------------------------------------------------------------------------- + +class TestDifficulty: + def test_clean_simple_trajectory(self): + traj = make_clean_trajectory() + assert classify_difficulty(traj) == DifficultyLevel.SIMPLE + + def test_multilingual_recovery_is_hard(self): + traj = make_multilingual_recovery_trajectory() + diff = classify_difficulty(traj) + assert diff in (DifficultyLevel.MODERATE, DifficultyLevel.HARD) + + def test_score_is_float_in_range(self): + for traj in ALL_TRAJECTORIES: + score = compute_difficulty_score(traj) + assert 0.0 <= score <= 1.0, f"Score {score} out of range for {traj.trajectory_id}" + + def test_score_increases_with_complexity(self): + clean_score = compute_difficulty_score(make_clean_trajectory()) + complex_score = compute_difficulty_score(make_multilingual_recovery_trajectory()) + assert complex_score > clean_score + + +# --------------------------------------------------------------------------- +# Tool efficiency +# --------------------------------------------------------------------------- + +class TestEfficiency: + def test_clean_trajectory_efficient(self): + traj = make_clean_trajectory() + score = compute_efficiency_score(traj.tool_calls) + assert score == 1.0 + + def test_redundant_calls_reduce_efficiency(self): + traj = make_redundant_trajectory() + score = compute_efficiency_score(traj.tool_calls) + assert score < 1.0 + + def test_empty_tool_calls_score_one(self): + assert compute_efficiency_score([]) == 1.0 + + +# --------------------------------------------------------------------------- +# Hard example mining +# --------------------------------------------------------------------------- + +class TestHardExampleMiner: + def test_multilingual_recovery_tagged(self): + traj = make_multilingual_recovery_trajectory() + analyzed = analyze_trajectory(traj) + assert analyzed.analysis is not None + tags = analyzed.analysis.hard_example_tags + assert any(t in tags for t in [ + HardExampleTag.MULTILINGUAL_EDGE, + HardExampleTag.RECOVERY_HEAVY, + ]) + + def test_clean_trajectory_sft_worthy(self): + traj = make_clean_trajectory() + analyzed = analyze_trajectory(traj) + # Clean + efficient + non-trivial → SFT_WORTHY + tags = analyzed.analysis.hard_example_tags + # May be SFT_WORTHY or empty depending on difficulty threshold + assert isinstance(tags, list) + + def test_hallucinated_trajectory_repair_worthy(self): + traj = make_hallucinated_args_trajectory() + analyzed = analyze_trajectory(traj) + tags = analyzed.analysis.hard_example_tags + assert HardExampleTag.REPAIR_WORTHY in tags + + def test_is_hard_example_logic(self): + assert is_hard_example([HardExampleTag.RECOVERY_HEAVY]) is True + assert is_hard_example([HardExampleTag.MULTILINGUAL_EDGE]) is True + assert is_hard_example([]) is False + + +# --------------------------------------------------------------------------- +# Full analyzer pipeline +# --------------------------------------------------------------------------- + +class TestAnalyzer: + def test_all_trajectories_analyzed(self): + for traj in ALL_TRAJECTORIES: + analyzed = analyze_trajectory(traj) + assert analyzed.analysis is not None + + def test_analysis_fields_populated(self): + traj = make_multilingual_recovery_trajectory() + analyzed = analyze_trajectory(traj) + a = analyzed.analysis + assert a.tool_count >= 0 + assert isinstance(a.efficiency_score, float) + assert isinstance(a.hard_example, bool) + assert isinstance(a.failure_types, list) + assert isinstance(a.recovery_patterns, list) + + def test_batch_summary(self): + analyzed = [analyze_trajectory(t) for t in ALL_TRAJECTORIES] + summary = batch_summary(analyzed) + assert summary["total_trajectories"] == len(ALL_TRAJECTORIES) + assert "difficulty_distribution" in summary + assert "hard_examples" in summary + assert "avg_efficiency_score" in summary + + def test_language_detection(self): + traj = make_multilingual_recovery_trajectory() + lang = detect_trajectory_language(traj) + assert lang != "unknown" + + +# --------------------------------------------------------------------------- +# Language detection in multilingual metrics +# --------------------------------------------------------------------------- + +class TestMultilingualMetrics: + def test_language_distribution(self): + from training_setup_logs.multilingual.multilingual_metrics import ( + script_distribution, + diversity_score, + ) + from tests.fixtures import WEATHER_QUERIES_MULTILINGUAL + + dist = script_distribution(WEATHER_QUERIES_MULTILINGUAL) + assert isinstance(dist, dict) + assert sum(dist.values()) == len(WEATHER_QUERIES_MULTILINGUAL) + + def test_diversity_score_range(self): + from training_setup_logs.multilingual.multilingual_metrics import diversity_score + from tests.fixtures import WEATHER_QUERIES_MULTILINGUAL + + diverse_score = diversity_score(WEATHER_QUERIES_MULTILINGUAL) + mono_score = diversity_score(["the cat sat on the mat"] * 5) + assert diverse_score >= mono_score # multilingual set should be more diverse + assert 0.0 <= diverse_score <= 1.0 + assert 0.0 <= mono_score <= 1.0 + + def test_empty_diversity(self): + from training_setup_logs.multilingual.multilingual_metrics import diversity_score + assert diversity_score([]) == 0.0 \ No newline at end of file diff --git a/training_setup_logs/__init__.py b/training_setup_logs/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/training_setup_logs/multilingual/__init__.py b/training_setup_logs/multilingual/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/training_setup_logs/multilingual/indic_normalizer.py b/training_setup_logs/multilingual/indic_normalizer.py new file mode 100644 index 0000000..14e8b86 --- /dev/null +++ b/training_setup_logs/multilingual/indic_normalizer.py @@ -0,0 +1,160 @@ +""" +indic_normalizer.py +------------------- +Normalization utilities for Indic (primarily Devanagari/Hindi) and +mixed-script text. Handles Unicode quirks, punctuation, numerals, +whitespace and common matra variants. + +Design note: we do *safe* normalization only — nothing that changes +lexical meaning. Transliteration-level canonicalization lives in +transliteration.py. +""" + +import re +import unicodedata +from typing import Optional + +# --------------------------------------------------------------------------- +# Devanagari Unicode ranges +# --------------------------------------------------------------------------- +DEVANAGARI_RANGE = (0x0900, 0x097F) +DEVANAGARI_EXTENDED_RANGE = (0x1CD0, 0x1CFF) + +# Devanagari digit to ASCII digit mapping +DEVA_DIGIT_MAP = str.maketrans("०१२३४५६७८९", "0123456789") + +# Common punctuation that should map to ASCII equivalents +PUNCTUATION_MAP = str.maketrans( + "\u2018\u2019\u201c\u201d\u2013\u2014\u00b7", + "''\"\"---", +) + +# Visually redundant / rarely meaningful matras we normalise away safely +# (anusvara variant, chandrabindu where it's stylistic, etc.) +# We keep nukta because it changes phoneme. +_SAFE_MATRA_COLLAPSES = { + "\u0902": "\u0902", # anusvara — keep as-is (normalize to canonical) + "\u0900": "\u0902", # inverted chandrabindu → anusvara (safe for NLP) + "\u0901": "\u0902", # chandrabindu → anusvara (safe approximation) +} + + +def normalize_unicode(text: str, form: str = "NFC") -> str: + """ + Apply Unicode normalization. + + Args: + text: Input string. + form: Unicode normalization form. NFC is the right default + for Devanagari (composed forms). + + Returns: + Normalized string. + """ + return unicodedata.normalize(form, text) + + +def normalize_devanagari_numerals(text: str) -> str: + """Replace Devanagari digits (०-९) with ASCII digits (0-9).""" + return text.translate(DEVA_DIGIT_MAP) + + +def normalize_punctuation(text: str) -> str: + """ + Map curly quotes, em-dashes, and other typographic characters to + their ASCII equivalents so downstream tokenizers are not surprised. + """ + return text.translate(PUNCTUATION_MAP) + + +def normalize_whitespace(text: str) -> str: + """ + Collapse all whitespace sequences (including zero-width joiners, + non-breaking spaces, etc.) to a single ASCII space and strip edges. + """ + # Replace zero-width chars silently + text = re.sub(r"[\u200b\u200c\u200d\ufeff]", "", text) + # Collapse whitespace + text = re.sub(r"\s+", " ", text) + return text.strip() + + +def normalize_matras(text: str) -> str: + """ + Collapse stylistically equivalent matra variants to a single + canonical form. This is *safe* — it does not change meaning. + """ + result = [] + for ch in text: + result.append(_SAFE_MATRA_COLLAPSES.get(ch, ch)) + return "".join(result) + + +def is_devanagari(text: str) -> bool: + """Return True if the text contains at least one Devanagari codepoint.""" + return any(DEVANAGARI_RANGE[0] <= ord(c) <= DEVANAGARI_RANGE[1] for c in text) + + +def normalize_devanagari(text: str) -> str: + """ + Full Devanagari normalization pipeline: + 1. Unicode NFC + 2. Numeral normalization + 3. Punctuation cleanup + 4. Matra normalization + 5. Whitespace normalization + """ + text = normalize_unicode(text) + text = normalize_devanagari_numerals(text) + text = normalize_punctuation(text) + text = normalize_matras(text) + text = normalize_whitespace(text) + return text + + +def normalize_indic_text(text: str, script: Optional[str] = None) -> str: + """ + Entry-point normalizer that dispatches on script. + + Args: + text: Raw input. + script: Optional hint — 'devanagari', 'latin' (transliterated), + 'mixed', or None (auto-detect). + + Returns: + Normalized text ready for embedding or comparison. + """ + text = normalize_unicode(text) + text = normalize_punctuation(text) + text = normalize_whitespace(text) + + # Always normalize Devanagari numerals — safe even in mixed text + text = normalize_devanagari_numerals(text) + + detected_script = script or (_detect_script(text)) + + if detected_script in ("devanagari", "mixed"): + text = normalize_matras(text) + + return text.lower() + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + +def _detect_script(text: str) -> str: + """ + Heuristic script detector. + + Returns: + 'devanagari' | 'latin' | 'mixed' + """ + has_deva = any(DEVANAGARI_RANGE[0] <= ord(c) <= DEVANAGARI_RANGE[1] for c in text) + has_latin = any("a" <= c.lower() <= "z" for c in text) + + if has_deva and has_latin: + return "mixed" + if has_deva: + return "devanagari" + return "latin" \ No newline at end of file diff --git a/training_setup_logs/multilingual/leakage_detector.py b/training_setup_logs/multilingual/leakage_detector.py new file mode 100644 index 0000000..42609a5 --- /dev/null +++ b/training_setup_logs/multilingual/leakage_detector.py @@ -0,0 +1,236 @@ +""" +leakage_detector.py +------------------- +Detect semantic and transliteration-level leakage between data splits +(e.g. train vs eval, train vs test). + +A "leak" happens when a sample in eval/test is semantically equivalent to +a sample in train — meaning the model has effectively seen it during +training. With multilingual data this is especially subtle because: + + train: "कल मौसम कैसा रहेगा" (Devanagari) + eval: "kal mausam kaisa rahega" (transliteration) + +are semantically the same but string-different. + +This module detects both kinds: +- Semantic leakage (embedding cosine similarity ≥ threshold) +- Transliteration leakage (same canonical form across scripts) +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Tuple + +import numpy as np + +from .indic_normalizer import normalize_indic_text +from .transliteration import canonicalize_transliteration, detect_transliterated_hindi +from .semantic_dedup import _cosine_matrix, _get_embed_model + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Data structures +# --------------------------------------------------------------------------- + + +@dataclass +class LeakagePair: + train_idx: int + eval_idx: int + train_text: str + eval_text: str + similarity: float + leak_type: str # 'semantic' | 'transliteration' | 'exact' + + +@dataclass +class LeakageReport: + train_size: int + eval_size: int + semantic_threshold: float + + exact_leaks: List[LeakagePair] = field(default_factory=list) + transliteration_leaks: List[LeakagePair] = field(default_factory=list) + semantic_leaks: List[LeakagePair] = field(default_factory=list) + + @property + def total_leaks(self) -> int: + return len(self.exact_leaks) + len(self.transliteration_leaks) + len(self.semantic_leaks) + + @property + def leak_rate(self) -> float: + """Fraction of eval items that have at least one leak.""" + leaking_eval = set() + for lp in self.exact_leaks + self.transliteration_leaks + self.semantic_leaks: + leaking_eval.add(lp.eval_idx) + return len(leaking_eval) / self.eval_size if self.eval_size else 0.0 + + def to_dict(self) -> dict: + return { + "train_size": self.train_size, + "eval_size": self.eval_size, + "semantic_threshold": self.semantic_threshold, + "exact_leaks": len(self.exact_leaks), + "transliteration_leaks": len(self.transliteration_leaks), + "semantic_leaks": len(self.semantic_leaks), + "total_leaks": self.total_leaks, + "leak_rate": round(self.leak_rate, 4), + "cross_split_leaks": [ + { + "train_idx": lp.train_idx, + "eval_idx": lp.eval_idx, + "similarity": lp.similarity, + "leak_type": lp.leak_type, + "train_text": lp.train_text[:120], + "eval_text": lp.eval_text[:120], + } + for lp in ( + self.exact_leaks + self.transliteration_leaks + self.semantic_leaks + ) + ], + } + + +# --------------------------------------------------------------------------- +# Normalisation helpers +# --------------------------------------------------------------------------- + + +def _canonical(text: str) -> str: + """Produce script-agnostic canonical form for exact/transliteration matching.""" + base = normalize_indic_text(text) + if detect_transliterated_hindi(base): + base = canonicalize_transliteration(base) + return base + + +# --------------------------------------------------------------------------- +# Leakage detection +# --------------------------------------------------------------------------- + + +def detect_leakage( + train_queries: List[str], + eval_queries: List[str], + semantic_threshold: float = 0.85, + batch_size: int = 256, + skip_semantic: bool = False, +) -> LeakageReport: + """ + Detect train→eval leakage at three levels: + + 1. **Exact** — identical after normalization. + 2. **Transliteration** — same canonical transliteration key. + 3. **Semantic** — cosine similarity ≥ `semantic_threshold` (most powerful, + catches multilingual paraphrases). + + Args: + train_queries: List of training split queries. + eval_queries: List of eval/test split queries. + semantic_threshold: Cosine similarity above which two texts are + considered leaks. + batch_size: Embedding batch size. + skip_semantic: If True, skip embedding-based detection (faster + but misses paraphrase leaks). + + Returns: + LeakageReport with all detected leak pairs. + """ + report = LeakageReport( + train_size=len(train_queries), + eval_size=len(eval_queries), + semantic_threshold=semantic_threshold, + ) + + if not train_queries or not eval_queries: + return report + + # ------------------------------------------------------------------ + # Pass 1: Exact + transliteration leaks (O(n+m), cheap) + # ------------------------------------------------------------------ + train_canonical = [_canonical(q) for q in train_queries] + eval_canonical = [_canonical(q) for q in eval_queries] + + train_canon_set: dict[str, int] = {} # canonical → first train idx + for idx, canon in enumerate(train_canonical): + if canon not in train_canon_set: + train_canon_set[canon] = idx + + for eval_idx, (eq, ec) in enumerate(zip(eval_queries, eval_canonical)): + if ec in train_canon_set: + train_idx = train_canon_set[ec] + tq = train_queries[train_idx] + sim = 1.0 if tq == eq else 0.98 + + leak = LeakagePair( + train_idx=train_idx, + eval_idx=eval_idx, + train_text=tq, + eval_text=eq, + similarity=sim, + leak_type="exact" if tq == eq else "transliteration", + ) + if tq == eq: + report.exact_leaks.append(leak) + else: + report.transliteration_leaks.append(leak) + + # ------------------------------------------------------------------ + # Pass 2: Semantic leaks via multilingual embeddings + # ------------------------------------------------------------------ + if not skip_semantic: + # We only embed eval items not already caught above + already_leaked_eval = { + lp.eval_idx for lp in report.exact_leaks + report.transliteration_leaks + } + remaining_eval_indices = [ + i for i in range(len(eval_queries)) if i not in already_leaked_eval + ] + + if remaining_eval_indices: + remaining_eval_texts = [eval_queries[i] for i in remaining_eval_indices] + all_texts = train_queries + remaining_eval_texts + + model = _get_embed_model() + logger.info( + "Embedding %d texts for semantic leakage detection...", len(all_texts) + ) + embeddings = model.encode( + all_texts, batch_size=batch_size, show_progress_bar=False + ) + embeddings = np.array(embeddings, dtype=np.float32) + + train_emb = embeddings[: len(train_queries)] + eval_emb = embeddings[len(train_queries) :] + + # Cross-split similarity (train × eval subset) + train_norm = train_emb / np.maximum( + np.linalg.norm(train_emb, axis=1, keepdims=True), 1e-9 + ) + eval_norm = eval_emb / np.maximum( + np.linalg.norm(eval_emb, axis=1, keepdims=True), 1e-9 + ) + cross_sim = eval_norm @ train_norm.T # shape: (n_eval_remaining, n_train) + + for local_eval_idx, global_eval_idx in enumerate(remaining_eval_indices): + row = cross_sim[local_eval_idx] + best_train_idx = int(np.argmax(row)) + score = float(row[best_train_idx]) + if score >= semantic_threshold: + report.semantic_leaks.append( + LeakagePair( + train_idx=best_train_idx, + eval_idx=global_eval_idx, + train_text=train_queries[best_train_idx], + eval_text=eval_queries[global_eval_idx], + similarity=round(score, 4), + leak_type="semantic", + ) + ) + + return report \ No newline at end of file diff --git a/training_setup_logs/multilingual/multilingual_metrics.py b/training_setup_logs/multilingual/multilingual_metrics.py new file mode 100644 index 0000000..a976247 --- /dev/null +++ b/training_setup_logs/multilingual/multilingual_metrics.py @@ -0,0 +1,208 @@ +""" +multilingual_metrics.py +----------------------- +Compute and report dataset-level multilingual quality metrics. + +Provides: +- script_distribution() → breakdown by script/language +- diversity_score() → entropy-based diversity measure +- generate_dedup_report() → full deduplication summary +- generate_quality_report() → combined quality report dict +""" + +from __future__ import annotations + +import math +import re +from collections import Counter +from typing import Dict, List, Optional + +from .indic_normalizer import is_devanagari, _detect_script +from .transliteration import detect_transliterated_hindi +from .semantic_dedup import DeduplicationResult +from .leakage_detector import LeakageReport + + +# --------------------------------------------------------------------------- +# Script / Language detection +# --------------------------------------------------------------------------- + +_ENGLISH_COMMON = frozenset( + ["the", "a", "is", "are", "was", "were", "and", "or", "but", "in", "on", + "at", "to", "of", "for", "with", "how", "what", "when", "where", "why"] +) + + +def classify_query_language(text: str) -> str: + """ + Classify query into one of: + - 'hi' : Hindi (Devanagari script) + - 'hi-latn' : Transliterated Hindi (Latin script) + - 'hi-en-mixed': Code-switched Hindi+English + - 'en' : English + - 'other' : Unknown / other script + + Args: + text: Raw query text. + + Returns: + Language code string. + """ + script = _detect_script(text) + tokens = set(re.findall(r"\b[a-zA-Z]+\b", text.lower())) + + if script == "devanagari": + # May still have English tokens mixed in + if tokens - _ENGLISH_COMMON: # significant non-English Latin tokens + return "hi-en-mixed" + return "hi" + + if script == "mixed": + return "hi-en-mixed" + + # Latin script — is it transliterated Hindi or English? + if detect_transliterated_hindi(text): + english_ratio = len(tokens & _ENGLISH_COMMON) / max(len(tokens), 1) + if english_ratio > 0.4: + return "hi-en-mixed" + return "hi-latn" + + return "en" + + +def script_distribution(queries: List[str]) -> Dict[str, int]: + """ + Count queries per language/script category. + + Returns: + {language_code: count} + """ + counts: Counter = Counter() + for q in queries: + counts[classify_query_language(q)] += 1 + return dict(counts) + + +def diversity_score(queries: List[str]) -> float: + """ + Compute normalized entropy of script distribution as a diversity proxy. + + Score of 1.0 = maximally diverse; 0.0 = all same script. + """ + dist = script_distribution(queries) + total = sum(dist.values()) + if total == 0: + return 0.0 + + n_categories = len(dist) + if n_categories <= 1: + return 0.0 + + entropy = 0.0 + for count in dist.values(): + p = count / total + if p > 0: + entropy -= p * math.log2(p) + + max_entropy = math.log2(n_categories) + return round(entropy / max_entropy, 4) if max_entropy > 0 else 0.0 + + +# --------------------------------------------------------------------------- +# Report generators +# --------------------------------------------------------------------------- + + +def generate_dedup_report( + queries: List[str], + result: DeduplicationResult, +) -> dict: + """ + Generate a structured deduplication quality report. + + Args: + queries: Original query list. + result: Output of semantic_dedup.deduplicate(). + + Returns: + Report dict suitable for JSON serialization. + """ + lang_dist = script_distribution(queries) + diversity = diversity_score(queries) + + # Cluster size distribution + cluster_sizes: Counter = Counter() + for idx in range(result.total_inputs): + root = result.cluster_map.get(idx, idx) + cluster_sizes[root] += 1 + + size_hist: Counter = Counter(cluster_sizes.values()) + + return { + "total_queries": result.total_inputs, + "unique_queries": len(result.unique_indices), + "removed_duplicates": len(result.removed_indices), + "dedup_ratio": result.dedup_ratio, + "semantic_clusters": result.num_clusters, + "language_distribution": lang_dist, + "diversity_score": diversity, + "cluster_size_histogram": {str(k): v for k, v in sorted(size_hist.items())}, + "cross_language_duplicates": sum( + 1 + for i, j, _ in result.similarity_pairs + if classify_query_language(queries[i]) != classify_query_language(queries[j]) + ), + "transliteration_duplicates": sum( + 1 + for i, j, _ in result.similarity_pairs + if { + classify_query_language(queries[i]), + classify_query_language(queries[j]), + } + == {"hi", "hi-latn"} + ), + } + + +def generate_leakage_report(leakage: LeakageReport) -> dict: + """Wrap LeakageReport into a clean summary dict.""" + return { + "train_size": leakage.train_size, + "eval_size": leakage.eval_size, + "exact_leaks": len(leakage.exact_leaks), + "transliteration_leaks": len(leakage.transliteration_leaks), + "semantic_leaks": len(leakage.semantic_leaks), + "total_leaks": leakage.total_leaks, + "train_eval_leakage_risk": leakage.leak_rate, + } + + +def generate_quality_report( + queries: List[str], + dedup_result: Optional[DeduplicationResult] = None, + leakage_result: Optional[LeakageReport] = None, +) -> dict: + """ + Produce a full multilingual dataset quality report. + + Args: + queries: All queries in the dataset. + dedup_result: Output of semantic_dedup.deduplicate() (optional). + leakage_result: Output of leakage_detector.detect_leakage() (optional). + + Returns: + Nested report dict. + """ + report: dict = { + "total_queries": len(queries), + "language_distribution": script_distribution(queries), + "diversity_score": diversity_score(queries), + } + + if dedup_result is not None: + report["deduplication"] = generate_dedup_report(queries, dedup_result) + + if leakage_result is not None: + report["leakage"] = generate_leakage_report(leakage_result) + + return report \ No newline at end of file diff --git a/training_setup_logs/multilingual/semantic_dedup.py b/training_setup_logs/multilingual/semantic_dedup.py new file mode 100644 index 0000000..6299a54 --- /dev/null +++ b/training_setup_logs/multilingual/semantic_dedup.py @@ -0,0 +1,245 @@ +""" +semantic_dedup.py +----------------- +Multilingual semantic deduplication engine. + +Core responsibilities: +- Embed queries using a multilingual sentence encoder. +- Cluster near-duplicates above a configurable similarity threshold. +- Return deduplicated indices + cluster membership for every input. + +Supports: +- Hindi / Devanagari +- Transliterated Hindi +- Code-switched (Hindi+English) +- Pure English + +Model: sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2 + (fallback: TF-IDF cosine for environments without GPU/sentence-transformers) +""" + +from __future__ import annotations + +import hashlib +import logging +from dataclasses import dataclass, field +from typing import List, Optional, Tuple + +import numpy as np + +from .indic_normalizer import normalize_indic_text +from .transliteration import canonicalize_transliteration, detect_transliterated_hindi + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Data structures +# --------------------------------------------------------------------------- + + +@dataclass +class DeduplicationResult: + """Output of a deduplication pass.""" + + total_inputs: int + unique_indices: List[int] # indices kept as representatives + cluster_map: dict[int, int] # query_idx → cluster_id + cluster_representatives: dict[int, int] # cluster_id → representative_idx + similarity_pairs: List[Tuple[int, int, float]] # (i, j, score) for near-dups + removed_indices: List[int] + dedup_ratio: float # fraction removed + + @property + def num_clusters(self) -> int: + return len(self.cluster_representatives) + + +# --------------------------------------------------------------------------- +# Embedding backend (lazy load) +# --------------------------------------------------------------------------- + +_EMBED_MODEL_NAME = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" +_embed_model = None # loaded on first use + + +def _get_embed_model(): + global _embed_model + if _embed_model is None: + try: + from sentence_transformers import SentenceTransformer + + logger.info("Loading multilingual embedding model: %s", _EMBED_MODEL_NAME) + _embed_model = SentenceTransformer(_EMBED_MODEL_NAME) + except ImportError: + logger.warning( + "sentence-transformers not available. " + "Falling back to TF-IDF cosine similarity." + ) + _embed_model = _TFIDFFallback() + return _embed_model + + +class _TFIDFFallback: + """ + Lightweight TF-IDF cosine similarity fallback. + Used when sentence-transformers is unavailable. + Not as accurate for multilingual content but avoids hard dependency. + """ + + def encode(self, texts: List[str], **kwargs) -> np.ndarray: + from sklearn.feature_extraction.text import TfidfVectorizer + + vec = TfidfVectorizer(analyzer="char_wb", ngram_range=(2, 4), min_df=1) + matrix = vec.fit_transform(texts) + # Convert sparse → dense float32 + return matrix.toarray().astype(np.float32) + + +# --------------------------------------------------------------------------- +# Normalisation helpers +# --------------------------------------------------------------------------- + + +def _preprocess(text: str) -> str: + """ + Apply script-aware preprocessing before embedding. + Transliterated variants get canonicalized; Devanagari gets NFC + cleaned. + """ + if detect_transliterated_hindi(text): + text = canonicalize_transliteration(text) + return normalize_indic_text(text) + + +# --------------------------------------------------------------------------- +# Cosine similarity utilities +# --------------------------------------------------------------------------- + + +def _cosine_matrix(embeddings: np.ndarray) -> np.ndarray: + """Compute pairwise cosine similarity matrix.""" + norms = np.linalg.norm(embeddings, axis=1, keepdims=True) + norms = np.where(norms == 0, 1e-9, norms) + normed = embeddings / norms + return normed @ normed.T + + +# --------------------------------------------------------------------------- +# Main deduplication engine +# --------------------------------------------------------------------------- + + +def deduplicate( + queries: List[str], + threshold: float = 0.85, + batch_size: int = 256, +) -> DeduplicationResult: + """ + Detect and cluster near-duplicate queries. + + Algorithm + --------- + 1. Preprocess every query (script-aware normalization). + 2. Embed with multilingual sentence encoder (or TF-IDF fallback). + 3. Compute pairwise cosine similarity. + 4. Union-Find clustering: merge pairs above `threshold`. + 5. Pick earliest-index representative per cluster. + + Args: + queries: List of raw query strings. + threshold: Cosine similarity above which two queries are near-duplicates. + batch_size: Embedding batch size. + + Returns: + DeduplicationResult with cluster assignments and kept indices. + """ + if not queries: + return DeduplicationResult( + total_inputs=0, + unique_indices=[], + cluster_map={}, + cluster_representatives={}, + similarity_pairs=[], + removed_indices=[], + dedup_ratio=0.0, + ) + + n = len(queries) + processed = [_preprocess(q) for q in queries] + + # Embed + model = _get_embed_model() + embeddings = model.encode(processed, batch_size=batch_size, show_progress_bar=False) + embeddings = np.array(embeddings, dtype=np.float32) + + # Pairwise similarity + sim_matrix = _cosine_matrix(embeddings) + + # Collect near-duplicate pairs + similarity_pairs: List[Tuple[int, int, float]] = [] + for i in range(n): + for j in range(i + 1, n): + score = float(sim_matrix[i, j]) + if score >= threshold: + similarity_pairs.append((i, j, round(score, 4))) + + # Union-Find + parent = list(range(n)) + + def find(x: int) -> int: + while parent[x] != x: + parent[x] = parent[parent[x]] + x = parent[x] + return x + + def union(x: int, y: int) -> None: + px, py = find(x), find(y) + if px != py: + # Always keep lower index as root + if px < py: + parent[py] = px + else: + parent[px] = py + + for i, j, _ in similarity_pairs: + union(i, j) + + # Build cluster assignments + cluster_map: dict[int, int] = {} # idx → cluster root + cluster_members: dict[int, list] = {} + + for idx in range(n): + root = find(idx) + cluster_map[idx] = root + cluster_members.setdefault(root, []).append(idx) + + cluster_representatives: dict[int, int] = { + root: min(members) for root, members in cluster_members.items() + } + + unique_indices = sorted(cluster_representatives.values()) + removed_indices = [i for i in range(n) if i not in set(unique_indices)] + dedup_ratio = len(removed_indices) / n if n else 0.0 + + return DeduplicationResult( + total_inputs=n, + unique_indices=unique_indices, + cluster_map=cluster_map, + cluster_representatives=cluster_representatives, + similarity_pairs=similarity_pairs, + removed_indices=removed_indices, + dedup_ratio=round(dedup_ratio, 4), + ) + + +# --------------------------------------------------------------------------- +# Stable hash helper (for caching / offline dedup) +# --------------------------------------------------------------------------- + + +def query_fingerprint(text: str) -> str: + """ + Produce a stable fingerprint for a query after normalization. + Useful for cheap exact-duplicate detection before running embeddings. + """ + normalized = _preprocess(text) + return hashlib.sha256(normalized.encode("utf-8")).hexdigest()[:16] \ No newline at end of file diff --git a/training_setup_logs/multilingual/transliteration.py b/training_setup_logs/multilingual/transliteration.py new file mode 100644 index 0000000..7c2c3ce --- /dev/null +++ b/training_setup_logs/multilingual/transliteration.py @@ -0,0 +1,172 @@ +""" +transliteration.py +------------------ +Utilities for detecting and canonicalizing transliterated Hindi/Indic text. + +Philosophy +---------- +We do NOT attempt perfect transliteration. We aim for *practical* +normalization so that "mausam", "mosam", "mousam" hash/cluster together +for deduplication purposes. + +The module exposes: +- detect_transliterated_hindi() → bool +- canonicalize_transliteration() → str +- transliteration_key() → str (stable key for grouping) +""" + +import re +from functools import lru_cache +from typing import Optional + +# --------------------------------------------------------------------------- +# Detection heuristics +# --------------------------------------------------------------------------- + +# High-frequency Hindi words that appear in Latin-script transliterations +_HINDI_TRANSLITERATION_TOKENS = frozenset( + [ + "hai", "hain", "kya", "kab", "kaise", "kaisa", "kaisi", "kal", + "aaj", "abhi", "yahan", "wahan", "mausam", "mosam", "mousam", + "accha", "theek", "nahi", "nahin", "hoga", "hogi", "rahega", + "batao", "bata", "karo", "karein", "chahiye", "chahie", + "kyun", "kyunki", "lekin", "aur", "par", "ya", "toh", + "matlab", "samajh", "samjha", "seedha", "sidha", + "bahut", "thoda", "zyada", "kam", "bilkul", "zaroor", + "weather", "mausam", + ] +) + +# Regex patterns for transliterated content +_DEVA_LIKE_PATTERNS = [ + r"\b(aa|ee|oo)\b", # long vowel markers (ai/au too common in English) + r"\b(kh|gh|jh|dh|ph|bh)\w+", # aspirated consonants (avoid 'th', 'ch', 'sh' — too English) +] +_DEVA_LIKE_RE = re.compile("|".join(_DEVA_LIKE_PATTERNS), re.IGNORECASE) + + +def detect_transliterated_hindi(text: str) -> bool: + """ + Return True if the text is likely transliterated Hindi/Indic in Latin script. + + Uses two signals: + 1. Token overlap with known Hindi transliteration vocabulary. + 2. Presence of phonemic patterns characteristic of Indic languages. + """ + if not text.strip(): + return False + + tokens = set(re.findall(r"\b\w+\b", text.lower())) + vocab_hit = bool(tokens & _HINDI_TRANSLITERATION_TOKENS) + pattern_hit = bool(_DEVA_LIKE_RE.search(text)) + + # "weather" alone in an English sentence should not trigger detection. + # Require at least one *non-English-function-word* Hindi token. + _HINDI_STRONG_TOKENS = frozenset([ + "hai", "hain", "kya", "kab", "kaise", "kaisa", "kaisi", "kal", + "aaj", "abhi", "yahan", "wahan", "mausam", "mosam", "mousam", + "rahega", "batao", "bata", "karo", "karein", "chahiye", + "kyun", "kyunki", "lekin", "matlab", "bahut", "thoda", + "zyada", "kam", "bilkul", "zaroor", + ]) + strong_vocab_hit = bool(tokens & _HINDI_STRONG_TOKENS) + + return strong_vocab_hit or (pattern_hit and len(tokens) >= 2) + + +# --------------------------------------------------------------------------- +# Spelling variant maps — curated, not exhaustive +# --------------------------------------------------------------------------- + +# Maps variant spellings → canonical form +_VARIANT_MAP: dict[str, str] = { + # Weather / common query words + "mosam": "mausam", + "mousam": "mausam", + "mausam": "mausam", + "mossam": "mausam", + # Negation + "nahin": "nahi", + "nahi": "nahi", + "nai": "nahi", + # Question words + "kaise": "kaisa", + "kaisi": "kaisa", + "kaisa": "kaisa", + # Copulas + "hain": "hai", + "hai": "hai", + # Future markers + "rahega": "rahega", + "rahegi": "rahega", + "rahenge": "rahega", + # Confirmations + "accha": "acha", + "acha": "acha", + "achha": "acha", + # Enough / okay + "theek": "thik", + "thik": "thik", + "thikhai": "thik hai", + # Long-vowel simplifications (common in casual typing) + "aa": "a", + "ee": "i", + "oo": "u", +} + + +@lru_cache(maxsize=4096) +def canonicalize_transliteration(text: str) -> str: + """ + Normalize spelling variants in transliterated Hindi. + + Steps: + 1. Lowercase. + 2. Replace known variant tokens with canonical forms. + 3. Collapse repeated characters (e.g. "aacccha" → "acha"). + 4. Strip extra whitespace. + + Args: + text: Transliterated Hindi string (Latin script). + + Returns: + Canonicalized string suitable for similarity comparison. + """ + text = text.lower().strip() + + # Token-level replacement + tokens = text.split() + canonical_tokens = [_VARIANT_MAP.get(tok, tok) for tok in tokens] + text = " ".join(canonical_tokens) + + # Collapse 3+ repeated chars → 2 (handles "aacchha" etc.) + text = re.sub(r"(.)\1{2,}", r"\1\1", text) + + return text.strip() + + +def transliteration_key(text: str, script_hint: Optional[str] = None) -> str: + """ + Produce a stable grouping key for a query regardless of transliteration + variant. + + This key is NOT reversible — it is used only for bucketing/deduplication. + + Args: + text: Input query (may be Devanagari, Latin, or mixed). + script_hint: 'devanagari' | 'latin' | 'mixed' | None + + Returns: + Normalized key string. + """ + # Detect if it looks like transliteration + if script_hint == "devanagari": + # For Devanagari, we just lowercase + collapse whitespace + return re.sub(r"\s+", " ", text.strip()).lower() + + if script_hint == "latin" or detect_transliterated_hindi(text): + return canonicalize_transliteration(text) + + # Mixed or unknown → apply both passes, take result + text = canonicalize_transliteration(text) + return text \ No newline at end of file diff --git a/training_setup_logs/trajectory/__init__.py b/training_setup_logs/trajectory/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/training_setup_logs/trajectory/analyzer.py b/training_setup_logs/trajectory/analyzer.py new file mode 100644 index 0000000..a643834 --- /dev/null +++ b/training_setup_logs/trajectory/analyzer.py @@ -0,0 +1,184 @@ +""" +analyzer.py +----------- +Orchestrates the trajectory intelligence pipeline. + +Given a Trajectory, the analyzer: +1. Detects and classifies failures. +2. Mines recovery patterns. +3. Classifies difficulty. +4. Computes tool-use efficiency. +5. Tags hard examples. +6. Attaches full TrajectoryAnalysis to the trajectory. + +Usage +----- + from training_setup_logs.trajectory.analyzer import analyze_trajectory, analyze_batch + + analyzed = analyze_trajectory(trajectory) + print(analyzed.analysis.model_dump()) +""" + +from __future__ import annotations + +import logging +from typing import List, Optional + +from .difficulty import classify_difficulty, compute_difficulty_score +from .failure_classifier import classify_failures +from .hard_example_miner import is_hard_example, tag_hard_examples +from .metadata_enrichment import ( + compute_efficiency_score, + detect_trajectory_language, + enrich_metadata, + has_transliteration_variant, + is_code_switched, + is_multilingual_trajectory, +) +from .models import Trajectory, TrajectoryAnalysis +from .recovery_patterns import mine_recovery_patterns + +logger = logging.getLogger(__name__) + + +def analyze_trajectory(trajectory: Trajectory) -> Trajectory: + """ + Run the full trajectory intelligence pipeline on a single trajectory. + + Mutates trajectory.analysis in-place (also returns the trajectory + for convenience). + + Args: + trajectory: A Trajectory object (tool_calls + turns populated). + + Returns: + Same Trajectory with .analysis populated. + """ + tc = trajectory.tool_calls + + # 1. Failure classification + failure_types, severity, repair_candidate = classify_failures(trajectory) + + # 2. Recovery patterns + recovery_patterns, recovery_success = mine_recovery_patterns(trajectory) + + # 3. Difficulty + difficulty = classify_difficulty(trajectory) + + # 4. Tool efficiency + tool_count = len(tc) + efficiency_score = compute_efficiency_score(tc) + from .metadata_enrichment import _estimate_optimal_calls + estimated_optimal = _estimate_optimal_calls(tc) if tc else 0 + + # 5. Language metadata + detected_language = detect_trajectory_language(trajectory) + is_multi = is_multilingual_trajectory(trajectory) + is_cs = is_code_switched(trajectory) + is_translit = has_transliteration_variant(trajectory) + + # 6. Build partial analysis (needed by hard_example_miner) + analysis = TrajectoryAnalysis( + detected_language=detected_language, + is_multilingual=is_multi, + is_code_switched=is_cs, + transliteration_variant=is_translit, + difficulty=difficulty, + has_failure=len(failure_types) > 0, + failure_types=failure_types, + failure_severity=severity if severity != "none" else None, + has_recovery=len(recovery_patterns) > 0, + recovery_patterns=recovery_patterns, + recovery_success=recovery_success if recovery_patterns else None, + tool_count=tool_count, + estimated_optimal_calls=estimated_optimal, + efficiency_score=efficiency_score, + repair_candidate=repair_candidate, + ) + + # 7. Hard-example tagging + tags = tag_hard_examples(trajectory, analysis) + analysis.hard_example_tags = tags + analysis.hard_example = is_hard_example(tags) + + trajectory.analysis = analysis + return trajectory + + +def analyze_batch( + trajectories: List[Trajectory], + semantic_cluster_ids: Optional[dict[str, str]] = None, +) -> List[Trajectory]: + """ + Analyze a batch of trajectories. + + Args: + trajectories: List of Trajectory objects. + semantic_cluster_ids: Optional mapping {trajectory_id → cluster_id} + from the multilingual dedup layer. + + Returns: + List of analyzed Trajectory objects. + """ + results = [] + for traj in trajectories: + try: + analyzed = analyze_trajectory(traj) + if semantic_cluster_ids and analyzed.analysis: + analyzed.analysis.semantic_cluster_id = semantic_cluster_ids.get( + traj.trajectory_id + ) + results.append(analyzed) + except Exception as exc: + logger.error( + "Failed to analyze trajectory %s: %s", traj.trajectory_id, exc + ) + results.append(traj) + return results + + +def batch_summary(trajectories: List[Trajectory]) -> dict: + """ + Produce an aggregate summary report for a batch of analyzed trajectories. + + Args: + trajectories: List of analyzed Trajectory objects. + + Returns: + Summary dict. + """ + from collections import Counter + + total = len(trajectories) + analyses = [t.analysis for t in trajectories if t.analysis] + + difficulty_dist = Counter(a.difficulty.value for a in analyses) + failure_dist: Counter = Counter() + for a in analyses: + for f in a.failure_types: + failure_dist[f.value] += 1 + + tag_dist: Counter = Counter() + for a in analyses: + for tag in a.hard_example_tags: + tag_dist[tag.value] += 1 + + avg_efficiency = ( + sum(a.efficiency_score for a in analyses) / len(analyses) + if analyses else 0.0 + ) + multilingual_count = sum(1 for a in analyses if a.is_multilingual) + hard_count = sum(1 for a in analyses if a.hard_example) + recovery_count = sum(1 for a in analyses if a.has_recovery) + + return { + "total_trajectories": total, + "analyzed": len(analyses), + "difficulty_distribution": dict(difficulty_dist), + "failure_type_counts": dict(failure_dist), + "hard_example_tag_counts": dict(tag_dist), + "hard_examples": hard_count, + "with_recovery": recovery_count, + "multilingual": multilingual_count, + "avg_efficiency_score": round(avg_efficiency, 4), + } \ No newline at end of file diff --git a/training_setup_logs/trajectory/difficulty.py b/training_setup_logs/trajectory/difficulty.py new file mode 100644 index 0000000..79f3bc0 --- /dev/null +++ b/training_setup_logs/trajectory/difficulty.py @@ -0,0 +1,103 @@ +""" +difficulty.py +------------- +Classify trajectory difficulty: simple | moderate | hard. + +Difficulty is a composite signal based on: +- Number of tool calls +- Retry count +- Ambiguity indicators +- Multilingual complexity +- Recovery behavior +- Follow-up depth (multi-turn) +- Failure presence +""" + +from __future__ import annotations + +from .models import DifficultyLevel, Trajectory, ToolCallStatus + + +# --------------------------------------------------------------------------- +# Feature extraction +# --------------------------------------------------------------------------- + + +def _retry_count(trajectory: Trajectory) -> int: + return sum(1 for tc in trajectory.tool_calls if tc.retry_of is not None) + + +def _failure_count(trajectory: Trajectory) -> int: + return sum( + 1 for tc in trajectory.tool_calls + if tc.status in (ToolCallStatus.FAILURE, ToolCallStatus.MISSING_RETURN, + ToolCallStatus.HALLUCINATED) + ) + + +def _turn_count(trajectory: Trajectory) -> int: + return len(trajectory.turns) + + +def _unique_tools(trajectory: Trajectory) -> int: + return len({tc.tool_name for tc in trajectory.tool_calls}) + + +def _is_multilingual(trajectory: Trajectory) -> bool: + langs = {t.language for t in trajectory.turns if t.language} + return len(langs) > 1 + + +def _has_code_switch(trajectory: Trajectory) -> bool: + return any( + t.language in ("hi-en-mixed",) for t in trajectory.turns if t.language + ) + + +# --------------------------------------------------------------------------- +# Scoring +# --------------------------------------------------------------------------- + + +def compute_difficulty_score(trajectory: Trajectory) -> float: + """ + Compute a raw difficulty score in [0, 1]. + + Each feature contributes weighted points: + - tool calls: weight 0.20 (normalised at 10 calls = max) + - unique tools: weight 0.15 + - retries: weight 0.20 + - failures: weight 0.15 + - multilingual: weight 0.15 + - code-switch: weight 0.10 + - turns > 4: weight 0.05 + """ + tc_score = min(len(trajectory.tool_calls) / 10.0, 1.0) * 0.20 + ut_score = min(_unique_tools(trajectory) / 5.0, 1.0) * 0.15 + retry_score = min(_retry_count(trajectory) / 5.0, 1.0) * 0.20 + fail_score = min(_failure_count(trajectory) / 3.0, 1.0) * 0.15 + ml_score = 0.15 if _is_multilingual(trajectory) else 0.0 + cs_score = 0.10 if _has_code_switch(trajectory) else 0.0 + turn_score = 0.05 if _turn_count(trajectory) > 4 else 0.0 + + return round( + tc_score + ut_score + retry_score + fail_score + ml_score + cs_score + turn_score, + 4, + ) + + +def classify_difficulty(trajectory: Trajectory) -> DifficultyLevel: + """ + Classify difficulty into: simple | moderate | hard. + + Thresholds: + - simple: score < 0.25 + - moderate: 0.25 ≤ score < 0.55 + - hard: score ≥ 0.55 + """ + score = compute_difficulty_score(trajectory) + if score < 0.25: + return DifficultyLevel.SIMPLE + if score < 0.55: + return DifficultyLevel.MODERATE + return DifficultyLevel.HARD \ No newline at end of file diff --git a/training_setup_logs/trajectory/failure_classifier.py b/training_setup_logs/trajectory/failure_classifier.py new file mode 100644 index 0000000..8c0f341 --- /dev/null +++ b/training_setup_logs/trajectory/failure_classifier.py @@ -0,0 +1,217 @@ +""" +failure_classifier.py +--------------------- +Classify failure patterns in a trajectory. + +Each classifier function inspects the trajectory and returns a list of +detected FailureType values. The aggregate result is used to tag +trajectories for downstream SFT/DPO pipeline decisions. + +Design: rule-based classifiers are fast, deterministic, and interpretable. +LLM-assisted classification can be layered on top later. +""" + +from __future__ import annotations + +import re +from typing import List, Tuple + +from .models import ( + FailureType, + ToolCall, + ToolCallStatus, + Trajectory, + TrajectoryAnalysis, +) + +# --------------------------------------------------------------------------- +# Individual failure detectors +# --------------------------------------------------------------------------- + + +def _detect_missing_tool_return(tool_calls: List[ToolCall]) -> bool: + """Any tool call with MISSING_RETURN status.""" + return any(tc.status == ToolCallStatus.MISSING_RETURN for tc in tool_calls) + + +def _detect_hallucinated_args(tool_calls: List[ToolCall]) -> bool: + """ + Detect hallucinated tool arguments. + + Signals: + - Status HALLUCINATED explicitly set. + - Arguments contain obviously fabricated values (None keys, nested + placeholder strings like "", etc.). + """ + for tc in tool_calls: + if tc.status == ToolCallStatus.HALLUCINATED: + return True + # Check for placeholder patterns in argument values + for v in tc.arguments.values(): + if isinstance(v, str) and re.search(r"<.*?>|\.\.\.|PLACEHOLDER|TODO", v, re.I): + return True + return False + + +def _detect_redundant_tool_usage(tool_calls: List[ToolCall]) -> bool: + """ + Detect redundant (repeated identical) tool calls. + + Two calls are considered redundant if they share the same tool name + and identical arguments and neither is a retry. + """ + seen: set = set() + for tc in tool_calls: + if tc.retry_of is not None: + continue # retries are not redundant — they're intentional + key = (tc.tool_name, str(sorted(tc.arguments.items()))) + if key in seen: + return True + seen.add(key) + return False + + +def _detect_excessive_retries(tool_calls: List[ToolCall], threshold: int = 3) -> bool: + """More than `threshold` retry calls in the trajectory.""" + retry_count = sum(1 for tc in tool_calls if tc.retry_of is not None) + return retry_count >= threshold + + +def _detect_contradiction(trajectory: Trajectory) -> bool: + """ + Detect if the assistant response contradicts a tool return. + + Heuristic: look for explicit tool failure indicators in tool returns + paired with an assistant turn that proceeds as if it succeeded. + """ + failed_tools: set[str] = { + tc.tool_name + for tc in trajectory.tool_calls + if tc.status in (ToolCallStatus.FAILURE, ToolCallStatus.MISSING_RETURN) + } + if not failed_tools: + return False + + # Check assistant turns after a failed tool call + for turn in trajectory.turns: + if turn.role.value == "assistant" and turn.content: + # If assistant claims success about a failed tool — rough heuristic + for tool_name in failed_tools: + if re.search( + rf"\b{re.escape(tool_name)}\b.*\b(success|done|completed|found|retrieved)\b", + turn.content, + re.I, + ): + return True + return False + + +def _detect_incomplete_trajectory(trajectory: Trajectory) -> bool: + """ + A trajectory is incomplete if: + - The last message is from the user (no assistant response), OR + - There are tool calls with no following assistant turn. + """ + if not trajectory.turns: + return True + last_role = trajectory.turns[-1].role.value + if last_role == "user": + return True + # Check for tool calls not followed by assistant + has_tool_after_last_assistant = False + seen_assistant = False + for turn in reversed(trajectory.turns): + if turn.role.value == "assistant": + seen_assistant = True + break + if turn.role.value == "tool": + has_tool_after_last_assistant = True + return has_tool_after_last_assistant and not seen_assistant + + +def _detect_wrong_tool(tool_calls: List[ToolCall]) -> bool: + """Placeholder: explicit FAILURE with short latency suggests wrong tool branch.""" + for tc in tool_calls: + if tc.status == ToolCallStatus.FAILURE: + # If failure happened very fast (<50ms), likely wrong tool chosen + if tc.latency_ms is not None and tc.latency_ms < 50: + return True + return False + + +def _detect_multilingual_inconsistency(trajectory: Trajectory) -> bool: + """ + Detect if the assistant responds in a different language than the user query. + Requires per-turn language tags. + """ + user_langs = { + t.language for t in trajectory.turns + if t.role.value == "user" and t.language + } + asst_langs = { + t.language for t in trajectory.turns + if t.role.value == "assistant" and t.language + } + if not user_langs or not asst_langs: + return False + # If all user turns are one language and assistant used a different one + return bool(user_langs) and bool(asst_langs) and not user_langs.intersection(asst_langs) + + +# --------------------------------------------------------------------------- +# Aggregate classifier +# --------------------------------------------------------------------------- + + +def classify_failures(trajectory: Trajectory) -> Tuple[List[FailureType], str, bool]: + """ + Run all failure detectors on a trajectory. + + Args: + trajectory: The trajectory to classify. + + Returns: + (failure_types, severity, repair_candidate) + - failure_types: list of detected FailureType values + - severity: 'low' | 'medium' | 'high' + - repair_candidate: whether the trajectory is worth repairing + """ + tc = trajectory.tool_calls + failures: List[FailureType] = [] + + if _detect_missing_tool_return(tc): + failures.append(FailureType.MISSING_TOOL_RETURN) + if _detect_hallucinated_args(tc): + failures.append(FailureType.HALLUCINATED_ARGS) + if _detect_redundant_tool_usage(tc): + failures.append(FailureType.REDUNDANT_TOOL) + if _detect_excessive_retries(tc): + failures.append(FailureType.EXCESSIVE_RETRIES) + if _detect_contradiction(trajectory): + failures.append(FailureType.CONTRADICTION) + if _detect_incomplete_trajectory(trajectory): + failures.append(FailureType.INCOMPLETE_TRAJECTORY) + if _detect_wrong_tool(tc): + failures.append(FailureType.WRONG_TOOL) + if _detect_multilingual_inconsistency(trajectory): + failures.append(FailureType.MULTILINGUAL_INCONSISTENCY) + + # Severity + n = len(failures) + if n == 0: + severity = "none" + elif n == 1 and failures[0] in (FailureType.REDUNDANT_TOOL, FailureType.EXCESSIVE_RETRIES): + severity = "low" + elif n <= 2: + severity = "medium" + else: + severity = "high" + + # Repair candidate: has failures but is not catastrophically broken + repair_candidate = ( + len(failures) > 0 + and FailureType.INCOMPLETE_TRAJECTORY not in failures + and FailureType.MULTILINGUAL_INCONSISTENCY not in failures + ) + + return failures, severity, repair_candidate \ No newline at end of file diff --git a/training_setup_logs/trajectory/hard_example_miner.py b/training_setup_logs/trajectory/hard_example_miner.py new file mode 100644 index 0000000..798161d --- /dev/null +++ b/training_setup_logs/trajectory/hard_example_miner.py @@ -0,0 +1,154 @@ +""" +hard_example_miner.py +--------------------- +Identify and tag hard examples for downstream SFT/DPO pipelines. + +IMPORTANT DESIGN PRINCIPLE +--------------------------- +We do NOT discard bad trajectories. Failures, retries, and recovery +behaviors are valuable training signal. This module decides *how* to +label each trajectory so the downstream consumer can choose: + +- SFT_WORTHY → clean example; include directly in supervised fine-tuning +- DPO_NEGATIVE → failed trajectory; use as rejected response in DPO +- REPAIR_WORTHY → fixable failure; queue for human/LLM repair +- MULTILINGUAL_EDGE → rare code-switch / transliteration example to preserve +- RECOVERY_HEAVY → rich recovery behavior; valuable for agent training +- SAFETY_SENSITIVE → flagged for safety review before use +- EVALUATION_WORTHY → complex / ambiguous; better for eval than training +""" + +from __future__ import annotations + +import re +from typing import List, Set + +from .models import ( + DifficultyLevel, + FailureType, + HardExampleTag, + ToolCallStatus, + Trajectory, + TrajectoryAnalysis, +) + +# --------------------------------------------------------------------------- +# Safety heuristics +# --------------------------------------------------------------------------- + +_SAFETY_PATTERNS = re.compile( + r"\b(harm|danger|weapon|illegal|drug|hack|exploit|bypass|jailbreak|" + r"violence|suicide|self-harm|nsfw)\b", + re.I, +) + + +def _has_safety_signal(trajectory: Trajectory) -> bool: + for turn in trajectory.turns: + if _SAFETY_PATTERNS.search(turn.content): + return True + return False + + +# --------------------------------------------------------------------------- +# Main tagging logic +# --------------------------------------------------------------------------- + + +def tag_hard_examples( + trajectory: Trajectory, + analysis: TrajectoryAnalysis, +) -> List[HardExampleTag]: + """ + Assign hard-example tags to a trajectory based on its analysis. + + Args: + trajectory: The raw trajectory. + analysis: Pre-computed TrajectoryAnalysis (from analyzer.py). + + Returns: + List of applicable HardExampleTag values. + """ + tags: Set[HardExampleTag] = set() + + # ------------------------------------------------------------------ + # SFT_WORTHY: clean, no failures, at least one tool call resolved + # ------------------------------------------------------------------ + if ( + not analysis.has_failure + and analysis.efficiency_score >= 0.6 + and analysis.difficulty != DifficultyLevel.SIMPLE + ): + tags.add(HardExampleTag.SFT_WORTHY) + + # ------------------------------------------------------------------ + # DPO_NEGATIVE: trajectory with clear failure, not repair-worthy + # ------------------------------------------------------------------ + if ( + analysis.has_failure + and not analysis.repair_candidate + and FailureType.INCOMPLETE_TRAJECTORY in analysis.failure_types + ): + tags.add(HardExampleTag.DPO_NEGATIVE) + + # ------------------------------------------------------------------ + # REPAIR_WORTHY: has fixable failures + # ------------------------------------------------------------------ + if analysis.repair_candidate: + tags.add(HardExampleTag.REPAIR_WORTHY) + + # ------------------------------------------------------------------ + # MULTILINGUAL_EDGE: code-switched or transliteration variant + # ------------------------------------------------------------------ + if analysis.is_code_switched or analysis.transliteration_variant or analysis.is_multilingual: + tags.add(HardExampleTag.MULTILINGUAL_EDGE) + + # ------------------------------------------------------------------ + # RECOVERY_HEAVY: multiple recovery patterns + # ------------------------------------------------------------------ + if analysis.has_recovery and len(analysis.recovery_patterns) >= 2: + tags.add(HardExampleTag.RECOVERY_HEAVY) + + # ------------------------------------------------------------------ + # SAFETY_SENSITIVE: safety signal detected + # ------------------------------------------------------------------ + if _has_safety_signal(trajectory): + tags.add(HardExampleTag.SAFETY_SENSITIVE) + + # ------------------------------------------------------------------ + # EVALUATION_WORTHY: hard but not broken — good for eval sets + # ------------------------------------------------------------------ + if ( + analysis.difficulty == DifficultyLevel.HARD + and not analysis.has_failure + and HardExampleTag.SFT_WORTHY not in tags + ): + tags.add(HardExampleTag.EVALUATION_WORTHY) + + # Alternatively: moderate difficulty with multilingual = good eval + if ( + analysis.difficulty == DifficultyLevel.MODERATE + and analysis.is_multilingual + and not analysis.has_failure + ): + tags.add(HardExampleTag.EVALUATION_WORTHY) + + return sorted(tags, key=lambda t: t.value) + + +def is_hard_example(tags: List[HardExampleTag]) -> bool: + """ + Return True if any tag indicates this is a 'hard' or high-value example. + + Simple examples with SFT_WORTHY as their only tag are not 'hard' per se, + but they are valuable. We mark as hard if there's any complexity signal. + """ + non_trivial_tags = { + HardExampleTag.DPO_NEGATIVE, + HardExampleTag.REPAIR_WORTHY, + HardExampleTag.MULTILINGUAL_EDGE, + HardExampleTag.RECOVERY_HEAVY, + HardExampleTag.SAFETY_SENSITIVE, + HardExampleTag.EVALUATION_WORTHY, + } + return bool(set(tags) & non_trivial_tags) \ No newline at end of file diff --git a/training_setup_logs/trajectory/metadata_enrichment.py b/training_setup_logs/trajectory/metadata_enrichment.py new file mode 100644 index 0000000..69c2c22 --- /dev/null +++ b/training_setup_logs/trajectory/metadata_enrichment.py @@ -0,0 +1,146 @@ +""" +metadata_enrichment.py +---------------------- +Enrich trajectories with full analysis metadata. + +This module computes tool-use efficiency and language-level metadata, +complementing the failure classifier and difficulty scorer. +""" + +from __future__ import annotations + +import math +from typing import List, Optional + +from .models import ToolCall, ToolCallStatus, Trajectory + + +# --------------------------------------------------------------------------- +# Tool efficiency +# --------------------------------------------------------------------------- + + +def _estimate_optimal_calls(tool_calls: List[ToolCall]) -> int: + """ + Estimate the minimum number of tool calls needed for this trajectory. + + Heuristic: + - Count unique tool names used successfully (ignoring retries/redundant). + - Add 1 for any fallback that was necessary. + """ + # Unique tools that ran successfully (non-retry, non-fallback) + primary_successes: set[str] = { + tc.tool_name + for tc in tool_calls + if tc.status == ToolCallStatus.SUCCESS + and tc.retry_of is None + and not tc.is_fallback + } + fallbacks_used = sum(1 for tc in tool_calls if tc.is_fallback and tc.status == ToolCallStatus.SUCCESS) + + return max(len(primary_successes) + fallbacks_used, 1) + + +def compute_efficiency_score(tool_calls: List[ToolCall]) -> float: + """ + Ratio of optimal calls to actual calls. + + Score 1.0 = perfectly efficient. + Score 0.5 = used twice as many calls as necessary. + """ + if not tool_calls: + return 1.0 + optimal = _estimate_optimal_calls(tool_calls) + actual = len(tool_calls) + return round(min(optimal / actual, 1.0), 4) + + +# --------------------------------------------------------------------------- +# Language metadata +# --------------------------------------------------------------------------- + + +def detect_trajectory_language(trajectory: Trajectory) -> str: + """ + Infer the primary language of a trajectory from per-turn language tags. + + Returns: + A language code string, e.g. 'hi', 'en', 'hi-en-mixed', 'unknown'. + """ + langs = [t.language for t in trajectory.turns if t.language] + if not langs: + return "unknown" + counter: dict[str, int] = {} + for lang in langs: + counter[lang] = counter.get(lang, 0) + 1 + # Majority language + primary = max(counter, key=counter.__getitem__) + has_multi = len(set(langs)) > 1 + if has_multi: + if "hi-en-mixed" in langs: + return "hi-en-mixed" + codes = sorted(set(langs)) + return "-".join(codes) + return primary + + +def is_multilingual_trajectory(trajectory: Trajectory) -> bool: + langs = {t.language for t in trajectory.turns if t.language} + return len(langs) > 1 + + +def is_code_switched(trajectory: Trajectory) -> bool: + return any(t.language == "hi-en-mixed" for t in trajectory.turns if t.language) + + +def has_transliteration_variant(trajectory: Trajectory) -> bool: + return any(t.language == "hi-latn" for t in trajectory.turns if t.language) + + +# --------------------------------------------------------------------------- +# Full enrichment +# --------------------------------------------------------------------------- + + +def enrich_metadata(trajectory: Trajectory, analysis) -> dict: + """ + Produce a flat metadata dict for a trajectory. + + This is the canonical representation attached to each trajectory + before export to SFT/DPO datasets. + + Args: + trajectory: The trajectory. + analysis: TrajectoryAnalysis (already populated by analyzer.py). + + Returns: + Flat dict ready for JSON serialization. + """ + return { + # Identity + "trajectory_id": trajectory.trajectory_id, + # Language + "language": analysis.detected_language, + "is_multilingual": analysis.is_multilingual, + "is_code_switched": analysis.is_code_switched, + "transliteration_variant": analysis.transliteration_variant, + # Quality signals + "difficulty": analysis.difficulty.value, + "has_failure": analysis.has_failure, + "failure_types": [f.value for f in analysis.failure_types], + "failure_severity": analysis.failure_severity, + # Recovery + "has_recovery": analysis.has_recovery, + "recovery_patterns": analysis.recovery_patterns, + "recovery_success": analysis.recovery_success, + # Tool efficiency + "tool_count": analysis.tool_count, + "estimated_optimal_calls": analysis.estimated_optimal_calls, + "efficiency_score": analysis.efficiency_score, + # Hard-example signals + "hard_example": analysis.hard_example, + "hard_example_tags": [t.value for t in analysis.hard_example_tags], + "repair_candidate": analysis.repair_candidate, + # Dedup + "semantic_cluster_id": analysis.semantic_cluster_id, + } \ No newline at end of file diff --git a/training_setup_logs/trajectory/models.py b/training_setup_logs/trajectory/models.py new file mode 100644 index 0000000..f4bd6ad --- /dev/null +++ b/training_setup_logs/trajectory/models.py @@ -0,0 +1,144 @@ +""" +models.py +--------- +Pydantic data models for trajectory representation. + +These are the shared types used across the trajectory intelligence layer. +They model the structure of a processed log entry from Langfuse / +Pydantic-based tracing systems. +""" + +from __future__ import annotations + +from datetime import datetime +from enum import Enum +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, Field + + +# --------------------------------------------------------------------------- +# Enums +# --------------------------------------------------------------------------- + + +class ToolCallStatus(str, Enum): + SUCCESS = "success" + FAILURE = "failure" + PARTIAL = "partial" + MISSING_RETURN = "missing_return" + HALLUCINATED = "hallucinated" + + +class MessageRole(str, Enum): + USER = "user" + ASSISTANT = "assistant" + TOOL = "tool" + SYSTEM = "system" + + +class DifficultyLevel(str, Enum): + SIMPLE = "simple" + MODERATE = "moderate" + HARD = "hard" + + +class HardExampleTag(str, Enum): + SFT_WORTHY = "sft_worthy" + DPO_NEGATIVE = "dpo_negative_candidate" + REPAIR_WORTHY = "repair_worthy" + MULTILINGUAL_EDGE = "multilingual_edge_case" + RECOVERY_HEAVY = "recovery_heavy" + SAFETY_SENSITIVE = "safety_sensitive" + EVALUATION_WORTHY = "evaluation_worthy" + + +class FailureType(str, Enum): + WRONG_TOOL = "wrong_tool_selected" + HALLUCINATED_ARGS = "hallucinated_tool_args" + INCOMPLETE_TRAJECTORY = "incomplete_trajectory" + MISSING_TOOL_RETURN = "missing_tool_return" + CONTRADICTION = "contradiction_with_tool_output" + REDUNDANT_TOOL = "redundant_tool_usage" + EXCESSIVE_RETRIES = "excessive_retries" + RECOVERY_FAILURE = "recovery_failure" + PERSONA_MISMATCH = "persona_mismatch" + MULTILINGUAL_INCONSISTENCY = "multilingual_inconsistency" + + +# --------------------------------------------------------------------------- +# Core trajectory components +# --------------------------------------------------------------------------- + + +class ToolCall(BaseModel): + tool_name: str + arguments: Dict[str, Any] = Field(default_factory=dict) + return_value: Optional[Any] = None + status: ToolCallStatus = ToolCallStatus.SUCCESS + timestamp: Optional[datetime] = None + latency_ms: Optional[float] = None + retry_of: Optional[int] = None # index of earlier call this retries + is_fallback: bool = False # True if this replaced a failed tool + + +class TurnMessage(BaseModel): + role: MessageRole + content: str + timestamp: Optional[datetime] = None + language: Optional[str] = None # detected language code + + +class Trajectory(BaseModel): + """ + A single multi-turn trajectory from a production log. + """ + + trajectory_id: str + turns: List[TurnMessage] = Field(default_factory=list) + tool_calls: List[ToolCall] = Field(default_factory=list) + metadata: Dict[str, Any] = Field(default_factory=dict) + + # Filled in by the intelligence layer + analysis: Optional[TrajectoryAnalysis] = None + + +class TrajectoryAnalysis(BaseModel): + """ + Metadata produced by the trajectory intelligence layer. + Attached to each Trajectory after analysis. + """ + + # Language + detected_language: str = "unknown" + is_multilingual: bool = False + is_code_switched: bool = False + transliteration_variant: bool = False + + # Quality + difficulty: DifficultyLevel = DifficultyLevel.SIMPLE + has_failure: bool = False + failure_types: List[FailureType] = Field(default_factory=list) + failure_severity: Optional[str] = None # 'low' | 'medium' | 'high' + + # Recovery + has_recovery: bool = False + recovery_patterns: List[str] = Field(default_factory=list) + recovery_success: Optional[bool] = None + + # Tool efficiency + tool_count: int = 0 + estimated_optimal_calls: int = 0 + efficiency_score: float = 1.0 + + # Hard-example tags + hard_example: bool = False + hard_example_tags: List[HardExampleTag] = Field(default_factory=list) + + # Dedup + semantic_cluster_id: Optional[str] = None + repair_candidate: bool = False + + +# Resolve forward reference +Trajectory.model_rebuild() \ No newline at end of file diff --git a/training_setup_logs/trajectory/recovery_patterns.py b/training_setup_logs/trajectory/recovery_patterns.py new file mode 100644 index 0000000..c49aecc --- /dev/null +++ b/training_setup_logs/trajectory/recovery_patterns.py @@ -0,0 +1,171 @@ +""" +recovery_patterns.py +-------------------- +Detect and classify recovery behavior in trajectories. + +Recovery patterns are especially valuable training signal because they show: +- how agents handle failures, +- when fallback tools are used, +- whether self-correction leads to success. + +Detected patterns +----------------- +- retry_success : tool_fail → retry → success +- retry_failure : tool_fail → retry → failure (still) +- fallback_tool_success : tool_fail → different_tool → success +- fallback_tool_failure : tool_fail → different_tool → failure +- clarification_loop : user_query → assistant_clarification → user_answer +- self_correction : assistant contradicts its own prior turn and corrects +- graceful_degradation : acknowledged failure, provided partial answer +""" + +from __future__ import annotations + +import re +from typing import List, Tuple + +from .models import ToolCall, ToolCallStatus, Trajectory + + +# --------------------------------------------------------------------------- +# Pattern detection +# --------------------------------------------------------------------------- + + +def _get_failed_indices(tool_calls: List[ToolCall]) -> List[int]: + return [ + i for i, tc in enumerate(tool_calls) + if tc.status in (ToolCallStatus.FAILURE, ToolCallStatus.MISSING_RETURN, + ToolCallStatus.HALLUCINATED) + ] + + +def detect_retry_patterns(tool_calls: List[ToolCall]) -> List[str]: + """ + Detect retry sequences. + + Returns: + List of pattern strings: 'retry_success' | 'retry_failure' + """ + patterns = [] + for i, tc in enumerate(tool_calls): + if tc.retry_of is None: + continue + # This is a retry; check if it succeeded + if tc.status == ToolCallStatus.SUCCESS: + patterns.append("retry_success") + else: + patterns.append("retry_failure") + return patterns + + +def detect_fallback_patterns(tool_calls: List[ToolCall]) -> List[str]: + """ + Detect fallback-tool patterns. + + A fallback is flagged via tc.is_fallback == True. + + Returns: + List of pattern strings. + """ + patterns = [] + for tc in tool_calls: + if not tc.is_fallback: + continue + if tc.status == ToolCallStatus.SUCCESS: + patterns.append("fallback_tool_success") + else: + patterns.append("fallback_tool_failure") + return patterns + + +def detect_clarification_loop(trajectory: Trajectory) -> bool: + """ + Detect clarification loops: assistant asks a question, user answers. + + Heuristic: an assistant turn ending in '?' followed by a user turn. + """ + turns = trajectory.turns + for i in range(len(turns) - 1): + curr = turns[i] + nxt = turns[i + 1] + if ( + curr.role.value == "assistant" + and nxt.role.value == "user" + and curr.content.rstrip().endswith("?") + ): + return True + return False + + +def detect_self_correction(trajectory: Trajectory) -> bool: + """ + Detect self-correction: assistant explicitly walks back a prior statement. + + Heuristic: assistant turn contains correction markers. + """ + correction_markers = re.compile( + r"\b(actually|correction|i made an error|let me correct|i was wrong|" + r"apologies|my mistake|to clarify|i should clarify)\b", + re.I, + ) + asst_turns = [t for t in trajectory.turns if t.role.value == "assistant"] + # Only meaningful if there's more than one assistant turn + if len(asst_turns) < 2: + return False + return any(correction_markers.search(t.content) for t in asst_turns[1:]) + + +def detect_graceful_degradation(trajectory: Trajectory) -> bool: + """ + Detect graceful degradation: assistant acknowledges failure but provides + partial or alternative answer. + + Heuristic: failure + assistant turn with hedging language. + """ + failed_indices = _get_failed_indices(trajectory.tool_calls) + if not failed_indices: + return False + + hedging = re.compile( + r"\b(could not|wasn't able|unable to|unfortunately|" + r"however|alternatively|instead|partial|best i can)\b", + re.I, + ) + asst_turns = [t for t in trajectory.turns if t.role.value == "assistant"] + return any(hedging.search(t.content) for t in asst_turns) + + +# --------------------------------------------------------------------------- +# Aggregate miner +# --------------------------------------------------------------------------- + + +def mine_recovery_patterns(trajectory: Trajectory) -> Tuple[List[str], bool]: + """ + Mine all recovery patterns from a trajectory. + + Args: + trajectory: Input trajectory. + + Returns: + (patterns, recovery_success) + - patterns: list of detected pattern names + - recovery_success: True if at least one recovery was successful + """ + patterns: List[str] = [] + + patterns.extend(detect_retry_patterns(trajectory.tool_calls)) + patterns.extend(detect_fallback_patterns(trajectory.tool_calls)) + + if detect_clarification_loop(trajectory): + patterns.append("clarification_loop") + if detect_self_correction(trajectory): + patterns.append("self_correction") + if detect_graceful_degradation(trajectory): + patterns.append("graceful_degradation") + + success_patterns = {"retry_success", "fallback_tool_success"} + recovery_success = bool(patterns) and any(p in success_patterns for p in patterns) + + return patterns, recovery_success \ No newline at end of file