diff --git a/benchmarks/stability_benchmarks.py b/benchmarks/stability_benchmarks.py new file mode 100644 index 0000000..0c4082f --- /dev/null +++ b/benchmarks/stability_benchmarks.py @@ -0,0 +1,736 @@ +""" +Stability Benchmarking Script for DiffMage + +This script uses your existing EvaluationBenchmarks class to test evaluation +stability across various real commit examples and scenarios. + +Usage: + python benchmark_stability.py --repo-path /path/to/repo --runs 10 + python benchmark_stability.py --commit-range "HEAD~20..HEAD" --runs 5 + python benchmark_stability.py --test-examples --model gpt-4 + python benchmark_stability.py --batch-test --variance-threshold 0.2 +""" + +import argparse +import json +import time +from pathlib import Path +from typing import List, Dict, Any, Tuple, Optional +from datetime import datetime +import git +import random + +from rich.console import Console +from rich.table import Table +from rich.panel import Panel +from rich.progress import Progress +from rich import box + +from diffmage.evaluation.benchmarks import EvaluationBenchmarks, StabilityTestResult +from diffmage.evaluation.commit_message_evaluator import CommitMessageEvaluator +from diffmage.git.diff_parser import GitDiffParser + +console = Console() + + +class StabilityBenchmarkSuite: + """Comprehensive stability testing using real commit data""" + + def __init__(self, repo_path: str = ".", model_name: Optional[str] = None): + self.repo_path = Path(repo_path) + self.repo = git.Repo(repo_path) + self.diff_parser = GitDiffParser(repo_path) + self.evaluator = CommitMessageEvaluator(model_name=model_name) + self.benchmarks = EvaluationBenchmarks(self.evaluator) + self.console = console + + def get_real_commit_examples( + self, commit_range: Optional[str] = None, max_examples: int = 20 + ) -> List[Tuple[str, str, str]]: + """Extract real commit examples from repository + + Returns: + List of (commit_hash, commit_message, git_diff) tuples + """ + + if commit_range: + try: + commits = list(self.repo.iter_commits(commit_range)) + except Exception as e: + console.print( + f"[red]Error parsing commit range '{commit_range}': {e}[/red]" + ) + commits = list(self.repo.iter_commits("HEAD", max_count=max_examples)) + else: + commits = list(self.repo.iter_commits("HEAD", max_count=max_examples)) + + examples = [] + + console.print( + f"[blue]Extracting examples from {len(commits)} commits...[/blue]" + ) + + with Progress(console=self.console) as progress: + task = progress.add_task("Processing commits...", total=len(commits)) + + for commit in commits: + try: + message = str(commit.message).strip() + + if len(commit.parents) > 1 or not message or len(message) < 10: + progress.update(task, advance=1) + continue + + git_diff = self.diff_parser.parse_specific_commit(commit.hexsha) + + if not git_diff or len(git_diff) > 10000: + progress.update(task, advance=1) + continue + + examples.append((commit.hexsha, message, git_diff)) + + if len(examples) >= max_examples: + break + + except Exception as e: + console.print( + f"[dim]Warning: Skipped commit {commit.hexsha[:8]}: {e}[/dim]" + ) + + progress.update(task, advance=1) + + console.print( + f"[green]Successfully extracted {len(examples)} usable commit examples[/green]" + ) + return examples + + def get_curated_test_examples(self) -> List[Tuple[str, str, str]]: + """Get curated test examples with various commit patterns + + Returns: + List of (example_id, commit_message, git_diff) tuples + """ + + examples = [ + ( + "bug_fix_null_check", + "fix: resolve null pointer exception in user validation", + """--- a/src/auth/UserValidator.java ++++ b/src/auth/UserValidator.java +@@ -23,7 +23,7 @@ public class UserValidator { + } + + public boolean isValidUser(User user) { +- return user.getEmail().contains("@"); ++ return user != null && user.getEmail() != null && user.getEmail().contains("@"); + } + }""", + ), + ( + "feature_dark_mode", + "feat: implement dark mode toggle in user preferences", + """--- a/src/components/UserPreferences.jsx ++++ b/src/components/UserPreferences.jsx +@@ -8,6 +8,7 @@ const UserPreferences = () => { + const [language, setLanguage] = useState('en'); + const [notifications, setNotifications] = useState(true); ++ const [darkMode, setDarkMode] = useState(false); + + const savePreferences = async () => { + const prefs = { +@@ -15,6 +16,7 @@ const UserPreferences = () => { + language, + notifications, ++ darkMode, + }; + + await api.updatePreferences(prefs); +@@ -35,6 +37,13 @@ const UserPreferences = () => { + onChange={(e) => setNotifications(e.target.checked)} + /> + ++ ++
++ ++ setDarkMode(e.target.checked)} ++ /> ++
+ + );""", + ), + ( + "refactor_extract_method", + "refactor: extract user authentication logic into separate service", + """--- a/src/controllers/AuthController.js ++++ b/src/controllers/AuthController.js +@@ -1,4 +1,5 @@ + const bcrypt = require('bcrypt'); ++const AuthService = require('../services/AuthService'); + + class AuthController { + async login(req, res) { +@@ -6,15 +7,7 @@ class AuthController { + + try { + // Authenticate user +- const user = await User.findOne({ email }); +- if (!user) { +- return res.status(401).json({ error: 'Invalid credentials' }); +- } +- +- const isValidPassword = await bcrypt.compare(password, user.passwordHash); +- if (!isValidPassword) { +- return res.status(401).json({ error: 'Invalid credentials' }); +- } ++ const user = await AuthService.authenticateUser(email, password); + + const token = jwt.sign({ userId: user.id }, process.env.JWT_SECRET); + res.json({ token, user: { id: user.id, email: user.email } }); +@@ -24,4 +17,22 @@ class AuthController { + } + } + ++--- /dev/null +++++ b/src/services/AuthService.js +@@ -0,0 +1,22 @@ ++const bcrypt = require('bcrypt'); ++const User = require('../models/User'); ++ ++class AuthService { ++ static async authenticateUser(email, password) { ++ const user = await User.findOne({ email }); ++ if (!user) { ++ throw new Error('Invalid credentials'); ++ } ++ ++ const isValidPassword = await bcrypt.compare(password, user.passwordHash); ++ if (!isValidPassword) { ++ throw new Error('Invalid credentials'); ++ } ++ ++ return user; ++ } ++} ++ ++module.exports = AuthService;""", + ), + ( + "docs_api_update", + "docs: update API documentation for user endpoints", + """--- a/docs/api/users.md ++++ b/docs/api/users.md +@@ -23,6 +23,20 @@ Creates a new user account. + - `409 Conflict` - Email already exists + - `500 Internal Server Error` - Server error + ++### Request Example ++ ++```json ++{ ++ "email": "user@example.com", ++ "password": "securePassword123", ++ "firstName": "John", ++ "lastName": "Doe" ++} ++``` ++ ++### Response Example ++ ++```json + ## GET /api/users/:id + + Retrieves user information by ID. +@@ -35,3 +49,15 @@ Retrieves user information by ID. + - `200 OK` - User found + - `404 Not Found` - User not found + - `500 Internal Server Error` - Server error ++ ++### Response Example ++ ++```json ++{ ++ "id": "123e4567-e89b-12d3-a456-426614174000", ++ "email": "user@example.com", ++ "firstName": "John", ++ "lastName": "Doe", ++ "createdAt": "2024-01-15T10:30:00Z" ++} ++```""", + ), + ( + "performance_optimize_query", + "perf: optimize database query for user search with indexing", + """--- a/src/models/User.js ++++ b/src/models/User.js +@@ -15,8 +15,12 @@ const userSchema = new mongoose.Schema({ + } + }); + ++// Add indexes for better query performance ++userSchema.index({ email: 1 }); ++userSchema.index({ firstName: 1, lastName: 1 }); ++userSchema.index({ createdAt: -1 }); ++ + // Static method for user search +-userSchema.statics.searchUsers = function(searchTerm) { ++userSchema.statics.searchUsers = function(searchTerm, limit = 20) { + const regex = new RegExp(searchTerm, 'i'); + return this.find({ + $or: [ +@@ -24,7 +28,8 @@ userSchema.statics.searchUsers = function(searchTerm) { + { firstName: regex }, + { lastName: regex } + ] +- }); ++ }) ++ .limit(limit) ++ .sort({ createdAt: -1 }); + };""", + ), + ( + "test_add_validation", + "test: add unit tests for email validation utility", + """--- /dev/null ++++ b/tests/utils/emailValidator.test.js +@@ -0,0 +1,45 @@ ++const { validateEmail } = require('../../src/utils/emailValidator'); ++ ++describe('Email Validator', () => { ++ describe('valid emails', () => { ++ test('should validate standard email format', () => { ++ expect(validateEmail('user@example.com')).toBe(true); ++ expect(validateEmail('john.doe@company.org')).toBe(true); ++ expect(validateEmail('admin@subdomain.example.co.uk')).toBe(true); ++ }); ++ ++ test('should validate emails with plus signs', () => { ++ expect(validateEmail('user+tag@example.com')).toBe(true); ++ }); ++ ++ test('should validate emails with numbers', () => { ++ expect(validateEmail('user123@example123.com')).toBe(true); ++ }); ++ }); ++ ++ describe('invalid emails', () => { ++ test('should reject emails without @ symbol', () => { ++ expect(validateEmail('userexample.com')).toBe(false); ++ }); ++ ++ test('should reject emails without domain', () => { ++ expect(validateEmail('user@')).toBe(false); ++ }); ++ ++ test('should reject emails without user part', () => { ++ expect(validateEmail('@example.com')).toBe(false); ++ }); ++ ++ test('should reject empty strings', () => { ++ expect(validateEmail('')).toBe(false); ++ }); ++ ++ test('should reject null and undefined', () => { ++ expect(validateEmail(null)).toBe(false); ++ expect(validateEmail(undefined)).toBe(false); ++ }); ++ }); ++});""", + ), + ] + + console.print(f"[green]Loaded {len(examples)} curated test examples[/green]") + return examples + + def run_single_stability_test( + self, + commit_hash: str, + message: str, + git_diff: str, + runs: int = 5, + variance_threshold: float = 0.3, + ) -> StabilityTestResult: + """Run stability test on a single commit example""" + + console.print(f"[blue]Testing stability for: {message[:60]}...[/blue]") + console.print( + f"[dim]Commit: {commit_hash[:8] if len(commit_hash) > 8 else commit_hash}[/dim]" + ) + + result = self.benchmarks.stability_test( + message=message, + diff=git_diff, + runs=runs, + variance_threshold=variance_threshold, + ) + + return result + + def run_batch_stability_test( + self, + examples: List[Tuple[str, str, str]], + runs: int = 5, + variance_threshold: float = 0.3, + max_examples: Optional[int] = None, + ) -> Dict[str, Any]: + """Run stability tests on multiple examples""" + + if max_examples and len(examples) > max_examples: + examples = random.sample(examples, max_examples) + + console.print( + f"[blue]Running batch stability test on {len(examples)} examples...[/blue]" + ) + console.print( + f"[dim]Runs per example: {runs}, Variance threshold: {variance_threshold}[/dim]" + ) + + results = [] + stable_count = 0 + total_time = 0 + + with Progress(console=self.console) as progress: + task = progress.add_task("Testing stability...", total=len(examples)) + + for commit_hash, message, git_diff in examples: + start_time = time.time() + + try: + result = self.benchmarks.stability_test( + message=message, + diff=git_diff, + runs=runs, + variance_threshold=variance_threshold, + ) + + results.append(result) + + if result["is_stable"]: # type: ignore + stable_count += 1 + + except Exception as e: + console.print(f"[red]Error testing {commit_hash[:8]}: {e}[/red]") + continue + + total_time += time.time() - start_time + progress.update(task, advance=1) + + stability_rate = (stable_count / len(results)) * 100 if results else 0 + avg_time_per_test = total_time / len(results) if results else 0 + + summary = { + "total_examples": len(examples), + "successful_tests": len(results), + "stable_examples": stable_count, + "stability_rate": stability_rate, + "average_time_per_test": avg_time_per_test, + "total_test_time": total_time, + "runs_per_example": runs, + "variance_threshold": variance_threshold, + } + + self._display_batch_summary(summary) + + return { + "summary": summary, + "individual_results": results, + "timestamp": datetime.now().isoformat(), + } + + def _display_batch_summary(self, summary: Dict[str, Any]): + """Display batch test summary""" + + console.print( + Panel( + f"[bold]Batch Stability Test Results[/bold]\n" + f"Examples Tested: {summary['successful_tests']}/{summary['total_examples']}\n" + f"Stable Examples: {summary['stable_examples']}\n" + f"Stability Rate: {summary['stability_rate']:.1f}%", + title="πŸ“Š Summary", + border_style="green" + if summary["stability_rate"] >= 80 + else "yellow" + if summary["stability_rate"] >= 60 + else "red", + ) + ) + + # Detailed table + table = Table(title="Performance Metrics", box=box.SIMPLE) + table.add_column("Metric", style="cyan") + table.add_column("Value", justify="center") + table.add_column("Assessment", style="yellow") + + # Stability rate assessment + if summary["stability_rate"] >= 90: + stability_assessment = "🟒 Excellent" + elif summary["stability_rate"] >= 80: + stability_assessment = "🟑 Good" + elif summary["stability_rate"] >= 60: + stability_assessment = "🟠 Acceptable" + else: + stability_assessment = "πŸ”΄ Needs Improvement" + + # Time assessment + avg_time = summary["average_time_per_test"] + if avg_time < 30: + time_assessment = "🟒 Fast" + elif avg_time < 60: + time_assessment = "🟑 Reasonable" + else: + time_assessment = "πŸ”΄ Slow" + + table.add_row( + "Stability Rate", f"{summary['stability_rate']:.1f}%", stability_assessment + ) + table.add_row("Avg Time per Test", f"{avg_time:.1f}s", time_assessment) + table.add_row("Total Test Time", f"{summary['total_test_time']:.1f}s", "") + table.add_row("Runs per Example", str(summary["runs_per_example"]), "") + table.add_row("Variance Threshold", str(summary["variance_threshold"]), "") + + console.print(table) + + def run_comparative_stability_test( + self, examples: List[Tuple[str, str, str]], models: List[str], runs: int = 3 + ) -> Dict[str, Any]: + """Compare stability across different models""" + + console.print( + f"[blue]Running comparative stability test across {len(models)} models...[/blue]" + ) + + model_results = {} + + for model in models: + console.print(f"[yellow]Testing model: {model}[/yellow]") + + # Create new evaluator for this model + evaluator = CommitMessageEvaluator(model_name=model) + benchmarks = EvaluationBenchmarks(evaluator) + + model_stability_results = [] + + for i, (commit_hash, message, git_diff) in enumerate( + examples[:5] + ): # Limit for comparative test + try: + result = benchmarks.stability_test( + message=message, + diff=git_diff, + runs=runs, + variance_threshold=0.3, + ) + + model_stability_results.append( + { + "example_id": i, + "is_stable": result["is_stable"], + "max_variance": result["max_variance"], + "commit_hash": commit_hash, + } + ) + + except Exception as e: + console.print( + f"[red]Error with {model} on {commit_hash[:8]}: {e}[/red]" + ) + continue + + model_results[model] = model_stability_results + + # Calculate comparative statistics + comparative_stats = self._calculate_comparative_stats(model_results) + self._display_comparative_results(comparative_stats) + + return { + "model_results": model_results, + "comparative_stats": comparative_stats, + "timestamp": datetime.now().isoformat(), + } + + def _calculate_comparative_stats( + self, model_results: Dict[str, List] + ) -> Dict[str, Any]: + """Calculate comparative statistics across models""" + + stats = {} + + for model, results in model_results.items(): + if not results: + continue + + stable_count = sum(1 for r in results if r["is_stable"]) + total_count = len(results) + avg_variance = ( + sum(r["max_variance"] for r in results) / total_count + if total_count > 0 + else 0 + ) + + stats[model] = { + "stability_rate": (stable_count / total_count * 100) + if total_count > 0 + else 0, + "average_variance": avg_variance, + "stable_examples": stable_count, + "total_examples": total_count, + } + + return stats + + def _display_comparative_results(self, stats: Dict[str, Any]): + """Display comparative model results""" + + table = Table(title="Model Stability Comparison", box=box.SIMPLE) + table.add_column("Model", style="cyan") + table.add_column("Stability Rate", justify="center") + table.add_column("Avg Variance", justify="center") + table.add_column("Stable/Total", justify="center") + table.add_column("Assessment", style="yellow") + + for model, model_stats in stats.items(): + rate = model_stats["stability_rate"] + + if rate >= 80: + assessment = "🟒 Excellent" + elif rate >= 60: + assessment = "🟑 Good" + else: + assessment = "πŸ”΄ Needs Work" + + table.add_row( + model, + f"{rate:.1f}%", + f"{model_stats['average_variance']:.3f}", + f"{model_stats['stable_examples']}/{model_stats['total_examples']}", + assessment, + ) + + console.print(table) + + +def main(): + """Main CLI interface""" + + parser = argparse.ArgumentParser( + description="Stability Benchmarking Script for DiffMage" + ) + parser.add_argument("--repo-path", default=".", help="Path to git repository") + parser.add_argument("--model", help="AI model to test (default: uses your default)") + parser.add_argument("--runs", type=int, default=5, help="Number of runs per test") + parser.add_argument( + "--variance-threshold", + type=float, + default=0.3, + help="Variance threshold for stability", + ) + + # Test data options + group = parser.add_mutually_exclusive_group() + group.add_argument("--commit-range", help="Git commit range (e.g., HEAD~10..HEAD)") + group.add_argument( + "--test-examples", action="store_true", help="Use curated test examples" + ) + group.add_argument("--single-commit", help="Test single commit by hash") + + # Test options + parser.add_argument( + "--batch-test", action="store_true", help="Run batch stability test" + ) + parser.add_argument("--comparative-test", help="Compare models (comma-separated)") + parser.add_argument( + "--max-examples", type=int, default=10, help="Maximum examples to test" + ) + parser.add_argument("--output", help="Save results to JSON file") + + args = parser.parse_args() + + try: + suite = StabilityBenchmarkSuite(args.repo_path, args.model) + except Exception as e: + console.print(f"[red]Error initializing benchmark suite: {e}[/red]") + return 1 + + results = None + + try: + if args.single_commit: + # Test single commit + commits = list(suite.repo.iter_commits(args.single_commit, max_count=1)) + if not commits: + console.print(f"[red]Commit not found: {args.single_commit}[/red]") + return 1 + + commit = commits[0] + message = str(commit.message).strip() + git_diff = suite.diff_parser.parse_specific_commit(commit.hexsha)[1] + + results = suite.run_single_stability_test( + commit.hexsha, message, git_diff, args.runs, args.variance_threshold + ) + + elif args.test_examples: + # Use curated examples + examples = suite.get_curated_test_examples() + + if args.comparative_test: + models = [m.strip() for m in args.comparative_test.split(",")] + results = suite.run_comparative_stability_test( + examples, models, args.runs + ) + elif args.batch_test: + results = suite.run_batch_stability_test( + examples, args.runs, args.variance_threshold, args.max_examples + ) + else: + # Test first example + commit_hash, message, git_diff = examples[0] + results = suite.run_single_stability_test( + commit_hash, message, git_diff, args.runs, args.variance_threshold + ) + + else: + # Use real repository commits + examples = suite.get_real_commit_examples( + args.commit_range, args.max_examples + ) + + if not examples: + console.print("[red]No suitable commit examples found[/red]") + return 1 + + if args.comparative_test: + models = [m.strip() for m in args.comparative_test.split(",")] + results = suite.run_comparative_stability_test( + examples, models, args.runs + ) + elif args.batch_test: + results = suite.run_batch_stability_test( + examples, args.runs, args.variance_threshold, args.max_examples + ) + else: + # Test first example + commit_hash, message, git_diff = examples[0] + results = suite.run_single_stability_test( + commit_hash, message, git_diff, args.runs, args.variance_threshold + ) + + # Save results if requested + if args.output and results: + with open(args.output, "w") as f: + json.dump(results, f, indent=2, default=str) + console.print(f"[green]Results saved to {args.output}[/green]") + + console.print("\n[green]Benchmarking completed successfully! βœ…[/green]") + + except Exception as e: + console.print(f"[red]Error during benchmarking: {e}[/red]") + return 1 + + return 0 + + +if __name__ == "__main__": + exit(main()) diff --git a/benchmarks/validation_suite.py b/benchmarks/validation_suite.py new file mode 100644 index 0000000..bd1ebff --- /dev/null +++ b/benchmarks/validation_suite.py @@ -0,0 +1,924 @@ +""" +Evaluation *Smoke Test* Validation Suite for DiffMage + +This script performs fundamental sanity checks on your LLM evaluator to ensure +it's working correctly before running extensive benchmarks. + +Usage: + python validation_suite.py --all + python validation_suite.py --test obvious-cases + python validation_suite.py --test ranking-consistency + python validation_suite.py --test score-distribution + python validation_suite.py --test edge-cases +""" + +import argparse +import statistics +from typing import List, Dict, Any, Tuple, Optional +from dataclasses import dataclass + +from rich.console import Console +from rich.table import Table +from rich.panel import Panel +from rich.progress import Progress +from rich import box + +try: + from diffmage.evaluation.commit_message_evaluator import CommitMessageEvaluator +except ImportError as e: + print(f"Error importing DiffMage modules: {e}") + print("Make sure you're running from the project root and the package is installed") + exit(1) + +console = Console() + + +@dataclass +class ValidationCase: + """A test case for validation""" + + name: str + commit_message: str + git_diff: str + expected_score_range: Tuple[float, float] # (min, max) + expected_quality: str # "excellent", "good", "average", "poor", "very_poor" + description: str + + +class EvaluationValidator: + """Validates that the LLM evaluator is working correctly""" + + def __init__(self, model_name: Optional[str] = None): + self.evaluator = CommitMessageEvaluator(model_name=model_name) + self.console = console + + def get_obvious_test_cases(self) -> List[ValidationCase]: + """Get test cases with obvious expected outcomes""" + + return [ + # EXCELLENT cases (4.5-5.0) + ValidationCase( + name="security_fix", + commit_message="fix: resolve critical SQL injection vulnerability in user authentication", + git_diff="""--- a/src/auth/UserAuth.py ++++ b/src/auth/UserAuth.py +@@ -23,7 +23,8 @@ class UserAuth: + def authenticate_user(self, username, password): +- query = f"SELECT * FROM users WHERE username='{username}' AND password='{password}'" ++ query = "SELECT * FROM users WHERE username=? AND password=?" ++ return self.db.execute(query, (username, password)) +- return self.db.execute(query)""", + expected_score_range=(4.0, 5.0), + expected_quality="excellent", + description="Clear security fix with good explanation", + ), + ValidationCase( + name="feature_with_context", + commit_message="feat: implement user password reset with email verification\n\nAdds secure password reset flow:\n- Generate time-limited reset tokens\n- Send verification emails\n- Validate tokens before allowing reset\n- Log security events for auditing", + git_diff="""--- a/src/auth/PasswordReset.py ++++ b/src/auth/PasswordReset.py +@@ -0,0 +1,45 @@ ++import secrets ++import hashlib ++from datetime import datetime, timedelta ++ ++class PasswordReset: ++ def __init__(self, email_service, user_service): ++ self.email_service = email_service ++ self.user_service = user_service ++ ++ def request_reset(self, email): ++ user = self.user_service.find_by_email(email) ++ if not user: ++ return False # Don't reveal if email exists ++ ++ token = secrets.token_urlsafe(32) ++ expires_at = datetime.now() + timedelta(hours=1) ++ ++ self.user_service.store_reset_token(user.id, token, expires_at) ++ self.email_service.send_reset_email(email, token) ++ ++ return True""", + expected_score_range=(4.0, 5.0), + expected_quality="excellent", + description="Comprehensive feature with detailed explanation", + ), + # GOOD cases (3.5-4.0) + ValidationCase( + name="simple_bug_fix", + commit_message="fix: handle null values in user profile display", + git_diff="""--- a/src/components/UserProfile.jsx ++++ b/src/components/UserProfile.jsx +@@ -15,7 +15,7 @@ const UserProfile = ({ user }) => { +
+

