From c96acc946a32d36823854e461385ad033862e9df Mon Sep 17 00:00:00 2001 From: Nick Galluzzo Date: Tue, 5 Aug 2025 14:13:35 +0700 Subject: [PATCH 1/2] feat(benchmarks): add stability benchmarking script for commit evaluation Adds stability testing capabilities using real commit data and curated examples. Includes CLI interface for running single, batch, and comparative stability tests with configurable parameters. Provides detailed reporting on evaluation consistency and model performance metrics. --- benchmarks/stability_benchmarks.py | 736 +++++++++++++++++++++++++++++ 1 file changed, 736 insertions(+) create mode 100644 benchmarks/stability_benchmarks.py diff --git a/benchmarks/stability_benchmarks.py b/benchmarks/stability_benchmarks.py new file mode 100644 index 0000000..0c4082f --- /dev/null +++ b/benchmarks/stability_benchmarks.py @@ -0,0 +1,736 @@ +""" +Stability Benchmarking Script for DiffMage + +This script uses your existing EvaluationBenchmarks class to test evaluation +stability across various real commit examples and scenarios. + +Usage: + python benchmark_stability.py --repo-path /path/to/repo --runs 10 + python benchmark_stability.py --commit-range "HEAD~20..HEAD" --runs 5 + python benchmark_stability.py --test-examples --model gpt-4 + python benchmark_stability.py --batch-test --variance-threshold 0.2 +""" + +import argparse +import json +import time +from pathlib import Path +from typing import List, Dict, Any, Tuple, Optional +from datetime import datetime +import git +import random + +from rich.console import Console +from rich.table import Table +from rich.panel import Panel +from rich.progress import Progress +from rich import box + +from diffmage.evaluation.benchmarks import EvaluationBenchmarks, StabilityTestResult +from diffmage.evaluation.commit_message_evaluator import CommitMessageEvaluator +from diffmage.git.diff_parser import GitDiffParser + +console = Console() + + +class StabilityBenchmarkSuite: + """Comprehensive stability testing using real commit data""" + + def __init__(self, repo_path: str = ".", model_name: Optional[str] = None): + self.repo_path = Path(repo_path) + self.repo = git.Repo(repo_path) + self.diff_parser = GitDiffParser(repo_path) + self.evaluator = CommitMessageEvaluator(model_name=model_name) + self.benchmarks = EvaluationBenchmarks(self.evaluator) + self.console = console + + def get_real_commit_examples( + self, commit_range: Optional[str] = None, max_examples: int = 20 + ) -> List[Tuple[str, str, str]]: + """Extract real commit examples from repository + + Returns: + List of (commit_hash, commit_message, git_diff) tuples + """ + + if commit_range: + try: + commits = list(self.repo.iter_commits(commit_range)) + except Exception as e: + console.print( + f"[red]Error parsing commit range '{commit_range}': {e}[/red]" + ) + commits = list(self.repo.iter_commits("HEAD", max_count=max_examples)) + else: + commits = list(self.repo.iter_commits("HEAD", max_count=max_examples)) + + examples = [] + + console.print( + f"[blue]Extracting examples from {len(commits)} commits...[/blue]" + ) + + with Progress(console=self.console) as progress: + task = progress.add_task("Processing commits...", total=len(commits)) + + for commit in commits: + try: + message = str(commit.message).strip() + + if len(commit.parents) > 1 or not message or len(message) < 10: + progress.update(task, advance=1) + continue + + git_diff = self.diff_parser.parse_specific_commit(commit.hexsha) + + if not git_diff or len(git_diff) > 10000: + progress.update(task, advance=1) + continue + + examples.append((commit.hexsha, message, git_diff)) + + if len(examples) >= max_examples: + break + + except Exception as e: + console.print( + f"[dim]Warning: Skipped commit {commit.hexsha[:8]}: {e}[/dim]" + ) + + progress.update(task, advance=1) + + console.print( + f"[green]Successfully extracted {len(examples)} usable commit examples[/green]" + ) + return examples + + def get_curated_test_examples(self) -> List[Tuple[str, str, str]]: + """Get curated test examples with various commit patterns + + Returns: + List of (example_id, commit_message, git_diff) tuples + """ + + examples = [ + ( + "bug_fix_null_check", + "fix: resolve null pointer exception in user validation", + """--- a/src/auth/UserValidator.java ++++ b/src/auth/UserValidator.java +@@ -23,7 +23,7 @@ public class UserValidator { + } + + public boolean isValidUser(User user) { +- return user.getEmail().contains("@"); ++ return user != null && user.getEmail() != null && user.getEmail().contains("@"); + } + }""", + ), + ( + "feature_dark_mode", + "feat: implement dark mode toggle in user preferences", + """--- a/src/components/UserPreferences.jsx ++++ b/src/components/UserPreferences.jsx +@@ -8,6 +8,7 @@ const UserPreferences = () => { + const [language, setLanguage] = useState('en'); + const [notifications, setNotifications] = useState(true); ++ const [darkMode, setDarkMode] = useState(false); + + const savePreferences = async () => { + const prefs = { +@@ -15,6 +16,7 @@ const UserPreferences = () => { + language, + notifications, ++ darkMode, + }; + + await api.updatePreferences(prefs); +@@ -35,6 +37,13 @@ const UserPreferences = () => { + onChange={(e) => setNotifications(e.target.checked)} + /> + ++ ++
++ ++ setDarkMode(e.target.checked)} ++ /> ++
+ + );""", + ), + ( + "refactor_extract_method", + "refactor: extract user authentication logic into separate service", + """--- a/src/controllers/AuthController.js ++++ b/src/controllers/AuthController.js +@@ -1,4 +1,5 @@ + const bcrypt = require('bcrypt'); ++const AuthService = require('../services/AuthService'); + + class AuthController { + async login(req, res) { +@@ -6,15 +7,7 @@ class AuthController { + + try { + // Authenticate user +- const user = await User.findOne({ email }); +- if (!user) { +- return res.status(401).json({ error: 'Invalid credentials' }); +- } +- +- const isValidPassword = await bcrypt.compare(password, user.passwordHash); +- if (!isValidPassword) { +- return res.status(401).json({ error: 'Invalid credentials' }); +- } ++ const user = await AuthService.authenticateUser(email, password); + + const token = jwt.sign({ userId: user.id }, process.env.JWT_SECRET); + res.json({ token, user: { id: user.id, email: user.email } }); +@@ -24,4 +17,22 @@ class AuthController { + } + } + ++--- /dev/null +++++ b/src/services/AuthService.js +@@ -0,0 +1,22 @@ ++const bcrypt = require('bcrypt'); ++const User = require('../models/User'); ++ ++class AuthService { ++ static async authenticateUser(email, password) { ++ const user = await User.findOne({ email }); ++ if (!user) { ++ throw new Error('Invalid credentials'); ++ } ++ ++ const isValidPassword = await bcrypt.compare(password, user.passwordHash); ++ if (!isValidPassword) { ++ throw new Error('Invalid credentials'); ++ } ++ ++ return user; ++ } ++} ++ ++module.exports = AuthService;""", + ), + ( + "docs_api_update", + "docs: update API documentation for user endpoints", + """--- a/docs/api/users.md ++++ b/docs/api/users.md +@@ -23,6 +23,20 @@ Creates a new user account. + - `409 Conflict` - Email already exists + - `500 Internal Server Error` - Server error + ++### Request Example ++ ++```json ++{ ++ "email": "user@example.com", ++ "password": "securePassword123", ++ "firstName": "John", ++ "lastName": "Doe" ++} ++``` ++ ++### Response Example ++ ++```json + ## GET /api/users/:id + + Retrieves user information by ID. +@@ -35,3 +49,15 @@ Retrieves user information by ID. + - `200 OK` - User found + - `404 Not Found` - User not found + - `500 Internal Server Error` - Server error ++ ++### Response Example ++ ++```json ++{ ++ "id": "123e4567-e89b-12d3-a456-426614174000", ++ "email": "user@example.com", ++ "firstName": "John", ++ "lastName": "Doe", ++ "createdAt": "2024-01-15T10:30:00Z" ++} ++```""", + ), + ( + "performance_optimize_query", + "perf: optimize database query for user search with indexing", + """--- a/src/models/User.js ++++ b/src/models/User.js +@@ -15,8 +15,12 @@ const userSchema = new mongoose.Schema({ + } + }); + ++// Add indexes for better query performance ++userSchema.index({ email: 1 }); ++userSchema.index({ firstName: 1, lastName: 1 }); ++userSchema.index({ createdAt: -1 }); ++ + // Static method for user search +-userSchema.statics.searchUsers = function(searchTerm) { ++userSchema.statics.searchUsers = function(searchTerm, limit = 20) { + const regex = new RegExp(searchTerm, 'i'); + return this.find({ + $or: [ +@@ -24,7 +28,8 @@ userSchema.statics.searchUsers = function(searchTerm) { + { firstName: regex }, + { lastName: regex } + ] +- }); ++ }) ++ .limit(limit) ++ .sort({ createdAt: -1 }); + };""", + ), + ( + "test_add_validation", + "test: add unit tests for email validation utility", + """--- /dev/null ++++ b/tests/utils/emailValidator.test.js +@@ -0,0 +1,45 @@ ++const { validateEmail } = require('../../src/utils/emailValidator'); ++ ++describe('Email Validator', () => { ++ describe('valid emails', () => { ++ test('should validate standard email format', () => { ++ expect(validateEmail('user@example.com')).toBe(true); ++ expect(validateEmail('john.doe@company.org')).toBe(true); ++ expect(validateEmail('admin@subdomain.example.co.uk')).toBe(true); ++ }); ++ ++ test('should validate emails with plus signs', () => { ++ expect(validateEmail('user+tag@example.com')).toBe(true); ++ }); ++ ++ test('should validate emails with numbers', () => { ++ expect(validateEmail('user123@example123.com')).toBe(true); ++ }); ++ }); ++ ++ describe('invalid emails', () => { ++ test('should reject emails without @ symbol', () => { ++ expect(validateEmail('userexample.com')).toBe(false); ++ }); ++ ++ test('should reject emails without domain', () => { ++ expect(validateEmail('user@')).toBe(false); ++ }); ++ ++ test('should reject emails without user part', () => { ++ expect(validateEmail('@example.com')).toBe(false); ++ }); ++ ++ test('should reject empty strings', () => { ++ expect(validateEmail('')).toBe(false); ++ }); ++ ++ test('should reject null and undefined', () => { ++ expect(validateEmail(null)).toBe(false); ++ expect(validateEmail(undefined)).toBe(false); ++ }); ++ }); ++});""", + ), + ] + + console.print(f"[green]Loaded {len(examples)} curated test examples[/green]") + return examples + + def run_single_stability_test( + self, + commit_hash: str, + message: str, + git_diff: str, + runs: int = 5, + variance_threshold: float = 0.3, + ) -> StabilityTestResult: + """Run stability test on a single commit example""" + + console.print(f"[blue]Testing stability for: {message[:60]}...[/blue]") + console.print( + f"[dim]Commit: {commit_hash[:8] if len(commit_hash) > 8 else commit_hash}[/dim]" + ) + + result = self.benchmarks.stability_test( + message=message, + diff=git_diff, + runs=runs, + variance_threshold=variance_threshold, + ) + + return result + + def run_batch_stability_test( + self, + examples: List[Tuple[str, str, str]], + runs: int = 5, + variance_threshold: float = 0.3, + max_examples: Optional[int] = None, + ) -> Dict[str, Any]: + """Run stability tests on multiple examples""" + + if max_examples and len(examples) > max_examples: + examples = random.sample(examples, max_examples) + + console.print( + f"[blue]Running batch stability test on {len(examples)} examples...[/blue]" + ) + console.print( + f"[dim]Runs per example: {runs}, Variance threshold: {variance_threshold}[/dim]" + ) + + results = [] + stable_count = 0 + total_time = 0 + + with Progress(console=self.console) as progress: + task = progress.add_task("Testing stability...", total=len(examples)) + + for commit_hash, message, git_diff in examples: + start_time = time.time() + + try: + result = self.benchmarks.stability_test( + message=message, + diff=git_diff, + runs=runs, + variance_threshold=variance_threshold, + ) + + results.append(result) + + if result["is_stable"]: # type: ignore + stable_count += 1 + + except Exception as e: + console.print(f"[red]Error testing {commit_hash[:8]}: {e}[/red]") + continue + + total_time += time.time() - start_time + progress.update(task, advance=1) + + stability_rate = (stable_count / len(results)) * 100 if results else 0 + avg_time_per_test = total_time / len(results) if results else 0 + + summary = { + "total_examples": len(examples), + "successful_tests": len(results), + "stable_examples": stable_count, + "stability_rate": stability_rate, + "average_time_per_test": avg_time_per_test, + "total_test_time": total_time, + "runs_per_example": runs, + "variance_threshold": variance_threshold, + } + + self._display_batch_summary(summary) + + return { + "summary": summary, + "individual_results": results, + "timestamp": datetime.now().isoformat(), + } + + def _display_batch_summary(self, summary: Dict[str, Any]): + """Display batch test summary""" + + console.print( + Panel( + f"[bold]Batch Stability Test Results[/bold]\n" + f"Examples Tested: {summary['successful_tests']}/{summary['total_examples']}\n" + f"Stable Examples: {summary['stable_examples']}\n" + f"Stability Rate: {summary['stability_rate']:.1f}%", + title="πŸ“Š Summary", + border_style="green" + if summary["stability_rate"] >= 80 + else "yellow" + if summary["stability_rate"] >= 60 + else "red", + ) + ) + + # Detailed table + table = Table(title="Performance Metrics", box=box.SIMPLE) + table.add_column("Metric", style="cyan") + table.add_column("Value", justify="center") + table.add_column("Assessment", style="yellow") + + # Stability rate assessment + if summary["stability_rate"] >= 90: + stability_assessment = "🟒 Excellent" + elif summary["stability_rate"] >= 80: + stability_assessment = "🟑 Good" + elif summary["stability_rate"] >= 60: + stability_assessment = "🟠 Acceptable" + else: + stability_assessment = "πŸ”΄ Needs Improvement" + + # Time assessment + avg_time = summary["average_time_per_test"] + if avg_time < 30: + time_assessment = "🟒 Fast" + elif avg_time < 60: + time_assessment = "🟑 Reasonable" + else: + time_assessment = "πŸ”΄ Slow" + + table.add_row( + "Stability Rate", f"{summary['stability_rate']:.1f}%", stability_assessment + ) + table.add_row("Avg Time per Test", f"{avg_time:.1f}s", time_assessment) + table.add_row("Total Test Time", f"{summary['total_test_time']:.1f}s", "") + table.add_row("Runs per Example", str(summary["runs_per_example"]), "") + table.add_row("Variance Threshold", str(summary["variance_threshold"]), "") + + console.print(table) + + def run_comparative_stability_test( + self, examples: List[Tuple[str, str, str]], models: List[str], runs: int = 3 + ) -> Dict[str, Any]: + """Compare stability across different models""" + + console.print( + f"[blue]Running comparative stability test across {len(models)} models...[/blue]" + ) + + model_results = {} + + for model in models: + console.print(f"[yellow]Testing model: {model}[/yellow]") + + # Create new evaluator for this model + evaluator = CommitMessageEvaluator(model_name=model) + benchmarks = EvaluationBenchmarks(evaluator) + + model_stability_results = [] + + for i, (commit_hash, message, git_diff) in enumerate( + examples[:5] + ): # Limit for comparative test + try: + result = benchmarks.stability_test( + message=message, + diff=git_diff, + runs=runs, + variance_threshold=0.3, + ) + + model_stability_results.append( + { + "example_id": i, + "is_stable": result["is_stable"], + "max_variance": result["max_variance"], + "commit_hash": commit_hash, + } + ) + + except Exception as e: + console.print( + f"[red]Error with {model} on {commit_hash[:8]}: {e}[/red]" + ) + continue + + model_results[model] = model_stability_results + + # Calculate comparative statistics + comparative_stats = self._calculate_comparative_stats(model_results) + self._display_comparative_results(comparative_stats) + + return { + "model_results": model_results, + "comparative_stats": comparative_stats, + "timestamp": datetime.now().isoformat(), + } + + def _calculate_comparative_stats( + self, model_results: Dict[str, List] + ) -> Dict[str, Any]: + """Calculate comparative statistics across models""" + + stats = {} + + for model, results in model_results.items(): + if not results: + continue + + stable_count = sum(1 for r in results if r["is_stable"]) + total_count = len(results) + avg_variance = ( + sum(r["max_variance"] for r in results) / total_count + if total_count > 0 + else 0 + ) + + stats[model] = { + "stability_rate": (stable_count / total_count * 100) + if total_count > 0 + else 0, + "average_variance": avg_variance, + "stable_examples": stable_count, + "total_examples": total_count, + } + + return stats + + def _display_comparative_results(self, stats: Dict[str, Any]): + """Display comparative model results""" + + table = Table(title="Model Stability Comparison", box=box.SIMPLE) + table.add_column("Model", style="cyan") + table.add_column("Stability Rate", justify="center") + table.add_column("Avg Variance", justify="center") + table.add_column("Stable/Total", justify="center") + table.add_column("Assessment", style="yellow") + + for model, model_stats in stats.items(): + rate = model_stats["stability_rate"] + + if rate >= 80: + assessment = "🟒 Excellent" + elif rate >= 60: + assessment = "🟑 Good" + else: + assessment = "πŸ”΄ Needs Work" + + table.add_row( + model, + f"{rate:.1f}%", + f"{model_stats['average_variance']:.3f}", + f"{model_stats['stable_examples']}/{model_stats['total_examples']}", + assessment, + ) + + console.print(table) + + +def main(): + """Main CLI interface""" + + parser = argparse.ArgumentParser( + description="Stability Benchmarking Script for DiffMage" + ) + parser.add_argument("--repo-path", default=".", help="Path to git repository") + parser.add_argument("--model", help="AI model to test (default: uses your default)") + parser.add_argument("--runs", type=int, default=5, help="Number of runs per test") + parser.add_argument( + "--variance-threshold", + type=float, + default=0.3, + help="Variance threshold for stability", + ) + + # Test data options + group = parser.add_mutually_exclusive_group() + group.add_argument("--commit-range", help="Git commit range (e.g., HEAD~10..HEAD)") + group.add_argument( + "--test-examples", action="store_true", help="Use curated test examples" + ) + group.add_argument("--single-commit", help="Test single commit by hash") + + # Test options + parser.add_argument( + "--batch-test", action="store_true", help="Run batch stability test" + ) + parser.add_argument("--comparative-test", help="Compare models (comma-separated)") + parser.add_argument( + "--max-examples", type=int, default=10, help="Maximum examples to test" + ) + parser.add_argument("--output", help="Save results to JSON file") + + args = parser.parse_args() + + try: + suite = StabilityBenchmarkSuite(args.repo_path, args.model) + except Exception as e: + console.print(f"[red]Error initializing benchmark suite: {e}[/red]") + return 1 + + results = None + + try: + if args.single_commit: + # Test single commit + commits = list(suite.repo.iter_commits(args.single_commit, max_count=1)) + if not commits: + console.print(f"[red]Commit not found: {args.single_commit}[/red]") + return 1 + + commit = commits[0] + message = str(commit.message).strip() + git_diff = suite.diff_parser.parse_specific_commit(commit.hexsha)[1] + + results = suite.run_single_stability_test( + commit.hexsha, message, git_diff, args.runs, args.variance_threshold + ) + + elif args.test_examples: + # Use curated examples + examples = suite.get_curated_test_examples() + + if args.comparative_test: + models = [m.strip() for m in args.comparative_test.split(",")] + results = suite.run_comparative_stability_test( + examples, models, args.runs + ) + elif args.batch_test: + results = suite.run_batch_stability_test( + examples, args.runs, args.variance_threshold, args.max_examples + ) + else: + # Test first example + commit_hash, message, git_diff = examples[0] + results = suite.run_single_stability_test( + commit_hash, message, git_diff, args.runs, args.variance_threshold + ) + + else: + # Use real repository commits + examples = suite.get_real_commit_examples( + args.commit_range, args.max_examples + ) + + if not examples: + console.print("[red]No suitable commit examples found[/red]") + return 1 + + if args.comparative_test: + models = [m.strip() for m in args.comparative_test.split(",")] + results = suite.run_comparative_stability_test( + examples, models, args.runs + ) + elif args.batch_test: + results = suite.run_batch_stability_test( + examples, args.runs, args.variance_threshold, args.max_examples + ) + else: + # Test first example + commit_hash, message, git_diff = examples[0] + results = suite.run_single_stability_test( + commit_hash, message, git_diff, args.runs, args.variance_threshold + ) + + # Save results if requested + if args.output and results: + with open(args.output, "w") as f: + json.dump(results, f, indent=2, default=str) + console.print(f"[green]Results saved to {args.output}[/green]") + + console.print("\n[green]Benchmarking completed successfully! βœ…[/green]") + + except Exception as e: + console.print(f"[red]Error during benchmarking: {e}[/red]") + return 1 + + return 0 + + +if __name__ == "__main__": + exit(main()) From c2a42cafe7336f672d21c928fede5b604161ab25 Mon Sep 17 00:00:00 2001 From: Nick Galluzzo Date: Tue, 5 Aug 2025 14:14:20 +0700 Subject: [PATCH 2/2] feat(validation): add comprehensive smoke test suite for LLM evaluator Adds a full validation suite with obvious cases, ranking consistency, score distribution, and edge case tests to ensure the commit message evaluator works correctly before running benchmarks. Includes rich CLI output and detailed test reporting. --- benchmarks/validation_suite.py | 924 +++++++++++++++++++++++++++++++++ 1 file changed, 924 insertions(+) create mode 100644 benchmarks/validation_suite.py diff --git a/benchmarks/validation_suite.py b/benchmarks/validation_suite.py new file mode 100644 index 0000000..bd1ebff --- /dev/null +++ b/benchmarks/validation_suite.py @@ -0,0 +1,924 @@ +""" +Evaluation *Smoke Test* Validation Suite for DiffMage + +This script performs fundamental sanity checks on your LLM evaluator to ensure +it's working correctly before running extensive benchmarks. + +Usage: + python validation_suite.py --all + python validation_suite.py --test obvious-cases + python validation_suite.py --test ranking-consistency + python validation_suite.py --test score-distribution + python validation_suite.py --test edge-cases +""" + +import argparse +import statistics +from typing import List, Dict, Any, Tuple, Optional +from dataclasses import dataclass + +from rich.console import Console +from rich.table import Table +from rich.panel import Panel +from rich.progress import Progress +from rich import box + +try: + from diffmage.evaluation.commit_message_evaluator import CommitMessageEvaluator +except ImportError as e: + print(f"Error importing DiffMage modules: {e}") + print("Make sure you're running from the project root and the package is installed") + exit(1) + +console = Console() + + +@dataclass +class ValidationCase: + """A test case for validation""" + + name: str + commit_message: str + git_diff: str + expected_score_range: Tuple[float, float] # (min, max) + expected_quality: str # "excellent", "good", "average", "poor", "very_poor" + description: str + + +class EvaluationValidator: + """Validates that the LLM evaluator is working correctly""" + + def __init__(self, model_name: Optional[str] = None): + self.evaluator = CommitMessageEvaluator(model_name=model_name) + self.console = console + + def get_obvious_test_cases(self) -> List[ValidationCase]: + """Get test cases with obvious expected outcomes""" + + return [ + # EXCELLENT cases (4.5-5.0) + ValidationCase( + name="security_fix", + commit_message="fix: resolve critical SQL injection vulnerability in user authentication", + git_diff="""--- a/src/auth/UserAuth.py ++++ b/src/auth/UserAuth.py +@@ -23,7 +23,8 @@ class UserAuth: + def authenticate_user(self, username, password): +- query = f"SELECT * FROM users WHERE username='{username}' AND password='{password}'" ++ query = "SELECT * FROM users WHERE username=? AND password=?" ++ return self.db.execute(query, (username, password)) +- return self.db.execute(query)""", + expected_score_range=(4.0, 5.0), + expected_quality="excellent", + description="Clear security fix with good explanation", + ), + ValidationCase( + name="feature_with_context", + commit_message="feat: implement user password reset with email verification\n\nAdds secure password reset flow:\n- Generate time-limited reset tokens\n- Send verification emails\n- Validate tokens before allowing reset\n- Log security events for auditing", + git_diff="""--- a/src/auth/PasswordReset.py ++++ b/src/auth/PasswordReset.py +@@ -0,0 +1,45 @@ ++import secrets ++import hashlib ++from datetime import datetime, timedelta ++ ++class PasswordReset: ++ def __init__(self, email_service, user_service): ++ self.email_service = email_service ++ self.user_service = user_service ++ ++ def request_reset(self, email): ++ user = self.user_service.find_by_email(email) ++ if not user: ++ return False # Don't reveal if email exists ++ ++ token = secrets.token_urlsafe(32) ++ expires_at = datetime.now() + timedelta(hours=1) ++ ++ self.user_service.store_reset_token(user.id, token, expires_at) ++ self.email_service.send_reset_email(email, token) ++ ++ return True""", + expected_score_range=(4.0, 5.0), + expected_quality="excellent", + description="Comprehensive feature with detailed explanation", + ), + # GOOD cases (3.5-4.0) + ValidationCase( + name="simple_bug_fix", + commit_message="fix: handle null values in user profile display", + git_diff="""--- a/src/components/UserProfile.jsx ++++ b/src/components/UserProfile.jsx +@@ -15,7 +15,7 @@ const UserProfile = ({ user }) => { +
+

