diff --git a/benchmarks/stability_benchmarks.py b/benchmarks/stability_benchmarks.py
new file mode 100644
index 0000000..0c4082f
--- /dev/null
+++ b/benchmarks/stability_benchmarks.py
@@ -0,0 +1,736 @@
+"""
+Stability Benchmarking Script for DiffMage
+
+This script uses your existing EvaluationBenchmarks class to test evaluation
+stability across various real commit examples and scenarios.
+
+Usage:
+ python benchmark_stability.py --repo-path /path/to/repo --runs 10
+ python benchmark_stability.py --commit-range "HEAD~20..HEAD" --runs 5
+ python benchmark_stability.py --test-examples --model gpt-4
+ python benchmark_stability.py --batch-test --variance-threshold 0.2
+"""
+
+import argparse
+import json
+import time
+from pathlib import Path
+from typing import List, Dict, Any, Tuple, Optional
+from datetime import datetime
+import git
+import random
+
+from rich.console import Console
+from rich.table import Table
+from rich.panel import Panel
+from rich.progress import Progress
+from rich import box
+
+from diffmage.evaluation.benchmarks import EvaluationBenchmarks, StabilityTestResult
+from diffmage.evaluation.commit_message_evaluator import CommitMessageEvaluator
+from diffmage.git.diff_parser import GitDiffParser
+
+console = Console()
+
+
+class StabilityBenchmarkSuite:
+ """Comprehensive stability testing using real commit data"""
+
+ def __init__(self, repo_path: str = ".", model_name: Optional[str] = None):
+ self.repo_path = Path(repo_path)
+ self.repo = git.Repo(repo_path)
+ self.diff_parser = GitDiffParser(repo_path)
+ self.evaluator = CommitMessageEvaluator(model_name=model_name)
+ self.benchmarks = EvaluationBenchmarks(self.evaluator)
+ self.console = console
+
+ def get_real_commit_examples(
+ self, commit_range: Optional[str] = None, max_examples: int = 20
+ ) -> List[Tuple[str, str, str]]:
+ """Extract real commit examples from repository
+
+ Returns:
+ List of (commit_hash, commit_message, git_diff) tuples
+ """
+
+ if commit_range:
+ try:
+ commits = list(self.repo.iter_commits(commit_range))
+ except Exception as e:
+ console.print(
+ f"[red]Error parsing commit range '{commit_range}': {e}[/red]"
+ )
+ commits = list(self.repo.iter_commits("HEAD", max_count=max_examples))
+ else:
+ commits = list(self.repo.iter_commits("HEAD", max_count=max_examples))
+
+ examples = []
+
+ console.print(
+ f"[blue]Extracting examples from {len(commits)} commits...[/blue]"
+ )
+
+ with Progress(console=self.console) as progress:
+ task = progress.add_task("Processing commits...", total=len(commits))
+
+ for commit in commits:
+ try:
+ message = str(commit.message).strip()
+
+ if len(commit.parents) > 1 or not message or len(message) < 10:
+ progress.update(task, advance=1)
+ continue
+
+ git_diff = self.diff_parser.parse_specific_commit(commit.hexsha)
+
+ if not git_diff or len(git_diff) > 10000:
+ progress.update(task, advance=1)
+ continue
+
+ examples.append((commit.hexsha, message, git_diff))
+
+ if len(examples) >= max_examples:
+ break
+
+ except Exception as e:
+ console.print(
+ f"[dim]Warning: Skipped commit {commit.hexsha[:8]}: {e}[/dim]"
+ )
+
+ progress.update(task, advance=1)
+
+ console.print(
+ f"[green]Successfully extracted {len(examples)} usable commit examples[/green]"
+ )
+ return examples
+
+ def get_curated_test_examples(self) -> List[Tuple[str, str, str]]:
+ """Get curated test examples with various commit patterns
+
+ Returns:
+ List of (example_id, commit_message, git_diff) tuples
+ """
+
+ examples = [
+ (
+ "bug_fix_null_check",
+ "fix: resolve null pointer exception in user validation",
+ """--- a/src/auth/UserValidator.java
++++ b/src/auth/UserValidator.java
+@@ -23,7 +23,7 @@ public class UserValidator {
+ }
+
+ public boolean isValidUser(User user) {
+- return user.getEmail().contains("@");
++ return user != null && user.getEmail() != null && user.getEmail().contains("@");
+ }
+ }""",
+ ),
+ (
+ "feature_dark_mode",
+ "feat: implement dark mode toggle in user preferences",
+ """--- a/src/components/UserPreferences.jsx
++++ b/src/components/UserPreferences.jsx
+@@ -8,6 +8,7 @@ const UserPreferences = () => {
+ const [language, setLanguage] = useState('en');
+ const [notifications, setNotifications] = useState(true);
++ const [darkMode, setDarkMode] = useState(false);
+
+ const savePreferences = async () => {
+ const prefs = {
+@@ -15,6 +16,7 @@ const UserPreferences = () => {
+ language,
+ notifications,
++ darkMode,
+ };
+
+ await api.updatePreferences(prefs);
+@@ -35,6 +37,13 @@ const UserPreferences = () => {
+ onChange={(e) => setNotifications(e.target.checked)}
+ />
+
++
++
++
++ setDarkMode(e.target.checked)}
++ />
++
+
+ );""",
+ ),
+ (
+ "refactor_extract_method",
+ "refactor: extract user authentication logic into separate service",
+ """--- a/src/controllers/AuthController.js
++++ b/src/controllers/AuthController.js
+@@ -1,4 +1,5 @@
+ const bcrypt = require('bcrypt');
++const AuthService = require('../services/AuthService');
+
+ class AuthController {
+ async login(req, res) {
+@@ -6,15 +7,7 @@ class AuthController {
+
+ try {
+ // Authenticate user
+- const user = await User.findOne({ email });
+- if (!user) {
+- return res.status(401).json({ error: 'Invalid credentials' });
+- }
+-
+- const isValidPassword = await bcrypt.compare(password, user.passwordHash);
+- if (!isValidPassword) {
+- return res.status(401).json({ error: 'Invalid credentials' });
+- }
++ const user = await AuthService.authenticateUser(email, password);
+
+ const token = jwt.sign({ userId: user.id }, process.env.JWT_SECRET);
+ res.json({ token, user: { id: user.id, email: user.email } });
+@@ -24,4 +17,22 @@ class AuthController {
+ }
+ }
+
++--- /dev/null
+++++ b/src/services/AuthService.js
+@@ -0,0 +1,22 @@
++const bcrypt = require('bcrypt');
++const User = require('../models/User');
++
++class AuthService {
++ static async authenticateUser(email, password) {
++ const user = await User.findOne({ email });
++ if (!user) {
++ throw new Error('Invalid credentials');
++ }
++
++ const isValidPassword = await bcrypt.compare(password, user.passwordHash);
++ if (!isValidPassword) {
++ throw new Error('Invalid credentials');
++ }
++
++ return user;
++ }
++}
++
++module.exports = AuthService;""",
+ ),
+ (
+ "docs_api_update",
+ "docs: update API documentation for user endpoints",
+ """--- a/docs/api/users.md
++++ b/docs/api/users.md
+@@ -23,6 +23,20 @@ Creates a new user account.
+ - `409 Conflict` - Email already exists
+ - `500 Internal Server Error` - Server error
+
++### Request Example
++
++```json
++{
++ "email": "user@example.com",
++ "password": "securePassword123",
++ "firstName": "John",
++ "lastName": "Doe"
++}
++```
++
++### Response Example
++
++```json
+ ## GET /api/users/:id
+
+ Retrieves user information by ID.
+@@ -35,3 +49,15 @@ Retrieves user information by ID.
+ - `200 OK` - User found
+ - `404 Not Found` - User not found
+ - `500 Internal Server Error` - Server error
++
++### Response Example
++
++```json
++{
++ "id": "123e4567-e89b-12d3-a456-426614174000",
++ "email": "user@example.com",
++ "firstName": "John",
++ "lastName": "Doe",
++ "createdAt": "2024-01-15T10:30:00Z"
++}
++```""",
+ ),
+ (
+ "performance_optimize_query",
+ "perf: optimize database query for user search with indexing",
+ """--- a/src/models/User.js
++++ b/src/models/User.js
+@@ -15,8 +15,12 @@ const userSchema = new mongoose.Schema({
+ }
+ });
+
++// Add indexes for better query performance
++userSchema.index({ email: 1 });
++userSchema.index({ firstName: 1, lastName: 1 });
++userSchema.index({ createdAt: -1 });
++
+ // Static method for user search
+-userSchema.statics.searchUsers = function(searchTerm) {
++userSchema.statics.searchUsers = function(searchTerm, limit = 20) {
+ const regex = new RegExp(searchTerm, 'i');
+ return this.find({
+ $or: [
+@@ -24,7 +28,8 @@ userSchema.statics.searchUsers = function(searchTerm) {
+ { firstName: regex },
+ { lastName: regex }
+ ]
+- });
++ })
++ .limit(limit)
++ .sort({ createdAt: -1 });
+ };""",
+ ),
+ (
+ "test_add_validation",
+ "test: add unit tests for email validation utility",
+ """--- /dev/null
++++ b/tests/utils/emailValidator.test.js
+@@ -0,0 +1,45 @@
++const { validateEmail } = require('../../src/utils/emailValidator');
++
++describe('Email Validator', () => {
++ describe('valid emails', () => {
++ test('should validate standard email format', () => {
++ expect(validateEmail('user@example.com')).toBe(true);
++ expect(validateEmail('john.doe@company.org')).toBe(true);
++ expect(validateEmail('admin@subdomain.example.co.uk')).toBe(true);
++ });
++
++ test('should validate emails with plus signs', () => {
++ expect(validateEmail('user+tag@example.com')).toBe(true);
++ });
++
++ test('should validate emails with numbers', () => {
++ expect(validateEmail('user123@example123.com')).toBe(true);
++ });
++ });
++
++ describe('invalid emails', () => {
++ test('should reject emails without @ symbol', () => {
++ expect(validateEmail('userexample.com')).toBe(false);
++ });
++
++ test('should reject emails without domain', () => {
++ expect(validateEmail('user@')).toBe(false);
++ });
++
++ test('should reject emails without user part', () => {
++ expect(validateEmail('@example.com')).toBe(false);
++ });
++
++ test('should reject empty strings', () => {
++ expect(validateEmail('')).toBe(false);
++ });
++
++ test('should reject null and undefined', () => {
++ expect(validateEmail(null)).toBe(false);
++ expect(validateEmail(undefined)).toBe(false);
++ });
++ });
++});""",
+ ),
+ ]
+
+ console.print(f"[green]Loaded {len(examples)} curated test examples[/green]")
+ return examples
+
+ def run_single_stability_test(
+ self,
+ commit_hash: str,
+ message: str,
+ git_diff: str,
+ runs: int = 5,
+ variance_threshold: float = 0.3,
+ ) -> StabilityTestResult:
+ """Run stability test on a single commit example"""
+
+ console.print(f"[blue]Testing stability for: {message[:60]}...[/blue]")
+ console.print(
+ f"[dim]Commit: {commit_hash[:8] if len(commit_hash) > 8 else commit_hash}[/dim]"
+ )
+
+ result = self.benchmarks.stability_test(
+ message=message,
+ diff=git_diff,
+ runs=runs,
+ variance_threshold=variance_threshold,
+ )
+
+ return result
+
+ def run_batch_stability_test(
+ self,
+ examples: List[Tuple[str, str, str]],
+ runs: int = 5,
+ variance_threshold: float = 0.3,
+ max_examples: Optional[int] = None,
+ ) -> Dict[str, Any]:
+ """Run stability tests on multiple examples"""
+
+ if max_examples and len(examples) > max_examples:
+ examples = random.sample(examples, max_examples)
+
+ console.print(
+ f"[blue]Running batch stability test on {len(examples)} examples...[/blue]"
+ )
+ console.print(
+ f"[dim]Runs per example: {runs}, Variance threshold: {variance_threshold}[/dim]"
+ )
+
+ results = []
+ stable_count = 0
+ total_time = 0
+
+ with Progress(console=self.console) as progress:
+ task = progress.add_task("Testing stability...", total=len(examples))
+
+ for commit_hash, message, git_diff in examples:
+ start_time = time.time()
+
+ try:
+ result = self.benchmarks.stability_test(
+ message=message,
+ diff=git_diff,
+ runs=runs,
+ variance_threshold=variance_threshold,
+ )
+
+ results.append(result)
+
+ if result["is_stable"]: # type: ignore
+ stable_count += 1
+
+ except Exception as e:
+ console.print(f"[red]Error testing {commit_hash[:8]}: {e}[/red]")
+ continue
+
+ total_time += time.time() - start_time
+ progress.update(task, advance=1)
+
+ stability_rate = (stable_count / len(results)) * 100 if results else 0
+ avg_time_per_test = total_time / len(results) if results else 0
+
+ summary = {
+ "total_examples": len(examples),
+ "successful_tests": len(results),
+ "stable_examples": stable_count,
+ "stability_rate": stability_rate,
+ "average_time_per_test": avg_time_per_test,
+ "total_test_time": total_time,
+ "runs_per_example": runs,
+ "variance_threshold": variance_threshold,
+ }
+
+ self._display_batch_summary(summary)
+
+ return {
+ "summary": summary,
+ "individual_results": results,
+ "timestamp": datetime.now().isoformat(),
+ }
+
+ def _display_batch_summary(self, summary: Dict[str, Any]):
+ """Display batch test summary"""
+
+ console.print(
+ Panel(
+ f"[bold]Batch Stability Test Results[/bold]\n"
+ f"Examples Tested: {summary['successful_tests']}/{summary['total_examples']}\n"
+ f"Stable Examples: {summary['stable_examples']}\n"
+ f"Stability Rate: {summary['stability_rate']:.1f}%",
+ title="π Summary",
+ border_style="green"
+ if summary["stability_rate"] >= 80
+ else "yellow"
+ if summary["stability_rate"] >= 60
+ else "red",
+ )
+ )
+
+ # Detailed table
+ table = Table(title="Performance Metrics", box=box.SIMPLE)
+ table.add_column("Metric", style="cyan")
+ table.add_column("Value", justify="center")
+ table.add_column("Assessment", style="yellow")
+
+ # Stability rate assessment
+ if summary["stability_rate"] >= 90:
+ stability_assessment = "π’ Excellent"
+ elif summary["stability_rate"] >= 80:
+ stability_assessment = "π‘ Good"
+ elif summary["stability_rate"] >= 60:
+ stability_assessment = "π Acceptable"
+ else:
+ stability_assessment = "π΄ Needs Improvement"
+
+ # Time assessment
+ avg_time = summary["average_time_per_test"]
+ if avg_time < 30:
+ time_assessment = "π’ Fast"
+ elif avg_time < 60:
+ time_assessment = "π‘ Reasonable"
+ else:
+ time_assessment = "π΄ Slow"
+
+ table.add_row(
+ "Stability Rate", f"{summary['stability_rate']:.1f}%", stability_assessment
+ )
+ table.add_row("Avg Time per Test", f"{avg_time:.1f}s", time_assessment)
+ table.add_row("Total Test Time", f"{summary['total_test_time']:.1f}s", "")
+ table.add_row("Runs per Example", str(summary["runs_per_example"]), "")
+ table.add_row("Variance Threshold", str(summary["variance_threshold"]), "")
+
+ console.print(table)
+
+ def run_comparative_stability_test(
+ self, examples: List[Tuple[str, str, str]], models: List[str], runs: int = 3
+ ) -> Dict[str, Any]:
+ """Compare stability across different models"""
+
+ console.print(
+ f"[blue]Running comparative stability test across {len(models)} models...[/blue]"
+ )
+
+ model_results = {}
+
+ for model in models:
+ console.print(f"[yellow]Testing model: {model}[/yellow]")
+
+ # Create new evaluator for this model
+ evaluator = CommitMessageEvaluator(model_name=model)
+ benchmarks = EvaluationBenchmarks(evaluator)
+
+ model_stability_results = []
+
+ for i, (commit_hash, message, git_diff) in enumerate(
+ examples[:5]
+ ): # Limit for comparative test
+ try:
+ result = benchmarks.stability_test(
+ message=message,
+ diff=git_diff,
+ runs=runs,
+ variance_threshold=0.3,
+ )
+
+ model_stability_results.append(
+ {
+ "example_id": i,
+ "is_stable": result["is_stable"],
+ "max_variance": result["max_variance"],
+ "commit_hash": commit_hash,
+ }
+ )
+
+ except Exception as e:
+ console.print(
+ f"[red]Error with {model} on {commit_hash[:8]}: {e}[/red]"
+ )
+ continue
+
+ model_results[model] = model_stability_results
+
+ # Calculate comparative statistics
+ comparative_stats = self._calculate_comparative_stats(model_results)
+ self._display_comparative_results(comparative_stats)
+
+ return {
+ "model_results": model_results,
+ "comparative_stats": comparative_stats,
+ "timestamp": datetime.now().isoformat(),
+ }
+
+ def _calculate_comparative_stats(
+ self, model_results: Dict[str, List]
+ ) -> Dict[str, Any]:
+ """Calculate comparative statistics across models"""
+
+ stats = {}
+
+ for model, results in model_results.items():
+ if not results:
+ continue
+
+ stable_count = sum(1 for r in results if r["is_stable"])
+ total_count = len(results)
+ avg_variance = (
+ sum(r["max_variance"] for r in results) / total_count
+ if total_count > 0
+ else 0
+ )
+
+ stats[model] = {
+ "stability_rate": (stable_count / total_count * 100)
+ if total_count > 0
+ else 0,
+ "average_variance": avg_variance,
+ "stable_examples": stable_count,
+ "total_examples": total_count,
+ }
+
+ return stats
+
+ def _display_comparative_results(self, stats: Dict[str, Any]):
+ """Display comparative model results"""
+
+ table = Table(title="Model Stability Comparison", box=box.SIMPLE)
+ table.add_column("Model", style="cyan")
+ table.add_column("Stability Rate", justify="center")
+ table.add_column("Avg Variance", justify="center")
+ table.add_column("Stable/Total", justify="center")
+ table.add_column("Assessment", style="yellow")
+
+ for model, model_stats in stats.items():
+ rate = model_stats["stability_rate"]
+
+ if rate >= 80:
+ assessment = "π’ Excellent"
+ elif rate >= 60:
+ assessment = "π‘ Good"
+ else:
+ assessment = "π΄ Needs Work"
+
+ table.add_row(
+ model,
+ f"{rate:.1f}%",
+ f"{model_stats['average_variance']:.3f}",
+ f"{model_stats['stable_examples']}/{model_stats['total_examples']}",
+ assessment,
+ )
+
+ console.print(table)
+
+
+def main():
+ """Main CLI interface"""
+
+ parser = argparse.ArgumentParser(
+ description="Stability Benchmarking Script for DiffMage"
+ )
+ parser.add_argument("--repo-path", default=".", help="Path to git repository")
+ parser.add_argument("--model", help="AI model to test (default: uses your default)")
+ parser.add_argument("--runs", type=int, default=5, help="Number of runs per test")
+ parser.add_argument(
+ "--variance-threshold",
+ type=float,
+ default=0.3,
+ help="Variance threshold for stability",
+ )
+
+ # Test data options
+ group = parser.add_mutually_exclusive_group()
+ group.add_argument("--commit-range", help="Git commit range (e.g., HEAD~10..HEAD)")
+ group.add_argument(
+ "--test-examples", action="store_true", help="Use curated test examples"
+ )
+ group.add_argument("--single-commit", help="Test single commit by hash")
+
+ # Test options
+ parser.add_argument(
+ "--batch-test", action="store_true", help="Run batch stability test"
+ )
+ parser.add_argument("--comparative-test", help="Compare models (comma-separated)")
+ parser.add_argument(
+ "--max-examples", type=int, default=10, help="Maximum examples to test"
+ )
+ parser.add_argument("--output", help="Save results to JSON file")
+
+ args = parser.parse_args()
+
+ try:
+ suite = StabilityBenchmarkSuite(args.repo_path, args.model)
+ except Exception as e:
+ console.print(f"[red]Error initializing benchmark suite: {e}[/red]")
+ return 1
+
+ results = None
+
+ try:
+ if args.single_commit:
+ # Test single commit
+ commits = list(suite.repo.iter_commits(args.single_commit, max_count=1))
+ if not commits:
+ console.print(f"[red]Commit not found: {args.single_commit}[/red]")
+ return 1
+
+ commit = commits[0]
+ message = str(commit.message).strip()
+ git_diff = suite.diff_parser.parse_specific_commit(commit.hexsha)[1]
+
+ results = suite.run_single_stability_test(
+ commit.hexsha, message, git_diff, args.runs, args.variance_threshold
+ )
+
+ elif args.test_examples:
+ # Use curated examples
+ examples = suite.get_curated_test_examples()
+
+ if args.comparative_test:
+ models = [m.strip() for m in args.comparative_test.split(",")]
+ results = suite.run_comparative_stability_test(
+ examples, models, args.runs
+ )
+ elif args.batch_test:
+ results = suite.run_batch_stability_test(
+ examples, args.runs, args.variance_threshold, args.max_examples
+ )
+ else:
+ # Test first example
+ commit_hash, message, git_diff = examples[0]
+ results = suite.run_single_stability_test(
+ commit_hash, message, git_diff, args.runs, args.variance_threshold
+ )
+
+ else:
+ # Use real repository commits
+ examples = suite.get_real_commit_examples(
+ args.commit_range, args.max_examples
+ )
+
+ if not examples:
+ console.print("[red]No suitable commit examples found[/red]")
+ return 1
+
+ if args.comparative_test:
+ models = [m.strip() for m in args.comparative_test.split(",")]
+ results = suite.run_comparative_stability_test(
+ examples, models, args.runs
+ )
+ elif args.batch_test:
+ results = suite.run_batch_stability_test(
+ examples, args.runs, args.variance_threshold, args.max_examples
+ )
+ else:
+ # Test first example
+ commit_hash, message, git_diff = examples[0]
+ results = suite.run_single_stability_test(
+ commit_hash, message, git_diff, args.runs, args.variance_threshold
+ )
+
+ # Save results if requested
+ if args.output and results:
+ with open(args.output, "w") as f:
+ json.dump(results, f, indent=2, default=str)
+ console.print(f"[green]Results saved to {args.output}[/green]")
+
+ console.print("\n[green]Benchmarking completed successfully! β
[/green]")
+
+ except Exception as e:
+ console.print(f"[red]Error during benchmarking: {e}[/red]")
+ return 1
+
+ return 0
+
+
+if __name__ == "__main__":
+ exit(main())
diff --git a/benchmarks/validation_suite.py b/benchmarks/validation_suite.py
new file mode 100644
index 0000000..bd1ebff
--- /dev/null
+++ b/benchmarks/validation_suite.py
@@ -0,0 +1,924 @@
+"""
+Evaluation *Smoke Test* Validation Suite for DiffMage
+
+This script performs fundamental sanity checks on your LLM evaluator to ensure
+it's working correctly before running extensive benchmarks.
+
+Usage:
+ python validation_suite.py --all
+ python validation_suite.py --test obvious-cases
+ python validation_suite.py --test ranking-consistency
+ python validation_suite.py --test score-distribution
+ python validation_suite.py --test edge-cases
+"""
+
+import argparse
+import statistics
+from typing import List, Dict, Any, Tuple, Optional
+from dataclasses import dataclass
+
+from rich.console import Console
+from rich.table import Table
+from rich.panel import Panel
+from rich.progress import Progress
+from rich import box
+
+try:
+ from diffmage.evaluation.commit_message_evaluator import CommitMessageEvaluator
+except ImportError as e:
+ print(f"Error importing DiffMage modules: {e}")
+ print("Make sure you're running from the project root and the package is installed")
+ exit(1)
+
+console = Console()
+
+
+@dataclass
+class ValidationCase:
+ """A test case for validation"""
+
+ name: str
+ commit_message: str
+ git_diff: str
+ expected_score_range: Tuple[float, float] # (min, max)
+ expected_quality: str # "excellent", "good", "average", "poor", "very_poor"
+ description: str
+
+
+class EvaluationValidator:
+ """Validates that the LLM evaluator is working correctly"""
+
+ def __init__(self, model_name: Optional[str] = None):
+ self.evaluator = CommitMessageEvaluator(model_name=model_name)
+ self.console = console
+
+ def get_obvious_test_cases(self) -> List[ValidationCase]:
+ """Get test cases with obvious expected outcomes"""
+
+ return [
+ # EXCELLENT cases (4.5-5.0)
+ ValidationCase(
+ name="security_fix",
+ commit_message="fix: resolve critical SQL injection vulnerability in user authentication",
+ git_diff="""--- a/src/auth/UserAuth.py
++++ b/src/auth/UserAuth.py
+@@ -23,7 +23,8 @@ class UserAuth:
+ def authenticate_user(self, username, password):
+- query = f"SELECT * FROM users WHERE username='{username}' AND password='{password}'"
++ query = "SELECT * FROM users WHERE username=? AND password=?"
++ return self.db.execute(query, (username, password))
+- return self.db.execute(query)""",
+ expected_score_range=(4.0, 5.0),
+ expected_quality="excellent",
+ description="Clear security fix with good explanation",
+ ),
+ ValidationCase(
+ name="feature_with_context",
+ commit_message="feat: implement user password reset with email verification\n\nAdds secure password reset flow:\n- Generate time-limited reset tokens\n- Send verification emails\n- Validate tokens before allowing reset\n- Log security events for auditing",
+ git_diff="""--- a/src/auth/PasswordReset.py
++++ b/src/auth/PasswordReset.py
+@@ -0,0 +1,45 @@
++import secrets
++import hashlib
++from datetime import datetime, timedelta
++
++class PasswordReset:
++ def __init__(self, email_service, user_service):
++ self.email_service = email_service
++ self.user_service = user_service
++
++ def request_reset(self, email):
++ user = self.user_service.find_by_email(email)
++ if not user:
++ return False # Don't reveal if email exists
++
++ token = secrets.token_urlsafe(32)
++ expires_at = datetime.now() + timedelta(hours=1)
++
++ self.user_service.store_reset_token(user.id, token, expires_at)
++ self.email_service.send_reset_email(email, token)
++
++ return True""",
+ expected_score_range=(4.0, 5.0),
+ expected_quality="excellent",
+ description="Comprehensive feature with detailed explanation",
+ ),
+ # GOOD cases (3.5-4.0)
+ ValidationCase(
+ name="simple_bug_fix",
+ commit_message="fix: handle null values in user profile display",
+ git_diff="""--- a/src/components/UserProfile.jsx
++++ b/src/components/UserProfile.jsx
+@@ -15,7 +15,7 @@ const UserProfile = ({ user }) => {
+
+
{user.name}
+-
Email: {user.email}
++
Email: {user.email || 'Not provided'}
+-
Bio: {user.bio}
++
Bio: {user.bio || 'No bio available'}
+
+ );""",
+ expected_score_range=(3.0, 4.5),
+ expected_quality="good",
+ description="Clear bug fix, could use more detail",
+ ),
+ # AVERAGE cases (2.5-3.5)
+ ValidationCase(
+ name="generic_update",
+ commit_message="update user component",
+ git_diff="""--- a/src/components/User.jsx
++++ b/src/components/User.jsx
+@@ -10,6 +10,7 @@ const User = ({ userData }) => {
+ return (
+
+ {userData.name}
++ {userData.role}
+
+ );""",
+ expected_score_range=(2.0, 3.5),
+ expected_quality="average",
+ description="Vague message, minimal change",
+ ),
+ # POOR cases (1.5-2.5)
+ ValidationCase(
+ name="meaningless_message",
+ commit_message="fix stuff",
+ git_diff="""--- a/src/utils/helper.js
++++ b/src/utils/helper.js
+@@ -5,7 +5,7 @@ function processData(data) {
+ if (!data) {
+ return null;
+ }
+- return data.map(item => item.value);
++ return data.map(item => item.value || 0);
+ }""",
+ expected_score_range=(1.0, 2.5),
+ expected_quality="poor",
+ description="Meaningless commit message",
+ ),
+ # VERY POOR cases (1.0-2.0)
+ ValidationCase(
+ name="gibberish",
+ commit_message="asdf jkl; qwerty",
+ git_diff="""--- a/src/test.js
++++ b/src/test.js
+@@ -1,3 +1,4 @@
+ // Test file
+ console.log('hello');
++console.log('world');""",
+ expected_score_range=(1.0, 2.0),
+ expected_quality="very_poor",
+ description="Nonsensical commit message",
+ ),
+ ]
+
+ def get_edge_test_cases(self) -> List[ValidationCase]:
+ """Get edge cases that might break the evaluator"""
+
+ return [
+ ValidationCase(
+ name="empty_message",
+ commit_message="",
+ git_diff="--- a/file.txt\n+++ b/file.txt\n@@ -1 +1,2 @@\n hello\n+world",
+ expected_score_range=(1.0, 2.0),
+ expected_quality="very_poor",
+ description="Empty commit message",
+ ),
+ ValidationCase(
+ name="very_long_message",
+ commit_message="fix: "
+ + "very " * 100
+ + "long commit message that goes on and on and provides way too much detail about a simple change "
+ * 10,
+ git_diff="--- a/file.txt\n+++ b/file.txt\n@@ -1 +1,2 @@\n hello\n+world",
+ expected_score_range=(1.0, 3.0),
+ expected_quality="poor",
+ description="Excessively long commit message",
+ ),
+ ValidationCase(
+ name="special_characters",
+ commit_message="fix: handle Γ©mojis π and Γ±oΓ±-ASCII Γ§haracters in ΓΌser input",
+ git_diff="""--- a/src/input.py
++++ b/src/input.py
+@@ -1,3 +1,4 @@
+ def process_input(text):
++ text = text.encode('utf-8').decode('utf-8')
+ return text.strip()""",
+ expected_score_range=(3.0, 4.5),
+ expected_quality="good",
+ description="Special characters and emojis",
+ ),
+ ValidationCase(
+ name="no_diff",
+ commit_message="fix: important bug fix",
+ git_diff="",
+ expected_score_range=(1.0, 3.0),
+ expected_quality="poor",
+ description="No diff provided",
+ ),
+ ValidationCase(
+ name="merge_commit",
+ commit_message="Merge branch 'feature/user-auth' into main",
+ git_diff="""--- a/src/auth.py
++++ b/src/auth.py
+@@ -1,10 +1,20 @@
+ # Auth module
++# New authentication features
+
+ def login(user, password):
+ return validate_credentials(user, password)
++
++def logout(user):
++ clear_session(user)
++
++def reset_password(email):
++ send_reset_email(email)""",
+ expected_score_range=(1.0, 3.0),
+ expected_quality="poor",
+ description="Merge commit (usually auto-generated)",
+ ),
+ ]
+
+ def test_obvious_cases(self) -> Dict[str, Any]:
+ """Test if evaluator handles obviously good/bad cases correctly"""
+
+ console.print(
+ Panel(
+ "[bold]Testing Obvious Cases[/bold]\n"
+ "Evaluator should clearly distinguish between excellent and poor commit messages",
+ title="π― Obvious Cases Test",
+ border_style="blue",
+ )
+ )
+
+ test_cases = self.get_obvious_test_cases()
+ results = []
+
+ with Progress(console=self.console) as progress:
+ task = progress.add_task(
+ "Evaluating obvious cases...", total=len(test_cases)
+ )
+
+ for case in test_cases:
+ try:
+ result = self.evaluator.evaluate_commit_message(
+ case.commit_message, case.git_diff
+ )
+
+ # Check if score is in expected range
+ in_range = (
+ case.expected_score_range[0]
+ <= result.overall_score
+ <= case.expected_score_range[1]
+ )
+
+ results.append(
+ {
+ "case": case,
+ "result": result,
+ "in_expected_range": in_range,
+ "score_deviation": self._calculate_deviation(
+ result.overall_score, case.expected_score_range
+ ),
+ }
+ )
+
+ except Exception as e:
+ console.print(f"[red]Error evaluating {case.name}: {e}[/red]")
+ results.append(
+ {
+ "case": case,
+ "result": None,
+ "error": str(e),
+ "in_expected_range": False,
+ "score_deviation": float("inf"),
+ }
+ )
+
+ progress.update(task, advance=1)
+
+ # Analyze results
+ success_rate = (
+ sum(1 for r in results if r.get("in_expected_range", False))
+ / len(results)
+ * 100
+ )
+ avg_deviation = statistics.mean(
+ [
+ r["score_deviation"]
+ for r in results
+ if r["score_deviation"] != float("inf")
+ ]
+ )
+
+ # Display results
+ self._display_obvious_cases_results(results, success_rate, avg_deviation)
+
+ return {
+ "test_name": "obvious_cases",
+ "success_rate": success_rate,
+ "average_deviation": avg_deviation,
+ "results": results,
+ "passed": success_rate >= 70, # 70% of obvious cases should be correct
+ }
+
+ def test_ranking_consistency(self) -> Dict[str, Any]:
+ """Test if evaluator consistently ranks messages in logical order"""
+
+ console.print(
+ Panel(
+ "[bold]Testing Ranking Consistency[/bold]\n"
+ "Evaluator should rank obviously better messages higher than worse ones",
+ title="π Ranking Test",
+ border_style="green",
+ )
+ )
+
+ # Get subset of test cases for ranking
+ test_cases = self.get_obvious_test_cases()
+
+ # Sort by expected quality for comparison
+ expected_order = sorted(
+ test_cases, key=lambda x: x.expected_score_range[1], reverse=True
+ )
+
+ # Evaluate all cases
+ evaluated_cases = []
+
+ with Progress(console=self.console) as progress:
+ task = progress.add_task("Evaluating for ranking...", total=len(test_cases))
+
+ for case in test_cases:
+ try:
+ result = self.evaluator.evaluate_commit_message(
+ case.commit_message, case.git_diff
+ )
+ evaluated_cases.append((case, result))
+ except Exception as e:
+ console.print(f"[red]Error evaluating {case.name}: {e}[/red]")
+ continue
+
+ progress.update(task, advance=1)
+
+ # Sort by actual scores
+ actual_order = sorted(
+ evaluated_cases, key=lambda x: x[1].overall_score, reverse=True
+ )
+
+ # Calculate ranking consistency
+ ranking_violations = self._count_ranking_violations(
+ expected_order, actual_order
+ )
+ total_pairs = len(evaluated_cases) * (len(evaluated_cases) - 1) // 2
+ consistency_rate = (
+ (total_pairs - ranking_violations) / total_pairs * 100
+ if total_pairs > 0
+ else 0
+ )
+
+ # Display results
+ self._display_ranking_results(
+ expected_order, actual_order, consistency_rate, ranking_violations
+ )
+
+ return {
+ "test_name": "ranking_consistency",
+ "consistency_rate": consistency_rate,
+ "ranking_violations": ranking_violations,
+ "total_pairs": total_pairs,
+ "passed": consistency_rate >= 80, # 80% of rankings should be consistent
+ }
+
+ def test_score_distribution(self) -> Dict[str, Any]:
+ """Test if evaluator uses the full score range appropriately"""
+
+ console.print(
+ Panel(
+ "[bold]Testing Score Distribution[/bold]\n"
+ "Evaluator should use the full 1-5 scale and not cluster around one value",
+ title="π Distribution Test",
+ border_style="yellow",
+ )
+ )
+
+ all_cases = self.get_obvious_test_cases() + self.get_edge_test_cases()
+ scores = []
+
+ with Progress(console=self.console) as progress:
+ task = progress.add_task("Collecting scores...", total=len(all_cases))
+
+ for case in all_cases:
+ try:
+ result = self.evaluator.evaluate_commit_message(
+ case.commit_message, case.git_diff
+ )
+ scores.append(
+ {
+ "case_name": case.name,
+ "overall_score": result.overall_score,
+ "what_score": result.what_score,
+ "why_score": result.why_score,
+ "expected_quality": case.expected_quality,
+ }
+ )
+ except Exception as e:
+ console.print(f"[red]Error evaluating {case.name}: {e}[/red]")
+ continue
+
+ progress.update(task, advance=1)
+
+ if not scores:
+ return {
+ "test_name": "score_distribution",
+ "passed": False,
+ "error": "No scores collected",
+ }
+
+ # Analyze distribution
+ overall_scores = [s["overall_score"] for s in scores]
+ distribution_stats = {
+ "mean": statistics.mean(overall_scores),
+ "median": statistics.median(overall_scores),
+ "std_dev": statistics.stdev(overall_scores)
+ if len(overall_scores) > 1
+ else 0,
+ "min": min(overall_scores),
+ "max": max(overall_scores),
+ "range": max(overall_scores) - min(overall_scores),
+ }
+
+ # Check for problems
+ problems = []
+ if distribution_stats["std_dev"] < 0.5:
+ problems.append("Low variance - scores too clustered")
+ if distribution_stats["range"] < 2.0:
+ problems.append("Narrow range - not using full scale")
+ if distribution_stats["mean"] > 4.0:
+ problems.append("Grade inflation - scores too high")
+ if distribution_stats["mean"] < 2.0:
+ problems.append("Grade deflation - scores too low")
+
+ # Display results
+ self._display_distribution_results(distribution_stats, scores, problems)
+
+ return {
+ "test_name": "score_distribution",
+ "distribution_stats": distribution_stats,
+ "problems": problems,
+ "scores": scores,
+ "passed": len(problems) == 0,
+ }
+
+ def test_edge_cases(self) -> Dict[str, Any]:
+ """Test if evaluator handles edge cases gracefully"""
+
+ console.print(
+ Panel(
+ "[bold]Testing Edge Cases[/bold]\n"
+ "Evaluator should handle unusual inputs without crashing",
+ title="β οΈ Edge Cases Test",
+ border_style="red",
+ )
+ )
+
+ edge_cases = self.get_edge_test_cases()
+ results = []
+
+ with Progress(console=self.console) as progress:
+ task = progress.add_task("Testing edge cases...", total=len(edge_cases))
+
+ for case in edge_cases:
+ try:
+ result = self.evaluator.evaluate_commit_message(
+ case.commit_message, case.git_diff
+ )
+
+ # Check if result is reasonable
+ is_reasonable = (
+ 1.0 <= result.overall_score <= 5.0
+ and 1.0 <= result.what_score <= 5.0
+ and 1.0 <= result.why_score <= 5.0
+ and result.reasoning
+ and len(result.reasoning) > 10
+ )
+
+ results.append(
+ {
+ "case": case,
+ "result": result,
+ "handled_gracefully": True,
+ "is_reasonable": is_reasonable,
+ "error": None,
+ }
+ )
+
+ except Exception as e:
+ results.append(
+ {
+ "case": case,
+ "result": None,
+ "handled_gracefully": False,
+ "is_reasonable": False,
+ "error": str(e),
+ }
+ )
+
+ progress.update(task, advance=1)
+
+ # Analyze results
+ graceful_handling_rate = (
+ sum(1 for r in results if r["handled_gracefully"]) / len(results) * 100
+ )
+ reasonable_results_rate = (
+ sum(1 for r in results if r.get("is_reasonable", False))
+ / len(results)
+ * 100
+ )
+
+ # Display results
+ self._display_edge_cases_results(
+ results, graceful_handling_rate, reasonable_results_rate
+ )
+
+ return {
+ "test_name": "edge_cases",
+ "graceful_handling_rate": graceful_handling_rate,
+ "reasonable_results_rate": reasonable_results_rate,
+ "results": results,
+ "passed": graceful_handling_rate >= 90 and reasonable_results_rate >= 70,
+ }
+
+ def run_all_tests(self) -> Dict[str, Any]:
+ """Run all validation tests"""
+
+ console.print(
+ Panel(
+ "[bold]Running Complete Validation Suite[/bold]\n"
+ "Testing fundamental evaluator functionality",
+ title="π§ͺ Validation Suite",
+ border_style="cyan",
+ )
+ )
+
+ all_results = {}
+
+ # Run each test
+ tests = [
+ ("obvious_cases", self.test_obvious_cases),
+ ("ranking_consistency", self.test_ranking_consistency),
+ ("score_distribution", self.test_score_distribution),
+ ("edge_cases", self.test_edge_cases),
+ ]
+
+ for test_name, test_func in tests:
+ console.print(
+ f"\n[blue]Running {test_name.replace('_', ' ').title()} test...[/blue]"
+ )
+ try:
+ all_results[test_name] = test_func()
+ except Exception as e:
+ console.print(f"[red]Test {test_name} failed with error: {e}[/red]")
+ all_results[test_name] = {
+ "test_name": test_name,
+ "passed": False,
+ "error": str(e),
+ }
+
+ # Overall assessment
+ passed_tests = sum(
+ 1 for result in all_results.values() if result.get("passed", False)
+ )
+ total_tests = len(all_results)
+ overall_pass_rate = passed_tests / total_tests * 100
+ model_used: str = self.evaluator.model_name
+
+ # Display summary
+ self._display_overall_summary(all_results, overall_pass_rate)
+
+ return {
+ "overall_pass_rate": overall_pass_rate,
+ "passed_tests": passed_tests,
+ "total_tests": total_tests,
+ "model_used": model_used,
+ "individual_results": all_results,
+ "evaluator_ready": overall_pass_rate >= 75, # 75% of tests should pass
+ }
+
+ # Helper methods for calculations and display
+ def _calculate_deviation(
+ self, actual_score: float, expected_range: Tuple[float, float]
+ ) -> float:
+ """Calculate how far actual score is from expected range"""
+ if expected_range[0] <= actual_score <= expected_range[1]:
+ return 0.0
+ elif actual_score < expected_range[0]:
+ return expected_range[0] - actual_score
+ else:
+ return actual_score - expected_range[1]
+
+ def _count_ranking_violations(
+ self, expected_order: List, actual_order: List
+ ) -> int:
+ """Count pairs that are ranked in wrong order"""
+ violations = 0
+ actual_scores = {
+ case.name: score
+ for case, score in [(c, r.overall_score) for c, r in actual_order]
+ }
+
+ for i, case1 in enumerate(expected_order):
+ for case2 in expected_order[i + 1 :]:
+ # case1 should score higher than case2
+ if actual_scores.get(case1.name, 0) < actual_scores.get(case2.name, 0):
+ violations += 1
+
+ return violations
+
+ def _display_obvious_cases_results(
+ self, results: List, success_rate: float, avg_deviation: float
+ ):
+ """Display obvious cases test results"""
+
+ table = Table(title="Obvious Cases Results", box=box.SIMPLE)
+ table.add_column("Case", style="cyan")
+ table.add_column("Expected", justify="center")
+ table.add_column("Actual", justify="center")
+ table.add_column("Status", justify="center")
+ table.add_column("Deviation", justify="center")
+
+ for r in results:
+ if r.get("result"):
+ status = "β
Pass" if r["in_expected_range"] else "β Fail"
+ status_color = "green" if r["in_expected_range"] else "red"
+
+ expected_range = f"{r['case'].expected_score_range[0]:.1f}-{r['case'].expected_score_range[1]:.1f}"
+ actual_score = f"{r['result'].overall_score:.1f}"
+ deviation = f"{r['score_deviation']:.1f}"
+
+ table.add_row(
+ r["case"].name,
+ expected_range,
+ actual_score,
+ f"[{status_color}]{status}[/{status_color}]",
+ deviation,
+ )
+ else:
+ table.add_row(
+ r["case"].name, "N/A", "ERROR", "[red]β Error[/red]", "β"
+ )
+
+ console.print(table)
+
+ # Summary
+ summary_color = (
+ "green" if success_rate >= 70 else "yellow" if success_rate >= 50 else "red"
+ )
+ console.print(
+ f"\n[{summary_color}]Success Rate: {success_rate:.1f}% | Average Deviation: {avg_deviation:.2f}[/{summary_color}]"
+ )
+
+ def _display_ranking_results(
+ self,
+ expected_order: List,
+ actual_order: List,
+ consistency_rate: float,
+ violations: int,
+ ):
+ """Display ranking consistency results"""
+
+ table = Table(title="Ranking Comparison", box=box.SIMPLE)
+ table.add_column("Rank", justify="center")
+ table.add_column("Expected", style="cyan")
+ table.add_column("Actual", style="yellow")
+ table.add_column("Score", justify="center")
+
+ for i, (case, result) in enumerate(actual_order[:5]): # Show top 5
+ expected_case = expected_order[i] if i < len(expected_order) else None
+ expected_name = expected_case.name if expected_case else "N/A"
+
+ table.add_row(
+ str(i + 1), expected_name, case.name, f"{result.overall_score:.1f}"
+ )
+
+ console.print(table)
+
+ # Summary
+ summary_color = (
+ "green"
+ if consistency_rate >= 80
+ else "yellow"
+ if consistency_rate >= 60
+ else "red"
+ )
+ console.print(
+ f"\n[{summary_color}]Ranking Consistency: {consistency_rate:.1f}% | Violations: {violations}[/{summary_color}]"
+ )
+
+ def _display_distribution_results(self, stats: Dict, scores: List, problems: List):
+ """Display score distribution results"""
+
+ table = Table(title="Score Distribution Statistics", box=box.SIMPLE)
+ table.add_column("Metric", style="cyan")
+ table.add_column("Value", justify="center")
+ table.add_column("Assessment", style="yellow")
+
+ table.add_row(
+ "Mean",
+ f"{stats['mean']:.2f}",
+ "Good" if 2.5 <= stats["mean"] <= 3.5 else "Check",
+ )
+ table.add_row(
+ "Std Dev",
+ f"{stats['std_dev']:.2f}",
+ "Good" if stats["std_dev"] >= 0.5 else "Low variance",
+ )
+ table.add_row(
+ "Range",
+ f"{stats['range']:.2f}",
+ "Good" if stats["range"] >= 2.0 else "Narrow range",
+ )
+ table.add_row("Min", f"{stats['min']:.2f}", "")
+ table.add_row("Max", f"{stats['max']:.2f}", "")
+
+ console.print(table)
+
+ if problems:
+ console.print("\n[red]Issues Found:[/red]")
+ for problem in problems:
+ console.print(f" β’ {problem}")
+ else:
+ console.print("\n[green]β
Score distribution looks healthy[/green]")
+
+ def _display_edge_cases_results(
+ self, results: List, graceful_rate: float, reasonable_rate: float
+ ):
+ """Display edge cases test results"""
+
+ table = Table(title="Edge Cases Results", box=box.SIMPLE)
+ table.add_column("Case", style="cyan")
+ table.add_column("Handled", justify="center")
+ table.add_column("Reasonable", justify="center")
+ table.add_column("Error", style="red")
+
+ for r in results:
+ handled = "β
Yes" if r["handled_gracefully"] else "β No"
+ reasonable = "β
Yes" if r.get("is_reasonable", False) else "β No"
+ error = (
+ r.get("error", "")[:50] + "..."
+ if r.get("error") and len(r.get("error", "")) > 50
+ else r.get("error", "")
+ )
+
+ table.add_row(r["case"].name, handled, reasonable, error)
+
+ console.print(table)
+
+ console.print(
+ f"\n[blue]Graceful Handling: {graceful_rate:.1f}% | Reasonable Results: {reasonable_rate:.1f}%[/blue]"
+ )
+
+ def _display_overall_summary(self, all_results: Dict, overall_pass_rate: float):
+ """Display overall validation summary"""
+
+ console.print(
+ Panel(
+ f"[bold]Validation Complete[/bold]\n"
+ f"Overall Pass Rate: {overall_pass_rate:.1f}%",
+ title="π Summary",
+ border_style="green"
+ if overall_pass_rate >= 75
+ else "yellow"
+ if overall_pass_rate >= 50
+ else "red",
+ )
+ )
+
+ summary_table = Table(title="Test Summary", box=box.SIMPLE)
+ summary_table.add_column("Test", style="cyan")
+ summary_table.add_column("Status", justify="center")
+ summary_table.add_column("Key Metric", justify="center")
+
+ for test_name, result in all_results.items():
+ status = "β
Pass" if result.get("passed", False) else "β Fail"
+ status_color = "green" if result.get("passed", False) else "red"
+
+ # Extract key metric based on test type
+ if test_name == "obvious_cases":
+ key_metric = f"{result.get('success_rate', 0):.1f}% correct"
+ elif test_name == "ranking_consistency":
+ key_metric = f"{result.get('consistency_rate', 0):.1f}% consistent"
+ elif test_name == "score_distribution":
+ problems = result.get("problems", [])
+ key_metric = f"{len(problems)} issues"
+ elif test_name == "edge_cases":
+ key_metric = f"{result.get('graceful_handling_rate', 0):.1f}% handled"
+ else:
+ key_metric = "N/A"
+
+ summary_table.add_row(
+ test_name.replace("_", " ").title(),
+ f"[{status_color}]{status}[/{status_color}]",
+ key_metric,
+ )
+
+ console.print(summary_table)
+
+ # Recommendations
+ if overall_pass_rate >= 75:
+ console.print(
+ "\n[green]π Evaluator validation passed! Ready for benchmarking and research.[/green]"
+ )
+ elif overall_pass_rate >= 50:
+ console.print(
+ "\n[yellow]β οΈ Evaluator has some issues but may be usable. Review failed tests.[/yellow]"
+ )
+ else:
+ console.print(
+ "\n[red]β Evaluator has significant issues. Fix core problems before proceeding.[/red]"
+ )
+
+
+def main():
+ """Main CLI interface"""
+
+ parser = argparse.ArgumentParser(
+ description="Validation Suite for DiffMage Evaluator"
+ )
+ parser.add_argument("--model", help="AI model to test (default: uses your default)")
+
+ test_group = parser.add_mutually_exclusive_group()
+ test_group.add_argument(
+ "--all", action="store_true", help="Run all validation tests"
+ )
+ test_group.add_argument(
+ "--test",
+ choices=[
+ "obvious-cases",
+ "ranking-consistency",
+ "score-distribution",
+ "edge-cases",
+ ],
+ help="Run specific test",
+ )
+
+ parser.add_argument("--output", help="Save results to JSON file")
+ parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output")
+
+ args = parser.parse_args()
+
+ if not args.test and not args.all:
+ args.all = True
+
+ try:
+ validator = EvaluationValidator(model_name=args.model)
+ except Exception as e:
+ console.print(f"[red]Error initializing validator: {e}[/red]")
+ console.print(
+ "[dim]Make sure you're in the project root and DiffMage is properly installed[/dim]"
+ )
+ return 1
+
+ results = None
+
+ try:
+ if args.all:
+ console.print("[blue]Running complete validation suite...[/blue]")
+ results = validator.run_all_tests()
+
+ elif args.test == "obvious-cases":
+ results = validator.test_obvious_cases()
+
+ elif args.test == "ranking-consistency":
+ results = validator.test_ranking_consistency()
+
+ elif args.test == "score-distribution":
+ results = validator.test_score_distribution()
+
+ elif args.test == "edge-cases":
+ results = validator.test_edge_cases()
+
+ if args.output and results:
+ import json
+
+ with open(args.output, "w") as f:
+ json.dump(results, f, indent=2, default=str)
+ console.print(f"[green]Results saved to {args.output}[/green]")
+
+ if results and results.get("evaluator_ready", results.get("passed", False)):
+ console.print("\n[green]β
Validation completed successfully![/green]")
+ return 0
+ else:
+ console.print("\n[red]β Validation found issues that need attention[/red]")
+ return 1
+
+ except Exception as e:
+ console.print(f"[red]Error during validation: {e}[/red]")
+ if args.verbose:
+ import traceback
+
+ console.print(f"[dim]{traceback.format_exc()}[/dim]")
+ return 1
+
+
+if __name__ == "__main__":
+ exit(main())