{user.name}

+-

Email: {user.email}

++

Email: {user.email || 'Not provided'}

+-

Bio: {user.bio}

++

Bio: {user.bio || 'No bio available'}

+
+ );""", + expected_score_range=(3.0, 4.5), + expected_quality="good", + description="Clear bug fix, could use more detail", + ), + # AVERAGE cases (2.5-3.5) + ValidationCase( + name="generic_update", + commit_message="update user component", + git_diff="""--- a/src/components/User.jsx ++++ b/src/components/User.jsx +@@ -10,6 +10,7 @@ const User = ({ userData }) => { + return ( +
+ {userData.name} ++ {userData.role} +
+ );""", + expected_score_range=(2.0, 3.5), + expected_quality="average", + description="Vague message, minimal change", + ), + # POOR cases (1.5-2.5) + ValidationCase( + name="meaningless_message", + commit_message="fix stuff", + git_diff="""--- a/src/utils/helper.js ++++ b/src/utils/helper.js +@@ -5,7 +5,7 @@ function processData(data) { + if (!data) { + return null; + } +- return data.map(item => item.value); ++ return data.map(item => item.value || 0); + }""", + expected_score_range=(1.0, 2.5), + expected_quality="poor", + description="Meaningless commit message", + ), + # VERY POOR cases (1.0-2.0) + ValidationCase( + name="gibberish", + commit_message="asdf jkl; qwerty", + git_diff="""--- a/src/test.js ++++ b/src/test.js +@@ -1,3 +1,4 @@ + // Test file + console.log('hello'); ++console.log('world');""", + expected_score_range=(1.0, 2.0), + expected_quality="very_poor", + description="Nonsensical commit message", + ), + ] + + def get_edge_test_cases(self) -> List[ValidationCase]: + """Get edge cases that might break the evaluator""" + + return [ + ValidationCase( + name="empty_message", + commit_message="", + git_diff="--- a/file.txt\n+++ b/file.txt\n@@ -1 +1,2 @@\n hello\n+world", + expected_score_range=(1.0, 2.0), + expected_quality="very_poor", + description="Empty commit message", + ), + ValidationCase( + name="very_long_message", + commit_message="fix: " + + "very " * 100 + + "long commit message that goes on and on and provides way too much detail about a simple change " + * 10, + git_diff="--- a/file.txt\n+++ b/file.txt\n@@ -1 +1,2 @@\n hello\n+world", + expected_score_range=(1.0, 3.0), + expected_quality="poor", + description="Excessively long commit message", + ), + ValidationCase( + name="special_characters", + commit_message="fix: handle Γ©mojis πŸš€ and Γ±oΓ±-ASCII Γ§haracters in ΓΌser input", + git_diff="""--- a/src/input.py ++++ b/src/input.py +@@ -1,3 +1,4 @@ + def process_input(text): ++ text = text.encode('utf-8').decode('utf-8') + return text.strip()""", + expected_score_range=(3.0, 4.5), + expected_quality="good", + description="Special characters and emojis", + ), + ValidationCase( + name="no_diff", + commit_message="fix: important bug fix", + git_diff="", + expected_score_range=(1.0, 3.0), + expected_quality="poor", + description="No diff provided", + ), + ValidationCase( + name="merge_commit", + commit_message="Merge branch 'feature/user-auth' into main", + git_diff="""--- a/src/auth.py ++++ b/src/auth.py +@@ -1,10 +1,20 @@ + # Auth module ++# New authentication features + + def login(user, password): + return validate_credentials(user, password) ++ ++def logout(user): ++ clear_session(user) ++ ++def reset_password(email): ++ send_reset_email(email)""", + expected_score_range=(1.0, 3.0), + expected_quality="poor", + description="Merge commit (usually auto-generated)", + ), + ] + + def test_obvious_cases(self) -> Dict[str, Any]: + """Test if evaluator handles obviously good/bad cases correctly""" + + console.print( + Panel( + "[bold]Testing Obvious Cases[/bold]\n" + "Evaluator should clearly distinguish between excellent and poor commit messages", + title="🎯 Obvious Cases Test", + border_style="blue", + ) + ) + + test_cases = self.get_obvious_test_cases() + results = [] + + with Progress(console=self.console) as progress: + task = progress.add_task( + "Evaluating obvious cases...", total=len(test_cases) + ) + + for case in test_cases: + try: + result = self.evaluator.evaluate_commit_message( + case.commit_message, case.git_diff + ) + + # Check if score is in expected range + in_range = ( + case.expected_score_range[0] + <= result.overall_score + <= case.expected_score_range[1] + ) + + results.append( + { + "case": case, + "result": result, + "in_expected_range": in_range, + "score_deviation": self._calculate_deviation( + result.overall_score, case.expected_score_range + ), + } + ) + + except Exception as e: + console.print(f"[red]Error evaluating {case.name}: {e}[/red]") + results.append( + { + "case": case, + "result": None, + "error": str(e), + "in_expected_range": False, + "score_deviation": float("inf"), + } + ) + + progress.update(task, advance=1) + + # Analyze results + success_rate = ( + sum(1 for r in results if r.get("in_expected_range", False)) + / len(results) + * 100 + ) + avg_deviation = statistics.mean( + [ + r["score_deviation"] + for r in results + if r["score_deviation"] != float("inf") + ] + ) + + # Display results + self._display_obvious_cases_results(results, success_rate, avg_deviation) + + return { + "test_name": "obvious_cases", + "success_rate": success_rate, + "average_deviation": avg_deviation, + "results": results, + "passed": success_rate >= 70, # 70% of obvious cases should be correct + } + + def test_ranking_consistency(self) -> Dict[str, Any]: + """Test if evaluator consistently ranks messages in logical order""" + + console.print( + Panel( + "[bold]Testing Ranking Consistency[/bold]\n" + "Evaluator should rank obviously better messages higher than worse ones", + title="πŸ“Š Ranking Test", + border_style="green", + ) + ) + + # Get subset of test cases for ranking + test_cases = self.get_obvious_test_cases() + + # Sort by expected quality for comparison + expected_order = sorted( + test_cases, key=lambda x: x.expected_score_range[1], reverse=True + ) + + # Evaluate all cases + evaluated_cases = [] + + with Progress(console=self.console) as progress: + task = progress.add_task("Evaluating for ranking...", total=len(test_cases)) + + for case in test_cases: + try: + result = self.evaluator.evaluate_commit_message( + case.commit_message, case.git_diff + ) + evaluated_cases.append((case, result)) + except Exception as e: + console.print(f"[red]Error evaluating {case.name}: {e}[/red]") + continue + + progress.update(task, advance=1) + + # Sort by actual scores + actual_order = sorted( + evaluated_cases, key=lambda x: x[1].overall_score, reverse=True + ) + + # Calculate ranking consistency + ranking_violations = self._count_ranking_violations( + expected_order, actual_order + ) + total_pairs = len(evaluated_cases) * (len(evaluated_cases) - 1) // 2 + consistency_rate = ( + (total_pairs - ranking_violations) / total_pairs * 100 + if total_pairs > 0 + else 0 + ) + + # Display results + self._display_ranking_results( + expected_order, actual_order, consistency_rate, ranking_violations + ) + + return { + "test_name": "ranking_consistency", + "consistency_rate": consistency_rate, + "ranking_violations": ranking_violations, + "total_pairs": total_pairs, + "passed": consistency_rate >= 80, # 80% of rankings should be consistent + } + + def test_score_distribution(self) -> Dict[str, Any]: + """Test if evaluator uses the full score range appropriately""" + + console.print( + Panel( + "[bold]Testing Score Distribution[/bold]\n" + "Evaluator should use the full 1-5 scale and not cluster around one value", + title="πŸ“ˆ Distribution Test", + border_style="yellow", + ) + ) + + all_cases = self.get_obvious_test_cases() + self.get_edge_test_cases() + scores = [] + + with Progress(console=self.console) as progress: + task = progress.add_task("Collecting scores...", total=len(all_cases)) + + for case in all_cases: + try: + result = self.evaluator.evaluate_commit_message( + case.commit_message, case.git_diff + ) + scores.append( + { + "case_name": case.name, + "overall_score": result.overall_score, + "what_score": result.what_score, + "why_score": result.why_score, + "expected_quality": case.expected_quality, + } + ) + except Exception as e: + console.print(f"[red]Error evaluating {case.name}: {e}[/red]") + continue + + progress.update(task, advance=1) + + if not scores: + return { + "test_name": "score_distribution", + "passed": False, + "error": "No scores collected", + } + + # Analyze distribution + overall_scores = [s["overall_score"] for s in scores] + distribution_stats = { + "mean": statistics.mean(overall_scores), + "median": statistics.median(overall_scores), + "std_dev": statistics.stdev(overall_scores) + if len(overall_scores) > 1 + else 0, + "min": min(overall_scores), + "max": max(overall_scores), + "range": max(overall_scores) - min(overall_scores), + } + + # Check for problems + problems = [] + if distribution_stats["std_dev"] < 0.5: + problems.append("Low variance - scores too clustered") + if distribution_stats["range"] < 2.0: + problems.append("Narrow range - not using full scale") + if distribution_stats["mean"] > 4.0: + problems.append("Grade inflation - scores too high") + if distribution_stats["mean"] < 2.0: + problems.append("Grade deflation - scores too low") + + # Display results + self._display_distribution_results(distribution_stats, scores, problems) + + return { + "test_name": "score_distribution", + "distribution_stats": distribution_stats, + "problems": problems, + "scores": scores, + "passed": len(problems) == 0, + } + + def test_edge_cases(self) -> Dict[str, Any]: + """Test if evaluator handles edge cases gracefully""" + + console.print( + Panel( + "[bold]Testing Edge Cases[/bold]\n" + "Evaluator should handle unusual inputs without crashing", + title="⚠️ Edge Cases Test", + border_style="red", + ) + ) + + edge_cases = self.get_edge_test_cases() + results = [] + + with Progress(console=self.console) as progress: + task = progress.add_task("Testing edge cases...", total=len(edge_cases)) + + for case in edge_cases: + try: + result = self.evaluator.evaluate_commit_message( + case.commit_message, case.git_diff + ) + + # Check if result is reasonable + is_reasonable = ( + 1.0 <= result.overall_score <= 5.0 + and 1.0 <= result.what_score <= 5.0 + and 1.0 <= result.why_score <= 5.0 + and result.reasoning + and len(result.reasoning) > 10 + ) + + results.append( + { + "case": case, + "result": result, + "handled_gracefully": True, + "is_reasonable": is_reasonable, + "error": None, + } + ) + + except Exception as e: + results.append( + { + "case": case, + "result": None, + "handled_gracefully": False, + "is_reasonable": False, + "error": str(e), + } + ) + + progress.update(task, advance=1) + + # Analyze results + graceful_handling_rate = ( + sum(1 for r in results if r["handled_gracefully"]) / len(results) * 100 + ) + reasonable_results_rate = ( + sum(1 for r in results if r.get("is_reasonable", False)) + / len(results) + * 100 + ) + + # Display results + self._display_edge_cases_results( + results, graceful_handling_rate, reasonable_results_rate + ) + + return { + "test_name": "edge_cases", + "graceful_handling_rate": graceful_handling_rate, + "reasonable_results_rate": reasonable_results_rate, + "results": results, + "passed": graceful_handling_rate >= 90 and reasonable_results_rate >= 70, + } + + def run_all_tests(self) -> Dict[str, Any]: + """Run all validation tests""" + + console.print( + Panel( + "[bold]Running Complete Validation Suite[/bold]\n" + "Testing fundamental evaluator functionality", + title="πŸ§ͺ Validation Suite", + border_style="cyan", + ) + ) + + all_results = {} + + # Run each test + tests = [ + ("obvious_cases", self.test_obvious_cases), + ("ranking_consistency", self.test_ranking_consistency), + ("score_distribution", self.test_score_distribution), + ("edge_cases", self.test_edge_cases), + ] + + for test_name, test_func in tests: + console.print( + f"\n[blue]Running {test_name.replace('_', ' ').title()} test...[/blue]" + ) + try: + all_results[test_name] = test_func() + except Exception as e: + console.print(f"[red]Test {test_name} failed with error: {e}[/red]") + all_results[test_name] = { + "test_name": test_name, + "passed": False, + "error": str(e), + } + + # Overall assessment + passed_tests = sum( + 1 for result in all_results.values() if result.get("passed", False) + ) + total_tests = len(all_results) + overall_pass_rate = passed_tests / total_tests * 100 + model_used: str = self.evaluator.model_name + + # Display summary + self._display_overall_summary(all_results, overall_pass_rate) + + return { + "overall_pass_rate": overall_pass_rate, + "passed_tests": passed_tests, + "total_tests": total_tests, + "model_used": model_used, + "individual_results": all_results, + "evaluator_ready": overall_pass_rate >= 75, # 75% of tests should pass + } + + # Helper methods for calculations and display + def _calculate_deviation( + self, actual_score: float, expected_range: Tuple[float, float] + ) -> float: + """Calculate how far actual score is from expected range""" + if expected_range[0] <= actual_score <= expected_range[1]: + return 0.0 + elif actual_score < expected_range[0]: + return expected_range[0] - actual_score + else: + return actual_score - expected_range[1] + + def _count_ranking_violations( + self, expected_order: List, actual_order: List + ) -> int: + """Count pairs that are ranked in wrong order""" + violations = 0 + actual_scores = { + case.name: score + for case, score in [(c, r.overall_score) for c, r in actual_order] + } + + for i, case1 in enumerate(expected_order): + for case2 in expected_order[i + 1 :]: + # case1 should score higher than case2 + if actual_scores.get(case1.name, 0) < actual_scores.get(case2.name, 0): + violations += 1 + + return violations + + def _display_obvious_cases_results( + self, results: List, success_rate: float, avg_deviation: float + ): + """Display obvious cases test results""" + + table = Table(title="Obvious Cases Results", box=box.SIMPLE) + table.add_column("Case", style="cyan") + table.add_column("Expected", justify="center") + table.add_column("Actual", justify="center") + table.add_column("Status", justify="center") + table.add_column("Deviation", justify="center") + + for r in results: + if r.get("result"): + status = "βœ… Pass" if r["in_expected_range"] else "❌ Fail" + status_color = "green" if r["in_expected_range"] else "red" + + expected_range = f"{r['case'].expected_score_range[0]:.1f}-{r['case'].expected_score_range[1]:.1f}" + actual_score = f"{r['result'].overall_score:.1f}" + deviation = f"{r['score_deviation']:.1f}" + + table.add_row( + r["case"].name, + expected_range, + actual_score, + f"[{status_color}]{status}[/{status_color}]", + deviation, + ) + else: + table.add_row( + r["case"].name, "N/A", "ERROR", "[red]❌ Error[/red]", "∞" + ) + + console.print(table) + + # Summary + summary_color = ( + "green" if success_rate >= 70 else "yellow" if success_rate >= 50 else "red" + ) + console.print( + f"\n[{summary_color}]Success Rate: {success_rate:.1f}% | Average Deviation: {avg_deviation:.2f}[/{summary_color}]" + ) + + def _display_ranking_results( + self, + expected_order: List, + actual_order: List, + consistency_rate: float, + violations: int, + ): + """Display ranking consistency results""" + + table = Table(title="Ranking Comparison", box=box.SIMPLE) + table.add_column("Rank", justify="center") + table.add_column("Expected", style="cyan") + table.add_column("Actual", style="yellow") + table.add_column("Score", justify="center") + + for i, (case, result) in enumerate(actual_order[:5]): # Show top 5 + expected_case = expected_order[i] if i < len(expected_order) else None + expected_name = expected_case.name if expected_case else "N/A" + + table.add_row( + str(i + 1), expected_name, case.name, f"{result.overall_score:.1f}" + ) + + console.print(table) + + # Summary + summary_color = ( + "green" + if consistency_rate >= 80 + else "yellow" + if consistency_rate >= 60 + else "red" + ) + console.print( + f"\n[{summary_color}]Ranking Consistency: {consistency_rate:.1f}% | Violations: {violations}[/{summary_color}]" + ) + + def _display_distribution_results(self, stats: Dict, scores: List, problems: List): + """Display score distribution results""" + + table = Table(title="Score Distribution Statistics", box=box.SIMPLE) + table.add_column("Metric", style="cyan") + table.add_column("Value", justify="center") + table.add_column("Assessment", style="yellow") + + table.add_row( + "Mean", + f"{stats['mean']:.2f}", + "Good" if 2.5 <= stats["mean"] <= 3.5 else "Check", + ) + table.add_row( + "Std Dev", + f"{stats['std_dev']:.2f}", + "Good" if stats["std_dev"] >= 0.5 else "Low variance", + ) + table.add_row( + "Range", + f"{stats['range']:.2f}", + "Good" if stats["range"] >= 2.0 else "Narrow range", + ) + table.add_row("Min", f"{stats['min']:.2f}", "") + table.add_row("Max", f"{stats['max']:.2f}", "") + + console.print(table) + + if problems: + console.print("\n[red]Issues Found:[/red]") + for problem in problems: + console.print(f" β€’ {problem}") + else: + console.print("\n[green]βœ… Score distribution looks healthy[/green]") + + def _display_edge_cases_results( + self, results: List, graceful_rate: float, reasonable_rate: float + ): + """Display edge cases test results""" + + table = Table(title="Edge Cases Results", box=box.SIMPLE) + table.add_column("Case", style="cyan") + table.add_column("Handled", justify="center") + table.add_column("Reasonable", justify="center") + table.add_column("Error", style="red") + + for r in results: + handled = "βœ… Yes" if r["handled_gracefully"] else "❌ No" + reasonable = "βœ… Yes" if r.get("is_reasonable", False) else "❌ No" + error = ( + r.get("error", "")[:50] + "..." + if r.get("error") and len(r.get("error", "")) > 50 + else r.get("error", "") + ) + + table.add_row(r["case"].name, handled, reasonable, error) + + console.print(table) + + console.print( + f"\n[blue]Graceful Handling: {graceful_rate:.1f}% | Reasonable Results: {reasonable_rate:.1f}%[/blue]" + ) + + def _display_overall_summary(self, all_results: Dict, overall_pass_rate: float): + """Display overall validation summary""" + + console.print( + Panel( + f"[bold]Validation Complete[/bold]\n" + f"Overall Pass Rate: {overall_pass_rate:.1f}%", + title="πŸ“‹ Summary", + border_style="green" + if overall_pass_rate >= 75 + else "yellow" + if overall_pass_rate >= 50 + else "red", + ) + ) + + summary_table = Table(title="Test Summary", box=box.SIMPLE) + summary_table.add_column("Test", style="cyan") + summary_table.add_column("Status", justify="center") + summary_table.add_column("Key Metric", justify="center") + + for test_name, result in all_results.items(): + status = "βœ… Pass" if result.get("passed", False) else "❌ Fail" + status_color = "green" if result.get("passed", False) else "red" + + # Extract key metric based on test type + if test_name == "obvious_cases": + key_metric = f"{result.get('success_rate', 0):.1f}% correct" + elif test_name == "ranking_consistency": + key_metric = f"{result.get('consistency_rate', 0):.1f}% consistent" + elif test_name == "score_distribution": + problems = result.get("problems", []) + key_metric = f"{len(problems)} issues" + elif test_name == "edge_cases": + key_metric = f"{result.get('graceful_handling_rate', 0):.1f}% handled" + else: + key_metric = "N/A" + + summary_table.add_row( + test_name.replace("_", " ").title(), + f"[{status_color}]{status}[/{status_color}]", + key_metric, + ) + + console.print(summary_table) + + # Recommendations + if overall_pass_rate >= 75: + console.print( + "\n[green]πŸŽ‰ Evaluator validation passed! Ready for benchmarking and research.[/green]" + ) + elif overall_pass_rate >= 50: + console.print( + "\n[yellow]⚠️ Evaluator has some issues but may be usable. Review failed tests.[/yellow]" + ) + else: + console.print( + "\n[red]❌ Evaluator has significant issues. Fix core problems before proceeding.[/red]" + ) + + +def main(): + """Main CLI interface""" + + parser = argparse.ArgumentParser( + description="Validation Suite for DiffMage Evaluator" + ) + parser.add_argument("--model", help="AI model to test (default: uses your default)") + + test_group = parser.add_mutually_exclusive_group() + test_group.add_argument( + "--all", action="store_true", help="Run all validation tests" + ) + test_group.add_argument( + "--test", + choices=[ + "obvious-cases", + "ranking-consistency", + "score-distribution", + "edge-cases", + ], + help="Run specific test", + ) + + parser.add_argument("--output", help="Save results to JSON file") + parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output") + + args = parser.parse_args() + + if not args.test and not args.all: + args.all = True + + try: + validator = EvaluationValidator(model_name=args.model) + except Exception as e: + console.print(f"[red]Error initializing validator: {e}[/red]") + console.print( + "[dim]Make sure you're in the project root and DiffMage is properly installed[/dim]" + ) + return 1 + + results = None + + try: + if args.all: + console.print("[blue]Running complete validation suite...[/blue]") + results = validator.run_all_tests() + + elif args.test == "obvious-cases": + results = validator.test_obvious_cases() + + elif args.test == "ranking-consistency": + results = validator.test_ranking_consistency() + + elif args.test == "score-distribution": + results = validator.test_score_distribution() + + elif args.test == "edge-cases": + results = validator.test_edge_cases() + + if args.output and results: + import json + + with open(args.output, "w") as f: + json.dump(results, f, indent=2, default=str) + console.print(f"[green]Results saved to {args.output}[/green]") + + if results and results.get("evaluator_ready", results.get("passed", False)): + console.print("\n[green]βœ… Validation completed successfully![/green]") + return 0 + else: + console.print("\n[red]❌ Validation found issues that need attention[/red]") + return 1 + + except Exception as e: + console.print(f"[red]Error during validation: {e}[/red]") + if args.verbose: + import traceback + + console.print(f"[dim]{traceback.format_exc()}[/dim]") + return 1 + + +if __name__ == "__main__": + exit(main())