{user.name}

+-

Email: {user.email}

++

Email: {user.email || 'Not provided'}

+-

Bio: {user.bio}

++

Bio: {user.bio || 'No bio available'}

+
+ );""", + expected_score_range=(3.0, 4.5), + expected_quality="good", + description="Clear bug fix, could use more detail", + ), + # AVERAGE cases (2.5-3.5) + ValidationCase( + name="generic_update", + commit_message="update user component", + git_diff="""--- a/src/components/User.jsx ++++ b/src/components/User.jsx +@@ -10,6 +10,7 @@ const User = ({ userData }) => { + return ( +
+ {userData.name} ++ {userData.role} +
+ );""", + expected_score_range=(2.0, 3.5), + expected_quality="average", + description="Vague message, minimal change", + ), + # POOR cases (1.5-2.5) + ValidationCase( + name="meaningless_message", + commit_message="fix stuff", + git_diff="""--- a/src/utils/helper.js ++++ b/src/utils/helper.js +@@ -5,7 +5,7 @@ function processData(data) { + if (!data) { + return null; + } +- return data.map(item => item.value); ++ return data.map(item => item.value || 0); + }""", + expected_score_range=(1.0, 2.5), + expected_quality="poor", + description="Meaningless commit message", + ), + # VERY POOR cases (1.0-2.0) + ValidationCase( + name="gibberish", + commit_message="asdf jkl; qwerty", + git_diff="""--- a/src/test.js ++++ b/src/test.js +@@ -1,3 +1,4 @@ + // Test file + console.log('hello'); ++console.log('world');""", + expected_score_range=(1.0, 2.0), + expected_quality="very_poor", + description="Nonsensical commit message", + ), + ] + + def get_edge_test_cases(self) -> List[ValidationCase]: + """Get edge cases that might break the evaluator""" + + return [ + ValidationCase( + name="empty_message", + commit_message="", + git_diff="--- a/file.txt\n+++ b/file.txt\n@@ -1 +1,2 @@\n hello\n+world", + expected_score_range=(1.0, 2.0), + expected_quality="very_poor", + description="Empty commit message", + ), + ValidationCase( + name="very_long_message", + commit_message="fix: " + + "very " * 100 + + "long commit message that goes on and on and provides way too much detail about a simple change " + * 10, + git_diff="--- a/file.txt\n+++ b/file.txt\n@@ -1 +1,2 @@\n hello\n+world", + expected_score_range=(1.0, 3.0), + expected_quality="poor", + description="Excessively long commit message", + ), + ValidationCase( + name="special_characters", + commit_message="fix: handle Γ©mojis πŸš€ and Γ±oΓ±-ASCII Γ§haracters in ΓΌser input", + git_diff="""--- a/src/input.py ++++ b/src/input.py +@@ -1,3 +1,4 @@ + def process_input(text): ++ text = text.encode('utf-8').decode('utf-8') + return text.strip()""", + expected_score_range=(3.0, 4.5), + expected_quality="good", + description="Special characters and emojis", + ), + ValidationCase( + name="no_diff", + commit_message="fix: important bug fix", + git_diff="", + expected_score_range=(1.0, 3.0), + expected_quality="poor", + description="No diff provided", + ), + ValidationCase( + name="merge_commit", + commit_message="Merge branch 'feature/user-auth' into main", + git_diff="""--- a/src/auth.py ++++ b/src/auth.py +@@ -1,10 +1,20 @@ + # Auth module ++# New authentication features + + def login(user, password): + return validate_credentials(user, password) ++ ++def logout(user): ++ clear_session(user) ++ ++def reset_password(email): ++ send_reset_email(email)""", + expected_score_range=(1.0, 3.0), + expected_quality="poor", + description="Merge commit (usually auto-generated)", + ), + ] + + def test_obvious_cases(self) -> Dict[str, Any]: + """Test if evaluator handles obviously good/bad cases correctly""" + + console.print( + Panel( + "[bold]Testing Obvious Cases[/bold]\n" + "Evaluator should clearly distinguish between excellent and poor commit messages", + title="🎯 Obvious Cases Test", + border_style="blue", + ) + ) + + test_cases = self.get_obvious_test_cases() + results = [] + + with Progress(console=self.console) as progress: + task = progress.add_task( + "Evaluating obvious cases...", total=len(test_cases) + ) + + for case in test_cases: + try: + result = self.evaluator.evaluate_commit_message( + case.commit_message, case.git_diff + ) + + # Check if score is in expected range + in_range = ( + case.expected_score_range[0] + <= result.overall_score + <= case.expected_score_range[1] + ) + + results.append( + { + "case": case, + "result": result, + "in_expected_range": in_range, + "score_deviation": self._calculate_deviation( + result.overall_score, case.expected_score_range + ), + } + ) + + except Exception as e: + console.print(f"[red]Error evaluating {case.name}: {e}[/red]") + results.append( + { + "case": case, + "result": None, + "error": str(e), + "in_expected_range": False, + "score_deviation": float("inf"), + } + ) + + progress.update(task, advance=1) + + # Analyze results + success_rate = ( + sum(1 for r in results if r.get("in_expected_range", False)) + / len(results) + * 100 + ) + avg_deviation = statistics.mean( + [ + r["score_deviation"] + for r in results + if r["score_deviation"] != float("inf") + ] + ) + + # Display results + self._display_obvious_cases_results(results, success_rate, avg_deviation) + + return { + "test_name": "obvious_cases", + "success_rate": success_rate, + "average_deviation": avg_deviation, + "results": results, + "passed": success_rate >= 70, # 70% of obvious cases should be correct + } + + def test_ranking_consistency(self) -> Dict[str, Any]: + """Test if evaluator consistently ranks messages in logical order""" + + console.print( + Panel( + "[bold]Testing Ranking Consistency[/bold]\n" + "Evaluator should rank obviously better messages higher than worse ones", + title="πŸ“Š Ranking Test", + border_style="green", + ) + ) + + # Get subset of test cases for ranking + test_cases = self.get_obvious_test_cases() + + # Sort by expected quality for comparison + expected_order = sorted( + test_cases, key=lambda x: x.expected_score_range[1], reverse=True + ) + + # Evaluate all cases + evaluated_cases = [] + + with Progress(console=self.console) as progress: + task = progress.add_task("Evaluating for ranking...", total=len(test_cases)) + + for case in test_cases: + try: + result = self.evaluator.evaluate_commit_message( + case.commit_message, case.git_diff + ) + evaluated_cases.append((case, result)) + except Exception as e: + console.print(f"[red]Error evaluating {case.name}: {e}[/red]") + continue + + progress.update(task, advance=1) + + # Sort by actual scores + actual_order = sorted( + evaluated_cases, key=lambda x: x[1].overall_score, reverse=True + ) + + # Calculate ranking consistency + ranking_violations = self._count_ranking_violations( + expected_order, actual_order + ) + total_pairs = len(evaluated_cases) * (len(evaluated_cases) - 1) // 2 + consistency_rate = ( + (total_pairs - ranking_violations) / total_pairs * 100 + if total_pairs > 0 + else 0 + ) + + # Display results + self._display_ranking_results( + expected_order, actual_order, consistency_rate, ranking_violations + ) + + return { + "test_name": "ranking_consistency", + "consistency_rate": consistency_rate, + "ranking_violations": ranking_violations, + "total_pairs": total_pairs, + "passed": consistency_rate >= 80, # 80% of rankings should be consistent + } + + def test_score_distribution(self) -> Dict[str, Any]: + """Test if evaluator uses the full score range appropriately""" + + console.print( + Panel( + "[bold]Testing Score Distribution[/bold]\n" + "Evaluator should use the full 1-5 scale and not cluster around one value", + title="πŸ“ˆ Distribution Test", + border_style="yellow", + ) + ) + + all_cases = self.get_obvious_test_cases() + self.get_edge_test_cases() + scores = [] + + with Progress(console=self.console) as progress: + task = progress.add_task("Collecting scores...", total=len(all_cases)) + + for case in all_cases: + try: + result = self.evaluator.evaluate_commit_message( + case.commit_message, case.git_diff + ) + scores.append( + { + "case_name": case.name, + "overall_score": result.overall_score, + "what_score": result.what_score, + "why_score": result.why_score, + "expected_quality": case.expected_quality, + } + ) + except Exception as e: + console.print(f"[red]Error evaluating {case.name}: {e}[/red]") + continue + + progress.update(task, advance=1) + + if not scores: + return { + "test_name": "score_distribution", + "passed": False, + "error": "No scores collected", + } + + # Analyze distribution + overall_scores = [s["overall_score"] for s in scores] + distribution_stats = { + "mean": statistics.mean(overall_scores), + "median": statistics.median(overall_scores), + "std_dev": statistics.stdev(overall_scores) + if len(overall_scores) > 1 + else 0, + "min": min(overall_scores), + "max": max(overall_scores), + "range": max(overall_scores) - min(overall_scores), + } + + # Check for problems + problems = [] + if distribution_stats["std_dev"] < 0.5: + problems.append("Low variance - scores too clustered") + if distribution_stats["range"] < 2.0: + problems.append("Narrow range - not using full scale") + if distribution_stats["mean"] > 4.0: + problems.append("Grade inflation - scores too high") + if distribution_stats["mean"] < 2.0: + problems.append("Grade deflation - scores too low") + + # Display results + self._display_distribution_results(distribution_stats, scores, problems) + + return { + "test_name": "score_distribution", + "distribution_stats": distribution_stats, + "problems": problems, + "scores": scores, + "passed": len(problems) == 0, + } + + def test_edge_cases(self) -> Dict[str, Any]: + """Test if evaluator handles edge cases gracefully""" + + console.print( + Panel( + "[bold]Testing Edge Cases[/bold]\n" + "Evaluator should handle unusual inputs without crashing", + title="⚠️ Edge Cases Test", + border_style="red", + ) + ) + + edge_cases = self.get_edge_test_cases() + results = [] + + with Progress(console=self.console) as progress: + task = progress.add_task("Testing edge cases...", total=len(edge_cases)) + + for case in edge_cases: + try: + result = self.evaluator.evaluate_commit_message( + case.commit_message, case.git_diff + ) + + # Check if result is reasonable + is_reasonable = ( + 1.0 <= result.overall_score <= 5.0 + and 1.0 <= result.what_score <= 5.0 + and 1.0 <= result.why_score <= 5.0 + and result.reasoning + and len(result.reasoning) > 10 + ) + + results.append( + { + "case": case, + "result": result, + "handled_gracefully": True, + "is_reasonable": is_reasonable, + "error": None, + } + ) + + except Exception as e: + results.append( + { + "case": case, + "result": None, + "handled_gracefully": False, + "is_reasonable": False, + "error": str(e), + } + ) + + progress.update(task, advance=1) + + # Analyze results + graceful_handling_rate = ( + sum(1 for r in results if r["handled_gracefully"]) / len(results) * 100 + ) + reasonable_results_rate = ( + sum(1 for r in results if r.get("is_reasonable", False)) + / len(results) + * 100 + ) + + # Display results + self._display_edge_cases_results( + results, graceful_handling_rate, reasonable_results_rate + ) + + return { + "test_name": "edge_cases", + "graceful_handling_rate": graceful_handling_rate, + "reasonable_results_rate": reasonable_results_rate, + "results": results, + "passed": graceful_handling_rate >= 90 and reasonable_results_rate >= 70, + } + + def run_all_tests(self) -> Dict[str, Any]: + """Run all validation tests""" + + console.print( + Panel( + "[bold]Running Complete Validation Suite[/bold]\n" + "Testing fundamental evaluator functionality", + title="πŸ§ͺ Validation Suite", + border_style="cyan", + ) + ) + + all_results = {} + + # Run each test + tests = [ + ("obvious_cases", self.test_obvious_cases), + ("ranking_consistency", self.test_ranking_consistency), + ("score_distribution", self.test_score_distribution), + ("edge_cases", self.test_edge_cases), + ] + + for test_name, test_func in tests: + console.print( + f"\n[blue]Running {test_name.replace('_', ' ').title()} test...[/blue]" + ) + try: + all_results[test_name] = test_func() + except Exception as e: + console.print(f"[red]Test {test_name} failed with error: {e}[/red]") + all_results[test_name] = { + "test_name": test_name, + "passed": False, + "error": str(e), + } + + # Overall assessment + passed_tests = sum( + 1 for result in all_results.values() if result.get("passed", False) + ) + total_tests = len(all_results) + overall_pass_rate = passed_tests / total_tests * 100 + model_used: str = self.evaluator.model_name + + # Display summary + self._display_overall_summary(all_results, overall_pass_rate) + + return { + "overall_pass_rate": overall_pass_rate, + "passed_tests": passed_tests, + "total_tests": total_tests, + "model_used": model_used, + "individual_results": all_results, + "evaluator_ready": overall_pass_rate >= 75, # 75% of tests should pass + } + + # Helper methods for calculations and display + def _calculate_deviation( + self, actual_score: float, expected_range: Tuple[float, float] + ) -> float: + """Calculate how far actual score is from expected range""" + if expected_range[0] <= actual_score <= expected_range[1]: + return 0.0 + elif actual_score < expected_range[0]: + return expected_range[0] - actual_score + else: + return actual_score - expected_range[1] + + def _count_ranking_violations( + self, expected_order: List, actual_order: List + ) -> int: + """Count pairs that are ranked in wrong order""" + violations = 0 + actual_scores = { + case.name: score + for case, score in [(c, r.overall_score) for c, r in actual_order] + } + + for i, case1 in enumerate(expected_order): + for case2 in expected_order[i + 1 :]: + # case1 should score higher than case2 + if actual_scores.get(case1.name, 0) < actual_scores.get(case2.name, 0): + violations += 1 + + return violations + + def _display_obvious_cases_results( + self, results: List, success_rate: float, avg_deviation: float + ): + """Display obvious cases test results""" + + table = Table(title="Obvious Cases Results", box=box.SIMPLE) + table.add_column("Case", style="cyan") + table.add_column("Expected", justify="center") + table.add_column("Actual", justify="center") + table.add_column("Status", justify="center") + table.add_column("Deviation", justify="center") + + for r in results: + if r.get("result"): + status = "βœ… Pass" if r["in_expected_range"] else "❌ Fail" + status_color = "green" if r["in_expected_range"] else "red" + + expected_range = f"{r['case'].expected_score_range[0]:.1f}-{r['case'].expected_score_range[1]:.1f}" + actual_score = f"{r['result'].overall_score:.1f}" + deviation = f"{r['score_deviation']:.1f}" + + table.add_row( + r["case"].name, + expected_range, + actual_score, + f"[{status_color}]{status}[/{status_color}]", + deviation, + ) + else: + table.add_row( + r["case"].name, "N/A", "ERROR", "[red]❌ Error[/red]", "∞" + ) + + console.print(table) + + # Summary + summary_color = ( + "green" if success_rate >= 70 else "yellow" if success_rate >= 50 else "red" + ) + console.print( + f"\n[{summary_color}]Success Rate: {success_rate:.1f}% | Average Deviation: {avg_deviation:.2f}[/{summary_color}]" + ) + + def _display_ranking_results( + self, + expected_order: List, + actual_order: List, + consistency_rate: float, + violations: int, + ): + """Display ranking consistency results""" + + table = Table(title="Ranking Comparison", box=box.SIMPLE) + table.add_column("Rank", justify="center") + table.add_column("Expected", style="cyan") + table.add_column("Actual", style="yellow") + table.add_column("Score", justify="center") + + for i, (case, result) in enumerate(actual_order[:5]): # Show top 5 + expected_case = expected_order[i] if i < len(expected_order) else None + expected_name = expected_case.name if expected_case else "N/A" + + table.add_row( + str(i + 1), expected_name, case.name, f"{result.overall_score:.1f}" + ) + + console.print(table) + + # Summary + summary_color = ( + "green" + if consistency_rate >= 80 + else "yellow" + if consistency_rate >= 60 + else "red" + ) + console.print( + f"\n[{summary_color}]Ranking Consistency: {consistency_rate:.1f}% | Violations: {violations}[/{summary_color}]" + ) + + def _display_distribution_results(self, stats: Dict, scores: List, problems: List): + """Display score distribution results""" + + table = Table(title="Score Distribution Statistics", box=box.SIMPLE) + table.add_column("Metric", style="cyan") + table.add_column("Value", justify="center") + table.add_column("Assessment", style="yellow") + + table.add_row( + "Mean", + f"{stats['mean']:.2f}", + "Good" if 2.5 <= stats["mean"] <= 3.5 else "Check", + ) + table.add_row( + "Std Dev", + f"{stats['std_dev']:.2f}", + "Good" if stats["std_dev"] >= 0.5 else "Low variance", + ) + table.add_row( + "Range", + f"{stats['range']:.2f}", + "Good" if stats["range"] >= 2.0 else "Narrow range", + ) + table.add_row("Min", f"{stats['min']:.2f}", "") + table.add_row("Max", f"{stats['max']:.2f}", "") + + console.print(table) + + if problems: + console.print("\n[red]Issues Found:[/red]") + for problem in problems: + console.print(f" β€’ {problem}") + else: + console.print("\n[green]βœ… Score distribution looks healthy[/green]") + + def _display_edge_cases_results( + self, results: List, graceful_rate: float, reasonable_rate: float + ): + """Display edge cases test results""" + + table = Table(title="Edge Cases Results", box=box.SIMPLE) + table.add_column("Case", style="cyan") + table.add_column("Handled", justify="center") + table.add_column("Reasonable", justify="center") + table.add_column("Error", style="red") + + for r in results: + handled = "βœ… Yes" if r["handled_gracefully"] else "❌ No" + reasonable = "βœ… Yes" if r.get("is_reasonable", False) else "❌ No" + error = ( + r.get("error", "")[:50] + "..." + if r.get("error") and len(r.get("error", "")) > 50 + else r.get("error", "") + ) + + table.add_row(r["case"].name, handled, reasonable, error) + + console.print(table) + + console.print( + f"\n[blue]Graceful Handling: {graceful_rate:.1f}% | Reasonable Results: {reasonable_rate:.1f}%[/blue]" + ) + + def _display_overall_summary(self, all_results: Dict, overall_pass_rate: float): + """Display overall validation summary""" + + console.print( + Panel( + f"[bold]Validation Complete[/bold]\n" + f"Overall Pass Rate: {overall_pass_rate:.1f}%", + title="πŸ“‹ Summary", + border_style="green" + if overall_pass_rate >= 75 + else "yellow" + if overall_pass_rate >= 50 + else "red", + ) + ) + + summary_table = Table(title="Test Summary", box=box.SIMPLE) + summary_table.add_column("Test", style="cyan") + summary_table.add_column("Status", justify="center") + summary_table.add_column("Key Metric", justify="center") + + for test_name, result in all_results.items(): + status = "βœ… Pass" if result.get("passed", False) else "❌ Fail" + status_color = "green" if result.get("passed", False) else "red" + + # Extract key metric based on test type + if test_name == "obvious_cases": + key_metric = f"{result.get('success_rate', 0):.1f}% correct" + elif test_name == "ranking_consistency": + key_metric = f"{result.get('consistency_rate', 0):.1f}% consistent" + elif test_name == "score_distribution": + problems = result.get("problems", []) + key_metric = f"{len(problems)} issues" + elif test_name == "edge_cases": + key_metric = f"{result.get('graceful_handling_rate', 0):.1f}% handled" + else: + key_metric = "N/A" + + summary_table.add_row( + test_name.replace("_", " ").title(), + f"[{status_color}]{status}[/{status_color}]", + key_metric, + ) + + console.print(summary_table) + + # Recommendations + if overall_pass_rate >= 75: + console.print( + "\n[green]πŸŽ‰ Evaluator validation passed! Ready for benchmarking and research.[/green]" + ) + elif overall_pass_rate >= 50: + console.print( + "\n[yellow]⚠️ Evaluator has some issues but may be usable. Review failed tests.[/yellow]" + ) + else: + console.print( + "\n[red]❌ Evaluator has significant issues. Fix core problems before proceeding.[/red]" + ) + + +def main(): + """Main CLI interface""" + + parser = argparse.ArgumentParser( + description="Validation Suite for DiffMage Evaluator" + ) + parser.add_argument("--model", help="AI model to test (default: uses your default)") + + test_group = parser.add_mutually_exclusive_group() + test_group.add_argument( + "--all", action="store_true", help="Run all validation tests" + ) + test_group.add_argument( + "--test", + choices=[ + "obvious-cases", + "ranking-consistency", + "score-distribution", + "edge-cases", + ], + help="Run specific test", + ) + + parser.add_argument("--output", help="Save results to JSON file") + parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output") + + args = parser.parse_args() + + if not args.test and not args.all: + args.all = True + + try: + validator = EvaluationValidator(model_name=args.model) + except Exception as e: + console.print(f"[red]Error initializing validator: {e}[/red]") + console.print( + "[dim]Make sure you're in the project root and DiffMage is properly installed[/dim]" + ) + return 1 + + results = None + + try: + if args.all: + console.print("[blue]Running complete validation suite...[/blue]") + results = validator.run_all_tests() + + elif args.test == "obvious-cases": + results = validator.test_obvious_cases() + + elif args.test == "ranking-consistency": + results = validator.test_ranking_consistency() + + elif args.test == "score-distribution": + results = validator.test_score_distribution() + + elif args.test == "edge-cases": + results = validator.test_edge_cases() + + if args.output and results: + import json + + with open(args.output, "w") as f: + json.dump(results, f, indent=2, default=str) + console.print(f"[green]Results saved to {args.output}[/green]") + + if results and results.get("evaluator_ready", results.get("passed", False)): + console.print("\n[green]βœ… Validation completed successfully![/green]") + return 0 + else: + console.print("\n[red]❌ Validation found issues that need attention[/red]") + return 1 + + except Exception as e: + console.print(f"[red]Error during validation: {e}[/red]") + if args.verbose: + import traceback + + console.print(f"[dim]{traceback.format_exc()}[/dim]") + return 1 + + +if __name__ == "__main__": + exit(main())