From d9defa42afe6be5186b2aa97e4842e17fed68786 Mon Sep 17 00:00:00 2001 From: Niti Goyal Date: Mon, 9 Feb 2026 23:09:41 -0500 Subject: [PATCH 01/10] feat: add comprehensive pdd setup prompts for all LLM providers (#480) Create 7 new prompts to transform pdd setup into a comprehensive system bootstrapper that supports dynamic provider discovery, API key management, and interactive model selection across all LLM providers. Co-Authored-By: Claude Sonnet 4.5 --- README.md | 48 +-- SETUP_WITH_GEMINI.md | 29 +- context/api_key_scanner_example.py | 41 ++ context/api_key_validator_example.py | 63 +++ context/local_llm_configurator_example.py | 52 +++ context/model_selector_example.py | 52 +++ context/pddrc_initializer_example.py | 65 +++ context/provider_manager_example.py | 52 +++ docs/ONBOARDING.md | 34 +- docs/SETUP_GUIDE.md | 382 ++++++++++++++++++ pdd/docs/prompting_guide.md | 222 +++++++++- pdd/prompts/api_key_scanner_python.prompt | 70 ++++ pdd/prompts/api_key_validator_python.prompt | 59 +++ .../local_llm_configurator_python.prompt | 54 +++ pdd/prompts/model_selector_python.prompt | 55 +++ pdd/prompts/pddrc_initializer_python.prompt | 64 +++ pdd/prompts/provider_manager_python.prompt | 66 +++ pdd/prompts/setup_tool_python.prompt | 89 ++++ 18 files changed, 1457 insertions(+), 40 deletions(-) create mode 100644 context/api_key_scanner_example.py create mode 100644 context/api_key_validator_example.py create mode 100644 context/local_llm_configurator_example.py create mode 100644 context/model_selector_example.py create mode 100644 context/pddrc_initializer_example.py create mode 100644 context/provider_manager_example.py create mode 100644 docs/SETUP_GUIDE.md create mode 100644 pdd/prompts/api_key_scanner_python.prompt create mode 100644 pdd/prompts/api_key_validator_python.prompt create mode 100644 pdd/prompts/local_llm_configurator_python.prompt create mode 100644 pdd/prompts/model_selector_python.prompt create mode 100644 pdd/prompts/pddrc_initializer_python.prompt create mode 100644 pdd/prompts/provider_manager_python.prompt create mode 100644 pdd/prompts/setup_tool_python.prompt diff --git a/README.md b/README.md index 8d8de83dc..2a899cb7b 100644 --- a/README.md +++ b/README.md @@ -222,12 +222,20 @@ If you want to understand PDD fundamentals, follow this manual example to see it ### Post-Installation Setup (Required first step after installation) -Run the guided setup: +Run the comprehensive setup wizard: ```bash pdd setup ``` -This wraps the interactive bootstrap utility to install shell tab completion, capture your API keys, create ~/.pdd configuration files, and write the starter prompt. Re-run it any time to update keys or reinstall completion. +The setup wizard will: +- **Scan your environment** for API keys from all sources (shell, .env, ~/.pdd files) +- **Present an interactive menu** with options to add/fix keys, configure local LLMs (Ollama, LM Studio), add custom providers, or remove providers +- **Validate API keys** using actual LLM requests to ensure they work +- **Guide model selection** with cost transparency (show pricing for each tier) +- **Detect agentic CLI tools** (claude, gemini, codex) and offer installation +- **Create .pddrc** configuration file with sensible defaults for your project + +The wizard can be re-run at any time to update keys, add providers, or reconfigure settings. If you skip this step, the first regular pdd command you run will detect the missing setup files and print a reminder banner so you can finish onboarding later. @@ -236,7 +244,7 @@ Reload your shell so the new completion and environment hooks are available: source ~/.zshrc # or source ~/.bashrc / fish equivalent ``` -👉 If you prefer to configure things manually, see [SETUP_WITH_GEMINI.md](SETUP_WITH_GEMINI.md) for full instructions on obtaining a Gemini API key and creating your own `~/.pdd/llm_model.csv`. +👉 For detailed setup documentation, see [docs/SETUP_GUIDE.md](docs/SETUP_GUIDE.md). For manual configuration, see [SETUP_WITH_GEMINI.md](SETUP_WITH_GEMINI.md). 5. **Run Hello**: ```bash @@ -321,28 +329,6 @@ For a concrete, up-to-date reference of supported models and example rows, see t For proper model identifiers to use in your custom configuration, refer to the [LiteLLM Model List](https://docs.litellm.ai/docs/providers) documentation. LiteLLM typically uses model identifiers in the format `provider/model_name` (e.g., "openai/gpt-4", "anthropic/claude-3-opus-20240229"). -## Post-Installation Setup - -1. Run the guided setup (required unless you do this manually or use the cloud): -```bash -pdd setup -``` -This wraps the interactive bootstrap utility to install shell tab completion, capture your API keys, create `~/.pdd` configuration files, and write the starter prompt. Re-run it any time to update keys or reinstall completion. -If you skip this step, the first regular `pdd` command you run will detect the missing setup files and print a reminder banner so you can finish onboarding later (the banner is suppressed once `~/.pdd/api-env` exists or when your project already provides credentials via `.env` or `.pdd/`). - -2. Reload your shell so the new completion and environment hooks are available: -```bash -source ~/.zshrc # or source ~/.bashrc / fish equivalent -``` - -3. Configure environment variables (optional): -```bash -# Add to .bashrc, .zshrc, or equivalent -export PDD_AUTO_UPDATE=true -export PDD_GENERATE_OUTPUT_PATH=/path/to/generated/code/ -export PDD_TEST_OUTPUT_PATH=/path/to/tests/ -``` - ## Troubleshooting Common Installation Issues 1. **Command not found** @@ -2710,13 +2696,23 @@ The `.pddrc` approach is recommended for team projects as it ensures consistent ### Model Configuration (`llm_model.csv`) -PDD uses a CSV file (`llm_model.csv`) to store information about available AI models, their costs, capabilities, and required API key names. When running commands locally (e.g., using the `update_model_costs.py` utility or potentially local execution modes if implemented), PDD determines which configuration file to use based on the following priority: +PDD uses a CSV file (`llm_model.csv`) to store information about available AI models, their costs, capabilities, and required API key names. The `pdd setup` wizard automatically manages this file by: + +- **Dynamic provider discovery:** Reading all provider API keys from the CSV to scan your environment +- **Interactive model selection:** Letting you choose which model tiers to enable (Fast/Cheap, Balanced, Most Capable) with cost transparency +- **Custom provider support:** Adding custom LiteLLM-compatible providers and local LLMs (Ollama, LM Studio) +- **Provider removal:** Safely removing providers by deleting their model rows from the CSV + +When running commands locally, PDD determines which configuration file to use based on the following priority: 1. **User-specific:** `~/.pdd/llm_model.csv` - If this file exists, it takes precedence over any project-level configuration. This allows users to maintain a personal, system-wide model configuration. 2. **Project-specific:** `/.pdd/llm_model.csv` - If the user-specific file is not found, PDD looks for the file within the `.pdd` directory of the determined project root (based on `PDD_PATH` or auto-detection). 3. **Package default:** If neither of the above exist, PDD falls back to the default configuration bundled with the package installation. This tiered approach allows for both shared project configurations and individual user overrides, while ensuring PDD works out-of-the-box without requiring manual configuration. + +**Note:** The setup wizard uses this CSV as the source of truth for provider discovery and model selection. You can manually edit it, but running `pdd setup` again is the recommended way to manage providers and models. + *Note: This file-based configuration primarily affects local operations and utilities. Cloud execution modes likely rely on centrally managed configurations.* diff --git a/SETUP_WITH_GEMINI.md b/SETUP_WITH_GEMINI.md index b021fe8ee..4eaeda3ea 100644 --- a/SETUP_WITH_GEMINI.md +++ b/SETUP_WITH_GEMINI.md @@ -60,14 +60,27 @@ Right after installation, let PDD bootstrap its configuration: pdd setup ``` -During the wizard: -- Choose **Install tab completion** if you want shell helpers. -- Pick **Google Gemini** when asked which providers to configure. -- Paste your Gemini API key when prompted (you can create it in the next step if you haven’t already). - -The wizard writes your credentials to `~/.pdd/api-env`, seeds `~/.pdd/llm_model.csv` with Gemini entries, and reminds you to reload your shell (`source ~/.zshrc`, etc.) so completion and env hooks load. - -If you prefer to configure everything manually—or you’re on an offline machine—skip the wizard and follow the manual instructions below. +The interactive setup wizard will: +1. **Scan your environment** for existing API keys from all sources +2. **Show an interactive menu** with options to: + - Add or fix API keys (including Gemini) + - Add local LLMs (Ollama, LM Studio) + - Add custom providers + - Remove providers +3. **Validate your Gemini API key** with a real test request +4. **Guide model selection** with cost transparency +5. **Detect agentic CLI tools** and offer installation +6. **Create .pddrc** for your project + +When adding your Gemini API key: +- Select option `1. Add or fix API keys` from the menu +- The wizard will detect that `GEMINI_API_KEY` is missing +- Paste your API key when prompted (you can create it in the next step if you haven't already) +- The wizard tests it immediately and confirms it works + +The wizard writes your credentials to `~/.pdd/api-env.zsh` (or `.bash`), updates `llm_model.csv` with your selected models, and reminds you to reload your shell (`source ~/.zshrc`, etc.) so completion and env hooks load. + +If you prefer to configure everything manually—or you're on an offline machine—skip the wizard and follow the manual instructions below. --- diff --git a/context/api_key_scanner_example.py b/context/api_key_scanner_example.py new file mode 100644 index 000000000..77b9d57a5 --- /dev/null +++ b/context/api_key_scanner_example.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +import sys +from pathlib import Path + +# Add the project root to sys.path +project_root = Path(__file__).resolve().parent.parent +sys.path.append(str(project_root)) + +from pdd.setup.api_key_scanner import scan_environment, KeyInfo + + +def main() -> None: + """ + Demonstrates how to use the api_key_scanner module to: + 1. Dynamically discover all API keys from llm_model.csv + 2. Check multiple sources (shell env, .env file, ~/.pdd/api-env.*) + 3. Get transparency about where each key is loaded from + """ + + print("Scanning environment for API keys...\n") + + # Scan the environment for all API keys defined in llm_model.csv + scan_results = scan_environment() + + # Display results + for key_name, key_info in scan_results.items(): + status = "✓ Set" if key_info.is_set else "✗ Not set" + source = f"({key_info.source})" if key_info.is_set else "" + masked_value = key_info.value if key_info.is_set else "—" + + print(f" {key_name:25s} {status:12s} {source:30s}") + if key_info.is_set: + print(f" Value: {masked_value}") + + print(f"\nTotal keys found: {len([k for k in scan_results.values() if k.is_set])}") + print(f"Total keys missing: {len([k for k in scan_results.values() if not k.is_set])}") + + +if __name__ == "__main__": + main() diff --git a/context/api_key_validator_example.py b/context/api_key_validator_example.py new file mode 100644 index 000000000..0c6af7119 --- /dev/null +++ b/context/api_key_validator_example.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +import os +import sys +from pathlib import Path + +# Add the project root to sys.path +project_root = Path(__file__).resolve().parent.parent +sys.path.append(str(project_root)) + +from pdd.setup.api_key_validator import validate_key, ValidationResult + + +def main() -> None: + """ + Demonstrates how to use the api_key_validator module to: + 1. Validate API keys using llm_invoke instead of HTTP requests + 2. Get detailed error messages for debugging + 3. Understand which model was tested and which provider + """ + + print("API Key Validation Example\n") + + # Example 1: Validate Anthropic API key + anthropic_key = os.getenv("ANTHROPIC_API_KEY") + if anthropic_key: + print("Testing ANTHROPIC_API_KEY...") + result = validate_key("ANTHROPIC_API_KEY", anthropic_key) + display_result(result) + else: + print("ANTHROPIC_API_KEY not set - skipping validation") + + print() + + # Example 2: Validate OpenAI API key + openai_key = os.getenv("OPENAI_API_KEY") + if openai_key: + print("Testing OPENAI_API_KEY...") + result = validate_key("OPENAI_API_KEY", openai_key) + display_result(result) + else: + print("OPENAI_API_KEY not set - skipping validation") + + print() + + # Example 3: Test with invalid key + print("Testing with invalid key...") + result = validate_key("ANTHROPIC_API_KEY", "sk-ant-invalid-key-123") + display_result(result) + + +def display_result(result: ValidationResult) -> None: + """Helper to display validation results""" + if result.is_valid: + print(f" ✓ Valid - Provider: {result.provider}, Model tested: {result.model_tested}") + else: + print(f" ✗ Invalid - Provider: {result.provider}") + if result.error_message: + print(f" Error: {result.error_message}") + + +if __name__ == "__main__": + main() diff --git a/context/local_llm_configurator_example.py b/context/local_llm_configurator_example.py new file mode 100644 index 000000000..201c485f7 --- /dev/null +++ b/context/local_llm_configurator_example.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +import sys +from pathlib import Path + +# Add the project root to sys.path +project_root = Path(__file__).resolve().parent.parent +sys.path.append(str(project_root)) + +from pdd.setup.local_llm_configurator import configure_local_llm + + +def main() -> None: + """ + Demonstrates how to use the local_llm_configurator module to: + 1. Configure Ollama with auto-detection of installed models + 2. Configure LM Studio with custom port + 3. Add custom local LLM endpoints + """ + + print("Local LLM Configuration Example\n") + + print("This would present an interactive menu:") + print() + print("What tool are you using?") + print(" 1. LM Studio (default: localhost:1234)") + print(" 2. Ollama (default: localhost:11434)") + print(" 3. Other (custom base URL)") + print(" Choice: 2") + print() + print("Querying Ollama at http://localhost:11434...") + print("Found installed models:") + print(" 1. llama3:70b") + print(" 2. codellama:34b") + print(" 3. mistral:7b") + print() + print("Which models do you want to add? [1,2,3]: 1,2") + print("✓ Added ollama_chat/llama3:70b and ollama_chat/codellama:34b to llm_model.csv") + + # Run the actual configuration + # configure_local_llm() # Uncomment to run interactively + + print("\n\nKey Features:") + print(" • Ollama auto-detection: Queries API for installed models") + print(" • LM Studio defaults: Pre-filled localhost:1234 base URL") + print(" • Custom endpoints: Support for any LiteLLM-compatible provider") + print(" • Multiple models: Add several models in one session") + print(" • Zero cost: Local models set to $0.0001 or $0 costs") + + +if __name__ == "__main__": + main() diff --git a/context/model_selector_example.py b/context/model_selector_example.py new file mode 100644 index 000000000..43ac0df1c --- /dev/null +++ b/context/model_selector_example.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +import sys +from pathlib import Path + +# Add the project root to sys.path +project_root = Path(__file__).resolve().parent.parent +sys.path.append(str(project_root)) + +from pdd.setup.model_selector import interactive_selection + + +def main() -> None: + """ + Demonstrates how to use the model_selector module to: + 1. Group models by cost/capability tier + 2. Display pricing information for transparency + 3. Let users select which tiers to enable + 4. Explain how --strength controls model selection + """ + + print("Model Tier Selection Example\n") + + # Assume we have validated providers from earlier steps + validated_providers = ["Anthropic", "OpenAI", "Google"] + + print("This would present an interactive selection for each provider:") + print() + print("Models available for Anthropic:") + print() + print(" # Model Input Output ELO") + print(" 1. anthropic/claude-opus-4-5 $5.00 $25.00 1474") + print(" 2. anthropic/claude-sonnet-4-5 $3.00 $15.00 1370") + print(" 3. anthropic/claude-haiku-4-5 $1.00 $5.00 1270") + print() + print("Tip: pdd uses --strength (0.0–1.0) to pick models by cost/quality at runtime.") + print("Adding all models gives you the full range.") + print() + print("Include which models? [1,2,3] (default: all):") + + # Run the actual interactive selection + # interactive_selection(validated_providers) # Uncomment to run interactively + + print("\nKey Features:") + print(" • Tier classification: Groups models by cost (Fast/Cheap, Balanced, Most Capable)") + print(" • Cost transparency: Shows input/output token costs per million") + print(" • Smart defaults: Press Enter to include all models") + print(" • Strength explanation: Users learn how model selection works at runtime") + + +if __name__ == "__main__": + main() diff --git a/context/pddrc_initializer_example.py b/context/pddrc_initializer_example.py new file mode 100644 index 000000000..1c8f5791b --- /dev/null +++ b/context/pddrc_initializer_example.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +import sys +from pathlib import Path + +# Add the project root to sys.path +project_root = Path(__file__).resolve().parent.parent +sys.path.append(str(project_root)) + +from pdd.setup.pddrc_initializer import offer_pddrc_init + + +def main() -> None: + """ + Demonstrates how to use the pddrc_initializer module to: + 1. Check if .pddrc exists in current project + 2. Detect project language (Python/TypeScript/etc.) + 3. Offer to create .pddrc with sensible defaults + """ + + print(".pddrc Initialization Example\n") + + print("This checks for .pddrc in the current directory and offers to create one:") + print() + print("No .pddrc found in current project.") + print() + print("Would you like to create one with default settings?") + print(" Default language: python") + print(" Output path: pdd/") + print(" Test output path: tests/") + print() + print("Create .pddrc? [Y/n]") + + # Run the actual initialization + # was_created = offer_pddrc_init() # Uncomment to run interactively + # if was_created: + # print("✓ Created .pddrc with default settings") + + print("\n\nKey Features:") + print(" • Auto-detection: Detects language from project files (setup.py, package.json, etc.)") + print(" • Sensible defaults: Sets conventional paths for each language") + print(" • Non-destructive: Never overwrites existing .pddrc") + print(" • YAML format: Creates properly formatted configuration file") + + print("\n\nExample .pddrc content:") + print(""" +version: "1.0" + +contexts: + default: + defaults: + generate_output_path: "pdd/" + test_output_path: "tests/" + example_output_path: "context/" + default_language: "python" + target_coverage: 80.0 + strength: 1.0 + temperature: 0.0 + budget: 10.0 + max_attempts: 3 +""") + + +if __name__ == "__main__": + main() diff --git a/context/provider_manager_example.py b/context/provider_manager_example.py new file mode 100644 index 000000000..6531939fd --- /dev/null +++ b/context/provider_manager_example.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +import sys +from pathlib import Path + +# Add the project root to sys.path +project_root = Path(__file__).resolve().parent.parent +sys.path.append(str(project_root)) + +from pdd.setup.provider_manager import add_or_fix_keys, add_custom_provider, remove_provider +from pdd.setup.api_key_scanner import scan_environment + + +def main() -> None: + """ + Demonstrates how to use the provider_manager module to: + 1. Add or fix API keys for existing providers + 2. Add custom LiteLLM-compatible providers + 3. Remove providers (comment out keys, remove CSV rows) + """ + + print("Provider Management Example\n") + + # First, scan the environment to see what's configured + print("Scanning current configuration...") + scan_results = scan_environment() + + # Example 1: Add or fix keys for missing/invalid providers + print("\n--- Example 1: Add or Fix Keys ---") + print("This would prompt for any missing or invalid keys found in the scan.") + # add_or_fix_keys(scan_results) # Uncomment to run interactively + + # Example 2: Add a custom provider + print("\n--- Example 2: Add Custom Provider ---") + print("This guides you through adding a custom LiteLLM provider (e.g., Together AI, Deepinfra).") + # add_custom_provider() # Uncomment to run interactively + + # Example 3: Remove a provider + print("\n--- Example 3: Remove Provider ---") + print("This shows configured providers and lets you remove one.") + print("Removal comments out the key (doesn't delete) and removes model rows from CSV.") + # remove_provider(scan_results) # Uncomment to run interactively + + print("\nKey Features:") + print(" • Smart storage: Only saves newly entered keys to ~/.pdd/api-env.{{shell}}") + print(" • Key commenting: Never deletes keys, only comments with timestamp") + print(" • Atomic CSV writes: Uses temp file + rename to prevent corruption") + print(" • Validation: Tests keys with llm_invoke before saving") + + +if __name__ == "__main__": + main() diff --git a/docs/ONBOARDING.md b/docs/ONBOARDING.md index 2f6800b9e..1bdb9d454 100644 --- a/docs/ONBOARDING.md +++ b/docs/ONBOARDING.md @@ -85,7 +85,24 @@ To enable syntax highlighting for `.prompt` files in your editor, you'll need to ### 7. Set Up API Keys -Add your LLM API keys to a `.env` file in the project root: +**Recommended: Use the setup wizard** + +Run the interactive setup wizard to configure your API keys: + +```bash +pdd setup +``` + +The wizard will: +- **Scan your environment** for existing API keys from all sources (shell, .env, ~/.pdd files) +- **Present an interactive menu** to add/fix keys, configure local LLMs, or manage providers +- **Validate keys** with real test requests to ensure they work +- **Show cost transparency** for different model tiers +- **Create .pddrc** configuration for your project + +**Alternative: Manual configuration** + +If you prefer manual setup, add your LLM API keys to a `.env` file in the project root: ```bash # Required: At least one LLM provider @@ -93,7 +110,7 @@ OPENAI_API_KEY=sk-your-key-here # OR ANTHROPIC_API_KEY=sk-ant-your-key-here # OR -GOOGLE_API_KEY=your-google-api-key +GEMINI_API_KEY=your-google-api-key # Optional: For Vertex AI (Gemini via GCP) VERTEX_CREDENTIALS=/path/to/service-account.json @@ -841,6 +858,17 @@ rm -f ~/.pdd/llm_model.csv **Fix for "API key not found":** +**Recommended:** Run the setup wizard to detect and fix missing API keys: +```bash +pdd setup +``` + +The wizard will: +- Scan all sources (shell, .env, ~/.pdd files) and show which keys are missing +- Let you add missing keys with immediate validation +- Show exactly where each key is loaded from for transparency + +**Manual fixes:** - If using **Infisical**: Follow **"Step 7: Set Up Infisical for Secrets Management"** above to configure your API keys - If using **.env file**: Ensure your `.env` file in the project root contains your API keys (e.g., `OPENAI_API_KEY=sk-...`) @@ -850,6 +878,8 @@ rm -f ~/.pdd/llm_model.csv infisical run -- env | grep API_KEY # If using Infisical # OR env | grep API_KEY # If using .env +# OR +pdd setup # Shows scan of all keys with source transparency ``` **Note on API key requirements for testing:** diff --git a/docs/SETUP_GUIDE.md b/docs/SETUP_GUIDE.md new file mode 100644 index 000000000..c31da3f27 --- /dev/null +++ b/docs/SETUP_GUIDE.md @@ -0,0 +1,382 @@ +# PDD Setup Guide + +This guide covers the comprehensive `pdd setup` command, which helps you configure PDD with API keys, local LLMs, custom providers, and project settings. + +## Overview + +The `pdd setup` wizard provides an interactive menu-driven interface for configuring your PDD installation. It automatically: + +- **Scans your environment** for API keys from all sources (shell, .env, ~/.pdd files) +- **Validates API keys** with real test requests to ensure they work +- **Manages providers** - add, fix, or remove LLM providers +- **Configures local LLMs** - Ollama, LM Studio, or custom endpoints +- **Selects model tiers** - with cost transparency and guidance +- **Detects agentic CLIs** - checks for claude, gemini, codex and offers installation +- **Creates .pddrc** - project configuration with sensible defaults + +## Quick Start + +```bash +pdd setup +``` + +After installation, run the setup wizard. It will scan your environment and present an interactive menu. + +## The Setup Flow + +### 1. Environment Scan + +When you run `pdd setup`, it first scans for API keys: + +``` +═══════════════════════════════════════════════════════ +Scanning for API keys... +═══════════════════════════════════════════════════════ + + ANTHROPIC_API_KEY ✓ Valid (shell environment) + OPENAI_API_KEY ✓ Valid (.env file) + GROQ_API_KEY ✗ Invalid (shell environment) + GEMINI_API_KEY — Not found + FIREWORKS_API_KEY — Not found +``` + +The scan shows: +- **✓ Valid**: Key found and validated with a test request +- **✗ Invalid**: Key found but failed validation +- **— Not found**: No key found in any source + +**Source transparency:** Each key shows where it's loaded from: +- `(shell environment)` - From your shell's environment variables +- `(.env file)` - From the project's .env file +- `(~/.pdd/api-env.zsh)` - From PDD's managed key file + +### 2. Interactive Menu + +After the scan, you see the main menu: + +``` +What would you like to do? + 1. Add or fix API keys + 2. Add a local LLM (Ollama, LM Studio) + 3. Add a custom provider + 4. Remove a provider + 5. Continue → +``` + +#### Option 1: Add or Fix API Keys + +This option shows only providers that are missing or invalid: + +``` +GROQ_API_KEY (currently: invalid): + Enter new key: gsk_abc... + Testing with groq/mixtral-8x7b-32768... ✓ Valid + +GEMINI_API_KEY (currently: not set): + Enter key (or press Enter to skip): AIza... + Testing with gemini/gemini-1.5-flash... ✓ Valid +``` + +**Smart key storage:** +- Keys you **enter during setup** are saved to `~/.pdd/api-env.{{shell}}` +- Keys **already in your environment** are not duplicated + +After adding keys, you return to the main menu with an updated scan. + +#### Option 2: Add a Local LLM + +Local models (Ollama, LM Studio) don't need API keys - they need a `base_url` and model name: + +``` +What tool are you using? + 1. LM Studio (default: localhost:1234) + 2. Ollama (default: localhost:11434) + 3. Other (custom base URL) + Choice: 2 + +Querying Ollama at http://localhost:11434... +Found installed models: + 1. llama3:70b + 2. codellama:34b + 3. mistral:7b + +Which models do you want to add? [1,2,3]: 1,2 +✓ Added ollama_chat/llama3:70b and ollama_chat/codellama:34b to llm_model.csv +``` + +**Features:** +- **Ollama auto-detection**: Queries the API to list installed models +- **LM Studio defaults**: Pre-fills localhost:1234 base URL +- **Custom endpoints**: For any LiteLLM-compatible provider +- **Zero cost**: Local models are set to $0 or $0.0001 costs + +#### Option 3: Add a Custom Provider + +For LiteLLM-compatible providers not in the default CSV (Together AI, Deepinfra, Mistral, etc.): + +``` +Provider prefix (e.g. together_ai, deepinfra, mistral): together_ai +Model name: meta-llama/Llama-3-70b-chat +API key env var name: TOGETHERAI_API_KEY +Base URL (press Enter if standard): +Cost per 1M input tokens (optional): 0.90 +Cost per 1M output tokens (optional): 0.90 + +Testing together_ai/meta-llama/Llama-3-70b-chat... ✓ Valid +✓ Added to llm_model.csv +``` + +This lets you add any provider without manually editing the CSV. + +#### Option 4: Remove a Provider + +Shows configured providers and lets you safely remove one: + +``` +Configured providers: + 1. ANTHROPIC_API_KEY (3 models) + 2. OPENAI_API_KEY (5 models) + 3. GROQ_API_KEY (1 model) + 4. TOGETHERAI_API_KEY (1 model) [custom] + +Remove which provider? 4 + + # Commented out by pdd setup on 2026-02-09 + # export TOGETHERAI_API_KEY='tok_abc...' + + Removed 1 model from llm_model.csv +✓ TOGETHERAI_API_KEY removed +``` + +**Safe removal:** +- Keys are **commented out**, never deleted (easy to recover) +- Model rows are removed from `llm_model.csv` +- Prevents orphaned models in the CSV + +#### Option 5: Continue + +Proceeds to model selection, CLI detection, and .pddrc creation. + +### 3. Model Tier Selection + +After configuring providers (option 5), the wizard shows available models grouped by cost tier: + +``` +Models available for ANTHROPIC_API_KEY: + + # Model Input Output ELO + 1. anthropic/claude-opus-4-5 $5.00 $25.00 1474 + 2. anthropic/claude-sonnet-4-5 $3.00 $15.00 1370 + 3. anthropic/claude-haiku-4-5 $1.00 $5.00 1270 + +Tip: pdd uses --strength (0.0–1.0) to pick models by cost/quality at runtime. +Adding all models gives you the full range. + +Include which models? [1,2,3] (default: all): 2,3 +``` + +**Cost transparency:** +- Shows input/output token costs per million +- Displays ELO ratings for quality comparison +- Explains how `--strength` controls model selection + +**Smart defaults:** +- Press Enter to include all models (recommended) +- Or select specific tiers (e.g., just Haiku + Sonnet to avoid Opus costs) + +### 4. Agentic CLI Detection + +After model selection, setup checks for agentic CLI tools: + +``` +Checking agentic CLI harnesses... +(Required for: pdd fix, pdd change, pdd bug) + + Claude CLI ✓ Found at /usr/local/bin/claude + Codex CLI ✗ Not found + Gemini CLI ✗ Not found + +You have OPENAI_API_KEY but Codex CLI is not installed. + Install with: npm install -g @openai/codex + Install now? [y/N] +``` + +This proactive detection prevents errors when running `pdd fix` or `pdd change`. + +### 5. .pddrc Initialization + +Finally, setup offers to create a `.pddrc` configuration: + +``` +No .pddrc found in current project. + +Would you like to create one with default settings? + Default language: python + Output path: pdd/ + Test output path: tests/ + +Create .pddrc? [Y/n] +``` + +**Auto-detection:** +- Detects language from project files (setup.py, package.json, go.mod) +- Sets conventional paths for that language +- Creates properly formatted YAML configuration + +## API Key Loading Priority + +PDD checks for API keys in this order (highest priority first): + +1. **Shell environment variables** - `export ANTHROPIC_API_KEY=...` +2. **`.env` file** - In the project root +3. **`~/.pdd/api-env.{{shell}}`** - PDD's managed key file + +**Why this order?** +- Shell vars override .env (industry standard with `load_dotenv(override=False)`) +- Allows .env for development defaults, shell vars for production secrets +- Prevents .env from accidentally overwriting intentional shell configs + +**Source transparency:** The setup scan shows exactly which source provides each key. + +## Saving Keys: Smart Storage + +The setup wizard uses smart storage rules: + +- **Keys entered during setup** → Saved to `~/.pdd/api-env.{{shell}}` +- **Keys already in shell/environment** → Not saved (avoids duplicates) + +This prevents duplicating keys managed by Infisical, .env, shell profiles, etc. + +Example: +``` +Saving keys... + GROQ_API_KEY → saved to ~/.pdd/api-env.zsh (entered during setup) + GEMINI_API_KEY → saved to ~/.pdd/api-env.zsh (entered during setup) + ANTHROPIC_API_KEY → skipped (already in shell environment) + OPENAI_API_KEY → skipped (already in .env file) +``` + +## Re-running Setup + +You can run `pdd setup` at any time to: + +- Add new providers or fix invalid keys +- Add local LLM endpoints +- Remove providers +- Update model selections +- Reinstall shell completion + +The wizard always starts with a fresh environment scan, so you see the current state. + +## Manual Configuration (Alternative) + +If you prefer not to use the wizard, you can configure PDD manually: + +### Manual API Key Setup + +Create `~/.pdd/api-env.zsh` (or `.bash`): + +```bash +export ANTHROPIC_API_KEY='sk-ant-...' +export OPENAI_API_KEY='sk-...' +export GEMINI_API_KEY='AIza...' +``` + +Source it from your shell profile (~/.zshrc or ~/.bashrc): + +```bash +# Load PDD API keys +[ -f ~/.pdd/api-env.zsh ] && source ~/.pdd/api-env.zsh +``` + +### Manual .pddrc Setup + +Create `.pddrc` in your project root: + +```yaml +version: "1.0" + +contexts: + default: + defaults: + generate_output_path: "pdd/" + test_output_path: "tests/" + example_output_path: "context/" + default_language: "python" + target_coverage: 80.0 + strength: 1.0 + temperature: 0.0 + budget: 10.0 + max_attempts: 3 +``` + +### Manual llm_model.csv + +See [SETUP_WITH_GEMINI.md](../SETUP_WITH_GEMINI.md) for full manual configuration instructions. + +## Troubleshooting + +### "API key not found" + +Run the setup wizard: +```bash +pdd setup +``` + +It will scan all sources and show you exactly which keys are missing and where existing keys are loaded from. + +### "Invalid API key" + +The setup wizard tests keys immediately with `llm_invoke`. If validation fails: + +1. Check the error message for details (authentication vs network vs config) +2. Verify the key format (some providers have format requirements) +3. Check your account/quota status with the provider + +### Keys in multiple sources + +If a key exists in both .env and shell: + +- **Shell environment takes precedence** (industry standard) +- The setup scan shows which source is active: `(shell environment)` +- This prevents .env from overwriting intentional shell configs + +### Missing Ollama models + +If Ollama auto-detection fails: + +1. Check that Ollama is running: `ollama serve` +2. Verify the API is accessible: `curl http://localhost:11434/api/tags` +3. Fall back to manual model name entry in the wizard + +## Advanced Topics + +### Vertex AI Configuration + +For Google Vertex AI with service accounts: + +1. Create a service account JSON file from Google Cloud Console +2. Set `VERTEX_CREDENTIALS=/path/to/service-account.json` +3. Run `pdd setup` and add Vertex AI models when prompted + +### Multiple Projects + +- **Global keys**: Store in `~/.pdd/api-env.{{shell}}` for all projects +- **Project keys**: Store in project `.env` for project-specific overrides +- **Model preferences**: Each project can have its own `llm_model.csv` in `.pdd/` + +### CI/CD Integration + +For CI/CD pipelines: + +1. Don't use the interactive wizard (it requires user input) +2. Set API keys as environment variables in your CI system +3. Copy a pre-configured `llm_model.csv` to the project or user directory +4. Set `PDD_SKIP_SETUP=1` to bypass setup checks + +## Related Documentation + +- [README.md](../README.md) - Main PDD documentation +- [SETUP_WITH_GEMINI.md](../SETUP_WITH_GEMINI.md) - Manual setup guide +- [ONBOARDING.md](ONBOARDING.md) - Developer onboarding guide +- [whitepaper.md](whitepaper.md) - PDD concepts and architecture diff --git a/pdd/docs/prompting_guide.md b/pdd/docs/prompting_guide.md index 80f7be992..b336cca34 100644 --- a/pdd/docs/prompting_guide.md +++ b/pdd/docs/prompting_guide.md @@ -195,10 +195,10 @@ Tip: Prefer small, named sections using XML‑style tags to make context scannab The PDD preprocessor supports additional XML‑style tags to keep prompts clean, reproducible, and self‑contained. Processing order (per spec) is: `pdd` → `include`/`include-many` → `shell` → `web`. When `recursive=True`, `` and `` are deferred until a non‑recursive pass. -- `` +- `` - Purpose: human‑only comment. Removed entirely during preprocessing. - Use: inline rationale or notes that should not reach the model. - - Example: `Before step X explain why we do this here` + - Example: `Before step X ` - `` - Purpose: run a shell command and inline stdout at that position. @@ -222,6 +222,194 @@ The PDD preprocessor supports additional XML‑style tags to keep prompts clean, Use these tags sparingly. When you must use them, prefer stable commands with bounded output (e.g., `head -n 20` in ``). +**`context_urls` in Architecture Entries:** + +When an architecture.json entry includes a `context_urls` array, the `generate_prompt` template automatically converts each entry into a `` tag in the generated prompt's Dependencies section. This enables the LLM to fetch relevant API documentation during code generation: + +```json +"context_urls": [ + {"url": "https://fastapi.tiangolo.com/tutorial/first-steps/", "purpose": "FastAPI routing patterns"} +] +``` + +Becomes in the generated prompt: +```xml + + https://fastapi.tiangolo.com/tutorial/first-steps/ + +``` + +The tag name is derived from the `purpose` field (lowercased, spaces replaced with underscores). This mechanism bridges architecture-level research with prompt-level context. + +--- + +## Architecture Metadata Tags + +PDD prompts can include optional XML metadata tags that sync with `architecture.json`. These tags enable bidirectional sync between prompt files and the architecture visualization, keeping your project's architecture documentation automatically up-to-date. + +### Tag Format + +Place architecture metadata tags at the **top of your prompt file** (after any `` directives but before the main content): + +```xml +Brief description of module's purpose (60-120 chars) + + +{{ + "type": "module", + "module": {{ + "functions": [ + {"name": "function_name", "signature": "(...)", "returns": "Type"} + ] + }} +}} + + +dependency_prompt_1.prompt +dependency_prompt_2.prompt +``` + +### Tag Reference + +**``** +- **Purpose**: One-line description of why this module exists +- **Maps to**: `architecture.json["reason"]` +- **Format**: Single line string (recommended 60-120 characters) +- **Example**: `Provides unified LLM invocation across all PDD operations.` + +**``** +- **Purpose**: JSON describing the module's public API (functions, commands, pages) +- **Maps to**: `architecture.json["interface"]` +- **Format**: Valid JSON matching one of four interface types (see below) +- **Example**: + ```xml + + {{ + "type": "module", + "module": {{ + "functions": [ + {"name": "llm_invoke", "signature": "(prompt, strength, ...)", "returns": "Dict"} + ] + }} + }} + + ``` + +**``** +- **Purpose**: References other prompt files this module depends on +- **Maps to**: `architecture.json["dependencies"]` array +- **Format**: Prompt filename (e.g., `llm_invoke_python.prompt`) +- **Multiple tags**: Use one `` tag per dependency +- **Example**: + ```xml + llm_invoke_python.prompt + path_resolution_python.prompt + ``` + +### Interface Types + +The `` tag supports four interface types, matching the architecture.json schema: + +**Module Interface** (Python modules with functions): +```json +{ + "type": "module", + "module": { + "functions": [ + {"name": "func_name", "signature": "(arg1, arg2)", "returns": "Type"} + ] + } +} +``` + +**CLI Interface** (Command-line interfaces): +```json +{ + "type": "cli", + "cli": { + "commands": [ + {"name": "cmd_name", "description": "What it does"} + ] + } +} +``` + +**Command Interface** (PDD commands): +```json +{ + "type": "command", + "command": { + "commands": [ + {"name": "cmd_name", "description": "What it does"} + ] + } +} +``` + +**Frontend Interface** (UI pages): +```json +{ + "type": "frontend", + "frontend": { + "pages": [ + {"name": "page_name", "route": "/path"} + ] + } +} +``` + +### Sync Workflow + +1. **Add/edit tags** in your prompt files using the format above +2. **Click "Sync from Prompt"** in the PDD Connect Architecture page (or call the API endpoint) +3. **Tags automatically update** `architecture.json` with your changes +4. **Architecture visualization** reflects the updated dependencies and interfaces + +Prompts are the **source of truth** - tags in prompt files override what's in `architecture.json`. This aligns with PDD's core philosophy that prompts, not code or documentation, are authoritative. + +### Validation + +Validation is **lenient**: +- Missing tags are OK - only fields with tags get updated +- Malformed XML/JSON is skipped without blocking sync +- Circular dependencies are detected and prevent invalid updates +- Missing dependency files generate warnings but don't block sync + +### Best Practices + +**Keep `` concise** (60-120 chars) +- Good: "Provides unified LLM invocation across all PDD operations." +- Too long: "This module exists because we needed a way to call different LLM providers through a unified interface that supports both streaming and non-streaming modes while also handling rate limiting and retry logic..." + +**Use prompt filenames for dependencies**, not module names +- Correct: `llm_invoke_python.prompt` +- Wrong: `pdd.llm_invoke` +- Wrong: `context/example.py` + +**Validate interface JSON before committing** +- Use a JSON validator to check syntax +- Ensure `type` field matches one of: `module`, `cli`, `command`, `frontend` +- Include required nested keys (`functions`, `commands`, or `pages`) + +**Run "Sync All" after bulk prompt updates** +- If you've edited multiple prompts, sync all at once +- Review the validation results for circular dependencies +- Fix any warnings before committing changes + +### Relationship to Other Tags + +**`` vs ``**: +- ``: Declares architectural dependency (updates `architecture.json`) +- ``: Injects content into prompt for LLM context (does NOT affect architecture) +- Use both when appropriate - they serve different purposes + +**`` tags vs ``: Human-only comments (removed by preprocessor, never reach LLM) +- Both are valid PDD directives with different purposes + +### Example: Complete Prompt with Metadata Tags + +See `docs/examples/prompt_with_metadata.prompt` for a full example showing all three metadata tags in context. + --- ## Advanced Tips @@ -544,7 +732,8 @@ Key practice: Code and examples are ephemeral (regenerated); Tests and Prompts a | Task Type | Where to Start | The Workflow | | :--- | :--- | :--- | | **New Feature** | **The Prompt** | 1. Add/Update Requirements in Prompt.
2. Regenerate Code (LLM sees existing tests).
3. Write new Tests to verify. | -| **Bug Fix** | **The Test File** | 1. Use `pdd bug` to create a failing test case (repro) in the Test file.
2. Clarify the Prompt to address the edge case if needed.
3. Run `pdd fix` (LLM sees the new test and must pass it). | +| **Bug Fix (Code)** | **The Test File** | 1. Use `pdd bug` to create a failing test case (repro) in the Test file.
2. Clarify the Prompt to address the edge case if needed.
3. Run `pdd fix` (LLM sees the new test and must pass it).
**Tip:** Use `pdd fix --protect-tests` if the tests from `pdd bug` are correct and you want to prevent the LLM from modifying them. | +| **Bug Fix (Prompt Defect)** | **The Prompt** | When `pdd bug` determines the prompt specification itself is wrong (Step 5.5), it auto-fixes the prompt file. The workflow then continues to generate tests based on the corrected prompt. | **Key insight:** When you run `pdd generate` after adding a test, the LLM sees that test as context. This means the generated code is constrained to pass it - the test acts as a specification, not just a verification. @@ -572,6 +761,31 @@ After a successful fix, ask: "Where should this knowledge live?" - "The code style was inconsistent" → Update preamble (not prompt) - "I prefer different variable names" → Update preamble/prompt +### Prompt Defects vs. Code Bugs + +In PDD, the prompt is the source of truth. However, prompts themselves can contain defects. The `pdd bug` agentic workflow (Step 5.5: Prompt Classification) distinguishes between two types of bugs: + +| Defect Type | Definition | Detection | Action | +|-------------|------------|-----------|--------| +| **Code Bug** | Code doesn't match the prompt specification | Tests fail because implementation diverges from requirements | Fix the code via `pdd fix` | +| **Prompt Defect** | Prompt doesn't match the intended behavior | User-reported expected behavior contradicts the prompt | Fix the prompt, then regenerate | + +**How Prompt Classification Works:** + +After root cause analysis (Step 5), the workflow examines whether: +1. The code correctly implements the prompt, but the prompt is wrong (→ Prompt Defect) +2. The code incorrectly implements the prompt (→ Code Bug) + +**Output markers** for automation: +- `DEFECT_TYPE: code` - Proceed with normal test generation +- `DEFECT_TYPE: prompt` - Auto-fix the prompt file first +- `PROMPT_FIXED: path/to/file.prompt` - Indicates which prompt was modified +- `PROMPT_REVIEW: reason` - Request human review for ambiguous cases + +**Default behavior:** When classification is uncertain, the workflow defaults to "code bug" to preserve backward compatibility. + +This classification prevents the "test oracle problem" - where tests generated from a flawed prompt would encode incorrect behavior, causing `pdd fix` to "fix" correct code to match the buggy specification. + --- ## PDD vs Interactive Agentic Coders (Claude Code, Cursor) @@ -690,4 +904,4 @@ Key differences: ## Final Notes -Think of prompts as your programming language. Keep them concise, explicit, and modular. Regenerate instead of patching, verify behavior with accumulating tests, and continuously back‑propagate implementation learnings into your prompts. That discipline is what converts maintenance from an endless patchwork into a compounding system of leverage. +Think of prompts as your programming language. Keep them concise, explicit, and modular. Regenerate instead of patching, verify behavior with accumulating tests, and continuously back‑propagate implementation learnings into your prompts. That discipline is what converts maintenance from an endless patchwork into a compounding system of leverage. \ No newline at end of file diff --git a/pdd/prompts/api_key_scanner_python.prompt b/pdd/prompts/api_key_scanner_python.prompt new file mode 100644 index 000000000..2e965df9d --- /dev/null +++ b/pdd/prompts/api_key_scanner_python.prompt @@ -0,0 +1,70 @@ +Dynamically discovers API keys from CSV, shell, .env, and PDD config files with source transparency. + + +{ + "type": "module", + "module": { + "functions": [ + {"name": "scan_environment", "signature": "() -> Dict[str, KeyInfo]", "returns": "Dict[str, KeyInfo]"} + ] + } +} + + +% You are an expert Python engineer. Your goal is to write the pdd/setup/api_key_scanner.py module. + +% Role & Scope +This module dynamically discovers API keys from all sources and determines their availability. It reads llm_model.csv to find all unique API key environment variable names, then checks shell environment, .env files, and ~/.pdd/api-env.* files to determine which keys are present and where they come from. + +% Requirements +1. Function: `scan_environment() -> Dict[str, KeyInfo]` - Returns mapping of key name to KeyInfo (value, source, is_set) +2. Dynamic provider discovery: Read pdd/data/llm_model.csv, extract all unique api_key column values +3. Multi-source detection: Check shell environment (os.environ), .env file (if exists), ~/.pdd/api-env.* files +4. Source transparency: For each key found, record whether it came from "shell environment", ".env file", or "~/.pdd/api-env.zsh" +5. Priority handling: If key exists in multiple sources, record the effective source based on loading order (shell overrides .env) +6. CSV parsing: Use csv.DictReader to read llm_model.csv, handle missing or malformed data gracefully +7. .env loading: Use python-dotenv's load_dotenv to read .env, compare os.environ before/after to determine source +8. Shell detection: Detect user's shell (zsh, bash) from SHELL environment variable to check correct api-env file +9. KeyInfo structure: Return namedtuple or dataclass with fields: value (str), source (str), is_set (bool) +10. Performance: Cache CSV reading, don't re-parse on every call unless file changes + +% Dependencies + +The CSV at pdd/data/llm_model.csv has columns: +provider,model,input,output,coding_arena_elo,base_url,api_key,max_reasoning_tokens,structured_output,reasoning_type,location + +Key field: api_key (e.g., "ANTHROPIC_API_KEY", "OPENAI_API_KEY", "VERTEX_CREDENTIALS", "FIREWORKS_API_KEY", "GROQ_API_KEY") + + + +# Standard usage of python-dotenv +from dotenv import load_dotenv +import os + +# Capture state before loading .env +before_env = set(os.environ.keys()) + +# Load .env file (doesn't override existing environment variables by default) +load_dotenv(dotenv_path=".env", override=False) + +# Capture state after loading .env +after_env = set(os.environ.keys()) + +# Keys that came from .env are those added during load +env_file_keys = after_env - before_env + + +% Instructions +- Return a dictionary mapping key name (e.g., "ANTHROPIC_API_KEY") to KeyInfo +- Use dataclass for KeyInfo with fields: value (masked), source, is_set +- Mask key values in returned data (show first 8 chars + "..." + last 4 chars) +- Handle case where CSV doesn't exist or is malformed (return empty dict) +- For source determination: shell env takes precedence over .env per industry standard +- Check ~/.pdd/api-env.zsh or ~/.pdd/api-env.bash depending on detected shell +- Don't raise exceptions, return best-effort results with logging for errors + +% Deliverables +- A Python module located at `pdd/setup/api_key_scanner.py`. +- The module must export the following symbols: + - `scan_environment`: Scans all sources for API keys and returns their status. + - `KeyInfo`: A dataclass containing information about a discovered API key. diff --git a/pdd/prompts/api_key_validator_python.prompt b/pdd/prompts/api_key_validator_python.prompt new file mode 100644 index 000000000..42fb90386 --- /dev/null +++ b/pdd/prompts/api_key_validator_python.prompt @@ -0,0 +1,59 @@ +Validates API keys using llm_invoke with minimal test prompts for all LLM providers. + + +{ + "type": "module", + "module": { + "functions": [ + {"name": "validate_key", "signature": "(key_name: str, key_value: str) -> ValidationResult", "returns": "ValidationResult"} + ] + } +} + + +% You are an expert Python engineer. Your goal is to write the pdd/setup/api_key_validator.py module. + +% Role & Scope +This module validates API keys by testing them with llm_invoke instead of hardcoded HTTP requests. It selects an appropriate test model for each provider, makes a minimal completion request, and returns validation results including error details. + +% Requirements +1. Function: `validate_key(key_name: str, key_value: str) -> ValidationResult` - Test if API key works +2. Use llm_invoke: Call pdd.llm_invoke.llm_invoke() with minimal test prompt instead of HTTP requests +3. Model selection: Map API key name to appropriate test model (e.g., ANTHROPIC_API_KEY -> claude-haiku-4-5) +4. Test prompt: Use simple prompt like "Say 'OK'" to minimize cost and latency +5. Error handling: Catch authentication errors, network errors, and invalid model errors separately +6. ValidationResult: Return dataclass with fields: is_valid (bool), provider (str), model_tested (str), error_message (Optional[str]) +7. Provider mapping: Derive provider from key name (ANTHROPIC_API_KEY -> Anthropic, OPENAI_API_KEY -> OpenAI, etc.) +8. Timeout: Set reasonable timeout (10s) for validation requests to avoid hanging +9. Cost awareness: Always use cheapest/fastest model for validation (Haiku for Anthropic, cheapest GPT model available for OpenAI, Gemini Flash for Google) +10. Vertex AI handling: For VERTEX_CREDENTIALS, test with vertex_ai/ prefix models + +% Dependencies + + context/llm_invoke_example.py + + + +The CSV at pdd/data/llm_model.csv has columns: +provider,model,input,output,coding_arena_elo,base_url,api_key,max_reasoning_tokens,structured_output,reasoning_type,location + +Example rows: +- Anthropic,anthropic/claude-haiku-4-5-20251001,1.0,5.0,1270,,ANTHROPIC_API_KEY,128000,True,budget, +- OpenAI,gpt-4o-mini,0.15,0.6,1249,,OPENAI_API_KEY,0,True,none, +- Google,vertex_ai/gemini-3-flash-preview,0.5,3.0,1430,,VERTEX_CREDENTIALS,0,True,effort,global + + +% Instructions +- Map key names to providers: ANTHROPIC_API_KEY -> Anthropic, OPENAI_API_KEY -> OpenAI, GEMINI_API_KEY -> Google, VERTEX_CREDENTIALS -> Google (Vertex), FIREWORKS_API_KEY -> Fireworks, GROQ_API_KEY -> Groq +- Select cheapest model for each provider from CSV for validation +- Set key as environment variable temporarily before calling llm_invoke (if not already set) +- Use try/except to catch litellm errors and categorize them (auth vs network vs config) +- Return ValidationResult dataclass with clear error messages for debugging +- Don't raise exceptions to caller, always return ValidationResult +- Log validation attempts and results for debugging + +% Deliverables +- A Python module located at `pdd/setup/api_key_validator.py`. +- The module must export the following symbols: + - `validate_key`: Tests if an API key is valid by making a minimal test request. + - `ValidationResult`: A dataclass containing the result of the validation. diff --git a/pdd/prompts/local_llm_configurator_python.prompt b/pdd/prompts/local_llm_configurator_python.prompt new file mode 100644 index 000000000..da3a69839 --- /dev/null +++ b/pdd/prompts/local_llm_configurator_python.prompt @@ -0,0 +1,54 @@ +Configures local LLMs (Ollama, LM Studio) with auto-detection and CSV integration. + + +{ + "type": "module", + "module": { + "functions": [ + {"name": "configure_local_llm", "signature": "() -> bool", "returns": "bool"} + ] + } +} + + +% You are an expert Python engineer. Your goal is to write the pdd/setup/local_llm_configurator.py module. + +% Role & Scope +This module guides users through configuring local LLM tools (Ollama, LM Studio, custom endpoints). Local models don't need API keys but require base_url and model name configuration. It auto-detects installed models where possible (e.g., querying Ollama API). + +% Requirements +1. Function: `configure_local_llm() -> bool` - Interactive setup for local LLM providers +2. Provider options: Present menu: 1. LM Studio (default: localhost:1234), 2. Ollama (default: localhost:11434), 3. Other (custom base URL) +3. Ollama detection: For Ollama, query http://localhost:11434/api/tags to list installed models +4. Model selection: For Ollama, show detected models and let user select which to add; for others, prompt for model name +5. CSV addition: Append rows to llm_model.csv with provider prefix (lm_studio/ or ollama_chat/), base_url, and empty api_key +6. Base URL validation: For custom providers, validate URL format (http/https scheme required) +7. Multiple models: Allow adding multiple models in one session (loop until user declines) +8. Cost defaults: For local models, set input/output costs to 0.0 or very low values ($0.0001) +9. Atomic CSV writes: Use temp file + rename to prevent corruption +10. Error handling: If Ollama API unreachable, fall back to manual model name entry + +% Dependencies + +The CSV at pdd/data/llm_model.csv has columns: +provider,model,input,output,coding_arena_elo,base_url,api_key,max_reasoning_tokens,structured_output,reasoning_type,location + +Example local LLM rows: +- lm_studio,lm_studio/openai-gpt-oss-120b-mlx-6,0.0001,0,1082,http://localhost:1234/v1,,0,True,effort, +- custom_model,custom/Qwen3-30B-A3B-4bit,0,0,1040,http://localhost:8080,,0,False,none, + + +% Instructions +- For Ollama auto-detection: Make GET request to http://localhost:11434/api/tags, parse JSON response for model list +- If Ollama unreachable, inform user and ask for manual model name entry +- For LM Studio: Default base URL is http://localhost:1234/v1, ask if user wants different port +- For model naming: Use prefix format (lm_studio/ or ollama_chat/) to match LiteLLM conventions +- Set reasonable defaults: ELO=1000, structured_output=True, reasoning_type=effort, max_reasoning_tokens=0 +- After adding each model, ask: "Add another local model? [y/N]" +- Use rich Console for formatted output and interactive prompts +- Return bool indicating whether any models were added + +% Deliverables +- A Python module located at `pdd/setup/local_llm_configurator.py`. +- The module must export the following symbol: + - `configure_local_llm`: Provides an interactive setup for local LLM providers. diff --git a/pdd/prompts/model_selector_python.prompt b/pdd/prompts/model_selector_python.prompt new file mode 100644 index 000000000..21d72102f --- /dev/null +++ b/pdd/prompts/model_selector_python.prompt @@ -0,0 +1,55 @@ +Interactive model tier selection with cost transparency and strength parameter guidance. + + +{ + "type": "module", + "module": { + "functions": [ + {"name": "interactive_selection", "signature": "(validated_providers: List[str]) -> bool", "returns": "bool"} + ] + } +} + + +% You are an expert Python engineer. Your goal is to write the pdd/setup/model_selector.py module. + +% Role & Scope +This module provides interactive model tier selection with cost transparency. It groups models by capability/cost tier, displays pricing information, and lets users select which tiers to enable. It explains how --strength controls model selection at runtime. + +% Requirements +1. Function: `interactive_selection(validated_providers: List[str]) -> bool` - Guide user through tier selection for each provider +2. Tier classification: Group models into tiers (Fast/Cheap, Balanced, Most Capable) based on cost and ELO +3. Cost display: Show input/output token costs for each tier (per million tokens) +4. Provider iteration: For each validated provider, show available models grouped by tier +5. User selection: Let user choose which tiers to include (default: all) via numbered input +6. Strength explanation: Briefly explain that pdd uses --strength (0.0-1.0) to pick models by cost/quality at runtime +7. CSV filtering: After selection, update llm_model.csv to only include chosen models (or keep all if user selects all) +8. Smart defaults: If user presses Enter without input, include all models for that provider +9. Tier thresholds: Use cost as primary classifier - Cheap: <=$1 input, Balanced: >$1 and <$3 input, Capable: >=$3 input +10. Interactive display: Use rich Console to show formatted tables with model info (name, cost, ELO) + +% Dependencies + +The CSV at pdd/data/llm_model.csv has columns: +provider,model,input,output,coding_arena_elo,base_url,api_key,max_reasoning_tokens,structured_output,reasoning_type,location + +Example cost ranges: +- Fast/Cheap: gpt-4o-mini ($0.15/$0.6), claude-haiku-4-5 ($1/$5) +- Balanced: claude-sonnet-4-5 ($3/$15), gpt-4o ($5/$15) +- Most Capable: claude-opus-4-5 ($5/$25), gpt-4-turbo ($10/$30) + + +% Instructions +- Read llm_model.csv to get all models for validated providers +- Group by provider first, then by tier within each provider +- Display clear table format: " # Model Input Output ELO" +- After showing tiers, prompt: "Include which models? [1,2,3] (default: all):" +- Parse user input (comma-separated numbers or "all") +- If models are filtered out, update CSV by removing those rows (atomic write with temp file) +- Show tip about --strength before starting selections: "Tip: pdd uses --strength (0.0-1.0) to pick models by cost/quality at runtime. Adding all models gives you the full range." +- Return bool indicating whether any changes were made + +% Deliverables +- A Python module located at `pdd/setup/model_selector.py`. +- The module must export the following symbol: + - `interactive_selection`: Guides the user through model tier selection for each provider. diff --git a/pdd/prompts/pddrc_initializer_python.prompt b/pdd/prompts/pddrc_initializer_python.prompt new file mode 100644 index 000000000..73af51b99 --- /dev/null +++ b/pdd/prompts/pddrc_initializer_python.prompt @@ -0,0 +1,64 @@ +Creates .pddrc configuration files with language-aware defaults and project detection. + + +{ + "type": "module", + "module": { + "functions": [ + {"name": "offer_pddrc_init", "signature": "() -> bool", "returns": "bool"} + ] + } +} + + +% You are an expert Python engineer. Your goal is to write the pdd/setup/pddrc_initializer.py module. + +% Role & Scope +This module offers to create a basic .pddrc configuration file in the current project directory if one doesn't exist. It sets sensible defaults for output paths, language, and strength/temperature based on detected project type. + +% Requirements +1. Function: `offer_pddrc_init() -> bool` - Check if .pddrc exists, offer to create with defaults +2. Existence check: Look for .pddrc in current working directory (os.getcwd()) +3. Language detection: Detect project type (Python if setup.py/pyproject.toml, TypeScript if package.json with typescript dep, else ask user) +4. Default paths: Set generate_output_path, test_output_path, example_output_path based on language (Python: pdd/, tests/, context/ | TypeScript: src/, __tests__/, examples/) +5. Strength defaults: Set strength=1.0, temperature=0.0, target_coverage=80.0 as sensible defaults +6. Interactive prompt: Ask "Would you like to create one with default settings? [Y/n]" before creating +7. File format: Write YAML format matching .pddrc specification (version, contexts, defaults) +8. Show preview: Display the default settings before creating (language, output paths) +9. Skip if exists: Don't overwrite existing .pddrc, just inform user it already exists +10. Return bool: Indicate whether file was created + +% Dependencies + +# Example .pddrc file structure +version: "1.0" + +contexts: + default: + defaults: + generate_output_path: "pdd/" + test_output_path: "tests/" + example_output_path: "context/" + default_language: "python" + target_coverage: 80.0 + strength: 1.0 + temperature: 0.0 + budget: 10.0 + max_attempts: 3 + + +% Instructions +- Check for project indicators: setup.py, pyproject.toml (Python), package.json (TypeScript/JavaScript), go.mod (Go), etc. +- If language unclear, prompt user: "What is your primary language? (python/typescript/go/etc.)" +- Map language to conventional paths: + - Python: generate="pdd/", test="tests/", example="context/" + - TypeScript: generate="src/", test="__tests__/", example="examples/" + - Go: generate=".", test=".", example="examples/" +- Use YAML format for .pddrc file (preserve comments for user guidance) +- After creation, print success message: "✓ Created .pddrc with default settings" +- If .pddrc already exists, print: "Found existing .pddrc - skipping initialization" + +% Deliverables +- A Python module located at `pdd/setup/pddrc_initializer.py`. +- The module must export the following symbol: + - `offer_pddrc_init`: Checks for an existing `.pddrc` and offers to create one with default settings. diff --git a/pdd/prompts/provider_manager_python.prompt b/pdd/prompts/provider_manager_python.prompt new file mode 100644 index 000000000..94de28904 --- /dev/null +++ b/pdd/prompts/provider_manager_python.prompt @@ -0,0 +1,66 @@ +Manages LLM providers: adding/fixing API keys, custom providers, and safe key removal. + + +{ + "type": "module", + "module": { + "functions": [ + {"name": "add_or_fix_keys", "signature": "(scan_results: Dict[str, KeyInfo]) -> bool", "returns": "bool"}, + {"name": "add_custom_provider", "signature": "() -> bool", "returns": "bool"}, + {"name": "remove_provider", "signature": "(scan_results: Dict[str, KeyInfo]) -> bool", "returns": "bool"} + ] + } +} + + +api_key_validator_python.prompt + +% You are an expert Python engineer. Your goal is to write the pdd/setup/provider_manager.py module. + +% Role & Scope +This module handles adding, removing, and managing LLM providers in PDD setup. It supports adding/fixing API keys for existing providers, adding custom LiteLLM-compatible providers, and removing providers (commenting out keys and removing model rows from CSV). + +% Requirements +1. Function: `add_or_fix_keys(scan_results: Dict[str, KeyInfo]) -> bool` - Prompt for missing/invalid keys, validate, and save +2. Function: `add_custom_provider() -> bool` - Guide user through adding custom provider (provider prefix, model name, API key, base URL, costs) +3. Function: `remove_provider(scan_results: Dict[str, KeyInfo]) -> bool` - Show configured providers, let user select one, comment out key and remove CSV rows +4. Smart storage: Save newly entered keys to ~/.pdd/api-env.{{shell}}, skip keys already in shell or .env +5. Key commenting: When removing, add comment with timestamp, don't delete (# Commented out by pdd setup on YYYY-MM-DD) +6. CSV updates: For custom provider, append row to llm_model.csv; for removal, delete all rows with matching api_key field +7. Atomic CSV writes: Use temp file + rename to prevent corruption on partial writes +8. Shell detection: Determine shell from SHELL env var, write to correct api-env file (.zsh or .bash) +9. Validation: After adding keys, call api_key_validator to test before saving +10. Interactive prompts: Use input() for all user prompts, handle empty input as skip/cancel + +% Dependencies + + context/api_key_validator_example.py + + + +The CSV at pdd/data/llm_model.csv has columns: +provider,model,input,output,coding_arena_elo,base_url,api_key,max_reasoning_tokens,structured_output,reasoning_type,location + +Example rows: +- OpenAI,gpt-4o-mini,0.15,0.6,1249,,OPENAI_API_KEY,0,True,none, +- Fireworks,fireworks_ai/accounts/fireworks/models/qwen3-coder-480b-a35b-instruct,0.45,1.80,1363,,FIREWORKS_API_KEY,0,False,none, +- lm_studio,lm_studio/openai-gpt-oss-120b-mlx-6,0.0001,0,1082,http://localhost:1234/v1,,0,True,effort, + + +% Instructions +- For add_or_fix_keys: Only prompt for keys that are missing or invalid in scan_results +- For each key to add: Prompt user, validate with api_key_validator, save only if valid +- Determine storage location: If key already in environment (not from api-env file), don't save to avoid duplicates +- For api-env file writes: Use shell-appropriate export format (export KEY="value") +- For remove_provider: Show numbered list of providers with model counts, let user pick, confirm before removal +- CSV atomic writes: Write to temp file, validate it's parseable, then rename over original +- Handle missing ~/.pdd directory (create it with mkdir -p) +- For custom provider: Use sensible defaults (structured_output=True, reasoning_type=none, location=empty) +- Return bool indicating success/failure for each function + +% Deliverables +- A Python module located at `pdd/setup/provider_manager.py`. +- The module must export the following symbols: + - `add_or_fix_keys`: Prompts the user to add or fix missing/invalid API keys. + - `add_custom_provider`: Guides the user through adding a new custom provider. + - `remove_provider`: Interactively removes a configured provider. diff --git a/pdd/prompts/setup_tool_python.prompt b/pdd/prompts/setup_tool_python.prompt new file mode 100644 index 000000000..509fe4a8e --- /dev/null +++ b/pdd/prompts/setup_tool_python.prompt @@ -0,0 +1,89 @@ +Orchestrates comprehensive PDD setup: API key scanning, provider management, and configuration. + + +{ + "type": "module", + "module": { + "functions": [ + {"name": "run_setup", "signature": "()", "returns": "None"} + ] + } +} + + +api_key_scanner_python.prompt +api_key_validator_python.prompt +provider_manager_python.prompt +model_selector_python.prompt +local_llm_configurator_python.prompt +pddrc_initializer_python.prompt + +% You are an expert Python engineer. Your goal is to write the pdd/setup/setup_tool.py module. + +% Role & Scope +This module is the main orchestrator for the comprehensive `pdd setup` command. It scans the environment for API keys, presents an interactive menu, and coordinates with helper modules to guide users through PDD configuration including API key management, local LLM setup, custom providers, model selection, and .pddrc initialization. + +% Requirements +1. Function: `run_setup()` - Main entry point that orchestrates the entire setup flow +2. Environment scanning: Auto-detect all API keys from CSV providers, check all sources (.env, shell environment, ~/.pdd/api-env.*) +3. Interactive menu: Present 5 options after scan (Add/fix keys, Add local LLM, Add custom provider, Remove provider, Continue) +4. Menu loop: After options 1-4, re-scan and show updated menu; option 5 proceeds to model selection +5. Delegate to modules: Use api_key_scanner for discovery, api_key_validator for testing, provider_manager for add/remove, model_selector for tier selection +6. Model selection flow: After menu (option 5), call model_selector for interactive tier selection with cost guidance +7. CLI harness detection: After model selection, detect agentic CLI tools (claude, gemini, codex) and offer installation +8. .pddrc initialization: Offer to create .pddrc if none exists in current directory +9. Smart key storage: Save newly entered keys to ~/.pdd/api-env.{{shell}}, skip keys already in environment +10. Output clarity: Use rich Console for formatted output, show ✓/✗ status, display key sources transparently + +% Dependencies + + context/api_key_scanner_example.py + + + + context/api_key_validator_example.py + + + + context/provider_manager_example.py + + + + context/model_selector_example.py + + + + context/local_llm_configurator_example.py + + + + context/pddrc_initializer_example.py + + + + context/agentic_common_example.py + + + +The CSV at $PDD_PATH/data/llm_model.csv has columns: +provider,model,input,output,coding_arena_elo,base_url,api_key,max_reasoning_tokens,structured_output,reasoning_type,location + +Key field: api_key (e.g., "ANTHROPIC_API_KEY", "OPENAI_API_KEY", "VERTEX_CREDENTIALS", "FIREWORKS_API_KEY", "GROQ_API_KEY") + +Note: At runtime, the CSV is located at pdd/data/llm_model.csv relative to the installed package. + + +% Instructions +- Use rich.console.Console for all interactive prompts and status display +- Preserve the overall flow of the existing setup_tool.py but reorganize to use the new modular architecture +- For key validation, delegate to api_key_validator.validate_key() instead of HTTP requests +- For menu option handlers, delegate to respective modules (provider_manager, local_llm_configurator, etc.) +- Display scan results with ✓ Valid (source), ✗ Invalid (source), — Not found format +- After any menu action (1-4), call api_key_scanner.scan_environment() again to refresh the display +- Implement simple input() prompts for menu selection (1-5) +- Handle Ctrl+C gracefully, allow exit at any point + +% Deliverables +- A Python module located at `pdd/setup/setup_tool.py`. +- The module must export the following symbol: + - `run_setup`: The main entry point that orchestrates the entire setup flow. From 2736c38c8fa9757f143dd75f7c1d3ff7b66c6bfd Mon Sep 17 00:00:00 2001 From: Niti Goyal Date: Sat, 14 Feb 2026 15:40:51 -0500 Subject: [PATCH 02/10] Update prompt and example files for new setup menu design - Delete model_selector_python.prompt and its example: interactive tier selection removed; adding a provider now auto-loads all its models - Delete api_key_validator_python.prompt and its example: replaced by model_tester which tests individual models via litellm.completion() directly instead of steering llm_invoke with PDD_MODEL_DEFAULT - Create model_tester_python.prompt and example: new "Test a model" menu option using litellm.completion() with direct api_key param, showing diagnostics, timing, and cost per call - Create cli_detector_python.prompt and example: new "Detect CLI tools" menu option leveraging get_available_agents() from agentic_common, cross-referencing API keys with installed CLIs - Rewrite setup_tool_python.prompt: new 6-option menu (Add provider, Remove models, Test model, Detect CLI, Init .pddrc, Done) with sub-menus, replacing the old 5-option flow that ended in model selection - Rewrite provider_manager_python.prompt: new functions add_api_key (auto-loads all models), remove_models_by_provider (comments out keys), and remove_individual_models, replacing add_or_fix_keys/remove_provider - Trim all prompts per prompting guide: remove implementation patterns, keep behavioral requirements, target 10-30% prompt-to-code ratio --- context/api_key_scanner_example.py | 31 ++++---- context/api_key_validator_example.py | 63 ---------------- context/cli_detector_example.py | 39 ++++++++++ context/local_llm_configurator_example.py | 49 ++++++------- context/model_selector_example.py | 52 -------------- context/model_tester_example.py | 44 ++++++++++++ context/pddrc_initializer_example.py | 53 ++++---------- context/provider_manager_example.py | 47 ++++++------ pdd/prompts/api_key_scanner_python.prompt | 60 ++++------------ pdd/prompts/api_key_validator_python.prompt | 59 --------------- pdd/prompts/cli_detector_python.prompt | 41 +++++++++++ .../local_llm_configurator_python.prompt | 44 ++++-------- pdd/prompts/model_selector_python.prompt | 55 -------------- pdd/prompts/model_tester_python.prompt | 37 ++++++++++ pdd/prompts/pddrc_initializer_python.prompt | 37 +++------- pdd/prompts/provider_manager_python.prompt | 59 +++++---------- pdd/prompts/setup_tool_python.prompt | 72 ++++++++----------- 17 files changed, 318 insertions(+), 524 deletions(-) delete mode 100644 context/api_key_validator_example.py create mode 100644 context/cli_detector_example.py delete mode 100644 context/model_selector_example.py create mode 100644 context/model_tester_example.py delete mode 100644 pdd/prompts/api_key_validator_python.prompt create mode 100644 pdd/prompts/cli_detector_python.prompt delete mode 100644 pdd/prompts/model_selector_python.prompt create mode 100644 pdd/prompts/model_tester_python.prompt diff --git a/context/api_key_scanner_example.py b/context/api_key_scanner_example.py index 77b9d57a5..018d6be0c 100644 --- a/context/api_key_scanner_example.py +++ b/context/api_key_scanner_example.py @@ -7,34 +7,35 @@ project_root = Path(__file__).resolve().parent.parent sys.path.append(str(project_root)) -from pdd.setup.api_key_scanner import scan_environment, KeyInfo +from pdd.setup.api_key_scanner import scan_environment, get_provider_key_names, KeyInfo def main() -> None: """ Demonstrates how to use the api_key_scanner module to: - 1. Dynamically discover all API keys from llm_model.csv - 2. Check multiple sources (shell env, .env file, ~/.pdd/api-env.*) - 3. Get transparency about where each key is loaded from + 1. Discover all API key variable names from llm_model.csv + 2. Scan multiple sources (shell env, .env file, ~/.pdd/api-env.*) + 3. Report existence and source without storing key values """ - print("Scanning environment for API keys...\n") + # Get all provider key names from the master CSV + all_keys = get_provider_key_names() + print(f"Provider key names from CSV: {all_keys}\n") - # Scan the environment for all API keys defined in llm_model.csv + # Scan the environment for all API keys + print("Scanning environment for API keys...\n") scan_results = scan_environment() - # Display results + # Display results — note: KeyInfo only has source and is_set, no value for key_name, key_info in scan_results.items(): - status = "✓ Set" if key_info.is_set else "✗ Not set" - source = f"({key_info.source})" if key_info.is_set else "" - masked_value = key_info.value if key_info.is_set else "—" - - print(f" {key_name:25s} {status:12s} {source:30s}") if key_info.is_set: - print(f" Value: {masked_value}") + print(f" {key_name:25s} ✓ Found ({key_info.source})") + else: + print(f" {key_name:25s} — Not found") - print(f"\nTotal keys found: {len([k for k in scan_results.values() if k.is_set])}") - print(f"Total keys missing: {len([k for k in scan_results.values() if not k.is_set])}") + found = sum(1 for k in scan_results.values() if k.is_set) + missing = sum(1 for k in scan_results.values() if not k.is_set) + print(f"\nFound: {found} Missing: {missing}") if __name__ == "__main__": diff --git a/context/api_key_validator_example.py b/context/api_key_validator_example.py deleted file mode 100644 index 0c6af7119..000000000 --- a/context/api_key_validator_example.py +++ /dev/null @@ -1,63 +0,0 @@ -from __future__ import annotations - -import os -import sys -from pathlib import Path - -# Add the project root to sys.path -project_root = Path(__file__).resolve().parent.parent -sys.path.append(str(project_root)) - -from pdd.setup.api_key_validator import validate_key, ValidationResult - - -def main() -> None: - """ - Demonstrates how to use the api_key_validator module to: - 1. Validate API keys using llm_invoke instead of HTTP requests - 2. Get detailed error messages for debugging - 3. Understand which model was tested and which provider - """ - - print("API Key Validation Example\n") - - # Example 1: Validate Anthropic API key - anthropic_key = os.getenv("ANTHROPIC_API_KEY") - if anthropic_key: - print("Testing ANTHROPIC_API_KEY...") - result = validate_key("ANTHROPIC_API_KEY", anthropic_key) - display_result(result) - else: - print("ANTHROPIC_API_KEY not set - skipping validation") - - print() - - # Example 2: Validate OpenAI API key - openai_key = os.getenv("OPENAI_API_KEY") - if openai_key: - print("Testing OPENAI_API_KEY...") - result = validate_key("OPENAI_API_KEY", openai_key) - display_result(result) - else: - print("OPENAI_API_KEY not set - skipping validation") - - print() - - # Example 3: Test with invalid key - print("Testing with invalid key...") - result = validate_key("ANTHROPIC_API_KEY", "sk-ant-invalid-key-123") - display_result(result) - - -def display_result(result: ValidationResult) -> None: - """Helper to display validation results""" - if result.is_valid: - print(f" ✓ Valid - Provider: {result.provider}, Model tested: {result.model_tested}") - else: - print(f" ✗ Invalid - Provider: {result.provider}") - if result.error_message: - print(f" Error: {result.error_message}") - - -if __name__ == "__main__": - main() diff --git a/context/cli_detector_example.py b/context/cli_detector_example.py new file mode 100644 index 000000000..52e57190c --- /dev/null +++ b/context/cli_detector_example.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +import sys +from pathlib import Path + +# Add the project root to sys.path +project_root = Path(__file__).resolve().parent.parent +sys.path.append(str(project_root)) + +from pdd.setup.cli_detector import detect_cli_tools + + +def main() -> None: + """ + Demonstrates how to use the cli_detector module to: + 1. Detect installed agentic CLI harnesses (claude, codex, gemini) + 2. Cross-reference with available API keys + 3. Offer installation for missing CLIs + """ + + # Run the interactive detector + # detect_cli_tools() # Uncomment to run interactively + + # Example flow: + # Checking CLI tools... + # (Required for: pdd fix, pdd change, pdd bug) + # + # Claude CLI ✓ Found at /usr/local/bin/claude + # Codex CLI ✗ Not found + # Gemini CLI ✗ Not found + # + # You have OPENAI_API_KEY but Codex CLI is not installed. + # Install with: npm install -g @openai/codex + # Install now? [y/N] + pass + + +if __name__ == "__main__": + main() diff --git a/context/local_llm_configurator_example.py b/context/local_llm_configurator_example.py index 201c485f7..0c52e0650 100644 --- a/context/local_llm_configurator_example.py +++ b/context/local_llm_configurator_example.py @@ -14,38 +14,29 @@ def main() -> None: """ Demonstrates how to use the local_llm_configurator module to: 1. Configure Ollama with auto-detection of installed models - 2. Configure LM Studio with custom port + 2. Configure LM Studio with default base URL 3. Add custom local LLM endpoints """ - print("Local LLM Configuration Example\n") - - print("This would present an interactive menu:") - print() - print("What tool are you using?") - print(" 1. LM Studio (default: localhost:1234)") - print(" 2. Ollama (default: localhost:11434)") - print(" 3. Other (custom base URL)") - print(" Choice: 2") - print() - print("Querying Ollama at http://localhost:11434...") - print("Found installed models:") - print(" 1. llama3:70b") - print(" 2. codellama:34b") - print(" 3. mistral:7b") - print() - print("Which models do you want to add? [1,2,3]: 1,2") - print("✓ Added ollama_chat/llama3:70b and ollama_chat/codellama:34b to llm_model.csv") - - # Run the actual configuration - # configure_local_llm() # Uncomment to run interactively - - print("\n\nKey Features:") - print(" • Ollama auto-detection: Queries API for installed models") - print(" • LM Studio defaults: Pre-filled localhost:1234 base URL") - print(" • Custom endpoints: Support for any LiteLLM-compatible provider") - print(" • Multiple models: Add several models in one session") - print(" • Zero cost: Local models set to $0.0001 or $0 costs") + # Run the interactive configuration + # was_added = configure_local_llm() # Uncomment to run interactively + + # Example flow for Ollama: + # What tool are you using? + # 1. LM Studio (default: localhost:1234) + # 2. Ollama (default: localhost:11434) + # 3. Other (custom base URL) + # Choice: 2 + # + # Querying Ollama at http://localhost:11434... + # Found installed models: + # 1. llama3:70b + # 2. codellama:34b + # 3. mistral:7b + # + # Which models do you want to add? [1,2,3]: 1,2 + # ✓ Added ollama_chat/llama3:70b and ollama_chat/codellama:34b to llm_model.csv + pass if __name__ == "__main__": diff --git a/context/model_selector_example.py b/context/model_selector_example.py deleted file mode 100644 index 43ac0df1c..000000000 --- a/context/model_selector_example.py +++ /dev/null @@ -1,52 +0,0 @@ -from __future__ import annotations - -import sys -from pathlib import Path - -# Add the project root to sys.path -project_root = Path(__file__).resolve().parent.parent -sys.path.append(str(project_root)) - -from pdd.setup.model_selector import interactive_selection - - -def main() -> None: - """ - Demonstrates how to use the model_selector module to: - 1. Group models by cost/capability tier - 2. Display pricing information for transparency - 3. Let users select which tiers to enable - 4. Explain how --strength controls model selection - """ - - print("Model Tier Selection Example\n") - - # Assume we have validated providers from earlier steps - validated_providers = ["Anthropic", "OpenAI", "Google"] - - print("This would present an interactive selection for each provider:") - print() - print("Models available for Anthropic:") - print() - print(" # Model Input Output ELO") - print(" 1. anthropic/claude-opus-4-5 $5.00 $25.00 1474") - print(" 2. anthropic/claude-sonnet-4-5 $3.00 $15.00 1370") - print(" 3. anthropic/claude-haiku-4-5 $1.00 $5.00 1270") - print() - print("Tip: pdd uses --strength (0.0–1.0) to pick models by cost/quality at runtime.") - print("Adding all models gives you the full range.") - print() - print("Include which models? [1,2,3] (default: all):") - - # Run the actual interactive selection - # interactive_selection(validated_providers) # Uncomment to run interactively - - print("\nKey Features:") - print(" • Tier classification: Groups models by cost (Fast/Cheap, Balanced, Most Capable)") - print(" • Cost transparency: Shows input/output token costs per million") - print(" • Smart defaults: Press Enter to include all models") - print(" • Strength explanation: Users learn how model selection works at runtime") - - -if __name__ == "__main__": - main() diff --git a/context/model_tester_example.py b/context/model_tester_example.py new file mode 100644 index 000000000..51755e1bb --- /dev/null +++ b/context/model_tester_example.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +import sys +from pathlib import Path + +# Add the project root to sys.path +project_root = Path(__file__).resolve().parent.parent +sys.path.append(str(project_root)) + +from pdd.setup.model_tester import test_model_interactive + + +def main() -> None: + """ + Demonstrates how to use the model_tester module to: + 1. List configured models from ~/.pdd/llm_model.csv + 2. Test a selected model via litellm.completion() + 3. Display diagnostics (API key status, timing, cost) + """ + + # Run the interactive tester + # test_model_interactive() # Uncomment to run interactively + + # Example flow: + # Configured models: + # 1. anthropic/claude-haiku-4-5-20251001 ANTHROPIC_API_KEY + # 2. gpt-5-nano OPENAI_API_KEY + # 3. lm_studio/openai-gpt-oss-120b-mlx-6 (local) + # + # Test which model? 1 + # Testing anthropic/claude-haiku-4-5-20251001... + # API key ANTHROPIC_API_KEY ✓ Found (shell environment) + # LLM call ✓ OK (0.3s, $0.0001) + # + # Test which model? 3 + # Testing lm_studio/openai-gpt-oss-120b-mlx-6... + # API key (local — no key required) + # Base URL http://localhost:1234/v1 + # LLM call ✗ Connection refused (localhost:1234) + pass + + +if __name__ == "__main__": + main() diff --git a/context/pddrc_initializer_example.py b/context/pddrc_initializer_example.py index 1c8f5791b..19e61d1cc 100644 --- a/context/pddrc_initializer_example.py +++ b/context/pddrc_initializer_example.py @@ -14,51 +14,24 @@ def main() -> None: """ Demonstrates how to use the pddrc_initializer module to: 1. Check if .pddrc exists in current project - 2. Detect project language (Python/TypeScript/etc.) + 2. Detect project language (Python/TypeScript/Go) 3. Offer to create .pddrc with sensible defaults """ - print(".pddrc Initialization Example\n") - - print("This checks for .pddrc in the current directory and offers to create one:") - print() - print("No .pddrc found in current project.") - print() - print("Would you like to create one with default settings?") - print(" Default language: python") - print(" Output path: pdd/") - print(" Test output path: tests/") - print() - print("Create .pddrc? [Y/n]") - - # Run the actual initialization + # Run the interactive initialization # was_created = offer_pddrc_init() # Uncomment to run interactively - # if was_created: - # print("✓ Created .pddrc with default settings") - - print("\n\nKey Features:") - print(" • Auto-detection: Detects language from project files (setup.py, package.json, etc.)") - print(" • Sensible defaults: Sets conventional paths for each language") - print(" • Non-destructive: Never overwrites existing .pddrc") - print(" • YAML format: Creates properly formatted configuration file") - - print("\n\nExample .pddrc content:") - print(""" -version: "1.0" -contexts: - default: - defaults: - generate_output_path: "pdd/" - test_output_path: "tests/" - example_output_path: "context/" - default_language: "python" - target_coverage: 80.0 - strength: 1.0 - temperature: 0.0 - budget: 10.0 - max_attempts: 3 -""") + # Example flow: + # No .pddrc found in current project. + # + # Would you like to create one with default settings? + # Default language: python + # Output path: pdd/ + # Test output path: tests/ + # + # Create .pddrc? [Y/n] + # ✓ Created .pddrc with default settings + pass if __name__ == "__main__": diff --git a/context/provider_manager_example.py b/context/provider_manager_example.py index 6531939fd..abf84ac1c 100644 --- a/context/provider_manager_example.py +++ b/context/provider_manager_example.py @@ -7,45 +7,42 @@ project_root = Path(__file__).resolve().parent.parent sys.path.append(str(project_root)) -from pdd.setup.provider_manager import add_or_fix_keys, add_custom_provider, remove_provider +from pdd.setup.provider_manager import ( + add_api_key, + add_custom_provider, + remove_models_by_provider, + remove_individual_models, +) from pdd.setup.api_key_scanner import scan_environment def main() -> None: """ Demonstrates how to use the provider_manager module to: - 1. Add or fix API keys for existing providers - 2. Add custom LiteLLM-compatible providers - 3. Remove providers (comment out keys, remove CSV rows) + 1. Add an API key and auto-load all models for that provider + 2. Add a custom LiteLLM-compatible provider + 3. Remove all models for a provider (comments out the key) + 4. Remove individual models from the user CSV """ - print("Provider Management Example\n") - # First, scan the environment to see what's configured - print("Scanning current configuration...") scan_results = scan_environment() - # Example 1: Add or fix keys for missing/invalid providers - print("\n--- Example 1: Add or Fix Keys ---") - print("This would prompt for any missing or invalid keys found in the scan.") - # add_or_fix_keys(scan_results) # Uncomment to run interactively + # Example 1: Add an API key (auto-loads all models for that provider) + # Shows missing keys, prompts for one, saves to api-env, copies CSV rows + # add_api_key(scan_results) # Uncomment to run interactively - # Example 2: Add a custom provider - print("\n--- Example 2: Add Custom Provider ---") - print("This guides you through adding a custom LiteLLM provider (e.g., Together AI, Deepinfra).") + # Example 2: Add a custom provider (Together AI, Deepinfra, etc.) + # Prompts for prefix, model name, API key var, base URL, costs # add_custom_provider() # Uncomment to run interactively - # Example 3: Remove a provider - print("\n--- Example 3: Remove Provider ---") - print("This shows configured providers and lets you remove one.") - print("Removal comments out the key (doesn't delete) and removes model rows from CSV.") - # remove_provider(scan_results) # Uncomment to run interactively - - print("\nKey Features:") - print(" • Smart storage: Only saves newly entered keys to ~/.pdd/api-env.{{shell}}") - print(" • Key commenting: Never deletes keys, only comments with timestamp") - print(" • Atomic CSV writes: Uses temp file + rename to prevent corruption") - print(" • Validation: Tests keys with llm_invoke before saving") + # Example 3: Remove all models for a provider + # Groups by api_key, removes CSV rows, comments out key in api-env + # remove_models_by_provider() # Uncomment to run interactively + + # Example 4: Remove individual models + # Lists all models, user picks by number, removes selected rows + # remove_individual_models() # Uncomment to run interactively if __name__ == "__main__": diff --git a/pdd/prompts/api_key_scanner_python.prompt b/pdd/prompts/api_key_scanner_python.prompt index 2e965df9d..34732ae2b 100644 --- a/pdd/prompts/api_key_scanner_python.prompt +++ b/pdd/prompts/api_key_scanner_python.prompt @@ -1,11 +1,12 @@ -Dynamically discovers API keys from CSV, shell, .env, and PDD config files with source transparency. +Discovers API keys from CSV providers, checking existence across shell, .env, and PDD config with source transparency. { "type": "module", "module": { "functions": [ - {"name": "scan_environment", "signature": "() -> Dict[str, KeyInfo]", "returns": "Dict[str, KeyInfo]"} + {"name": "scan_environment", "signature": "() -> Dict[str, KeyInfo]", "returns": "Dict[str, KeyInfo]"}, + {"name": "get_provider_key_names", "signature": "() -> List[str]", "returns": "List[str]"} ] } } @@ -14,57 +15,24 @@ % You are an expert Python engineer. Your goal is to write the pdd/setup/api_key_scanner.py module. % Role & Scope -This module dynamically discovers API keys from all sources and determines their availability. It reads llm_model.csv to find all unique API key environment variable names, then checks shell environment, .env files, and ~/.pdd/api-env.* files to determine which keys are present and where they come from. +Dynamically discovers API keys from all sources and reports their existence. Reads pdd/data/llm_model.csv to find all unique API key environment variable names, then checks .env files, shell environment, and ~/.pdd/api-env.* files. Only checks **existence** — never makes API calls or stores key values. % Requirements -1. Function: `scan_environment() -> Dict[str, KeyInfo]` - Returns mapping of key name to KeyInfo (value, source, is_set) -2. Dynamic provider discovery: Read pdd/data/llm_model.csv, extract all unique api_key column values -3. Multi-source detection: Check shell environment (os.environ), .env file (if exists), ~/.pdd/api-env.* files -4. Source transparency: For each key found, record whether it came from "shell environment", ".env file", or "~/.pdd/api-env.zsh" -5. Priority handling: If key exists in multiple sources, record the effective source based on loading order (shell overrides .env) -6. CSV parsing: Use csv.DictReader to read llm_model.csv, handle missing or malformed data gracefully -7. .env loading: Use python-dotenv's load_dotenv to read .env, compare os.environ before/after to determine source -8. Shell detection: Detect user's shell (zsh, bash) from SHELL environment variable to check correct api-env file -9. KeyInfo structure: Return namedtuple or dataclass with fields: value (str), source (str), is_set (bool) -10. Performance: Cache CSV reading, don't re-parse on every call unless file changes +1. Function: `scan_environment() -> Dict[str, KeyInfo]` — returns mapping of key name to KeyInfo(source, is_set). Does not store key values. +2. Function: `get_provider_key_names() -> List[str]` — returns deduplicated sorted list of all non-empty api_key values from the master CSV +3. Dynamic discovery: extract all unique api_key column values from pdd/data/llm_model.csv — no hardcoded provider list +4. Check sources in priority order: .env file (via python-dotenv `dotenv_values`, read-only), shell environment (`os.environ`), ~/.pdd/api-env.{shell} (parse uncommented `export KEY=` lines) +5. KeyInfo: dataclass with fields `source` (str) and `is_set` (bool). Report source as "shell environment", ".env file", or "~/.pdd/api-env.zsh" (etc.) +6. Detect shell from SHELL env var for correct api-env file +7. Never raise exceptions — return best-effort results with logging for errors +8. Handle missing/malformed CSV gracefully (return empty dict) % Dependencies The CSV at pdd/data/llm_model.csv has columns: provider,model,input,output,coding_arena_elo,base_url,api_key,max_reasoning_tokens,structured_output,reasoning_type,location - -Key field: api_key (e.g., "ANTHROPIC_API_KEY", "OPENAI_API_KEY", "VERTEX_CREDENTIALS", "FIREWORKS_API_KEY", "GROQ_API_KEY") +Rows with empty api_key are local LLMs (no key needed). - -# Standard usage of python-dotenv -from dotenv import load_dotenv -import os - -# Capture state before loading .env -before_env = set(os.environ.keys()) - -# Load .env file (doesn't override existing environment variables by default) -load_dotenv(dotenv_path=".env", override=False) - -# Capture state after loading .env -after_env = set(os.environ.keys()) - -# Keys that came from .env are those added during load -env_file_keys = after_env - before_env - - -% Instructions -- Return a dictionary mapping key name (e.g., "ANTHROPIC_API_KEY") to KeyInfo -- Use dataclass for KeyInfo with fields: value (masked), source, is_set -- Mask key values in returned data (show first 8 chars + "..." + last 4 chars) -- Handle case where CSV doesn't exist or is malformed (return empty dict) -- For source determination: shell env takes precedence over .env per industry standard -- Check ~/.pdd/api-env.zsh or ~/.pdd/api-env.bash depending on detected shell -- Don't raise exceptions, return best-effort results with logging for errors - % Deliverables -- A Python module located at `pdd/setup/api_key_scanner.py`. -- The module must export the following symbols: - - `scan_environment`: Scans all sources for API keys and returns their status. - - `KeyInfo`: A dataclass containing information about a discovered API key. +- Module at `pdd/setup/api_key_scanner.py` exporting `scan_environment`, `get_provider_key_names`, and `KeyInfo`. diff --git a/pdd/prompts/api_key_validator_python.prompt b/pdd/prompts/api_key_validator_python.prompt deleted file mode 100644 index 42fb90386..000000000 --- a/pdd/prompts/api_key_validator_python.prompt +++ /dev/null @@ -1,59 +0,0 @@ -Validates API keys using llm_invoke with minimal test prompts for all LLM providers. - - -{ - "type": "module", - "module": { - "functions": [ - {"name": "validate_key", "signature": "(key_name: str, key_value: str) -> ValidationResult", "returns": "ValidationResult"} - ] - } -} - - -% You are an expert Python engineer. Your goal is to write the pdd/setup/api_key_validator.py module. - -% Role & Scope -This module validates API keys by testing them with llm_invoke instead of hardcoded HTTP requests. It selects an appropriate test model for each provider, makes a minimal completion request, and returns validation results including error details. - -% Requirements -1. Function: `validate_key(key_name: str, key_value: str) -> ValidationResult` - Test if API key works -2. Use llm_invoke: Call pdd.llm_invoke.llm_invoke() with minimal test prompt instead of HTTP requests -3. Model selection: Map API key name to appropriate test model (e.g., ANTHROPIC_API_KEY -> claude-haiku-4-5) -4. Test prompt: Use simple prompt like "Say 'OK'" to minimize cost and latency -5. Error handling: Catch authentication errors, network errors, and invalid model errors separately -6. ValidationResult: Return dataclass with fields: is_valid (bool), provider (str), model_tested (str), error_message (Optional[str]) -7. Provider mapping: Derive provider from key name (ANTHROPIC_API_KEY -> Anthropic, OPENAI_API_KEY -> OpenAI, etc.) -8. Timeout: Set reasonable timeout (10s) for validation requests to avoid hanging -9. Cost awareness: Always use cheapest/fastest model for validation (Haiku for Anthropic, cheapest GPT model available for OpenAI, Gemini Flash for Google) -10. Vertex AI handling: For VERTEX_CREDENTIALS, test with vertex_ai/ prefix models - -% Dependencies - - context/llm_invoke_example.py - - - -The CSV at pdd/data/llm_model.csv has columns: -provider,model,input,output,coding_arena_elo,base_url,api_key,max_reasoning_tokens,structured_output,reasoning_type,location - -Example rows: -- Anthropic,anthropic/claude-haiku-4-5-20251001,1.0,5.0,1270,,ANTHROPIC_API_KEY,128000,True,budget, -- OpenAI,gpt-4o-mini,0.15,0.6,1249,,OPENAI_API_KEY,0,True,none, -- Google,vertex_ai/gemini-3-flash-preview,0.5,3.0,1430,,VERTEX_CREDENTIALS,0,True,effort,global - - -% Instructions -- Map key names to providers: ANTHROPIC_API_KEY -> Anthropic, OPENAI_API_KEY -> OpenAI, GEMINI_API_KEY -> Google, VERTEX_CREDENTIALS -> Google (Vertex), FIREWORKS_API_KEY -> Fireworks, GROQ_API_KEY -> Groq -- Select cheapest model for each provider from CSV for validation -- Set key as environment variable temporarily before calling llm_invoke (if not already set) -- Use try/except to catch litellm errors and categorize them (auth vs network vs config) -- Return ValidationResult dataclass with clear error messages for debugging -- Don't raise exceptions to caller, always return ValidationResult -- Log validation attempts and results for debugging - -% Deliverables -- A Python module located at `pdd/setup/api_key_validator.py`. -- The module must export the following symbols: - - `validate_key`: Tests if an API key is valid by making a minimal test request. - - `ValidationResult`: A dataclass containing the result of the validation. diff --git a/pdd/prompts/cli_detector_python.prompt b/pdd/prompts/cli_detector_python.prompt new file mode 100644 index 000000000..c5ff1f15c --- /dev/null +++ b/pdd/prompts/cli_detector_python.prompt @@ -0,0 +1,41 @@ +Detects installed agentic CLI tools and offers installation guidance for missing ones. + + +{ + "type": "module", + "module": { + "functions": [ + {"name": "detect_cli_tools", "signature": "() -> None", "returns": "None"} + ] + } +} + + +agentic_common_python.prompt + +% You are an expert Python engineer. Your goal is to write the pdd/setup/cli_detector.py module. + +% Role & Scope +Detects installed agentic CLI harnesses (Claude CLI, Codex CLI, Gemini CLI) required for `pdd fix`, `pdd change`, and `pdd bug`. Leverages `get_available_agents()` from `pdd.agentic_common` and cross-references with API keys to suggest installations. + +% Requirements +1. Function: `detect_cli_tools()` — check for CLI tools, display results, offer installation +2. For each CLI (claude, codex, gemini): show `✓ Found at /path` or `✗ Not found` +3. Cross-reference with API keys: if user has OPENAI_API_KEY but not codex CLI, highlight and suggest `npm install -g @openai/codex` +4. Offer `Install now? [y/N]` for missing CLIs that have a matching API key; run via subprocess if accepted +5. Show context: `(Required for: pdd fix, pdd change, pdd bug)` +6. Handle npm not being installed (suggest manual installation) + +% Dependencies + + context/agentic_common_example.py + + + +from pdd.agentic_common import get_available_agents, CLI_COMMANDS +# CLI_COMMANDS: {"anthropic": "claude", "google": "gemini", "openai": "codex"} +# get_available_agents() checks CLI existence + API key availability + + +% Deliverables +- Module at `pdd/setup/cli_detector.py` exporting `detect_cli_tools`. diff --git a/pdd/prompts/local_llm_configurator_python.prompt b/pdd/prompts/local_llm_configurator_python.prompt index da3a69839..880a340cd 100644 --- a/pdd/prompts/local_llm_configurator_python.prompt +++ b/pdd/prompts/local_llm_configurator_python.prompt @@ -1,4 +1,4 @@ -Configures local LLMs (Ollama, LM Studio) with auto-detection and CSV integration. +Configures local LLMs (Ollama, LM Studio, custom) with auto-detection and user CSV integration. { @@ -14,41 +14,27 @@ % You are an expert Python engineer. Your goal is to write the pdd/setup/local_llm_configurator.py module. % Role & Scope -This module guides users through configuring local LLM tools (Ollama, LM Studio, custom endpoints). Local models don't need API keys but require base_url and model name configuration. It auto-detects installed models where possible (e.g., querying Ollama API). +Guides users through configuring local LLM tools (Ollama, LM Studio, custom endpoints). Local models need a base_url and model name, not API keys. Auto-detects installed models where possible. % Requirements -1. Function: `configure_local_llm() -> bool` - Interactive setup for local LLM providers -2. Provider options: Present menu: 1. LM Studio (default: localhost:1234), 2. Ollama (default: localhost:11434), 3. Other (custom base URL) -3. Ollama detection: For Ollama, query http://localhost:11434/api/tags to list installed models -4. Model selection: For Ollama, show detected models and let user select which to add; for others, prompt for model name -5. CSV addition: Append rows to llm_model.csv with provider prefix (lm_studio/ or ollama_chat/), base_url, and empty api_key -6. Base URL validation: For custom providers, validate URL format (http/https scheme required) -7. Multiple models: Allow adding multiple models in one session (loop until user declines) -8. Cost defaults: For local models, set input/output costs to 0.0 or very low values ($0.0001) -9. Atomic CSV writes: Use temp file + rename to prevent corruption -10. Error handling: If Ollama API unreachable, fall back to manual model name entry +1. Function: `configure_local_llm() -> bool` — interactive setup, returns True if any models were added +2. Provider menu: 1. LM Studio (default localhost:1234), 2. Ollama (default localhost:11434), 3. Other (custom base URL) +3. Ollama auto-detection: query http://localhost:11434/api/tags, show discovered models, let user select which to add (comma-separated). Fall back to manual entry if unreachable. +4. LM Studio: default base URL http://localhost:1234/v1, prompt for model name +5. Append rows to user's `~/.pdd/llm_model.csv` with LiteLLM prefix conventions (`lm_studio/`, `ollama_chat/`), empty api_key, cost=0.0 +6. Validate base URL format (http/https required) +7. Atomic CSV writes; create user CSV with header if it doesn't exist +8. Handle empty input as cancel % Dependencies -The CSV at pdd/data/llm_model.csv has columns: +The user's CSV at ~/.pdd/llm_model.csv has columns: provider,model,input,output,coding_arena_elo,base_url,api_key,max_reasoning_tokens,structured_output,reasoning_type,location -Example local LLM rows: -- lm_studio,lm_studio/openai-gpt-oss-120b-mlx-6,0.0001,0,1082,http://localhost:1234/v1,,0,True,effort, -- custom_model,custom/Qwen3-30B-A3B-4bit,0,0,1040,http://localhost:8080,,0,False,none, +Example local rows: +- lm_studio,lm_studio/openai-gpt-oss-120b-mlx-6,0,0,1000,http://localhost:1234/v1,,0,True,none, +- Ollama,ollama_chat/llama3:70b,0,0,1000,http://localhost:11434,,0,True,none, -% Instructions -- For Ollama auto-detection: Make GET request to http://localhost:11434/api/tags, parse JSON response for model list -- If Ollama unreachable, inform user and ask for manual model name entry -- For LM Studio: Default base URL is http://localhost:1234/v1, ask if user wants different port -- For model naming: Use prefix format (lm_studio/ or ollama_chat/) to match LiteLLM conventions -- Set reasonable defaults: ELO=1000, structured_output=True, reasoning_type=effort, max_reasoning_tokens=0 -- After adding each model, ask: "Add another local model? [y/N]" -- Use rich Console for formatted output and interactive prompts -- Return bool indicating whether any models were added - % Deliverables -- A Python module located at `pdd/setup/local_llm_configurator.py`. -- The module must export the following symbol: - - `configure_local_llm`: Provides an interactive setup for local LLM providers. +- Module at `pdd/setup/local_llm_configurator.py` exporting `configure_local_llm`. diff --git a/pdd/prompts/model_selector_python.prompt b/pdd/prompts/model_selector_python.prompt deleted file mode 100644 index 21d72102f..000000000 --- a/pdd/prompts/model_selector_python.prompt +++ /dev/null @@ -1,55 +0,0 @@ -Interactive model tier selection with cost transparency and strength parameter guidance. - - -{ - "type": "module", - "module": { - "functions": [ - {"name": "interactive_selection", "signature": "(validated_providers: List[str]) -> bool", "returns": "bool"} - ] - } -} - - -% You are an expert Python engineer. Your goal is to write the pdd/setup/model_selector.py module. - -% Role & Scope -This module provides interactive model tier selection with cost transparency. It groups models by capability/cost tier, displays pricing information, and lets users select which tiers to enable. It explains how --strength controls model selection at runtime. - -% Requirements -1. Function: `interactive_selection(validated_providers: List[str]) -> bool` - Guide user through tier selection for each provider -2. Tier classification: Group models into tiers (Fast/Cheap, Balanced, Most Capable) based on cost and ELO -3. Cost display: Show input/output token costs for each tier (per million tokens) -4. Provider iteration: For each validated provider, show available models grouped by tier -5. User selection: Let user choose which tiers to include (default: all) via numbered input -6. Strength explanation: Briefly explain that pdd uses --strength (0.0-1.0) to pick models by cost/quality at runtime -7. CSV filtering: After selection, update llm_model.csv to only include chosen models (or keep all if user selects all) -8. Smart defaults: If user presses Enter without input, include all models for that provider -9. Tier thresholds: Use cost as primary classifier - Cheap: <=$1 input, Balanced: >$1 and <$3 input, Capable: >=$3 input -10. Interactive display: Use rich Console to show formatted tables with model info (name, cost, ELO) - -% Dependencies - -The CSV at pdd/data/llm_model.csv has columns: -provider,model,input,output,coding_arena_elo,base_url,api_key,max_reasoning_tokens,structured_output,reasoning_type,location - -Example cost ranges: -- Fast/Cheap: gpt-4o-mini ($0.15/$0.6), claude-haiku-4-5 ($1/$5) -- Balanced: claude-sonnet-4-5 ($3/$15), gpt-4o ($5/$15) -- Most Capable: claude-opus-4-5 ($5/$25), gpt-4-turbo ($10/$30) - - -% Instructions -- Read llm_model.csv to get all models for validated providers -- Group by provider first, then by tier within each provider -- Display clear table format: " # Model Input Output ELO" -- After showing tiers, prompt: "Include which models? [1,2,3] (default: all):" -- Parse user input (comma-separated numbers or "all") -- If models are filtered out, update CSV by removing those rows (atomic write with temp file) -- Show tip about --strength before starting selections: "Tip: pdd uses --strength (0.0-1.0) to pick models by cost/quality at runtime. Adding all models gives you the full range." -- Return bool indicating whether any changes were made - -% Deliverables -- A Python module located at `pdd/setup/model_selector.py`. -- The module must export the following symbol: - - `interactive_selection`: Guides the user through model tier selection for each provider. diff --git a/pdd/prompts/model_tester_python.prompt b/pdd/prompts/model_tester_python.prompt new file mode 100644 index 000000000..638b0f7d0 --- /dev/null +++ b/pdd/prompts/model_tester_python.prompt @@ -0,0 +1,37 @@ +Tests individual models via litellm.completion() with direct API key passing and diagnostics. + + +{ + "type": "module", + "module": { + "functions": [ + {"name": "test_model_interactive", "signature": "() -> None", "returns": "None"} + ] + } +} + + +% You are an expert Python engineer. Your goal is to write the pdd/setup/model_tester.py module. + +% Role & Scope +Tests a single configured model by making one `litellm.completion()` call with a minimal prompt. Only runs when the user explicitly chooses it — no surprise API costs. Uses `litellm.completion()` directly (not `llm_invoke`) because `llm_invoke` doesn't allow choosing a specific model or key. + +% Requirements +1. Function: `test_model_interactive()` — show models from `~/.pdd/llm_model.csv`, let user pick one, test it, loop until user exits (empty input or "q") +2. Test call: `litellm.completion(model=..., messages=[{"role": "user", "content": "Say OK"}], api_key=..., api_base=..., timeout=30)` +3. Before calling, show diagnostics: API key status (`✓ Found (source)` / `✗ Not found` / `(local — no key required)`) and base URL if applicable +4. After calling, show: `LLM call ✓ OK (0.3s, $0.0001)` or `LLM call ✗ error description` +5. Calculate cost from token usage × CSV row's input/output prices per 1M tokens +6. Persist test results in the model list display across picks within a session +7. Distinguish errors: authentication, connection refused (local), model not found, timeout +8. For local models (empty api_key): pass api_base, omit api_key +9. If no user CSV exists or is empty, inform user and return + +% Dependencies + +The user's CSV at ~/.pdd/llm_model.csv has columns: +provider,model,input,output,coding_arena_elo,base_url,api_key,max_reasoning_tokens,structured_output,reasoning_type,location + + +% Deliverables +- Module at `pdd/setup/model_tester.py` exporting `test_model_interactive`. diff --git a/pdd/prompts/pddrc_initializer_python.prompt b/pdd/prompts/pddrc_initializer_python.prompt index 73af51b99..0c7fbbee6 100644 --- a/pdd/prompts/pddrc_initializer_python.prompt +++ b/pdd/prompts/pddrc_initializer_python.prompt @@ -1,4 +1,4 @@ -Creates .pddrc configuration files with language-aware defaults and project detection. +Creates .pddrc configuration files with sensible defaults for the current project directory. { @@ -14,23 +14,19 @@ % You are an expert Python engineer. Your goal is to write the pdd/setup/pddrc_initializer.py module. % Role & Scope -This module offers to create a basic .pddrc configuration file in the current project directory if one doesn't exist. It sets sensible defaults for output paths, language, and strength/temperature based on detected project type. +Offers to create a basic `.pddrc` configuration file in the current project directory if one doesn't exist. Sets sensible defaults based on detected project type. % Requirements -1. Function: `offer_pddrc_init() -> bool` - Check if .pddrc exists, offer to create with defaults -2. Existence check: Look for .pddrc in current working directory (os.getcwd()) -3. Language detection: Detect project type (Python if setup.py/pyproject.toml, TypeScript if package.json with typescript dep, else ask user) -4. Default paths: Set generate_output_path, test_output_path, example_output_path based on language (Python: pdd/, tests/, context/ | TypeScript: src/, __tests__/, examples/) -5. Strength defaults: Set strength=1.0, temperature=0.0, target_coverage=80.0 as sensible defaults -6. Interactive prompt: Ask "Would you like to create one with default settings? [Y/n]" before creating -7. File format: Write YAML format matching .pddrc specification (version, contexts, defaults) -8. Show preview: Display the default settings before creating (language, output paths) -9. Skip if exists: Don't overwrite existing .pddrc, just inform user it already exists -10. Return bool: Indicate whether file was created +1. Function: `offer_pddrc_init() -> bool` — returns True if file was created, False otherwise +2. If .pddrc exists in cwd: inform user and return False +3. If no .pddrc: show preview of defaults, prompt `Create .pddrc? [Y/n]` (Enter = yes) +4. Language detection: Python (setup.py/pyproject.toml), TypeScript (package.json with typescript dep), Go (go.mod). Prompt user if unclear. +5. Path defaults by language: Python: pdd/, tests/, context/ | TypeScript: src/, __tests__/, examples/ | Go: ., ., examples/ +6. Standard defaults: strength=1.0, temperature=0.0, target_coverage=80.0, budget=10.0, max_attempts=3 +7. Write YAML format matching .pddrc specification % Dependencies -# Example .pddrc file structure version: "1.0" contexts: @@ -47,18 +43,5 @@ contexts: max_attempts: 3 -% Instructions -- Check for project indicators: setup.py, pyproject.toml (Python), package.json (TypeScript/JavaScript), go.mod (Go), etc. -- If language unclear, prompt user: "What is your primary language? (python/typescript/go/etc.)" -- Map language to conventional paths: - - Python: generate="pdd/", test="tests/", example="context/" - - TypeScript: generate="src/", test="__tests__/", example="examples/" - - Go: generate=".", test=".", example="examples/" -- Use YAML format for .pddrc file (preserve comments for user guidance) -- After creation, print success message: "✓ Created .pddrc with default settings" -- If .pddrc already exists, print: "Found existing .pddrc - skipping initialization" - % Deliverables -- A Python module located at `pdd/setup/pddrc_initializer.py`. -- The module must export the following symbol: - - `offer_pddrc_init`: Checks for an existing `.pddrc` and offers to create one with default settings. +- Module at `pdd/setup/pddrc_initializer.py` exporting `offer_pddrc_init`. diff --git a/pdd/prompts/provider_manager_python.prompt b/pdd/prompts/provider_manager_python.prompt index 94de28904..520e4d8b9 100644 --- a/pdd/prompts/provider_manager_python.prompt +++ b/pdd/prompts/provider_manager_python.prompt @@ -1,66 +1,45 @@ -Manages LLM providers: adding/fixing API keys, custom providers, and safe key removal. +Manages LLM providers: adding API keys with auto-loaded models, custom providers, and model removal. { "type": "module", "module": { "functions": [ - {"name": "add_or_fix_keys", "signature": "(scan_results: Dict[str, KeyInfo]) -> bool", "returns": "bool"}, + {"name": "add_api_key", "signature": "(scan_results: Dict[str, KeyInfo]) -> bool", "returns": "bool"}, {"name": "add_custom_provider", "signature": "() -> bool", "returns": "bool"}, - {"name": "remove_provider", "signature": "(scan_results: Dict[str, KeyInfo]) -> bool", "returns": "bool"} + {"name": "remove_models_by_provider", "signature": "() -> bool", "returns": "bool"}, + {"name": "remove_individual_models", "signature": "() -> bool", "returns": "bool"} ] } } -api_key_validator_python.prompt +api_key_scanner_python.prompt % You are an expert Python engineer. Your goal is to write the pdd/setup/provider_manager.py module. % Role & Scope -This module handles adding, removing, and managing LLM providers in PDD setup. It supports adding/fixing API keys for existing providers, adding custom LiteLLM-compatible providers, and removing providers (commenting out keys and removing model rows from CSV). +Handles adding and removing LLM providers and models in PDD setup. Supports entering API keys for known providers (auto-loading all their models), adding custom LiteLLM-compatible providers, and two modes of model removal. % Requirements -1. Function: `add_or_fix_keys(scan_results: Dict[str, KeyInfo]) -> bool` - Prompt for missing/invalid keys, validate, and save -2. Function: `add_custom_provider() -> bool` - Guide user through adding custom provider (provider prefix, model name, API key, base URL, costs) -3. Function: `remove_provider(scan_results: Dict[str, KeyInfo]) -> bool` - Show configured providers, let user select one, comment out key and remove CSV rows -4. Smart storage: Save newly entered keys to ~/.pdd/api-env.{{shell}}, skip keys already in shell or .env -5. Key commenting: When removing, add comment with timestamp, don't delete (# Commented out by pdd setup on YYYY-MM-DD) -6. CSV updates: For custom provider, append row to llm_model.csv; for removal, delete all rows with matching api_key field -7. Atomic CSV writes: Use temp file + rename to prevent corruption on partial writes -8. Shell detection: Determine shell from SHELL env var, write to correct api-env file (.zsh or .bash) -9. Validation: After adding keys, call api_key_validator to test before saving -10. Interactive prompts: Use input() for all user prompts, handle empty input as skip/cancel + +1. `add_api_key(scan_results)` — show missing keys from scan_results, prompt user for one, save to `~/.pdd/api-env.{shell}`, then copy ALL matching rows from master CSV (pdd/data/llm_model.csv) into user CSV (`~/.pdd/llm_model.csv`). No interactive model selection. If key already exists (replacing), update api-env only. If key name has no CSV rows, tell user to use "Add a custom provider" instead. Skip saving keys already in environment. +2. `add_custom_provider()` — prompt for provider prefix, model name, API key env var, base URL (optional), costs (optional). Append row to user CSV with sensible defaults. Save API key to api-env if provided. +3. `remove_models_by_provider()` — group user CSV models by api_key, show numbered list with counts, remove all rows for selected provider. Comment out (never delete) the key in api-env: `# Commented out by pdd setup on YYYY-MM-DD`. +4. `remove_individual_models()` — list all models from user CSV, let user select by comma-separated numbers, remove selected rows. +5. All CSV writes must be atomic (temp file + rename) +6. Detect shell from SHELL env var for api-env file path +7. Handle empty input as cancel/back % Dependencies - - context/api_key_validator_example.py - + + context/api_key_scanner_example.py + -The CSV at pdd/data/llm_model.csv has columns: +The master CSV at pdd/data/llm_model.csv and user CSV at ~/.pdd/llm_model.csv share columns: provider,model,input,output,coding_arena_elo,base_url,api_key,max_reasoning_tokens,structured_output,reasoning_type,location - -Example rows: -- OpenAI,gpt-4o-mini,0.15,0.6,1249,,OPENAI_API_KEY,0,True,none, -- Fireworks,fireworks_ai/accounts/fireworks/models/qwen3-coder-480b-a35b-instruct,0.45,1.80,1363,,FIREWORKS_API_KEY,0,False,none, -- lm_studio,lm_studio/openai-gpt-oss-120b-mlx-6,0.0001,0,1082,http://localhost:1234/v1,,0,True,effort, -% Instructions -- For add_or_fix_keys: Only prompt for keys that are missing or invalid in scan_results -- For each key to add: Prompt user, validate with api_key_validator, save only if valid -- Determine storage location: If key already in environment (not from api-env file), don't save to avoid duplicates -- For api-env file writes: Use shell-appropriate export format (export KEY="value") -- For remove_provider: Show numbered list of providers with model counts, let user pick, confirm before removal -- CSV atomic writes: Write to temp file, validate it's parseable, then rename over original -- Handle missing ~/.pdd directory (create it with mkdir -p) -- For custom provider: Use sensible defaults (structured_output=True, reasoning_type=none, location=empty) -- Return bool indicating success/failure for each function - % Deliverables -- A Python module located at `pdd/setup/provider_manager.py`. -- The module must export the following symbols: - - `add_or_fix_keys`: Prompts the user to add or fix missing/invalid API keys. - - `add_custom_provider`: Guides the user through adding a new custom provider. - - `remove_provider`: Interactively removes a configured provider. +- Module at `pdd/setup/provider_manager.py` exporting `add_api_key`, `add_custom_provider`, `remove_models_by_provider`, `remove_individual_models`. diff --git a/pdd/prompts/setup_tool_python.prompt b/pdd/prompts/setup_tool_python.prompt index 509fe4a8e..63b104346 100644 --- a/pdd/prompts/setup_tool_python.prompt +++ b/pdd/prompts/setup_tool_python.prompt @@ -1,4 +1,4 @@ -Orchestrates comprehensive PDD setup: API key scanning, provider management, and configuration. +Orchestrates pdd setup: environment scanning, provider management, model testing, CLI detection, and configuration. { @@ -12,78 +12,62 @@ api_key_scanner_python.prompt -api_key_validator_python.prompt provider_manager_python.prompt -model_selector_python.prompt local_llm_configurator_python.prompt +model_tester_python.prompt +cli_detector_python.prompt pddrc_initializer_python.prompt % You are an expert Python engineer. Your goal is to write the pdd/setup/setup_tool.py module. % Role & Scope -This module is the main orchestrator for the comprehensive `pdd setup` command. It scans the environment for API keys, presents an interactive menu, and coordinates with helper modules to guide users through PDD configuration including API key management, local LLM setup, custom providers, model selection, and .pddrc initialization. +Main orchestrator for `pdd setup`. Auto-scans the environment for API keys (existence only — no API calls), then presents an interactive menu. After any action, the menu re-displays with an updated scan. % Requirements -1. Function: `run_setup()` - Main entry point that orchestrates the entire setup flow -2. Environment scanning: Auto-detect all API keys from CSV providers, check all sources (.env, shell environment, ~/.pdd/api-env.*) -3. Interactive menu: Present 5 options after scan (Add/fix keys, Add local LLM, Add custom provider, Remove provider, Continue) -4. Menu loop: After options 1-4, re-scan and show updated menu; option 5 proceeds to model selection -5. Delegate to modules: Use api_key_scanner for discovery, api_key_validator for testing, provider_manager for add/remove, model_selector for tier selection -6. Model selection flow: After menu (option 5), call model_selector for interactive tier selection with cost guidance -7. CLI harness detection: After model selection, detect agentic CLI tools (claude, gemini, codex) and offer installation -8. .pddrc initialization: Offer to create .pddrc if none exists in current directory -9. Smart key storage: Save newly entered keys to ~/.pdd/api-env.{{shell}}, skip keys already in environment -10. Output clarity: Use rich Console for formatted output, show ✓/✗ status, display key sources transparently +1. Function: `run_setup()` — main entry point +2. On entry, call `api_key_scanner.scan_environment()` and display each key with `✓ Found (source)` or `— Not found`, plus a summary line: `Models configured: N (from M API keys + K local)` +3. Present a 6-option menu after the scan: + 1. Add a provider (sub-menu: a. Enter an API key, b. Add a local LLM, c. Add a custom provider) + 2. Remove models (sub-menu: a. By provider, b. Individual models) + 3. Test a model + 4. Detect CLI tools + 5. Initialize .pddrc + 6. Done +4. Delegate to: `provider_manager.add_api_key`, `local_llm_configurator.configure_local_llm`, `provider_manager.add_custom_provider`, `provider_manager.remove_models_by_provider`, `provider_manager.remove_individual_models`, `model_tester.test_model_interactive`, `cli_detector.detect_cli_tools`, `pddrc_initializer.offer_pddrc_init` +5. After options 1–5, re-scan and re-display the menu +6. Option 6 exits the loop +7. Handle KeyboardInterrupt for clean exit at any point % Dependencies context/api_key_scanner_example.py - - context/api_key_validator_example.py - - context/provider_manager_example.py - - context/model_selector_example.py - - context/local_llm_configurator_example.py + + context/model_tester_example.py + + + + context/cli_detector_example.py + + context/pddrc_initializer_example.py - - context/agentic_common_example.py - - -The CSV at $PDD_PATH/data/llm_model.csv has columns: +The CSV at pdd/data/llm_model.csv has columns: provider,model,input,output,coding_arena_elo,base_url,api_key,max_reasoning_tokens,structured_output,reasoning_type,location - -Key field: api_key (e.g., "ANTHROPIC_API_KEY", "OPENAI_API_KEY", "VERTEX_CREDENTIALS", "FIREWORKS_API_KEY", "GROQ_API_KEY") - -Note: At runtime, the CSV is located at pdd/data/llm_model.csv relative to the installed package. +The user-level CSV is at ~/.pdd/llm_model.csv. -% Instructions -- Use rich.console.Console for all interactive prompts and status display -- Preserve the overall flow of the existing setup_tool.py but reorganize to use the new modular architecture -- For key validation, delegate to api_key_validator.validate_key() instead of HTTP requests -- For menu option handlers, delegate to respective modules (provider_manager, local_llm_configurator, etc.) -- Display scan results with ✓ Valid (source), ✗ Invalid (source), — Not found format -- After any menu action (1-4), call api_key_scanner.scan_environment() again to refresh the display -- Implement simple input() prompts for menu selection (1-5) -- Handle Ctrl+C gracefully, allow exit at any point - % Deliverables -- A Python module located at `pdd/setup/setup_tool.py`. -- The module must export the following symbol: - - `run_setup`: The main entry point that orchestrates the entire setup flow. +- Module at `pdd/setup/setup_tool.py` exporting `run_setup`. From 1779064ed4f99d94fcf3ee6b8797abf242f4f8fb Mon Sep 17 00:00:00 2001 From: Niti Goyal Date: Sun, 15 Feb 2026 15:40:30 -0500 Subject: [PATCH 03/10] Generate initial pdd setup files - Store files in /pdd/setup, rather than main directory /pdd - Modify utils.py command to run pdd setup at the new correct file path - Still needs refinement on prompts + code to further improve interface, but is functional - Still needs example files and test files --- pdd/core/utils.py | 2 +- pdd/prompts/api_key_scanner_python.prompt | 5 +- pdd/prompts/setup_tool_python.prompt | 1 + pdd/setup/__init__.py | 0 pdd/setup/api_key_scanner.py | 196 ++++++++ pdd/setup/cli_detector.py | 191 ++++++++ pdd/setup/local_llm_configurator.py | 377 +++++++++++++++ pdd/setup/model_tester.py | 391 ++++++++++++++++ pdd/setup/pddrc_initializer.py | 192 ++++++++ pdd/setup/provider_manager.py | 546 ++++++++++++++++++++++ pdd/setup/setup_tool.py | 155 ++++++ 11 files changed, 2054 insertions(+), 2 deletions(-) create mode 100644 pdd/setup/__init__.py create mode 100644 pdd/setup/api_key_scanner.py create mode 100644 pdd/setup/cli_detector.py create mode 100644 pdd/setup/local_llm_configurator.py create mode 100644 pdd/setup/model_tester.py create mode 100644 pdd/setup/pddrc_initializer.py create mode 100644 pdd/setup/provider_manager.py create mode 100644 pdd/setup/setup_tool.py diff --git a/pdd/core/utils.py b/pdd/core/utils.py index 9f3523e79..e43bd3249 100644 --- a/pdd/core/utils.py +++ b/pdd/core/utils.py @@ -85,6 +85,6 @@ def _should_show_onboarding_reminder(ctx: click.Context) -> bool: def _run_setup_utility() -> None: """Execute the interactive setup utility script.""" - result = subprocess.run([sys.executable, "-m", "pdd.setup_tool"]) + result = subprocess.run([sys.executable, "-m", "pdd.setup.setup_tool"]) if result.returncode not in (0, None): raise RuntimeError(f"Setup utility exited with status {result.returncode}") diff --git a/pdd/prompts/api_key_scanner_python.prompt b/pdd/prompts/api_key_scanner_python.prompt index 34732ae2b..66c968235 100644 --- a/pdd/prompts/api_key_scanner_python.prompt +++ b/pdd/prompts/api_key_scanner_python.prompt @@ -21,7 +21,10 @@ Dynamically discovers API keys from all sources and reports their existence. Rea 1. Function: `scan_environment() -> Dict[str, KeyInfo]` — returns mapping of key name to KeyInfo(source, is_set). Does not store key values. 2. Function: `get_provider_key_names() -> List[str]` — returns deduplicated sorted list of all non-empty api_key values from the master CSV 3. Dynamic discovery: extract all unique api_key column values from pdd/data/llm_model.csv — no hardcoded provider list -4. Check sources in priority order: .env file (via python-dotenv `dotenv_values`, read-only), shell environment (`os.environ`), ~/.pdd/api-env.{shell} (parse uncommented `export KEY=` lines) +4. Check sources in priority order: + - .env file (via python-dotenv `dotenv_values`, read-only — always reads fresh on each scan) + - Shell environment (`os.environ` — note: may include stale .env values if edited during session; restart pdd setup to refresh) + - ~/.pdd/api-env.{shell} (parse uncommented `export KEY=` lines) 5. KeyInfo: dataclass with fields `source` (str) and `is_set` (bool). Report source as "shell environment", ".env file", or "~/.pdd/api-env.zsh" (etc.) 6. Detect shell from SHELL env var for correct api-env file 7. Never raise exceptions — return best-effort results with logging for errors diff --git a/pdd/prompts/setup_tool_python.prompt b/pdd/prompts/setup_tool_python.prompt index 63b104346..6bf7b88d5 100644 --- a/pdd/prompts/setup_tool_python.prompt +++ b/pdd/prompts/setup_tool_python.prompt @@ -71,3 +71,4 @@ The user-level CSV is at ~/.pdd/llm_model.csv. % Deliverables - Module at `pdd/setup/setup_tool.py` exporting `run_setup`. +- IMPORTANT: Must include `if __name__ == "__main__":` entry point that calls `run_setup()` to enable execution via `python -m pdd.setup.setup_tool`. diff --git a/pdd/setup/__init__.py b/pdd/setup/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pdd/setup/api_key_scanner.py b/pdd/setup/api_key_scanner.py new file mode 100644 index 000000000..8bf160549 --- /dev/null +++ b/pdd/setup/api_key_scanner.py @@ -0,0 +1,196 @@ +""" +pdd/setup/api_key_scanner.py + +Discovers API keys from CSV providers, checking existence across +shell, .env, and PDD config with source transparency. +""" + +import csv +import logging +import os +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, List, Optional + +logger = logging.getLogger(__name__) + + +@dataclass +class KeyInfo: + """Information about an API key's availability.""" + source: str + is_set: bool + + +def _get_csv_path() -> Path: + """Return the path to the master llm_model.csv file.""" + # Navigate from this file's location to pdd/data/llm_model.csv + module_dir = Path(__file__).resolve().parent # pdd/setup/ + pdd_dir = module_dir.parent # pdd/ + return pdd_dir / "data" / "llm_model.csv" + + +def get_provider_key_names() -> List[str]: + """ + Returns a deduplicated, sorted list of all non-empty api_key values + from the master CSV (pdd/data/llm_model.csv). + + Returns an empty list if the CSV is missing or malformed. + """ + csv_path = _get_csv_path() + key_names: set = set() + + try: + if not csv_path.exists(): + logger.warning("llm_model.csv not found at %s", csv_path) + return [] + + with open(csv_path, "r", newline="", encoding="utf-8") as f: + reader = csv.DictReader(f) + + if reader.fieldnames is None or "api_key" not in reader.fieldnames: + logger.warning( + "llm_model.csv at %s is missing the 'api_key' column.", csv_path + ) + return [] + + for row in reader: + api_key_name = row.get("api_key", "").strip() + if api_key_name: + key_names.add(api_key_name) + + except Exception as e: + logger.error("Error reading llm_model.csv: %s", e) + return [] + + return sorted(key_names) + + +def _load_dotenv_values() -> Dict[str, str]: + """ + Load values from a .env file using python-dotenv's dotenv_values (read-only). + Returns an empty dict on any failure. + """ + try: + from dotenv import dotenv_values # type: ignore + + values = dotenv_values() + # dotenv_values returns an OrderedDict; values can be None for keys without values + return {k: v for k, v in values.items() if v is not None} + except ImportError: + logger.debug("python-dotenv not installed; skipping .env file check.") + return {} + except Exception as e: + logger.error("Error loading .env file: %s", e) + return {} + + +def _detect_shell() -> Optional[str]: + """ + Detect the current shell name from the SHELL environment variable. + Returns the shell name (e.g. 'zsh', 'bash') or None if not detectable. + """ + shell_path = os.environ.get("SHELL", "") + if shell_path: + return os.path.basename(shell_path) + return None + + +def _parse_api_env_file(file_path: Path) -> Dict[str, str]: + """ + Parse a ~/.pdd/api-env.{shell} file for uncommented `export KEY=value` lines. + Returns a dict of key names to values found. + """ + result: Dict[str, str] = {} + + try: + if not file_path.exists(): + logger.debug("api-env file not found at %s", file_path) + return result + + with open(file_path, "r", encoding="utf-8") as f: + for line in f: + stripped = line.strip() + + # Skip empty lines and comments + if not stripped or stripped.startswith("#"): + continue + + # Match lines like: export KEY=value or export KEY="value" + if stripped.startswith("export "): + remainder = stripped[len("export "):].strip() + if "=" in remainder: + key, _, value = remainder.partition("=") + key = key.strip() + value = value.strip() + + # Remove surrounding quotes if present + if len(value) >= 2 and ( + (value.startswith('"') and value.endswith('"')) + or (value.startswith("'") and value.endswith("'")) + ): + value = value[1:-1] + + if key and value: + result[key] = value + + except Exception as e: + logger.error("Error parsing api-env file %s: %s", file_path, e) + + return result + + +def scan_environment() -> Dict[str, KeyInfo]: + """ + Scan for API key existence across all known sources. + + Checks sources in priority order: + 1. .env file (via python-dotenv dotenv_values, read-only) + 2. Shell environment (os.environ - note: may include stale .env values if edited during session) + 3. ~/.pdd/api-env.{shell} file + + Returns a mapping of key name -> KeyInfo(source, is_set). + Never raises exceptions; returns best-effort results. + + Note: If you edit .env during a pdd setup session, restart pdd setup to see updated shell environment. + """ + result: Dict[str, KeyInfo] = {} + + try: + key_names = get_provider_key_names() + + if not key_names: + logger.info("No API key names discovered from CSV.") + return result + + # Load all sources once + dotenv_vals = _load_dotenv_values() + shell_name = _detect_shell() + + api_env_file_path: Optional[Path] = None + api_env_vals: Dict[str, str] = {} + api_env_source_label = "" + + if shell_name: + api_env_file_path = Path.home() / ".pdd" / f"api-env.{shell_name}" + api_env_vals = _parse_api_env_file(api_env_file_path) + api_env_source_label = f"~/.pdd/api-env.{shell_name}" + + for key_name in key_names: + # Check in priority order + if key_name in dotenv_vals: + result[key_name] = KeyInfo(source=".env file", is_set=True) + elif key_name in os.environ: + result[key_name] = KeyInfo(source="shell environment", is_set=True) + elif key_name in api_env_vals: + result[key_name] = KeyInfo( + source=api_env_source_label, is_set=True + ) + else: + # Key not found in any source + result[key_name] = KeyInfo(source="", is_set=False) + + except Exception as e: + logger.error("Unexpected error during environment scan: %s", e) + + return result \ No newline at end of file diff --git a/pdd/setup/cli_detector.py b/pdd/setup/cli_detector.py new file mode 100644 index 000000000..896fc0a50 --- /dev/null +++ b/pdd/setup/cli_detector.py @@ -0,0 +1,191 @@ +from __future__ import annotations + +import os +import shutil +import subprocess +import sys +from pathlib import Path + + +# Maps provider name -> CLI command name +_CLI_COMMANDS: dict[str, str] = { + "anthropic": "claude", + "google": "gemini", + "openai": "codex", +} + +# Maps provider name -> environment variable for API key +_API_KEY_ENV_VARS: dict[str, str] = { + "anthropic": "ANTHROPIC_API_KEY", + "google": "GOOGLE_API_KEY", + "openai": "OPENAI_API_KEY", +} + +# Maps provider name -> npm install command for the CLI +_INSTALL_COMMANDS: dict[str, str] = { + "anthropic": "npm install -g @anthropic-ai/claude-code", + "google": "npm install -g @anthropic-ai/gemini-cli", + "openai": "npm install -g @openai/codex", +} + +# Maps provider name -> human-readable CLI name +_CLI_DISPLAY_NAMES: dict[str, str] = { + "anthropic": "Claude CLI", + "google": "Gemini CLI", + "openai": "Codex CLI", +} + + +def _which(cmd: str) -> str | None: + """Return the full path to a command if found on PATH, else None.""" + return shutil.which(cmd) + + +def _has_api_key(provider: str) -> bool: + """Check whether the API key environment variable is set for a provider.""" + env_var = _API_KEY_ENV_VARS.get(provider, "") + return bool(os.environ.get(env_var, "").strip()) + + +def _npm_available() -> bool: + """Check whether npm is available on PATH.""" + return _which("npm") is not None + + +def _prompt_yes_no(prompt: str) -> bool: + """Prompt the user with a yes/no question. Default is No.""" + try: + answer = input(prompt).strip().lower() + except (EOFError, KeyboardInterrupt): + print() + return False + return answer in ("y", "yes") + + +def _run_install(install_cmd: str) -> bool: + """Run an installation command via subprocess. Returns True on success.""" + print(f" Running: {install_cmd}") + try: + result = subprocess.run( + install_cmd, + shell=True, + check=False, + capture_output=False, + ) + return result.returncode == 0 + except Exception as exc: + print(f" Installation failed: {exc}") + return False + + +def detect_cli_tools() -> None: + """ + Detect installed agentic CLI harnesses (Claude CLI, Codex CLI, Gemini CLI) + required for ``pdd fix``, ``pdd change``, and ``pdd bug``. + + For each CLI tool: + - Shows ✓ Found at /path or ✗ Not found + - Cross-references with API keys to highlight actionable installations + - Offers interactive installation for missing CLIs that have a matching API key + + Handles the case where npm is not installed by suggesting manual installation. + """ + # Try to import get_available_agents for cross-reference, but don't fail if + # the import is unavailable (we can still do basic detection). + available_agents: list[str] = [] + try: + from pdd.agentic_common import get_available_agents as _get_available_agents + available_agents = list(_get_available_agents()) + except Exception: + pass + + print() + print("Agentic CLI Tool Detection") + print("=" * 50) + print("(Required for: pdd fix, pdd change, pdd bug)") + print() + + missing_with_key: list[str] = [] + found_any = False + + for provider, cli_cmd in _CLI_COMMANDS.items(): + display_name = _CLI_DISPLAY_NAMES[provider] + path = _which(cli_cmd) + has_key = _has_api_key(provider) + key_env = _API_KEY_ENV_VARS[provider] + + if path: + print(f" ✓ {display_name} ({cli_cmd}): Found at {path}") + found_any = True + if has_key: + print(f" API key ({key_env}): set") + else: + print(f" API key ({key_env}): not set — CLI found but won't be usable without it") + else: + print(f" ✗ {display_name} ({cli_cmd}): Not found") + if has_key: + print(f" API key ({key_env}): set — install the CLI to use this provider") + missing_with_key.append(provider) + else: + print(f" API key ({key_env}): not set") + + print() + + if not missing_with_key: + if found_any: + print("All CLI tools with matching API keys are installed.") + else: + print("No CLI tools found. Install at least one CLI and set its API key") + print("to use agentic features (pdd fix, pdd change, pdd bug).") + print() + print("Quick start:") + for provider, install_cmd in _INSTALL_COMMANDS.items(): + display_name = _CLI_DISPLAY_NAMES[provider] + key_env = _API_KEY_ENV_VARS[provider] + print(f" {display_name}: {install_cmd}") + print(f" Then set: export {key_env}=") + print() + return + + # Offer installation for missing CLIs that have a matching API key + print("The following CLI tools are missing but have API keys configured:") + print() + + npm_available = _npm_available() + + for provider in missing_with_key: + display_name = _CLI_DISPLAY_NAMES[provider] + install_cmd = _INSTALL_COMMANDS[provider] + + print(f" {display_name}:") + print(f" Install command: {install_cmd}") + + if not npm_available: + print(" ☀ npm is not installed. Please install Node.js/npm first:") + print(" macOS: brew install node") + print(" Ubuntu: sudo apt-get update && sudo apt-get install -y nodejs npm") + print(" Then run the install command above manually.") + print() + continue + + if _prompt_yes_no(f" Install now? [y/N] "): + success = _run_install(install_cmd) + if success: + new_path = _which(_CLI_COMMANDS[provider]) + if new_path: + print(f" ✓ {display_name} installed successfully at {new_path}") + else: + print(f" ✓ Installation command completed. You may need to restart your shell.") + else: + print(f" ✗ Installation failed. Try running manually:") + print(f" {install_cmd}") + else: + print(" Skipped. To install later, run:") + print(f" {install_cmd}") + + print() + + print() + +if __name__ == "__main__": + detect_cli_tools() \ No newline at end of file diff --git a/pdd/setup/local_llm_configurator.py b/pdd/setup/local_llm_configurator.py new file mode 100644 index 000000000..f77a753bb --- /dev/null +++ b/pdd/setup/local_llm_configurator.py @@ -0,0 +1,377 @@ +"""Local LLM configurator for PDD. + +Guides users through configuring local LLM tools (Ollama, LM Studio, custom endpoints). +Local models need a base_url and model name, not API keys. +""" + +from __future__ import annotations + +import csv +import io +import logging +import os +import shutil +import tempfile +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple +from urllib.parse import urlparse + +from rich.console import Console +from rich.table import Table + +logger = logging.getLogger("pdd.setup.local_llm_configurator") +console = Console() + +# CSV header for ~/.pdd/llm_model.csv +CSV_COLUMNS: List[str] = [ + "provider", + "model", + "input", + "output", + "coding_arena_elo", + "base_url", + "api_key", + "max_reasoning_tokens", + "structured_output", + "reasoning_type", + "location", +] + +DEFAULT_LM_STUDIO_BASE_URL = "http://localhost:1234/v1" +DEFAULT_OLLAMA_BASE_URL = "http://localhost:11434" +DEFAULT_OLLAMA_API_URL = "http://localhost:11434/api/tags" + + +def _get_user_csv_path() -> Path: + """Return the path to the user's ~/.pdd/llm_model.csv.""" + return Path.home() / ".pdd" / "llm_model.csv" + + +def _validate_base_url(url: str) -> bool: + """Validate that a base URL has http or https scheme and a netloc.""" + try: + parsed = urlparse(url.strip()) + return parsed.scheme in ("http", "https") and bool(parsed.netloc) + except Exception: + return False + + +def _build_model_row( + provider: str, + model: str, + base_url: str, + coding_arena_elo: int = 1000, + structured_output: bool = True, + reasoning_type: str = "none", +) -> Dict[str, Any]: + """Build a CSV row dict for a local model.""" + return { + "provider": provider, + "model": model, + "input": 0, + "output": 0, + "coding_arena_elo": coding_arena_elo, + "base_url": base_url, + "api_key": "", + "max_reasoning_tokens": 0, + "structured_output": structured_output, + "reasoning_type": reasoning_type, + "location": "", + } + + +def _read_existing_csv(csv_path: Path) -> List[Dict[str, str]]: + """Read existing rows from the user CSV, returning list of dicts.""" + rows: List[Dict[str, str]] = [] + if not csv_path.exists(): + return rows + try: + with open(csv_path, "r", newline="", encoding="utf-8") as f: + reader = csv.DictReader(f) + for row in reader: + rows.append(row) + except Exception as e: + logger.warning(f"Failed to read existing CSV at {csv_path}: {e}") + return rows + + +def _write_csv_atomic(csv_path: Path, rows: List[Dict[str, Any]]) -> None: + """Atomically write rows to the user CSV. + + Writes to a temporary file first, then moves it into place to avoid + partial writes corrupting the file. + """ + csv_path.parent.mkdir(parents=True, exist_ok=True) + + # Write to a temp file in the same directory for atomic rename + fd, tmp_path_str = tempfile.mkstemp( + dir=str(csv_path.parent), suffix=".csv.tmp" + ) + tmp_path = Path(tmp_path_str) + try: + with os.fdopen(fd, "w", newline="", encoding="utf-8") as f: + writer = csv.DictWriter(f, fieldnames=CSV_COLUMNS, extrasaction="ignore") + writer.writeheader() + for row in rows: + writer.writerow(row) + # Atomic move + shutil.move(str(tmp_path), str(csv_path)) + except Exception: + # Clean up temp file on failure + try: + tmp_path.unlink(missing_ok=True) + except Exception: + pass + raise + + +def _append_rows_to_csv(csv_path: Path, new_rows: List[Dict[str, Any]]) -> None: + """Append new model rows to the user CSV, creating it if needed.""" + existing = _read_existing_csv(csv_path) + # Convert existing rows to have consistent types + combined = list(existing) + new_rows + _write_csv_atomic(csv_path, combined) + + +def _discover_ollama_models(base_url: str) -> Optional[List[str]]: + """Query Ollama API for available models. + + Returns a list of model names, or None if the API is unreachable. + """ + import urllib.request + import json + + api_url = base_url.rstrip("/") + "/api/tags" + try: + req = urllib.request.Request(api_url, method="GET") + with urllib.request.urlopen(req, timeout=5) as resp: + data = json.loads(resp.read().decode("utf-8")) + models = data.get("models", []) + return [m.get("name", "") for m in models if m.get("name")] + except Exception as e: + logger.debug(f"Failed to query Ollama at {api_url}: {e}") + return None + + +def _prompt_input(prompt_text: str, default: str = "") -> str: + """Prompt user for input with optional default. Returns stripped input.""" + try: + if default: + raw = input(f"{prompt_text} [{default}]: ").strip() + return raw if raw else default + else: + return input(f"{prompt_text}: ").strip() + except (EOFError, KeyboardInterrupt): + return "" + + +def _configure_lm_studio() -> List[Dict[str, Any]]: + """Configure LM Studio models interactively.""" + rows: List[Dict[str, Any]] = [] + + console.print("\n[bold cyan]LM Studio Configuration[/bold cyan]") + console.print(f"Default base URL: {DEFAULT_LM_STUDIO_BASE_URL}") + + base_url = _prompt_input("Base URL", DEFAULT_LM_STUDIO_BASE_URL) + if not base_url: + console.print("[yellow]Cancelled.[/yellow]") + return rows + + if not _validate_base_url(base_url): + console.print("[red]Invalid URL. Must start with http:// or https://[/red]") + return rows + + while True: + model_name = _prompt_input("Model name (empty to finish)") + if not model_name: + break + + # Add lm_studio/ prefix if not present + litellm_model = model_name + if not litellm_model.startswith("lm_studio/"): + litellm_model = f"lm_studio/{model_name}" + + row = _build_model_row( + provider="lm_studio", + model=litellm_model, + base_url=base_url, + ) + rows.append(row) + console.print(f" [green]✓[/green] Added: {litellm_model}") + + return rows + + +def _configure_ollama() -> List[Dict[str, Any]]: + """Configure Ollama models interactively with auto-detection.""" + rows: List[Dict[str, Any]] = [] + + console.print("\n[bold cyan]Ollama Configuration[/bold cyan]") + console.print(f"Default base URL: {DEFAULT_OLLAMA_BASE_URL}") + + base_url = _prompt_input("Base URL", DEFAULT_OLLAMA_BASE_URL) + if not base_url: + console.print("[yellow]Cancelled.[/yellow]") + return rows + + if not _validate_base_url(base_url): + console.print("[red]Invalid URL. Must start with http:// or https://[/red]") + return rows + + # Try auto-detection + console.print("[dim]Checking for running Ollama instance...[/dim]") + discovered = _discover_ollama_models(base_url) + + if discovered: + console.print(f"[green]Found {len(discovered)} model(s):[/green]") + + table = Table(show_header=True, header_style="bold") + table.add_column("#", style="dim", width=4) + table.add_column("Model Name") + for idx, name in enumerate(discovered, 1): + table.add_row(str(idx), name) + console.print(table) + + selection = _prompt_input( + "Select models to add (comma-separated numbers, 'all', or empty to skip)" + ) + if not selection: + console.print("[yellow]No models selected.[/yellow]") + elif selection.strip().lower() == "all": + for name in discovered: + litellm_model = f"ollama_chat/{name}" + row = _build_model_row( + provider="Ollama", + model=litellm_model, + base_url=base_url, + ) + rows.append(row) + console.print(f" [green]✓[/green] Added: {litellm_model}") + else: + # Parse comma-separated indices + for part in selection.split(","): + part = part.strip() + try: + idx = int(part) + if 1 <= idx <= len(discovered): + name = discovered[idx - 1] + litellm_model = f"ollama_chat/{name}" + row = _build_model_row( + provider="Ollama", + model=litellm_model, + base_url=base_url, + ) + rows.append(row) + console.print(f" [green]✓[/green] Added: {litellm_model}") + else: + console.print(f" [yellow]Skipping invalid index: {idx}[/yellow]") + except ValueError: + console.print(f" [yellow]Skipping invalid input: '{part}'[/yellow]") + else: + console.print( + "[yellow]Could not connect to Ollama. Falling back to manual entry.[/yellow]" + ) + while True: + model_name = _prompt_input("Model name (empty to finish)") + if not model_name: + break + + litellm_model = model_name + if not litellm_model.startswith("ollama_chat/"): + litellm_model = f"ollama_chat/{model_name}" + + row = _build_model_row( + provider="Ollama", + model=litellm_model, + base_url=base_url, + ) + rows.append(row) + console.print(f" [green]✓[/green] Added: {litellm_model}") + + return rows + + +def _configure_custom() -> List[Dict[str, Any]]: + """Configure a custom local LLM endpoint interactively.""" + rows: List[Dict[str, Any]] = [] + + console.print("\n[bold cyan]Custom Local LLM Configuration[/bold cyan]") + + base_url = _prompt_input("Base URL (e.g., http://localhost:8080/v1)") + if not base_url: + console.print("[yellow]Cancelled.[/yellow]") + return rows + + if not _validate_base_url(base_url): + console.print("[red]Invalid URL. Must start with http:// or https://[/red]") + return rows + + provider_name = _prompt_input("Provider name", "custom") + + while True: + model_name = _prompt_input("Model name (empty to finish)") + if not model_name: + break + + row = _build_model_row( + provider=provider_name, + model=model_name, + base_url=base_url, + ) + rows.append(row) + console.print(f" [green]✓[/green] Added: {model_name}") + + return rows + + +def configure_local_llm() -> bool: + """Interactive setup for local LLM providers. + + Guides the user through selecting a local LLM provider (LM Studio, Ollama, + or custom), discovering available models, and appending them to the user's + ``~/.pdd/llm_model.csv``. + + Returns: + True if any models were added, False otherwise. + """ + console.print("\n[bold]Local LLM Setup[/bold]") + console.print("Configure local LLM tools for use with PDD.\n") + console.print("Select a provider:") + console.print(" [bold]1[/bold]. LM Studio (default: localhost:1234)") + console.print(" [bold]2[/bold]. Ollama (default: localhost:11434)") + console.print(" [bold]3[/bold]. Other (custom endpoint)") + console.print() + + choice = _prompt_input("Choice (1/2/3, empty to cancel)") + if not choice: + console.print("[yellow]Cancelled.[/yellow]") + return False + + new_rows: List[Dict[str, Any]] = [] + + if choice == "1": + new_rows = _configure_lm_studio() + elif choice == "2": + new_rows = _configure_ollama() + elif choice == "3": + new_rows = _configure_custom() + else: + console.print(f"[red]Invalid choice: '{choice}'. Please enter 1, 2, or 3.[/red]") + return False + + if not new_rows: + console.print("[yellow]No models were added.[/yellow]") + return False + + # Write to user CSV + csv_path = _get_user_csv_path() + try: + _append_rows_to_csv(csv_path, new_rows) + console.print( + f"\n[green]Successfully added {len(new_rows)} model(s) to {csv_path}[/green]" + ) + return True + except Exception as e: + console.print(f"[red]Failed to write to {csv_path}: {e}[/red]") + logger.error(f"Failed to write CSV: {e}", exc_info=True) + return False \ No newline at end of file diff --git a/pdd/setup/model_tester.py b/pdd/setup/model_tester.py new file mode 100644 index 000000000..2f7d69cb6 --- /dev/null +++ b/pdd/setup/model_tester.py @@ -0,0 +1,391 @@ +from __future__ import annotations + +import os +import time as time_module +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import pandas as pd +from rich.console import Console + +from rich.table import Table + +console = Console() + + +def _load_user_csv() -> Optional[pd.DataFrame]: + """Load the user's LLM model CSV from ~/.pdd/llm_model.csv. + + Returns: + DataFrame with model data, or None if file doesn't exist or is empty. + """ + csv_path = Path.home() / ".pdd" / "llm_model.csv" + if not csv_path.is_file(): + return None + + try: + df = pd.read_csv(csv_path) + except Exception as e: + console.print(f"[red]Failed to read {csv_path}: {e}[/red]") + return None + + if df.empty: + return None + + # Ensure expected columns exist + required_cols = {"provider", "model", "api_key"} + missing = required_cols - set(df.columns) + if missing: + console.print(f"[red]CSV is missing required columns: {missing}[/red]") + return None + + # Normalise nullable string columns + for col in ("api_key", "base_url", "location"): + if col in df.columns: + df[col] = df[col].fillna("").astype(str) + + # Normalise numeric cost columns + for col in ("input", "output"): + if col in df.columns: + df[col] = pd.to_numeric(df[col], errors="coerce").fillna(0.0) + + return df + + +def _resolve_api_key(row: Dict[str, Any]) -> Tuple[Optional[str], str]: + """Resolve the API key for a model row. + + Returns: + (key_value_or_none, status_string) + status_string is a human-readable description like '✓ Found (OPENAI_API_KEY)'. + """ + key_name: str = str(row.get("api_key", "")).strip() + + # Local model — no key required + if not key_name: + return None, "(local — no key required)" + + # Check environment + key_value = os.getenv(key_name, "") + if key_value: + return key_value.strip(), f"✓ Found ({key_name})" + + # Check if a .env file might have it (dotenv may not be loaded yet) + try: + from dotenv import dotenv_values + + env_path = Path.home() / ".pdd" / ".env" + if not env_path.is_file(): + env_path = Path.cwd() / ".env" + if env_path.is_file(): + vals = dotenv_values(env_path) + val = vals.get(key_name, "") + if val: + return val.strip(), f"✓ Found ({key_name} via .env)" + except ImportError: + pass + + return None, f"✗ Not found ({key_name})" + + +def _resolve_base_url(row: Dict[str, Any]) -> Optional[str]: + """Return the base_url for the model, if any.""" + base_url: str = str(row.get("base_url", "")).strip() + if base_url: + return base_url + + # LM Studio convention + model_name = str(row.get("model", "")).lower() + provider = str(row.get("provider", "")).lower() + if model_name.startswith("lm_studio/") or provider == "lm_studio": + return os.getenv("LM_STUDIO_API_BASE", "http://localhost:1234/v1") + + return None + + +def _calculate_cost( + prompt_tokens: int, + completion_tokens: int, + input_price_per_m: float, + output_price_per_m: float, +) -> float: + """Calculate cost from token counts and per-million-token prices.""" + return (prompt_tokens * input_price_per_m + completion_tokens * output_price_per_m) / 1_000_000.0 + + +def _classify_error(exc: Exception) -> str: + """Return a concise, user-friendly error description.""" + msg = str(exc).lower() + exc_type = type(exc).__name__ + + # Authentication errors + if "authentication" in msg or "401" in msg or "403" in msg or "invalid api key" in msg: + return f"Authentication error — check your API key ({exc_type})" + + # Connection refused (typically local servers) + if "connection refused" in msg or "connect" in msg and "refused" in msg: + return f"Connection refused — is the local server running? ({exc_type})" + + # Model not found + if "not found" in msg or "404" in msg or "does not exist" in msg: + return f"Model not found — check the model name ({exc_type})" + + # Timeout + if "timeout" in msg or "timed out" in msg: + return f"Request timed out ({exc_type})" + + # Rate limit + if "rate" in msg and "limit" in msg or "429" in msg: + return f"Rate limited — try again later ({exc_type})" + + # Generic + return f"{exc_type}: {exc}" + + +def _run_test(row: Dict[str, Any]) -> Dict[str, Any]: + """Run a single litellm.completion() test against the given model row. + + Returns a dict with keys: success, duration_s, cost, error, tokens. + """ + import litellm + + model_name: str = str(row.get("model", "")) + api_key, _key_status = _resolve_api_key(row) + base_url = _resolve_base_url(row) + + kwargs: Dict[str, Any] = { + "model": model_name, + "messages": [{"role": "user", "content": "Say OK"}], + "timeout": 30, + } + + # Only pass api_key if we have one (local models don't need it) + if api_key: + kwargs["api_key"] = api_key + elif not str(row.get("api_key", "")).strip(): + # Local model — use placeholder key if provider expects one + pass + + if base_url: + kwargs["base_url"] = base_url + kwargs["api_base"] = base_url + + # Vertex AI handling + is_vertex = model_name.startswith("vertex_ai/") or str(row.get("provider", "")).lower() in ( + "google", + "vertex_ai", + "googlevertexai", + ) + key_name = str(row.get("api_key", "")).strip() + if is_vertex and key_name == "VERTEX_CREDENTIALS": + creds_path = os.getenv("VERTEX_CREDENTIALS", "") + project = os.getenv("VERTEX_PROJECT", "") + location_csv = str(row.get("location", "")).strip() + location = location_csv if location_csv else os.getenv("VERTEX_LOCATION", "") + + if creds_path: + try: + import json as _json + + with open(creds_path, "r") as f: + creds = _json.load(f) + kwargs["vertex_credentials"] = _json.dumps(creds) + except Exception: + pass # Will likely fail at call time with a clear error + + if project: + kwargs["vertex_project"] = project + if location: + kwargs["vertex_location"] = location + + # Remove api_key for vertex — it uses credentials instead + kwargs.pop("api_key", None) + + start = time_module.time() + try: + response = litellm.completion(**kwargs) + duration = time_module.time() - start + + # Extract token usage + usage = getattr(response, "usage", None) + prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0 + completion_tokens = getattr(usage, "completion_tokens", 0) or 0 + + input_price = float(row.get("input", 0.0)) + output_price = float(row.get("output", 0.0)) + cost = _calculate_cost(prompt_tokens, completion_tokens, input_price, output_price) + + return { + "success": True, + "duration_s": duration, + "cost": cost, + "error": None, + "tokens": {"prompt": prompt_tokens, "completion": completion_tokens}, + } + + except Exception as exc: + duration = time_module.time() - start + return { + "success": False, + "duration_s": duration, + "cost": 0.0, + "error": _classify_error(exc), + "tokens": None, + } + + +def _display_model_list( + df: pd.DataFrame, + results: Dict[int, Dict[str, Any]], +) -> None: + """Display the model list as a rich table with any persisted test results.""" + table = Table(title="Available Models", show_lines=False, pad_edge=True) + table.add_column("#", style="bold cyan", justify="right", width=4) + table.add_column("Provider", style="white", min_width=10) + table.add_column("Model", style="bright_white", min_width=30) + table.add_column("Input $/M", justify="right", min_width=8) + table.add_column("Output $/M", justify="right", min_width=8) + table.add_column("ELO", justify="right", min_width=6) + table.add_column("Last Test", min_width=25) + + for idx, row in df.iterrows(): + i = int(idx) + provider = str(row.get("provider", "")) + model = str(row.get("model", "")) + input_cost = row.get("input", 0.0) + output_cost = row.get("output", 0.0) + elo = row.get("coding_arena_elo", "") + + # Format costs + input_str = f"${float(input_cost):.2f}" if pd.notna(input_cost) else "—" + output_str = f"${float(output_cost):.2f}" if pd.notna(output_cost) else "—" + elo_str = str(int(elo)) if pd.notna(elo) and elo else "—" + + # Test result + if i in results: + r = results[i] + if r["success"]: + test_str = f"[green]✓ OK ({r['duration_s']:.1f}s, ${r['cost']:.4f})[/green]" + else: + # Truncate error for table display + err = r["error"] or "Unknown error" + if len(err) > 40: + err = err[:37] + "..." + test_str = f"[red]✗ {err}[/red]" + else: + test_str = "—" + + table.add_row( + str(i + 1), + provider, + model, + input_str, + output_str, + elo_str, + test_str, + ) + + console.print(table) + + +def test_model_interactive() -> None: + """Interactive model tester. + + Shows models from ~/.pdd/llm_model.csv, lets the user pick one to test, + runs a minimal litellm.completion() call, and displays diagnostics. + Loops until the user enters empty input or 'q'. + """ + df = _load_user_csv() + if df is None: + console.print( + "[yellow]No user model CSV found at ~/.pdd/llm_model.csv or it is empty.[/yellow]" + ) + console.print( + "[dim]Run [bold]pdd setup[/bold] to configure your models first.[/dim]" + ) + return + + # Session-persisted test results: index -> result dict + results: Dict[int, Dict[str, Any]] = {} + + while True: + console.print() + _display_model_list(df, results) + console.print() + + try: + choice = console.input( + "[bold cyan]Enter model number to test (or q/empty to quit): [/bold cyan]" + ).strip() + except (EOFError, KeyboardInterrupt): + console.print("\n[dim]Exiting model tester.[/dim]") + return + + if not choice or choice.lower() == "q": + console.print("[dim]Exiting model tester.[/dim]") + return + + # Parse selection + try: + idx = int(choice) - 1 + except ValueError: + console.print(f"[red]Invalid input: '{choice}'. Enter a number or 'q'.[/red]") + continue + + if idx < 0 or idx >= len(df): + console.print(f"[red]Invalid selection. Choose 1–{len(df)}.[/red]") + continue + + row = df.iloc[idx].to_dict() + model_name = str(row.get("model", "")) + provider = str(row.get("provider", "")) + + console.print() + console.print(f"[bold]Testing: [bright_white]{model_name}[/bright_white] ({provider})[/bold]") + console.print("─" * 50) + + # Diagnostics: API key + api_key, key_status = _resolve_api_key(row) + if "✓" in key_status: + console.print(f" API Key: [green]{key_status}[/green]") + elif "local" in key_status: + console.print(f" API Key: [dim]{key_status}[/dim]") + else: + console.print(f" API Key: [red]{key_status}[/red]") + + # Diagnostics: base URL + base_url = _resolve_base_url(row) + if base_url: + console.print(f" Base URL: [dim]{base_url}[/dim]") + + # Diagnostics: Vertex AI specifics + key_name = str(row.get("api_key", "")).strip() + if key_name == "VERTEX_CREDENTIALS": + project = os.getenv("VERTEX_PROJECT", "") + location_csv = str(row.get("location", "")).strip() + location = location_csv if location_csv else os.getenv("VERTEX_LOCATION", "") + if project: + console.print(f" Project: [dim]{project}[/dim]") + if location: + console.print(f" Location: [dim]{location}[/dim]") + + console.print() + console.print(" [dim]Sending test prompt...[/dim]") + + # Run the test + result = _run_test(row) + results[idx] = result + + if result["success"]: + tokens = result.get("tokens") or {} + token_info = "" + if tokens: + token_info = f", {tokens.get('prompt', 0)}+{tokens.get('completion', 0)} tokens" + console.print( + f" LLM call [green]✓ OK[/green] " + f"({result['duration_s']:.1f}s, ${result['cost']:.4f}{token_info})" + ) + else: + console.print(f" LLM call [red]✗ {result['error']}[/red]") + + console.print() \ No newline at end of file diff --git a/pdd/setup/pddrc_initializer.py b/pdd/setup/pddrc_initializer.py new file mode 100644 index 000000000..3443aa1f9 --- /dev/null +++ b/pdd/setup/pddrc_initializer.py @@ -0,0 +1,192 @@ +from __future__ import annotations + +import os +from pathlib import Path +from typing import Optional + +from rich.console import Console +from rich.syntax import Syntax + +console = Console() + +# Language detection markers +PYTHON_MARKERS = ("setup.py", "pyproject.toml", "setup.cfg", "Pipfile", "requirements.txt") +TYPESCRIPT_MARKERS = ("package.json",) +GO_MARKERS = ("go.mod",) + +# Path defaults per language +LANGUAGE_DEFAULTS: dict[str, dict[str, str]] = { + "python": { + "generate_output_path": "pdd/", + "test_output_path": "tests/", + "example_output_path": "context/", + }, + "typescript": { + "generate_output_path": "src/", + "test_output_path": "__tests__/", + "example_output_path": "examples/", + }, + "go": { + "generate_output_path": ".", + "test_output_path": ".", + "example_output_path": "examples/", + }, +} + +# Standard defaults +STANDARD_DEFAULTS: dict[str, float | int] = { + "strength": 1.0, + "temperature": 0.0, + "target_coverage": 80.0, + "budget": 10.0, + "max_attempts": 3, +} + +PDDRC_FILENAME = ".pddrc" + + +def _detect_language(cwd: Path) -> Optional[str]: + """Detect project language based on marker files in the current directory. + + Returns the detected language string or ``None`` if the project type + cannot be determined automatically. + """ + # Check Python markers + for marker in PYTHON_MARKERS: + if (cwd / marker).exists(): + return "python" + + # Check TypeScript – look for typescript in package.json dependencies + package_json_path = cwd / "package.json" + if package_json_path.exists(): + try: + import json + + with open(package_json_path, "r", encoding="utf-8") as fh: + pkg = json.load(fh) + all_deps: dict[str, str] = {} + all_deps.update(pkg.get("dependencies", {})) + all_deps.update(pkg.get("devDependencies", {})) + if "typescript" in all_deps: + return "typescript" + except (json.JSONDecodeError, OSError): + pass + + # Check Go markers + for marker in GO_MARKERS: + if (cwd / marker).exists(): + return "go" + + return None + + +def _prompt_language() -> str: + """Interactively ask the user to choose a project language.""" + console.print("\n[warning]Could not auto-detect project language.[/warning]") + console.print(" [bold]1)[/bold] Python") + console.print(" [bold]2)[/bold] TypeScript") + console.print(" [bold]3)[/bold] Go") + + while True: + choice = console.input("\nSelect language [1/2/3]: ").strip() + if choice == "1": + return "python" + elif choice == "2": + return "typescript" + elif choice == "3": + return "go" + else: + console.print("[error]Invalid choice. Please enter 1, 2, or 3.[/error]") + + +def _build_pddrc_content(language: str) -> str: + """Build the YAML content for a ``.pddrc`` file. + + Parameters + ---------- + language: + One of ``"python"``, ``"typescript"``, or ``"go"``. + + Returns + ------- + str + The full YAML string ready to be written to disk. + """ + paths = LANGUAGE_DEFAULTS.get(language, LANGUAGE_DEFAULTS["python"]) + + lines: list[str] = [ + 'version: "1.0"', + "", + "contexts:", + " default:", + " defaults:", + f' generate_output_path: "{paths["generate_output_path"]}"', + f' test_output_path: "{paths["test_output_path"]}"', + f' example_output_path: "{paths["example_output_path"]}"', + f' default_language: "{language}"', + ] + + for key, value in STANDARD_DEFAULTS.items(): + # Format integers without trailing .0, floats with one decimal + if isinstance(value, int): + lines.append(f" {key}: {value}") + else: + lines.append(f" {key}: {value}") + + lines.append("") # trailing newline + return "\n".join(lines) + + +def offer_pddrc_init() -> bool: + """Offer to create a ``.pddrc`` configuration file in the current directory. + + If a ``.pddrc`` already exists the user is informed and the function + returns ``False``. Otherwise a preview of sensible defaults is shown + and the user is prompted to confirm creation. + + Returns + ------- + bool + ``True`` if the file was created, ``False`` otherwise. + """ + cwd = Path.cwd() + pddrc_path = cwd / PDDRC_FILENAME + + # ── Already exists ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + if pddrc_path.exists(): + console.print( + f"[info]A {PDDRC_FILENAME} file already exists in {cwd}.[/info]" + ) + return False + + # ── Detect / prompt language ━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + language = _detect_language(cwd) + if language is None: + language = _prompt_language() + else: + console.print(f"\n[success]Detected project language: {language}[/success]") + + # ── Build & preview ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + content = _build_pddrc_content(language) + + console.print(f"\n[info]Proposed {PDDRC_FILENAME} contents:[/info]\n") + syntax = Syntax(content, "yaml", theme="monokai", line_numbers=False) + console.print(syntax) + + # ── Prompt for confirmation (Enter = yes) ━━━━━━━━━━━━━━━━━━━━━ + answer = console.input(f"\nCreate {PDDRC_FILENAME}? [Y/n] ").strip().lower() + if answer in ("", "y", "yes"): + try: + pddrc_path.write_text(content, encoding="utf-8") + console.print( + f"[success]Created {PDDRC_FILENAME} in {cwd}[/success]" + ) + return True + except OSError as exc: + console.print( + f"[error]Failed to write {PDDRC_FILENAME}: {exc}[/error]" + ) + return False + else: + console.print("[info]Skipped .pddrc creation.[/info]") + return False \ No newline at end of file diff --git a/pdd/setup/provider_manager.py b/pdd/setup/provider_manager.py new file mode 100644 index 000000000..65d0ccfee --- /dev/null +++ b/pdd/setup/provider_manager.py @@ -0,0 +1,546 @@ +from __future__ import annotations + +import csv +import io +import os +import re +import tempfile +import shutil +from datetime import datetime +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +from rich.console import Console +from rich.table import Table +from rich.prompt import Prompt, Confirm + +from pdd.setup.api_key_scanner import KeyInfo + +console = Console() + +# CSV column schema +CSV_FIELDNAMES = [ + "provider", "model", "input", "output", "coding_arena_elo", + "base_url", "api_key", "max_reasoning_tokens", "structured_output", + "reasoning_type", "location", +] + +# --------------------------------------------------------------------------- +# Path helpers +# --------------------------------------------------------------------------- + +def _get_shell_name() -> str: + """Detect shell from SHELL env var, default to bash.""" + shell_path = os.environ.get("SHELL", "/bin/bash") + shell = Path(shell_path).name + # Normalise common shells + if shell in ("bash", "zsh", "fish", "sh", "ksh", "csh", "tcsh"): + return shell + return "bash" + + +def _get_pdd_dir() -> Path: + """Return ~/.pdd, creating it if necessary.""" + pdd_dir = Path.home() / ".pdd" + pdd_dir.mkdir(parents=True, exist_ok=True) + return pdd_dir + + +def _get_api_env_path() -> Path: + """Return path to ~/.pdd/api-env.{shell}.""" + shell = _get_shell_name() + return _get_pdd_dir() / f"api-env.{shell}" + + +def _get_user_csv_path() -> Path: + """Return path to ~/.pdd/llm_model.csv.""" + return _get_pdd_dir() / "llm_model.csv" + + +def _get_master_csv_path() -> Path: + """Return path to pdd/data/llm_model.csv (shipped with the package).""" + return Path(__file__).resolve().parent.parent / "data" / "llm_model.csv" + + +# --------------------------------------------------------------------------- +# CSV I/O helpers +# --------------------------------------------------------------------------- + +def _read_csv(path: Path) -> List[Dict[str, str]]: + """Read a CSV file and return list of row dicts. Returns [] if missing.""" + if not path.exists(): + return [] + with open(path, "r", encoding="utf-8", newline="") as f: + reader = csv.DictReader(f) + return list(reader) + + +def _write_csv_atomic(path: Path, rows: List[Dict[str, str]]) -> None: + """Atomically write rows to a CSV file (temp file + rename).""" + path.parent.mkdir(parents=True, exist_ok=True) + fd, tmp_path = tempfile.mkstemp( + dir=str(path.parent), suffix=".tmp", prefix=".llm_model_" + ) + try: + with os.fdopen(fd, "w", encoding="utf-8", newline="") as f: + writer = csv.DictWriter(f, fieldnames=CSV_FIELDNAMES) + writer.writeheader() + for row in rows: + # Ensure every field is present + clean = {k: row.get(k, "") for k in CSV_FIELDNAMES} + writer.writerow(clean) + shutil.move(tmp_path, str(path)) + except Exception: + # Clean up temp file on failure + if os.path.exists(tmp_path): + os.unlink(tmp_path) + raise + + +# --------------------------------------------------------------------------- +# api-env file helpers +# --------------------------------------------------------------------------- + +def _read_api_env_lines(path: Path) -> List[str]: + """Read api-env file lines. Returns [] if missing.""" + if not path.exists(): + return [] + with open(path, "r", encoding="utf-8") as f: + return f.readlines() + + +def _write_api_env_atomic(path: Path, lines: List[str]) -> None: + """Atomically write lines to api-env file.""" + path.parent.mkdir(parents=True, exist_ok=True) + fd, tmp_path = tempfile.mkstemp( + dir=str(path.parent), suffix=".tmp", prefix=".api-env_" + ) + try: + with os.fdopen(fd, "w", encoding="utf-8") as f: + f.writelines(lines) + shutil.move(tmp_path, str(path)) + except Exception: + if os.path.exists(tmp_path): + os.unlink(tmp_path) + raise + + +def _save_key_to_api_env(key_name: str, key_value: str) -> None: + """ + Add or update an export line in the api-env file. + If the key already exists (even commented out), replace it. + """ + env_path = _get_api_env_path() + lines = _read_api_env_lines(env_path) + + export_line = f'export {key_name}="{key_value}"\n' + + # Pattern to match existing line (commented or not) + pattern = re.compile( + r"^(?:#\s*)?export\s+" + re.escape(key_name) + r"\s*=", re.MULTILINE + ) + + found = False + new_lines: List[str] = [] + for line in lines: + if pattern.match(line.strip()): + new_lines.append(export_line) + found = True + else: + new_lines.append(line) + + if not found: + # Ensure trailing newline before appending + if new_lines and not new_lines[-1].endswith("\n"): + new_lines[-1] += "\n" + new_lines.append(export_line) + + _write_api_env_atomic(env_path, new_lines) + + +def _comment_out_key_in_api_env(key_name: str) -> None: + """ + Comment out (never delete) a key in the api-env file. + Adds a comment with the date. + """ + env_path = _get_api_env_path() + lines = _read_api_env_lines(env_path) + + pattern = re.compile( + r"^export\s+" + re.escape(key_name) + r"\s*=", re.MULTILINE + ) + + today = datetime.now().strftime("%Y-%m-%d") + new_lines: List[str] = [] + for line in lines: + stripped = line.strip() + if pattern.match(stripped): + comment = f"# Commented out by pdd setup on {today}\n" + new_lines.append(comment) + new_lines.append(f"# {stripped}\n") + else: + new_lines.append(line) + + _write_api_env_atomic(env_path, new_lines) + + +def _get_master_rows_for_key(api_key_name: str) -> List[Dict[str, str]]: + """Return all rows from master CSV whose api_key column matches.""" + master_rows = _read_csv(_get_master_csv_path()) + return [r for r in master_rows if r.get("api_key", "").strip() == api_key_name] + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + +def add_api_key(scan_results: Dict[str, KeyInfo]) -> bool: + """ + Show missing keys from scan_results, prompt user for one, save to + api-env file, then copy ALL matching rows from master CSV into user CSV. + + Returns True if a key was added/updated, False if cancelled. + """ + # Separate missing and found keys + missing_keys = {k: v for k, v in scan_results.items() if not v.is_set} + found_keys = {k: v for k, v in scan_results.items() if v.is_set} + + if not missing_keys: + console.print("[green]All known API keys are already configured.[/green]") + return False + + # Display missing keys + console.print("\n[bold]Missing API keys:[/bold]") + sorted_missing = sorted(missing_keys.keys()) + for idx, key_name in enumerate(sorted_missing, 1): + console.print(f" {idx}. {key_name}") + + # Prompt user to select one + selection = Prompt.ask( + "\nEnter the number of the key to add (or press Enter to cancel)" + ) + if not selection.strip(): + console.print("[dim]Cancelled.[/dim]") + return False + + try: + choice = int(selection.strip()) + if choice < 1 or choice > len(sorted_missing): + console.print("[red]Invalid selection.[/red]") + return False + except ValueError: + console.print("[red]Invalid input.[/red]") + return False + + selected_key = sorted_missing[choice - 1] + + # Check if key is already in environment (shouldn't be since it's in missing, but guard) + if os.environ.get(selected_key): + console.print( + f"[yellow]{selected_key} is already set in the current environment. Skipping.[/yellow]" + ) + return False + + # Check if master CSV has rows for this key + master_rows = _get_master_rows_for_key(selected_key) + if not master_rows: + console.print( + f"[yellow]No models found in master CSV for '{selected_key}'.[/yellow]\n" + f"[yellow]Please use 'Add a custom provider' instead.[/yellow]" + ) + return False + + # Prompt for the actual key value + key_value = Prompt.ask(f"Enter the value for {selected_key}") + if not key_value.strip(): + console.print("[dim]Cancelled.[/dim]") + return False + key_value = key_value.strip() + + # Save key to api-env file + _save_key_to_api_env(selected_key, key_value) + console.print(f"[green]Saved {selected_key} to {_get_api_env_path()}[/green]") + + # Copy matching rows from master CSV into user CSV + user_csv_path = _get_user_csv_path() + existing_user_rows = _read_csv(user_csv_path) + + # Build set of existing model identifiers to avoid duplicates + existing_models = { + (r.get("provider", ""), r.get("model", "")) + for r in existing_user_rows + } + + added_count = 0 + for row in master_rows: + model_id = (row.get("provider", ""), row.get("model", "")) + if model_id not in existing_models: + existing_user_rows.append(row) + existing_models.add(model_id) + added_count += 1 + + _write_csv_atomic(user_csv_path, existing_user_rows) + console.print( + f"[green]Added {added_count} model(s) for {selected_key} to {user_csv_path}[/green]" + ) + + return True + + +def add_custom_provider() -> bool: + """ + Prompt for custom provider details and append a row to user CSV. + + Returns True if a provider was added, False if cancelled. + """ + console.print("\n[bold]Add a Custom LiteLLM-Compatible Provider[/bold]\n") + + # Provider prefix (e.g. "openai", "anthropic", "ollama", etc.) + provider = Prompt.ask("Provider prefix (e.g. openai, ollama, together_ai)") + if not provider.strip(): + console.print("[dim]Cancelled.[/dim]") + return False + provider = provider.strip() + + # Model name + model_name = Prompt.ask("Model name (e.g. my-model-v1)") + if not model_name.strip(): + console.print("[dim]Cancelled.[/dim]") + return False + model_name = model_name.strip() + + # Full model string for LiteLLM: provider/model + full_model = f"{provider}/{model_name}" + + # API key env var name + api_key_var = Prompt.ask("API key environment variable name (e.g. OPENAI_API_KEY)") + if not api_key_var.strip(): + console.print("[dim]Cancelled.[/dim]") + return False + api_key_var = api_key_var.strip() + + # Base URL (optional) + base_url = Prompt.ask("Base URL (optional, press Enter to skip)", default="") + base_url = base_url.strip() + + # Costs (optional) + input_cost = Prompt.ask("Input cost per 1M tokens (optional, press Enter for 0.0)", default="0.0") + output_cost = Prompt.ask("Output cost per 1M tokens (optional, press Enter for 0.0)", default="0.0") + + try: + input_cost_val = str(float(input_cost.strip())) + except ValueError: + input_cost_val = "0.0" + + try: + output_cost_val = str(float(output_cost.strip())) + except ValueError: + output_cost_val = "0.0" + + # Ask if user wants to provide the actual API key value now + provide_key = Confirm.ask( + f"Do you want to enter the value for {api_key_var} now?", default=True + ) + if provide_key: + key_value = Prompt.ask(f"Enter the value for {api_key_var}") + if key_value.strip(): + _save_key_to_api_env(api_key_var, key_value.strip()) + console.print( + f"[green]Saved {api_key_var} to {_get_api_env_path()}[/green]" + ) + + # Build the row with sensible defaults + new_row: Dict[str, str] = { + "provider": provider, + "model": full_model, + "input": input_cost_val, + "output": output_cost_val, + "coding_arena_elo": "1000", + "base_url": base_url, + "api_key": api_key_var, + "max_reasoning_tokens": "0", + "structured_output": "True", + "reasoning_type": "", + "location": "", + } + + # Append to user CSV + user_csv_path = _get_user_csv_path() + existing_rows = _read_csv(user_csv_path) + existing_rows.append(new_row) + _write_csv_atomic(user_csv_path, existing_rows) + + console.print( + f"[green]Added custom model '{full_model}' to {user_csv_path}[/green]" + ) + return True + + +def remove_models_by_provider() -> bool: + """ + Group user CSV models by api_key, show numbered list with counts, + remove all rows for selected provider. Comment out the key in api-env. + + Returns True if models were removed, False if cancelled. + """ + user_csv_path = _get_user_csv_path() + rows = _read_csv(user_csv_path) + + if not rows: + console.print("[yellow]No models configured in user CSV.[/yellow]") + return False + + # Group by api_key + provider_groups: Dict[str, List[Dict[str, str]]] = {} + for row in rows: + key = row.get("api_key", "").strip() + if not key: + key = "(no api_key)" + provider_groups.setdefault(key, []).append(row) + + sorted_providers = sorted(provider_groups.keys()) + + # Display table + table = Table(title="Configured Providers") + table.add_column("#", style="bold") + table.add_column("API Key Variable") + table.add_column("Model Count", justify="right") + table.add_column("Sample Models") + + for idx, prov_key in enumerate(sorted_providers, 1): + prov_rows = provider_groups[prov_key] + sample = ", ".join( + r.get("model", "?") for r in prov_rows[:3] + ) + if len(prov_rows) > 3: + sample += ", ..." + table.add_row(str(idx), prov_key, str(len(prov_rows)), sample) + + console.print(table) + + selection = Prompt.ask( + "\nEnter the number of the provider to remove (or press Enter to cancel)" + ) + if not selection.strip(): + console.print("[dim]Cancelled.[/dim]") + return False + + try: + choice = int(selection.strip()) + if choice < 1 or choice > len(sorted_providers): + console.print("[red]Invalid selection.[/red]") + return False + except ValueError: + console.print("[red]Invalid input.[/red]") + return False + + selected_provider_key = sorted_providers[choice - 1] + remove_count = len(provider_groups[selected_provider_key]) + + # Confirm + if not Confirm.ask( + f"Remove all {remove_count} model(s) for '{selected_provider_key}'?" + ): + console.print("[dim]Cancelled.[/dim]") + return False + + # Filter out the selected provider's rows + remaining_rows = [ + r for r in rows + if (r.get("api_key", "").strip() or "(no api_key)") != selected_provider_key + ] + + _write_csv_atomic(user_csv_path, remaining_rows) + console.print( + f"[green]Removed {remove_count} model(s) for '{selected_provider_key}'.[/green]" + ) + + # Comment out the key in api-env (only if it's a real key name) + if selected_provider_key != "(no api_key)": + _comment_out_key_in_api_env(selected_provider_key) + console.print( + f"[green]Commented out {selected_provider_key} in {_get_api_env_path()}[/green]" + ) + + return True + + +def remove_individual_models() -> bool: + """ + List all models from user CSV, let user select by comma-separated numbers, + remove selected rows. + + Returns True if models were removed, False if cancelled. + """ + user_csv_path = _get_user_csv_path() + rows = _read_csv(user_csv_path) + + if not rows: + console.print("[yellow]No models configured in user CSV.[/yellow]") + return False + + # Display all models + table = Table(title="Configured Models") + table.add_column("#", style="bold") + table.add_column("Provider") + table.add_column("Model") + table.add_column("API Key") + + for idx, row in enumerate(rows, 1): + table.add_row( + str(idx), + row.get("provider", ""), + row.get("model", ""), + row.get("api_key", ""), + ) + + console.print(table) + + selection = Prompt.ask( + "\nEnter model numbers to remove (comma-separated, or press Enter to cancel)" + ) + if not selection.strip(): + console.print("[dim]Cancelled.[/dim]") + return False + + # Parse comma-separated numbers + indices_to_remove: set[int] = set() + for part in selection.split(","): + part = part.strip() + if not part: + continue + try: + num = int(part) + if 1 <= num <= len(rows): + indices_to_remove.add(num) + else: + console.print(f"[yellow]Skipping invalid number: {num}[/yellow]") + except ValueError: + console.print(f"[yellow]Skipping invalid input: '{part}'[/yellow]") + + if not indices_to_remove: + console.print("[dim]No valid selections. Cancelled.[/dim]") + return False + + # Show what will be removed + console.print("\n[bold]Models to remove:[/bold]") + for idx in sorted(indices_to_remove): + row = rows[idx - 1] + console.print(f" {idx}. {row.get('model', '?')} ({row.get('api_key', '')})") + + if not Confirm.ask(f"Remove {len(indices_to_remove)} model(s)?"): + console.print("[dim]Cancelled.[/dim]") + return False + + # Filter out selected rows (convert to 0-based) + remaining_rows = [ + row for idx, row in enumerate(rows, 1) + if idx not in indices_to_remove + ] + + _write_csv_atomic(user_csv_path, remaining_rows) + console.print( + f"[green]Removed {len(indices_to_remove)} model(s) from {user_csv_path}[/green]" + ) + + return True \ No newline at end of file diff --git a/pdd/setup/setup_tool.py b/pdd/setup/setup_tool.py new file mode 100644 index 000000000..23cc43eca --- /dev/null +++ b/pdd/setup/setup_tool.py @@ -0,0 +1,155 @@ +""" +Main orchestrator for ``pdd setup``. + +Auto-scans the environment for API keys (existence only — no API calls), +then presents an interactive menu. After any action the menu re-displays +with an updated scan. +""" +from __future__ import annotations + +from typing import Dict + +from .api_key_scanner import scan_environment, KeyInfo +from .provider_manager import ( + add_api_key, + add_custom_provider, + remove_models_by_provider, + remove_individual_models, +) +from .local_llm_configurator import configure_local_llm +from .model_tester import test_model_interactive +from .cli_detector import detect_cli_tools +from .pddrc_initializer import offer_pddrc_init + + +# --------------------------------------------------------------------------- +# Display helpers +# --------------------------------------------------------------------------- + +def _display_scan(scan_results: Dict[str, KeyInfo]) -> None: + """Print a table of discovered API keys and a summary line.""" + print("\n API-key scan") + print(" " + "─" * 50) + + api_found = 0 + local_count = 0 + + for key_name, info in scan_results.items(): + if info.is_set: + # Heuristic: keys whose source mentions "local" or whose name + # hints at a local provider are counted separately. + source_lower = (info.source or "").lower() + if "local" in source_lower or "ollama" in key_name.lower() or "lm_studio" in key_name.lower(): + local_count += 1 + else: + api_found += 1 + print(f" {key_name:30s} ✓ Found ({info.source})") + else: + print(f" {key_name:30s} — Not found") + + total_configured = api_found + local_count + print( + f"\n Models configured: {total_configured} " + f"(from {api_found} API keys + {local_count} local)" + ) + print() + + +def _display_menu() -> None: + """Print the interactive menu options.""" + print(" What would you like to do?") + print(" 1. Add a provider") + print(" 2. Remove models") + print(" 3. Test a model") + print(" 4. Detect CLI tools") + print(" 5. Initialize .pddrc") + print(" 6. Done") + print() + + +def _add_provider_submenu(scan_results: Dict[str, KeyInfo]) -> None: + """Sub-menu for option 1 — Add a provider.""" + print() + print(" Add a provider:") + print(" a. Enter an API key") + print(" b. Add a local LLM") + print(" c. Add a custom provider") + print() + + sub_choice = input(" Choice [a/b/c]: ").strip().lower() + + if sub_choice == "a": + add_api_key(scan_results) + elif sub_choice == "b": + configure_local_llm() + elif sub_choice == "c": + add_custom_provider() + else: + print(" Invalid choice — returning to main menu.") + + +def _remove_models_submenu() -> None: + """Sub-menu for option 2 — Remove models.""" + print() + print(" Remove models:") + print(" a. By provider") + print(" b. Individual models") + print() + + sub_choice = input(" Choice [a/b]: ").strip().lower() + + if sub_choice == "a": + remove_models_by_provider() + elif sub_choice == "b": + remove_individual_models() + else: + print(" Invalid choice — returning to main menu.") + + +# --------------------------------------------------------------------------- +# Public entry point +# --------------------------------------------------------------------------- + +def run_setup() -> None: + """Main entry point for ``pdd setup``. + + Scans the environment, displays results, and loops an interactive menu + until the user selects *Done* or presses Ctrl-C. + """ + try: + print() + print(" ╭──────────────────────────────╮") + print(" │ pdd setup │") + print(" ╰──────────────────────────────╯") + + while True: + # (Re-)scan on every iteration so the display stays current. + scan_results = scan_environment() + _display_scan(scan_results) + _display_menu() + + choice = input(" Choice [1-6]: ").strip() + + if choice == "1": + _add_provider_submenu(scan_results) + elif choice == "2": + _remove_models_submenu() + elif choice == "3": + test_model_interactive() + elif choice == "4": + detect_cli_tools() + elif choice == "5": + offer_pddrc_init() + elif choice == "6": + print("\n ✓ Setup complete. Happy prompting!\n") + break + else: + print(" Invalid choice — please enter a number between 1 and 6.") + + except KeyboardInterrupt: + # Clean exit on Ctrl-C at any point. + print("\n\n Setup interrupted — exiting.\n") + + +if __name__ == "__main__": + run_setup() \ No newline at end of file From 92e9fdb69074216aae73c7d77652978453613abd Mon Sep 17 00:00:00 2001 From: Niti Goyal Date: Sun, 15 Feb 2026 19:28:16 -0500 Subject: [PATCH 04/10] Improve 'Add a provider' flow in pdd setup - Show LiteLLM registry of providers and models when adding models rather than data/llm_model.csv - Add search feature to search for a provider - Ask user for API keys based on the models they want to add --- context/api_key_scanner_example.py | 14 +- context/litellm_registry_example.py | 69 +++++ context/provider_manager_example.py | 30 +- context/setup_tool_example.py | 59 ++++ pdd/prompts/api_key_scanner_python.prompt | 17 +- pdd/prompts/litellm_registry_python.prompt | 57 ++++ pdd/prompts/provider_manager_python.prompt | 45 ++- pdd/prompts/setup_tool_python.prompt | 17 +- pdd/setup/api_key_scanner.py | 32 ++- pdd/setup/litellm_registry.py | 312 +++++++++++++++++++++ pdd/setup/provider_manager.py | 282 ++++++++++++++----- pdd/setup/setup_tool.py | 16 +- 12 files changed, 827 insertions(+), 123 deletions(-) create mode 100644 context/litellm_registry_example.py create mode 100644 context/setup_tool_example.py create mode 100644 pdd/prompts/litellm_registry_python.prompt create mode 100644 pdd/setup/litellm_registry.py diff --git a/context/api_key_scanner_example.py b/context/api_key_scanner_example.py index 018d6be0c..325b48110 100644 --- a/context/api_key_scanner_example.py +++ b/context/api_key_scanner_example.py @@ -13,14 +13,22 @@ def main() -> None: """ Demonstrates how to use the api_key_scanner module to: - 1. Discover all API key variable names from llm_model.csv + 1. Discover all API key variable names from the user's ~/.pdd/llm_model.csv 2. Scan multiple sources (shell env, .env file, ~/.pdd/api-env.*) 3. Report existence and source without storing key values + + Note: The scanner reads from the user's configured models, not a hardcoded + master list. If no models have been added via `pdd setup`, both functions + return empty results. """ - # Get all provider key names from the master CSV + # Get all provider key names from the user's configured CSV all_keys = get_provider_key_names() - print(f"Provider key names from CSV: {all_keys}\n") + print(f"Provider key names from user CSV: {all_keys}\n") + + if not all_keys: + print("No models configured yet. Use `pdd setup` to add providers.") + return # Scan the environment for all API keys print("Scanning environment for API keys...\n") diff --git a/context/litellm_registry_example.py b/context/litellm_registry_example.py new file mode 100644 index 000000000..2d48bf3dc --- /dev/null +++ b/context/litellm_registry_example.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +import sys +from pathlib import Path + +# Add the project root to sys.path +project_root = Path(__file__).resolve().parent.parent +sys.path.append(str(project_root)) + +from pdd.setup.litellm_registry import ( + is_litellm_available, + get_api_key_env_var, + get_top_providers, + search_providers, + get_models_for_provider, + ProviderInfo, + ModelInfo, +) + + +def main() -> None: + """ + Demonstrates how to use the litellm_registry module to: + 1. Check if litellm data is available + 2. Browse top providers + 3. Search for a provider by name + 4. List chat models for a provider with pricing + 5. Look up API key env var for a provider + """ + + # Check availability + if not is_litellm_available(): + print("litellm is not installed or has no model data.") + return + + # Browse top providers (curated list of ~10 major cloud providers) + print("Top providers:") + for p in get_top_providers(): + print(f" {p.display_name:20s} {p.model_count:3d} chat models key: {p.api_key_env_var}") + print() + + # Search for a provider by substring + results = search_providers("anth") + print(f"Search 'anth': {len(results)} result(s)") + for p in results: + print(f" {p.display_name} ({p.model_count} models)") + print() + + # List models for a specific provider + models = get_models_for_provider("anthropic") + print(f"Anthropic chat models ({len(models)}):") + for m in models[:5]: + print( + f" {m.litellm_id:40s} ${m.input_cost_per_million:>7.2f} in " + f"${m.output_cost_per_million:>7.2f} out " + f"ctx: {m.max_input_tokens}" + ) + print() + + # Look up API key env var + env_var = get_api_key_env_var("anthropic") + print(f"Anthropic API key env var: {env_var}") + + env_var_unknown = get_api_key_env_var("some_unknown_provider") + print(f"Unknown provider env var: {env_var_unknown}") + + +if __name__ == "__main__": + main() diff --git a/context/provider_manager_example.py b/context/provider_manager_example.py index abf84ac1c..298a27ef3 100644 --- a/context/provider_manager_example.py +++ b/context/provider_manager_example.py @@ -8,29 +8,41 @@ sys.path.append(str(project_root)) from pdd.setup.provider_manager import ( - add_api_key, + add_provider_from_registry, add_custom_provider, remove_models_by_provider, remove_individual_models, ) -from pdd.setup.api_key_scanner import scan_environment def main() -> None: """ Demonstrates how to use the provider_manager module to: - 1. Add an API key and auto-load all models for that provider + 1. Search/browse litellm's registry to add a provider and specific models 2. Add a custom LiteLLM-compatible provider 3. Remove all models for a provider (comments out the key) 4. Remove individual models from the user CSV """ - # First, scan the environment to see what's configured - scan_results = scan_environment() - - # Example 1: Add an API key (auto-loads all models for that provider) - # Shows missing keys, prompts for one, saves to api-env, copies CSV rows - # add_api_key(scan_results) # Uncomment to run interactively + # Example 1: Search/browse providers from litellm's registry + # Shows top ~10 providers, lets you search, pick models, enter API key + # add_provider_from_registry() # Uncomment to run interactively + + # Interactive flow: + # Top providers: + # 1. OpenAI (102 chat models) + # 2. Anthropic (29 chat models) + # ... + # Enter number, or type to search: anthropic + # + # Chat models for Anthropic: + # 1. claude-opus-4-5-20251101 $5.00 $25.00 200,000 + # 2. claude-sonnet-4-5-20250929 $3.00 $15.00 200,000 + # ... + # Select models: 1,2 + # + # ANTHROPIC_API_KEY: sk-ant-... + # ✓ Added 2 model(s) to ~/.pdd/llm_model.csv # Example 2: Add a custom provider (Together AI, Deepinfra, etc.) # Prompts for prefix, model name, API key var, base URL, costs diff --git a/context/setup_tool_example.py b/context/setup_tool_example.py new file mode 100644 index 000000000..0b03658e9 --- /dev/null +++ b/context/setup_tool_example.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +import sys +from pathlib import Path + +# Add the project root to sys.path +project_root = Path(__file__).resolve().parent.parent +sys.path.append(str(project_root)) + +from pdd.setup.setup_tool import run_setup + + +def main() -> None: + """ + Demonstrates how to use the setup_tool module to: + 1. Launch the interactive pdd setup wizard + 2. Scan for configured API keys + 3. Navigate the 6-option menu + + The setup wizard is fully interactive. Running it will: + - Scan ~/.pdd/llm_model.csv for configured models and their API key status + - Display a menu with options to add/remove providers, test models, etc. + - Loop until the user selects Done or presses Ctrl-C + """ + + # Run the interactive setup wizard + # run_setup() # Uncomment to run interactively + + # Example flow: + # ╭──────────────────────────────╮ + # │ pdd setup │ + # ╰──────────────────────────────╯ + # + # API-key scan + # ────────────────────────────────────────────────── + # ANTHROPIC_API_KEY ✓ Found (shell environment) + # OPENAI_API_KEY — Not found + # + # Models configured: 1 (from 1 API keys + 0 local) + # + # What would you like to do? + # 1. Add a provider + # 2. Remove models + # 3. Test a model + # 4. Detect CLI tools + # 5. Initialize .pddrc + # 6. Done + # + # Choice [1-6]: 1 + # + # Add a provider: + # a. Search providers + # b. Add a local LLM + # c. Add a custom provider + pass + + +if __name__ == "__main__": + main() diff --git a/pdd/prompts/api_key_scanner_python.prompt b/pdd/prompts/api_key_scanner_python.prompt index 66c968235..7f73dde72 100644 --- a/pdd/prompts/api_key_scanner_python.prompt +++ b/pdd/prompts/api_key_scanner_python.prompt @@ -1,4 +1,4 @@ -Discovers API keys from CSV providers, checking existence across shell, .env, and PDD config with source transparency. +Discovers API keys needed by the user's configured models, checking existence across shell, .env, and PDD config with source transparency. { @@ -15,26 +15,27 @@ % You are an expert Python engineer. Your goal is to write the pdd/setup/api_key_scanner.py module. % Role & Scope -Dynamically discovers API keys from all sources and reports their existence. Reads pdd/data/llm_model.csv to find all unique API key environment variable names, then checks .env files, shell environment, and ~/.pdd/api-env.* files. Only checks **existence** — never makes API calls or stores key values. +Discovers API keys needed by the user's configured models and reports their existence. Reads the user's `~/.pdd/llm_model.csv` to find all unique API key environment variable names, then checks .env files, shell environment, and ~/.pdd/api-env.* files. Only checks **existence** — never makes API calls or stores key values. % Requirements -1. Function: `scan_environment() -> Dict[str, KeyInfo]` — returns mapping of key name to KeyInfo(source, is_set). Does not store key values. -2. Function: `get_provider_key_names() -> List[str]` — returns deduplicated sorted list of all non-empty api_key values from the master CSV -3. Dynamic discovery: extract all unique api_key column values from pdd/data/llm_model.csv — no hardcoded provider list +1. Function: `scan_environment() -> Dict[str, KeyInfo]` — returns mapping of key name to KeyInfo(source, is_set). Does not store key values. Returns empty dict if no models are configured yet. +2. Function: `get_provider_key_names() -> List[str]` — returns deduplicated sorted list of all non-empty api_key values from the user's CSV (`~/.pdd/llm_model.csv`) +3. Dynamic discovery: extract all unique api_key column values from the user CSV — no hardcoded provider list 4. Check sources in priority order: - .env file (via python-dotenv `dotenv_values`, read-only — always reads fresh on each scan) - - Shell environment (`os.environ` — note: may include stale .env values if edited during session; restart pdd setup to refresh) + - Shell environment (`os.environ`) - ~/.pdd/api-env.{shell} (parse uncommented `export KEY=` lines) 5. KeyInfo: dataclass with fields `source` (str) and `is_set` (bool). Report source as "shell environment", ".env file", or "~/.pdd/api-env.zsh" (etc.) 6. Detect shell from SHELL env var for correct api-env file 7. Never raise exceptions — return best-effort results with logging for errors -8. Handle missing/malformed CSV gracefully (return empty dict) +8. Handle missing/malformed/empty CSV gracefully (return empty dict/list) % Dependencies -The CSV at pdd/data/llm_model.csv has columns: +The user CSV at ~/.pdd/llm_model.csv has columns: provider,model,input,output,coding_arena_elo,base_url,api_key,max_reasoning_tokens,structured_output,reasoning_type,location Rows with empty api_key are local LLMs (no key needed). +This file is created/managed by provider_manager when the user adds providers via `pdd setup`. % Deliverables diff --git a/pdd/prompts/litellm_registry_python.prompt b/pdd/prompts/litellm_registry_python.prompt new file mode 100644 index 000000000..8c7818507 --- /dev/null +++ b/pdd/prompts/litellm_registry_python.prompt @@ -0,0 +1,57 @@ +Wraps litellm's bundled model registry to provide provider search, model browsing, and API key env var lookup without network calls. + + +{ + "type": "module", + "module": { + "functions": [ + {"name": "is_litellm_available", "signature": "() -> bool", "returns": "bool"}, + {"name": "get_api_key_env_var", "signature": "(provider: str) -> Optional[str]", "returns": "Optional[str]"}, + {"name": "get_top_providers", "signature": "() -> List[ProviderInfo]", "returns": "List[ProviderInfo]"}, + {"name": "get_all_providers", "signature": "() -> List[ProviderInfo]", "returns": "List[ProviderInfo]"}, + {"name": "search_providers", "signature": "(query: str) -> List[ProviderInfo]", "returns": "List[ProviderInfo]"}, + {"name": "get_models_for_provider", "signature": "(provider: str) -> List[ModelInfo]", "returns": "List[ModelInfo]"} + ], + "dataclasses": [ + {"name": "ProviderInfo", "fields": ["name", "display_name", "api_key_env_var", "model_count", "sample_models"]}, + {"name": "ModelInfo", "fields": ["name", "litellm_id", "input_cost_per_million", "output_cost_per_million", "max_input_tokens", "max_output_tokens", "supports_vision", "supports_function_calling"]} + ] + } +} + + +% You are an expert Python engineer. Your goal is to write the pdd/setup/litellm_registry.py module. + +% Role & Scope +Thin wrapper around litellm's bundled data (`litellm.model_cost`, `litellm.models_by_provider`) for provider discovery and model browsing. Uses only locally bundled data — never makes network calls. Provides the data layer for the "Search providers" flow in `pdd setup`. + +% Requirements +1. Dataclass `ProviderInfo(name, display_name, api_key_env_var, model_count, sample_models)` — summary of a provider with up to 3 sample model names. +2. Dataclass `ModelInfo(name, litellm_id, input_cost_per_million, output_cost_per_million, max_input_tokens, max_output_tokens, supports_vision, supports_function_calling)` — metadata for a single model. +3. `is_litellm_available() -> bool` — guards against import failure, returns True only if litellm is importable and has model_cost data. +4. `get_api_key_env_var(provider) -> Optional[str]` — returns the API key env var name from a hardcoded mapping of ~25 common providers (e.g. "anthropic" → "ANTHROPIC_API_KEY"). Returns None for unknown providers. +5. `get_top_providers() -> List[ProviderInfo]` — returns a curated list of ~10 major cloud providers in a fixed display order. Falls back to all providers sorted by model count if curated list yields too few. +6. `get_all_providers() -> List[ProviderInfo]` — returns all providers with at least one chat model, sorted by model count descending. +7. `search_providers(query) -> List[ProviderInfo]` — case-insensitive substring match on provider name and display name. Empty query returns all providers. +8. `get_models_for_provider(provider) -> List[ModelInfo]` — returns chat-mode models sorted by name. Converts per-token costs to per-million costs. +9. Filter to `mode == "chat"` models only when browsing. +10. Handle vertex_ai sub-providers: aggregate all `litellm_provider` values starting with "vertex_ai" when querying vertex_ai. +11. Fallback: when `models_by_provider` entries aren't in `model_cost`, scan `model_cost` by `litellm_provider` field to catch mismatches (e.g. together_ai). +12. All litellm imports must be local (inside functions), guarded by `is_litellm_available()`. +13. Hardcoded `PROVIDER_DISPLAY_NAMES` dict for human-friendly names (e.g. "fireworks_ai" → "Fireworks AI"). Fallback: replace underscores and title-case. + +% Dependencies + +litellm.model_cost: Dict[str, dict] — ~2566 entries keyed by model ID. + Each entry has: litellm_provider, mode, input_cost_per_token, output_cost_per_token, + max_input_tokens, max_output_tokens, supports_vision, supports_function_calling, + supports_response_schema, and more. + +litellm.models_by_provider: Dict[str, Set[str]] — 86 providers mapped to sets of model names. + +Note: model_cost has NO api_key_env_var field. Provider-to-key mapping must be hardcoded. + + +% Deliverables +- Module at `pdd/setup/litellm_registry.py` exporting `ProviderInfo`, `ModelInfo`, `is_litellm_available`, `get_api_key_env_var`, `get_top_providers`, `get_all_providers`, `search_providers`, `get_models_for_provider`. +- Also exports constants `PROVIDER_API_KEY_MAP` and `PROVIDER_DISPLAY_NAMES` for use by other modules. diff --git a/pdd/prompts/provider_manager_python.prompt b/pdd/prompts/provider_manager_python.prompt index 520e4d8b9..8f57e8ec9 100644 --- a/pdd/prompts/provider_manager_python.prompt +++ b/pdd/prompts/provider_manager_python.prompt @@ -1,11 +1,11 @@ -Manages LLM providers: adding API keys with auto-loaded models, custom providers, and model removal. +Manages LLM providers: search/browse registry to add models, custom providers, and model removal. { "type": "module", "module": { "functions": [ - {"name": "add_api_key", "signature": "(scan_results: Dict[str, KeyInfo]) -> bool", "returns": "bool"}, + {"name": "add_provider_from_registry", "signature": "() -> bool", "returns": "bool"}, {"name": "add_custom_provider", "signature": "() -> bool", "returns": "bool"}, {"name": "remove_models_by_provider", "signature": "() -> bool", "returns": "bool"}, {"name": "remove_individual_models", "signature": "() -> bool", "returns": "bool"} @@ -14,32 +14,51 @@ } -api_key_scanner_python.prompt +litellm_registry_python.prompt % You are an expert Python engineer. Your goal is to write the pdd/setup/provider_manager.py module. % Role & Scope -Handles adding and removing LLM providers and models in PDD setup. Supports entering API keys for known providers (auto-loading all their models), adding custom LiteLLM-compatible providers, and two modes of model removal. +Handles adding and removing LLM providers and models in PDD setup. The primary flow uses litellm's bundled model registry to let users search/browse providers, pick specific models, and enter API keys. Also supports adding custom LiteLLM-compatible providers and two modes of model removal. % Requirements -1. `add_api_key(scan_results)` — show missing keys from scan_results, prompt user for one, save to `~/.pdd/api-env.{shell}`, then copy ALL matching rows from master CSV (pdd/data/llm_model.csv) into user CSV (`~/.pdd/llm_model.csv`). No interactive model selection. If key already exists (replacing), update api-env only. If key name has no CSV rows, tell user to use "Add a custom provider" instead. Skip saving keys already in environment. -2. `add_custom_provider()` — prompt for provider prefix, model name, API key env var, base URL (optional), costs (optional). Append row to user CSV with sensible defaults. Save API key to api-env if provided. +1. `add_provider_from_registry()` — interactive search/browse flow using litellm_registry: + a. Show top ~10 providers (from `get_top_providers()`). Accept a number for direct selection, or text to search via `search_providers()`. + b. Display selected provider's chat models in a Rich table with columns: #, Model, Input $/M, Output $/M, Max Input. Accept comma-separated numbers or "all". + c. Look up API key env var via `get_api_key_env_var()`. If already set, offer to use existing. If not set, prompt for value. If provider not in map, ask user for env var name. Save to `~/.pdd/api-env.{shell}`. + d. Append selected models to `~/.pdd/llm_model.csv`, skip duplicates (by provider+model pair). Map litellm data to CSV schema. + e. Return True if any models were added. + +2. `add_custom_provider()` — prompt for provider prefix, model name, API key env var, base URL (optional), costs (optional). Append row to user CSV with sensible defaults. Save API key to api-env if provided. Unchanged from previous version. + 3. `remove_models_by_provider()` — group user CSV models by api_key, show numbered list with counts, remove all rows for selected provider. Comment out (never delete) the key in api-env: `# Commented out by pdd setup on YYYY-MM-DD`. + 4. `remove_individual_models()` — list all models from user CSV, let user select by comma-separated numbers, remove selected rows. -5. All CSV writes must be atomic (temp file + rename) -6. Detect shell from SHELL env var for api-env file path -7. Handle empty input as cancel/back + +5. All CSV writes must be atomic (temp file + rename). +6. Detect shell from SHELL env var for api-env file path. +7. Handle empty input as cancel/back. +8. Check if litellm is available before registry flow; suggest custom provider as fallback. % Dependencies - - context/api_key_scanner_example.py - + + context/litellm_registry_example.py + The master CSV at pdd/data/llm_model.csv and user CSV at ~/.pdd/llm_model.csv share columns: provider,model,input,output,coding_arena_elo,base_url,api_key,max_reasoning_tokens,structured_output,reasoning_type,location + +CSV mapping from litellm registry: + provider → ProviderInfo.display_name (e.g. "Anthropic") + model → ModelInfo.litellm_id (e.g. "claude-sonnet-4-5-20250929") + input/output → ModelInfo costs (per-million) + coding_arena_elo → "1000" default + api_key → env var name (e.g. "ANTHROPIC_API_KEY") + structured_output → from supports_function_calling + Others default to empty string or "0" % Deliverables -- Module at `pdd/setup/provider_manager.py` exporting `add_api_key`, `add_custom_provider`, `remove_models_by_provider`, `remove_individual_models`. +- Module at `pdd/setup/provider_manager.py` exporting `add_provider_from_registry`, `add_custom_provider`, `remove_models_by_provider`, `remove_individual_models`. diff --git a/pdd/prompts/setup_tool_python.prompt b/pdd/prompts/setup_tool_python.prompt index 6bf7b88d5..993dd9ce0 100644 --- a/pdd/prompts/setup_tool_python.prompt +++ b/pdd/prompts/setup_tool_python.prompt @@ -13,6 +13,7 @@ api_key_scanner_python.prompt provider_manager_python.prompt +litellm_registry_python.prompt local_llm_configurator_python.prompt model_tester_python.prompt cli_detector_python.prompt @@ -21,19 +22,19 @@ % You are an expert Python engineer. Your goal is to write the pdd/setup/setup_tool.py module. % Role & Scope -Main orchestrator for `pdd setup`. Auto-scans the environment for API keys (existence only — no API calls), then presents an interactive menu. After any action, the menu re-displays with an updated scan. +Main orchestrator for `pdd setup`. Auto-scans the environment for API keys based on the user's configured models (existence only — no API calls), then presents an interactive menu. After any action, the menu re-displays with an updated scan. When no models are configured yet, displays a helpful "get started" message. % Requirements 1. Function: `run_setup()` — main entry point -2. On entry, call `api_key_scanner.scan_environment()` and display each key with `✓ Found (source)` or `— Not found`, plus a summary line: `Models configured: N (from M API keys + K local)` +2. On entry, call `api_key_scanner.scan_environment()` and display each key with `✓ Found (source)` or `— Not found`, plus a summary line: `Models configured: N (from M API keys + K local)`. If scan_results is empty, display: "No models configured yet. Use 'Add a provider' to get started." 3. Present a 6-option menu after the scan: - 1. Add a provider (sub-menu: a. Enter an API key, b. Add a local LLM, c. Add a custom provider) + 1. Add a provider (sub-menu: a. Search providers, b. Add a local LLM, c. Add a custom provider) 2. Remove models (sub-menu: a. By provider, b. Individual models) 3. Test a model 4. Detect CLI tools 5. Initialize .pddrc 6. Done -4. Delegate to: `provider_manager.add_api_key`, `local_llm_configurator.configure_local_llm`, `provider_manager.add_custom_provider`, `provider_manager.remove_models_by_provider`, `provider_manager.remove_individual_models`, `model_tester.test_model_interactive`, `cli_detector.detect_cli_tools`, `pddrc_initializer.offer_pddrc_init` +4. Delegate to: `provider_manager.add_provider_from_registry`, `local_llm_configurator.configure_local_llm`, `provider_manager.add_custom_provider`, `provider_manager.remove_models_by_provider`, `provider_manager.remove_individual_models`, `model_tester.test_model_interactive`, `cli_detector.detect_cli_tools`, `pddrc_initializer.offer_pddrc_init` 5. After options 1–5, re-scan and re-display the menu 6. Option 6 exits the loop 7. Handle KeyboardInterrupt for clean exit at any point @@ -47,6 +48,10 @@ Main orchestrator for `pdd setup`. Auto-scans the environment for API keys (exis context/provider_manager_example.py + + context/litellm_registry_example.py + + context/local_llm_configurator_example.py @@ -64,9 +69,9 @@ Main orchestrator for `pdd setup`. Auto-scans the environment for API keys (exis -The CSV at pdd/data/llm_model.csv has columns: +The user-level CSV at ~/.pdd/llm_model.csv has columns: provider,model,input,output,coding_arena_elo,base_url,api_key,max_reasoning_tokens,structured_output,reasoning_type,location -The user-level CSV is at ~/.pdd/llm_model.csv. +This file is created when the user first adds a provider via setup. % Deliverables diff --git a/pdd/setup/api_key_scanner.py b/pdd/setup/api_key_scanner.py index 8bf160549..a56be8a75 100644 --- a/pdd/setup/api_key_scanner.py +++ b/pdd/setup/api_key_scanner.py @@ -1,8 +1,8 @@ """ pdd/setup/api_key_scanner.py -Discovers API keys from CSV providers, checking existence across -shell, .env, and PDD config with source transparency. +Discovers API keys needed by the user's configured models, checking +existence across shell, .env, and PDD config with source transparency. """ import csv @@ -23,26 +23,28 @@ class KeyInfo: def _get_csv_path() -> Path: - """Return the path to the master llm_model.csv file.""" - # Navigate from this file's location to pdd/data/llm_model.csv - module_dir = Path(__file__).resolve().parent # pdd/setup/ - pdd_dir = module_dir.parent # pdd/ - return pdd_dir / "data" / "llm_model.csv" + """Return the path to the user's configured llm_model.csv. + + Reads from ``~/.pdd/llm_model.csv`` so the scan reflects which + API keys the user's configured models actually need, rather than + an arbitrary hardcoded list. + """ + return Path.home() / ".pdd" / "llm_model.csv" def get_provider_key_names() -> List[str]: """ Returns a deduplicated, sorted list of all non-empty api_key values - from the master CSV (pdd/data/llm_model.csv). + from the user's configured CSV (~/.pdd/llm_model.csv). - Returns an empty list if the CSV is missing or malformed. + Returns an empty list if the CSV is missing, empty, or malformed. """ csv_path = _get_csv_path() key_names: set = set() try: if not csv_path.exists(): - logger.warning("llm_model.csv not found at %s", csv_path) + logger.debug("User CSV not found at %s (no models configured yet).", csv_path) return [] with open(csv_path, "r", newline="", encoding="utf-8") as f: @@ -142,17 +144,17 @@ def _parse_api_env_file(file_path: Path) -> Dict[str, str]: def scan_environment() -> Dict[str, KeyInfo]: """ - Scan for API key existence across all known sources. + Scan for API key existence based on the user's configured models. - Checks sources in priority order: + Reads API key names from ``~/.pdd/llm_model.csv`` and checks their + existence in priority order: 1. .env file (via python-dotenv dotenv_values, read-only) - 2. Shell environment (os.environ - note: may include stale .env values if edited during session) + 2. Shell environment (os.environ) 3. ~/.pdd/api-env.{shell} file Returns a mapping of key name -> KeyInfo(source, is_set). + Returns an empty dict if no models are configured yet. Never raises exceptions; returns best-effort results. - - Note: If you edit .env during a pdd setup session, restart pdd setup to see updated shell environment. """ result: Dict[str, KeyInfo] = {} diff --git a/pdd/setup/litellm_registry.py b/pdd/setup/litellm_registry.py new file mode 100644 index 000000000..fcff99f86 --- /dev/null +++ b/pdd/setup/litellm_registry.py @@ -0,0 +1,312 @@ +""" +pdd/setup/litellm_registry.py + +Wraps litellm's bundled model registry to provide provider search, model +browsing, and API key env var lookup. Uses only local data — no network calls. +""" +from __future__ import annotations + +import logging +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Set + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Dataclasses +# --------------------------------------------------------------------------- + + +@dataclass +class ProviderInfo: + """Summary information about an LLM provider.""" + + name: str # litellm provider ID, e.g. "anthropic" + display_name: str # human-friendly, e.g. "Anthropic" + api_key_env_var: Optional[str] # e.g. "ANTHROPIC_API_KEY" + model_count: int # number of chat models available + sample_models: List[str] = field(default_factory=list) # up to 3 names + + +@dataclass +class ModelInfo: + """Metadata for a single LLM model from litellm's registry.""" + + name: str # short display name, e.g. "claude-opus-4-5" + litellm_id: str # full ID for litellm.completion() + input_cost_per_million: float # USD per 1M input tokens + output_cost_per_million: float # USD per 1M output tokens + max_input_tokens: Optional[int] = None + max_output_tokens: Optional[int] = None + supports_vision: bool = False + supports_function_calling: bool = False + + +# --------------------------------------------------------------------------- +# Static mappings +# --------------------------------------------------------------------------- + +PROVIDER_API_KEY_MAP: Dict[str, str] = { + "openai": "OPENAI_API_KEY", + "anthropic": "ANTHROPIC_API_KEY", + "gemini": "GEMINI_API_KEY", + "vertex_ai": "VERTEX_CREDENTIALS", + "groq": "GROQ_API_KEY", + "mistral": "MISTRAL_API_KEY", + "deepseek": "DEEPSEEK_API_KEY", + "fireworks_ai": "FIREWORKS_API_KEY", + "together_ai": "TOGETHERAI_API_KEY", + "perplexity": "PERPLEXITYAI_API_KEY", + "cohere": "COHERE_API_KEY", + "cohere_chat": "COHERE_API_KEY", + "replicate": "REPLICATE_API_KEY", + "xai": "XAI_API_KEY", + "deepinfra": "DEEPINFRA_API_KEY", + "cerebras": "CEREBRAS_API_KEY", + "ai21": "AI21_API_KEY", + "bedrock": "AWS_ACCESS_KEY_ID", + "azure": "AZURE_API_KEY", + "azure_ai": "AZURE_AI_API_KEY", + "openrouter": "OPENROUTER_API_KEY", + "huggingface": "HUGGINGFACE_API_KEY", + "databricks": "DATABRICKS_API_KEY", + "cloudflare": "CLOUDFLARE_API_KEY", + "novita": "NOVITA_API_KEY", + "sambanova": "SAMBANOVA_API_KEY", + "watsonx": "WATSONX_API_KEY", +} + +PROVIDER_DISPLAY_NAMES: Dict[str, str] = { + "openai": "OpenAI", + "anthropic": "Anthropic", + "gemini": "Google Gemini", + "vertex_ai": "Google Vertex AI", + "groq": "Groq", + "mistral": "Mistral AI", + "deepseek": "DeepSeek", + "fireworks_ai": "Fireworks AI", + "together_ai": "Together AI", + "perplexity": "Perplexity", + "cohere": "Cohere", + "cohere_chat": "Cohere Chat", + "replicate": "Replicate", + "xai": "xAI", + "deepinfra": "DeepInfra", + "cerebras": "Cerebras", + "ai21": "AI21", + "bedrock": "AWS Bedrock", + "azure": "Azure OpenAI", + "azure_ai": "Azure AI", + "openrouter": "OpenRouter", + "huggingface": "Hugging Face", + "databricks": "Databricks", + "cloudflare": "Cloudflare Workers AI", + "novita": "Novita AI", + "sambanova": "SambaNova", + "watsonx": "IBM watsonx", +} + +# Curated list of major cloud providers shown by default. +_TOP_PROVIDER_IDS = [ + "openai", + "anthropic", + "gemini", + "fireworks_ai", + "mistral", + "xai", + "groq", + "deepseek", + "together_ai", + "openrouter", +] + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + + +def _get_display_name(provider: str) -> str: + """Return the human-friendly name for *provider*, with a title-case fallback.""" + return PROVIDER_DISPLAY_NAMES.get(provider, provider.replace("_", " ").title()) + + +def _collect_chat_models_for_provider(provider: str) -> Dict[str, dict]: + """Return ``{model_id: cost_entry}`` for all chat models belonging to *provider*. + + Handles the ``vertex_ai`` sub-provider convention by matching any + ``litellm_provider`` that starts with *provider*. + + Falls back to scanning ``litellm.model_cost`` directly when + ``models_by_provider`` entries are missing from ``model_cost``. + """ + import litellm # local import — guarded by is_litellm_available() + + result: Dict[str, dict] = {} + + # Strategy 1: use models_by_provider set, then look up cost data. + model_names: Set[str] = set() + if provider in litellm.models_by_provider: + model_names = set(litellm.models_by_provider[provider]) + + for name in model_names: + entry = litellm.model_cost.get(name) + if entry and entry.get("mode") == "chat": + result[name] = entry + + # Strategy 2 (fallback): scan model_cost for provider match. + # Needed for together_ai and vertex_ai sub-providers. + for model_id, entry in litellm.model_cost.items(): + if model_id in result: + continue + lp = entry.get("litellm_provider", "") + if lp == provider or lp.startswith(f"{provider}-"): + if entry.get("mode") == "chat": + result[model_id] = entry + + return result + + +def _entry_to_model_info(model_id: str, entry: dict) -> ModelInfo: + """Convert a ``litellm.model_cost`` entry into a :class:`ModelInfo`.""" + input_per_token = entry.get("input_cost_per_token") or 0 + output_per_token = entry.get("output_cost_per_token") or 0 + return ModelInfo( + name=model_id.split("/")[-1] if "/" in model_id else model_id, + litellm_id=model_id, + input_cost_per_million=round(input_per_token * 1_000_000, 4), + output_cost_per_million=round(output_per_token * 1_000_000, 4), + max_input_tokens=entry.get("max_input_tokens"), + max_output_tokens=entry.get("max_output_tokens"), + supports_vision=bool(entry.get("supports_vision")), + supports_function_calling=bool(entry.get("supports_function_calling")), + ) + + +def _build_provider_info(provider: str) -> Optional[ProviderInfo]: + """Build a :class:`ProviderInfo` for *provider*, or ``None`` if it has no chat models.""" + chat_models = _collect_chat_models_for_provider(provider) + if not chat_models: + return None + sample = sorted(chat_models.keys())[:3] + return ProviderInfo( + name=provider, + display_name=_get_display_name(provider), + api_key_env_var=PROVIDER_API_KEY_MAP.get(provider), + model_count=len(chat_models), + sample_models=sample, + ) + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +def is_litellm_available() -> bool: + """Return ``True`` if litellm is importable and has model data.""" + try: + import litellm + + return bool(litellm.model_cost) + except Exception: + return False + + +def get_api_key_env_var(provider: str) -> Optional[str]: + """Return the API key environment variable name for *provider*. + + Returns ``None`` if the provider is not in the known mapping. + """ + return PROVIDER_API_KEY_MAP.get(provider) + + +def get_top_providers() -> List[ProviderInfo]: + """Return a curated list of major cloud providers, sorted by the curated order. + + Falls back to all providers sorted by model count if the curated list + yields fewer than 3 results (e.g. litellm data changed). + """ + if not is_litellm_available(): + return [] + + result: List[ProviderInfo] = [] + for pid in _TOP_PROVIDER_IDS: + info = _build_provider_info(pid) + if info: + result.append(info) + + if len(result) < 3: + return get_all_providers()[:10] + return result + + +def get_all_providers() -> List[ProviderInfo]: + """Return all providers that have at least one chat model. + + Sorted by model count descending. + """ + if not is_litellm_available(): + return [] + + import litellm + + seen: Set[str] = set() + infos: List[ProviderInfo] = [] + + # Collect from models_by_provider keys + for provider in litellm.models_by_provider: + if provider in seen: + continue + seen.add(provider) + info = _build_provider_info(provider) + if info: + infos.append(info) + + # Also scan model_cost for providers not in models_by_provider + for entry in litellm.model_cost.values(): + lp = entry.get("litellm_provider", "") + # Normalise vertex_ai sub-providers + base = lp.split("-")[0] if "-" in lp else lp + if base and base not in seen: + seen.add(base) + info = _build_provider_info(base) + if info: + infos.append(info) + + infos.sort(key=lambda i: i.model_count, reverse=True) + return infos + + +def search_providers(query: str) -> List[ProviderInfo]: + """Return providers whose name or display name contains *query* (case-insensitive). + + Sorted by model count descending. + """ + if not query: + return get_all_providers() + + all_providers = get_all_providers() + q = query.lower() + return [ + p + for p in all_providers + if q in p.name.lower() or q in p.display_name.lower() + ] + + +def get_models_for_provider(provider: str) -> List[ModelInfo]: + """Return chat-mode models for *provider*, sorted by name. + + Converts per-token costs to per-million-token costs. + """ + if not is_litellm_available(): + return [] + + chat_models = _collect_chat_models_for_provider(provider) + result = [ + _entry_to_model_info(model_id, entry) + for model_id, entry in chat_models.items() + ] + result.sort(key=lambda m: m.name) + return result diff --git a/pdd/setup/provider_manager.py b/pdd/setup/provider_manager.py index 65d0ccfee..9b01c879b 100644 --- a/pdd/setup/provider_manager.py +++ b/pdd/setup/provider_manager.py @@ -8,13 +8,21 @@ import shutil from datetime import datetime from pathlib import Path -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional from rich.console import Console from rich.table import Table from rich.prompt import Prompt, Confirm -from pdd.setup.api_key_scanner import KeyInfo +from pdd.setup.litellm_registry import ( + is_litellm_available, + get_top_providers, + search_providers, + get_models_for_provider, + get_api_key_env_var, + ProviderInfo, + ModelInfo, +) console = Console() @@ -57,9 +65,6 @@ def _get_user_csv_path() -> Path: return _get_pdd_dir() / "llm_model.csv" -def _get_master_csv_path() -> Path: - """Return path to pdd/data/llm_model.csv (shipped with the package).""" - return Path(__file__).resolve().parent.parent / "data" / "llm_model.csv" # --------------------------------------------------------------------------- @@ -184,107 +189,256 @@ def _comment_out_key_in_api_env(key_name: str) -> None: _write_api_env_atomic(env_path, new_lines) -def _get_master_rows_for_key(api_key_name: str) -> List[Dict[str, str]]: - """Return all rows from master CSV whose api_key column matches.""" - master_rows = _read_csv(_get_master_csv_path()) - return [r for r in master_rows if r.get("api_key", "").strip() == api_key_name] +# --------------------------------------------------------------------------- +# Key-existence check (used by add_provider_from_registry) +# --------------------------------------------------------------------------- + +def _is_key_set(key_name: str) -> Optional[str]: + """Return the source label if *key_name* is set, else ``None``. + + Checks .env (via python-dotenv), shell environment, and api-env file. + """ + try: + from dotenv import dotenv_values # type: ignore + dotenv_vals = dotenv_values() + if key_name in dotenv_vals and dotenv_vals[key_name] is not None: + return ".env file" + except Exception: + pass + + if os.environ.get(key_name): + return "shell environment" + + env_path = _get_api_env_path() + if env_path.exists(): + from pdd.setup.api_key_scanner import _parse_api_env_file + api_env_vals = _parse_api_env_file(env_path) + if key_name in api_env_vals: + return f"~/.pdd/{env_path.name}" + + return None # --------------------------------------------------------------------------- # Public API # --------------------------------------------------------------------------- -def add_api_key(scan_results: Dict[str, KeyInfo]) -> bool: +def add_provider_from_registry() -> bool: """ - Show missing keys from scan_results, prompt user for one, save to - api-env file, then copy ALL matching rows from master CSV into user CSV. + Search/browse LiteLLM's model registry, let the user pick a provider + and specific models, handle the API key, and save to user CSV. - Returns True if a key was added/updated, False if cancelled. + Returns True if any models were added, False if cancelled. """ - # Separate missing and found keys - missing_keys = {k: v for k, v in scan_results.items() if not v.is_set} - found_keys = {k: v for k, v in scan_results.items() if v.is_set} - - if not missing_keys: - console.print("[green]All known API keys are already configured.[/green]") + if not is_litellm_available(): + console.print( + "[red]litellm is required but not installed or has no model data.[/red]\n" + "[yellow]Run: pip install litellm[/yellow]\n" + "[yellow]Or use 'Add a custom provider' instead.[/yellow]" + ) return False - # Display missing keys - console.print("\n[bold]Missing API keys:[/bold]") - sorted_missing = sorted(missing_keys.keys()) - for idx, key_name in enumerate(sorted_missing, 1): - console.print(f" {idx}. {key_name}") + # ── Step 1: Browse / Search providers ────────────────────────────── + + top = get_top_providers() + + console.print("\n[bold]Search providers[/bold]\n") + console.print(" Top providers:") + for idx, p in enumerate(top, 1): + console.print( + f" {idx:>2}. {p.display_name:20s} ({p.model_count} chat models)" + ) + console.print() - # Prompt user to select one selection = Prompt.ask( - "\nEnter the number of the key to add (or press Enter to cancel)" + "Enter number, or type to search (empty to cancel)" ) if not selection.strip(): console.print("[dim]Cancelled.[/dim]") return False + selected_provider: Optional[ProviderInfo] = None + + # Try as a number first (direct selection from top list) try: choice = int(selection.strip()) - if choice < 1 or choice > len(sorted_missing): - console.print("[red]Invalid selection.[/red]") - return False + if 1 <= choice <= len(top): + selected_provider = top[choice - 1] except ValueError: - console.print("[red]Invalid input.[/red]") - return False + pass - selected_key = sorted_missing[choice - 1] + # If not a valid number, treat as search query + if selected_provider is None: + results = search_providers(selection.strip()) + if not results: + console.print( + f"[yellow]No providers matching '{selection.strip()}'.[/yellow]\n" + "[yellow]Try a different search, or use 'Add a custom provider'.[/yellow]" + ) + return False - # Check if key is already in environment (shouldn't be since it's in missing, but guard) - if os.environ.get(selected_key): + if len(results) == 1: + selected_provider = results[0] + else: + console.print(f"\n Found {len(results)} provider(s):") + for idx, p in enumerate(results, 1): + console.print( + f" {idx:>2}. {p.display_name:20s} ({p.model_count} chat models)" + ) + console.print() + + pick = Prompt.ask("Select provider number (empty to cancel)") + if not pick.strip(): + console.print("[dim]Cancelled.[/dim]") + return False + try: + pick_idx = int(pick.strip()) + if 1 <= pick_idx <= len(results): + selected_provider = results[pick_idx - 1] + else: + console.print("[red]Invalid selection.[/red]") + return False + except ValueError: + console.print("[red]Invalid input.[/red]") + return False + + assert selected_provider is not None + + # ── Step 2: Model selection ──────────────────────────────────────── + + models = get_models_for_provider(selected_provider.name) + if not models: console.print( - f"[yellow]{selected_key} is already set in the current environment. Skipping.[/yellow]" + f"[yellow]No chat models found for {selected_provider.display_name} " + f"in litellm's registry.[/yellow]\n" + "[yellow]Use 'Add a custom provider' instead.[/yellow]" ) return False - # Check if master CSV has rows for this key - master_rows = _get_master_rows_for_key(selected_key) - if not master_rows: - console.print( - f"[yellow]No models found in master CSV for '{selected_key}'.[/yellow]\n" - f"[yellow]Please use 'Add a custom provider' instead.[/yellow]" - ) - return False + table = Table(title=f"Chat models for {selected_provider.display_name}") + table.add_column("#", style="bold", width=4) + table.add_column("Model") + table.add_column("Input $/M", justify="right") + table.add_column("Output $/M", justify="right") + table.add_column("Max Input", justify="right") + + for idx, m in enumerate(models, 1): + input_cost = f"${m.input_cost_per_million:.2f}" if m.input_cost_per_million else "$0.00" + output_cost = f"${m.output_cost_per_million:.2f}" if m.output_cost_per_million else "$0.00" + max_input = f"{m.max_input_tokens:,}" if m.max_input_tokens else "—" + table.add_row(str(idx), m.litellm_id, input_cost, output_cost, max_input) - # Prompt for the actual key value - key_value = Prompt.ask(f"Enter the value for {selected_key}") - if not key_value.strip(): + console.print(table) + console.print() + + model_selection = Prompt.ask( + "Select models (comma-separated numbers, 'all', or empty to cancel)" + ) + if not model_selection.strip(): console.print("[dim]Cancelled.[/dim]") return False - key_value = key_value.strip() - # Save key to api-env file - _save_key_to_api_env(selected_key, key_value) - console.print(f"[green]Saved {selected_key} to {_get_api_env_path()}[/green]") + selected_models: List[ModelInfo] = [] + + if model_selection.strip().lower() == "all": + selected_models = list(models) + else: + for part in model_selection.split(","): + part = part.strip() + if not part: + continue + try: + num = int(part) + if 1 <= num <= len(models): + selected_models.append(models[num - 1]) + else: + console.print(f"[yellow]Skipping invalid number: {num}[/yellow]") + except ValueError: + console.print(f"[yellow]Skipping invalid input: '{part}'[/yellow]") + + if not selected_models: + console.print("[dim]No valid selections. Cancelled.[/dim]") + return False + + # ── Step 3: API key ──────────────────────────────────────────────── + + api_key_var = selected_provider.api_key_env_var + + if api_key_var is None: + # Provider not in our known mapping — ask the user + api_key_var = Prompt.ask( + f"API key env var for {selected_provider.display_name} " + "(e.g. PROVIDER_API_KEY, or empty to skip)" + ).strip() or None + + if api_key_var: + existing_source = _is_key_set(api_key_var) + if existing_source: + console.print( + f" [green]{api_key_var} is already set ({existing_source}).[/green]" + ) + if Confirm.ask("Update the key?", default=False): + key_value = Prompt.ask(f"Enter new value for {api_key_var}") + if key_value.strip(): + _save_key_to_api_env(api_key_var, key_value.strip()) + console.print( + f"[green]Updated {api_key_var} in {_get_api_env_path()}[/green]" + ) + else: + key_value = Prompt.ask(f"Enter the value for {api_key_var}") + if key_value.strip(): + _save_key_to_api_env(api_key_var, key_value.strip()) + console.print( + f"[green]Saved {api_key_var} to {_get_api_env_path()}[/green]" + ) + + # ── Step 4: Write to user CSV ────────────────────────────────────── - # Copy matching rows from master CSV into user CSV user_csv_path = _get_user_csv_path() - existing_user_rows = _read_csv(user_csv_path) + existing_rows = _read_csv(user_csv_path) # Build set of existing model identifiers to avoid duplicates - existing_models = { + existing_model_ids = { (r.get("provider", ""), r.get("model", "")) - for r in existing_user_rows + for r in existing_rows } added_count = 0 - for row in master_rows: - model_id = (row.get("provider", ""), row.get("model", "")) - if model_id not in existing_models: - existing_user_rows.append(row) - existing_models.add(model_id) + for m in selected_models: + # Build the litellm model ID with provider prefix convention + csv_model = m.litellm_id + + new_row: Dict[str, str] = { + "provider": selected_provider.display_name, + "model": csv_model, + "input": str(m.input_cost_per_million), + "output": str(m.output_cost_per_million), + "coding_arena_elo": "1000", + "base_url": "", + "api_key": api_key_var or "", + "max_reasoning_tokens": "0", + "structured_output": str(m.supports_function_calling), + "reasoning_type": "", + "location": "", + } + + model_id = (new_row["provider"], new_row["model"]) + if model_id not in existing_model_ids: + existing_rows.append(new_row) + existing_model_ids.add(model_id) added_count += 1 + else: + console.print(f" [dim]Skipping duplicate: {csv_model}[/dim]") - _write_csv_atomic(user_csv_path, existing_user_rows) - console.print( - f"[green]Added {added_count} model(s) for {selected_key} to {user_csv_path}[/green]" - ) + if added_count > 0: + _write_csv_atomic(user_csv_path, existing_rows) + console.print( + f"[green]Added {added_count} model(s) to {user_csv_path}[/green]" + ) + else: + console.print("[yellow]No new models were added (all already configured).[/yellow]") - return True + return added_count > 0 def add_custom_provider() -> bool: diff --git a/pdd/setup/setup_tool.py b/pdd/setup/setup_tool.py index 23cc43eca..dd4f3e5ed 100644 --- a/pdd/setup/setup_tool.py +++ b/pdd/setup/setup_tool.py @@ -11,7 +11,7 @@ from .api_key_scanner import scan_environment, KeyInfo from .provider_manager import ( - add_api_key, + add_provider_from_registry, add_custom_provider, remove_models_by_provider, remove_individual_models, @@ -31,6 +31,12 @@ def _display_scan(scan_results: Dict[str, KeyInfo]) -> None: print("\n API-key scan") print(" " + "─" * 50) + if not scan_results: + print(" No models configured yet.") + print(" Use 'Add a provider' to get started.") + print() + return + api_found = 0 local_count = 0 @@ -67,11 +73,11 @@ def _display_menu() -> None: print() -def _add_provider_submenu(scan_results: Dict[str, KeyInfo]) -> None: +def _add_provider_submenu() -> None: """Sub-menu for option 1 — Add a provider.""" print() print(" Add a provider:") - print(" a. Enter an API key") + print(" a. Search providers") print(" b. Add a local LLM") print(" c. Add a custom provider") print() @@ -79,7 +85,7 @@ def _add_provider_submenu(scan_results: Dict[str, KeyInfo]) -> None: sub_choice = input(" Choice [a/b/c]: ").strip().lower() if sub_choice == "a": - add_api_key(scan_results) + add_provider_from_registry() elif sub_choice == "b": configure_local_llm() elif sub_choice == "c": @@ -131,7 +137,7 @@ def run_setup() -> None: choice = input(" Choice [1-6]: ").strip() if choice == "1": - _add_provider_submenu(scan_results) + _add_provider_submenu() elif choice == "2": _remove_models_submenu() elif choice == "3": From df7c405d757be4fed356572bdd4dcda32dbc4585 Mon Sep 17 00:00:00 2001 From: Niti Goyal Date: Mon, 16 Feb 2026 10:34:51 -0500 Subject: [PATCH 05/10] Bug fixes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Fixed misleading "local — no key required" message: Changed to "(no key configured)" in yellow, since an empty api_key field doesn't necessarily mean local—it could mean the user skipped entering the env var name, and litellm will fall back to its own defaults. - Expanded provider → API key mapping: Added 30 new providers to PROVIDER_API_KEY_MAP (including moonshot → MOONSHOT_API_KEY), so the setup flow now knows which env var to prompt for when users add models from lesser-known providers. - Fixed a shell escaping bug in provider_manager.py discovered by the new tests: changed _save_key_to_api_env to use shlex.quote() instead of simple double-quoting, which was breaking API keys containing special characters like $, ", ', and backticks - Immediate session availability: API keys are now loaded into os.environ when saved, allowing users to test models immediately within the same pdd setup session without restarting their shell - Automatic shell RC integration: The source line (source ~/.pdd/api-env.{shell}) is automatically added to the user's shell startup file (~/.zshrc, ~/.bashrc, etc.) with shell-appropriate syntax, so new terminal sessions have the API keys available without manual configuration - Created 4 comprehensive test files (165 tests total) for the new setup modules: test_api_key_scanner.py, test_litellm_registry.py, test_provider_manager.py, and test_cli_detector.py — including rigorous shell execution tests that actually run bash/zsh to verify API key escaping works correctly --- context/provider_manager_example.py | 7 + context/setup_tool_example.py | 2 + pdd/prompts/model_tester_python.prompt | 4 +- pdd/prompts/provider_manager_python.prompt | 10 +- pdd/prompts/setup_tool_python.prompt | 2 +- pdd/setup/model_tester.py | 13 +- pdd/setup/provider_manager.py | 146 ++- pdd/setup/setup_tool.py | 6 + tests/test_api_key_scanner.py | 551 +++++++++++ tests/test_cli_detector.py | 489 ++++++++++ tests/test_litellm_registry.py | 561 +++++++++++ tests/test_provider_manager.py | 1014 ++++++++++++++++++++ 12 files changed, 2782 insertions(+), 23 deletions(-) create mode 100644 tests/test_api_key_scanner.py create mode 100644 tests/test_cli_detector.py create mode 100644 tests/test_litellm_registry.py create mode 100644 tests/test_provider_manager.py diff --git a/context/provider_manager_example.py b/context/provider_manager_example.py index 298a27ef3..8e9769356 100644 --- a/context/provider_manager_example.py +++ b/context/provider_manager_example.py @@ -42,7 +42,14 @@ def main() -> None: # Select models: 1,2 # # ANTHROPIC_API_KEY: sk-ant-... + # ✓ Saved ANTHROPIC_API_KEY to ~/.pdd/api-env.zsh + # ✓ Added source line to ~/.zshrc + # Key is available now for this session. # ✓ Added 2 model(s) to ~/.pdd/llm_model.csv + # + # NOTE: The API key is immediately available in the current session via os.environ, + # so you can test the model right away. New terminal sessions will also have the + # key automatically because `source ~/.pdd/api-env.zsh` was added to ~/.zshrc. # Example 2: Add a custom provider (Together AI, Deepinfra, etc.) # Prompts for prefix, model name, API key var, base URL, costs diff --git a/context/setup_tool_example.py b/context/setup_tool_example.py index 0b03658e9..a517e5980 100644 --- a/context/setup_tool_example.py +++ b/context/setup_tool_example.py @@ -36,6 +36,8 @@ def main() -> None: # ANTHROPIC_API_KEY ✓ Found (shell environment) # OPENAI_API_KEY — Not found # + # 💡 To edit API keys: update ~/.pdd/api-env.zsh or .env file + # # Models configured: 1 (from 1 API keys + 0 local) # # What would you like to do? diff --git a/pdd/prompts/model_tester_python.prompt b/pdd/prompts/model_tester_python.prompt index 638b0f7d0..c946eaf49 100644 --- a/pdd/prompts/model_tester_python.prompt +++ b/pdd/prompts/model_tester_python.prompt @@ -19,12 +19,12 @@ Tests a single configured model by making one `litellm.completion()` call with a % Requirements 1. Function: `test_model_interactive()` — show models from `~/.pdd/llm_model.csv`, let user pick one, test it, loop until user exits (empty input or "q") 2. Test call: `litellm.completion(model=..., messages=[{"role": "user", "content": "Say OK"}], api_key=..., api_base=..., timeout=30)` -3. Before calling, show diagnostics: API key status (`✓ Found (source)` / `✗ Not found` / `(local — no key required)`) and base URL if applicable +3. Before calling, show diagnostics: API key status (`✓ Found (source)` / `✗ Not found` / `(no key configured)`) and base URL if applicable 4. After calling, show: `LLM call ✓ OK (0.3s, $0.0001)` or `LLM call ✗ error description` 5. Calculate cost from token usage × CSV row's input/output prices per 1M tokens 6. Persist test results in the model list display across picks within a session 7. Distinguish errors: authentication, connection refused (local), model not found, timeout -8. For local models (empty api_key): pass api_base, omit api_key +8. For models with empty api_key field (no env var configured): pass api_base if present, omit api_key — litellm will use its own defaults 9. If no user CSV exists or is empty, inform user and return % Dependencies diff --git a/pdd/prompts/provider_manager_python.prompt b/pdd/prompts/provider_manager_python.prompt index 8f57e8ec9..fd028d414 100644 --- a/pdd/prompts/provider_manager_python.prompt +++ b/pdd/prompts/provider_manager_python.prompt @@ -26,20 +26,24 @@ Handles adding and removing LLM providers and models in PDD setup. The primary f 1. `add_provider_from_registry()` — interactive search/browse flow using litellm_registry: a. Show top ~10 providers (from `get_top_providers()`). Accept a number for direct selection, or text to search via `search_providers()`. b. Display selected provider's chat models in a Rich table with columns: #, Model, Input $/M, Output $/M, Max Input. Accept comma-separated numbers or "all". - c. Look up API key env var via `get_api_key_env_var()`. If already set, offer to use existing. If not set, prompt for value. If provider not in map, ask user for env var name. Save to `~/.pdd/api-env.{shell}`. + c. Look up API key env var via `get_api_key_env_var()`. If already set, offer to use existing. If not set, prompt for value. If provider not in map, ask user for env var name. Save to `~/.pdd/api-env.{shell}`. Also set `os.environ[key_name] = key_value` so the key is immediately available in the current session. Ensure `source ~/.pdd/api-env.{shell}` is added to the user's shell RC file (~/.zshrc, ~/.bashrc, etc.) so new terminal sessions automatically have the keys. d. Append selected models to `~/.pdd/llm_model.csv`, skip duplicates (by provider+model pair). Map litellm data to CSV schema. e. Return True if any models were added. -2. `add_custom_provider()` — prompt for provider prefix, model name, API key env var, base URL (optional), costs (optional). Append row to user CSV with sensible defaults. Save API key to api-env if provided. Unchanged from previous version. +2. `add_custom_provider()` — prompt for provider prefix, model name, API key env var, base URL (optional), costs (optional). Append row to user CSV with sensible defaults. Save API key to api-env if provided, set in os.environ for immediate availability, and ensure the source line is in the shell RC file. 3. `remove_models_by_provider()` — group user CSV models by api_key, show numbered list with counts, remove all rows for selected provider. Comment out (never delete) the key in api-env: `# Commented out by pdd setup on YYYY-MM-DD`. 4. `remove_individual_models()` — list all models from user CSV, let user select by comma-separated numbers, remove selected rows. 5. All CSV writes must be atomic (temp file + rename). -6. Detect shell from SHELL env var for api-env file path. +6. Detect shell from SHELL env var for api-env file path and use shell-appropriate syntax: + - bash/zsh/ksh/sh: `export KEY=value` + - fish: `set -gx KEY value` + - csh/tcsh: `setenv KEY value` 7. Handle empty input as cancel/back. 8. Check if litellm is available before registry flow; suggest custom provider as fallback. +9. When adding source line to RC file, use shell-appropriate syntax with existence check (e.g., `[ -f "..." ] && source "..."` for bash/zsh, `test -f ...; and source ...` for fish). % Dependencies diff --git a/pdd/prompts/setup_tool_python.prompt b/pdd/prompts/setup_tool_python.prompt index 993dd9ce0..fad7774f3 100644 --- a/pdd/prompts/setup_tool_python.prompt +++ b/pdd/prompts/setup_tool_python.prompt @@ -26,7 +26,7 @@ Main orchestrator for `pdd setup`. Auto-scans the environment for API keys based % Requirements 1. Function: `run_setup()` — main entry point -2. On entry, call `api_key_scanner.scan_environment()` and display each key with `✓ Found (source)` or `— Not found`, plus a summary line: `Models configured: N (from M API keys + K local)`. If scan_results is empty, display: "No models configured yet. Use 'Add a provider' to get started." +2. On entry, call `api_key_scanner.scan_environment()` and display each key with `✓ Found (source)` or `— Not found`. After the key list, display a helpful note: `💡 To edit API keys: update ~/.pdd/api-env.{shell} or .env file` (where {shell} is auto-detected from the SHELL environment variable, defaulting to "bash"). Then display the summary line: `Models configured: N (from M API keys + K local)`. If scan_results is empty, display: "No models configured yet. Use 'Add a provider' to get started." 3. Present a 6-option menu after the scan: 1. Add a provider (sub-menu: a. Search providers, b. Add a local LLM, c. Add a custom provider) 2. Remove models (sub-menu: a. By provider, b. Individual models) diff --git a/pdd/setup/model_tester.py b/pdd/setup/model_tester.py index 2f7d69cb6..6b7a1ace1 100644 --- a/pdd/setup/model_tester.py +++ b/pdd/setup/model_tester.py @@ -61,9 +61,9 @@ def _resolve_api_key(row: Dict[str, Any]) -> Tuple[Optional[str], str]: """ key_name: str = str(row.get("api_key", "")).strip() - # Local model — no key required + # No env var configured — litellm will use its own defaults if not key_name: - return None, "(local — no key required)" + return None, "(no key configured)" # Check environment key_value = os.getenv(key_name, "") @@ -159,12 +159,9 @@ def _run_test(row: Dict[str, Any]) -> Dict[str, Any]: "timeout": 30, } - # Only pass api_key if we have one (local models don't need it) + # Only pass api_key if we have one; otherwise litellm uses its defaults if api_key: kwargs["api_key"] = api_key - elif not str(row.get("api_key", "")).strip(): - # Local model — use placeholder key if provider expects one - pass if base_url: kwargs["base_url"] = base_url @@ -348,8 +345,8 @@ def test_model_interactive() -> None: api_key, key_status = _resolve_api_key(row) if "✓" in key_status: console.print(f" API Key: [green]{key_status}[/green]") - elif "local" in key_status: - console.print(f" API Key: [dim]{key_status}[/dim]") + elif "no key configured" in key_status: + console.print(f" API Key: [yellow]{key_status}[/yellow]") else: console.print(f" API Key: [red]{key_status}[/red]") diff --git a/pdd/setup/provider_manager.py b/pdd/setup/provider_manager.py index 9b01c879b..92844b109 100644 --- a/pdd/setup/provider_manager.py +++ b/pdd/setup/provider_manager.py @@ -4,8 +4,9 @@ import io import os import re -import tempfile +import shlex import shutil +import tempfile from datetime import datetime from pathlib import Path from typing import Dict, List, Optional @@ -65,6 +66,75 @@ def _get_user_csv_path() -> Path: return _get_pdd_dir() / "llm_model.csv" +def _get_shell_rc_path() -> Optional[Path]: + """Return the shell RC file path (~/.zshrc, ~/.bashrc, etc.).""" + shell = _get_shell_name() + home = Path.home() + shell_files = { + "zsh": home / ".zshrc", + "bash": home / ".bashrc", + "fish": home / ".config" / "fish" / "config.fish", + "csh": home / ".cshrc", + "tcsh": home / ".tcshrc", + "ksh": home / ".kshrc", + "sh": home / ".profile", + } + return shell_files.get(shell) + + +def _get_source_line_for_shell(api_env_path: Path) -> str: + """Return the appropriate source line syntax for the current shell.""" + shell = _get_shell_name() + path_str = str(api_env_path) + + if shell == "fish": + return f'test -f "{path_str}"; and source "{path_str}"' + elif shell in ("csh", "tcsh"): + return f'if ( -f "{path_str}" ) source "{path_str}"' + elif shell == "sh": + # sh uses . instead of source + return f'[ -f "{path_str}" ] && . "{path_str}"' + else: + # bash, zsh, ksh and others + return f'[ -f "{path_str}" ] && source "{path_str}"' + + +def _ensure_api_env_sourced_in_rc() -> bool: + """ + Ensure the api-env file is sourced in the user's shell RC file. + + Adds a shell-appropriate source line to ~/.zshrc (or equivalent) if not + already present. This ensures new terminal sessions automatically have + the API keys available. + + Returns True if the line was added, False if already present or unsupported. + """ + rc_path = _get_shell_rc_path() + if rc_path is None: + return False + + api_env_path = _get_api_env_path() + + # Ensure parent directory exists (important for fish: ~/.config/fish/) + rc_path.parent.mkdir(parents=True, exist_ok=True) + + # Check if api-env path is already referenced in the RC file + if rc_path.exists(): + content = rc_path.read_text(encoding="utf-8") + # Check if the api-env file path is already mentioned (covers any syntax) + if str(api_env_path) in content: + return False + else: + content = "" + + # Build shell-appropriate source line + source_line = _get_source_line_for_shell(api_env_path) + + # Append the source line + with open(rc_path, "a", encoding="utf-8") as f: + f.write(f"\n# PDD API keys\n{source_line}\n") + + return True # --------------------------------------------------------------------------- @@ -130,20 +200,51 @@ def _write_api_env_atomic(path: Path, lines: List[str]) -> None: raise +def _build_env_export_line(key_name: str, key_value: str) -> str: + """Build a shell-appropriate export line for the given key/value.""" + shell = _get_shell_name() + # Use shlex.quote() for proper shell escaping of special characters + # This handles $, ", ', `, \, spaces, and other problematic chars + quoted_value = shlex.quote(key_value) + + if shell == "fish": + return f"set -gx {key_name} {quoted_value}\n" + elif shell in ("csh", "tcsh"): + return f"setenv {key_name} {quoted_value}\n" + else: + # bash, zsh, ksh, sh and others + return f"export {key_name}={quoted_value}\n" + + +def _build_env_key_pattern(key_name: str) -> re.Pattern: + """Build a regex pattern to match any shell syntax for the given key.""" + # Match: export KEY=, setenv KEY , set -gx KEY (with optional comment prefix) + escaped_key = re.escape(key_name) + return re.compile( + rf"^(?:#\s*)?(?:export\s+{escaped_key}\s*=|setenv\s+{escaped_key}\s|set\s+-gx\s+{escaped_key}\s)", + re.MULTILINE, + ) + + def _save_key_to_api_env(key_name: str, key_value: str) -> None: """ Add or update an export line in the api-env file. If the key already exists (even commented out), replace it. + + Uses shell-appropriate syntax (export for bash/zsh, set -gx for fish, + setenv for csh/tcsh). + + Also sets the key in os.environ so it's immediately available + in the current session without requiring the user to source their shell. """ + # Set in current process environment for immediate availability + os.environ[key_name] = key_value + env_path = _get_api_env_path() lines = _read_api_env_lines(env_path) - export_line = f'export {key_name}="{key_value}"\n' - - # Pattern to match existing line (commented or not) - pattern = re.compile( - r"^(?:#\s*)?export\s+" + re.escape(key_name) + r"\s*=", re.MULTILINE - ) + export_line = _build_env_export_line(key_name, key_value) + pattern = _build_env_key_pattern(key_name) found = False new_lines: List[str] = [] @@ -166,13 +267,16 @@ def _save_key_to_api_env(key_name: str, key_value: str) -> None: def _comment_out_key_in_api_env(key_name: str) -> None: """ Comment out (never delete) a key in the api-env file. - Adds a comment with the date. + Adds a comment with the date. Handles all shell syntaxes. """ env_path = _get_api_env_path() lines = _read_api_env_lines(env_path) + # Match uncommented lines only (export, setenv, set -gx) + escaped_key = re.escape(key_name) pattern = re.compile( - r"^export\s+" + re.escape(key_name) + r"\s*=", re.MULTILINE + rf"^(?:export\s+{escaped_key}\s*=|setenv\s+{escaped_key}\s|set\s+-gx\s+{escaped_key}\s)", + re.MULTILINE, ) today = datetime.now().strftime("%Y-%m-%d") @@ -384,6 +488,14 @@ def add_provider_from_registry() -> bool: console.print( f"[green]Updated {api_key_var} in {_get_api_env_path()}[/green]" ) + rc_updated = _ensure_api_env_sourced_in_rc() + if rc_updated: + console.print( + f"[green]Added source line to {_get_shell_rc_path()}[/green]" + ) + console.print( + "[dim]Key is available now for this session.[/dim]" + ) else: key_value = Prompt.ask(f"Enter the value for {api_key_var}") if key_value.strip(): @@ -391,6 +503,14 @@ def add_provider_from_registry() -> bool: console.print( f"[green]Saved {api_key_var} to {_get_api_env_path()}[/green]" ) + rc_updated = _ensure_api_env_sourced_in_rc() + if rc_updated: + console.print( + f"[green]Added source line to {_get_shell_rc_path()}[/green]" + ) + console.print( + "[dim]Key is available now for this session.[/dim]" + ) # ── Step 4: Write to user CSV ────────────────────────────────────── @@ -502,6 +622,14 @@ def add_custom_provider() -> bool: console.print( f"[green]Saved {api_key_var} to {_get_api_env_path()}[/green]" ) + rc_updated = _ensure_api_env_sourced_in_rc() + if rc_updated: + console.print( + f"[green]Added source line to {_get_shell_rc_path()}[/green]" + ) + console.print( + "[dim]Key is available now for this session.[/dim]" + ) # Build the row with sensible defaults new_row: Dict[str, str] = { diff --git a/pdd/setup/setup_tool.py b/pdd/setup/setup_tool.py index dd4f3e5ed..4f420d2c0 100644 --- a/pdd/setup/setup_tool.py +++ b/pdd/setup/setup_tool.py @@ -7,6 +7,7 @@ """ from __future__ import annotations +import os from typing import Dict from .api_key_scanner import scan_environment, KeyInfo @@ -53,6 +54,11 @@ def _display_scan(scan_results: Dict[str, KeyInfo]) -> None: else: print(f" {key_name:30s} — Not found") + # Add helpful note about editing API keys + shell_path = os.environ.get("SHELL", "") + shell_name = os.path.basename(shell_path) if shell_path else "bash" + print(f"\n 💡 To edit API keys: update ~/.pdd/api-env.{shell_name} or .env file") + total_configured = api_found + local_count print( f"\n Models configured: {total_configured} " diff --git a/tests/test_api_key_scanner.py b/tests/test_api_key_scanner.py new file mode 100644 index 000000000..64e989728 --- /dev/null +++ b/tests/test_api_key_scanner.py @@ -0,0 +1,551 @@ +"""Tests for pdd/setup/api_key_scanner.py""" + +import csv +import os +import tempfile +from pathlib import Path +from unittest import mock + +import pytest + +from pdd.setup.api_key_scanner import ( + KeyInfo, + get_provider_key_names, + scan_environment, + _get_csv_path, + _load_dotenv_values, + _detect_shell, + _parse_api_env_file, +) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def temp_home(tmp_path, monkeypatch): + """Create a temporary home directory with .pdd folder.""" + pdd_dir = tmp_path / ".pdd" + pdd_dir.mkdir(parents=True, exist_ok=True) + monkeypatch.setattr(Path, "home", lambda: tmp_path) + return tmp_path + + +@pytest.fixture +def sample_csv(temp_home): + """Create a sample llm_model.csv with various providers.""" + csv_path = temp_home / ".pdd" / "llm_model.csv" + rows = [ + { + "provider": "OpenAI", + "model": "gpt-4", + "input": "30.0", + "output": "60.0", + "coding_arena_elo": "1000", + "base_url": "", + "api_key": "OPENAI_API_KEY", + "max_reasoning_tokens": "0", + "structured_output": "True", + "reasoning_type": "", + "location": "", + }, + { + "provider": "Anthropic", + "model": "claude-3-opus", + "input": "15.0", + "output": "75.0", + "coding_arena_elo": "1000", + "base_url": "", + "api_key": "ANTHROPIC_API_KEY", + "max_reasoning_tokens": "0", + "structured_output": "True", + "reasoning_type": "", + "location": "", + }, + { + "provider": "Local", + "model": "ollama/llama2", + "input": "0.0", + "output": "0.0", + "coding_arena_elo": "1000", + "base_url": "http://localhost:11434", + "api_key": "", # Local LLM - no key needed + "max_reasoning_tokens": "0", + "structured_output": "False", + "reasoning_type": "", + "location": "", + }, + ] + + fieldnames = [ + "provider", "model", "input", "output", "coding_arena_elo", + "base_url", "api_key", "max_reasoning_tokens", "structured_output", + "reasoning_type", "location", + ] + + with open(csv_path, "w", newline="", encoding="utf-8") as f: + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + writer.writerows(rows) + + return csv_path + + +# --------------------------------------------------------------------------- +# Tests for get_provider_key_names +# --------------------------------------------------------------------------- + + +class TestGetProviderKeyNames: + """Tests for get_provider_key_names function.""" + + def test_returns_sorted_unique_keys(self, sample_csv): + """Should return deduplicated, sorted list of API key names.""" + result = get_provider_key_names() + assert result == ["ANTHROPIC_API_KEY", "OPENAI_API_KEY"] + + def test_returns_empty_list_when_csv_missing(self, temp_home): + """Should return empty list when CSV doesn't exist.""" + result = get_provider_key_names() + assert result == [] + + def test_returns_empty_list_when_csv_empty(self, temp_home): + """Should return empty list when CSV exists but is empty.""" + csv_path = temp_home / ".pdd" / "llm_model.csv" + csv_path.touch() + result = get_provider_key_names() + assert result == [] + + def test_handles_csv_without_api_key_column(self, temp_home): + """Should return empty list when CSV lacks api_key column.""" + csv_path = temp_home / ".pdd" / "llm_model.csv" + with open(csv_path, "w", newline="", encoding="utf-8") as f: + writer = csv.DictWriter(f, fieldnames=["provider", "model"]) + writer.writeheader() + writer.writerow({"provider": "OpenAI", "model": "gpt-4"}) + + result = get_provider_key_names() + assert result == [] + + def test_handles_csv_with_only_empty_api_keys(self, temp_home): + """Should return empty list when all api_key values are empty.""" + csv_path = temp_home / ".pdd" / "llm_model.csv" + fieldnames = ["provider", "model", "api_key"] + with open(csv_path, "w", newline="", encoding="utf-8") as f: + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + writer.writerow({"provider": "Local", "model": "llama2", "api_key": ""}) + writer.writerow({"provider": "Local2", "model": "mistral", "api_key": " "}) + + result = get_provider_key_names() + assert result == [] + + def test_deduplicates_same_key_multiple_providers(self, temp_home): + """Should deduplicate when multiple rows use the same API key.""" + csv_path = temp_home / ".pdd" / "llm_model.csv" + fieldnames = ["provider", "model", "api_key"] + with open(csv_path, "w", newline="", encoding="utf-8") as f: + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + writer.writerow({"provider": "OpenAI", "model": "gpt-4", "api_key": "OPENAI_API_KEY"}) + writer.writerow({"provider": "OpenAI", "model": "gpt-3.5", "api_key": "OPENAI_API_KEY"}) + writer.writerow({"provider": "Together", "model": "llama", "api_key": "TOGETHER_API_KEY"}) + + result = get_provider_key_names() + assert result == ["OPENAI_API_KEY", "TOGETHER_API_KEY"] + + def test_handles_malformed_csv(self, temp_home): + """Should return empty list for malformed CSV without raising.""" + csv_path = temp_home / ".pdd" / "llm_model.csv" + csv_path.write_text("this is not,a valid\ncsv file with\"broken quotes") + + result = get_provider_key_names() + # Should handle gracefully - either empty or partial results + assert isinstance(result, list) + + +# --------------------------------------------------------------------------- +# Tests for _detect_shell +# --------------------------------------------------------------------------- + + +class TestDetectShell: + """Tests for _detect_shell function.""" + + def test_detects_zsh(self, monkeypatch): + """Should detect zsh shell.""" + monkeypatch.setenv("SHELL", "/bin/zsh") + assert _detect_shell() == "zsh" + + def test_detects_bash(self, monkeypatch): + """Should detect bash shell.""" + monkeypatch.setenv("SHELL", "/bin/bash") + assert _detect_shell() == "bash" + + def test_detects_fish(self, monkeypatch): + """Should detect fish shell.""" + monkeypatch.setenv("SHELL", "/usr/local/bin/fish") + assert _detect_shell() == "fish" + + def test_returns_none_when_shell_not_set(self, monkeypatch): + """Should return None when SHELL env var is not set.""" + monkeypatch.delenv("SHELL", raising=False) + assert _detect_shell() is None + + def test_handles_unusual_shell_paths(self, monkeypatch): + """Should extract shell name from unusual paths.""" + monkeypatch.setenv("SHELL", "/opt/homebrew/bin/zsh") + assert _detect_shell() == "zsh" + + +# --------------------------------------------------------------------------- +# Tests for _parse_api_env_file +# --------------------------------------------------------------------------- + + +class TestParseApiEnvFile: + """Tests for _parse_api_env_file function.""" + + def test_parses_simple_exports(self, tmp_path): + """Should parse simple export KEY=value lines.""" + env_file = tmp_path / "api-env.bash" + env_file.write_text( + "export OPENAI_API_KEY=sk-12345\n" + "export ANTHROPIC_API_KEY=ant-67890\n" + ) + + result = _parse_api_env_file(env_file) + assert result == { + "OPENAI_API_KEY": "sk-12345", + "ANTHROPIC_API_KEY": "ant-67890", + } + + def test_parses_quoted_values(self, tmp_path): + """Should parse quoted export values.""" + env_file = tmp_path / "api-env.bash" + env_file.write_text( + 'export OPENAI_API_KEY="sk-12345"\n' + "export ANTHROPIC_API_KEY='ant-67890'\n" + ) + + result = _parse_api_env_file(env_file) + assert result == { + "OPENAI_API_KEY": "sk-12345", + "ANTHROPIC_API_KEY": "ant-67890", + } + + def test_skips_commented_lines(self, tmp_path): + """Should skip lines starting with #.""" + env_file = tmp_path / "api-env.bash" + env_file.write_text( + "# This is a comment\n" + "export OPENAI_API_KEY=sk-12345\n" + "# export ANTHROPIC_API_KEY=ant-67890\n" + ) + + result = _parse_api_env_file(env_file) + assert result == {"OPENAI_API_KEY": "sk-12345"} + + def test_skips_empty_lines(self, tmp_path): + """Should skip empty lines.""" + env_file = tmp_path / "api-env.bash" + env_file.write_text( + "export OPENAI_API_KEY=sk-12345\n" + "\n" + " \n" + "export ANTHROPIC_API_KEY=ant-67890\n" + ) + + result = _parse_api_env_file(env_file) + assert len(result) == 2 + + def test_returns_empty_dict_for_missing_file(self, tmp_path): + """Should return empty dict when file doesn't exist.""" + env_file = tmp_path / "nonexistent" + result = _parse_api_env_file(env_file) + assert result == {} + + def test_handles_special_characters_in_values(self, tmp_path): + """Should handle API keys with special characters.""" + # Characters that might appear in API keys + special_key = 'AIzaSyAbCdEf123456$var"quote\'backtick\\slash' + env_file = tmp_path / "api-env.bash" + # Note: The file might have various quoting styles + env_file.write_text(f"export TEST_KEY='{special_key}'\n") + + result = _parse_api_env_file(env_file) + assert result.get("TEST_KEY") == special_key + + def test_ignores_non_export_lines(self, tmp_path): + """Should ignore lines that don't start with 'export '.""" + env_file = tmp_path / "api-env.bash" + env_file.write_text( + "OPENAI_API_KEY=sk-12345\n" # No export + "export ANTHROPIC_API_KEY=ant-67890\n" + "echo 'hello'\n" + ) + + result = _parse_api_env_file(env_file) + assert result == {"ANTHROPIC_API_KEY": "ant-67890"} + + def test_handles_whitespace_around_equals(self, tmp_path): + """Should handle whitespace around equals sign.""" + env_file = tmp_path / "api-env.bash" + env_file.write_text("export OPENAI_API_KEY = sk-12345\n") + + # The current implementation uses partition("="), check behavior + result = _parse_api_env_file(env_file) + # Result may vary based on implementation - just ensure no crash + assert isinstance(result, dict) + + +# --------------------------------------------------------------------------- +# Tests for scan_environment +# --------------------------------------------------------------------------- + + +class TestScanEnvironment: + """Tests for scan_environment function.""" + + def test_returns_empty_dict_when_no_models_configured(self, temp_home): + """Should return empty dict when no models in CSV.""" + result = scan_environment() + assert result == {} + + def test_detects_key_in_shell_environment(self, sample_csv, monkeypatch): + """Should detect keys set in shell environment.""" + monkeypatch.setenv("OPENAI_API_KEY", "sk-test123") + # Don't set ANTHROPIC_API_KEY + + result = scan_environment() + + assert "OPENAI_API_KEY" in result + assert result["OPENAI_API_KEY"].is_set is True + assert result["OPENAI_API_KEY"].source == "shell environment" + + assert "ANTHROPIC_API_KEY" in result + assert result["ANTHROPIC_API_KEY"].is_set is False + + def test_detects_key_in_api_env_file(self, sample_csv, temp_home, monkeypatch): + """Should detect keys in ~/.pdd/api-env.{shell} file.""" + monkeypatch.setenv("SHELL", "/bin/bash") + # Clear any existing env vars + monkeypatch.delenv("OPENAI_API_KEY", raising=False) + monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False) + + api_env_path = temp_home / ".pdd" / "api-env.bash" + api_env_path.write_text("export OPENAI_API_KEY=sk-from-api-env\n") + + result = scan_environment() + + assert result["OPENAI_API_KEY"].is_set is True + assert result["OPENAI_API_KEY"].source == "~/.pdd/api-env.bash" + assert result["ANTHROPIC_API_KEY"].is_set is False + + def test_priority_order_dotenv_first(self, sample_csv, temp_home, monkeypatch): + """Should check .env file first (highest priority).""" + monkeypatch.setenv("SHELL", "/bin/bash") + monkeypatch.setenv("OPENAI_API_KEY", "sk-from-shell") + + # Create api-env file too + api_env_path = temp_home / ".pdd" / "api-env.bash" + api_env_path.write_text("export OPENAI_API_KEY=sk-from-api-env\n") + + # Mock dotenv to return a value + with mock.patch( + "pdd.setup.api_key_scanner._load_dotenv_values", + return_value={"OPENAI_API_KEY": "sk-from-dotenv"}, + ): + result = scan_environment() + + assert result["OPENAI_API_KEY"].source == ".env file" + + def test_priority_order_shell_before_api_env(self, sample_csv, temp_home, monkeypatch): + """Shell environment should have priority over api-env file.""" + monkeypatch.setenv("SHELL", "/bin/bash") + monkeypatch.setenv("OPENAI_API_KEY", "sk-from-shell") + + api_env_path = temp_home / ".pdd" / "api-env.bash" + api_env_path.write_text("export OPENAI_API_KEY=sk-from-api-env\n") + + # Mock dotenv to return empty (no .env file) + with mock.patch( + "pdd.setup.api_key_scanner._load_dotenv_values", + return_value={}, + ): + result = scan_environment() + + assert result["OPENAI_API_KEY"].source == "shell environment" + + def test_keyinfo_structure(self, sample_csv, monkeypatch): + """Should return KeyInfo dataclass with correct fields.""" + monkeypatch.setenv("OPENAI_API_KEY", "sk-test") + + result = scan_environment() + key_info = result["OPENAI_API_KEY"] + + assert isinstance(key_info, KeyInfo) + assert hasattr(key_info, "source") + assert hasattr(key_info, "is_set") + + def test_handles_exception_gracefully(self, monkeypatch, temp_home): + """Should return best-effort results on errors without raising.""" + # Create a CSV that will cause issues + csv_path = temp_home / ".pdd" / "llm_model.csv" + csv_path.write_text("provider,model,api_key\nTest,test,TEST_KEY\n") + + # Mock get_provider_key_names to raise + with mock.patch( + "pdd.setup.api_key_scanner.get_provider_key_names", + side_effect=Exception("Test error"), + ): + result = scan_environment() + + # Should return empty dict, not raise + assert result == {} + + def test_different_shells_use_different_api_env_files(self, sample_csv, temp_home, monkeypatch): + """Should use api-env file matching the detected shell.""" + # Create both bash and zsh api-env files with different keys + (temp_home / ".pdd" / "api-env.bash").write_text( + "export OPENAI_API_KEY=sk-bash\n" + ) + (temp_home / ".pdd" / "api-env.zsh").write_text( + "export ANTHROPIC_API_KEY=ant-zsh\n" + ) + + # Clear shell env + monkeypatch.delenv("OPENAI_API_KEY", raising=False) + monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False) + + # Test with bash shell + monkeypatch.setenv("SHELL", "/bin/bash") + with mock.patch("pdd.setup.api_key_scanner._load_dotenv_values", return_value={}): + result = scan_environment() + + assert result["OPENAI_API_KEY"].is_set is True + assert result["OPENAI_API_KEY"].source == "~/.pdd/api-env.bash" + assert result["ANTHROPIC_API_KEY"].is_set is False + + +# --------------------------------------------------------------------------- +# Tests for _load_dotenv_values +# --------------------------------------------------------------------------- + + +class TestLoadDotenvValues: + """Tests for _load_dotenv_values function.""" + + def test_returns_empty_dict_when_dotenv_not_installed(self, monkeypatch): + """Should return empty dict if python-dotenv is not available.""" + # Mock the import to fail + import builtins + real_import = builtins.__import__ + + def mock_import(name, *args, **kwargs): + if name == "dotenv": + raise ImportError("No module named 'dotenv'") + return real_import(name, *args, **kwargs) + + monkeypatch.setattr(builtins, "__import__", mock_import) + + result = _load_dotenv_values() + assert result == {} + + def test_filters_out_none_values(self): + """Should filter out keys with None values from dotenv.""" + # Mock dotenv_values to return some None values + with mock.patch("dotenv.dotenv_values", return_value={ + "KEY1": "value1", + "KEY2": None, + "KEY3": "value3", + }): + result = _load_dotenv_values() + + assert result == {"KEY1": "value1", "KEY3": "value3"} + + +# --------------------------------------------------------------------------- +# Edge case tests +# --------------------------------------------------------------------------- + + +class TestEdgeCases: + """Test edge cases and error handling.""" + + def test_handles_unicode_in_csv(self, temp_home): + """Should handle unicode characters in CSV.""" + csv_path = temp_home / ".pdd" / "llm_model.csv" + fieldnames = ["provider", "model", "api_key"] + with open(csv_path, "w", newline="", encoding="utf-8") as f: + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + writer.writerow({ + "provider": "Tëst Provider", + "model": "模型", + "api_key": "UNICODE_KEY_名前", + }) + + result = get_provider_key_names() + assert "UNICODE_KEY_名前" in result + + def test_handles_very_long_api_key_names(self, temp_home): + """Should handle very long API key names.""" + csv_path = temp_home / ".pdd" / "llm_model.csv" + fieldnames = ["provider", "model", "api_key"] + long_key_name = "A" * 1000 + "_API_KEY" + + with open(csv_path, "w", newline="", encoding="utf-8") as f: + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + writer.writerow({ + "provider": "Test", + "model": "test", + "api_key": long_key_name, + }) + + result = get_provider_key_names() + assert long_key_name in result + + def test_handles_api_key_with_special_shell_characters(self, temp_home, monkeypatch): + """Should handle API key names with characters problematic for shells.""" + csv_path = temp_home / ".pdd" / "llm_model.csv" + fieldnames = ["provider", "model", "api_key"] + + with open(csv_path, "w", newline="", encoding="utf-8") as f: + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + writer.writerow({ + "provider": "Test", + "model": "test", + "api_key": "MY_SPECIAL_KEY", + }) + + # Set the env var + monkeypatch.setenv("MY_SPECIAL_KEY", "value_with_$pecial_chars") + monkeypatch.setenv("SHELL", "/bin/bash") + + with mock.patch("pdd.setup.api_key_scanner._load_dotenv_values", return_value={}): + result = scan_environment() + + assert result["MY_SPECIAL_KEY"].is_set is True + + def test_handles_permission_error_on_csv(self, temp_home, monkeypatch): + """Should handle permission errors gracefully.""" + csv_path = temp_home / ".pdd" / "llm_model.csv" + csv_path.write_text("provider,model,api_key\nTest,test,KEY\n") + + # Mock open to raise PermissionError + original_open = open + + def mock_open_with_permission_error(file, *args, **kwargs): + if str(file) == str(csv_path): + raise PermissionError("Access denied") + return original_open(file, *args, **kwargs) + + with mock.patch("builtins.open", side_effect=mock_open_with_permission_error): + result = get_provider_key_names() + + assert result == [] diff --git a/tests/test_cli_detector.py b/tests/test_cli_detector.py new file mode 100644 index 000000000..89c7db857 --- /dev/null +++ b/tests/test_cli_detector.py @@ -0,0 +1,489 @@ +"""Tests for pdd/setup/cli_detector.py""" + +import subprocess +from unittest import mock + +import pytest + +from pdd.setup.cli_detector import ( + _CLI_COMMANDS, + _API_KEY_ENV_VARS, + _INSTALL_COMMANDS, + _CLI_DISPLAY_NAMES, + _which, + _has_api_key, + _npm_available, + _prompt_yes_no, + _run_install, + detect_cli_tools, +) + + +# --------------------------------------------------------------------------- +# Tests for constants +# --------------------------------------------------------------------------- + + +class TestConstants: + """Tests for static mappings.""" + + def test_cli_commands_has_expected_providers(self): + """Should have CLI commands for known providers.""" + assert "anthropic" in _CLI_COMMANDS + assert "google" in _CLI_COMMANDS + assert "openai" in _CLI_COMMANDS + + def test_cli_commands_values(self): + """CLI command values should be correct.""" + assert _CLI_COMMANDS["anthropic"] == "claude" + assert _CLI_COMMANDS["google"] == "gemini" + assert _CLI_COMMANDS["openai"] == "codex" + + def test_api_key_env_vars_has_expected_providers(self): + """Should have API key env vars for known providers.""" + assert "anthropic" in _API_KEY_ENV_VARS + assert "google" in _API_KEY_ENV_VARS + assert "openai" in _API_KEY_ENV_VARS + + def test_api_key_env_vars_values(self): + """API key env var values should be correct.""" + assert _API_KEY_ENV_VARS["anthropic"] == "ANTHROPIC_API_KEY" + assert _API_KEY_ENV_VARS["google"] == "GOOGLE_API_KEY" + assert _API_KEY_ENV_VARS["openai"] == "OPENAI_API_KEY" + + def test_install_commands_has_expected_providers(self): + """Should have install commands for known providers.""" + assert "anthropic" in _INSTALL_COMMANDS + assert "google" in _INSTALL_COMMANDS + assert "openai" in _INSTALL_COMMANDS + + def test_install_commands_are_npm_commands(self): + """Install commands should be npm install commands.""" + for provider, cmd in _INSTALL_COMMANDS.items(): + assert cmd.startswith("npm install -g ") + + def test_cli_display_names_has_expected_providers(self): + """Should have display names for known providers.""" + assert "anthropic" in _CLI_DISPLAY_NAMES + assert "google" in _CLI_DISPLAY_NAMES + assert "openai" in _CLI_DISPLAY_NAMES + + def test_cli_display_names_are_human_readable(self): + """Display names should be human-readable.""" + assert _CLI_DISPLAY_NAMES["anthropic"] == "Claude CLI" + assert _CLI_DISPLAY_NAMES["google"] == "Gemini CLI" + assert _CLI_DISPLAY_NAMES["openai"] == "Codex CLI" + + def test_all_providers_have_consistent_mappings(self): + """All providers should have entries in all mappings.""" + providers = set(_CLI_COMMANDS.keys()) + + assert providers == set(_API_KEY_ENV_VARS.keys()) + assert providers == set(_INSTALL_COMMANDS.keys()) + assert providers == set(_CLI_DISPLAY_NAMES.keys()) + + +# --------------------------------------------------------------------------- +# Tests for _which +# --------------------------------------------------------------------------- + + +class TestWhich: + """Tests for _which function.""" + + def test_returns_path_for_existing_command(self): + """Should return path for commands that exist.""" + # 'ls' should exist on all Unix-like systems + result = _which("ls") + assert result is not None + assert "ls" in result + + def test_returns_none_for_nonexistent_command(self): + """Should return None for commands that don't exist.""" + result = _which("nonexistent_command_xyz_12345") + assert result is None + + def test_returns_none_for_empty_string(self): + """Should return None for empty command string.""" + result = _which("") + assert result is None + + +# --------------------------------------------------------------------------- +# Tests for _has_api_key +# --------------------------------------------------------------------------- + + +class TestHasApiKey: + """Tests for _has_api_key function.""" + + def test_returns_true_when_key_set(self, monkeypatch): + """Should return True when API key is set.""" + monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key-value") + assert _has_api_key("anthropic") is True + + def test_returns_false_when_key_not_set(self, monkeypatch): + """Should return False when API key is not set.""" + monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False) + assert _has_api_key("anthropic") is False + + def test_returns_false_when_key_empty(self, monkeypatch): + """Should return False when API key is empty string.""" + monkeypatch.setenv("ANTHROPIC_API_KEY", "") + assert _has_api_key("anthropic") is False + + def test_returns_false_when_key_whitespace(self, monkeypatch): + """Should return False when API key is only whitespace.""" + monkeypatch.setenv("ANTHROPIC_API_KEY", " ") + assert _has_api_key("anthropic") is False + + def test_returns_false_for_unknown_provider(self, monkeypatch): + """Should return False for unknown providers.""" + # Unknown provider won't be in _API_KEY_ENV_VARS + assert _has_api_key("unknown_provider") is False + + +# --------------------------------------------------------------------------- +# Tests for _npm_available +# --------------------------------------------------------------------------- + + +class TestNpmAvailable: + """Tests for _npm_available function.""" + + def test_returns_bool(self): + """Should return a boolean.""" + result = _npm_available() + assert isinstance(result, bool) + + def test_uses_which_internally(self): + """Should use _which to find npm.""" + with mock.patch("pdd.setup.cli_detector._which") as mock_which: + mock_which.return_value = "/usr/bin/npm" + result = _npm_available() + mock_which.assert_called_once_with("npm") + assert result is True + + def test_returns_false_when_npm_not_found(self): + """Should return False when npm is not installed.""" + with mock.patch("pdd.setup.cli_detector._which", return_value=None): + assert _npm_available() is False + + +# --------------------------------------------------------------------------- +# Tests for _prompt_yes_no +# --------------------------------------------------------------------------- + + +class TestPromptYesNo: + """Tests for _prompt_yes_no function.""" + + def test_returns_true_for_y(self): + """Should return True for 'y' input.""" + with mock.patch("builtins.input", return_value="y"): + assert _prompt_yes_no("Test? ") is True + + def test_returns_true_for_yes(self): + """Should return True for 'yes' input.""" + with mock.patch("builtins.input", return_value="yes"): + assert _prompt_yes_no("Test? ") is True + + def test_returns_true_for_Y_uppercase(self): + """Should return True for uppercase 'Y' input.""" + with mock.patch("builtins.input", return_value="Y"): + assert _prompt_yes_no("Test? ") is True + + def test_returns_true_for_YES_uppercase(self): + """Should return True for uppercase 'YES' input.""" + with mock.patch("builtins.input", return_value="YES"): + assert _prompt_yes_no("Test? ") is True + + def test_returns_false_for_n(self): + """Should return False for 'n' input.""" + with mock.patch("builtins.input", return_value="n"): + assert _prompt_yes_no("Test? ") is False + + def test_returns_false_for_no(self): + """Should return False for 'no' input.""" + with mock.patch("builtins.input", return_value="no"): + assert _prompt_yes_no("Test? ") is False + + def test_returns_false_for_empty(self): + """Should return False for empty input (default is No).""" + with mock.patch("builtins.input", return_value=""): + assert _prompt_yes_no("Test? ") is False + + def test_returns_false_for_random_input(self): + """Should return False for unrecognized input.""" + with mock.patch("builtins.input", return_value="maybe"): + assert _prompt_yes_no("Test? ") is False + + def test_handles_eof_error(self): + """Should return False on EOFError.""" + with mock.patch("builtins.input", side_effect=EOFError()): + assert _prompt_yes_no("Test? ") is False + + def test_handles_keyboard_interrupt(self): + """Should return False on KeyboardInterrupt.""" + with mock.patch("builtins.input", side_effect=KeyboardInterrupt()): + assert _prompt_yes_no("Test? ") is False + + def test_strips_whitespace(self): + """Should strip whitespace from input.""" + with mock.patch("builtins.input", return_value=" y "): + assert _prompt_yes_no("Test? ") is True + + +# --------------------------------------------------------------------------- +# Tests for _run_install +# --------------------------------------------------------------------------- + + +class TestRunInstall: + """Tests for _run_install function.""" + + def test_returns_true_on_success(self): + """Should return True when command succeeds.""" + with mock.patch("subprocess.run") as mock_run: + mock_run.return_value = mock.MagicMock(returncode=0) + result = _run_install("echo test") + assert result is True + + def test_returns_false_on_failure(self): + """Should return False when command fails.""" + with mock.patch("subprocess.run") as mock_run: + mock_run.return_value = mock.MagicMock(returncode=1) + result = _run_install("exit 1") + assert result is False + + def test_returns_false_on_exception(self): + """Should return False on subprocess exception.""" + with mock.patch("subprocess.run", side_effect=Exception("Test error")): + result = _run_install("failing command") + assert result is False + + def test_runs_command_with_shell(self): + """Should run command with shell=True.""" + with mock.patch("subprocess.run") as mock_run: + mock_run.return_value = mock.MagicMock(returncode=0) + _run_install("npm install -g test") + mock_run.assert_called_once() + call_kwargs = mock_run.call_args[1] + assert call_kwargs["shell"] is True + + +# --------------------------------------------------------------------------- +# Tests for detect_cli_tools +# --------------------------------------------------------------------------- + + +class TestDetectCliTools: + """Tests for detect_cli_tools function.""" + + def test_prints_header(self, capsys): + """Should print the detection header.""" + with mock.patch("pdd.setup.cli_detector._which", return_value=None): + with mock.patch("pdd.setup.cli_detector._has_api_key", return_value=False): + detect_cli_tools() + + captured = capsys.readouterr() + assert "Agentic CLI Tool Detection" in captured.out + assert "pdd fix, pdd change, pdd bug" in captured.out + + def test_shows_found_cli(self, capsys): + """Should show checkmark for found CLI tools.""" + with mock.patch("pdd.setup.cli_detector._which") as mock_which: + mock_which.side_effect = lambda cmd: "/usr/bin/claude" if cmd == "claude" else None + with mock.patch("pdd.setup.cli_detector._has_api_key", return_value=False): + detect_cli_tools() + + captured = capsys.readouterr() + assert "✓" in captured.out + assert "Claude CLI" in captured.out + + def test_shows_not_found_cli(self, capsys): + """Should show X for missing CLI tools.""" + with mock.patch("pdd.setup.cli_detector._which", return_value=None): + with mock.patch("pdd.setup.cli_detector._has_api_key", return_value=False): + detect_cli_tools() + + captured = capsys.readouterr() + assert "✗" in captured.out + assert "Not found" in captured.out + + def test_shows_api_key_status_when_cli_found(self, capsys): + """Should show API key status when CLI is found.""" + with mock.patch("pdd.setup.cli_detector._which", return_value="/usr/bin/claude"): + with mock.patch("pdd.setup.cli_detector._has_api_key") as mock_has_key: + mock_has_key.return_value = True + detect_cli_tools() + + captured = capsys.readouterr() + assert "set" in captured.out + + def test_warns_when_cli_found_but_no_key(self, capsys): + """Should warn when CLI found but API key not set.""" + with mock.patch("pdd.setup.cli_detector._which", return_value="/usr/bin/claude"): + with mock.patch("pdd.setup.cli_detector._has_api_key", return_value=False): + detect_cli_tools() + + captured = capsys.readouterr() + assert "not set" in captured.out + assert "won't be usable" in captured.out + + def test_suggests_install_when_key_but_no_cli(self, capsys): + """Should suggest installation when API key is set but CLI is missing.""" + with mock.patch("pdd.setup.cli_detector._which", return_value=None): + with mock.patch("pdd.setup.cli_detector._has_api_key") as mock_has_key: + # Only anthropic has key set + mock_has_key.side_effect = lambda p: p == "anthropic" + with mock.patch("pdd.setup.cli_detector._npm_available", return_value=False): + detect_cli_tools() + + captured = capsys.readouterr() + assert "install the CLI" in captured.out + + def test_offers_installation_when_npm_available(self, capsys): + """Should offer installation when npm is available.""" + with mock.patch("pdd.setup.cli_detector._which", return_value=None): + with mock.patch("pdd.setup.cli_detector._has_api_key") as mock_has_key: + mock_has_key.side_effect = lambda p: p == "anthropic" + with mock.patch("pdd.setup.cli_detector._npm_available", return_value=True): + with mock.patch("pdd.setup.cli_detector._prompt_yes_no", return_value=False): + detect_cli_tools() + + captured = capsys.readouterr() + assert "Install now?" in captured.out or "Install command" in captured.out + + def test_shows_npm_not_installed_message(self, capsys): + """Should show message when npm is not installed.""" + with mock.patch("pdd.setup.cli_detector._which", return_value=None): + with mock.patch("pdd.setup.cli_detector._has_api_key") as mock_has_key: + mock_has_key.side_effect = lambda p: p == "anthropic" + with mock.patch("pdd.setup.cli_detector._npm_available", return_value=False): + detect_cli_tools() + + captured = capsys.readouterr() + assert "npm is not installed" in captured.out + + def test_runs_installation_on_yes(self, capsys): + """Should run installation when user says yes.""" + with mock.patch("pdd.setup.cli_detector._which") as mock_which: + mock_which.side_effect = [None, None, None, "/usr/bin/claude"] # Found after install + with mock.patch("pdd.setup.cli_detector._has_api_key") as mock_has_key: + mock_has_key.side_effect = lambda p: p == "anthropic" + with mock.patch("pdd.setup.cli_detector._npm_available", return_value=True): + with mock.patch("pdd.setup.cli_detector._prompt_yes_no", return_value=True): + with mock.patch("pdd.setup.cli_detector._run_install", return_value=True): + detect_cli_tools() + + captured = capsys.readouterr() + assert "installed successfully" in captured.out or "completed" in captured.out + + def test_shows_failure_message_on_install_fail(self, capsys): + """Should show failure message when installation fails.""" + with mock.patch("pdd.setup.cli_detector._which", return_value=None): + with mock.patch("pdd.setup.cli_detector._has_api_key") as mock_has_key: + mock_has_key.side_effect = lambda p: p == "anthropic" + with mock.patch("pdd.setup.cli_detector._npm_available", return_value=True): + with mock.patch("pdd.setup.cli_detector._prompt_yes_no", return_value=True): + with mock.patch("pdd.setup.cli_detector._run_install", return_value=False): + detect_cli_tools() + + captured = capsys.readouterr() + assert "failed" in captured.out.lower() or "manually" in captured.out.lower() + + def test_shows_skip_message_on_no(self, capsys): + """Should show skip message when user declines installation.""" + with mock.patch("pdd.setup.cli_detector._which", return_value=None): + with mock.patch("pdd.setup.cli_detector._has_api_key") as mock_has_key: + mock_has_key.side_effect = lambda p: p == "anthropic" + with mock.patch("pdd.setup.cli_detector._npm_available", return_value=True): + with mock.patch("pdd.setup.cli_detector._prompt_yes_no", return_value=False): + detect_cli_tools() + + captured = capsys.readouterr() + assert "Skipped" in captured.out or "later" in captured.out + + def test_shows_quick_start_when_nothing_installed(self, capsys): + """Should show quick start guide when no CLIs are installed.""" + with mock.patch("pdd.setup.cli_detector._which", return_value=None): + with mock.patch("pdd.setup.cli_detector._has_api_key", return_value=False): + detect_cli_tools() + + captured = capsys.readouterr() + assert "Quick start" in captured.out or "No CLI tools found" in captured.out + + def test_shows_all_installed_message(self, capsys): + """Should show success message when all CLIs with keys are installed.""" + with mock.patch("pdd.setup.cli_detector._which", return_value="/usr/bin/cli"): + with mock.patch("pdd.setup.cli_detector._has_api_key", return_value=True): + detect_cli_tools() + + captured = capsys.readouterr() + assert "All CLI tools with matching API keys are installed" in captured.out + + +# --------------------------------------------------------------------------- +# Integration tests +# --------------------------------------------------------------------------- + + +class TestIntegration: + """Integration tests for CLI detector.""" + + def test_detect_cli_tools_handles_import_error(self, capsys): + """Should handle missing agentic_common gracefully.""" + with mock.patch("pdd.setup.cli_detector._which", return_value=None): + with mock.patch("pdd.setup.cli_detector._has_api_key", return_value=False): + # The function imports get_available_agents but handles import errors + detect_cli_tools() + + # Should complete without error + captured = capsys.readouterr() + assert "Agentic CLI Tool Detection" in captured.out + + def test_detect_cli_tools_complete_flow(self, capsys): + """Test complete detection flow with mixed results.""" + def mock_which(cmd): + return "/usr/bin/claude" if cmd == "claude" else None + + def mock_has_key(provider): + return provider in ["anthropic", "openai"] + + with mock.patch("pdd.setup.cli_detector._which", side_effect=mock_which): + with mock.patch("pdd.setup.cli_detector._has_api_key", side_effect=mock_has_key): + with mock.patch("pdd.setup.cli_detector._npm_available", return_value=False): + detect_cli_tools() + + captured = capsys.readouterr() + # Claude should show as found + assert "Claude CLI" in captured.out + assert "✓" in captured.out + # Others should show as not found + assert "✗" in captured.out + + +# --------------------------------------------------------------------------- +# Edge case tests +# --------------------------------------------------------------------------- + + +class TestEdgeCases: + """Test edge cases and error handling.""" + + def test_handles_subprocess_timeout(self): + """Should handle subprocess timeout gracefully.""" + with mock.patch("subprocess.run", side_effect=subprocess.TimeoutExpired("cmd", 30)): + result = _run_install("slow command") + assert result is False + + def test_empty_env_var_treated_as_not_set(self, monkeypatch): + """Empty string env vars should be treated as not set.""" + monkeypatch.setenv("ANTHROPIC_API_KEY", "") + assert _has_api_key("anthropic") is False + + def test_whitespace_only_env_var_treated_as_not_set(self, monkeypatch): + """Whitespace-only env vars should be treated as not set.""" + monkeypatch.setenv("ANTHROPIC_API_KEY", " \t\n ") + assert _has_api_key("anthropic") is False diff --git a/tests/test_litellm_registry.py b/tests/test_litellm_registry.py new file mode 100644 index 000000000..067b7633c --- /dev/null +++ b/tests/test_litellm_registry.py @@ -0,0 +1,561 @@ +"""Tests for pdd/setup/litellm_registry.py""" + +from unittest import mock + +import pytest + +from pdd.setup.litellm_registry import ( + ProviderInfo, + ModelInfo, + PROVIDER_API_KEY_MAP, + PROVIDER_DISPLAY_NAMES, + is_litellm_available, + get_api_key_env_var, + get_top_providers, + get_all_providers, + search_providers, + get_models_for_provider, + _get_display_name, + _entry_to_model_info, +) + + +# --------------------------------------------------------------------------- +# Tests for constants +# --------------------------------------------------------------------------- + + +class TestConstants: + """Tests for static mappings.""" + + def test_provider_api_key_map_has_common_providers(self): + """Should have API key mappings for common providers.""" + assert "openai" in PROVIDER_API_KEY_MAP + assert "anthropic" in PROVIDER_API_KEY_MAP + assert "gemini" in PROVIDER_API_KEY_MAP + assert "groq" in PROVIDER_API_KEY_MAP + assert "mistral" in PROVIDER_API_KEY_MAP + + def test_provider_api_key_map_values_are_strings(self): + """All API key env var names should be strings.""" + for key_name in PROVIDER_API_KEY_MAP.values(): + assert isinstance(key_name, str) + assert len(key_name) > 0 + # Should look like an env var (uppercase with underscores) + assert key_name.isupper() or "_" in key_name + + def test_provider_display_names_has_common_providers(self): + """Should have display names for common providers.""" + assert "openai" in PROVIDER_DISPLAY_NAMES + assert PROVIDER_DISPLAY_NAMES["openai"] == "OpenAI" + assert PROVIDER_DISPLAY_NAMES["anthropic"] == "Anthropic" + assert PROVIDER_DISPLAY_NAMES["gemini"] == "Google Gemini" + + def test_provider_display_names_are_human_readable(self): + """Display names should be human-readable (not snake_case).""" + for provider, display_name in PROVIDER_DISPLAY_NAMES.items(): + assert isinstance(display_name, str) + assert len(display_name) > 0 + # Should not be all lowercase with underscores + if "_" in provider: + assert "_" not in display_name or display_name != provider + + +# --------------------------------------------------------------------------- +# Tests for dataclasses +# --------------------------------------------------------------------------- + + +class TestDataclasses: + """Tests for ProviderInfo and ModelInfo dataclasses.""" + + def test_provider_info_fields(self): + """ProviderInfo should have all required fields.""" + info = ProviderInfo( + name="openai", + display_name="OpenAI", + api_key_env_var="OPENAI_API_KEY", + model_count=10, + sample_models=["gpt-4", "gpt-3.5-turbo"], + ) + assert info.name == "openai" + assert info.display_name == "OpenAI" + assert info.api_key_env_var == "OPENAI_API_KEY" + assert info.model_count == 10 + assert info.sample_models == ["gpt-4", "gpt-3.5-turbo"] + + def test_provider_info_defaults(self): + """ProviderInfo sample_models should default to empty list.""" + info = ProviderInfo( + name="test", + display_name="Test", + api_key_env_var=None, + model_count=0, + ) + assert info.sample_models == [] + + def test_model_info_fields(self): + """ModelInfo should have all required fields.""" + info = ModelInfo( + name="gpt-4", + litellm_id="gpt-4", + input_cost_per_million=30.0, + output_cost_per_million=60.0, + max_input_tokens=128000, + max_output_tokens=8192, + supports_vision=True, + supports_function_calling=True, + ) + assert info.name == "gpt-4" + assert info.litellm_id == "gpt-4" + assert info.input_cost_per_million == 30.0 + assert info.output_cost_per_million == 60.0 + assert info.max_input_tokens == 128000 + assert info.max_output_tokens == 8192 + assert info.supports_vision is True + assert info.supports_function_calling is True + + def test_model_info_defaults(self): + """ModelInfo should have sensible defaults.""" + info = ModelInfo( + name="test", + litellm_id="test", + input_cost_per_million=0.0, + output_cost_per_million=0.0, + ) + assert info.max_input_tokens is None + assert info.max_output_tokens is None + assert info.supports_vision is False + assert info.supports_function_calling is False + + +# --------------------------------------------------------------------------- +# Tests for is_litellm_available +# --------------------------------------------------------------------------- + + +class TestIsLitellmAvailable: + """Tests for is_litellm_available function.""" + + def test_returns_true_when_litellm_installed(self): + """Should return True when litellm is importable with model data.""" + # If litellm is installed in test environment, this should return True + # We'll mock it to ensure consistent behavior + mock_litellm = mock.MagicMock() + mock_litellm.model_cost = {"gpt-4": {"mode": "chat"}} + + with mock.patch.dict("sys.modules", {"litellm": mock_litellm}): + # Need to reimport or call the function after mocking + result = is_litellm_available() + # Either True (if litellm is installed) or we need to mock + assert isinstance(result, bool) + + def test_returns_false_when_litellm_not_installed(self): + """Should return False when litellm import fails.""" + with mock.patch.dict("sys.modules", {"litellm": None}): + # Force ImportError + def raise_import_error(): + raise ImportError("No module named 'litellm'") + + with mock.patch( + "pdd.setup.litellm_registry.is_litellm_available", + side_effect=raise_import_error, + ): + # The actual function should handle this gracefully + pass + + def test_returns_false_when_model_cost_empty(self): + """Should return False when litellm.model_cost is empty.""" + mock_litellm = mock.MagicMock() + mock_litellm.model_cost = {} + + with mock.patch("pdd.setup.litellm_registry.is_litellm_available") as mock_fn: + mock_fn.return_value = False + assert mock_fn() is False + + +# --------------------------------------------------------------------------- +# Tests for get_api_key_env_var +# --------------------------------------------------------------------------- + + +class TestGetApiKeyEnvVar: + """Tests for get_api_key_env_var function.""" + + def test_returns_key_for_known_providers(self): + """Should return correct API key env var for known providers.""" + assert get_api_key_env_var("openai") == "OPENAI_API_KEY" + assert get_api_key_env_var("anthropic") == "ANTHROPIC_API_KEY" + assert get_api_key_env_var("gemini") == "GEMINI_API_KEY" + assert get_api_key_env_var("groq") == "GROQ_API_KEY" + + def test_returns_none_for_unknown_providers(self): + """Should return None for providers not in the mapping.""" + assert get_api_key_env_var("unknown_provider") is None + assert get_api_key_env_var("") is None + assert get_api_key_env_var("my_custom_llm") is None + + def test_case_sensitive(self): + """Provider name lookup should be case-sensitive.""" + assert get_api_key_env_var("openai") == "OPENAI_API_KEY" + assert get_api_key_env_var("OpenAI") is None + assert get_api_key_env_var("OPENAI") is None + + +# --------------------------------------------------------------------------- +# Tests for _get_display_name +# --------------------------------------------------------------------------- + + +class TestGetDisplayName: + """Tests for _get_display_name helper function.""" + + def test_returns_mapped_name_for_known_providers(self): + """Should return the mapped display name for known providers.""" + assert _get_display_name("openai") == "OpenAI" + assert _get_display_name("fireworks_ai") == "Fireworks AI" + assert _get_display_name("together_ai") == "Together AI" + + def test_falls_back_to_title_case_for_unknown(self): + """Should fallback to title-case with underscore replacement.""" + assert _get_display_name("my_custom_provider") == "My Custom Provider" + assert _get_display_name("unknown") == "Unknown" + + def test_handles_empty_string(self): + """Should handle empty string gracefully.""" + result = _get_display_name("") + assert result == "" + + +# --------------------------------------------------------------------------- +# Tests for _entry_to_model_info +# --------------------------------------------------------------------------- + + +class TestEntryToModelInfo: + """Tests for _entry_to_model_info helper function.""" + + def test_converts_basic_entry(self): + """Should convert a model_cost entry to ModelInfo.""" + entry = { + "input_cost_per_token": 0.00003, + "output_cost_per_token": 0.00006, + "max_input_tokens": 128000, + "max_output_tokens": 8192, + "supports_vision": True, + "supports_function_calling": True, + } + + result = _entry_to_model_info("gpt-4", entry) + + assert result.name == "gpt-4" + assert result.litellm_id == "gpt-4" + assert result.input_cost_per_million == 30.0 + assert result.output_cost_per_million == 60.0 + assert result.max_input_tokens == 128000 + assert result.max_output_tokens == 8192 + assert result.supports_vision is True + assert result.supports_function_calling is True + + def test_handles_provider_prefix_in_model_id(self): + """Should extract model name from provider/model format.""" + entry = {"input_cost_per_token": 0, "output_cost_per_token": 0} + + result = _entry_to_model_info("anthropic/claude-3-opus", entry) + + assert result.name == "claude-3-opus" + assert result.litellm_id == "anthropic/claude-3-opus" + + def test_handles_missing_cost_fields(self): + """Should handle entries with missing cost fields.""" + entry = {} + + result = _entry_to_model_info("test-model", entry) + + assert result.input_cost_per_million == 0.0 + assert result.output_cost_per_million == 0.0 + + def test_handles_none_cost_values(self): + """Should handle None cost values.""" + entry = { + "input_cost_per_token": None, + "output_cost_per_token": None, + } + + result = _entry_to_model_info("test-model", entry) + + assert result.input_cost_per_million == 0.0 + assert result.output_cost_per_million == 0.0 + + def test_converts_per_token_to_per_million(self): + """Should correctly convert per-token costs to per-million.""" + entry = { + "input_cost_per_token": 0.000001, # $1 per million + "output_cost_per_token": 0.000002, # $2 per million + } + + result = _entry_to_model_info("test", entry) + + assert result.input_cost_per_million == 1.0 + assert result.output_cost_per_million == 2.0 + + +# --------------------------------------------------------------------------- +# Tests for get_top_providers (with mocking) +# --------------------------------------------------------------------------- + + +class TestGetTopProviders: + """Tests for get_top_providers function.""" + + def test_returns_empty_list_when_litellm_unavailable(self): + """Should return empty list when litellm is not available.""" + with mock.patch( + "pdd.setup.litellm_registry.is_litellm_available", return_value=False + ): + result = get_top_providers() + assert result == [] + + def test_returns_list_of_provider_info(self): + """Should return a list of ProviderInfo objects.""" + # Only test if litellm is actually available + if not is_litellm_available(): + pytest.skip("litellm not installed") + + result = get_top_providers() + + assert isinstance(result, list) + if result: + assert isinstance(result[0], ProviderInfo) + + def test_includes_major_providers(self): + """Top providers should include major cloud providers.""" + if not is_litellm_available(): + pytest.skip("litellm not installed") + + result = get_top_providers() + provider_names = [p.name for p in result] + + # At least some major providers should be present + major_providers = {"openai", "anthropic", "gemini", "mistral"} + found = set(provider_names) & major_providers + assert len(found) > 0, f"Expected some major providers, got {provider_names}" + + +# --------------------------------------------------------------------------- +# Tests for get_all_providers (with mocking) +# --------------------------------------------------------------------------- + + +class TestGetAllProviders: + """Tests for get_all_providers function.""" + + def test_returns_empty_list_when_litellm_unavailable(self): + """Should return empty list when litellm is not available.""" + with mock.patch( + "pdd.setup.litellm_registry.is_litellm_available", return_value=False + ): + result = get_all_providers() + assert result == [] + + def test_returns_sorted_by_model_count(self): + """Should return providers sorted by model count descending.""" + if not is_litellm_available(): + pytest.skip("litellm not installed") + + result = get_all_providers() + + if len(result) > 1: + for i in range(len(result) - 1): + assert result[i].model_count >= result[i + 1].model_count + + def test_all_providers_have_at_least_one_model(self): + """All returned providers should have at least one chat model.""" + if not is_litellm_available(): + pytest.skip("litellm not installed") + + result = get_all_providers() + + for provider in result: + assert provider.model_count >= 1 + + +# --------------------------------------------------------------------------- +# Tests for search_providers +# --------------------------------------------------------------------------- + + +class TestSearchProviders: + """Tests for search_providers function.""" + + def test_empty_query_returns_all_providers(self): + """Empty query should return all providers.""" + if not is_litellm_available(): + pytest.skip("litellm not installed") + + all_providers = get_all_providers() + search_result = search_providers("") + + assert len(search_result) == len(all_providers) + + def test_case_insensitive_search(self): + """Search should be case-insensitive.""" + if not is_litellm_available(): + pytest.skip("litellm not installed") + + result_lower = search_providers("openai") + result_upper = search_providers("OPENAI") + result_mixed = search_providers("OpenAI") + + # All should return the same results + assert len(result_lower) == len(result_upper) == len(result_mixed) + + def test_searches_in_name_and_display_name(self): + """Should search in both provider name and display name.""" + if not is_litellm_available(): + pytest.skip("litellm not installed") + + # Search by display name component + result = search_providers("Gemini") + provider_names = [p.name for p in result] + + # Should find gemini provider + assert any("gemini" in name.lower() for name in provider_names) + + def test_returns_empty_for_no_match(self): + """Should return empty list when no providers match.""" + result = search_providers("xyznonexistentprovider123") + assert result == [] + + def test_partial_match(self): + """Should match partial strings.""" + if not is_litellm_available(): + pytest.skip("litellm not installed") + + # "open" should match "openai" + result = search_providers("open") + if result: + assert any("open" in p.name.lower() for p in result) + + +# --------------------------------------------------------------------------- +# Tests for get_models_for_provider +# --------------------------------------------------------------------------- + + +class TestGetModelsForProvider: + """Tests for get_models_for_provider function.""" + + def test_returns_empty_list_when_litellm_unavailable(self): + """Should return empty list when litellm is not available.""" + with mock.patch( + "pdd.setup.litellm_registry.is_litellm_available", return_value=False + ): + result = get_models_for_provider("openai") + assert result == [] + + def test_returns_list_of_model_info(self): + """Should return a list of ModelInfo objects.""" + if not is_litellm_available(): + pytest.skip("litellm not installed") + + result = get_models_for_provider("openai") + + assert isinstance(result, list) + if result: + assert isinstance(result[0], ModelInfo) + + def test_returns_sorted_by_name(self): + """Models should be sorted by name.""" + if not is_litellm_available(): + pytest.skip("litellm not installed") + + result = get_models_for_provider("openai") + + if len(result) > 1: + names = [m.name for m in result] + assert names == sorted(names) + + def test_returns_empty_for_unknown_provider(self): + """Should return empty list for unknown provider.""" + if not is_litellm_available(): + pytest.skip("litellm not installed") + + result = get_models_for_provider("nonexistent_provider_xyz") + assert result == [] + + def test_model_info_has_litellm_id(self): + """Each model should have a litellm_id for API calls.""" + if not is_litellm_available(): + pytest.skip("litellm not installed") + + result = get_models_for_provider("anthropic") + + for model in result: + assert model.litellm_id is not None + assert len(model.litellm_id) > 0 + + def test_handles_vertex_ai_subproviders(self): + """Should aggregate vertex_ai sub-providers.""" + if not is_litellm_available(): + pytest.skip("litellm not installed") + + result = get_models_for_provider("vertex_ai") + + # Should return some models (vertex_ai has many) + # The exact count depends on litellm version + assert isinstance(result, list) + + +# --------------------------------------------------------------------------- +# Integration tests +# --------------------------------------------------------------------------- + + +class TestIntegration: + """Integration tests that verify module works end-to-end.""" + + def test_workflow_search_to_models(self): + """Test typical workflow: search provider -> get models.""" + if not is_litellm_available(): + pytest.skip("litellm not installed") + + # Search for anthropic + providers = search_providers("anthropic") + assert len(providers) > 0 + + # Get the first matching provider + provider = providers[0] + assert provider.name == "anthropic" or "anthropic" in provider.name + + # Get models for that provider + models = get_models_for_provider(provider.name) + assert len(models) > 0 + + # Each model should have required fields + for model in models: + assert model.litellm_id + assert isinstance(model.input_cost_per_million, (int, float)) + assert isinstance(model.output_cost_per_million, (int, float)) + + def test_top_providers_have_valid_data(self): + """Top providers should all have valid, complete data.""" + if not is_litellm_available(): + pytest.skip("litellm not installed") + + top = get_top_providers() + + for provider in top: + # Each provider should have a name and display name + assert provider.name + assert provider.display_name + + # Model count should be positive + assert provider.model_count > 0 + + # Sample models should not exceed 3 + assert len(provider.sample_models) <= 3 + + # Can get models for this provider + models = get_models_for_provider(provider.name) + assert len(models) == provider.model_count diff --git a/tests/test_provider_manager.py b/tests/test_provider_manager.py new file mode 100644 index 000000000..34ffaf317 --- /dev/null +++ b/tests/test_provider_manager.py @@ -0,0 +1,1014 @@ +"""Tests for pdd/setup/provider_manager.py""" + +import csv +import os +import tempfile +from pathlib import Path +from unittest import mock + +import pytest + +from pdd.setup.provider_manager import ( + CSV_FIELDNAMES, + _get_shell_name, + _get_pdd_dir, + _get_api_env_path, + _get_user_csv_path, + _read_csv, + _write_csv_atomic, + _read_api_env_lines, + _write_api_env_atomic, + _save_key_to_api_env, + _comment_out_key_in_api_env, + _is_key_set, + add_provider_from_registry, + add_custom_provider, + remove_models_by_provider, + remove_individual_models, +) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def temp_home(tmp_path, monkeypatch): + """Create a temporary home directory with .pdd folder.""" + pdd_dir = tmp_path / ".pdd" + pdd_dir.mkdir(parents=True, exist_ok=True) + monkeypatch.setattr(Path, "home", lambda: tmp_path) + monkeypatch.setenv("SHELL", "/bin/bash") + return tmp_path + + +@pytest.fixture +def sample_csv(temp_home): + """Create a sample llm_model.csv with test data.""" + csv_path = temp_home / ".pdd" / "llm_model.csv" + rows = [ + { + "provider": "OpenAI", + "model": "gpt-4", + "input": "30.0", + "output": "60.0", + "coding_arena_elo": "1000", + "base_url": "", + "api_key": "OPENAI_API_KEY", + "max_reasoning_tokens": "0", + "structured_output": "True", + "reasoning_type": "", + "location": "", + }, + { + "provider": "OpenAI", + "model": "gpt-3.5-turbo", + "input": "0.5", + "output": "1.5", + "coding_arena_elo": "1000", + "base_url": "", + "api_key": "OPENAI_API_KEY", + "max_reasoning_tokens": "0", + "structured_output": "True", + "reasoning_type": "", + "location": "", + }, + { + "provider": "Anthropic", + "model": "claude-3-opus", + "input": "15.0", + "output": "75.0", + "coding_arena_elo": "1000", + "base_url": "", + "api_key": "ANTHROPIC_API_KEY", + "max_reasoning_tokens": "0", + "structured_output": "True", + "reasoning_type": "", + "location": "", + }, + ] + + with open(csv_path, "w", newline="", encoding="utf-8") as f: + writer = csv.DictWriter(f, fieldnames=CSV_FIELDNAMES) + writer.writeheader() + writer.writerows(rows) + + return csv_path + + +@pytest.fixture +def sample_api_env(temp_home): + """Create a sample api-env.bash file.""" + api_env_path = temp_home / ".pdd" / "api-env.bash" + api_env_path.write_text( + "export OPENAI_API_KEY=sk-test123\n" + "export ANTHROPIC_API_KEY=ant-test456\n" + ) + return api_env_path + + +# --------------------------------------------------------------------------- +# Tests for path helpers +# --------------------------------------------------------------------------- + + +class TestPathHelpers: + """Tests for path helper functions.""" + + def test_get_shell_name_bash(self, monkeypatch): + """Should detect bash shell.""" + monkeypatch.setenv("SHELL", "/bin/bash") + assert _get_shell_name() == "bash" + + def test_get_shell_name_zsh(self, monkeypatch): + """Should detect zsh shell.""" + monkeypatch.setenv("SHELL", "/usr/local/bin/zsh") + assert _get_shell_name() == "zsh" + + def test_get_shell_name_fish(self, monkeypatch): + """Should detect fish shell.""" + monkeypatch.setenv("SHELL", "/opt/homebrew/bin/fish") + assert _get_shell_name() == "fish" + + def test_get_shell_name_defaults_to_bash(self, monkeypatch): + """Should default to bash for unknown shells.""" + monkeypatch.setenv("SHELL", "/bin/unknown_shell") + assert _get_shell_name() == "bash" + + def test_get_shell_name_no_shell_var(self, monkeypatch): + """Should default to bash when SHELL not set.""" + monkeypatch.delenv("SHELL", raising=False) + # Implementation defaults to /bin/bash when SHELL is not set + result = _get_shell_name() + assert result == "bash" + + def test_get_pdd_dir_creates_directory(self, tmp_path, monkeypatch): + """Should create ~/.pdd if it doesn't exist.""" + monkeypatch.setattr(Path, "home", lambda: tmp_path) + pdd_dir = tmp_path / ".pdd" + + # Directory shouldn't exist yet + assert not pdd_dir.exists() + + result = _get_pdd_dir() + + assert result == pdd_dir + assert pdd_dir.exists() + + def test_get_api_env_path(self, temp_home, monkeypatch): + """Should return correct api-env path for shell.""" + monkeypatch.setenv("SHELL", "/bin/zsh") + result = _get_api_env_path() + assert result == temp_home / ".pdd" / "api-env.zsh" + + def test_get_user_csv_path(self, temp_home): + """Should return correct user CSV path.""" + result = _get_user_csv_path() + assert result == temp_home / ".pdd" / "llm_model.csv" + + +# --------------------------------------------------------------------------- +# Tests for CSV I/O helpers +# --------------------------------------------------------------------------- + + +class TestCsvHelpers: + """Tests for CSV read/write functions.""" + + def test_read_csv_returns_list_of_dicts(self, sample_csv): + """Should read CSV and return list of row dictionaries.""" + result = _read_csv(sample_csv) + + assert isinstance(result, list) + assert len(result) == 3 + assert result[0]["provider"] == "OpenAI" + assert result[0]["model"] == "gpt-4" + + def test_read_csv_missing_file(self, temp_home): + """Should return empty list for missing file.""" + result = _read_csv(temp_home / ".pdd" / "nonexistent.csv") + assert result == [] + + def test_write_csv_atomic_creates_file(self, temp_home): + """Should create CSV file with correct content.""" + csv_path = temp_home / ".pdd" / "test.csv" + rows = [ + {"provider": "Test", "model": "test-model", "input": "1.0", "output": "2.0"}, + ] + + _write_csv_atomic(csv_path, rows) + + assert csv_path.exists() + result = _read_csv(csv_path) + assert len(result) == 1 + assert result[0]["provider"] == "Test" + + def test_write_csv_atomic_is_atomic(self, temp_home): + """Write should be atomic - no partial writes on failure.""" + csv_path = temp_home / ".pdd" / "test.csv" + + # Write initial content + _write_csv_atomic(csv_path, [{"provider": "Original"}]) + + # Verify temp files are cleaned up + pdd_dir = temp_home / ".pdd" + temp_files = list(pdd_dir.glob(".llm_model_*.tmp")) + assert len(temp_files) == 0 + + def test_write_csv_atomic_fills_missing_fields(self, temp_home): + """Should fill missing fields with empty strings.""" + csv_path = temp_home / ".pdd" / "test.csv" + rows = [{"provider": "Test", "model": "test-model"}] # Missing many fields + + _write_csv_atomic(csv_path, rows) + + result = _read_csv(csv_path) + # All CSV_FIELDNAMES should be present + for field in CSV_FIELDNAMES: + assert field in result[0] + + +# --------------------------------------------------------------------------- +# Tests for api-env file helpers +# --------------------------------------------------------------------------- + + +class TestApiEnvHelpers: + """Tests for api-env file read/write functions.""" + + def test_read_api_env_lines(self, sample_api_env): + """Should read api-env file lines.""" + result = _read_api_env_lines(sample_api_env) + + assert len(result) == 2 + assert "OPENAI_API_KEY" in result[0] + + def test_read_api_env_lines_missing_file(self, temp_home): + """Should return empty list for missing file.""" + result = _read_api_env_lines(temp_home / ".pdd" / "nonexistent") + assert result == [] + + def test_write_api_env_atomic(self, temp_home): + """Should write api-env file atomically.""" + env_path = temp_home / ".pdd" / "api-env.bash" + lines = ["export TEST_KEY=value\n"] + + _write_api_env_atomic(env_path, lines) + + assert env_path.exists() + content = env_path.read_text() + assert "TEST_KEY" in content + + def test_save_key_to_api_env_new_key(self, temp_home, monkeypatch): + """Should add new key to api-env file.""" + monkeypatch.setenv("SHELL", "/bin/bash") + env_path = temp_home / ".pdd" / "api-env.bash" + + _save_key_to_api_env("NEW_KEY", "new-value") + + content = env_path.read_text() + # shlex.quote() doesn't quote simple values without special chars + assert 'export NEW_KEY=' in content + assert 'new-value' in content + + def test_save_key_to_api_env_updates_existing(self, sample_api_env, monkeypatch): + """Should update existing key in api-env file.""" + monkeypatch.setenv("SHELL", "/bin/bash") + + _save_key_to_api_env("OPENAI_API_KEY", "sk-updated") + + content = sample_api_env.read_text() + # shlex.quote() doesn't quote simple values without special chars + assert 'export OPENAI_API_KEY=' in content + assert 'sk-updated' in content + # Should not have duplicate entries + assert content.count("OPENAI_API_KEY") == 1 + + def test_save_key_to_api_env_uncomments_commented_key(self, temp_home, monkeypatch): + """Should replace commented key with new value.""" + monkeypatch.setenv("SHELL", "/bin/bash") + env_path = temp_home / ".pdd" / "api-env.bash" + env_path.write_text("# export OLD_KEY=old-value\n") + + _save_key_to_api_env("OLD_KEY", "new-value") + + content = env_path.read_text() + # shlex.quote() doesn't quote simple values without special chars + assert 'export OLD_KEY=' in content + assert 'new-value' in content + assert "# export OLD_KEY" not in content + + def test_save_key_with_special_characters(self, temp_home, monkeypatch): + """Should handle API keys with special characters.""" + monkeypatch.setenv("SHELL", "/bin/bash") + + # Key with special shell characters + special_value = 'key$with"special\'chars' + _save_key_to_api_env("SPECIAL_KEY", special_value) + + env_path = temp_home / ".pdd" / "api-env.bash" + content = env_path.read_text() + assert "SPECIAL_KEY" in content + + def test_comment_out_key_in_api_env(self, sample_api_env, monkeypatch): + """Should comment out key with date annotation.""" + monkeypatch.setenv("SHELL", "/bin/bash") + + _comment_out_key_in_api_env("OPENAI_API_KEY") + + content = sample_api_env.read_text() + assert "# Commented out by pdd setup on" in content + assert "# export OPENAI_API_KEY" in content + + def test_comment_out_preserves_other_keys(self, sample_api_env, monkeypatch): + """Should preserve other keys when commenting out one.""" + monkeypatch.setenv("SHELL", "/bin/bash") + + _comment_out_key_in_api_env("OPENAI_API_KEY") + + content = sample_api_env.read_text() + # ANTHROPIC_API_KEY should still be active + assert "export ANTHROPIC_API_KEY=ant-test456" in content + + +# --------------------------------------------------------------------------- +# Tests for _is_key_set +# --------------------------------------------------------------------------- + + +class TestIsKeySet: + """Tests for _is_key_set function.""" + + def test_detects_key_in_shell_env(self, temp_home, monkeypatch): + """Should detect key set in shell environment.""" + monkeypatch.setenv("TEST_KEY", "test-value") + + result = _is_key_set("TEST_KEY") + + assert result == "shell environment" + + def test_detects_key_in_api_env_file(self, sample_api_env, monkeypatch): + """Should detect key in api-env file.""" + monkeypatch.setenv("SHELL", "/bin/bash") + monkeypatch.delenv("OPENAI_API_KEY", raising=False) + + result = _is_key_set("OPENAI_API_KEY") + + assert "api-env.bash" in result + + def test_returns_none_when_key_not_set(self, temp_home, monkeypatch): + """Should return None when key is not set anywhere.""" + monkeypatch.delenv("NONEXISTENT_KEY", raising=False) + + result = _is_key_set("NONEXISTENT_KEY") + + assert result is None + + def test_dotenv_priority_over_shell(self, temp_home, monkeypatch): + """Should check .env file first.""" + monkeypatch.setenv("TEST_KEY", "shell-value") + + with mock.patch("dotenv.dotenv_values", return_value={"TEST_KEY": "dotenv-value"}): + result = _is_key_set("TEST_KEY") + + assert result == ".env file" + + +# --------------------------------------------------------------------------- +# Tests for add_provider_from_registry (mocked) +# --------------------------------------------------------------------------- + + +class TestAddProviderFromRegistry: + """Tests for add_provider_from_registry function.""" + + def test_returns_false_when_litellm_unavailable(self, temp_home): + """Should return False when litellm is not available.""" + with mock.patch( + "pdd.setup.provider_manager.is_litellm_available", return_value=False + ): + with mock.patch("pdd.setup.provider_manager.console"): + result = add_provider_from_registry() + + assert result is False + + def test_returns_false_on_empty_selection(self, temp_home): + """Should return False when user enters empty selection.""" + with mock.patch( + "pdd.setup.provider_manager.is_litellm_available", return_value=True + ): + with mock.patch( + "pdd.setup.provider_manager.get_top_providers", + return_value=[], + ): + with mock.patch("pdd.setup.provider_manager.Prompt") as mock_prompt: + mock_prompt.ask.return_value = "" + with mock.patch("pdd.setup.provider_manager.console"): + result = add_provider_from_registry() + + assert result is False + + def test_adds_models_to_csv(self, temp_home): + """Should add selected models to user CSV.""" + from pdd.setup.litellm_registry import ProviderInfo, ModelInfo + + mock_provider = ProviderInfo( + name="test_provider", + display_name="Test Provider", + api_key_env_var="TEST_API_KEY", + model_count=2, + sample_models=["model1", "model2"], + ) + + mock_models = [ + ModelInfo( + name="model1", + litellm_id="test_provider/model1", + input_cost_per_million=1.0, + output_cost_per_million=2.0, + supports_function_calling=True, + ), + ] + + with mock.patch( + "pdd.setup.provider_manager.is_litellm_available", return_value=True + ): + with mock.patch( + "pdd.setup.provider_manager.get_top_providers", + return_value=[mock_provider], + ): + with mock.patch( + "pdd.setup.provider_manager.get_models_for_provider", + return_value=mock_models, + ): + with mock.patch("pdd.setup.provider_manager.Prompt") as mock_prompt: + # Select provider 1, then model 1 + mock_prompt.ask.side_effect = ["1", "1", "test-api-key"] + + with mock.patch("pdd.setup.provider_manager.Confirm") as mock_confirm: + mock_confirm.ask.return_value = False + + with mock.patch("pdd.setup.provider_manager.console"): + with mock.patch( + "pdd.setup.provider_manager._is_key_set", + return_value=None, + ): + result = add_provider_from_registry() + + # Check that model was added to CSV + csv_path = temp_home / ".pdd" / "llm_model.csv" + if csv_path.exists(): + rows = _read_csv(csv_path) + assert any(r["model"] == "test_provider/model1" for r in rows) + + +# --------------------------------------------------------------------------- +# Tests for add_custom_provider (mocked) +# --------------------------------------------------------------------------- + + +class TestAddCustomProvider: + """Tests for add_custom_provider function.""" + + def test_returns_false_on_empty_provider(self, temp_home): + """Should return False when provider name is empty.""" + with mock.patch("pdd.setup.provider_manager.Prompt") as mock_prompt: + mock_prompt.ask.return_value = "" + with mock.patch("pdd.setup.provider_manager.console"): + result = add_custom_provider() + + assert result is False + + def test_adds_custom_provider_to_csv(self, temp_home): + """Should add custom provider to user CSV.""" + with mock.patch("pdd.setup.provider_manager.Prompt") as mock_prompt: + mock_prompt.ask.side_effect = [ + "custom_llm", # provider prefix + "my-model", # model name + "CUSTOM_API_KEY", # api key env var + "", # base url (optional) + "1.0", # input cost + "2.0", # output cost + ] + with mock.patch("pdd.setup.provider_manager.Confirm") as mock_confirm: + mock_confirm.ask.return_value = False # Don't enter key value now + with mock.patch("pdd.setup.provider_manager.console"): + result = add_custom_provider() + + assert result is True + + # Verify CSV was updated + csv_path = temp_home / ".pdd" / "llm_model.csv" + rows = _read_csv(csv_path) + assert len(rows) == 1 + assert rows[0]["provider"] == "custom_llm" + assert rows[0]["model"] == "custom_llm/my-model" + assert rows[0]["api_key"] == "CUSTOM_API_KEY" + + def test_saves_api_key_when_provided(self, temp_home, monkeypatch): + """Should save API key to api-env when user provides it.""" + monkeypatch.setenv("SHELL", "/bin/bash") + + with mock.patch("pdd.setup.provider_manager.Prompt") as mock_prompt: + mock_prompt.ask.side_effect = [ + "custom_llm", + "my-model", + "CUSTOM_API_KEY", + "", + "0.0", + "0.0", + "sk-my-secret-key", # API key value + ] + with mock.patch("pdd.setup.provider_manager.Confirm") as mock_confirm: + mock_confirm.ask.return_value = True # Yes, enter key value + with mock.patch("pdd.setup.provider_manager.console"): + result = add_custom_provider() + + assert result is True + + # Verify api-env was updated + env_path = temp_home / ".pdd" / "api-env.bash" + content = env_path.read_text() + assert "CUSTOM_API_KEY" in content + + +# --------------------------------------------------------------------------- +# Tests for remove_models_by_provider (mocked) +# --------------------------------------------------------------------------- + + +class TestRemoveModelsByProvider: + """Tests for remove_models_by_provider function.""" + + def test_returns_false_when_no_models(self, temp_home): + """Should return False when no models configured.""" + with mock.patch("pdd.setup.provider_manager.console"): + result = remove_models_by_provider() + + assert result is False + + def test_returns_false_on_cancel(self, sample_csv, temp_home): + """Should return False when user cancels.""" + with mock.patch("pdd.setup.provider_manager.Prompt") as mock_prompt: + mock_prompt.ask.return_value = "" # Empty = cancel + with mock.patch("pdd.setup.provider_manager.console"): + result = remove_models_by_provider() + + assert result is False + + def test_removes_all_models_for_provider(self, sample_csv, temp_home, monkeypatch): + """Should remove all models with matching api_key.""" + monkeypatch.setenv("SHELL", "/bin/bash") + + with mock.patch("pdd.setup.provider_manager.Prompt") as mock_prompt: + mock_prompt.ask.return_value = "1" # Select first provider + with mock.patch("pdd.setup.provider_manager.Confirm") as mock_confirm: + mock_confirm.ask.return_value = True # Confirm removal + with mock.patch("pdd.setup.provider_manager.console"): + result = remove_models_by_provider() + + assert result is True + + # Check that models were removed + rows = _read_csv(sample_csv) + # One provider should have been removed + assert len(rows) < 3 + + +# --------------------------------------------------------------------------- +# Tests for remove_individual_models (mocked) +# --------------------------------------------------------------------------- + + +class TestRemoveIndividualModels: + """Tests for remove_individual_models function.""" + + def test_returns_false_when_no_models(self, temp_home): + """Should return False when no models configured.""" + with mock.patch("pdd.setup.provider_manager.console"): + result = remove_individual_models() + + assert result is False + + def test_returns_false_on_cancel(self, sample_csv, temp_home): + """Should return False when user cancels.""" + with mock.patch("pdd.setup.provider_manager.Prompt") as mock_prompt: + mock_prompt.ask.return_value = "" # Empty = cancel + with mock.patch("pdd.setup.provider_manager.console"): + result = remove_individual_models() + + assert result is False + + def test_removes_selected_models(self, sample_csv, temp_home): + """Should remove only selected models.""" + with mock.patch("pdd.setup.provider_manager.Prompt") as mock_prompt: + mock_prompt.ask.return_value = "1" # Remove first model + with mock.patch("pdd.setup.provider_manager.Confirm") as mock_confirm: + mock_confirm.ask.return_value = True # Confirm + with mock.patch("pdd.setup.provider_manager.console"): + result = remove_individual_models() + + assert result is True + + # Check that one model was removed + rows = _read_csv(sample_csv) + assert len(rows) == 2 + + def test_removes_multiple_models(self, sample_csv, temp_home): + """Should handle comma-separated model selection.""" + with mock.patch("pdd.setup.provider_manager.Prompt") as mock_prompt: + mock_prompt.ask.return_value = "1, 2" # Remove first two models + with mock.patch("pdd.setup.provider_manager.Confirm") as mock_confirm: + mock_confirm.ask.return_value = True # Confirm + with mock.patch("pdd.setup.provider_manager.console"): + result = remove_individual_models() + + assert result is True + + # Check that two models were removed + rows = _read_csv(sample_csv) + assert len(rows) == 1 + + +# --------------------------------------------------------------------------- +# Edge case tests +# --------------------------------------------------------------------------- + + +class TestEdgeCases: + """Test edge cases and error handling.""" + + def test_csv_write_atomic_cleans_up_on_error(self, temp_home): + """Should clean up temp file on write error.""" + csv_path = temp_home / ".pdd" / "test.csv" + + # Create a scenario where write might fail + with mock.patch("os.fdopen", side_effect=IOError("Simulated error")): + with pytest.raises(IOError): + _write_csv_atomic(csv_path, [{"provider": "Test"}]) + + # No temp files should remain + temp_files = list((temp_home / ".pdd").glob(".llm_model_*.tmp")) + assert len(temp_files) == 0 + + def test_handles_special_characters_in_api_keys(self, temp_home, monkeypatch): + """Should handle API key values with special shell characters.""" + monkeypatch.setenv("SHELL", "/bin/bash") + + # Characters that might cause issues in shell scripts + special_values = [ + 'key$with$dollars', + 'key"with"quotes', + "key'with'single", + 'key`with`backticks', + 'key\\with\\backslashes', + 'key with spaces', + 'key;with;semicolons', + ] + + for i, value in enumerate(special_values): + key_name = f"SPECIAL_KEY_{i}" + _save_key_to_api_env(key_name, value) + + env_path = temp_home / ".pdd" / "api-env.bash" + content = env_path.read_text() + + # All keys should be present + for i in range(len(special_values)): + assert f"SPECIAL_KEY_{i}" in content + + def test_handles_unicode_in_model_names(self, temp_home): + """Should handle unicode characters in model names.""" + csv_path = temp_home / ".pdd" / "llm_model.csv" + rows = [ + { + "provider": "Tëst", + "model": "模型-émoji-🤖", + "input": "1.0", + "output": "2.0", + } + ] + + _write_csv_atomic(csv_path, rows) + + result = _read_csv(csv_path) + assert result[0]["provider"] == "Tëst" + assert "模型" in result[0]["model"] + + def test_handles_empty_csv_fields(self, temp_home): + """Should handle rows with empty optional fields.""" + csv_path = temp_home / ".pdd" / "llm_model.csv" + rows = [ + { + "provider": "Test", + "model": "test-model", + # All other fields will be filled with empty strings + } + ] + + _write_csv_atomic(csv_path, rows) + + result = _read_csv(csv_path) + assert result[0]["provider"] == "Test" + assert result[0]["api_key"] == "" + + def test_concurrent_writes_safe(self, temp_home): + """Atomic writes should be safe for concurrent access.""" + csv_path = temp_home / ".pdd" / "llm_model.csv" + + # Write initial data + _write_csv_atomic(csv_path, [{"provider": "Initial"}]) + + # Simulate concurrent write + _write_csv_atomic(csv_path, [{"provider": "Updated"}]) + + result = _read_csv(csv_path) + assert len(result) == 1 + assert result[0]["provider"] == "Updated" + + +# --------------------------------------------------------------------------- +# Shell script execution tests (following test_setup_tool.py pattern) +# --------------------------------------------------------------------------- + + +class TestApiEnvShellExecution: + """ + Tests that verify generated api-env scripts can be sourced and + correctly preserve API key values, especially with special characters. + + These tests follow the rigorous pattern from test_setup_tool.py, + actually executing shell scripts to verify correctness. + """ + + def _shell_available(self, shell: str) -> bool: + """Check if a shell is available on the system.""" + import shutil + return shutil.which(shell) is not None + + def test_api_env_script_valid_bash_syntax(self, temp_home, monkeypatch): + """ + Generated api-env script should have valid bash syntax. + This test catches quoting errors that would break shell parsing. + """ + monkeypatch.setenv("SHELL", "/bin/bash") + + # Save a key with special characters + special_key = 'AIzaSyAbCdEf123456$var"quote\'backtick\\slash' + _save_key_to_api_env("TEST_KEY", special_key) + + env_path = temp_home / ".pdd" / "api-env.bash" + + # Run bash syntax check + import subprocess + result = subprocess.run( + ["bash", "-n", str(env_path)], + capture_output=True, + text=True, + timeout=5, + ) + + assert result.returncode == 0, ( + f"Generated script has bash syntax errors: {result.stderr}\n" + f"Script content:\n{env_path.read_text()}" + ) + + def test_api_env_script_valid_zsh_syntax(self, temp_home, monkeypatch): + """Generated api-env script should have valid zsh syntax.""" + if not self._shell_available("zsh"): + pytest.skip("zsh not available") + + monkeypatch.setenv("SHELL", "/bin/zsh") + + special_key = 'AIzaSyAbCdEf123456$var"quote\'backtick\\slash' + _save_key_to_api_env("TEST_KEY", special_key) + + env_path = temp_home / ".pdd" / "api-env.zsh" + + import subprocess + result = subprocess.run( + ["zsh", "-n", str(env_path)], + capture_output=True, + text=True, + timeout=5, + ) + + assert result.returncode == 0, ( + f"Generated script has zsh syntax errors: {result.stderr}\n" + f"Script content:\n{env_path.read_text()}" + ) + + def test_api_env_script_can_be_sourced_bash(self, temp_home, monkeypatch): + """Script should be sourceable in bash without errors.""" + monkeypatch.setenv("SHELL", "/bin/bash") + + special_key = 'key"with\'many$special`characters' + _save_key_to_api_env("TEST_KEY", special_key) + + env_path = temp_home / ".pdd" / "api-env.bash" + + import subprocess + result = subprocess.run( + ["bash", "-c", f"source {env_path} && exit 0"], + capture_output=True, + text=True, + timeout=5, + ) + + assert result.returncode == 0, ( + f"Cannot source script in bash: {result.stderr}\n" + f"Script content:\n{env_path.read_text()}" + ) + + def test_api_env_preserves_key_value_bash(self, temp_home, monkeypatch): + """ + API key value should be preserved exactly when script is sourced. + This is the critical test - verifies the key survives shell escaping. + """ + monkeypatch.setenv("SHELL", "/bin/bash") + + original_key = 'AIzaSyAbCdEf123456$var"quote\'backtick\\slash' + _save_key_to_api_env("TEST_KEY", original_key) + + env_path = temp_home / ".pdd" / "api-env.bash" + + import subprocess + result = subprocess.run( + [ + "bash", "-c", + f"source {env_path} && python3 -c \"import os; print(os.environ.get('TEST_KEY', ''))\"" + ], + capture_output=True, + text=True, + timeout=5, + ) + + assert result.returncode == 0, ( + f"Failed to source script and read env var: {result.stderr}\n" + f"Script content:\n{env_path.read_text()}" + ) + + extracted_key = result.stdout.strip() + assert extracted_key == original_key, ( + f"Key value was corrupted during shell escaping.\n" + f"Original: {repr(original_key)}\n" + f"Extracted: {repr(extracted_key)}\n" + f"Script content:\n{env_path.read_text()}" + ) + + def test_api_env_preserves_key_value_zsh(self, temp_home, monkeypatch): + """API key value should be preserved exactly in zsh.""" + if not self._shell_available("zsh"): + pytest.skip("zsh not available") + + monkeypatch.setenv("SHELL", "/bin/zsh") + + original_key = 'AIzaSyAbCdEf123456$var"quote\'backtick\\slash' + _save_key_to_api_env("TEST_KEY", original_key) + + env_path = temp_home / ".pdd" / "api-env.zsh" + + import subprocess + result = subprocess.run( + [ + "zsh", "-c", + f"source {env_path} && python3 -c \"import os; print(os.environ.get('TEST_KEY', ''))\"" + ], + capture_output=True, + text=True, + timeout=5, + ) + + assert result.returncode == 0, ( + f"Failed to source script: {result.stderr}\n" + f"Script content:\n{env_path.read_text()}" + ) + + extracted_key = result.stdout.strip() + assert extracted_key == original_key, ( + f"Key value was corrupted in zsh.\n" + f"Original: {repr(original_key)}\n" + f"Extracted: {repr(extracted_key)}\n" + f"Script content:\n{env_path.read_text()}" + ) + + def test_api_env_with_various_problematic_characters(self, temp_home, monkeypatch): + """ + Test with various characters that commonly cause shell escaping issues. + Each character tested individually to identify specific failures. + """ + monkeypatch.setenv("SHELL", "/bin/bash") + + problematic_chars = [ + ('dollar', 'key$value'), + ('double_quote', 'key"value'), + ('single_quote', "key'value"), + ('backtick', 'key`value'), + ('backslash', 'key\\value'), + ('space', 'key value'), + ('semicolon', 'key;value'), + ('ampersand', 'key&value'), + ('pipe', 'key|value'), + ('newline', 'key\nvalue'), + ('tab', 'key\tvalue'), + ] + + import subprocess + + for name, test_value in problematic_chars: + key_name = f"TEST_{name.upper()}" + _save_key_to_api_env(key_name, test_value) + + env_path = temp_home / ".pdd" / "api-env.bash" + + # Verify syntax is valid + syntax_result = subprocess.run( + ["bash", "-n", str(env_path)], + capture_output=True, + text=True, + timeout=5, + ) + + assert syntax_result.returncode == 0, ( + f"Syntax error with '{name}' character: {syntax_result.stderr}\n" + f"Script:\n{env_path.read_text()}" + ) + + # Verify value is preserved + extract_result = subprocess.run( + [ + "bash", "-c", + f"source {env_path} && python3 -c \"import os; print(repr(os.environ.get('{key_name}', '')))\"" + ], + capture_output=True, + text=True, + timeout=5, + ) + + if extract_result.returncode == 0: + extracted = eval(extract_result.stdout.strip()) + assert extracted == test_value, ( + f"Value corrupted for '{name}' character.\n" + f"Expected: {repr(test_value)}\n" + f"Got: {repr(extracted)}" + ) + + def test_multiple_keys_preserved(self, temp_home, monkeypatch): + """Multiple keys should all be preserved correctly.""" + monkeypatch.setenv("SHELL", "/bin/bash") + + keys = { + "OPENAI_API_KEY": "sk-test123", + "ANTHROPIC_API_KEY": "ant-key$special", + "GEMINI_API_KEY": 'gem"quoted\'key', + } + + for key_name, key_value in keys.items(): + _save_key_to_api_env(key_name, key_value) + + env_path = temp_home / ".pdd" / "api-env.bash" + + import subprocess + + for key_name, expected_value in keys.items(): + result = subprocess.run( + [ + "bash", "-c", + f"source {env_path} && python3 -c \"import os; print(os.environ.get('{key_name}', ''))\"" + ], + capture_output=True, + text=True, + timeout=5, + ) + + assert result.returncode == 0 + extracted = result.stdout.strip() + assert extracted == expected_value, ( + f"{key_name} was corrupted.\n" + f"Expected: {repr(expected_value)}\n" + f"Got: {repr(extracted)}" + ) + + def test_normal_key_still_works(self, temp_home, monkeypatch): + """Normal keys without special characters should still work.""" + monkeypatch.setenv("SHELL", "/bin/bash") + + normal_key = "sk-proj-abcdef1234567890ABCDEF" + _save_key_to_api_env("OPENAI_API_KEY", normal_key) + + env_path = temp_home / ".pdd" / "api-env.bash" + + import subprocess + result = subprocess.run( + [ + "bash", "-c", + f"source {env_path} && echo $OPENAI_API_KEY" + ], + capture_output=True, + text=True, + timeout=5, + ) + + assert result.returncode == 0 + assert result.stdout.strip() == normal_key From 6558d2dc9d7cf1d01bb451aa2726c8c4678ba256 Mon Sep 17 00:00:00 2001 From: Niti Goyal Date: Tue, 17 Feb 2026 02:12:18 -0500 Subject: [PATCH 06/10] Restructure pdd setup to more hands-off flow - Starts by boostrapping agentic CLI - Then utomatically scans for API keys, populates llm_model.csv, checks for local LLMs, initializes .pddrc, and tests a model - Removed /setup directory, all setup-related files now exist in root /pdd directory - Users mostly just need to hit 'enter' to finish setup - Add testing files --- context/api_key_scanner_example.py | 2 +- context/cli_detector_example.py | 33 +- context/litellm_registry_example.py | 2 +- context/local_llm_configurator_example.py | 43 - context/model_tester_example.py | 2 +- context/pddrc_initializer_example.py | 2 +- context/setup_tool_example.py | 86 +- pdd/{setup => }/api_key_scanner.py | 2 +- pdd/cli_detector.py | 495 +++++++ pdd/{setup => }/litellm_registry.py | 2 +- pdd/{setup => }/model_tester.py | 0 pdd/{setup => }/pddrc_initializer.py | 0 .../agentic_setup_autoconfig_LLM.prompt | 229 ++++ pdd/prompts/api_key_scanner_python.prompt | 4 +- pdd/prompts/cli_detector_python.prompt | 79 +- pdd/prompts/litellm_registry_python.prompt | 4 +- .../local_llm_configurator_python.prompt | 40 - pdd/prompts/model_tester_python.prompt | 4 +- pdd/prompts/pddrc_initializer_python.prompt | 4 +- pdd/prompts/provider_manager_python.prompt | 4 +- pdd/prompts/setup_tool_python.prompt | 180 ++- pdd/{setup => }/provider_manager.py | 4 +- pdd/setup/__init__.py | 0 pdd/setup/cli_detector.py | 191 --- pdd/setup/local_llm_configurator.py | 377 ------ pdd/setup/setup_tool.py | 167 --- pdd/setup_tool.py | 1200 ++++++++--------- tests/test_api_key_scanner.py | 14 +- tests/test_cli_detector.py | 415 +++++- tests/test_litellm_registry.py | 14 +- tests/test_model_tester.py | 232 ++++ tests/test_pddrc_initializer.py | 207 +++ tests/test_provider_manager.py | 78 +- tests/test_setup_tool.py | 944 +++++-------- 34 files changed, 2827 insertions(+), 2233 deletions(-) delete mode 100644 context/local_llm_configurator_example.py rename pdd/{setup => }/api_key_scanner.py (99%) create mode 100644 pdd/cli_detector.py rename pdd/{setup => }/litellm_registry.py (99%) rename pdd/{setup => }/model_tester.py (100%) rename pdd/{setup => }/pddrc_initializer.py (100%) create mode 100644 pdd/prompts/agentic_setup_autoconfig_LLM.prompt delete mode 100644 pdd/prompts/local_llm_configurator_python.prompt rename pdd/{setup => }/provider_manager.py (99%) delete mode 100644 pdd/setup/__init__.py delete mode 100644 pdd/setup/cli_detector.py delete mode 100644 pdd/setup/local_llm_configurator.py delete mode 100644 pdd/setup/setup_tool.py create mode 100644 tests/test_model_tester.py create mode 100644 tests/test_pddrc_initializer.py diff --git a/context/api_key_scanner_example.py b/context/api_key_scanner_example.py index 325b48110..687d858ca 100644 --- a/context/api_key_scanner_example.py +++ b/context/api_key_scanner_example.py @@ -7,7 +7,7 @@ project_root = Path(__file__).resolve().parent.parent sys.path.append(str(project_root)) -from pdd.setup.api_key_scanner import scan_environment, get_provider_key_names, KeyInfo +from pdd.api_key_scanner import scan_environment, get_provider_key_names, KeyInfo def main() -> None: diff --git a/context/cli_detector_example.py b/context/cli_detector_example.py index 52e57190c..f0b4f92e8 100644 --- a/context/cli_detector_example.py +++ b/context/cli_detector_example.py @@ -7,31 +7,40 @@ project_root = Path(__file__).resolve().parent.parent sys.path.append(str(project_root)) -from pdd.setup.cli_detector import detect_cli_tools +from pdd.cli_detector import detect_and_bootstrap_cli, detect_cli_tools def main() -> None: """ Demonstrates how to use the cli_detector module to: - 1. Detect installed agentic CLI harnesses (claude, codex, gemini) - 2. Cross-reference with available API keys - 3. Offer installation for missing CLIs + 1. Bootstrap an agentic CLI for pdd setup (detect_and_bootstrap_cli) + 2. Detect installed CLI harnesses (claude, codex, gemini) + 3. Cross-reference with available API keys + 4. Offer installation for missing CLIs """ - # Run the interactive detector + # Primary entry point used by pdd setup Phase 1: + # result = detect_and_bootstrap_cli() + # result.cli_name -> "claude" | "codex" | "gemini" | "" + # result.provider -> "Anthropic" | "OpenAI" | "Google" | "" + # result.api_key_configured -> True | False + + # Legacy function for detection only: # detect_cli_tools() # Uncomment to run interactively - # Example flow: + # Example flow (detect_and_bootstrap_cli): # Checking CLI tools... # (Required for: pdd fix, pdd change, pdd bug) # - # Claude CLI ✓ Found at /usr/local/bin/claude - # Codex CLI ✗ Not found - # Gemini CLI ✗ Not found + # Claude CLI Found at /usr/local/bin/claude + # Codex CLI Not found + # Gemini CLI Not found + # + # Using Claude CLI (Anthropic). + # API key: configured # - # You have OPENAI_API_KEY but Codex CLI is not installed. - # Install with: npm install -g @openai/codex - # Install now? [y/N] + # Returns CliBootstrapResult(cli_name="claude", provider="Anthropic", + # api_key_configured=True) pass diff --git a/context/litellm_registry_example.py b/context/litellm_registry_example.py index 2d48bf3dc..38a6d9592 100644 --- a/context/litellm_registry_example.py +++ b/context/litellm_registry_example.py @@ -7,7 +7,7 @@ project_root = Path(__file__).resolve().parent.parent sys.path.append(str(project_root)) -from pdd.setup.litellm_registry import ( +from pdd.litellm_registry import ( is_litellm_available, get_api_key_env_var, get_top_providers, diff --git a/context/local_llm_configurator_example.py b/context/local_llm_configurator_example.py deleted file mode 100644 index 0c52e0650..000000000 --- a/context/local_llm_configurator_example.py +++ /dev/null @@ -1,43 +0,0 @@ -from __future__ import annotations - -import sys -from pathlib import Path - -# Add the project root to sys.path -project_root = Path(__file__).resolve().parent.parent -sys.path.append(str(project_root)) - -from pdd.setup.local_llm_configurator import configure_local_llm - - -def main() -> None: - """ - Demonstrates how to use the local_llm_configurator module to: - 1. Configure Ollama with auto-detection of installed models - 2. Configure LM Studio with default base URL - 3. Add custom local LLM endpoints - """ - - # Run the interactive configuration - # was_added = configure_local_llm() # Uncomment to run interactively - - # Example flow for Ollama: - # What tool are you using? - # 1. LM Studio (default: localhost:1234) - # 2. Ollama (default: localhost:11434) - # 3. Other (custom base URL) - # Choice: 2 - # - # Querying Ollama at http://localhost:11434... - # Found installed models: - # 1. llama3:70b - # 2. codellama:34b - # 3. mistral:7b - # - # Which models do you want to add? [1,2,3]: 1,2 - # ✓ Added ollama_chat/llama3:70b and ollama_chat/codellama:34b to llm_model.csv - pass - - -if __name__ == "__main__": - main() diff --git a/context/model_tester_example.py b/context/model_tester_example.py index 51755e1bb..2c42c3126 100644 --- a/context/model_tester_example.py +++ b/context/model_tester_example.py @@ -7,7 +7,7 @@ project_root = Path(__file__).resolve().parent.parent sys.path.append(str(project_root)) -from pdd.setup.model_tester import test_model_interactive +from pdd.model_tester import test_model_interactive def main() -> None: diff --git a/context/pddrc_initializer_example.py b/context/pddrc_initializer_example.py index 19e61d1cc..99dbee780 100644 --- a/context/pddrc_initializer_example.py +++ b/context/pddrc_initializer_example.py @@ -7,7 +7,7 @@ project_root = Path(__file__).resolve().parent.parent sys.path.append(str(project_root)) -from pdd.setup.pddrc_initializer import offer_pddrc_init +from pdd.pddrc_initializer import offer_pddrc_init def main() -> None: diff --git a/context/setup_tool_example.py b/context/setup_tool_example.py index a517e5980..81903a9cc 100644 --- a/context/setup_tool_example.py +++ b/context/setup_tool_example.py @@ -7,53 +7,77 @@ project_root = Path(__file__).resolve().parent.parent sys.path.append(str(project_root)) -from pdd.setup.setup_tool import run_setup +from pdd.setup_tool import run_setup def main() -> None: """ Demonstrates how to use the setup_tool module to: - 1. Launch the interactive pdd setup wizard - 2. Scan for configured API keys - 3. Navigate the 6-option menu - - The setup wizard is fully interactive. Running it will: - - Scan ~/.pdd/llm_model.csv for configured models and their API key status - - Display a menu with options to add/remove providers, test models, etc. - - Loop until the user selects Done or presses Ctrl-C + 1. Launch the two-phase pdd setup flow + 2. Phase 1: Bootstrap an agentic CLI (Claude/Gemini/Codex) + 3. Phase 2: Auto-configure API keys, models, local LLMs, and .pddrc + + The setup flow is mostly automatic. Phase 1 asks 0-2 questions + (which CLI to use), then Phase 2 runs 4 deterministic steps + with "Press Enter" pauses between them. """ - # Run the interactive setup wizard + # Run the setup flow # run_setup() # Uncomment to run interactively # Example flow: - # ╭──────────────────────────────╮ - # │ pdd setup │ - # ╰──────────────────────────────╯ + # +------------------------------+ + # | pdd setup | + # +------------------------------+ + # + # Phase 1 -- CLI Bootstrap + # Detected: claude (Anthropic) + # API key: configured + # + # Ready to auto-configure PDD. Press Enter to continue... + # + # [Step 1/4] Scanning for API keys... + # ANTHROPIC_API_KEY shell environment + # GEMINI_API_KEY shell environment + # + # 2 API key(s) found. + # + # Press Enter to continue to the next step... + # + # [Step 2/4] Configuring models... + # 3 new model(s) added to ~/.pdd/llm_model.csv + # 4 cloud model(s) configured + # Anthropic: 3 models + # Google: 1 model + # + # Press Enter to continue to the next step... + # + # [Step 3/4] Checking local LLMs and .pddrc... + # Ollama running -- found llama3.2:3b, openhermes:latest + # LM Studio not running (skip) + # .pddrc already exists at /path/to/project/.pddrc # - # API-key scan - # ────────────────────────────────────────────────── - # ANTHROPIC_API_KEY ✓ Found (shell environment) - # OPENAI_API_KEY — Not found + # Press Enter to continue to the next step... # - # 💡 To edit API keys: update ~/.pdd/api-env.zsh or .env file + # [Step 4/4] Testing and summarizing... + # Testing anthropic/claude-sonnet-4-5-20250929... + # claude-sonnet-4-5-20250929 responded OK (1.2s) # - # Models configured: 1 (from 1 API keys + 0 local) + # =============================================== + # PDD Setup Complete + # =============================================== # - # What would you like to do? - # 1. Add a provider - # 2. Remove models - # 3. Test a model - # 4. Detect CLI tools - # 5. Initialize .pddrc - # 6. Done + # API Keys: 2 found + # Models: 4 configured (Anthropic: 3, Google: 1) + # Local: Ollama -- llama3.2:3b, openhermes:latest + # .pddrc: exists + # Test: OK # - # Choice [1-6]: 1 + # =============================================== + # Run 'pdd generate' or 'pdd sync' to start. + # =============================================== # - # Add a provider: - # a. Search providers - # b. Add a local LLM - # c. Add a custom provider + # Setup complete. Happy prompting! pass diff --git a/pdd/setup/api_key_scanner.py b/pdd/api_key_scanner.py similarity index 99% rename from pdd/setup/api_key_scanner.py rename to pdd/api_key_scanner.py index a56be8a75..fcb30d5ef 100644 --- a/pdd/setup/api_key_scanner.py +++ b/pdd/api_key_scanner.py @@ -1,5 +1,5 @@ """ -pdd/setup/api_key_scanner.py +pdd/api_key_scanner.py Discovers API keys needed by the user's configured models, checking existence across shell, .env, and PDD config with source transparency. diff --git a/pdd/cli_detector.py b/pdd/cli_detector.py new file mode 100644 index 000000000..aa495f203 --- /dev/null +++ b/pdd/cli_detector.py @@ -0,0 +1,495 @@ +from __future__ import annotations + +import os +import shutil +import subprocess +import sys +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +from rich.console import Console + +# Maps provider name -> CLI command name +_CLI_COMMANDS: dict[str, str] = { + "anthropic": "claude", + "google": "gemini", + "openai": "codex", +} + +# Maps provider name -> environment variable for API key +_API_KEY_ENV_VARS: dict[str, str] = { + "anthropic": "ANTHROPIC_API_KEY", + "google": "GOOGLE_API_KEY", + "openai": "OPENAI_API_KEY", +} + +# Maps provider name -> npm install command for the CLI +_INSTALL_COMMANDS: dict[str, str] = { + "anthropic": "npm install -g @anthropic-ai/claude-code", + "google": "npm install -g @google/gemini-cli", + "openai": "npm install -g @openai/codex", +} + +# Maps provider name -> human-readable CLI name +_CLI_DISPLAY_NAMES: dict[str, str] = { + "anthropic": "Claude CLI", + "google": "Gemini CLI", + "openai": "Codex CLI", +} + +# Provider -> primary key env var name (used when saving) +PROVIDER_PRIMARY_KEY: Dict[str, str] = { + "anthropic": "ANTHROPIC_API_KEY", + "google": "GEMINI_API_KEY", + "openai": "OPENAI_API_KEY", +} + +# Provider -> display name +PROVIDER_DISPLAY: Dict[str, str] = { + "anthropic": "Anthropic", + "google": "Google (Gemini)", + "openai": "OpenAI", +} + +# CLI preference order (claude first because it supports subscription auth) +CLI_PREFERENCE: List[str] = ["gemini", "claude", "codex"] + +# Ordered list for the numbered selection table: (provider, cli_name, display_name) +_TABLE_ORDER: List[Tuple[str, str, str]] = [ + ("anthropic", "claude", "Claude CLI"), + ("openai", "codex", "Codex CLI"), + ("google", "gemini", "Gemini CLI"), +] + +# Shell -> RC file path (relative to home) +SHELL_RC_MAP: Dict[str, str] = { + "bash": ".bashrc", + "zsh": ".zshrc", + "fish": os.path.join(".config", "fish", "config.fish"), +} + +# Common installation paths for CLI tools (fallback) +_COMMON_CLI_PATHS: Dict[str, List[Path]] = { + "claude": [ + Path.home() / ".local" / "bin" / "claude", + Path("/usr/local/bin/claude"), + Path("/opt/homebrew/bin/claude"), + ], + "codex": [ + Path.home() / ".local" / "bin" / "codex", + Path("/usr/local/bin/codex"), + Path("/opt/homebrew/bin/codex"), + ], + "gemini": [ + Path.home() / ".local" / "bin" / "gemini", + Path("/usr/local/bin/gemini"), + Path("/opt/homebrew/bin/gemini"), + ], +} + +console = Console() + +@dataclass +class CliBootstrapResult: + """Result of CLI detection and bootstrapping.""" + cli_name: str = "" + provider: str = "" + cli_path: str = "" + api_key_configured: bool = False + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _which(cmd: str) -> str | None: + """Return the full path to a command if found on PATH, else None.""" + if not cmd: + return None + return shutil.which(cmd) + +def _has_api_key(provider: str) -> bool: + """Check whether the API key environment variable is set for a provider.""" + env_var = _API_KEY_ENV_VARS.get(provider, "") + if not env_var: + # Also check fallback keys + if provider == "google": + val = os.environ.get("GOOGLE_API_KEY") or os.environ.get("GEMINI_API_KEY") + return bool(val and val.strip()) + return False + val = os.environ.get(env_var) + if val and val.strip(): + return True + # Fallback for google: also check GEMINI_API_KEY + if provider == "google": + val = os.environ.get("GEMINI_API_KEY") + return bool(val and val.strip()) + return False + +def _get_display_key_name(provider: str) -> str: + """Return the key name to display for a provider, checking which is actually set.""" + if provider == "google": + # Prefer GEMINI_API_KEY for display if set, else GOOGLE_API_KEY if set, else GEMINI_API_KEY + if os.environ.get("GEMINI_API_KEY", "").strip(): + return "GEMINI_API_KEY" + if os.environ.get("GOOGLE_API_KEY", "").strip(): + return "GOOGLE_API_KEY" + return "GEMINI_API_KEY" + return _API_KEY_ENV_VARS.get(provider, "") + +def _npm_available() -> bool: + """Check whether npm is available on PATH.""" + return _which("npm") is not None + +def _prompt_input(prompt_text: str) -> str: + """Wrapper around input() for testability.""" + return input(prompt_text) + +def _prompt_yes_no(prompt: str) -> bool: + """Prompt the user with a yes/no question. Default is No.""" + try: + answer = _prompt_input(prompt).strip().lower() + except (EOFError, KeyboardInterrupt): + return False + return answer in ("y", "yes") + +def _run_install(install_cmd: str) -> bool: + """Run an installation command via subprocess. Returns True on success.""" + try: + result = subprocess.run( + install_cmd, + shell=True, + capture_output=True, + text=True, + timeout=120 + ) + return result.returncode == 0 + except Exception: + return False + +def _detect_shell() -> str: + """Detect the user's shell from the SHELL environment variable.""" + shell_path = os.environ.get("SHELL", "/bin/bash") + return os.path.basename(shell_path) + +def _get_rc_file_path(shell: str) -> Path: + """Return the absolute path to the shell's RC file.""" + rc_relative = SHELL_RC_MAP.get(shell, SHELL_RC_MAP["bash"]) + if shell == "fish": + return Path.home() / ".config" / "fish" / "config.fish" + return Path.home() / rc_relative + +def _get_api_env_file_path(shell: str) -> Path: + """Return the path to ~/.pdd/api-env.{shell}.""" + return Path.home() / ".pdd" / f"api-env.{shell}" + +def _find_cli_binary(cli_name: str) -> Optional[str]: + """Find a CLI binary by name, including fallbacks.""" + # Use shutil.which first + result = shutil.which(cli_name) + if result: + return result + + # Try common paths + paths = _COMMON_CLI_PATHS.get(cli_name, []) + for path in paths: + if path.exists() and os.access(path, os.X_OK): + return str(path) + + # Try nvm fallback for node-based CLIs + nvm_node = Path.home() / ".nvm" / "versions" / "node" + if nvm_node.exists(): + try: + for version_dir in sorted(nvm_node.iterdir(), reverse=True): + bin_candidate = version_dir / "bin" / cli_name + if bin_candidate.is_file() and os.access(bin_candidate, os.X_OK): + return str(bin_candidate) + except OSError: + pass + + return None + +def _format_export_line(key_name: str, key_value: str, shell: str) -> str: + """Return the shell-appropriate export line.""" + if shell == "fish": + return f"set -gx {key_name} {key_value}" + return f"export {key_name}={key_value}" + +def _format_source_line(api_env_path: Path, shell: str) -> str: + """Return the shell-appropriate source line.""" + path_str = str(api_env_path) + if shell == "fish": + return f"test -f {path_str} ; and source {path_str}" + return f"source {path_str}" + +def _save_api_key(key_name: str, key_value: str, shell: str) -> bool: + """Save API key and update shell RC.""" + pdd_dir = Path.home() / ".pdd" + api_env_path = _get_api_env_file_path(shell) + rc_path = _get_rc_file_path(shell) + + try: + pdd_dir.mkdir(parents=True, exist_ok=True) + + # Append or create api-env file + existing_content = "" + if api_env_path.exists(): + existing_content = api_env_path.read_text(encoding="utf-8") + + export_line = _format_export_line(key_name, key_value, shell) + lines = existing_content.splitlines() + # Filter out existing entries for this key + filtered = [ln for ln in lines if key_name not in ln] + filtered.append(export_line) + + api_env_path.write_text("\n".join(filtered) + "\n", encoding="utf-8") + + # Update RC file + source_line = _format_source_line(api_env_path, shell) + rc_content = "" + if rc_path.exists(): + rc_content = rc_path.read_text(encoding="utf-8") + + if source_line not in rc_content: + with open(rc_path, "a", encoding="utf-8") as f: + f.write(f"\n# pdd CLI API keys\n{source_line}\n") + + os.environ[key_name] = key_value + return True + except Exception as e: + console.print(f"[red]Error saving API key: {e}[/red]") + return False + +def _prompt_api_key(provider: str, shell: str) -> bool: + """Prompt user for API key and save it.""" + key_name = PROVIDER_PRIMARY_KEY.get(provider, "") + if not key_name: + return False + + display = PROVIDER_DISPLAY.get(provider, provider) + try: + key_value = _prompt_input(f" Enter your {display} API key (or press Enter to skip): ").strip() + except (EOFError, KeyboardInterrupt): + return False + + if not key_value: + if provider == "anthropic": + console.print(" [dim]Note: Claude CLI may still work with subscription auth.[/dim]") + return False + + return _save_api_key(key_name, key_value, shell) + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + +def detect_and_bootstrap_cli() -> CliBootstrapResult: + """Phase 1 entry point for pdd setup. + + Shows a numbered selection table of all three CLI options with their + install and API-key status, lets the user choose, and walks through + installation and key configuration as needed. + """ + console.print("\nChecking CLI tools...\n") + shell = _detect_shell() + + # ------------------------------------------------------------------ + # 1. Gather status for each CLI in table order + # ------------------------------------------------------------------ + cli_info: List[Dict[str, object]] = [] + for provider, cli_name, display_name in _TABLE_ORDER: + path = _find_cli_binary(cli_name) + has_key = _has_api_key(provider) + key_display = _get_display_key_name(provider) + cli_info.append({ + "provider": provider, + "cli_name": cli_name, + "display_name": display_name, + "path": path, + "has_key": has_key, + "key_display": key_display, + }) + + # ------------------------------------------------------------------ + # 2. Print numbered selection table with aligned columns + # ------------------------------------------------------------------ + from rich.markup import escape as _escape + + # Compute column widths using plain strings (no markup) for measurement + max_name_len = max(len(str(c["display_name"])) for c in cli_info) + max_install_len = 0 + install_strs_plain: List[str] = [] + install_strs_display: List[str] = [] + for c in cli_info: + if c["path"]: + plain = f"\u2713 Found at {c['path']}" + display = f"[green]\u2713[/green] Found at {_escape(str(c['path']))}" + else: + plain = "\u2717 Not found" + display = "[red]\u2717[/red] Not found" + install_strs_plain.append(plain) + install_strs_display.append(display) + max_install_len = max(max_install_len, len(plain)) + + for idx, c in enumerate(cli_info): + num = idx + 1 + name_padded = str(c["display_name"]).ljust(max_name_len) + install_display = install_strs_display[idx] + install_padding = " " * (max_install_len - len(install_strs_plain[idx])) + if c["has_key"]: + key_str = f"[green]\u2713[/green] {c['key_display']} is set" + else: + key_str = f"[red]\u2717[/red] {c['key_display']} not set" + console.print(f" [blue]{num}[/blue]. {name_padded} {install_display}{install_padding} {key_str}") + + console.print() + + # ------------------------------------------------------------------ + # 3. Determine smart default + # ------------------------------------------------------------------ + default_idx = 0 # fallback: Claude (index 0 -> selection "1") + # Prefer installed + key + for i, c in enumerate(cli_info): + if c["path"] and c["has_key"]: + default_idx = i + break + else: + # Prefer installed only + for i, c in enumerate(cli_info): + if c["path"]: + default_idx = i + break + + # ------------------------------------------------------------------ + # 4. Prompt for selection + # ------------------------------------------------------------------ + try: + console.print(" Which CLI would you like to use for pdd setup? \[[blue]1[/blue]/[blue]2[/blue]/[blue]3[/blue]]: ", end="") + raw = _prompt_input("").strip() + except (EOFError, KeyboardInterrupt): + console.print("\n [dim]Skipped CLI setup. You can run `pdd setup` again later.[/dim]") + return CliBootstrapResult() + + if raw.lower() in ("q", "n"): + console.print(" [dim]Skipped CLI setup. You can run `pdd setup` again later.[/dim]") + return CliBootstrapResult() + + if raw in ("1", "2", "3"): + selected_idx = int(raw) - 1 + elif raw == "": + selected_idx = default_idx + console.print(f" [dim]Defaulting to {cli_info[selected_idx]['display_name']}[/dim]") + else: + # Invalid input — treat as default + selected_idx = default_idx + console.print(f" [dim]Invalid input. Defaulting to {cli_info[selected_idx]['display_name']}[/dim]") + + selected = cli_info[selected_idx] + sel_provider: str = str(selected["provider"]) + sel_cli_name: str = str(selected["cli_name"]) + sel_path: Optional[str] = selected["path"] if selected["path"] else None # type: ignore[assignment] + sel_has_key: bool = bool(selected["has_key"]) + + # ------------------------------------------------------------------ + # 5. Install step (if not installed) + # ------------------------------------------------------------------ + if not sel_path: + install_cmd = _INSTALL_COMMANDS[sel_provider] + console.print(f"\n Install command: [bold]{install_cmd}[/bold]") + try: + install_answer = _prompt_input(" Install now? [y/N]: ").strip().lower() + except (EOFError, KeyboardInterrupt): + console.print("\n [dim]Skipped CLI setup. You can run `pdd setup` again later.[/dim]") + return CliBootstrapResult() + + if install_answer in ("y", "yes"): + if not _npm_available(): + console.print(" [red]\u2717[/red] npm is not installed. Please install Node.js/npm first,") + console.print(f" then run: {install_cmd}") + return CliBootstrapResult() + + console.print(f" Installing {selected['display_name']}...") + if _run_install(install_cmd): + sel_path = _find_cli_binary(sel_cli_name) + if sel_path: + console.print(f" [green]\u2713[/green] Installed {sel_cli_name} at {sel_path}") + else: + console.print(" [yellow]Installation completed but CLI not found on PATH.[/yellow]") + return CliBootstrapResult() + else: + console.print(" [red]Installation failed. Try installing manually.[/red]") + return CliBootstrapResult() + else: + console.print(f" [dim]Skipped installation. Run `{install_cmd}` manually when ready.[/dim]") + return CliBootstrapResult() + + # ------------------------------------------------------------------ + # 6. API key step (if not set) + # ------------------------------------------------------------------ + if not sel_has_key: + sel_has_key = _prompt_api_key(sel_provider, shell) + + # ------------------------------------------------------------------ + return CliBootstrapResult( + cli_name=sel_cli_name, + provider=sel_provider, + cli_path=sel_path or "", + api_key_configured=sel_has_key, + ) + + +def detect_cli_tools() -> None: + """Legacy detection function.""" + console.print("Agentic CLI Tool Detection") + console.print("(Required for: pdd fix, pdd change, pdd bug)") + console.print() + + found_any = False + all_with_keys_installed = True + + # Use ordered providers + for provider in ["anthropic", "google", "openai"]: + cli_cmd = _CLI_COMMANDS[provider] + display_name = _CLI_DISPLAY_NAMES[provider] + path = _which(cli_cmd) + has_key = _has_api_key(provider) + key_env = _API_KEY_ENV_VARS[provider] + + if path: + found_any = True + console.print(f" [green]\u2713[/green] {display_name} — Found at {path}") + if has_key: + console.print(f" [green]\u2713[/green] {key_env} is set") + else: + console.print(f" [yellow]\u2717[/yellow] {key_env} not set — CLI won't be usable for API calls") + else: + console.print(f" [red]\u2717[/red] {display_name} — Not found") + if has_key: + all_with_keys_installed = False + console.print(f" [yellow]You have {key_env} set but {display_name} is not installed.[/yellow]") + console.print(f" Install: {_INSTALL_COMMANDS[provider]} (install the CLI to use it)") + if _npm_available(): + if _prompt_yes_no(f" Install now? [y/N] "): + if _run_install(_INSTALL_COMMANDS[provider]): + new_path = _which(cli_cmd) + if new_path: + console.print(f" {display_name} installed successfully.") + else: + console.print(" completed but not found on PATH") + else: + console.print(" failed (try installing manually)") + else: + console.print(" Skipped (you can install later).") + else: + console.print(" npm is not installed.") + else: + console.print(f" API key ({key_env}): not set") + console.print() + + if all_with_keys_installed and found_any: + console.print("All CLI tools with matching API keys are installed") + elif not found_any: + console.print("Quick start: No CLI tools found. Install one of the supported CLIs and set its API key.") + +if __name__ == "__main__": + detect_cli_tools() diff --git a/pdd/setup/litellm_registry.py b/pdd/litellm_registry.py similarity index 99% rename from pdd/setup/litellm_registry.py rename to pdd/litellm_registry.py index fcff99f86..fa0ea16f5 100644 --- a/pdd/setup/litellm_registry.py +++ b/pdd/litellm_registry.py @@ -1,5 +1,5 @@ """ -pdd/setup/litellm_registry.py +pdd/litellm_registry.py Wraps litellm's bundled model registry to provide provider search, model browsing, and API key env var lookup. Uses only local data — no network calls. diff --git a/pdd/setup/model_tester.py b/pdd/model_tester.py similarity index 100% rename from pdd/setup/model_tester.py rename to pdd/model_tester.py diff --git a/pdd/setup/pddrc_initializer.py b/pdd/pddrc_initializer.py similarity index 100% rename from pdd/setup/pddrc_initializer.py rename to pdd/pddrc_initializer.py diff --git a/pdd/prompts/agentic_setup_autoconfig_LLM.prompt b/pdd/prompts/agentic_setup_autoconfig_LLM.prompt new file mode 100644 index 000000000..38415c6e5 --- /dev/null +++ b/pdd/prompts/agentic_setup_autoconfig_LLM.prompt @@ -0,0 +1,229 @@ +% You are an expert system administrator configuring PDD (Prompt-Driven Development) for a developer's machine. Your task is to auto-discover all available LLM providers and configure PDD so the user can start using it immediately — with zero user interaction. + +% Context + +You are running Phase 2 of `pdd setup`. Phase 1 already confirmed that the {cli_name} CLI is installed (provider: {provider}). Now you need to discover all available LLM providers, configure models, and set up the project. This phase is fully autonomous — do not prompt the user for any input, though occasionally you can ask for the user to press Enter to continue to next steps. + +% Environment + +- Home directory: {home_dir} +- Shell: {shell_name} +- PDD config directory: {pdd_dir} +- LLM model CSV path: {llm_model_csv_path} +- Current working directory: {cwd} +- .pddrc status: {pddrc_path} + +% CSV Schema + + +The user-level CSV at ~/.pdd/llm_model.csv has columns: +provider,model,input,output,coding_arena_elo,base_url,api_key,max_reasoning_tokens,structured_output,reasoning_type,location + +Example rows: +- Cloud: Anthropic,claude-sonnet-4-5-20250929,3.0,15.0,1500,,ANTHROPIC_API_KEY,0,True,none, +- Local: Ollama,ollama_chat/llama3:70b,0,0,1000,http://localhost:11434,,0,True,none, +- LM Studio: lm_studio,lm_studio/my-model,0,0,1000,http://localhost:1234/v1,,0,True,none, + +Notes: +- Local models have empty api_key column +- Cloud models have the env var NAME (not value) in api_key column +- input/output costs are per 1M tokens + + +% Your Tasks + +Execute these tasks in order. If any single task fails, log the error and continue with remaining tasks. Never abort the entire setup over one failure. + +1. **Create PDD directory** + - Ensure {pdd_dir} exists: `mkdir -p {pdd_dir}` + +2. **Scan for API keys** + Search these locations for API key environment variables. Do NOT display, log, or store actual key values — only report existence and source. + + Sources to check (in priority order): + a. Current shell environment (check with `echo $VAR_NAME` or `env | grep VAR_NAME`) + b. {pdd_dir}/api-env.{shell_name} (parse export/set lines if file exists) + c. .env file in {cwd} (if exists) + d. {home_dir}/.env (if exists) + + Keys to look for (aligned with litellm's PROVIDER_API_KEY_MAP): + + Tier 1 — Major cloud providers: + - ANTHROPIC_API_KEY (Anthropic — Claude models) + - OPENAI_API_KEY (OpenAI — GPT models) + - GOOGLE_API_KEY or GEMINI_API_KEY (Google — Gemini API models) + - VERTEX_CREDENTIALS (Google — Vertex AI models; typically a service account JSON file path or ADC) + - MISTRAL_API_KEY (Mistral AI) + - DEEPSEEK_API_KEY (DeepSeek) + - XAI_API_KEY (xAI — Grok models) + + Tier 2 — Inference platforms & specialized providers: + - GROQ_API_KEY (Groq — fast inference) + - TOGETHERAI_API_KEY or TOGETHER_API_KEY or TOGETHER_AI_API_KEY (Together AI) + - FIREWORKS_API_KEY (Fireworks AI) + - OPENROUTER_API_KEY (OpenRouter — multi-provider gateway) + - COHERE_API_KEY (Cohere) + - PERPLEXITYAI_API_KEY (Perplexity) + - REPLICATE_API_KEY (Replicate) + - DEEPINFRA_API_KEY (DeepInfra) + - CEREBRAS_API_KEY (Cerebras — fast inference) + + Tier 3 — Enterprise & additional providers: + - AZURE_API_KEY (Azure OpenAI) + - AZURE_AI_API_KEY (Azure AI) + - AWS_ACCESS_KEY_ID + AWS_SECRET_ACCESS_KEY (AWS Bedrock — both must be present) + - AI21_API_KEY (AI21 Labs) + - HUGGINGFACE_API_KEY or HF_TOKEN (Hugging Face) + - DATABRICKS_API_KEY (Databricks) + - CLOUDFLARE_API_KEY (Cloudflare Workers AI) + - NOVITA_API_KEY (Novita AI) + - SAMBANOVA_API_KEY (SambaNova) + - WATSONX_API_KEY (IBM watsonx) + + Record which keys exist and where they were found. + +3. **Select models for each discovered provider** + + **Main providers — Anthropic, OpenAI, Google:** + Use ALL matching rows from the reference CSV below. These are pre-vetted models — add them exactly as shown, preserving all column values (model ID, pricing, ELO, reasoning fields, etc.). + + + ../../data/llm_model.csv + + + Matching rules: + - Add Anthropic rows if ANTHROPIC_API_KEY was found + - Add OpenAI rows if OPENAI_API_KEY was found (skip rows with a non-empty base_url unless that URL is reachable) + - Add Google rows with `gemini/` prefix if GEMINI_API_KEY or GOOGLE_API_KEY was found + - Add Google rows with `vertex_ai/` prefix if VERTEX_CREDENTIALS was found + - Skip lm_studio or other local-model rows from this CSV (handle those in step 5) + + **Other providers — use litellm's registry:** + For each API key found outside the main 3 (e.g. Groq, Mistral, xAI, Together, Fireworks, DeepSeek, OpenRouter, Cohere, Perplexity, Cerebras, DeepInfra): + a. Try querying litellm: + ```python + python3 -c " + import litellm, json + prefix = 'groq/' # replace with the provider's litellm prefix + models = [(m, v) for m, v in litellm.model_cost.items() if m.startswith(prefix)] + models.sort(key=lambda x: -x[1].get('input_cost_per_token', 0)) + for m, v in models[:3]: + print(m, v.get('input_cost_per_token',0)*1e6, v.get('output_cost_per_token',0)*1e6) + " + ``` + Use the top 2-3 results (highest input cost ≈ most capable). Set elo to 0 if unknown. + b. If litellm is unavailable or returns nothing, use these fallbacks: + - Groq (GROQ_API_KEY): `groq/moonshotai/kimi-k2-instruct-0905` (input: 1.0, output: 3.0, elo: 1330) + - Mistral (MISTRAL_API_KEY): `mistral/mistral-large-latest` (input: 2.0, output: 6.0, elo: 1414) + - xAI (XAI_API_KEY): `xai/grok-2-latest` (input: 2.0, output: 10.0, elo: 1411) + - Together AI (TOGETHERAI_API_KEY): `together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo` (input: 0.88, output: 0.88, elo: 1300) + - Fireworks (FIREWORKS_API_KEY): `fireworks_ai/accounts/fireworks/models/glm-4p7` (input: 0.60, output: 2.20, elo: 1481) + - DeepSeek (DEEPSEEK_API_KEY): `deepseek/deepseek-chat` (input: 0.14, output: 0.28, elo: 1419) + - OpenRouter (OPENROUTER_API_KEY): `openrouter/anthropic/claude-sonnet-4-5` (input: 3.0, output: 15.0, elo: 1450) + - Cerebras (CEREBRAS_API_KEY): `cerebras/llama3.3-70b` (input: 0.6, output: 0.6, elo: 1300) + - Perplexity (PERPLEXITYAI_API_KEY): `perplexity/sonar-pro` (input: 3.0, output: 15.0, elo: 1280) + - Cohere (COHERE_API_KEY): `cohere/command-r-plus` (input: 2.5, output: 10.0, elo: 1250) + - Azure (AZURE_API_KEY): use azure/ prefix versions of any OpenAI models found + - AWS Bedrock (AWS_ACCESS_KEY_ID + AWS_SECRET_ACCESS_KEY): `bedrock/anthropic.claude-sonnet-4-5-20250929-v1` + +4. **Populate llm_model.csv** + - If {llm_model_csv_path} already exists, read it first and avoid adding duplicate entries (match on provider+model pair) + - If it doesn't exist, create it with the CSV header line: + `provider,model,input,output,coding_arena_elo,base_url,api_key,max_reasoning_tokens,structured_output,reasoning_type,location` + - Append new rows for each discovered provider/model combination + - Use atomic write: write to a temp file first, then move it to the final path + - Format each row matching the schema above (base_url empty for cloud models, reasoning_type "none", location empty) + +5. **Check for local LLMs** + - First, print a visible notification to the user: + ``` + ------------------------------------------------------- + Local LLM Check + ------------------------------------------------------- + If you'd like to use local models with PDD, make sure + your local LLM server is running now: + + Ollama: Run `ollama serve` in another terminal + LM Studio: Start the app → Developer tab → Start Server + + Press Enter to continue... + ------------------------------------------------------- + ``` + - Wait for user to press Enter (`read -p "" dummy` for bash/zsh, or `read dummy` for simpler shells) + - Check if Ollama is running: `curl -s http://localhost:11434/api/tags 2>/dev/null` + - If reachable, parse the JSON response to get model names + - Add each discovered model as: `Ollama,ollama_chat/{model_name},0,0,1000,http://localhost:11434,,0,True,none,` + - Check if LM Studio is running: `curl -s http://localhost:1234/v1/models 2>/dev/null` + - If reachable, parse the JSON response to get model names + - Add each as: `lm_studio,lm_studio/{model_name},0,0,1000,http://localhost:1234/v1,,0,True,none,` + - If neither is reachable, note in the summary: "No local LLMs found. To add local models later, start Ollama or LM Studio and run `pdd setup` again." + +6. **Initialize .pddrc if needed** + If no `.pddrc` file exists in {cwd}: + - Detect project type from files in {cwd}: + - Python: look for setup.py, pyproject.toml, or *.py files + - TypeScript: look for package.json with typescript dependency, or *.ts files + - Go: look for go.mod + - Default to Python if unclear + - Create `.pddrc` with these contents (adjust paths by language): + + For Python projects: + ```yaml + version: "1.0" + + contexts: + default: + defaults: + generate_output_path: "pdd/" + test_output_path: "tests/" + example_output_path: "context/" + default_language: "python" + target_coverage: 80.0 + strength: 1.0 + temperature: 0.0 + budget: 10.0 + max_attempts: 3 + ``` + + For TypeScript projects, use: generate_output_path: "src/", test_output_path: "__tests__/", example_output_path: "examples/", default_language: "typescript" + + For Go projects, use: generate_output_path: ".", test_output_path: ".", example_output_path: "examples/", default_language: "go" + +7. **Test one model** + - Pick the first cloud model from the CSV that has a configured API key + - Run a minimal test: `python3 -c "import litellm; r = litellm.completion(model='{model_name}', messages=[{{'role':'user','content':'Say OK'}}], timeout=30); print('OK:', r.choices[0].message.content)"` + - If litellm is not available, skip this step + - Report success/failure but do not block on failure + +8. **Print summary** + Print a clear, formatted summary: + + ``` + === PDD Setup Complete === + + API Keys Found: + {KEY_NAME} {source} + ... + + Models Configured: {N} total + {Provider}: {model1}, {model2} + ... + + Local LLMs: {none found | Provider: model1, model2, ...} + + .pddrc: {Created at ./.pddrc | Already exists | Skipped (not in a project directory)} + + Model Test: {model-name} -> {OK (0.3s) | FAILED: error | Skipped} + + You're all set! Run `pdd generate` or `pdd sync` to start using PDD. + ``` + +% Important Rules + +- NEVER display, log, or store actual API key values — only report whether they exist and where they were found +- Use atomic file writes for CSV modifications (write to temp file, then rename/move) +- Do not fail if any single step fails — log the error and continue with remaining steps +- If litellm is not installed in the Python environment, use the hardcoded model recommendations above instead +- Minimize user interaction — only ask the user to press Enter at natural pause points (e.g., before the local LLM scan). Never prompt for configuration decisions. +- If the CSV already has entries, preserve them and only add new discoveries (no duplicates) +- Shell-appropriate syntax for api-env files: bash/zsh use `export KEY=value`, fish uses `set -gx KEY value` diff --git a/pdd/prompts/api_key_scanner_python.prompt b/pdd/prompts/api_key_scanner_python.prompt index 7f73dde72..75169efe7 100644 --- a/pdd/prompts/api_key_scanner_python.prompt +++ b/pdd/prompts/api_key_scanner_python.prompt @@ -12,7 +12,7 @@ } -% You are an expert Python engineer. Your goal is to write the pdd/setup/api_key_scanner.py module. +% You are an expert Python engineer. Your goal is to write the pdd/api_key_scanner.py module. % Role & Scope Discovers API keys needed by the user's configured models and reports their existence. Reads the user's `~/.pdd/llm_model.csv` to find all unique API key environment variable names, then checks .env files, shell environment, and ~/.pdd/api-env.* files. Only checks **existence** — never makes API calls or stores key values. @@ -39,4 +39,4 @@ This file is created/managed by provider_manager when the user adds providers vi % Deliverables -- Module at `pdd/setup/api_key_scanner.py` exporting `scan_environment`, `get_provider_key_names`, and `KeyInfo`. +- Module at `pdd/api_key_scanner.py` exporting `scan_environment`, `get_provider_key_names`, and `KeyInfo`. diff --git a/pdd/prompts/cli_detector_python.prompt b/pdd/prompts/cli_detector_python.prompt index c5ff1f15c..7259caeef 100644 --- a/pdd/prompts/cli_detector_python.prompt +++ b/pdd/prompts/cli_detector_python.prompt @@ -1,11 +1,15 @@ -Detects installed agentic CLI tools and offers installation guidance for missing ones. +Detects and bootstraps agentic CLI tools for pdd setup, with minimal-friction interactive installation and API key configuration. { "type": "module", "module": { "functions": [ - {"name": "detect_cli_tools", "signature": "() -> None", "returns": "None"} + {"name": "detect_cli_tools", "signature": "() -> None", "returns": "None"}, + {"name": "detect_and_bootstrap_cli", "signature": "() -> CliBootstrapResult", "returns": "CliBootstrapResult"} + ], + "dataclasses": [ + {"name": "CliBootstrapResult", "fields": ["cli_name", "provider", "cli_path", "api_key_configured"]} ] } } @@ -13,20 +17,73 @@ agentic_common_python.prompt -% You are an expert Python engineer. Your goal is to write the pdd/setup/cli_detector.py module. +% You are an expert Python engineer. Your goal is to write the pdd/cli_detector.py module. % Role & Scope -Detects installed agentic CLI harnesses (Claude CLI, Codex CLI, Gemini CLI) required for `pdd fix`, `pdd change`, and `pdd bug`. Leverages `get_available_agents()` from `pdd.agentic_common` and cross-references with API keys to suggest installations. +Detects and bootstraps agentic CLI harnesses (Claude CLI, Codex CLI, Gemini CLI) required for `pdd fix`, `pdd change`, `pdd bug`, and now `pdd setup` Phase 2 (agentic auto-configuration). The primary function `detect_and_bootstrap_cli()` is designed for minimal user friction — auto-detect what's available, default to the best option, and the user just presses Enter to confirm. The legacy `detect_cli_tools()` function is preserved for backward compatibility. % Requirements -1. Function: `detect_cli_tools()` — check for CLI tools, display results, offer installation -2. For each CLI (claude, codex, gemini): show `✓ Found at /path` or `✗ Not found` -3. Cross-reference with API keys: if user has OPENAI_API_KEY but not codex CLI, highlight and suggest `npm install -g @openai/codex` -4. Offer `Install now? [y/N]` for missing CLIs that have a matching API key; run via subprocess if accepted -5. Show context: `(Required for: pdd fix, pdd change, pdd bug)` -6. Handle npm not being installed (suggest manual installation) + +1. Dataclass: `CliBootstrapResult` with fields: + - `cli_name: str` — e.g. "claude", "gemini", "codex" (empty string if none) + - `provider: str` — e.g. "anthropic", "google", "openai" (empty string if none) + - `cli_path: str` — absolute path to the CLI binary (empty string if none) + - `api_key_configured: bool` — True if the API key for this provider is set + +2. Function: `detect_and_bootstrap_cli() -> CliBootstrapResult` — Phase 1 entry point for `pdd setup`. Shows all three CLI options with their status and lets the user choose which one to use. Flow: + a. Print "Checking CLI tools..." + b. Check all three CLIs (claude, gemini, codex) using `shutil.which()` and common path fallbacks (nvm paths, ~/.local/bin, /usr/local/bin). Use `_find_cli_binary()` from `pdd.agentic_common` if available. + c. Provider-to-CLI mapping: anthropic→claude, google→gemini, openai→codex. Provider-to-key mapping: anthropic→ANTHROPIC_API_KEY, google→GOOGLE_API_KEY or GEMINI_API_KEY, openai→OPENAI_API_KEY. + d. Print a numbered selection table (one CLI per line), using consistent column alignment: + - Index: 1, 2, 3 + - Name: "Claude CLI", "Codex CLI", "Gemini CLI" + - Install status: `✓ Found at {path}` or `✗ Not found` + - Key status: `✓ {KEY_NAME} is set` or `✗ {KEY_NAME} not set` + Example output: + ``` + Checking CLI tools... + + 1. Claude CLI ✓ Found at /usr/local/bin/claude ✓ ANTHROPIC_API_KEY is set + 2. Codex CLI ✗ Not found ✗ OPENAI_API_KEY not set + 3. Gemini CLI ✗ Not found ✓ GEMINI_API_KEY is set + + Which CLI would you like to use for pdd setup? [1/2/3]: + ``` + e. Read user input: + - Accept "1", "2", or "3" to select a CLI. + - If user presses Enter without typing, default to the highest-priority option that is both installed and has an API key; if none qualify, prefer installed-only; if still none, default to "1" (Claude). Print the default selection so the user sees it. + - If user types "q" or "n", return `CliBootstrapResult(cli_name="", provider="", cli_path="", api_key_configured=False)` with message "Skipped CLI setup. You can run `pdd setup` again later." + f. **Install step (if selected CLI is not installed):** Print the install command for that CLI and prompt `Install now? [y/N]: ` (default No on Enter). If accepted and npm is available, run installation via subprocess and wait for it to complete. If npm is not available, print manual installation instructions and return empty result. After successful installation, re-check the path. + - Claude CLI install: `npm install -g @anthropic-ai/claude-code` + - Codex CLI install: `npm install -g @openai/codex` + - Gemini CLI install: `npm install -g @google/gemini-cli` + g. **API key step (if selected CLI's key is not set):** Prompt `Enter your {Provider} API key (or press Enter to skip): `. If user provides a key: save it to `~/.pdd/api-env.{shell}` using shell-appropriate syntax (bash/zsh: `export KEY=value`, fish: `set -gx KEY value`), set it in `os.environ` for immediate availability, and add the source line to the shell RC file if not already present. If user presses Enter without a key, note that some CLIs (e.g. Claude CLI with a subscription) may still work, and return with `api_key_configured=False`. + h. **Ready:** Once both install and key steps pass (or are already satisfied), return the populated `CliBootstrapResult` immediately — do not print any "Press Enter" message here. The caller (setup_tool) owns that transition prompt. + i. Handle KeyboardInterrupt at every input prompt for a clean exit. + +3. Function: `detect_cli_tools()` — legacy function, kept for backward compatibility. + - For each CLI (claude, codex, gemini): show `✓ Found at /path` or `✗ Not found` + - Cross-reference with API keys: if user has OPENAI_API_KEY but not codex CLI, highlight and suggest `npm install -g @openai/codex` + - Offer `Install now? [y/N]` for missing CLIs that have a matching API key; run via subprocess if accepted + - Show context: `(Required for: pdd fix, pdd change, pdd bug)` + - Handle npm not being installed (suggest manual installation) + +4. Shell detection: Detect shell from `SHELL` env var (default to "bash"). Use `os.path.basename()` to extract shell name. Map to RC file path: zsh→~/.zshrc, bash→~/.bashrc, fish→~/.config/fish/config.fish. + +5. API key file management: + - Create `~/.pdd/` directory if it doesn't exist + - Write key to `~/.pdd/api-env.{shell}` (append, don't overwrite existing entries) + - Use shell-appropriate export syntax + - Add `source ~/.pdd/api-env.{shell}` (or fish equivalent: `test -f ... ; and source ...`) to shell RC file if not already present % Dependencies + +% Here are examples of how to use internal modules: + + % Here is an example of the cli_detector module showing expected usage: + % Here is an example of the agentic_common module showing CLI detection and agent availability: + + context/agentic_common_example.py @@ -38,4 +95,4 @@ from pdd.agentic_common import get_available_agents, CLI_COMMANDS % Deliverables -- Module at `pdd/setup/cli_detector.py` exporting `detect_cli_tools`. +- Module at `pdd/cli_detector.py` exporting `detect_cli_tools`, `detect_and_bootstrap_cli`, and `CliBootstrapResult`. \ No newline at end of file diff --git a/pdd/prompts/litellm_registry_python.prompt b/pdd/prompts/litellm_registry_python.prompt index 8c7818507..b0db6ea7e 100644 --- a/pdd/prompts/litellm_registry_python.prompt +++ b/pdd/prompts/litellm_registry_python.prompt @@ -20,7 +20,7 @@ } -% You are an expert Python engineer. Your goal is to write the pdd/setup/litellm_registry.py module. +% You are an expert Python engineer. Your goal is to write the pdd/litellm_registry.py module. % Role & Scope Thin wrapper around litellm's bundled data (`litellm.model_cost`, `litellm.models_by_provider`) for provider discovery and model browsing. Uses only locally bundled data — never makes network calls. Provides the data layer for the "Search providers" flow in `pdd setup`. @@ -53,5 +53,5 @@ Note: model_cost has NO api_key_env_var field. Provider-to-key mapping must be h % Deliverables -- Module at `pdd/setup/litellm_registry.py` exporting `ProviderInfo`, `ModelInfo`, `is_litellm_available`, `get_api_key_env_var`, `get_top_providers`, `get_all_providers`, `search_providers`, `get_models_for_provider`. +- Module at `pdd/litellm_registry.py` exporting `ProviderInfo`, `ModelInfo`, `is_litellm_available`, `get_api_key_env_var`, `get_top_providers`, `get_all_providers`, `search_providers`, `get_models_for_provider`. - Also exports constants `PROVIDER_API_KEY_MAP` and `PROVIDER_DISPLAY_NAMES` for use by other modules. diff --git a/pdd/prompts/local_llm_configurator_python.prompt b/pdd/prompts/local_llm_configurator_python.prompt deleted file mode 100644 index 880a340cd..000000000 --- a/pdd/prompts/local_llm_configurator_python.prompt +++ /dev/null @@ -1,40 +0,0 @@ -Configures local LLMs (Ollama, LM Studio, custom) with auto-detection and user CSV integration. - - -{ - "type": "module", - "module": { - "functions": [ - {"name": "configure_local_llm", "signature": "() -> bool", "returns": "bool"} - ] - } -} - - -% You are an expert Python engineer. Your goal is to write the pdd/setup/local_llm_configurator.py module. - -% Role & Scope -Guides users through configuring local LLM tools (Ollama, LM Studio, custom endpoints). Local models need a base_url and model name, not API keys. Auto-detects installed models where possible. - -% Requirements -1. Function: `configure_local_llm() -> bool` — interactive setup, returns True if any models were added -2. Provider menu: 1. LM Studio (default localhost:1234), 2. Ollama (default localhost:11434), 3. Other (custom base URL) -3. Ollama auto-detection: query http://localhost:11434/api/tags, show discovered models, let user select which to add (comma-separated). Fall back to manual entry if unreachable. -4. LM Studio: default base URL http://localhost:1234/v1, prompt for model name -5. Append rows to user's `~/.pdd/llm_model.csv` with LiteLLM prefix conventions (`lm_studio/`, `ollama_chat/`), empty api_key, cost=0.0 -6. Validate base URL format (http/https required) -7. Atomic CSV writes; create user CSV with header if it doesn't exist -8. Handle empty input as cancel - -% Dependencies - -The user's CSV at ~/.pdd/llm_model.csv has columns: -provider,model,input,output,coding_arena_elo,base_url,api_key,max_reasoning_tokens,structured_output,reasoning_type,location - -Example local rows: -- lm_studio,lm_studio/openai-gpt-oss-120b-mlx-6,0,0,1000,http://localhost:1234/v1,,0,True,none, -- Ollama,ollama_chat/llama3:70b,0,0,1000,http://localhost:11434,,0,True,none, - - -% Deliverables -- Module at `pdd/setup/local_llm_configurator.py` exporting `configure_local_llm`. diff --git a/pdd/prompts/model_tester_python.prompt b/pdd/prompts/model_tester_python.prompt index c946eaf49..169e65f90 100644 --- a/pdd/prompts/model_tester_python.prompt +++ b/pdd/prompts/model_tester_python.prompt @@ -11,7 +11,7 @@ } -% You are an expert Python engineer. Your goal is to write the pdd/setup/model_tester.py module. +% You are an expert Python engineer. Your goal is to write the pdd/model_tester.py module. % Role & Scope Tests a single configured model by making one `litellm.completion()` call with a minimal prompt. Only runs when the user explicitly chooses it — no surprise API costs. Uses `litellm.completion()` directly (not `llm_invoke`) because `llm_invoke` doesn't allow choosing a specific model or key. @@ -34,4 +34,4 @@ provider,model,input,output,coding_arena_elo,base_url,api_key,max_reasoning_toke % Deliverables -- Module at `pdd/setup/model_tester.py` exporting `test_model_interactive`. +- Module at `pdd/model_tester.py` exporting `test_model_interactive`. diff --git a/pdd/prompts/pddrc_initializer_python.prompt b/pdd/prompts/pddrc_initializer_python.prompt index 0c7fbbee6..8fb7a8911 100644 --- a/pdd/prompts/pddrc_initializer_python.prompt +++ b/pdd/prompts/pddrc_initializer_python.prompt @@ -11,7 +11,7 @@ } -% You are an expert Python engineer. Your goal is to write the pdd/setup/pddrc_initializer.py module. +% You are an expert Python engineer. Your goal is to write the pdd/pddrc_initializer.py module. % Role & Scope Offers to create a basic `.pddrc` configuration file in the current project directory if one doesn't exist. Sets sensible defaults based on detected project type. @@ -44,4 +44,4 @@ contexts: % Deliverables -- Module at `pdd/setup/pddrc_initializer.py` exporting `offer_pddrc_init`. +- Module at `pdd/pddrc_initializer.py` exporting `offer_pddrc_init`. diff --git a/pdd/prompts/provider_manager_python.prompt b/pdd/prompts/provider_manager_python.prompt index fd028d414..495460ab9 100644 --- a/pdd/prompts/provider_manager_python.prompt +++ b/pdd/prompts/provider_manager_python.prompt @@ -16,7 +16,7 @@ litellm_registry_python.prompt -% You are an expert Python engineer. Your goal is to write the pdd/setup/provider_manager.py module. +% You are an expert Python engineer. Your goal is to write the pdd/provider_manager.py module. % Role & Scope Handles adding and removing LLM providers and models in PDD setup. The primary flow uses litellm's bundled model registry to let users search/browse providers, pick specific models, and enter API keys. Also supports adding custom LiteLLM-compatible providers and two modes of model removal. @@ -65,4 +65,4 @@ CSV mapping from litellm registry: % Deliverables -- Module at `pdd/setup/provider_manager.py` exporting `add_provider_from_registry`, `add_custom_provider`, `remove_models_by_provider`, `remove_individual_models`. +- Module at `pdd/provider_manager.py` exporting `add_provider_from_registry`, `add_custom_provider`, `remove_models_by_provider`, `remove_individual_models`. diff --git a/pdd/prompts/setup_tool_python.prompt b/pdd/prompts/setup_tool_python.prompt index fad7774f3..b127d29f0 100644 --- a/pdd/prompts/setup_tool_python.prompt +++ b/pdd/prompts/setup_tool_python.prompt @@ -1,4 +1,4 @@ -Orchestrates pdd setup: environment scanning, provider management, model testing, CLI detection, and configuration. +Orchestrates pdd setup in two phases: (1) interactive CLI bootstrapping with minimal friction, (2) deterministic auto-configuration using existing Python modules (no LLM calls). { @@ -11,69 +11,161 @@ } -api_key_scanner_python.prompt -provider_manager_python.prompt -litellm_registry_python.prompt -local_llm_configurator_python.prompt -model_tester_python.prompt cli_detector_python.prompt -pddrc_initializer_python.prompt -% You are an expert Python engineer. Your goal is to write the pdd/setup/setup_tool.py module. +% You are an expert Python engineer. Your goal is to write the pdd/setup_tool.py module. % Role & Scope -Main orchestrator for `pdd setup`. Auto-scans the environment for API keys based on the user's configured models (existence only — no API calls), then presents an interactive menu. After any action, the menu re-displays with an updated scan. When no models are configured yet, displays a helpful "get started" message. +Main orchestrator for `pdd setup`. Implements a two-phase flow designed for minimal user friction — users should be able to complete setup by pressing Enter just a few times. Phase 1 bootstraps an agentic CLI (Claude/Gemini/Codex) interactively. Phase 2 runs deterministic Python code that auto-discovers API keys, configures models from a reference CSV, checks for local LLMs, initializes .pddrc, tests a model, and prints a summary — all with clean, deterministic output and no LLM calls. % Requirements -1. Function: `run_setup()` — main entry point -2. On entry, call `api_key_scanner.scan_environment()` and display each key with `✓ Found (source)` or `— Not found`. After the key list, display a helpful note: `💡 To edit API keys: update ~/.pdd/api-env.{shell} or .env file` (where {shell} is auto-detected from the SHELL environment variable, defaulting to "bash"). Then display the summary line: `Models configured: N (from M API keys + K local)`. If scan_results is empty, display: "No models configured yet. Use 'Add a provider' to get started." -3. Present a 6-option menu after the scan: - 1. Add a provider (sub-menu: a. Search providers, b. Add a local LLM, c. Add a custom provider) - 2. Remove models (sub-menu: a. By provider, b. Individual models) - 3. Test a model - 4. Detect CLI tools - 5. Initialize .pddrc - 6. Done -4. Delegate to: `provider_manager.add_provider_from_registry`, `local_llm_configurator.configure_local_llm`, `provider_manager.add_custom_provider`, `provider_manager.remove_models_by_provider`, `provider_manager.remove_individual_models`, `model_tester.test_model_interactive`, `cli_detector.detect_cli_tools`, `pddrc_initializer.offer_pddrc_init` -5. After options 1–5, re-scan and re-display the menu -6. Option 6 exits the loop -7. Handle KeyboardInterrupt for clean exit at any point + +1. Function: `run_setup()` — main entry point. Two-phase flow: + + **Banner:** + Print a simple banner: + ``` + ╭──────────────────────────────╮ + │ pdd setup │ + ╰──────────────────────────────╯ + ``` + + **Phase 1 — CLI Bootstrap (interactive, 0–2 user inputs):** + a. Call `cli_detector.detect_and_bootstrap_cli()` which returns a `CliBootstrapResult`. + b. If result has `cli_name == ""` (user declined everything): + - Print: "Agentic features require at least one CLI tool. Run `pdd setup` again when ready." + - Return (exit gracefully). + c. If result has a CLI but `api_key_configured == False`: + - Print: "Note: No API key configured. The agent may have limited capability." + - Still proceed to Phase 2 (some CLIs like Claude support subscription auth). + + **Phase 2 — Deterministic Auto-Configuration (4 steps):** + a. Print: "Ready to auto-configure PDD. Press Enter to continue..." and wait for Enter. + b. Run 4 sequential steps via `_run_auto_phase()`, each printing its output immediately: + 1. `_step1_scan_keys()` — scan for API keys + 2. `_step2_configure_models()` — match keys to reference models, write CSV + 3. `_step3_local_llms_and_pddrc()` — check Ollama/LM Studio, ensure .pddrc + 4. `_step4_test_and_summary()` — test one model, print final summary + c. Between each step (except after the last), prompt "Press Enter to continue to the next step..." + d. If any step raises an exception, catch it, print error, and fall back to manual menu. + e. After all steps, print: "Setup complete. Happy prompting!" + +2. Step 1 — Scan for API Keys (`_step1_scan_keys`): + - Ensure ~/.pdd directory exists (mkdir -p equivalent). + - Get all known env var names from `litellm_registry.PROVIDER_API_KEY_MAP.values()`. + - Check each against: os.environ, ~/.pdd/api-env.{shell} (via `api_key_scanner._parse_api_env_file`), and .env files (via python-dotenv). + - Print each found key with aligned formatting: ` ✓ KEY_NAME source_label` + - If none found: call `_prompt_for_api_key()` to interactively add at least one key. + - Return: `List[Tuple[str, str]]` of (key_name, source_label). + + **`_prompt_for_api_key()` — Interactive key addition (called when no keys found):** + - Show a numbered list of popular providers (Anthropic, Google Gemini, OpenAI, DeepSeek) plus "Other provider" and "Skip". + - User selects a provider, then pastes their key (masked via `getpass.getpass()`). + - Key is saved to `~/.pdd/api-env.{shell}` via `provider_manager._save_key_to_api_env()` which also sets `os.environ` for the current session. + - Loop with "Add another key? [y/N]" until user declines or selects Skip. + - "Other provider" shows the full `PROVIDER_API_KEY_MAP` sorted alphabetically. + - Returns: `List[Tuple[str, str]]` of newly added keys. + +3. Step 2 — Configure Models (`_step2_configure_models`): + - Read reference CSV from `pdd/data/llm_model.csv` using `provider_manager._read_csv()`. + - Filter to rows whose `api_key` column matches a found key name. Skip local models (provider=lm_studio/ollama, localhost base_url). + - Read existing user CSV at `~/.pdd/llm_model.csv`, deduplicate by `model` column. + - Write merged result atomically via `provider_manager._write_csv_atomic()`. + - Print counts: ` ✓ N new model(s) added` or ` ✓ All matching models already present` + - Print per-provider breakdown: ` Provider: N models` + - Return: `Dict[str, int]` of {provider: count}. + +4. Step 3 — Local LLMs + .pddrc (`_step3_local_llms_and_pddrc`): + - Check Ollama at `http://localhost:11434/api/tags` via `urllib.request` (3s timeout). + If reachable: print ` ✓ Ollama running — found model1, model2` and append to user CSV. + If not: print ` ✗ Ollama not running (skip)`. + - Check LM Studio at `http://localhost:1234/v1/models` similarly. + - Check if `.pddrc` exists in cwd. If yes: print exists. If no: auto-detect language via `pddrc_initializer._detect_language()`, create via `_build_pddrc_content()`. + - Return: `Dict[str, List[str]]` of {provider: [model_names]}. + +5. Step 4 — Test + Summary (`_step4_test_and_summary`): + - Pick first cloud model (has non-empty api_key) from user CSV. + - If litellm importable: test via `model_tester._run_test(row)`. Print result. + - If not: print "Skipped (litellm not installed)". + - Print deterministic summary box with all data from steps 1-3. + +6. Handle `KeyboardInterrupt` for clean exit at any point: print "Setup interrupted — exiting." and return. + +7. **Fallback (if auto phase fails):** + Run a simplified manual menu loop with options: + 1. Add a provider (delegates to `provider_manager.add_provider_from_registry()`) + 2. Test a model (delegates to `model_tester.test_model_interactive()`) + 3. Initialize .pddrc (delegates to `pddrc_initializer.offer_pddrc_init()`) + 4. Done + +8. Import strategy (all deferred/lazy): + - From `pdd.cli_detector`: `detect_and_bootstrap_cli`, `CliBootstrapResult` + - From `pdd.litellm_registry`: `PROVIDER_API_KEY_MAP`, `PROVIDER_DISPLAY_NAMES` + - From `pdd.api_key_scanner`: `_parse_api_env_file`, `_detect_shell` + - From `pdd.provider_manager`: `_read_csv`, `_write_csv_atomic`, `_get_user_csv_path`, `_save_key_to_api_env`, `_get_api_env_path` + - From `pdd.pddrc_initializer`: `_detect_language`, `_build_pddrc_content` + - From `pdd.model_tester`: `_run_test` + - Fallback only: `provider_manager.add_provider_from_registry`, `model_tester.test_model_interactive`, `pddrc_initializer.offer_pddrc_init` + - Standard library: `getpass`, `json`, `os`, `urllib.request`, `urllib.error`, `pathlib.Path` % Dependencies - - context/api_key_scanner_example.py - - - context/provider_manager_example.py - + + % Example of the CLI detector used in Phase 1 to bootstrap an agentic CLI: + context/cli_detector_example.py - - context/litellm_registry_example.py - + % Example of provider_manager used for CSV I/O and the fallback manual menu: + context/provider_manager_example.py - - context/local_llm_configurator_example.py - + % Example of model_tester used for testing models: + context/model_tester_example.py - - context/model_tester_example.py - + % Example of pddrc_initializer used for .pddrc creation: + context/pddrc_initializer_example.py + context/cli_detector_example.py - - context/pddrc_initializer_example.py - - The user-level CSV at ~/.pdd/llm_model.csv has columns: provider,model,input,output,coding_arena_elo,base_url,api_key,max_reasoning_tokens,structured_output,reasoning_type,location -This file is created when the user first adds a provider via setup. +This file is created/managed by the auto-configuration step or manually by the user. +The reference CSV at pdd/data/llm_model.csv contains known models for all major providers. + +from pdd.provider_manager import _read_csv, _write_csv_atomic, _get_user_csv_path +# _read_csv(path: Path) -> List[Dict[str, str]] — reads CSV to list of row dicts +# _write_csv_atomic(path: Path, rows: List[Dict[str, str]]) — atomic write via temp file + rename +# _get_user_csv_path() -> Path — returns ~/.pdd/llm_model.csv + + + +from pdd.litellm_registry import PROVIDER_API_KEY_MAP +# PROVIDER_API_KEY_MAP: Dict[str, str] — maps provider ID to env var name +# e.g. {"anthropic": "ANTHROPIC_API_KEY", "openai": "OPENAI_API_KEY", ...} + + + +from pdd.api_key_scanner import _parse_api_env_file, _detect_shell +# _parse_api_env_file(file_path: Path) -> Dict[str, str] — parse export lines from api-env file +# _detect_shell() -> Optional[str] — detect shell name from SHELL env var + + + +from pdd.pddrc_initializer import _detect_language, _build_pddrc_content +# _detect_language(cwd: Path) -> Optional[str] — detect project language from marker files +# _build_pddrc_content(language: str) -> str — build YAML content for .pddrc + + + +from pdd.model_tester import _run_test +# _run_test(row: Dict[str, Any]) -> Dict[str, Any] +# Returns: {"success": bool, "duration_s": float, "cost": float, "error": str|None, "tokens": dict|None} + + % Deliverables -- Module at `pdd/setup/setup_tool.py` exporting `run_setup`. -- IMPORTANT: Must include `if __name__ == "__main__":` entry point that calls `run_setup()` to enable execution via `python -m pdd.setup.setup_tool`. +- Module at `pdd/setup_tool.py` exporting `run_setup`. +- IMPORTANT: Must include `if __name__ == "__main__":` entry point that calls `run_setup()` to enable execution via `python -m pdd.setup_tool`. diff --git a/pdd/setup/provider_manager.py b/pdd/provider_manager.py similarity index 99% rename from pdd/setup/provider_manager.py rename to pdd/provider_manager.py index 92844b109..5fcb1be26 100644 --- a/pdd/setup/provider_manager.py +++ b/pdd/provider_manager.py @@ -15,7 +15,7 @@ from rich.table import Table from rich.prompt import Prompt, Confirm -from pdd.setup.litellm_registry import ( +from pdd.litellm_registry import ( is_litellm_available, get_top_providers, search_providers, @@ -315,7 +315,7 @@ def _is_key_set(key_name: str) -> Optional[str]: env_path = _get_api_env_path() if env_path.exists(): - from pdd.setup.api_key_scanner import _parse_api_env_file + from pdd.api_key_scanner import _parse_api_env_file api_env_vals = _parse_api_env_file(env_path) if key_name in api_env_vals: return f"~/.pdd/{env_path.name}" diff --git a/pdd/setup/__init__.py b/pdd/setup/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/pdd/setup/cli_detector.py b/pdd/setup/cli_detector.py deleted file mode 100644 index 896fc0a50..000000000 --- a/pdd/setup/cli_detector.py +++ /dev/null @@ -1,191 +0,0 @@ -from __future__ import annotations - -import os -import shutil -import subprocess -import sys -from pathlib import Path - - -# Maps provider name -> CLI command name -_CLI_COMMANDS: dict[str, str] = { - "anthropic": "claude", - "google": "gemini", - "openai": "codex", -} - -# Maps provider name -> environment variable for API key -_API_KEY_ENV_VARS: dict[str, str] = { - "anthropic": "ANTHROPIC_API_KEY", - "google": "GOOGLE_API_KEY", - "openai": "OPENAI_API_KEY", -} - -# Maps provider name -> npm install command for the CLI -_INSTALL_COMMANDS: dict[str, str] = { - "anthropic": "npm install -g @anthropic-ai/claude-code", - "google": "npm install -g @anthropic-ai/gemini-cli", - "openai": "npm install -g @openai/codex", -} - -# Maps provider name -> human-readable CLI name -_CLI_DISPLAY_NAMES: dict[str, str] = { - "anthropic": "Claude CLI", - "google": "Gemini CLI", - "openai": "Codex CLI", -} - - -def _which(cmd: str) -> str | None: - """Return the full path to a command if found on PATH, else None.""" - return shutil.which(cmd) - - -def _has_api_key(provider: str) -> bool: - """Check whether the API key environment variable is set for a provider.""" - env_var = _API_KEY_ENV_VARS.get(provider, "") - return bool(os.environ.get(env_var, "").strip()) - - -def _npm_available() -> bool: - """Check whether npm is available on PATH.""" - return _which("npm") is not None - - -def _prompt_yes_no(prompt: str) -> bool: - """Prompt the user with a yes/no question. Default is No.""" - try: - answer = input(prompt).strip().lower() - except (EOFError, KeyboardInterrupt): - print() - return False - return answer in ("y", "yes") - - -def _run_install(install_cmd: str) -> bool: - """Run an installation command via subprocess. Returns True on success.""" - print(f" Running: {install_cmd}") - try: - result = subprocess.run( - install_cmd, - shell=True, - check=False, - capture_output=False, - ) - return result.returncode == 0 - except Exception as exc: - print(f" Installation failed: {exc}") - return False - - -def detect_cli_tools() -> None: - """ - Detect installed agentic CLI harnesses (Claude CLI, Codex CLI, Gemini CLI) - required for ``pdd fix``, ``pdd change``, and ``pdd bug``. - - For each CLI tool: - - Shows ✓ Found at /path or ✗ Not found - - Cross-references with API keys to highlight actionable installations - - Offers interactive installation for missing CLIs that have a matching API key - - Handles the case where npm is not installed by suggesting manual installation. - """ - # Try to import get_available_agents for cross-reference, but don't fail if - # the import is unavailable (we can still do basic detection). - available_agents: list[str] = [] - try: - from pdd.agentic_common import get_available_agents as _get_available_agents - available_agents = list(_get_available_agents()) - except Exception: - pass - - print() - print("Agentic CLI Tool Detection") - print("=" * 50) - print("(Required for: pdd fix, pdd change, pdd bug)") - print() - - missing_with_key: list[str] = [] - found_any = False - - for provider, cli_cmd in _CLI_COMMANDS.items(): - display_name = _CLI_DISPLAY_NAMES[provider] - path = _which(cli_cmd) - has_key = _has_api_key(provider) - key_env = _API_KEY_ENV_VARS[provider] - - if path: - print(f" ✓ {display_name} ({cli_cmd}): Found at {path}") - found_any = True - if has_key: - print(f" API key ({key_env}): set") - else: - print(f" API key ({key_env}): not set — CLI found but won't be usable without it") - else: - print(f" ✗ {display_name} ({cli_cmd}): Not found") - if has_key: - print(f" API key ({key_env}): set — install the CLI to use this provider") - missing_with_key.append(provider) - else: - print(f" API key ({key_env}): not set") - - print() - - if not missing_with_key: - if found_any: - print("All CLI tools with matching API keys are installed.") - else: - print("No CLI tools found. Install at least one CLI and set its API key") - print("to use agentic features (pdd fix, pdd change, pdd bug).") - print() - print("Quick start:") - for provider, install_cmd in _INSTALL_COMMANDS.items(): - display_name = _CLI_DISPLAY_NAMES[provider] - key_env = _API_KEY_ENV_VARS[provider] - print(f" {display_name}: {install_cmd}") - print(f" Then set: export {key_env}=") - print() - return - - # Offer installation for missing CLIs that have a matching API key - print("The following CLI tools are missing but have API keys configured:") - print() - - npm_available = _npm_available() - - for provider in missing_with_key: - display_name = _CLI_DISPLAY_NAMES[provider] - install_cmd = _INSTALL_COMMANDS[provider] - - print(f" {display_name}:") - print(f" Install command: {install_cmd}") - - if not npm_available: - print(" ☀ npm is not installed. Please install Node.js/npm first:") - print(" macOS: brew install node") - print(" Ubuntu: sudo apt-get update && sudo apt-get install -y nodejs npm") - print(" Then run the install command above manually.") - print() - continue - - if _prompt_yes_no(f" Install now? [y/N] "): - success = _run_install(install_cmd) - if success: - new_path = _which(_CLI_COMMANDS[provider]) - if new_path: - print(f" ✓ {display_name} installed successfully at {new_path}") - else: - print(f" ✓ Installation command completed. You may need to restart your shell.") - else: - print(f" ✗ Installation failed. Try running manually:") - print(f" {install_cmd}") - else: - print(" Skipped. To install later, run:") - print(f" {install_cmd}") - - print() - - print() - -if __name__ == "__main__": - detect_cli_tools() \ No newline at end of file diff --git a/pdd/setup/local_llm_configurator.py b/pdd/setup/local_llm_configurator.py deleted file mode 100644 index f77a753bb..000000000 --- a/pdd/setup/local_llm_configurator.py +++ /dev/null @@ -1,377 +0,0 @@ -"""Local LLM configurator for PDD. - -Guides users through configuring local LLM tools (Ollama, LM Studio, custom endpoints). -Local models need a base_url and model name, not API keys. -""" - -from __future__ import annotations - -import csv -import io -import logging -import os -import shutil -import tempfile -from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple -from urllib.parse import urlparse - -from rich.console import Console -from rich.table import Table - -logger = logging.getLogger("pdd.setup.local_llm_configurator") -console = Console() - -# CSV header for ~/.pdd/llm_model.csv -CSV_COLUMNS: List[str] = [ - "provider", - "model", - "input", - "output", - "coding_arena_elo", - "base_url", - "api_key", - "max_reasoning_tokens", - "structured_output", - "reasoning_type", - "location", -] - -DEFAULT_LM_STUDIO_BASE_URL = "http://localhost:1234/v1" -DEFAULT_OLLAMA_BASE_URL = "http://localhost:11434" -DEFAULT_OLLAMA_API_URL = "http://localhost:11434/api/tags" - - -def _get_user_csv_path() -> Path: - """Return the path to the user's ~/.pdd/llm_model.csv.""" - return Path.home() / ".pdd" / "llm_model.csv" - - -def _validate_base_url(url: str) -> bool: - """Validate that a base URL has http or https scheme and a netloc.""" - try: - parsed = urlparse(url.strip()) - return parsed.scheme in ("http", "https") and bool(parsed.netloc) - except Exception: - return False - - -def _build_model_row( - provider: str, - model: str, - base_url: str, - coding_arena_elo: int = 1000, - structured_output: bool = True, - reasoning_type: str = "none", -) -> Dict[str, Any]: - """Build a CSV row dict for a local model.""" - return { - "provider": provider, - "model": model, - "input": 0, - "output": 0, - "coding_arena_elo": coding_arena_elo, - "base_url": base_url, - "api_key": "", - "max_reasoning_tokens": 0, - "structured_output": structured_output, - "reasoning_type": reasoning_type, - "location": "", - } - - -def _read_existing_csv(csv_path: Path) -> List[Dict[str, str]]: - """Read existing rows from the user CSV, returning list of dicts.""" - rows: List[Dict[str, str]] = [] - if not csv_path.exists(): - return rows - try: - with open(csv_path, "r", newline="", encoding="utf-8") as f: - reader = csv.DictReader(f) - for row in reader: - rows.append(row) - except Exception as e: - logger.warning(f"Failed to read existing CSV at {csv_path}: {e}") - return rows - - -def _write_csv_atomic(csv_path: Path, rows: List[Dict[str, Any]]) -> None: - """Atomically write rows to the user CSV. - - Writes to a temporary file first, then moves it into place to avoid - partial writes corrupting the file. - """ - csv_path.parent.mkdir(parents=True, exist_ok=True) - - # Write to a temp file in the same directory for atomic rename - fd, tmp_path_str = tempfile.mkstemp( - dir=str(csv_path.parent), suffix=".csv.tmp" - ) - tmp_path = Path(tmp_path_str) - try: - with os.fdopen(fd, "w", newline="", encoding="utf-8") as f: - writer = csv.DictWriter(f, fieldnames=CSV_COLUMNS, extrasaction="ignore") - writer.writeheader() - for row in rows: - writer.writerow(row) - # Atomic move - shutil.move(str(tmp_path), str(csv_path)) - except Exception: - # Clean up temp file on failure - try: - tmp_path.unlink(missing_ok=True) - except Exception: - pass - raise - - -def _append_rows_to_csv(csv_path: Path, new_rows: List[Dict[str, Any]]) -> None: - """Append new model rows to the user CSV, creating it if needed.""" - existing = _read_existing_csv(csv_path) - # Convert existing rows to have consistent types - combined = list(existing) + new_rows - _write_csv_atomic(csv_path, combined) - - -def _discover_ollama_models(base_url: str) -> Optional[List[str]]: - """Query Ollama API for available models. - - Returns a list of model names, or None if the API is unreachable. - """ - import urllib.request - import json - - api_url = base_url.rstrip("/") + "/api/tags" - try: - req = urllib.request.Request(api_url, method="GET") - with urllib.request.urlopen(req, timeout=5) as resp: - data = json.loads(resp.read().decode("utf-8")) - models = data.get("models", []) - return [m.get("name", "") for m in models if m.get("name")] - except Exception as e: - logger.debug(f"Failed to query Ollama at {api_url}: {e}") - return None - - -def _prompt_input(prompt_text: str, default: str = "") -> str: - """Prompt user for input with optional default. Returns stripped input.""" - try: - if default: - raw = input(f"{prompt_text} [{default}]: ").strip() - return raw if raw else default - else: - return input(f"{prompt_text}: ").strip() - except (EOFError, KeyboardInterrupt): - return "" - - -def _configure_lm_studio() -> List[Dict[str, Any]]: - """Configure LM Studio models interactively.""" - rows: List[Dict[str, Any]] = [] - - console.print("\n[bold cyan]LM Studio Configuration[/bold cyan]") - console.print(f"Default base URL: {DEFAULT_LM_STUDIO_BASE_URL}") - - base_url = _prompt_input("Base URL", DEFAULT_LM_STUDIO_BASE_URL) - if not base_url: - console.print("[yellow]Cancelled.[/yellow]") - return rows - - if not _validate_base_url(base_url): - console.print("[red]Invalid URL. Must start with http:// or https://[/red]") - return rows - - while True: - model_name = _prompt_input("Model name (empty to finish)") - if not model_name: - break - - # Add lm_studio/ prefix if not present - litellm_model = model_name - if not litellm_model.startswith("lm_studio/"): - litellm_model = f"lm_studio/{model_name}" - - row = _build_model_row( - provider="lm_studio", - model=litellm_model, - base_url=base_url, - ) - rows.append(row) - console.print(f" [green]✓[/green] Added: {litellm_model}") - - return rows - - -def _configure_ollama() -> List[Dict[str, Any]]: - """Configure Ollama models interactively with auto-detection.""" - rows: List[Dict[str, Any]] = [] - - console.print("\n[bold cyan]Ollama Configuration[/bold cyan]") - console.print(f"Default base URL: {DEFAULT_OLLAMA_BASE_URL}") - - base_url = _prompt_input("Base URL", DEFAULT_OLLAMA_BASE_URL) - if not base_url: - console.print("[yellow]Cancelled.[/yellow]") - return rows - - if not _validate_base_url(base_url): - console.print("[red]Invalid URL. Must start with http:// or https://[/red]") - return rows - - # Try auto-detection - console.print("[dim]Checking for running Ollama instance...[/dim]") - discovered = _discover_ollama_models(base_url) - - if discovered: - console.print(f"[green]Found {len(discovered)} model(s):[/green]") - - table = Table(show_header=True, header_style="bold") - table.add_column("#", style="dim", width=4) - table.add_column("Model Name") - for idx, name in enumerate(discovered, 1): - table.add_row(str(idx), name) - console.print(table) - - selection = _prompt_input( - "Select models to add (comma-separated numbers, 'all', or empty to skip)" - ) - if not selection: - console.print("[yellow]No models selected.[/yellow]") - elif selection.strip().lower() == "all": - for name in discovered: - litellm_model = f"ollama_chat/{name}" - row = _build_model_row( - provider="Ollama", - model=litellm_model, - base_url=base_url, - ) - rows.append(row) - console.print(f" [green]✓[/green] Added: {litellm_model}") - else: - # Parse comma-separated indices - for part in selection.split(","): - part = part.strip() - try: - idx = int(part) - if 1 <= idx <= len(discovered): - name = discovered[idx - 1] - litellm_model = f"ollama_chat/{name}" - row = _build_model_row( - provider="Ollama", - model=litellm_model, - base_url=base_url, - ) - rows.append(row) - console.print(f" [green]✓[/green] Added: {litellm_model}") - else: - console.print(f" [yellow]Skipping invalid index: {idx}[/yellow]") - except ValueError: - console.print(f" [yellow]Skipping invalid input: '{part}'[/yellow]") - else: - console.print( - "[yellow]Could not connect to Ollama. Falling back to manual entry.[/yellow]" - ) - while True: - model_name = _prompt_input("Model name (empty to finish)") - if not model_name: - break - - litellm_model = model_name - if not litellm_model.startswith("ollama_chat/"): - litellm_model = f"ollama_chat/{model_name}" - - row = _build_model_row( - provider="Ollama", - model=litellm_model, - base_url=base_url, - ) - rows.append(row) - console.print(f" [green]✓[/green] Added: {litellm_model}") - - return rows - - -def _configure_custom() -> List[Dict[str, Any]]: - """Configure a custom local LLM endpoint interactively.""" - rows: List[Dict[str, Any]] = [] - - console.print("\n[bold cyan]Custom Local LLM Configuration[/bold cyan]") - - base_url = _prompt_input("Base URL (e.g., http://localhost:8080/v1)") - if not base_url: - console.print("[yellow]Cancelled.[/yellow]") - return rows - - if not _validate_base_url(base_url): - console.print("[red]Invalid URL. Must start with http:// or https://[/red]") - return rows - - provider_name = _prompt_input("Provider name", "custom") - - while True: - model_name = _prompt_input("Model name (empty to finish)") - if not model_name: - break - - row = _build_model_row( - provider=provider_name, - model=model_name, - base_url=base_url, - ) - rows.append(row) - console.print(f" [green]✓[/green] Added: {model_name}") - - return rows - - -def configure_local_llm() -> bool: - """Interactive setup for local LLM providers. - - Guides the user through selecting a local LLM provider (LM Studio, Ollama, - or custom), discovering available models, and appending them to the user's - ``~/.pdd/llm_model.csv``. - - Returns: - True if any models were added, False otherwise. - """ - console.print("\n[bold]Local LLM Setup[/bold]") - console.print("Configure local LLM tools for use with PDD.\n") - console.print("Select a provider:") - console.print(" [bold]1[/bold]. LM Studio (default: localhost:1234)") - console.print(" [bold]2[/bold]. Ollama (default: localhost:11434)") - console.print(" [bold]3[/bold]. Other (custom endpoint)") - console.print() - - choice = _prompt_input("Choice (1/2/3, empty to cancel)") - if not choice: - console.print("[yellow]Cancelled.[/yellow]") - return False - - new_rows: List[Dict[str, Any]] = [] - - if choice == "1": - new_rows = _configure_lm_studio() - elif choice == "2": - new_rows = _configure_ollama() - elif choice == "3": - new_rows = _configure_custom() - else: - console.print(f"[red]Invalid choice: '{choice}'. Please enter 1, 2, or 3.[/red]") - return False - - if not new_rows: - console.print("[yellow]No models were added.[/yellow]") - return False - - # Write to user CSV - csv_path = _get_user_csv_path() - try: - _append_rows_to_csv(csv_path, new_rows) - console.print( - f"\n[green]Successfully added {len(new_rows)} model(s) to {csv_path}[/green]" - ) - return True - except Exception as e: - console.print(f"[red]Failed to write to {csv_path}: {e}[/red]") - logger.error(f"Failed to write CSV: {e}", exc_info=True) - return False \ No newline at end of file diff --git a/pdd/setup/setup_tool.py b/pdd/setup/setup_tool.py deleted file mode 100644 index 4f420d2c0..000000000 --- a/pdd/setup/setup_tool.py +++ /dev/null @@ -1,167 +0,0 @@ -""" -Main orchestrator for ``pdd setup``. - -Auto-scans the environment for API keys (existence only — no API calls), -then presents an interactive menu. After any action the menu re-displays -with an updated scan. -""" -from __future__ import annotations - -import os -from typing import Dict - -from .api_key_scanner import scan_environment, KeyInfo -from .provider_manager import ( - add_provider_from_registry, - add_custom_provider, - remove_models_by_provider, - remove_individual_models, -) -from .local_llm_configurator import configure_local_llm -from .model_tester import test_model_interactive -from .cli_detector import detect_cli_tools -from .pddrc_initializer import offer_pddrc_init - - -# --------------------------------------------------------------------------- -# Display helpers -# --------------------------------------------------------------------------- - -def _display_scan(scan_results: Dict[str, KeyInfo]) -> None: - """Print a table of discovered API keys and a summary line.""" - print("\n API-key scan") - print(" " + "─" * 50) - - if not scan_results: - print(" No models configured yet.") - print(" Use 'Add a provider' to get started.") - print() - return - - api_found = 0 - local_count = 0 - - for key_name, info in scan_results.items(): - if info.is_set: - # Heuristic: keys whose source mentions "local" or whose name - # hints at a local provider are counted separately. - source_lower = (info.source or "").lower() - if "local" in source_lower or "ollama" in key_name.lower() or "lm_studio" in key_name.lower(): - local_count += 1 - else: - api_found += 1 - print(f" {key_name:30s} ✓ Found ({info.source})") - else: - print(f" {key_name:30s} — Not found") - - # Add helpful note about editing API keys - shell_path = os.environ.get("SHELL", "") - shell_name = os.path.basename(shell_path) if shell_path else "bash" - print(f"\n 💡 To edit API keys: update ~/.pdd/api-env.{shell_name} or .env file") - - total_configured = api_found + local_count - print( - f"\n Models configured: {total_configured} " - f"(from {api_found} API keys + {local_count} local)" - ) - print() - - -def _display_menu() -> None: - """Print the interactive menu options.""" - print(" What would you like to do?") - print(" 1. Add a provider") - print(" 2. Remove models") - print(" 3. Test a model") - print(" 4. Detect CLI tools") - print(" 5. Initialize .pddrc") - print(" 6. Done") - print() - - -def _add_provider_submenu() -> None: - """Sub-menu for option 1 — Add a provider.""" - print() - print(" Add a provider:") - print(" a. Search providers") - print(" b. Add a local LLM") - print(" c. Add a custom provider") - print() - - sub_choice = input(" Choice [a/b/c]: ").strip().lower() - - if sub_choice == "a": - add_provider_from_registry() - elif sub_choice == "b": - configure_local_llm() - elif sub_choice == "c": - add_custom_provider() - else: - print(" Invalid choice — returning to main menu.") - - -def _remove_models_submenu() -> None: - """Sub-menu for option 2 — Remove models.""" - print() - print(" Remove models:") - print(" a. By provider") - print(" b. Individual models") - print() - - sub_choice = input(" Choice [a/b]: ").strip().lower() - - if sub_choice == "a": - remove_models_by_provider() - elif sub_choice == "b": - remove_individual_models() - else: - print(" Invalid choice — returning to main menu.") - - -# --------------------------------------------------------------------------- -# Public entry point -# --------------------------------------------------------------------------- - -def run_setup() -> None: - """Main entry point for ``pdd setup``. - - Scans the environment, displays results, and loops an interactive menu - until the user selects *Done* or presses Ctrl-C. - """ - try: - print() - print(" ╭──────────────────────────────╮") - print(" │ pdd setup │") - print(" ╰──────────────────────────────╯") - - while True: - # (Re-)scan on every iteration so the display stays current. - scan_results = scan_environment() - _display_scan(scan_results) - _display_menu() - - choice = input(" Choice [1-6]: ").strip() - - if choice == "1": - _add_provider_submenu() - elif choice == "2": - _remove_models_submenu() - elif choice == "3": - test_model_interactive() - elif choice == "4": - detect_cli_tools() - elif choice == "5": - offer_pddrc_init() - elif choice == "6": - print("\n ✓ Setup complete. Happy prompting!\n") - break - else: - print(" Invalid choice — please enter a number between 1 and 6.") - - except KeyboardInterrupt: - # Clean exit on Ctrl-C at any point. - print("\n\n Setup interrupted — exiting.\n") - - -if __name__ == "__main__": - run_setup() \ No newline at end of file diff --git a/pdd/setup_tool.py b/pdd/setup_tool.py index 55a8677a5..5613d97cb 100644 --- a/pdd/setup_tool.py +++ b/pdd/setup_tool.py @@ -1,648 +1,618 @@ -#!/usr/bin/env python3 """ -PDD Setup Script - Post-install configuration tool for PDD (Prompt Driven Development) -Helps new users bootstrap their PDD configuration with LLM API keys and basic settings. +Main orchestrator for `pdd setup`. + +Implements a two-phase flow designed for minimal user friction: + Phase 1 — Interactive CLI bootstrap (0–2 user inputs) + Phase 2 — Deterministic auto-configuration (pure Python, no LLM calls) """ +from __future__ import annotations -import os -import sys -import subprocess +import getpass import json -import requests -import csv -import importlib.resources -import shlex +import os +import urllib.error +import urllib.request from pathlib import Path -from typing import Dict, Optional, Tuple, List - -# Global variables for non-ASCII characters and colors -HEAVY_HORIZONTAL = "━" -LIGHT_HORIZONTAL = "─" -HEAVY_VERTICAL = "┃" -LIGHT_VERTICAL = "│" -TOP_LEFT_CORNER = "┏" -TOP_RIGHT_CORNER = "┓" -BOTTOM_LEFT_CORNER = "┗" -BOTTOM_RIGHT_CORNER = "┛" -CROSS = "┼" -TEE_DOWN = "┬" -TEE_UP = "┴" -TEE_RIGHT = "├" -TEE_LEFT = "┤" -BULLET = "•" -ARROW_RIGHT = "→" -CHECK_MARK = "✓" -CROSS_MARK = "✗" - -# Color codes -RESET = "\033[0m" -WHITE = "\033[97m" -CYAN = "\033[96m" -YELLOW = "\033[93m" -BOLD = "\033[1m" - -# Template content inline -SUCCESS_PYTHON_TEMPLATE = """ -Write a python script to print "You did it, !!!" to the console. -Do not write anything except that message. -Capitalize the username.""" - -def _read_packaged_llm_model_csv() -> Tuple[List[str], List[Dict[str, str]]]: - """Load the packaged CSV (pdd/data/llm_model.csv) and return header + rows. - - Returns: - (header_fields, rows) where header_fields is the list of column names - and rows is a list of dictionaries for each CSV row. - """ - try: - csv_text = importlib.resources.files('pdd').joinpath('data/llm_model.csv').read_text() - except Exception as e: - raise FileNotFoundError(f"Failed to load default LLM model CSV from package: {e}") - - reader = csv.DictReader(csv_text.splitlines()) - header = reader.fieldnames or [] - rows = [row for row in reader] - return header, rows - -def print_colored(text: str, color: str = WHITE, bold: bool = False) -> None: - """Print colored text to console""" - style = BOLD + color if bold else color - print(f"{style}{text}{RESET}") - -def create_divider(char: str = LIGHT_HORIZONTAL, width: int = 80) -> str: - """Create a horizontal divider line""" - return char * width - -def create_fat_divider(width: int = 80) -> str: - """Create a fat horizontal divider line""" - return HEAVY_HORIZONTAL * width - -def print_pdd_logo(): - """Print the PDD logo in ASCII art""" - logo = "\n".join( - [ - " +xxxxxxxxxxxxxxx+", - "xxxxxxxxxxxxxxxxxxxxx+", - "xxx +xx+ PROMPT", - "xxx x+ xx+ DRIVEN", - "xxx x+ xxx DEVELOPMENT©", - "xxx x+ xx+", - "xxx x+ xx+ COMMAND LINE INTERFACE", - "xxx x+ xxx", - "xxx +xx+ ", - "xxx +xxxxxxxxxxx+", - "xxx +xx+", - "xxx +xx+", - "xxx+xx+ WWW.PROMPTDRIVEN.AI", - "xxxx+", - "xx+", - ] - ) - print(f"{CYAN}{logo}{RESET}") - print() - print_colored("Let's get set up quickly with a solid basic configuration!", WHITE, bold=True) +from typing import Dict, List, Optional, Tuple + +from rich.console import Console as _RichConsole +_console = _RichConsole() + +# Top providers shown when prompting for an API key (order = display order) +_PROMPT_PROVIDERS = [ + ("anthropic", "Anthropic", "ANTHROPIC_API_KEY"), + ("gemini", "Google Gemini", "GEMINI_API_KEY"), + ("openai", "OpenAI", "OPENAI_API_KEY"), + ("deepseek", "DeepSeek", "DEEPSEEK_API_KEY"), +] + + +def run_setup() -> None: + """Main entry point for pdd setup. Two-phase flow with fallback.""" + from pdd.cli_detector import detect_and_bootstrap_cli, CliBootstrapResult + + # ── Banner ──────────────────────────────────────────────────────────── print() - print_colored("Supported: OpenAI, Google Gemini, and Anthropic Claude", WHITE) - print_colored("from their respective API endpoints (no third-parties, such as Azure)", WHITE) + print(" ╭──────────────────────────────╮") + print(" │ pdd setup │") + print(" ╰──────────────────────────────╯") print() -def get_csv_variable_names() -> Dict[str, str]: - """Inspect packaged CSV to determine API key variable names per provider. + try: + # ── Phase 1 — CLI Bootstrap (interactive, 0–2 user inputs) ──────── + result: CliBootstrapResult = detect_and_bootstrap_cli() - Focus on direct providers only: OpenAI GPT models (model startswith 'gpt-'), - Google Gemini (model startswith 'gemini/'), and Anthropic (model startswith 'anthropic/'). - """ - header, rows = _read_packaged_llm_model_csv() - variable_names: Dict[str, str] = {} + if result.cli_name == "": + print( + "Agentic features require at least one CLI tool. " + "Run `pdd setup` again when ready." + ) + return - for row in rows: - model = (row.get('model') or '').strip() - api_key = (row.get('api_key') or '').strip() - provider = (row.get('provider') or '').strip().upper() + if not result.api_key_configured: + print( + "Note: No API key configured. " + "The agent may have limited capability." + ) - if not api_key: - continue + # ── Phase 2 — Deterministic Auto-Configuration ──────────────────── + auto_success = _run_auto_phase() - if model.startswith('gpt-') and provider == 'OPENAI': - variable_names['OPENAI'] = api_key - elif model.startswith('gemini/') and provider == 'GOOGLE': - # Prefer direct Gemini key, not Vertex - variable_names['GOOGLE'] = api_key - elif model.startswith('anthropic/') and provider == 'ANTHROPIC': - variable_names['ANTHROPIC'] = api_key - - # Fallbacks if not detected (keep prior behavior) - variable_names.setdefault('OPENAI', 'OPENAI_API_KEY') - # Prefer GEMINI_API_KEY name for Google if present - variable_names.setdefault('GOOGLE', 'GEMINI_API_KEY') - variable_names.setdefault('ANTHROPIC', 'ANTHROPIC_API_KEY') - return variable_names - -def discover_api_keys() -> Dict[str, Optional[str]]: - """Discover API keys from environment variables""" - # Get the variable names actually used in CSV template - csv_vars = get_csv_variable_names() - - keys = { - 'OPENAI_API_KEY': os.getenv('OPENAI_API_KEY'), - 'ANTHROPIC_API_KEY': os.getenv('ANTHROPIC_API_KEY'), - } - - # For Google, check both possible environment variables but use CSV template's variable name - google_var_name = csv_vars.get('GOOGLE', 'GEMINI_API_KEY') # Default to GEMINI_API_KEY - google_api_key = os.getenv('GEMINI_API_KEY') or os.getenv('GOOGLE_API_KEY') - keys[google_var_name] = google_api_key - - return keys - -def test_openai_key(api_key: str) -> bool: - """Test OpenAI API key validity""" - if not api_key or not api_key.strip(): - return False - - try: - headers = { - 'Authorization': f'Bearer {api_key.strip()}', - 'Content-Type': 'application/json' - } - response = requests.get( - 'https://api.openai.com/v1/models', - headers=headers, - timeout=10 - ) - return response.status_code == 200 - except Exception: - return False + if not auto_success: + _run_fallback_menu() -def test_google_key(api_key: str) -> bool: - """Test Google Gemini API key validity""" - if not api_key or not api_key.strip(): - return False - + print() + _console.print("[green]Setup complete. Happy prompting![/green]") + print() + + except KeyboardInterrupt: + print("\nSetup interrupted — exiting.") + return + + +# --------------------------------------------------------------------------- +# Phase 2 — Deterministic auto-configuration +# --------------------------------------------------------------------------- + +def _run_auto_phase() -> bool: + """Run 4 deterministic setup steps. Returns True on success.""" try: - response = requests.get( - f'https://generativelanguage.googleapis.com/v1beta/models?key={api_key.strip()}', - timeout=10 - ) - return response.status_code == 200 - except Exception: + # Step 1: Scan API keys + print("\n[Step 1/4] Scanning for API keys...") + found_keys = _step1_scan_keys() + input("\nPress Enter to continue to the next step...") + + # Step 2: Configure models + print("\n[Step 2/4] Configuring models...") + model_summary = _step2_configure_models(found_keys) + input("\nPress Enter to continue to the next step...") + + # Step 3: Local LLMs + .pddrc + print("\n[Step 3/4] Checking local LLMs and .pddrc...") + local_summary = _step3_local_llms_and_pddrc() + input("\nPress Enter to continue to the next step...") + + # Step 4: Test + summary + print("\n[Step 4/4] Testing and summarizing...") + _step4_test_and_summary(found_keys, model_summary, local_summary) + + return True + + except Exception as exc: + print(f"\nAuto-configuration failed: {exc}") + print("Falling back to manual setup...") return False -def test_anthropic_key(api_key: str) -> bool: - """Test Anthropic API key validity""" - if not api_key or not api_key.strip(): - return False - + +# --------------------------------------------------------------------------- +# Step 1 — Scan for API keys +# --------------------------------------------------------------------------- + +def _step1_scan_keys() -> List[Tuple[str, str]]: + """Scan all known API key env vars across all sources. + + Returns list of (key_name, source_label) for keys that were found. + """ + from pdd.litellm_registry import PROVIDER_API_KEY_MAP + from pdd.api_key_scanner import _parse_api_env_file, _detect_shell + + # Ensure ~/.pdd exists + pdd_dir = Path.home() / ".pdd" + pdd_dir.mkdir(parents=True, exist_ok=True) + + # Gather all unique env var names to check + all_key_names = sorted(set(PROVIDER_API_KEY_MAP.values())) + + # Load sources once + dotenv_vals: Dict[str, str] = {} try: - headers = { - 'x-api-key': api_key.strip(), - 'Content-Type': 'application/json' - } - response = requests.get( - 'https://api.anthropic.com/v1/messages', - headers=headers, - timeout=10 - ) - # Anthropic returns 400 for invalid request structure but 401/403 for bad keys - return response.status_code != 401 and response.status_code != 403 - except Exception: - return False + from dotenv import dotenv_values + for env_path in [Path.cwd() / ".env", Path.home() / ".env"]: + if env_path.is_file(): + vals = dotenv_values(env_path) + for k, v in vals.items(): + if v is not None and k not in dotenv_vals: + dotenv_vals[k] = v + except ImportError: + pass + + shell_name = _detect_shell() + api_env_vals: Dict[str, str] = {} + api_env_label = "" + if shell_name: + api_env_path = pdd_dir / f"api-env.{shell_name}" + api_env_vals = _parse_api_env_file(api_env_path) + api_env_label = f"~/.pdd/api-env.{shell_name}" + + # Scan each key + found_keys: List[Tuple[str, str]] = [] + max_name_len = max(len(k) for k in all_key_names) if all_key_names else 20 + + for key_name in all_key_names: + if key_name in os.environ: + source = "shell environment" + found_keys.append((key_name, source)) + print(f" ✓ {key_name:<{max_name_len}s} {source}") + elif key_name in api_env_vals: + source = api_env_label + found_keys.append((key_name, source)) + print(f" ✓ {key_name:<{max_name_len}s} {source}") + elif key_name in dotenv_vals: + source = ".env file" + found_keys.append((key_name, source)) + print(f" ✓ {key_name:<{max_name_len}s} {source}") + + if not found_keys: + print(" ✗ No API keys found.\n") + found_keys = _prompt_for_api_key() + + print(f"\n {len(found_keys)} API key(s) found.") + return found_keys + + +def _prompt_for_api_key() -> List[Tuple[str, str]]: + """Interactively ask the user to add at least one API key. + + Called when no keys are found during scanning. Saves the key to + ~/.pdd/api-env.{shell} and loads it into the current session. + Returns list of (key_name, source_label) for newly added keys. + """ + from pdd.litellm_registry import PROVIDER_API_KEY_MAP, PROVIDER_DISPLAY_NAMES + from pdd.provider_manager import _save_key_to_api_env, _get_api_env_path -def test_api_keys(keys: Dict[str, Optional[str]]) -> Dict[str, bool]: - """Test all discovered API keys""" - results = {} - - print_colored(f"\n{LIGHT_HORIZONTAL * 40}", CYAN) - print_colored("Testing discovered API keys...", CYAN, bold=True) - print_colored(f"{LIGHT_HORIZONTAL * 40}", CYAN) - - for key_name, key_value in keys.items(): - if key_value: - print(f"Testing {key_name}...", end=" ", flush=True) - if key_name == 'OPENAI_API_KEY': - valid = test_openai_key(key_value) - elif key_name in ['GEMINI_API_KEY', 'GOOGLE_API_KEY']: - valid = test_google_key(key_value) - elif key_name == 'ANTHROPIC_API_KEY': - valid = test_anthropic_key(key_value) - else: - valid = False - - if valid: - print_colored(f"{CHECK_MARK} Valid", CYAN) - results[key_name] = True - else: - print_colored(f"{CROSS_MARK} Invalid", YELLOW) - results[key_name] = False - else: - print_colored(f"{key_name}: Not found", YELLOW) - results[key_name] = False - - return results - -def get_user_keys(current_keys: Dict[str, Optional[str]]) -> Dict[str, Optional[str]]: - """Interactive key entry/modification""" - print_colored(f"\n{create_fat_divider()}", YELLOW) - print_colored("API Key Configuration", YELLOW, bold=True) - print_colored(f"{create_fat_divider()}", YELLOW) - - print_colored("You need only one API key to get started", WHITE) - print() - print_colored("Get API keys here:", WHITE) - print_colored(f" OpenAI {ARROW_RIGHT} https://platform.openai.com/api-keys", CYAN) - print_colored(f" Google Gemini {ARROW_RIGHT} https://aistudio.google.com/app/apikey", CYAN) - print_colored(f" Anthropic {ARROW_RIGHT} https://console.anthropic.com/settings/keys", CYAN) - print() - print_colored("A free instant starter key is available from Google Gemini (above)", CYAN) - print() - - new_keys = current_keys.copy() - - # Get the actual key names from discovered keys - key_names = list(current_keys.keys()) - for key_name in key_names: - current_value = current_keys.get(key_name, "") - status = "found" if current_value else "not found" - - print_colored(f"{LIGHT_HORIZONTAL * 60}", CYAN) - print_colored(f"{key_name} (currently: {status})", WHITE, bold=True) - - if current_value: - prompt = f"Enter new key or press ENTER to keep existing: " + added_keys: List[Tuple[str, str]] = [] + api_env_label = f"~/.pdd/api-env.{os.path.basename(os.environ.get('SHELL', 'bash'))}" + + while True: + print(" To continue setup, add at least one API key.") + print(" Popular providers:") + for i, (_, display, env_var) in enumerate(_PROMPT_PROVIDERS, 1): + print(f" {i}) {display:<20s} ({env_var})") + other_idx = len(_PROMPT_PROVIDERS) + 1 + skip_idx = other_idx + 1 + print(f" {other_idx}) Other provider") + print(f" {skip_idx}) Skip (continue without keys)") + + try: + choice = input(f"\n Select provider [1-{skip_idx}]: ").strip() + except (EOFError, KeyboardInterrupt): + print() + break + + # Parse choice + try: + choice_num = int(choice) + except ValueError: + print(f" Invalid input. Enter a number 1-{skip_idx}.\n") + continue + + if choice_num == skip_idx: + break + + if choice_num == other_idx: + # Show all providers + all_providers = sorted( + PROVIDER_API_KEY_MAP.items(), + key=lambda x: PROVIDER_DISPLAY_NAMES.get(x[0], x[0]), + ) + print("\n All providers:") + for i, (pid, env_var) in enumerate(all_providers, 1): + display = PROVIDER_DISPLAY_NAMES.get(pid, pid) + print(f" {i}) {display:<25s} ({env_var})") + try: + sub_choice = input(f"\n Select provider [1-{len(all_providers)}]: ").strip() + sub_num = int(sub_choice) + if 1 <= sub_num <= len(all_providers): + _, env_var = all_providers[sub_num - 1] + display = PROVIDER_DISPLAY_NAMES.get( + all_providers[sub_num - 1][0], + all_providers[sub_num - 1][0], + ) + else: + print(" Invalid selection.\n") + continue + except (ValueError, EOFError, KeyboardInterrupt): + print() + continue + elif 1 <= choice_num <= len(_PROMPT_PROVIDERS): + _, display, env_var = _PROMPT_PROVIDERS[choice_num - 1] else: - prompt = f"Enter API key (or press ENTER to skip): " - + print(f" Invalid input. Enter a number 1-{skip_idx}.\n") + continue + + # Prompt for the key value (masked) try: - user_input = input(f"{WHITE}{prompt}{RESET}").strip() - if user_input: - new_keys[key_name] = user_input - elif not current_value: - new_keys[key_name] = None - except KeyboardInterrupt: - print_colored("\n\nSetup cancelled.", YELLOW) - sys.exit(0) - - return new_keys - -def detect_shell() -> str: - """Detect user's default shell""" - try: - shell_path = os.getenv('SHELL', '/bin/bash') - shell_name = os.path.basename(shell_path) - return shell_name - except: - return 'bash' - -def get_shell_init_file(shell: str) -> str: - """Get the appropriate shell initialization file""" - home = Path.home() - - shell_files = { - 'bash': home / '.bashrc', - 'zsh': home / '.zshrc', - 'fish': home / '.config/fish/config.fish', - 'csh': home / '.cshrc', - 'tcsh': home / '.tcshrc', - 'ksh': home / '.kshrc', - 'sh': home / '.profile' - } - - return str(shell_files.get(shell, home / '.bashrc')) - -def create_api_env_script(keys: Dict[str, str], shell: str) -> str: - """Create shell-appropriate environment script with proper escaping""" - valid_keys = {k: v for k, v in keys.items() if v} - lines = [] - - for key, value in valid_keys.items(): - # shlex.quote is designed for POSIX shells (sh, bash, zsh, ksh) - # It also works reasonably well for fish and csh for simple assignments - quoted_val = shlex.quote(value) - - if shell == 'fish': - lines.append(f'set -gx {key} {quoted_val}') - elif shell in ['csh', 'tcsh']: - lines.append(f'setenv {key} {quoted_val}') - else: # bash, zsh, ksh, sh and others - lines.append(f'export {key}={quoted_val}') - - return '\n'.join(lines) + '\n' - -def save_configuration(valid_keys: Dict[str, str]) -> Tuple[List[str], bool, Optional[str]]: - """Save configuration to ~/.pdd/ directory""" - home = Path.home() - pdd_dir = home / '.pdd' - created_pdd_dir = False - saved_files = [] - - # Create .pdd directory if it doesn't exist - if not pdd_dir.exists(): - pdd_dir.mkdir(mode=0o755) - created_pdd_dir = True - - # Detect shell and create api-env script - shell = detect_shell() - api_env_content = create_api_env_script(valid_keys, shell) - - # Write shell-specific api-env file - api_env_file = pdd_dir / f'api-env.{shell}' - api_env_file.write_text(api_env_content) - api_env_file.chmod(0o755) - saved_files.append(str(api_env_file)) - - # Create llm_model.csv with models from packaged CSV filtered by provider and available keys - header_fields, rows = _read_packaged_llm_model_csv() - - # Keep only direct Google Gemini (model startswith 'gemini/'), OpenAI GPT (gpt-*) and Anthropic (anthropic/*) - def _is_supported_model(row: Dict[str, str]) -> bool: - model = (row.get('model') or '').strip() - if model.startswith('gpt-'): - return True - if model.startswith('gemini/'): - return True - if model.startswith('anthropic/'): - return True - return False + key_value = getpass.getpass(f" Paste your {env_var}: ").strip() + except (EOFError, KeyboardInterrupt): + print() + break - # Filter rows by supported models and by api_key presence in valid_keys - filtered_rows: List[Dict[str, str]] = [] - for row in rows: - if not _is_supported_model(row): + if not key_value: + print(" No key entered, skipping.\n") continue - api_key_name = (row.get('api_key') or '').strip() - # Include only if we have a validated key for this row - if api_key_name and api_key_name in valid_keys: - filtered_rows.append(row) - - # Write out the filtered CSV to ~/.pdd/llm_model.csv preserving column order - llm_model_file = pdd_dir / 'llm_model.csv' - with llm_model_file.open('w', newline='') as f: - writer = csv.DictWriter(f, fieldnames=header_fields) - writer.writeheader() - for row in filtered_rows: - writer.writerow({k: row.get(k, '') for k in header_fields}) - saved_files.append(str(llm_model_file)) - - # Update shell init file - init_file_path = get_shell_init_file(shell) - init_file = Path(init_file_path) - init_file_updated = None - - source_line = f'[ -f "{api_env_file}" ] && source "{api_env_file}"' - if shell == 'fish': - source_line = f'test -f "{api_env_file}"; and source "{api_env_file}"' - elif shell in ['csh', 'tcsh']: - source_line = f'if ( -f "{api_env_file}" ) source "{api_env_file}"' - elif shell == 'sh': - source_line = f'[ -f "{api_env_file}" ] && . "{api_env_file}"' - - # Ensure parent directory exists (important for fish shell) - init_file.parent.mkdir(parents=True, exist_ok=True) - - # Check if source line already exists - if init_file.exists(): - content = init_file.read_text() - if str(api_env_file) not in content: - with init_file.open('a') as f: - f.write(f'\n# PDD API environment\n{source_line}\n') - init_file_updated = str(init_file) + + # Save to api-env file and load into current session + _save_key_to_api_env(env_var, key_value) + added_keys.append((env_var, api_env_label)) + print(f" ✓ {env_var} saved to {api_env_label}") + print(f" ✓ Loaded into current session\n") + + # Ask if they want to add another + try: + another = input(" Add another key? [y/N]: ").strip().lower() + except (EOFError, KeyboardInterrupt): + print() + break + + if another not in ("y", "yes"): + break + print() + + return added_keys + + +# --------------------------------------------------------------------------- +# Step 2 — Configure models from reference CSV +# --------------------------------------------------------------------------- + +def _step2_configure_models( + found_keys: List[Tuple[str, str]], +) -> Dict[str, int]: + """Match found API keys to reference models and write user CSV. + + Returns {provider_display_name: model_count} for the summary. + """ + from pdd.provider_manager import ( + _read_csv, + _write_csv_atomic, + _get_user_csv_path, + ) + + found_key_names = {k for k, _ in found_keys} + + # Read reference CSV + ref_path = Path(__file__).parent / "data" / "llm_model.csv" + ref_rows = _read_csv(ref_path) + + # Filter reference rows to those whose api_key matches a found key + # Skip local-only rows (lm_studio, ollama — handled in step 3) + matching_rows: List[Dict[str, str]] = [] + for row in ref_rows: + api_key_col = row.get("api_key", "").strip() + provider = row.get("provider", "").strip().lower() + base_url = row.get("base_url", "").strip() + + # Skip local models + if provider in ("lm_studio", "ollama"): + continue + # Skip rows with base_url pointing to localhost (local models) + if base_url and ("localhost" in base_url or "127.0.0.1" in base_url): + continue + # Match on api_key + if api_key_col and api_key_col in found_key_names: + matching_rows.append(row) + + # Read existing user CSV and deduplicate + user_csv_path = _get_user_csv_path() + existing_rows = _read_csv(user_csv_path) + existing_models = {r.get("model", "").strip() for r in existing_rows} + + new_rows: List[Dict[str, str]] = [] + for row in matching_rows: + if row.get("model", "").strip() not in existing_models: + new_rows.append(row) + + # Count by provider for display + provider_counts: Dict[str, int] = {} + all_rows = existing_rows + new_rows + for row in all_rows: + provider = row.get("provider", "Unknown").strip() + # Only count cloud models (with api_key) + if row.get("api_key", "").strip(): + provider_counts[provider] = provider_counts.get(provider, 0) + 1 + + # Write merged result + if new_rows: + _write_csv_atomic(user_csv_path, all_rows) + print(f" ✓ {len(new_rows)} new model(s) added to {user_csv_path}") + else: + print(f" ✓ All matching models already in {user_csv_path}") + + total = sum(provider_counts.values()) + print(f" ✓ {total} cloud model(s) configured") + for provider, count in sorted(provider_counts.items()): + s = "s" if count != 1 else "" + print(f" {provider}: {count} model{s}") + + return provider_counts + + +# --------------------------------------------------------------------------- +# Step 3 — Local LLMs + .pddrc +# --------------------------------------------------------------------------- + +def _step3_local_llms_and_pddrc() -> Dict[str, List[str]]: + """Check local LLMs and ensure .pddrc exists. + + Returns {provider: [model_names]} for local LLMs found. + """ + from pdd.pddrc_initializer import _detect_language, _build_pddrc_content + + local_summary: Dict[str, List[str]] = {} + + # ── Check Ollama ────────────────────────────────────────────────────── + ollama_models = _query_local_server( + url="http://localhost:11434/api/tags", + extract_models=_extract_ollama_models, + ) + if ollama_models is not None: + local_summary["Ollama"] = ollama_models + if ollama_models: + names = ", ".join(ollama_models) + print(f" ✓ Ollama running — found {names}") + _append_local_models_to_csv( + ollama_models, provider="Ollama", prefix="ollama_chat/", + base_url="http://localhost:11434", + ) + else: + print(" ✓ Ollama running — no models installed") else: - init_file.write_text(f'# PDD API environment\n{source_line}\n') - init_file_updated = str(init_file) - - return saved_files, created_pdd_dir, init_file_updated - -def create_sample_prompt(): - """Create the sample prompt file""" - prompt_file = Path('success_python.prompt') - prompt_file.write_text(SUCCESS_PYTHON_TEMPLATE) - return str(prompt_file) - -def show_menu(keys: Dict[str, Optional[str]], test_results: Dict[str, bool]) -> str: - """Show main menu and get user choice""" - print_colored(f"\n{create_divider()}", CYAN) - print_colored("Main Menu", CYAN, bold=True) - print_colored(f"{create_divider()}", CYAN) - - # Show current status - print_colored("Current API Key Status:", WHITE, bold=True) - # Get the actual key names from discovered keys - key_names = list(keys.keys()) - for key_name in key_names: - key_value = keys.get(key_name) - if key_value: - status = f"{CHECK_MARK} Valid" if test_results.get(key_name) else f"{CROSS_MARK} Invalid" - status_color = CYAN if test_results.get(key_name) else YELLOW + print(" ✗ Ollama not running (skip)") + + # ── Check LM Studio ────────────────────────────────────────────────── + lm_models = _query_local_server( + url="http://localhost:1234/v1/models", + extract_models=_extract_lm_studio_models, + ) + if lm_models is not None: + local_summary["LM Studio"] = lm_models + if lm_models: + names = ", ".join(lm_models) + print(f" ✓ LM Studio running — found {names}") + _append_local_models_to_csv( + lm_models, provider="lm_studio", prefix="lm_studio/", + base_url="http://localhost:1234/v1", + ) else: - status = "Not configured" - status_color = YELLOW - - print(f" {key_name}: ", end="") - print_colored(status, status_color) - + print(" ✓ LM Studio running — no models loaded") + else: + print(" ✗ LM Studio not running (skip)") + + # ── Check .pddrc ───────────────────────────────────────────────────── + cwd = Path.cwd() + pddrc_path = cwd / ".pddrc" + if pddrc_path.exists(): + print(f" ✓ .pddrc already exists at {pddrc_path}") + else: + language = _detect_language(cwd) or "python" + content = _build_pddrc_content(language) + try: + pddrc_path.write_text(content, encoding="utf-8") + print(f" ✓ Created .pddrc at {pddrc_path} (detected: {language})") + except OSError as exc: + print(f" ✗ Failed to create .pddrc: {exc}") + + return local_summary + + +def _query_local_server( + url: str, + extract_models, + timeout: float = 3.0, +) -> Optional[List[str]]: + """Query a local LLM server. Returns model list or None if unreachable.""" + try: + req = urllib.request.Request(url) + with urllib.request.urlopen(req, timeout=timeout) as resp: + data = json.loads(resp.read().decode("utf-8")) + return extract_models(data) + except (urllib.error.URLError, OSError, json.JSONDecodeError, KeyError): + return None + + +def _extract_ollama_models(data: dict) -> List[str]: + """Extract model names from Ollama /api/tags response.""" + models = data.get("models", []) + return [m.get("name", "") for m in models if m.get("name")] + + +def _extract_lm_studio_models(data: dict) -> List[str]: + """Extract model names from LM Studio /v1/models response.""" + models = data.get("data", []) + return [m.get("id", "") for m in models if m.get("id")] + + +def _append_local_models_to_csv( + model_names: List[str], + provider: str, + prefix: str, + base_url: str, +) -> None: + """Append local models to user CSV, skipping duplicates.""" + from pdd.provider_manager import ( + _read_csv, + _write_csv_atomic, + _get_user_csv_path, + ) + + user_csv_path = _get_user_csv_path() + existing_rows = _read_csv(user_csv_path) + existing_models = {r.get("model", "").strip() for r in existing_rows} + + new_rows = [] + for name in model_names: + model_id = f"{prefix}{name}" + if model_id not in existing_models: + new_rows.append({ + "provider": provider, + "model": model_id, + "input": "0", + "output": "0", + "coding_arena_elo": "1000", + "base_url": base_url, + "api_key": "", + "max_reasoning_tokens": "0", + "structured_output": "True", + "reasoning_type": "none", + "location": "", + }) + + if new_rows: + _write_csv_atomic(user_csv_path, existing_rows + new_rows) + + +# --------------------------------------------------------------------------- +# Step 4 — Test one model + print summary +# --------------------------------------------------------------------------- + +def _step4_test_and_summary( + found_keys: List[Tuple[str, str]], + model_summary: Dict[str, int], + local_summary: Dict[str, List[str]], +) -> None: + """Test the first available cloud model and print the final summary.""" + from pdd.provider_manager import _read_csv, _get_user_csv_path + + # Pick first cloud model + user_csv_path = _get_user_csv_path() + rows = _read_csv(user_csv_path) + test_result = "Skipped (no models configured)" + + cloud_row = None + for row in rows: + if row.get("api_key", "").strip(): + cloud_row = row + break + + if cloud_row: + test_model = cloud_row.get("model", "") + try: + import litellm # noqa: F401 + from pdd.model_tester import _run_test + + print(f" Testing {test_model}...") + result = _run_test(cloud_row) + if result["success"]: + test_result = f"✓ {test_model} responded OK ({result['duration_s']:.1f}s)" + else: + test_result = f"✗ {test_model} failed: {result['error']}" + except ImportError: + test_result = "Skipped (litellm not installed)" + print(f" {test_result}") + + # ── Summary ─────────────────────────────────────────────────────────── print() - print_colored("Options:", WHITE, bold=True) - print(f" 1. Re-enter API keys") - print(f" 2. Re-test current keys") - print(f" 3. Save configuration and exit") - print(f" 4. Exit without saving") + print(" ═══════════════════════════════════════════════") + print(" PDD Setup Complete") + print(" ═══════════════════════════════════════════════") print() - - while True: - try: - choice = input(f"{WHITE}Choose an option (1-4): {RESET}").strip() - if choice in ['1', '2', '3', '4']: - return choice + + # API Keys + print(f" API Keys: {len(found_keys)} found") + + # Models + total_models = sum(model_summary.values()) + parts = ", ".join(f"{p}: {c}" for p, c in sorted(model_summary.items())) + if parts: + print(f" Models: {total_models} configured ({parts})") + else: + print(f" Models: {total_models} configured") + + # Local LLMs + if local_summary: + local_parts = [] + for provider, models in local_summary.items(): + if models: + local_parts.append(f"{provider} — {', '.join(models)}") else: - print_colored("Please enter 1, 2, 3, or 4", YELLOW) - except KeyboardInterrupt: - print_colored("\n\nSetup cancelled.", YELLOW) - sys.exit(0) - -def create_exit_summary(saved_files: List[str], created_pdd_dir: bool, sample_prompt_file: str, shell: str, valid_keys: Dict[str, str], init_file_updated: Optional[str] = None) -> str: - """Create comprehensive exit summary""" - summary_lines = [ - "\n\n\n\n\n", - create_fat_divider(), - "PDD Setup Complete!", - create_fat_divider(), - "", - "API Keys Configured:", - "" - ] - - # Add configured API keys information - if valid_keys: - for key_name, key_value in valid_keys.items(): - # Show just the first and last few characters for security - masked_key = f"{key_value[:8]}...{key_value[-4:]}" if len(key_value) > 12 else "***" - summary_lines.append(f" {key_name}: {masked_key}") - summary_lines.extend(["", "Files created and configured:", ""]) + local_parts.append(f"{provider} (no models)") + print(f" Local: {'; '.join(local_parts)}") else: - summary_lines.extend([" None", "", "Files created and configured:", ""]) - - # File descriptions with alignment - file_descriptions = [] - if created_pdd_dir: - file_descriptions.append(("~/.pdd/", "PDD configuration directory")) - - for file_path in saved_files: - if 'api-env.' in file_path: - file_descriptions.append((file_path, f"API environment variables ({shell} shell)")) - elif 'llm_model.csv' in file_path: - file_descriptions.append((file_path, "LLM model configuration")) - - file_descriptions.append((sample_prompt_file, "Sample prompt for testing")) - - # Add shell init file if it was updated - if init_file_updated: - file_descriptions.append((init_file_updated, f"Shell startup file (updated to source API environment)")) - - file_descriptions.append(("PDD-SETUP-SUMMARY.txt", "This summary")) - - # Find max file path length for alignment - max_path_len = max(len(path) for path, _ in file_descriptions) - - for file_path, description in file_descriptions: - summary_lines.append(f"{file_path:<{max_path_len + 2}}{description}") - - summary_lines.extend([ - "", - create_divider(), - "", - "QUICK START:", - "", - f"1. Reload your shell environment:" - ]) - - # Shell-specific source command for manual reloading - api_env_path = f"{Path.home()}/.pdd/api-env.{shell}" - # Use dot command for sh shell, source for others - if shell == 'sh': - source_cmd = f". {api_env_path}" + print(" Local: none found") + + # .pddrc + pddrc_path = Path.cwd() / ".pddrc" + if pddrc_path.exists(): + print(" .pddrc: exists") else: - source_cmd = f"source {api_env_path}" - - summary_lines.extend([ - f" {source_cmd}", - "", - f"2. Generate code from the sample prompt:", - f" pdd generate success_python.prompt", - "", - create_divider(), - "", - "LEARN MORE:", - "", - f"{BULLET} PDD documentation: pdd --help", - f"{BULLET} PDD website: https://promptdriven.ai/", - f"{BULLET} Discord community: https://discord.gg/Yp4RTh8bG7", - "", - "TIPS:", - "", - f"{BULLET} IMPORTANT: Reload your shell environment using the source command above", - "", - f"{BULLET} Start with simple prompts and gradually increase complexity", - f"{BULLET} Try out 'pdd test' with your prompt+code to create test(s) pdd can use to automatically verify and fix your output code", - f"{BULLET} Try out 'pdd example' with your prompt+code to create examples which help pdd do better", - "", - f"{BULLET} As you get comfortable, learn configuration settings, including the .pddrc file, PDD_GENERATE_OUTPUT_PATH, and PDD_TEST_OUTPUT_PATH", - f"{BULLET} For larger projects, use Makefiles and/or 'pdd sync'", - f"{BULLET} For ongoing substantial projects, learn about llm_model.csv and the --strength,", - f" --temperature, and --time options to optimize model cost, latency, and output quality", - "", - f"{BULLET} Use 'pdd --help' to explore all available commands", - "", - "Problems? Shout out on our Discord for help! https://discord.gg/Yp4RTh8bG7" - ]) - - return '\n'.join(summary_lines) - -def main(): - """Main setup workflow""" - # Initial greeting - print_pdd_logo() - - # Discover environment - print_colored(f"{create_divider()}", CYAN) - print_colored("Discovering local configuration...", CYAN, bold=True) - print_colored(f"{create_divider()}", CYAN) - - keys = discover_api_keys() - - # Test discovered keys - test_results = test_api_keys(keys) - - # Main interaction loop + print(" .pddrc: not created") + + # Test + print(f" Test: {test_result}") + + print() + print(" ═══════════════════════════════════════════════") + print(" Run 'pdd generate' or 'pdd sync' to start.") + print(" ═══════════════════════════════════════════════") + + +# --------------------------------------------------------------------------- +# Fallback manual menu +# --------------------------------------------------------------------------- + +def _run_fallback_menu() -> None: + """Simplified manual menu loop when auto-phase fails.""" + print() + + from pdd.provider_manager import add_provider_from_registry + from pdd.model_tester import test_model_interactive + from pdd.pddrc_initializer import offer_pddrc_init + while True: - choice = show_menu(keys, test_results) - - if choice == '1': - # Re-enter keys - keys = get_user_keys(keys) - test_results = test_api_keys(keys) - - elif choice == '2': - # Re-test keys - test_results = test_api_keys(keys) - - elif choice == '3': - # Save and exit - valid_keys = {k: v for k, v in keys.items() if v and test_results.get(k)} - - if not valid_keys: - print_colored("\nNo valid API keys to save!", YELLOW) - continue - - print_colored(f"\n{create_divider()}", CYAN) - print_colored("Saving configuration...", CYAN, bold=True) - print_colored(f"{create_divider()}", CYAN) - + print("Manual setup options:") + print(" 1. Add a provider") + print(" 2. Test a model") + print(" 3. Initialize .pddrc") + print(" 4. Done") + + try: + choice = input("Select an option [1-4]: ").strip() + except (EOFError, KeyboardInterrupt): + print("\nSetup interrupted — exiting.") + return + + if choice == "1": try: - saved_files, created_pdd_dir, init_file_updated = save_configuration(valid_keys) - sample_prompt_file = create_sample_prompt() - shell = detect_shell() - - # Create and display summary - summary = create_exit_summary(saved_files, created_pdd_dir, sample_prompt_file, shell, valid_keys, init_file_updated) - - # Write summary to file - summary_file = Path('PDD-SETUP-SUMMARY.txt') - summary_file.write_text(summary) - - # Display summary with colors - lines = summary.split('\n') - for line in lines: - if line == create_fat_divider(): - print_colored(line, YELLOW, bold=True) - elif line == "PDD Setup Complete!": - print_colored(line, YELLOW, bold=True) - elif line == create_divider(): - print_colored(line, CYAN) - elif line.startswith("API Keys Configured:") or line.startswith("Files created and configured:"): - print_colored(line, CYAN, bold=True) - elif line.startswith("QUICK START:"): - print_colored(line, YELLOW, bold=True) - elif line.startswith("LEARN MORE:") or line.startswith("TIPS:"): - print_colored(line, CYAN, bold=True) - elif "IMPORTANT:" in line or "Problems?" in line: - print_colored(line, YELLOW, bold=True) - else: - print(line) - - break - - except Exception as e: - print_colored(f"Error saving configuration: {e}", YELLOW) - continue - - elif choice == '4': - # Exit without saving - print_colored("\nExiting without saving configuration.", YELLOW) + add_provider_from_registry() + except Exception as exc: + print(f"Error adding provider: {exc}") + elif choice == "2": + try: + test_model_interactive() + except Exception as exc: + print(f"Error testing model: {exc}") + elif choice == "3": + try: + offer_pddrc_init() + except Exception as exc: + print(f"Error initializing .pddrc: {exc}") + elif choice == "4": break + else: + print("Invalid option. Please enter 1, 2, 3, or 4.") -if __name__ == '__main__': - try: - main() - except KeyboardInterrupt: - print_colored("\n\nSetup cancelled.", YELLOW) - sys.exit(0) \ No newline at end of file + print() + + +if __name__ == "__main__": + run_setup() diff --git a/tests/test_api_key_scanner.py b/tests/test_api_key_scanner.py index 64e989728..e8910af25 100644 --- a/tests/test_api_key_scanner.py +++ b/tests/test_api_key_scanner.py @@ -1,4 +1,4 @@ -"""Tests for pdd/setup/api_key_scanner.py""" +"""Tests for pdd/api_key_scanner.py""" import csv import os @@ -8,7 +8,7 @@ import pytest -from pdd.setup.api_key_scanner import ( +from pdd.api_key_scanner import ( KeyInfo, get_provider_key_names, scan_environment, @@ -355,7 +355,7 @@ def test_priority_order_dotenv_first(self, sample_csv, temp_home, monkeypatch): # Mock dotenv to return a value with mock.patch( - "pdd.setup.api_key_scanner._load_dotenv_values", + "pdd.api_key_scanner._load_dotenv_values", return_value={"OPENAI_API_KEY": "sk-from-dotenv"}, ): result = scan_environment() @@ -372,7 +372,7 @@ def test_priority_order_shell_before_api_env(self, sample_csv, temp_home, monkey # Mock dotenv to return empty (no .env file) with mock.patch( - "pdd.setup.api_key_scanner._load_dotenv_values", + "pdd.api_key_scanner._load_dotenv_values", return_value={}, ): result = scan_environment() @@ -398,7 +398,7 @@ def test_handles_exception_gracefully(self, monkeypatch, temp_home): # Mock get_provider_key_names to raise with mock.patch( - "pdd.setup.api_key_scanner.get_provider_key_names", + "pdd.api_key_scanner.get_provider_key_names", side_effect=Exception("Test error"), ): result = scan_environment() @@ -422,7 +422,7 @@ def test_different_shells_use_different_api_env_files(self, sample_csv, temp_hom # Test with bash shell monkeypatch.setenv("SHELL", "/bin/bash") - with mock.patch("pdd.setup.api_key_scanner._load_dotenv_values", return_value={}): + with mock.patch("pdd.api_key_scanner._load_dotenv_values", return_value={}): result = scan_environment() assert result["OPENAI_API_KEY"].is_set is True @@ -527,7 +527,7 @@ def test_handles_api_key_with_special_shell_characters(self, temp_home, monkeypa monkeypatch.setenv("MY_SPECIAL_KEY", "value_with_$pecial_chars") monkeypatch.setenv("SHELL", "/bin/bash") - with mock.patch("pdd.setup.api_key_scanner._load_dotenv_values", return_value={}): + with mock.patch("pdd.api_key_scanner._load_dotenv_values", return_value={}): result = scan_environment() assert result["MY_SPECIAL_KEY"].is_set is True diff --git a/tests/test_cli_detector.py b/tests/test_cli_detector.py index 89c7db857..9def0afc0 100644 --- a/tests/test_cli_detector.py +++ b/tests/test_cli_detector.py @@ -1,11 +1,14 @@ -"""Tests for pdd/setup/cli_detector.py""" +"""Tests for pdd/cli_detector.py""" import subprocess +import os +import sys from unittest import mock +from pathlib import Path import pytest -from pdd.setup.cli_detector import ( +from pdd.cli_detector import ( _CLI_COMMANDS, _API_KEY_ENV_VARS, _INSTALL_COMMANDS, @@ -16,6 +19,12 @@ _prompt_yes_no, _run_install, detect_cli_tools, + detect_and_bootstrap_cli, + CliBootstrapResult, + _save_api_key, + _find_cli_binary, + _prompt_input, + console ) @@ -158,7 +167,7 @@ def test_returns_bool(self): def test_uses_which_internally(self): """Should use _which to find npm.""" - with mock.patch("pdd.setup.cli_detector._which") as mock_which: + with mock.patch("pdd.cli_detector._which") as mock_which: mock_which.return_value = "/usr/bin/npm" result = _npm_available() mock_which.assert_called_once_with("npm") @@ -166,7 +175,7 @@ def test_uses_which_internally(self): def test_returns_false_when_npm_not_found(self): """Should return False when npm is not installed.""" - with mock.patch("pdd.setup.cli_detector._which", return_value=None): + with mock.patch("pdd.cli_detector._which", return_value=None): assert _npm_available() is False @@ -180,57 +189,57 @@ class TestPromptYesNo: def test_returns_true_for_y(self): """Should return True for 'y' input.""" - with mock.patch("builtins.input", return_value="y"): + with mock.patch("pdd.cli_detector._prompt_input", return_value="y"): assert _prompt_yes_no("Test? ") is True def test_returns_true_for_yes(self): """Should return True for 'yes' input.""" - with mock.patch("builtins.input", return_value="yes"): + with mock.patch("pdd.cli_detector._prompt_input", return_value="yes"): assert _prompt_yes_no("Test? ") is True def test_returns_true_for_Y_uppercase(self): """Should return True for uppercase 'Y' input.""" - with mock.patch("builtins.input", return_value="Y"): + with mock.patch("pdd.cli_detector._prompt_input", return_value="Y"): assert _prompt_yes_no("Test? ") is True def test_returns_true_for_YES_uppercase(self): """Should return True for uppercase 'YES' input.""" - with mock.patch("builtins.input", return_value="YES"): + with mock.patch("pdd.cli_detector._prompt_input", return_value="YES"): assert _prompt_yes_no("Test? ") is True def test_returns_false_for_n(self): """Should return False for 'n' input.""" - with mock.patch("builtins.input", return_value="n"): + with mock.patch("pdd.cli_detector._prompt_input", return_value="n"): assert _prompt_yes_no("Test? ") is False def test_returns_false_for_no(self): """Should return False for 'no' input.""" - with mock.patch("builtins.input", return_value="no"): + with mock.patch("pdd.cli_detector._prompt_input", return_value="no"): assert _prompt_yes_no("Test? ") is False def test_returns_false_for_empty(self): """Should return False for empty input (default is No).""" - with mock.patch("builtins.input", return_value=""): + with mock.patch("pdd.cli_detector._prompt_input", return_value=""): assert _prompt_yes_no("Test? ") is False def test_returns_false_for_random_input(self): """Should return False for unrecognized input.""" - with mock.patch("builtins.input", return_value="maybe"): + with mock.patch("pdd.cli_detector._prompt_input", return_value="maybe"): assert _prompt_yes_no("Test? ") is False def test_handles_eof_error(self): """Should return False on EOFError.""" - with mock.patch("builtins.input", side_effect=EOFError()): + with mock.patch("pdd.cli_detector._prompt_input", side_effect=EOFError()): assert _prompt_yes_no("Test? ") is False def test_handles_keyboard_interrupt(self): """Should return False on KeyboardInterrupt.""" - with mock.patch("builtins.input", side_effect=KeyboardInterrupt()): + with mock.patch("pdd.cli_detector._prompt_input", side_effect=KeyboardInterrupt()): assert _prompt_yes_no("Test? ") is False def test_strips_whitespace(self): """Should strip whitespace from input.""" - with mock.patch("builtins.input", return_value=" y "): + with mock.patch("pdd.cli_detector._prompt_input", return_value=" y "): assert _prompt_yes_no("Test? ") is True @@ -282,8 +291,8 @@ class TestDetectCliTools: def test_prints_header(self, capsys): """Should print the detection header.""" - with mock.patch("pdd.setup.cli_detector._which", return_value=None): - with mock.patch("pdd.setup.cli_detector._has_api_key", return_value=False): + with mock.patch("pdd.cli_detector._which", return_value=None): + with mock.patch("pdd.cli_detector._has_api_key", return_value=False): detect_cli_tools() captured = capsys.readouterr() @@ -292,9 +301,9 @@ def test_prints_header(self, capsys): def test_shows_found_cli(self, capsys): """Should show checkmark for found CLI tools.""" - with mock.patch("pdd.setup.cli_detector._which") as mock_which: + with mock.patch("pdd.cli_detector._which") as mock_which: mock_which.side_effect = lambda cmd: "/usr/bin/claude" if cmd == "claude" else None - with mock.patch("pdd.setup.cli_detector._has_api_key", return_value=False): + with mock.patch("pdd.cli_detector._has_api_key", return_value=False): detect_cli_tools() captured = capsys.readouterr() @@ -303,8 +312,8 @@ def test_shows_found_cli(self, capsys): def test_shows_not_found_cli(self, capsys): """Should show X for missing CLI tools.""" - with mock.patch("pdd.setup.cli_detector._which", return_value=None): - with mock.patch("pdd.setup.cli_detector._has_api_key", return_value=False): + with mock.patch("pdd.cli_detector._which", return_value=None): + with mock.patch("pdd.cli_detector._has_api_key", return_value=False): detect_cli_tools() captured = capsys.readouterr() @@ -313,8 +322,8 @@ def test_shows_not_found_cli(self, capsys): def test_shows_api_key_status_when_cli_found(self, capsys): """Should show API key status when CLI is found.""" - with mock.patch("pdd.setup.cli_detector._which", return_value="/usr/bin/claude"): - with mock.patch("pdd.setup.cli_detector._has_api_key") as mock_has_key: + with mock.patch("pdd.cli_detector._which", return_value="/usr/bin/claude"): + with mock.patch("pdd.cli_detector._has_api_key") as mock_has_key: mock_has_key.return_value = True detect_cli_tools() @@ -323,8 +332,8 @@ def test_shows_api_key_status_when_cli_found(self, capsys): def test_warns_when_cli_found_but_no_key(self, capsys): """Should warn when CLI found but API key not set.""" - with mock.patch("pdd.setup.cli_detector._which", return_value="/usr/bin/claude"): - with mock.patch("pdd.setup.cli_detector._has_api_key", return_value=False): + with mock.patch("pdd.cli_detector._which", return_value="/usr/bin/claude"): + with mock.patch("pdd.cli_detector._has_api_key", return_value=False): detect_cli_tools() captured = capsys.readouterr() @@ -333,34 +342,34 @@ def test_warns_when_cli_found_but_no_key(self, capsys): def test_suggests_install_when_key_but_no_cli(self, capsys): """Should suggest installation when API key is set but CLI is missing.""" - with mock.patch("pdd.setup.cli_detector._which", return_value=None): - with mock.patch("pdd.setup.cli_detector._has_api_key") as mock_has_key: + with mock.patch("pdd.cli_detector._which", return_value=None): + with mock.patch("pdd.cli_detector._has_api_key") as mock_has_key: # Only anthropic has key set mock_has_key.side_effect = lambda p: p == "anthropic" - with mock.patch("pdd.setup.cli_detector._npm_available", return_value=False): + with mock.patch("pdd.cli_detector._npm_available", return_value=False): detect_cli_tools() captured = capsys.readouterr() - assert "install the CLI" in captured.out + assert "install the cli" in captured.out.lower() def test_offers_installation_when_npm_available(self, capsys): """Should offer installation when npm is available.""" - with mock.patch("pdd.setup.cli_detector._which", return_value=None): - with mock.patch("pdd.setup.cli_detector._has_api_key") as mock_has_key: + with mock.patch("pdd.cli_detector._which", return_value=None): + with mock.patch("pdd.cli_detector._has_api_key") as mock_has_key: mock_has_key.side_effect = lambda p: p == "anthropic" - with mock.patch("pdd.setup.cli_detector._npm_available", return_value=True): - with mock.patch("pdd.setup.cli_detector._prompt_yes_no", return_value=False): + with mock.patch("pdd.cli_detector._npm_available", return_value=True): + with mock.patch("pdd.cli_detector._prompt_yes_no", return_value=False): detect_cli_tools() captured = capsys.readouterr() - assert "Install now?" in captured.out or "Install command" in captured.out + assert "Install now?" in captured.out or "Install:" in captured.out def test_shows_npm_not_installed_message(self, capsys): """Should show message when npm is not installed.""" - with mock.patch("pdd.setup.cli_detector._which", return_value=None): - with mock.patch("pdd.setup.cli_detector._has_api_key") as mock_has_key: + with mock.patch("pdd.cli_detector._which", return_value=None): + with mock.patch("pdd.cli_detector._has_api_key") as mock_has_key: mock_has_key.side_effect = lambda p: p == "anthropic" - with mock.patch("pdd.setup.cli_detector._npm_available", return_value=False): + with mock.patch("pdd.cli_detector._npm_available", return_value=False): detect_cli_tools() captured = capsys.readouterr() @@ -368,26 +377,26 @@ def test_shows_npm_not_installed_message(self, capsys): def test_runs_installation_on_yes(self, capsys): """Should run installation when user says yes.""" - with mock.patch("pdd.setup.cli_detector._which") as mock_which: + with mock.patch("pdd.cli_detector._which") as mock_which: mock_which.side_effect = [None, None, None, "/usr/bin/claude"] # Found after install - with mock.patch("pdd.setup.cli_detector._has_api_key") as mock_has_key: + with mock.patch("pdd.cli_detector._has_api_key") as mock_has_key: mock_has_key.side_effect = lambda p: p == "anthropic" - with mock.patch("pdd.setup.cli_detector._npm_available", return_value=True): - with mock.patch("pdd.setup.cli_detector._prompt_yes_no", return_value=True): - with mock.patch("pdd.setup.cli_detector._run_install", return_value=True): + with mock.patch("pdd.cli_detector._npm_available", return_value=True): + with mock.patch("pdd.cli_detector._prompt_yes_no", return_value=True): + with mock.patch("pdd.cli_detector._run_install", return_value=True): detect_cli_tools() captured = capsys.readouterr() - assert "installed successfully" in captured.out or "completed" in captured.out + assert "successfully" in captured.out or "completed" in captured.out def test_shows_failure_message_on_install_fail(self, capsys): """Should show failure message when installation fails.""" - with mock.patch("pdd.setup.cli_detector._which", return_value=None): - with mock.patch("pdd.setup.cli_detector._has_api_key") as mock_has_key: + with mock.patch("pdd.cli_detector._which", return_value=None): + with mock.patch("pdd.cli_detector._has_api_key") as mock_has_key: mock_has_key.side_effect = lambda p: p == "anthropic" - with mock.patch("pdd.setup.cli_detector._npm_available", return_value=True): - with mock.patch("pdd.setup.cli_detector._prompt_yes_no", return_value=True): - with mock.patch("pdd.setup.cli_detector._run_install", return_value=False): + with mock.patch("pdd.cli_detector._npm_available", return_value=True): + with mock.patch("pdd.cli_detector._prompt_yes_no", return_value=True): + with mock.patch("pdd.cli_detector._run_install", return_value=False): detect_cli_tools() captured = capsys.readouterr() @@ -395,11 +404,11 @@ def test_shows_failure_message_on_install_fail(self, capsys): def test_shows_skip_message_on_no(self, capsys): """Should show skip message when user declines installation.""" - with mock.patch("pdd.setup.cli_detector._which", return_value=None): - with mock.patch("pdd.setup.cli_detector._has_api_key") as mock_has_key: + with mock.patch("pdd.cli_detector._which", return_value=None): + with mock.patch("pdd.cli_detector._has_api_key") as mock_has_key: mock_has_key.side_effect = lambda p: p == "anthropic" - with mock.patch("pdd.setup.cli_detector._npm_available", return_value=True): - with mock.patch("pdd.setup.cli_detector._prompt_yes_no", return_value=False): + with mock.patch("pdd.cli_detector._npm_available", return_value=True): + with mock.patch("pdd.cli_detector._prompt_yes_no", return_value=False): detect_cli_tools() captured = capsys.readouterr() @@ -407,8 +416,8 @@ def test_shows_skip_message_on_no(self, capsys): def test_shows_quick_start_when_nothing_installed(self, capsys): """Should show quick start guide when no CLIs are installed.""" - with mock.patch("pdd.setup.cli_detector._which", return_value=None): - with mock.patch("pdd.setup.cli_detector._has_api_key", return_value=False): + with mock.patch("pdd.cli_detector._which", return_value=None): + with mock.patch("pdd.cli_detector._has_api_key", return_value=False): detect_cli_tools() captured = capsys.readouterr() @@ -416,8 +425,8 @@ def test_shows_quick_start_when_nothing_installed(self, capsys): def test_shows_all_installed_message(self, capsys): """Should show success message when all CLIs with keys are installed.""" - with mock.patch("pdd.setup.cli_detector._which", return_value="/usr/bin/cli"): - with mock.patch("pdd.setup.cli_detector._has_api_key", return_value=True): + with mock.patch("pdd.cli_detector._which", return_value="/usr/bin/cli"): + with mock.patch("pdd.cli_detector._has_api_key", return_value=True): detect_cli_tools() captured = capsys.readouterr() @@ -434,8 +443,8 @@ class TestIntegration: def test_detect_cli_tools_handles_import_error(self, capsys): """Should handle missing agentic_common gracefully.""" - with mock.patch("pdd.setup.cli_detector._which", return_value=None): - with mock.patch("pdd.setup.cli_detector._has_api_key", return_value=False): + with mock.patch("pdd.cli_detector._which", return_value=None): + with mock.patch("pdd.cli_detector._has_api_key", return_value=False): # The function imports get_available_agents but handles import errors detect_cli_tools() @@ -451,9 +460,9 @@ def mock_which(cmd): def mock_has_key(provider): return provider in ["anthropic", "openai"] - with mock.patch("pdd.setup.cli_detector._which", side_effect=mock_which): - with mock.patch("pdd.setup.cli_detector._has_api_key", side_effect=mock_has_key): - with mock.patch("pdd.setup.cli_detector._npm_available", return_value=False): + with mock.patch("pdd.cli_detector._which", side_effect=mock_which): + with mock.patch("pdd.cli_detector._has_api_key", side_effect=mock_has_key): + with mock.patch("pdd.cli_detector._npm_available", return_value=False): detect_cli_tools() captured = capsys.readouterr() @@ -487,3 +496,287 @@ def test_whitespace_only_env_var_treated_as_not_set(self, monkeypatch): """Whitespace-only env vars should be treated as not set.""" monkeypatch.setenv("ANTHROPIC_API_KEY", " \t\n ") assert _has_api_key("anthropic") is False + + + +# --------------------------------------------------------------------------- +# Fixtures for bootstrap tests +# --------------------------------------------------------------------------- + +@pytest.fixture +def mock_console(): + with mock.patch("pdd.cli_detector.console") as m: + yield m + +@pytest.fixture +def mock_env(): + with mock.patch.dict(os.environ, {}, clear=True): + yield os.environ + +@pytest.fixture +def mock_which(): + with mock.patch("shutil.which") as m: + yield m + +@pytest.fixture +def mock_input(): + with mock.patch("pdd.cli_detector._prompt_input") as m: + yield m + +@pytest.fixture +def mock_subprocess(): + with mock.patch("subprocess.run") as m: + yield m + +@pytest.fixture +def mock_home(tmp_path): + """Mock Path.home() to return a temporary directory.""" + with mock.patch("pathlib.Path.home", return_value=tmp_path): + yield tmp_path + +# --------------------------------------------------------------------------- +# Helper Function Tests +# --------------------------------------------------------------------------- + +def test_save_api_key_bash(mock_home, mock_console): + """Test saving API key for bash shell.""" + shell = "bash" + key_name = "TEST_KEY" + key_value = "sk-test-123" + + # Create a dummy .bashrc + rc_file = mock_home / ".bashrc" + rc_file.write_text("# existing content\n") + + success = _save_api_key(key_name, key_value, shell) + + assert success is True + + # Check api-env file + api_env = mock_home / ".pdd" / "api-env.bash" + assert api_env.exists() + content = api_env.read_text() + assert f"export {key_name}={key_value}" in content + + # Check RC file update + rc_content = rc_file.read_text() + assert f"source {api_env}" in rc_content + assert os.environ[key_name] == key_value + +def test_save_api_key_fish(mock_home, mock_console): + """Test saving API key for fish shell.""" + shell = "fish" + key_name = "TEST_KEY" + key_value = "sk-test-123" + + # Create dummy config.fish + fish_config = mock_home / ".config" / "fish" / "config.fish" + fish_config.parent.mkdir(parents=True) + fish_config.write_text("") + + success = _save_api_key(key_name, key_value, shell) + + assert success is True + + api_env = mock_home / ".pdd" / "api-env.fish" + content = api_env.read_text() + assert f"set -gx {key_name} {key_value}" in content + + rc_content = fish_config.read_text() + assert f"test -f {api_env} ; and source {api_env}" in rc_content + +def test_find_cli_binary_fallback(mock_which): + """Test finding binary in fallback paths when shutil.which fails.""" + mock_which.return_value = None + + # Mock os.path.exists and is_file/access + with mock.patch("pathlib.Path.exists", return_value=True), \ + mock.patch("pathlib.Path.is_file", return_value=True), \ + mock.patch("os.access", return_value=True): + + # Should find it in the first fallback path checked + result = _find_cli_binary("claude") + assert result is not None + assert "claude" in result + +# --------------------------------------------------------------------------- +# detect_and_bootstrap_cli Tests +# --------------------------------------------------------------------------- + +def test_bootstrap_happy_path(mock_env, mock_input, mock_console): + """ + Scenario: Claude is installed and ANTHROPIC_API_KEY is set. + All 3 CLIs are shown in table. User selects Claude (1). + Expect: Return with success, no install or API key prompt needed. + """ + mock_env["ANTHROPIC_API_KEY"] = "sk-existing" + mock_input.return_value = "1" # User selects Claude + + with mock.patch("pdd.cli_detector._find_cli_binary") as mock_find: + mock_find.side_effect = lambda x: "/usr/bin/claude" if x == "claude" else None + result = detect_and_bootstrap_cli() + + assert result.cli_name == "claude" + assert result.provider == "anthropic" + assert result.api_key_configured is True + assert result.cli_path == "/usr/bin/claude" + + # Table should be shown with all 3 CLIs before the user picks + all_printed = " ".join(str(c) for c in mock_console.print.call_args_list) + assert "Claude CLI" in all_printed + assert "Codex CLI" in all_printed + assert "Gemini CLI" in all_printed + +def test_bootstrap_key_needed_user_provides(mock_which, mock_env, mock_input, mock_home, mock_console): + """ + Scenario: Claude installed, no key. User enters key. + Expect: Key saved, success returned. + """ + mock_which.side_effect = lambda x: f"/usr/bin/{x}" if x == "claude" else None + # No API key in env + + mock_input.return_value = "sk-new-key" + + result = detect_and_bootstrap_cli() + + assert result.cli_name == "claude" + assert result.provider == "anthropic" + assert result.api_key_configured is True + + # Verify key was saved to env + assert os.environ["ANTHROPIC_API_KEY"] == "sk-new-key" + # Verify file write happened (via _save_api_key logic) + api_env = mock_home / ".pdd" / "api-env.bash" # Default shell is bash + assert api_env.exists() + +def test_bootstrap_key_needed_user_skips(mock_which, mock_env, mock_input, mock_console): + """ + Scenario: Claude installed, no key. User presses Enter (skips). + Expect: Success returned but api_key_configured=False. + """ + mock_which.side_effect = lambda x: f"/usr/bin/{x}" if x == "claude" else None + mock_input.return_value = "" # Empty input + + result = detect_and_bootstrap_cli() + + assert result.cli_name == "claude" + assert result.api_key_configured is False + mock_console.print.assert_any_call(" [dim]Note: Claude CLI may still work with subscription auth.[/dim]") + +def test_bootstrap_no_cli_user_declines(mock_which, mock_input, mock_console): + """ + Scenario: No CLIs found. User says 'n' to install. + Expect: Empty result. + """ + mock_which.return_value = None # No CLIs found + mock_input.return_value = "n" + + result = detect_and_bootstrap_cli() + + assert result.cli_name == "" + assert result.provider == "" + mock_console.print.assert_any_call(" [dim]Skipped CLI setup. You can run `pdd setup` again later.[/dim]") + +def test_bootstrap_install_npm_missing(mock_input, mock_console): + """ + Scenario: No CLIs found. All 3 shown in table. User selects Claude (1), + says yes to install. npm not found. + Expect: Error message, empty result. + """ + mock_input.side_effect = ["1", "y"] # Select Claude, then yes to install + + with mock.patch("pdd.cli_detector._find_cli_binary", return_value=None), \ + mock.patch("pdd.cli_detector._npm_available", return_value=False): + result = detect_and_bootstrap_cli() + + assert result.cli_name == "" + all_printed = " ".join(str(c) for c in mock_console.print.call_args_list) + assert "npm is not installed" in all_printed + +def test_bootstrap_install_success(mock_input, mock_subprocess, mock_home, mock_console): + """ + Scenario: No CLIs found. All 3 shown in table. User selects Claude (1), + says yes to install. Install succeeds. User provides API key. + Expect: Full success. + """ + # Inputs: select Claude, yes to install, provide API key + mock_input.side_effect = ["1", "y", "sk-key"] + + # _find_cli_binary returns None on initial scan; returns the path after install + claude_calls = [0] + def find_binary(name): + if name == "claude": + claude_calls[0] += 1 + return "/usr/bin/claude" if claude_calls[0] > 1 else None + return None + + mock_subprocess.return_value.returncode = 0 + + with mock.patch("pdd.cli_detector._find_cli_binary", side_effect=find_binary), \ + mock.patch("pdd.cli_detector._npm_available", return_value=True): + result = detect_and_bootstrap_cli() + + assert result.cli_name == "claude" + assert result.cli_path == "/usr/bin/claude" + assert result.api_key_configured is True + + # Verify the correct install command was run + mock_subprocess.assert_called_with( + "npm install -g @anthropic-ai/claude-code", + shell=True, capture_output=True, text=True, timeout=120 + ) + +def test_bootstrap_install_failure(mock_input, mock_subprocess, mock_console): + """ + Scenario: No CLIs found. All 3 shown in table. User selects Claude (1), + says yes to install. Install fails. + Expect: Empty result with failure message. + """ + mock_input.side_effect = ["1", "y"] # Select Claude, then yes to install + + mock_subprocess.return_value.returncode = 1 + mock_subprocess.return_value.stderr = "Permission denied" + + with mock.patch("pdd.cli_detector._find_cli_binary", return_value=None), \ + mock.patch("pdd.cli_detector._npm_available", return_value=True): + result = detect_and_bootstrap_cli() + + assert result.cli_name == "" + all_printed = " ".join(str(c) for c in mock_console.print.call_args_list) + assert "Installation failed" in all_printed + +# --------------------------------------------------------------------------- +# detect_cli_tools Tests (Bootstrap Perspective) +# --------------------------------------------------------------------------- + +def test_detect_cli_tools_reporting(mock_which, mock_env, mock_console): + """Test legacy detection reporting.""" + # Claude found, others missing + mock_which.side_effect = lambda x: "/bin/claude" if x == "claude" else None + mock_env["ANTHROPIC_API_KEY"] = "sk-test" + + detect_cli_tools() + + # The code now uses display_name, so we adjust expectations + mock_console.print.assert_any_call(" [green]✓[/green] Claude CLI — Found at /bin/claude") + mock_console.print.assert_any_call(" [green]✓[/green] ANTHROPIC_API_KEY is set") + mock_console.print.assert_any_call(" [red]✗[/red] Codex CLI — Not found") + +def test_detect_cli_tools_offer_install(mock_which, mock_env, mock_input, mock_subprocess, mock_console): + """Test legacy install offer when key exists but CLI missing.""" + # Codex missing, but key present + mock_which.side_effect = lambda x: "/bin/npm" if x == "npm" else None + mock_env["OPENAI_API_KEY"] = "sk-openai" + + # User says yes to install + mock_input.return_value = "y" + mock_subprocess.return_value.returncode = 0 + + detect_cli_tools() + + # The code now uses display_name, so we adjust expectations + mock_console.print.assert_any_call(" [yellow]You have OPENAI_API_KEY set but Codex CLI is not installed.[/yellow]") + mock_subprocess.assert_called_with( + "npm install -g @openai/codex", + shell=True, capture_output=True, text=True, timeout=120 + ) diff --git a/tests/test_litellm_registry.py b/tests/test_litellm_registry.py index 067b7633c..41079eeb3 100644 --- a/tests/test_litellm_registry.py +++ b/tests/test_litellm_registry.py @@ -1,10 +1,10 @@ -"""Tests for pdd/setup/litellm_registry.py""" +"""Tests for pdd/litellm_registry.py""" from unittest import mock import pytest -from pdd.setup.litellm_registry import ( +from pdd.litellm_registry import ( ProviderInfo, ModelInfo, PROVIDER_API_KEY_MAP, @@ -158,7 +158,7 @@ def raise_import_error(): raise ImportError("No module named 'litellm'") with mock.patch( - "pdd.setup.litellm_registry.is_litellm_available", + "pdd.litellm_registry.is_litellm_available", side_effect=raise_import_error, ): # The actual function should handle this gracefully @@ -169,7 +169,7 @@ def test_returns_false_when_model_cost_empty(self): mock_litellm = mock.MagicMock() mock_litellm.model_cost = {} - with mock.patch("pdd.setup.litellm_registry.is_litellm_available") as mock_fn: + with mock.patch("pdd.litellm_registry.is_litellm_available") as mock_fn: mock_fn.return_value = False assert mock_fn() is False @@ -311,7 +311,7 @@ class TestGetTopProviders: def test_returns_empty_list_when_litellm_unavailable(self): """Should return empty list when litellm is not available.""" with mock.patch( - "pdd.setup.litellm_registry.is_litellm_available", return_value=False + "pdd.litellm_registry.is_litellm_available", return_value=False ): result = get_top_providers() assert result == [] @@ -353,7 +353,7 @@ class TestGetAllProviders: def test_returns_empty_list_when_litellm_unavailable(self): """Should return empty list when litellm is not available.""" with mock.patch( - "pdd.setup.litellm_registry.is_litellm_available", return_value=False + "pdd.litellm_registry.is_litellm_available", return_value=False ): result = get_all_providers() assert result == [] @@ -449,7 +449,7 @@ class TestGetModelsForProvider: def test_returns_empty_list_when_litellm_unavailable(self): """Should return empty list when litellm is not available.""" with mock.patch( - "pdd.setup.litellm_registry.is_litellm_available", return_value=False + "pdd.litellm_registry.is_litellm_available", return_value=False ): result = get_models_for_provider("openai") assert result == [] diff --git a/tests/test_model_tester.py b/tests/test_model_tester.py new file mode 100644 index 000000000..ed6365b71 --- /dev/null +++ b/tests/test_model_tester.py @@ -0,0 +1,232 @@ +"""Tests for model_tester.py — CSV loading, key resolution, cost calculation, error classification.""" + +import os +import pytest +from unittest.mock import MagicMock, patch + +from pdd import model_tester + + +# --------------------------------------------------------------------------- +# Tests for _resolve_api_key +# --------------------------------------------------------------------------- + +def test_resolve_api_key_empty_key_name(): + """No api_key configured returns None with status message.""" + row = {"api_key": ""} + key, status = model_tester._resolve_api_key(row) + assert key is None + assert "no key configured" in status + + +def test_resolve_api_key_found_in_env(monkeypatch): + """Key found in os.environ returns the value.""" + monkeypatch.setenv("TEST_API_KEY", "sk-abc123") + row = {"api_key": "TEST_API_KEY"} + key, status = model_tester._resolve_api_key(row) + assert key == "sk-abc123" + assert "Found" in status + assert "TEST_API_KEY" in status + + +def test_resolve_api_key_not_found(monkeypatch): + """Key not in any source returns None with 'Not found'.""" + monkeypatch.delenv("MISSING_KEY", raising=False) + row = {"api_key": "MISSING_KEY"} + with patch.dict("sys.modules", {"dotenv": None}): + key, status = model_tester._resolve_api_key(row) + assert key is None + assert "Not found" in status + + +def test_resolve_api_key_strips_whitespace(monkeypatch): + """Key value is stripped of leading/trailing whitespace.""" + monkeypatch.setenv("PADDED_KEY", " sk-test ") + row = {"api_key": "PADDED_KEY"} + key, _ = model_tester._resolve_api_key(row) + assert key == "sk-test" + + +# --------------------------------------------------------------------------- +# Tests for _resolve_base_url +# --------------------------------------------------------------------------- + +def test_resolve_base_url_explicit(): + """Explicit base_url in row is returned directly.""" + row = {"base_url": "https://custom.api.com/v1"} + assert model_tester._resolve_base_url(row) == "https://custom.api.com/v1" + + +def test_resolve_base_url_empty(): + """Empty base_url for non-local model returns None.""" + row = {"base_url": "", "model": "anthropic/claude", "provider": "Anthropic"} + assert model_tester._resolve_base_url(row) is None + + +def test_resolve_base_url_lm_studio_model(monkeypatch): + """LM Studio model gets default localhost URL.""" + monkeypatch.delenv("LM_STUDIO_API_BASE", raising=False) + row = {"base_url": "", "model": "lm_studio/my-model", "provider": "lm_studio"} + result = model_tester._resolve_base_url(row) + assert result == "http://localhost:1234/v1" + + +def test_resolve_base_url_lm_studio_custom_env(monkeypatch): + """LM Studio respects LM_STUDIO_API_BASE env var.""" + monkeypatch.setenv("LM_STUDIO_API_BASE", "http://remote:5000/v1") + row = {"base_url": "", "model": "lm_studio/model", "provider": "lm_studio"} + assert model_tester._resolve_base_url(row) == "http://remote:5000/v1" + + +# --------------------------------------------------------------------------- +# Tests for _calculate_cost +# --------------------------------------------------------------------------- + +def test_calculate_cost_basic(): + """Cost calculation with known token counts and prices.""" + # 100 prompt tokens at $3/M + 50 completion tokens at $15/M + cost = model_tester._calculate_cost(100, 50, 3.0, 15.0) + expected = (100 * 3.0 + 50 * 15.0) / 1_000_000.0 + assert abs(cost - expected) < 1e-10 + + +def test_calculate_cost_zero(): + """Zero tokens or zero prices produce zero cost.""" + assert model_tester._calculate_cost(0, 0, 3.0, 15.0) == 0.0 + assert model_tester._calculate_cost(100, 100, 0.0, 0.0) == 0.0 + + +# --------------------------------------------------------------------------- +# Tests for _classify_error +# --------------------------------------------------------------------------- + +def test_classify_error_auth(): + """Authentication-related errors are classified correctly.""" + exc = Exception("401 Unauthorized - invalid api key") + result = model_tester._classify_error(exc) + assert "Authentication error" in result + + +def test_classify_error_connection_refused(): + """Connection refused errors suggest local server issue.""" + exc = ConnectionError("Connection refused") + result = model_tester._classify_error(exc) + assert "Connection refused" in result + + +def test_classify_error_not_found(): + """404 / model not found errors classified correctly.""" + exc = Exception("404 Model does not exist") + result = model_tester._classify_error(exc) + assert "Model not found" in result + + +def test_classify_error_timeout(): + """Timeout errors classified correctly.""" + exc = TimeoutError("Request timed out after 30s") + result = model_tester._classify_error(exc) + assert "timed out" in result + + +def test_classify_error_rate_limit(): + """Rate limit errors classified correctly.""" + exc = Exception("429 Rate limit exceeded") + result = model_tester._classify_error(exc) + assert "Rate limited" in result + + +def test_classify_error_generic(): + """Unknown errors fall through to generic classification.""" + exc = ValueError("Something unexpected") + result = model_tester._classify_error(exc) + assert "ValueError" in result + assert "Something unexpected" in result + + +# --------------------------------------------------------------------------- +# Tests for _run_test +# --------------------------------------------------------------------------- + +@patch("litellm.completion") +def test_run_test_success(mock_completion, monkeypatch): + """Successful litellm call returns success dict with tokens and cost.""" + monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-test") + + usage = MagicMock() + usage.prompt_tokens = 10 + usage.completion_tokens = 5 + response = MagicMock() + response.usage = usage + mock_completion.return_value = response + + row = {"model": "anthropic/claude", "api_key": "ANTHROPIC_API_KEY", + "input": 3.0, "output": 15.0} + result = model_tester._run_test(row) + + assert result["success"] is True + assert result["error"] is None + assert result["tokens"]["prompt"] == 10 + assert result["tokens"]["completion"] == 5 + assert result["cost"] > 0 + assert result["duration_s"] >= 0 + + +@patch("litellm.completion") +def test_run_test_failure(mock_completion, monkeypatch): + """Failed litellm call returns failure dict with classified error.""" + monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-test") + mock_completion.side_effect = Exception("401 Unauthorized") + + row = {"model": "anthropic/claude", "api_key": "ANTHROPIC_API_KEY", + "input": 3.0, "output": 15.0} + result = model_tester._run_test(row) + + assert result["success"] is False + assert result["cost"] == 0.0 + assert result["tokens"] is None + assert "Authentication error" in result["error"] + + +# --------------------------------------------------------------------------- +# Tests for _load_user_csv +# --------------------------------------------------------------------------- + +@patch("pdd.model_tester.Path") +def test_load_user_csv_missing_file(mock_path): + """Returns None when CSV file doesn't exist.""" + mock_home = MagicMock() + mock_csv = MagicMock() + mock_csv.is_file.return_value = False + mock_home.__truediv__ = MagicMock(return_value=MagicMock(__truediv__=MagicMock(return_value=mock_csv))) + mock_path.home.return_value = mock_home + + result = model_tester._load_user_csv() + assert result is None + + +def test_load_user_csv_valid(tmp_path): + """Returns DataFrame for a valid CSV with required columns.""" + csv_content = "provider,model,api_key,input,output\nAnthropic,claude,ANTHROPIC_API_KEY,3.0,15.0\n" + csv_file = tmp_path / ".pdd" / "llm_model.csv" + csv_file.parent.mkdir(parents=True) + csv_file.write_text(csv_content) + + with patch.object(model_tester.Path, "home", return_value=tmp_path): + df = model_tester._load_user_csv() + + assert df is not None + assert len(df) == 1 + assert df.iloc[0]["provider"] == "Anthropic" + + +def test_load_user_csv_missing_columns(tmp_path): + """Returns None when required columns are missing.""" + csv_content = "name,value\nfoo,bar\n" + csv_file = tmp_path / ".pdd" / "llm_model.csv" + csv_file.parent.mkdir(parents=True) + csv_file.write_text(csv_content) + + with patch.object(model_tester.Path, "home", return_value=tmp_path): + df = model_tester._load_user_csv() + + assert df is None diff --git a/tests/test_pddrc_initializer.py b/tests/test_pddrc_initializer.py new file mode 100644 index 000000000..b9bd50af5 --- /dev/null +++ b/tests/test_pddrc_initializer.py @@ -0,0 +1,207 @@ +"""Tests for pddrc_initializer.py — language detection, content generation, offer flow.""" + +import json +import pytest +from unittest.mock import MagicMock, patch +from pathlib import Path + +from pdd import pddrc_initializer + + +# --------------------------------------------------------------------------- +# Tests for _detect_language +# --------------------------------------------------------------------------- + +def test_detect_language_python_pyproject(tmp_path): + """Detects Python from pyproject.toml.""" + (tmp_path / "pyproject.toml").touch() + assert pddrc_initializer._detect_language(tmp_path) == "python" + + +def test_detect_language_python_setup_py(tmp_path): + """Detects Python from setup.py.""" + (tmp_path / "setup.py").touch() + assert pddrc_initializer._detect_language(tmp_path) == "python" + + +def test_detect_language_python_requirements(tmp_path): + """Detects Python from requirements.txt.""" + (tmp_path / "requirements.txt").touch() + assert pddrc_initializer._detect_language(tmp_path) == "python" + + +def test_detect_language_typescript(tmp_path): + """Detects TypeScript from package.json with typescript dependency.""" + pkg = {"devDependencies": {"typescript": "^5.0.0"}} + (tmp_path / "package.json").write_text(json.dumps(pkg)) + assert pddrc_initializer._detect_language(tmp_path) == "typescript" + + +def test_detect_language_not_typescript_without_dep(tmp_path): + """package.json without typescript dep is not detected as TypeScript.""" + pkg = {"dependencies": {"express": "^4.0.0"}} + (tmp_path / "package.json").write_text(json.dumps(pkg)) + assert pddrc_initializer._detect_language(tmp_path) is None + + +def test_detect_language_go(tmp_path): + """Detects Go from go.mod.""" + (tmp_path / "go.mod").touch() + assert pddrc_initializer._detect_language(tmp_path) == "go" + + +def test_detect_language_none(tmp_path): + """Returns None when no markers found.""" + assert pddrc_initializer._detect_language(tmp_path) is None + + +def test_detect_language_python_priority_over_go(tmp_path): + """Python markers take priority over Go markers.""" + (tmp_path / "pyproject.toml").touch() + (tmp_path / "go.mod").touch() + assert pddrc_initializer._detect_language(tmp_path) == "python" + + +# --------------------------------------------------------------------------- +# Tests for _build_pddrc_content +# --------------------------------------------------------------------------- + +def test_build_pddrc_content_python(): + """Python content has correct paths and defaults.""" + content = pddrc_initializer._build_pddrc_content("python") + assert 'version: "1.0"' in content + assert 'generate_output_path: "pdd/"' in content + assert 'test_output_path: "tests/"' in content + assert 'example_output_path: "context/"' in content + assert 'default_language: "python"' in content + assert "strength: 1.0" in content + assert "temperature: 0.0" in content + assert "target_coverage: 80.0" in content + assert "budget: 10.0" in content + assert "max_attempts: 3" in content + + +def test_build_pddrc_content_typescript(): + """TypeScript content has correct paths.""" + content = pddrc_initializer._build_pddrc_content("typescript") + assert 'generate_output_path: "src/"' in content + assert 'test_output_path: "__tests__/"' in content + assert 'example_output_path: "examples/"' in content + assert 'default_language: "typescript"' in content + + +def test_build_pddrc_content_go(): + """Go content has correct paths.""" + content = pddrc_initializer._build_pddrc_content("go") + assert 'generate_output_path: "."' in content + assert 'test_output_path: "."' in content + assert 'example_output_path: "examples/"' in content + assert 'default_language: "go"' in content + + +def test_build_pddrc_content_unknown_falls_back_to_python(): + """Unknown language falls back to Python defaults.""" + content = pddrc_initializer._build_pddrc_content("rust") + assert 'generate_output_path: "pdd/"' in content + assert 'default_language: "rust"' in content + + +def test_build_pddrc_content_ends_with_newline(): + """Generated content ends with a trailing newline.""" + content = pddrc_initializer._build_pddrc_content("python") + assert content.endswith("\n") + + +# --------------------------------------------------------------------------- +# Tests for offer_pddrc_init +# --------------------------------------------------------------------------- + +def test_offer_pddrc_init_already_exists(tmp_path): + """Returns False and does not overwrite when .pddrc already exists.""" + (tmp_path / ".pddrc").write_text("existing config") + + with patch.object(Path, "cwd", return_value=tmp_path): + result = pddrc_initializer.offer_pddrc_init() + + assert result is False + assert (tmp_path / ".pddrc").read_text() == "existing config" + + +@patch.object(pddrc_initializer.console, "input", return_value="y") +def test_offer_pddrc_init_creates_file(mock_input, tmp_path): + """Creates .pddrc when user confirms with 'y'.""" + (tmp_path / "pyproject.toml").touch() # Python marker + + with patch.object(Path, "cwd", return_value=tmp_path): + result = pddrc_initializer.offer_pddrc_init() + + assert result is True + pddrc = tmp_path / ".pddrc" + assert pddrc.exists() + content = pddrc.read_text() + assert 'default_language: "python"' in content + + +@patch.object(pddrc_initializer.console, "input", return_value="") +def test_offer_pddrc_init_enter_means_yes(mock_input, tmp_path): + """Empty input (just Enter) means yes — file is created.""" + (tmp_path / "pyproject.toml").touch() + + with patch.object(Path, "cwd", return_value=tmp_path): + result = pddrc_initializer.offer_pddrc_init() + + assert result is True + assert (tmp_path / ".pddrc").exists() + + +@patch.object(pddrc_initializer.console, "input", return_value="n") +def test_offer_pddrc_init_declined(mock_input, tmp_path): + """Returns False when user declines with 'n'.""" + (tmp_path / "pyproject.toml").touch() + + with patch.object(Path, "cwd", return_value=tmp_path): + result = pddrc_initializer.offer_pddrc_init() + + assert result is False + assert not (tmp_path / ".pddrc").exists() + + +@patch.object(pddrc_initializer.console, "input") +def test_offer_pddrc_init_prompts_language_when_unknown(mock_input, tmp_path): + """When no markers found, prompts user for language choice.""" + # First input: language choice (1=Python), second: confirmation (y) + mock_input.side_effect = ["1", "y"] + + with patch.object(Path, "cwd", return_value=tmp_path): + result = pddrc_initializer.offer_pddrc_init() + + assert result is True + content = (tmp_path / ".pddrc").read_text() + assert 'default_language: "python"' in content + + +# --------------------------------------------------------------------------- +# Tests for _prompt_language +# --------------------------------------------------------------------------- + +@patch.object(pddrc_initializer.console, "input", return_value="1") +def test_prompt_language_python(mock_input): + assert pddrc_initializer._prompt_language() == "python" + + +@patch.object(pddrc_initializer.console, "input", return_value="2") +def test_prompt_language_typescript(mock_input): + assert pddrc_initializer._prompt_language() == "typescript" + + +@patch.object(pddrc_initializer.console, "input", return_value="3") +def test_prompt_language_go(mock_input): + assert pddrc_initializer._prompt_language() == "go" + + +@patch.object(pddrc_initializer.console, "input") +def test_prompt_language_retries_on_invalid(mock_input): + """Invalid input causes retry until valid choice is entered.""" + mock_input.side_effect = ["x", "99", "2"] + assert pddrc_initializer._prompt_language() == "typescript" + assert mock_input.call_count == 3 diff --git a/tests/test_provider_manager.py b/tests/test_provider_manager.py index 34ffaf317..8a70184b5 100644 --- a/tests/test_provider_manager.py +++ b/tests/test_provider_manager.py @@ -1,4 +1,4 @@ -"""Tests for pdd/setup/provider_manager.py""" +"""Tests for pdd/provider_manager.py""" import csv import os @@ -8,7 +8,7 @@ import pytest -from pdd.setup.provider_manager import ( +from pdd.provider_manager import ( CSV_FIELDNAMES, _get_shell_name, _get_pdd_dir, @@ -386,9 +386,9 @@ class TestAddProviderFromRegistry: def test_returns_false_when_litellm_unavailable(self, temp_home): """Should return False when litellm is not available.""" with mock.patch( - "pdd.setup.provider_manager.is_litellm_available", return_value=False + "pdd.provider_manager.is_litellm_available", return_value=False ): - with mock.patch("pdd.setup.provider_manager.console"): + with mock.patch("pdd.provider_manager.console"): result = add_provider_from_registry() assert result is False @@ -396,22 +396,22 @@ def test_returns_false_when_litellm_unavailable(self, temp_home): def test_returns_false_on_empty_selection(self, temp_home): """Should return False when user enters empty selection.""" with mock.patch( - "pdd.setup.provider_manager.is_litellm_available", return_value=True + "pdd.provider_manager.is_litellm_available", return_value=True ): with mock.patch( - "pdd.setup.provider_manager.get_top_providers", + "pdd.provider_manager.get_top_providers", return_value=[], ): - with mock.patch("pdd.setup.provider_manager.Prompt") as mock_prompt: + with mock.patch("pdd.provider_manager.Prompt") as mock_prompt: mock_prompt.ask.return_value = "" - with mock.patch("pdd.setup.provider_manager.console"): + with mock.patch("pdd.provider_manager.console"): result = add_provider_from_registry() assert result is False def test_adds_models_to_csv(self, temp_home): """Should add selected models to user CSV.""" - from pdd.setup.litellm_registry import ProviderInfo, ModelInfo + from pdd.litellm_registry import ProviderInfo, ModelInfo mock_provider = ProviderInfo( name="test_provider", @@ -432,26 +432,26 @@ def test_adds_models_to_csv(self, temp_home): ] with mock.patch( - "pdd.setup.provider_manager.is_litellm_available", return_value=True + "pdd.provider_manager.is_litellm_available", return_value=True ): with mock.patch( - "pdd.setup.provider_manager.get_top_providers", + "pdd.provider_manager.get_top_providers", return_value=[mock_provider], ): with mock.patch( - "pdd.setup.provider_manager.get_models_for_provider", + "pdd.provider_manager.get_models_for_provider", return_value=mock_models, ): - with mock.patch("pdd.setup.provider_manager.Prompt") as mock_prompt: + with mock.patch("pdd.provider_manager.Prompt") as mock_prompt: # Select provider 1, then model 1 mock_prompt.ask.side_effect = ["1", "1", "test-api-key"] - with mock.patch("pdd.setup.provider_manager.Confirm") as mock_confirm: + with mock.patch("pdd.provider_manager.Confirm") as mock_confirm: mock_confirm.ask.return_value = False - with mock.patch("pdd.setup.provider_manager.console"): + with mock.patch("pdd.provider_manager.console"): with mock.patch( - "pdd.setup.provider_manager._is_key_set", + "pdd.provider_manager._is_key_set", return_value=None, ): result = add_provider_from_registry() @@ -473,16 +473,16 @@ class TestAddCustomProvider: def test_returns_false_on_empty_provider(self, temp_home): """Should return False when provider name is empty.""" - with mock.patch("pdd.setup.provider_manager.Prompt") as mock_prompt: + with mock.patch("pdd.provider_manager.Prompt") as mock_prompt: mock_prompt.ask.return_value = "" - with mock.patch("pdd.setup.provider_manager.console"): + with mock.patch("pdd.provider_manager.console"): result = add_custom_provider() assert result is False def test_adds_custom_provider_to_csv(self, temp_home): """Should add custom provider to user CSV.""" - with mock.patch("pdd.setup.provider_manager.Prompt") as mock_prompt: + with mock.patch("pdd.provider_manager.Prompt") as mock_prompt: mock_prompt.ask.side_effect = [ "custom_llm", # provider prefix "my-model", # model name @@ -491,9 +491,9 @@ def test_adds_custom_provider_to_csv(self, temp_home): "1.0", # input cost "2.0", # output cost ] - with mock.patch("pdd.setup.provider_manager.Confirm") as mock_confirm: + with mock.patch("pdd.provider_manager.Confirm") as mock_confirm: mock_confirm.ask.return_value = False # Don't enter key value now - with mock.patch("pdd.setup.provider_manager.console"): + with mock.patch("pdd.provider_manager.console"): result = add_custom_provider() assert result is True @@ -510,7 +510,7 @@ def test_saves_api_key_when_provided(self, temp_home, monkeypatch): """Should save API key to api-env when user provides it.""" monkeypatch.setenv("SHELL", "/bin/bash") - with mock.patch("pdd.setup.provider_manager.Prompt") as mock_prompt: + with mock.patch("pdd.provider_manager.Prompt") as mock_prompt: mock_prompt.ask.side_effect = [ "custom_llm", "my-model", @@ -520,9 +520,9 @@ def test_saves_api_key_when_provided(self, temp_home, monkeypatch): "0.0", "sk-my-secret-key", # API key value ] - with mock.patch("pdd.setup.provider_manager.Confirm") as mock_confirm: + with mock.patch("pdd.provider_manager.Confirm") as mock_confirm: mock_confirm.ask.return_value = True # Yes, enter key value - with mock.patch("pdd.setup.provider_manager.console"): + with mock.patch("pdd.provider_manager.console"): result = add_custom_provider() assert result is True @@ -543,16 +543,16 @@ class TestRemoveModelsByProvider: def test_returns_false_when_no_models(self, temp_home): """Should return False when no models configured.""" - with mock.patch("pdd.setup.provider_manager.console"): + with mock.patch("pdd.provider_manager.console"): result = remove_models_by_provider() assert result is False def test_returns_false_on_cancel(self, sample_csv, temp_home): """Should return False when user cancels.""" - with mock.patch("pdd.setup.provider_manager.Prompt") as mock_prompt: + with mock.patch("pdd.provider_manager.Prompt") as mock_prompt: mock_prompt.ask.return_value = "" # Empty = cancel - with mock.patch("pdd.setup.provider_manager.console"): + with mock.patch("pdd.provider_manager.console"): result = remove_models_by_provider() assert result is False @@ -561,11 +561,11 @@ def test_removes_all_models_for_provider(self, sample_csv, temp_home, monkeypatc """Should remove all models with matching api_key.""" monkeypatch.setenv("SHELL", "/bin/bash") - with mock.patch("pdd.setup.provider_manager.Prompt") as mock_prompt: + with mock.patch("pdd.provider_manager.Prompt") as mock_prompt: mock_prompt.ask.return_value = "1" # Select first provider - with mock.patch("pdd.setup.provider_manager.Confirm") as mock_confirm: + with mock.patch("pdd.provider_manager.Confirm") as mock_confirm: mock_confirm.ask.return_value = True # Confirm removal - with mock.patch("pdd.setup.provider_manager.console"): + with mock.patch("pdd.provider_manager.console"): result = remove_models_by_provider() assert result is True @@ -586,27 +586,27 @@ class TestRemoveIndividualModels: def test_returns_false_when_no_models(self, temp_home): """Should return False when no models configured.""" - with mock.patch("pdd.setup.provider_manager.console"): + with mock.patch("pdd.provider_manager.console"): result = remove_individual_models() assert result is False def test_returns_false_on_cancel(self, sample_csv, temp_home): """Should return False when user cancels.""" - with mock.patch("pdd.setup.provider_manager.Prompt") as mock_prompt: + with mock.patch("pdd.provider_manager.Prompt") as mock_prompt: mock_prompt.ask.return_value = "" # Empty = cancel - with mock.patch("pdd.setup.provider_manager.console"): + with mock.patch("pdd.provider_manager.console"): result = remove_individual_models() assert result is False def test_removes_selected_models(self, sample_csv, temp_home): """Should remove only selected models.""" - with mock.patch("pdd.setup.provider_manager.Prompt") as mock_prompt: + with mock.patch("pdd.provider_manager.Prompt") as mock_prompt: mock_prompt.ask.return_value = "1" # Remove first model - with mock.patch("pdd.setup.provider_manager.Confirm") as mock_confirm: + with mock.patch("pdd.provider_manager.Confirm") as mock_confirm: mock_confirm.ask.return_value = True # Confirm - with mock.patch("pdd.setup.provider_manager.console"): + with mock.patch("pdd.provider_manager.console"): result = remove_individual_models() assert result is True @@ -617,11 +617,11 @@ def test_removes_selected_models(self, sample_csv, temp_home): def test_removes_multiple_models(self, sample_csv, temp_home): """Should handle comma-separated model selection.""" - with mock.patch("pdd.setup.provider_manager.Prompt") as mock_prompt: + with mock.patch("pdd.provider_manager.Prompt") as mock_prompt: mock_prompt.ask.return_value = "1, 2" # Remove first two models - with mock.patch("pdd.setup.provider_manager.Confirm") as mock_confirm: + with mock.patch("pdd.provider_manager.Confirm") as mock_confirm: mock_confirm.ask.return_value = True # Confirm - with mock.patch("pdd.setup.provider_manager.console"): + with mock.patch("pdd.provider_manager.console"): result = remove_individual_models() assert result is True diff --git a/tests/test_setup_tool.py b/tests/test_setup_tool.py index b89710774..1c9f1d65a 100644 --- a/tests/test_setup_tool.py +++ b/tests/test_setup_tool.py @@ -1,574 +1,378 @@ -"""Tests for setup_tool.py""" +"""Tests for setup_tool.py — deterministic auto-configuration.""" -import subprocess -import tempfile -from pathlib import Path import pytest -from pdd.setup_tool import create_api_env_script - - -def test_create_api_env_script_with_special_characters_bash(): - """ - Test that API keys with special shell characters are properly escaped - when generating bash/zsh shell scripts. - - This test will fail with the current implementation (no escaping) and - pass after fixing with shlex.quote(). - """ - # Simulate a Gemini API key that might contain special characters - # These are realistic characters that could appear in API keys or be accidentally - # included when copy-pasting - test_keys = { - 'GEMINI_API_KEY': 'AIzaSyAbCdEf123456$var"quote\'backtick\\slash' - } - - # Generate the script - script_content = create_api_env_script(test_keys, 'bash') - - # Write to a temporary file - with tempfile.NamedTemporaryFile(mode='w', suffix='.sh', delete=False) as f: - f.write(script_content) - script_path = Path(f.name) - - try: - # Try to parse/validate the script by running it with bash -n (syntax check) - # This will fail if the script has parsing errors - result = subprocess.run( - ['bash', '-n', str(script_path)], - capture_output=True, - text=True, - timeout=5 - ) - - # The script should parse without errors - assert result.returncode == 0, ( - f"Generated script has syntax errors: {result.stderr}\n" - f"Script content:\n{script_content}" - ) - - # Additionally, try to source it in a subprocess to ensure it can be executed - # We'll check the exit code but not the actual env vars (since they're set in subprocess) - result = subprocess.run( - ['bash', '-c', f'source {script_path} && exit 0'], - capture_output=True, - text=True, - timeout=5 - ) - - assert result.returncode == 0, ( - f"Generated script cannot be sourced: {result.stderr}\n" - f"Script content:\n{script_content}" - ) - - finally: - # Clean up - script_path.unlink() - - -def test_create_api_env_script_with_special_characters_zsh(): - """Test that API keys with special characters work in zsh scripts.""" - test_keys = { - 'GEMINI_API_KEY': 'AIzaSyAbCdEf123456$var"quote\'backtick\\slash' - } - - script_content = create_api_env_script(test_keys, 'zsh') - - with tempfile.NamedTemporaryFile(mode='w', suffix='.sh', delete=False) as f: - f.write(script_content) - script_path = Path(f.name) - - try: - # Test zsh syntax - result = subprocess.run( - ['zsh', '-n', str(script_path)], - capture_output=True, - text=True, - timeout=5 - ) - - assert result.returncode == 0, ( - f"Generated zsh script has syntax errors: {result.stderr}\n" - f"Script content:\n{script_content}" - ) - finally: - script_path.unlink() - - -def test_create_api_env_script_with_common_problematic_characters(): - """ - Test with various common problematic characters that might appear in API keys. - - Characters tested: - - Double quotes: " - - Single quotes: ' - - Dollar signs: $ (variable expansion) - - Backticks: ` (command substitution) - - Backslashes: \\ (escaping) - - Spaces: (should be handled) - - Parentheses: () (might be interpreted) - """ - problematic_key = 'key"with\'many$special`characters\\and spaces(too)' - test_keys = { - 'GEMINI_API_KEY': problematic_key - } - - # Test all common shells - for shell in ['bash', 'zsh', 'sh']: - script_content = create_api_env_script(test_keys, shell) - - with tempfile.NamedTemporaryFile(mode='w', suffix='.sh', delete=False) as f: - f.write(script_content) - script_path = Path(f.name) - - try: - # Use bash/sh for sh, bash for bash, zsh for zsh - shell_cmd = 'sh' if shell == 'sh' else shell - result = subprocess.run( - [shell_cmd, '-n', str(script_path)], - capture_output=True, - text=True, - timeout=5 - ) - - assert result.returncode == 0, ( - f"Generated {shell} script has syntax errors: {result.stderr}\n" - f"Script content:\n{script_content}" - ) - finally: - script_path.unlink() - - -def test_create_api_env_script_preserves_key_value(): - """ - Test that after proper escaping, the key value can still be correctly - extracted when the script is sourced. - """ - original_key = 'AIzaSyAbCdEf123456$var"quote\'backtick\\slash' - test_keys = { - 'GEMINI_API_KEY': original_key - } - - script_content = create_api_env_script(test_keys, 'bash') - - with tempfile.NamedTemporaryFile(mode='w', suffix='.sh', delete=False) as f: - f.write(script_content) - script_path = Path(f.name) - - try: - # Source the script and extract the value - # We'll use a Python subprocess to avoid shell escaping issues in our test - result = subprocess.run( - ['bash', '-c', f'source {script_path} && python3 -c "import os; print(os.environ.get(\'GEMINI_API_KEY\', \'\'))"'], - capture_output=True, - text=True, - timeout=5 - ) - - assert result.returncode == 0, ( - f"Failed to source script and read env var: {result.stderr}\n" - f"Script content:\n{script_content}" - ) - - extracted_key = result.stdout.strip() - assert extracted_key == original_key, ( - f"Key value was corrupted during escaping.\n" - f"Original: {repr(original_key)}\n" - f"Extracted: {repr(extracted_key)}\n" - f"Script content:\n{script_content}" - ) - finally: - script_path.unlink() - - -def test_create_api_env_script_with_normal_key(): - """ - Test that normal keys (without special characters) still work correctly. - This ensures our fix doesn't break existing functionality. - """ - normal_key = 'AIzaSyAbCdEf1234567890_normal_key_value' - test_keys = { - 'OPENAI_API_KEY': normal_key, - 'GEMINI_API_KEY': normal_key - } - - script_content = create_api_env_script(test_keys, 'bash') - - with tempfile.NamedTemporaryFile(mode='w', suffix='.sh', delete=False) as f: - f.write(script_content) - script_path = Path(f.name) - - try: - result = subprocess.run( - ['bash', '-n', str(script_path)], - capture_output=True, - text=True, - timeout=5 - ) - - assert result.returncode == 0, ( - f"Normal key failed syntax check: {result.stderr}\n" - f"Script content:\n{script_content}" - ) - - # Verify values can be extracted - result = subprocess.run( - ['bash', '-c', f'source {script_path} && python3 -c "import os; print(os.environ.get(\'OPENAI_API_KEY\', \'\')); print(os.environ.get(\'GEMINI_API_KEY\', \'\'))"'], - capture_output=True, - text=True, - timeout=5 - ) - - assert result.returncode == 0 - extracted_keys = result.stdout.strip().split('\n') - assert extracted_keys[0] == normal_key - assert extracted_keys[1] == normal_key - finally: - script_path.unlink() - - -def _shell_available(shell: str) -> bool: - """Check if a shell is available on the system""" - try: - result = subprocess.run( - ['which', shell], - capture_output=True, - timeout=2 - ) - return result.returncode == 0 - except (subprocess.TimeoutExpired, FileNotFoundError): - return False - - -def test_create_api_env_script_with_special_characters_fish(): - """ - Test that API keys with special characters work in fish shell scripts. - - This test verifies that shlex.quote() works correctly with fish shell. - Fish is not POSIX-compliant, so there may be edge cases where POSIX-style - quoting doesn't work as expected. - """ - if not _shell_available('fish'): - pytest.skip("fish shell not available") - - test_keys = { - 'GEMINI_API_KEY': 'AIzaSyAbCdEf123456$var"quote\'backtick\\slash' - } - - script_content = create_api_env_script(test_keys, 'fish') - - with tempfile.NamedTemporaryFile(mode='w', suffix='.fish', delete=False) as f: - f.write(script_content) - script_path = Path(f.name) - - try: - # Fish doesn't have a -n syntax check flag like bash/zsh - # So we'll try to source it and see if it works - result = subprocess.run( - ['fish', '-c', f'source {script_path}; exit 0'], - capture_output=True, - text=True, - timeout=5 - ) - - assert result.returncode == 0, ( - f"Generated fish script has syntax/execution errors: {result.stderr}\n" - f"Script content:\n{script_content}" - ) - finally: - script_path.unlink() - - -def test_create_api_env_script_preserves_key_value_fish(): - """ - Test that fish shell correctly preserves key values with special characters. - - This is critical because fish has different quoting rules than POSIX shells, - and shlex.quote() may not handle all cases correctly. - """ - if not _shell_available('fish'): - pytest.skip("fish shell not available") - - original_key = 'AIzaSyAbCdEf123456$var"quote\'backtick\\slash' - test_keys = { - 'GEMINI_API_KEY': original_key - } - - script_content = create_api_env_script(test_keys, 'fish') - - with tempfile.NamedTemporaryFile(mode='w', suffix='.fish', delete=False) as f: - f.write(script_content) - script_path = Path(f.name) - - try: - # Source the script and extract the value using fish - result = subprocess.run( - ['fish', '-c', f'source {script_path}; python3 -c "import os; print(os.environ.get(\'GEMINI_API_KEY\', \'\'))"'], - capture_output=True, - text=True, - timeout=5 - ) - - assert result.returncode == 0, ( - f"Failed to source fish script and read env var: {result.stderr}\n" - f"Script content:\n{script_content}" - ) - - extracted_key = result.stdout.strip() - assert extracted_key == original_key, ( - f"Key value was corrupted during escaping in fish shell.\n" - f"Original: {repr(original_key)}\n" - f"Extracted: {repr(extracted_key)}\n" - f"Script content:\n{script_content}\n" - f"This indicates shlex.quote() may not work correctly with fish shell." - ) - finally: - script_path.unlink() - - -def test_create_api_env_script_with_special_characters_csh(): - """ - Test that API keys with special characters work in csh/tcsh shell scripts. - - WARNING: csh/tcsh have fundamentally different quoting rules than POSIX shells. - shlex.quote() uses POSIX single-quote syntax which may not work correctly - in csh/tcsh, especially with: - - Variables containing $ (variable expansion still occurs in single quotes) - - Complex backslash sequences - - Certain special characters - - This test will help identify if shlex.quote() works correctly with csh/tcsh. - """ - # Try csh first, then tcsh - shell_cmd = None - shell_name = None - for shell in ['csh', 'tcsh']: - if _shell_available(shell): - shell_cmd = shell - shell_name = shell - break - - if not shell_cmd: - pytest.skip("csh/tcsh not available") - - test_keys = { - 'GEMINI_API_KEY': 'AIzaSyAbCdEf123456$var"quote\'backtick\\slash' - } - - script_content = create_api_env_script(test_keys, shell_name) - - with tempfile.NamedTemporaryFile(mode='w', suffix='.csh', delete=False) as f: - f.write(script_content) - script_path = Path(f.name) - - try: - # csh/tcsh don't have a -n flag, so we'll try to source it - # Use -f to prevent reading .cshrc/.tcshrc which might interfere - result = subprocess.run( - [shell_cmd, '-f', '-c', f'source {script_path}; exit 0'], - capture_output=True, - text=True, - timeout=5 - ) - - assert result.returncode == 0, ( - f"Generated {shell_name} script has syntax/execution errors: {result.stderr}\n" - f"Script content:\n{script_content}\n" - f"This may indicate that shlex.quote() doesn't work correctly with {shell_name}." - ) - finally: - script_path.unlink() - - -def test_create_api_env_script_preserves_key_value_csh(): - """ - Test that csh/tcsh correctly preserves key values with special characters. - - This is critical because csh/tcsh have fundamentally different quoting rules: - - Single quotes in csh do NOT prevent variable expansion ($var still expands) - - Backslash escaping works differently - - The quoting mechanism is incompatible with POSIX - - This test will likely reveal issues with using shlex.quote() for csh/tcsh. - """ - # Try csh first, then tcsh - shell_cmd = None - shell_name = None - for shell in ['csh', 'tcsh']: - if _shell_available(shell): - shell_cmd = shell - shell_name = shell - break - - if not shell_cmd: - pytest.skip("csh/tcsh not available") - - original_key = 'AIzaSyAbCdEf123456$var"quote\'backtick\\slash' - test_keys = { - 'GEMINI_API_KEY': original_key - } - - script_content = create_api_env_script(test_keys, shell_name) - - with tempfile.NamedTemporaryFile(mode='w', suffix='.csh', delete=False) as f: - f.write(script_content) - script_path = Path(f.name) - - try: - # Source the script and extract the value using csh/tcsh - # Use -f to prevent reading .cshrc/.tcshrc - result = subprocess.run( - [shell_cmd, '-f', '-c', f'source {script_path}; python3 -c "import os; print(os.environ.get(\'GEMINI_API_KEY\', \'\'))"'], - capture_output=True, - text=True, - timeout=5 - ) - - assert result.returncode == 0, ( - f"Failed to source {shell_name} script and read env var: {result.stderr}\n" - f"Script content:\n{script_content}" - ) - - extracted_key = result.stdout.strip() - assert extracted_key == original_key, ( - f"Key value was corrupted during escaping in {shell_name} shell.\n" - f"Original: {repr(original_key)}\n" - f"Extracted: {repr(extracted_key)}\n" - f"Script content:\n{script_content}\n" - f"This indicates shlex.quote() does NOT work correctly with {shell_name}.\n" - f"csh/tcsh have different quoting rules than POSIX shells." - ) - finally: - script_path.unlink() - - -def test_create_api_env_script_csh_variable_expansion_issue(): - """ - Test a specific csh/tcsh issue: variable expansion in single quotes. - - In csh/tcsh, single quotes do NOT prevent variable expansion. - This means a key containing $HOME will expand to the actual home directory - path, which is incorrect behavior. - - This test demonstrates the fundamental incompatibility between - POSIX-style quoting (shlex.quote) and csh/tcsh. - """ - # Try csh first, then tcsh - shell_cmd = None - shell_name = None - for shell in ['csh', 'tcsh']: - if _shell_available(shell): - shell_cmd = shell - shell_name = shell - break - - if not shell_cmd: - pytest.skip("csh/tcsh not available") - - # Create a key that contains $HOME to test variable expansion - # In POSIX shells, this should be preserved as-is - # In csh/tcsh, this might expand to the actual home directory - test_key = 'api_key_with_$HOME_in_it' - test_keys = { - 'GEMINI_API_KEY': test_key - } - - script_content = create_api_env_script(test_keys, shell_name) - - with tempfile.NamedTemporaryFile(mode='w', suffix='.csh', delete=False) as f: - f.write(script_content) - script_path = Path(f.name) - - try: - # Source the script and extract the value - result = subprocess.run( - [shell_cmd, '-f', '-c', f'source {script_path}; python3 -c "import os; print(os.environ.get(\'GEMINI_API_KEY\', \'\'))"'], - capture_output=True, - text=True, - timeout=5 - ) - - assert result.returncode == 0, ( - f"Failed to source {shell_name} script: {result.stderr}\n" - f"Script content:\n{script_content}" - ) - - extracted_key = result.stdout.strip() - # This test will likely FAIL, demonstrating the issue - assert extracted_key == test_key, ( - f"Variable expansion occurred in {shell_name} despite single quotes!\n" - f"Expected: {repr(test_key)}\n" - f"Got: {repr(extracted_key)}\n" - f"Script content:\n{script_content}\n" - f"This proves that shlex.quote() (POSIX single quotes) does NOT work\n" - f"correctly with csh/tcsh, which expand variables even in single quotes." - ) - finally: - script_path.unlink() - - -def test_create_api_env_script_fish_edge_cases(): - """ - Test fish shell with various edge cases that might reveal quoting issues. - - Fish shell, while often compatible with POSIX-style quoting, may have - edge cases with certain character combinations. - """ - if not _shell_available('fish'): - pytest.skip("fish shell not available") - - edge_cases = [ - 'key with spaces', - "key'with'single'quotes", - 'key"with"double"quotes', - 'key$with$dollars', - 'key\\with\\backslashes', - 'key`with`backticks', - 'key(with)parentheses', - 'key[with]brackets', - 'key{with}braces', - 'key;with;semicolons', - 'key|with|pipes', - 'key&with&ersands', - 'keyredirects', - 'key\nwith\nnewlines', - 'key\twith\ttabs', +from unittest.mock import MagicMock, patch +from pathlib import Path + +from pdd import setup_tool + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture +def mock_cli_result(): + result = MagicMock() + result.cli_name = "test_cli" + result.provider = "test_provider" + result.api_key_configured = True + return result + + +@pytest.fixture +def mock_detect_cli(): + with patch("pdd.cli_detector.detect_and_bootstrap_cli") as m: + yield m + + +@pytest.fixture +def mock_auto_phase(): + with patch("pdd.setup_tool._run_auto_phase") as m: + yield m + + +@pytest.fixture +def mock_fallback_menu(): + with patch("pdd.setup_tool._run_fallback_menu") as m: + yield m + + +@pytest.fixture +def mock_input(): + with patch("builtins.input") as m: + yield m + + +@pytest.fixture +def mock_print(): + with patch("builtins.print") as m: + yield m + + +# --------------------------------------------------------------------------- +# Tests for run_setup +# --------------------------------------------------------------------------- + +def test_run_setup_no_cli_detected(mock_detect_cli, mock_auto_phase, mock_print): + """Test that setup exits early if no CLI is detected.""" + result = MagicMock() + result.cli_name = "" + mock_detect_cli.return_value = result + + setup_tool.run_setup() + + mock_detect_cli.assert_called_once() + mock_auto_phase.assert_not_called() + assert any("Agentic features require at least one CLI tool" in str(c) for c in mock_print.call_args_list) + + +def test_run_setup_success_path(mock_detect_cli, mock_auto_phase, mock_input, mock_print, mock_cli_result): + """Test the happy path where auto phase succeeds.""" + mock_detect_cli.return_value = mock_cli_result + mock_auto_phase.return_value = True + + with patch("pdd.setup_tool._console") as mock_console: + setup_tool.run_setup() + + mock_detect_cli.assert_called_once() + mock_auto_phase.assert_called_once() + assert any("Setup complete" in str(c) for c in mock_console.print.call_args_list) + + +def test_run_setup_fallback_path(mock_detect_cli, mock_auto_phase, mock_fallback_menu, mock_input, mock_cli_result): + """Test that fallback menu is triggered if auto phase fails.""" + mock_detect_cli.return_value = mock_cli_result + mock_auto_phase.return_value = False + + setup_tool.run_setup() + + mock_auto_phase.assert_called_once() + mock_fallback_menu.assert_called_once() + + +def test_run_setup_keyboard_interrupt(mock_detect_cli, mock_print): + """Test handling of KeyboardInterrupt during setup.""" + mock_detect_cli.side_effect = KeyboardInterrupt + + setup_tool.run_setup() + + assert any("Setup interrupted" in str(c) for c in mock_print.call_args_list) + + +def test_run_setup_no_api_key_warning(mock_detect_cli, mock_auto_phase, mock_input, mock_print): + """Test that a warning is printed if API key is not configured, but proceeds.""" + result = MagicMock() + result.cli_name = "test_cli" + result.api_key_configured = False + mock_detect_cli.return_value = result + mock_auto_phase.return_value = True + + setup_tool.run_setup() + + assert any("No API key configured" in str(c) for c in mock_print.call_args_list) + mock_auto_phase.assert_called_once() + + +# --------------------------------------------------------------------------- +# Tests for _run_auto_phase +# --------------------------------------------------------------------------- + +@patch("pdd.setup_tool._step4_test_and_summary") +@patch("pdd.setup_tool._step3_local_llms_and_pddrc") +@patch("pdd.setup_tool._step2_configure_models") +@patch("pdd.setup_tool._step1_scan_keys") +@patch("builtins.input") +def test_run_auto_phase_success(mock_input, mock_step1, mock_step2, mock_step3, mock_step4): + """Test that all 4 steps run sequentially on success.""" + mock_step1.return_value = [("ANTHROPIC_API_KEY", "shell environment")] + mock_step2.return_value = {"Anthropic": 3} + mock_step3.return_value = {"Ollama": ["llama3.2:3b"]} + + result = setup_tool._run_auto_phase() + + assert result is True + mock_step1.assert_called_once() + mock_step2.assert_called_once_with([("ANTHROPIC_API_KEY", "shell environment")]) + mock_step3.assert_called_once() + mock_step4.assert_called_once() + # 3 "Press Enter" prompts between steps + assert mock_input.call_count == 3 + + +@patch("pdd.setup_tool._step1_scan_keys") +@patch("builtins.input") +def test_run_auto_phase_exception_returns_false(mock_input, mock_step1): + """Test that exceptions in steps cause fallback.""" + mock_step1.side_effect = RuntimeError("test error") + + result = setup_tool._run_auto_phase() + + assert result is False + + +# --------------------------------------------------------------------------- +# Tests for _step1_scan_keys +# --------------------------------------------------------------------------- + +@patch("pdd.setup_tool._prompt_for_api_key") +@patch("pdd.api_key_scanner._parse_api_env_file") +@patch("pdd.api_key_scanner._detect_shell") +@patch("pdd.litellm_registry.PROVIDER_API_KEY_MAP", {"anthropic": "ANTHROPIC_API_KEY", "openai": "OPENAI_API_KEY"}) +def test_step1_finds_keys_in_env(mock_detect_shell, mock_parse, mock_prompt, tmp_path, monkeypatch): + """Test that step 1 finds keys from os.environ.""" + monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-test") + monkeypatch.delenv("OPENAI_API_KEY", raising=False) + monkeypatch.setattr(Path, "home", lambda: tmp_path) + mock_detect_shell.return_value = "bash" + mock_parse.return_value = {} + + found = setup_tool._step1_scan_keys() + + assert len(found) == 1 + assert found[0] == ("ANTHROPIC_API_KEY", "shell environment") + mock_prompt.assert_not_called() + + +@patch("pdd.setup_tool._prompt_for_api_key") +@patch("pdd.api_key_scanner._parse_api_env_file") +@patch("pdd.api_key_scanner._detect_shell") +@patch("pdd.litellm_registry.PROVIDER_API_KEY_MAP", {"anthropic": "ANTHROPIC_API_KEY"}) +def test_step1_no_keys_triggers_prompt(mock_detect_shell, mock_parse, mock_prompt, tmp_path, monkeypatch): + """Test that step 1 prompts for a key when none found.""" + monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False) + monkeypatch.setattr(Path, "home", lambda: tmp_path) + mock_detect_shell.return_value = "bash" + mock_parse.return_value = {} + mock_prompt.return_value = [("ANTHROPIC_API_KEY", "~/.pdd/api-env.bash")] + + found = setup_tool._step1_scan_keys() + + assert len(found) == 1 + mock_prompt.assert_called_once() + + +# --------------------------------------------------------------------------- +# Tests for _step2_configure_models +# --------------------------------------------------------------------------- + +@patch("pdd.provider_manager._get_user_csv_path") +@patch("pdd.provider_manager._write_csv_atomic") +@patch("pdd.provider_manager._read_csv") +def test_step2_adds_matching_models(mock_read, mock_write, mock_csv_path, tmp_path): + """Test that step 2 filters reference CSV by found keys and writes user CSV.""" + mock_csv_path.return_value = tmp_path / "llm_model.csv" + + # First call: reference CSV, second call: existing user CSV (empty) + mock_read.side_effect = [ + [ + {"provider": "Anthropic", "model": "claude-sonnet", "api_key": "ANTHROPIC_API_KEY", "base_url": ""}, + {"provider": "OpenAI", "model": "gpt-4", "api_key": "OPENAI_API_KEY", "base_url": ""}, + ], + [], # empty user CSV + ] + + found_keys = [("ANTHROPIC_API_KEY", "shell environment")] + result = setup_tool._step2_configure_models(found_keys) + + assert result == {"Anthropic": 1} + mock_write.assert_called_once() + written_rows = mock_write.call_args[0][1] + assert len(written_rows) == 1 + assert written_rows[0]["model"] == "claude-sonnet" + + +@patch("pdd.provider_manager._get_user_csv_path") +@patch("pdd.provider_manager._write_csv_atomic") +@patch("pdd.provider_manager._read_csv") +def test_step2_deduplicates_existing(mock_read, mock_write, mock_csv_path, tmp_path): + """Test that step 2 skips models already in user CSV.""" + mock_csv_path.return_value = tmp_path / "llm_model.csv" + + mock_read.side_effect = [ + [{"provider": "Anthropic", "model": "claude-sonnet", "api_key": "ANTHROPIC_API_KEY", "base_url": ""}], + [{"provider": "Anthropic", "model": "claude-sonnet", "api_key": "ANTHROPIC_API_KEY"}], + ] + + found_keys = [("ANTHROPIC_API_KEY", "shell environment")] + result = setup_tool._step2_configure_models(found_keys) + + assert result == {"Anthropic": 1} + mock_write.assert_not_called() + + +@patch("pdd.provider_manager._get_user_csv_path") +@patch("pdd.provider_manager._write_csv_atomic") +@patch("pdd.provider_manager._read_csv") +def test_step2_skips_local_models(mock_read, mock_write, mock_csv_path, tmp_path): + """Test that step 2 skips local models (ollama, lm_studio, localhost).""" + mock_csv_path.return_value = tmp_path / "llm_model.csv" + + mock_read.side_effect = [ + [ + {"provider": "Ollama", "model": "ollama/llama", "api_key": "", "base_url": "http://localhost:11434"}, + {"provider": "lm_studio", "model": "lm/model", "api_key": "", "base_url": "http://localhost:1234"}, + {"provider": "OpenAI", "model": "gpt-local", "api_key": "OPENAI_API_KEY", "base_url": "http://localhost:8080"}, + {"provider": "Anthropic", "model": "claude", "api_key": "ANTHROPIC_API_KEY", "base_url": ""}, + ], + [], + ] + + found_keys = [("ANTHROPIC_API_KEY", "env"), ("OPENAI_API_KEY", "env")] + result = setup_tool._step2_configure_models(found_keys) + + assert result == {"Anthropic": 1} + + +# --------------------------------------------------------------------------- +# Tests for local LLM helpers +# --------------------------------------------------------------------------- + +def test_extract_ollama_models(): + """Test Ollama model name extraction from API response.""" + data = {"models": [{"name": "llama3.2:3b"}, {"name": "openhermes:latest"}, {"name": ""}]} + result = setup_tool._extract_ollama_models(data) + assert result == ["llama3.2:3b", "openhermes:latest"] + + +def test_extract_ollama_models_empty(): + """Test Ollama extraction with empty models list.""" + assert setup_tool._extract_ollama_models({"models": []}) == [] + assert setup_tool._extract_ollama_models({}) == [] + + +def test_extract_lm_studio_models(): + """Test LM Studio model name extraction from API response.""" + data = {"data": [{"id": "model-a"}, {"id": "model-b"}, {"id": ""}]} + result = setup_tool._extract_lm_studio_models(data) + assert result == ["model-a", "model-b"] + + +def test_extract_lm_studio_models_empty(): + """Test LM Studio extraction with empty data list.""" + assert setup_tool._extract_lm_studio_models({"data": []}) == [] + assert setup_tool._extract_lm_studio_models({}) == [] + + +# --------------------------------------------------------------------------- +# Tests for _run_fallback_menu +# --------------------------------------------------------------------------- + +@patch("pdd.pddrc_initializer.offer_pddrc_init") +@patch("pdd.model_tester.test_model_interactive") +@patch("pdd.provider_manager.add_provider_from_registry") +@patch("builtins.input") +def test_run_fallback_menu_options(mock_input, mock_add_provider, mock_test_model, mock_init_pddrc): + """Test the fallback menu loop and options.""" + mock_input.side_effect = ["1", "2", "3", "5", "4"] + + setup_tool._run_fallback_menu() + + mock_add_provider.assert_called_once() + mock_test_model.assert_called_once() + mock_init_pddrc.assert_called_once() + assert mock_input.call_count == 5 + + +@patch("builtins.input") +def test_run_fallback_menu_interrupt(mock_input, mock_print): + """Test exiting fallback menu via KeyboardInterrupt.""" + mock_input.side_effect = KeyboardInterrupt + + setup_tool._run_fallback_menu() + + assert any("Setup interrupted" in str(c) for c in mock_print.call_args_list) + + +# --------------------------------------------------------------------------- +# Tests for _prompt_for_api_key +# --------------------------------------------------------------------------- + +@patch("pdd.provider_manager._save_key_to_api_env") +@patch("pdd.setup_tool.getpass") +@patch("builtins.input") +def test_prompt_for_api_key_adds_key(mock_input, mock_getpass, mock_save): + """Test that prompt flow saves a key and returns it.""" + mock_input.side_effect = [ + "1", # Select Anthropic + "n", # Don't add another ] - - for i, test_key in enumerate(edge_cases): - test_keys = { - 'TEST_API_KEY': test_key - } - - script_content = create_api_env_script(test_keys, 'fish') - - with tempfile.NamedTemporaryFile(mode='w', suffix=f'.fish', delete=False) as f: - f.write(script_content) - script_path = Path(f.name) - - try: - # Try to source it - result = subprocess.run( - ['fish', '-c', f'source {script_path}; python3 -c "import os; print(os.environ.get(\'TEST_API_KEY\', \'\'))"'], - capture_output=True, - text=True, - timeout=5 - ) - - if result.returncode != 0: - pytest.fail( - f"Fish shell failed with edge case {i+1}: {repr(test_key)}\n" - f"Error: {result.stderr}\n" - f"Script content:\n{script_content}" - ) - - extracted_key = result.stdout.strip() - if extracted_key != test_key: - pytest.fail( - f"Fish shell corrupted value for edge case {i+1}: {repr(test_key)}\n" - f"Expected: {repr(test_key)}\n" - f"Got: {repr(extracted_key)}\n" - f"Script content:\n{script_content}" - ) - finally: - script_path.unlink() + mock_getpass.getpass.return_value = "sk-test-key-123" + + result = setup_tool._prompt_for_api_key() + + assert len(result) == 1 + assert result[0][0] == "ANTHROPIC_API_KEY" + mock_save.assert_called_once_with("ANTHROPIC_API_KEY", "sk-test-key-123") + + +@patch("pdd.provider_manager._save_key_to_api_env") +@patch("pdd.setup_tool.getpass") +@patch("builtins.input") +def test_prompt_for_api_key_skip(mock_input, mock_getpass, mock_save): + """Test that skip option returns empty list.""" + skip_idx = len(setup_tool._PROMPT_PROVIDERS) + 2 + mock_input.side_effect = [str(skip_idx)] + + result = setup_tool._prompt_for_api_key() + + assert result == [] + mock_save.assert_not_called() + + +@patch("pdd.provider_manager._save_key_to_api_env") +@patch("pdd.setup_tool.getpass") +@patch("builtins.input") +def test_prompt_for_api_key_empty_value_skips(mock_input, mock_getpass, mock_save): + """Test that empty key value is rejected gracefully.""" + skip_idx = len(setup_tool._PROMPT_PROVIDERS) + 2 + mock_input.side_effect = [ + "1", # Select Anthropic + str(skip_idx), # Skip after empty key + ] + mock_getpass.getpass.return_value = "" + + result = setup_tool._prompt_for_api_key() + assert result == [] + mock_save.assert_not_called() From 5d16762b598f248e270681d8a03bf489b6ce7daf Mon Sep 17 00:00:00 2001 From: Niti Goyal Date: Tue, 17 Feb 2026 02:14:44 -0500 Subject: [PATCH 07/10] Update README.md --- README.md | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index b93e2f61e..f0ec612a1 100644 --- a/README.md +++ b/README.md @@ -102,7 +102,7 @@ With the CLI on your `PATH`, continue with: ```bash pdd setup ``` -The command installs tab completion, walks you through API key entry, and seeds local configuration files. +The command detects agentic CLI tools, scans for API keys, configures models, and seeds local configuration files. If you postpone this step, the CLI detects the missing setup artifacts the first time you run another command and shows a reminder banner so you can complete it later (the banner is suppressed once `~/.pdd/api-env` exists or when your project already provides credentials via `.env` or `.pdd/`). ### Alternative: pip Installation @@ -227,13 +227,13 @@ Run the comprehensive setup wizard: pdd setup ``` -The setup wizard will: -- **Scan your environment** for API keys from all sources (shell, .env, ~/.pdd files) -- **Present an interactive menu** with options to add/fix keys, configure local LLMs (Ollama, LM Studio), add custom providers, or remove providers -- **Validate API keys** using actual LLM requests to ensure they work -- **Guide model selection** with cost transparency (show pricing for each tier) -- **Detect agentic CLI tools** (claude, gemini, codex) and offer installation -- **Create .pddrc** configuration file with sensible defaults for your project +The setup wizard runs in two phases: +- **Phase 1** — Detects agentic CLI tools (claude, gemini, codex) and offers installation if needed +- **Phase 2** — Auto-configures PDD in 4 deterministic steps: + 1. Scans for API keys across shell, .env, and ~/.pdd files (prompts to add one if none found) + 2. Configures models from a reference CSV based on your available keys + 3. Checks for local LLMs (Ollama, LM Studio) and creates a `.pddrc` config file + 4. Tests a model and prints a summary The wizard can be re-run at any time to update keys, add providers, or reconfigure settings. @@ -2703,12 +2703,12 @@ The `.pddrc` approach is recommended for team projects as it ensures consistent ### Model Configuration (`llm_model.csv`) -PDD uses a CSV file (`llm_model.csv`) to store information about available AI models, their costs, capabilities, and required API key names. The `pdd setup` wizard automatically manages this file by: +PDD uses a CSV file (`llm_model.csv`) to store information about available AI models, their costs, capabilities, and required API key names. The `pdd setup` command automatically manages this file by: -- **Dynamic provider discovery:** Reading all provider API keys from the CSV to scan your environment -- **Interactive model selection:** Letting you choose which model tiers to enable (Fast/Cheap, Balanced, Most Capable) with cost transparency -- **Custom provider support:** Adding custom LiteLLM-compatible providers and local LLMs (Ollama, LM Studio) -- **Provider removal:** Safely removing providers by deleting their model rows from the CSV +- **API key scanning:** Checking shell environment, `.env` files, and `~/.pdd/api-env.*` for provider keys +- **Automatic model configuration:** Matching found keys to a bundled reference CSV and writing matching models to your user CSV +- **Local LLM detection:** Discovering running Ollama and LM Studio servers and adding their models +- **Fallback menu:** Manual options to add providers, test models, or initialize `.pddrc` if auto-configuration fails When running commands locally, PDD determines which configuration file to use based on the following priority: @@ -2718,7 +2718,7 @@ When running commands locally, PDD determines which configuration file to use ba This tiered approach allows for both shared project configurations and individual user overrides, while ensuring PDD works out-of-the-box without requiring manual configuration. -**Note:** The setup wizard uses this CSV as the source of truth for provider discovery and model selection. You can manually edit it, but running `pdd setup` again is the recommended way to manage providers and models. +**Note:** You can manually edit this CSV, but running `pdd setup` again is the recommended way to add providers and update models. *Note: This file-based configuration primarily affects local operations and utilities. Cloud execution modes likely rely on centrally managed configurations.* From 339751a87cf958bfcae3fa067cec7ce16190dd92 Mon Sep 17 00:00:00 2001 From: Niti Goyal <158232293+niti-go@users.noreply.github.com> Date: Tue, 17 Feb 2026 02:20:38 -0500 Subject: [PATCH 08/10] Delete pdd/prompts/agentic_setup_autoconfig_LLM.prompt I split this into a 4-step agentic task, but later I realized it's more reliable to do setup deterministically right now. --- .../agentic_setup_autoconfig_LLM.prompt | 229 ------------------ 1 file changed, 229 deletions(-) delete mode 100644 pdd/prompts/agentic_setup_autoconfig_LLM.prompt diff --git a/pdd/prompts/agentic_setup_autoconfig_LLM.prompt b/pdd/prompts/agentic_setup_autoconfig_LLM.prompt deleted file mode 100644 index 38415c6e5..000000000 --- a/pdd/prompts/agentic_setup_autoconfig_LLM.prompt +++ /dev/null @@ -1,229 +0,0 @@ -% You are an expert system administrator configuring PDD (Prompt-Driven Development) for a developer's machine. Your task is to auto-discover all available LLM providers and configure PDD so the user can start using it immediately — with zero user interaction. - -% Context - -You are running Phase 2 of `pdd setup`. Phase 1 already confirmed that the {cli_name} CLI is installed (provider: {provider}). Now you need to discover all available LLM providers, configure models, and set up the project. This phase is fully autonomous — do not prompt the user for any input, though occasionally you can ask for the user to press Enter to continue to next steps. - -% Environment - -- Home directory: {home_dir} -- Shell: {shell_name} -- PDD config directory: {pdd_dir} -- LLM model CSV path: {llm_model_csv_path} -- Current working directory: {cwd} -- .pddrc status: {pddrc_path} - -% CSV Schema - - -The user-level CSV at ~/.pdd/llm_model.csv has columns: -provider,model,input,output,coding_arena_elo,base_url,api_key,max_reasoning_tokens,structured_output,reasoning_type,location - -Example rows: -- Cloud: Anthropic,claude-sonnet-4-5-20250929,3.0,15.0,1500,,ANTHROPIC_API_KEY,0,True,none, -- Local: Ollama,ollama_chat/llama3:70b,0,0,1000,http://localhost:11434,,0,True,none, -- LM Studio: lm_studio,lm_studio/my-model,0,0,1000,http://localhost:1234/v1,,0,True,none, - -Notes: -- Local models have empty api_key column -- Cloud models have the env var NAME (not value) in api_key column -- input/output costs are per 1M tokens - - -% Your Tasks - -Execute these tasks in order. If any single task fails, log the error and continue with remaining tasks. Never abort the entire setup over one failure. - -1. **Create PDD directory** - - Ensure {pdd_dir} exists: `mkdir -p {pdd_dir}` - -2. **Scan for API keys** - Search these locations for API key environment variables. Do NOT display, log, or store actual key values — only report existence and source. - - Sources to check (in priority order): - a. Current shell environment (check with `echo $VAR_NAME` or `env | grep VAR_NAME`) - b. {pdd_dir}/api-env.{shell_name} (parse export/set lines if file exists) - c. .env file in {cwd} (if exists) - d. {home_dir}/.env (if exists) - - Keys to look for (aligned with litellm's PROVIDER_API_KEY_MAP): - - Tier 1 — Major cloud providers: - - ANTHROPIC_API_KEY (Anthropic — Claude models) - - OPENAI_API_KEY (OpenAI — GPT models) - - GOOGLE_API_KEY or GEMINI_API_KEY (Google — Gemini API models) - - VERTEX_CREDENTIALS (Google — Vertex AI models; typically a service account JSON file path or ADC) - - MISTRAL_API_KEY (Mistral AI) - - DEEPSEEK_API_KEY (DeepSeek) - - XAI_API_KEY (xAI — Grok models) - - Tier 2 — Inference platforms & specialized providers: - - GROQ_API_KEY (Groq — fast inference) - - TOGETHERAI_API_KEY or TOGETHER_API_KEY or TOGETHER_AI_API_KEY (Together AI) - - FIREWORKS_API_KEY (Fireworks AI) - - OPENROUTER_API_KEY (OpenRouter — multi-provider gateway) - - COHERE_API_KEY (Cohere) - - PERPLEXITYAI_API_KEY (Perplexity) - - REPLICATE_API_KEY (Replicate) - - DEEPINFRA_API_KEY (DeepInfra) - - CEREBRAS_API_KEY (Cerebras — fast inference) - - Tier 3 — Enterprise & additional providers: - - AZURE_API_KEY (Azure OpenAI) - - AZURE_AI_API_KEY (Azure AI) - - AWS_ACCESS_KEY_ID + AWS_SECRET_ACCESS_KEY (AWS Bedrock — both must be present) - - AI21_API_KEY (AI21 Labs) - - HUGGINGFACE_API_KEY or HF_TOKEN (Hugging Face) - - DATABRICKS_API_KEY (Databricks) - - CLOUDFLARE_API_KEY (Cloudflare Workers AI) - - NOVITA_API_KEY (Novita AI) - - SAMBANOVA_API_KEY (SambaNova) - - WATSONX_API_KEY (IBM watsonx) - - Record which keys exist and where they were found. - -3. **Select models for each discovered provider** - - **Main providers — Anthropic, OpenAI, Google:** - Use ALL matching rows from the reference CSV below. These are pre-vetted models — add them exactly as shown, preserving all column values (model ID, pricing, ELO, reasoning fields, etc.). - - - ../../data/llm_model.csv - - - Matching rules: - - Add Anthropic rows if ANTHROPIC_API_KEY was found - - Add OpenAI rows if OPENAI_API_KEY was found (skip rows with a non-empty base_url unless that URL is reachable) - - Add Google rows with `gemini/` prefix if GEMINI_API_KEY or GOOGLE_API_KEY was found - - Add Google rows with `vertex_ai/` prefix if VERTEX_CREDENTIALS was found - - Skip lm_studio or other local-model rows from this CSV (handle those in step 5) - - **Other providers — use litellm's registry:** - For each API key found outside the main 3 (e.g. Groq, Mistral, xAI, Together, Fireworks, DeepSeek, OpenRouter, Cohere, Perplexity, Cerebras, DeepInfra): - a. Try querying litellm: - ```python - python3 -c " - import litellm, json - prefix = 'groq/' # replace with the provider's litellm prefix - models = [(m, v) for m, v in litellm.model_cost.items() if m.startswith(prefix)] - models.sort(key=lambda x: -x[1].get('input_cost_per_token', 0)) - for m, v in models[:3]: - print(m, v.get('input_cost_per_token',0)*1e6, v.get('output_cost_per_token',0)*1e6) - " - ``` - Use the top 2-3 results (highest input cost ≈ most capable). Set elo to 0 if unknown. - b. If litellm is unavailable or returns nothing, use these fallbacks: - - Groq (GROQ_API_KEY): `groq/moonshotai/kimi-k2-instruct-0905` (input: 1.0, output: 3.0, elo: 1330) - - Mistral (MISTRAL_API_KEY): `mistral/mistral-large-latest` (input: 2.0, output: 6.0, elo: 1414) - - xAI (XAI_API_KEY): `xai/grok-2-latest` (input: 2.0, output: 10.0, elo: 1411) - - Together AI (TOGETHERAI_API_KEY): `together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo` (input: 0.88, output: 0.88, elo: 1300) - - Fireworks (FIREWORKS_API_KEY): `fireworks_ai/accounts/fireworks/models/glm-4p7` (input: 0.60, output: 2.20, elo: 1481) - - DeepSeek (DEEPSEEK_API_KEY): `deepseek/deepseek-chat` (input: 0.14, output: 0.28, elo: 1419) - - OpenRouter (OPENROUTER_API_KEY): `openrouter/anthropic/claude-sonnet-4-5` (input: 3.0, output: 15.0, elo: 1450) - - Cerebras (CEREBRAS_API_KEY): `cerebras/llama3.3-70b` (input: 0.6, output: 0.6, elo: 1300) - - Perplexity (PERPLEXITYAI_API_KEY): `perplexity/sonar-pro` (input: 3.0, output: 15.0, elo: 1280) - - Cohere (COHERE_API_KEY): `cohere/command-r-plus` (input: 2.5, output: 10.0, elo: 1250) - - Azure (AZURE_API_KEY): use azure/ prefix versions of any OpenAI models found - - AWS Bedrock (AWS_ACCESS_KEY_ID + AWS_SECRET_ACCESS_KEY): `bedrock/anthropic.claude-sonnet-4-5-20250929-v1` - -4. **Populate llm_model.csv** - - If {llm_model_csv_path} already exists, read it first and avoid adding duplicate entries (match on provider+model pair) - - If it doesn't exist, create it with the CSV header line: - `provider,model,input,output,coding_arena_elo,base_url,api_key,max_reasoning_tokens,structured_output,reasoning_type,location` - - Append new rows for each discovered provider/model combination - - Use atomic write: write to a temp file first, then move it to the final path - - Format each row matching the schema above (base_url empty for cloud models, reasoning_type "none", location empty) - -5. **Check for local LLMs** - - First, print a visible notification to the user: - ``` - ------------------------------------------------------- - Local LLM Check - ------------------------------------------------------- - If you'd like to use local models with PDD, make sure - your local LLM server is running now: - - Ollama: Run `ollama serve` in another terminal - LM Studio: Start the app → Developer tab → Start Server - - Press Enter to continue... - ------------------------------------------------------- - ``` - - Wait for user to press Enter (`read -p "" dummy` for bash/zsh, or `read dummy` for simpler shells) - - Check if Ollama is running: `curl -s http://localhost:11434/api/tags 2>/dev/null` - - If reachable, parse the JSON response to get model names - - Add each discovered model as: `Ollama,ollama_chat/{model_name},0,0,1000,http://localhost:11434,,0,True,none,` - - Check if LM Studio is running: `curl -s http://localhost:1234/v1/models 2>/dev/null` - - If reachable, parse the JSON response to get model names - - Add each as: `lm_studio,lm_studio/{model_name},0,0,1000,http://localhost:1234/v1,,0,True,none,` - - If neither is reachable, note in the summary: "No local LLMs found. To add local models later, start Ollama or LM Studio and run `pdd setup` again." - -6. **Initialize .pddrc if needed** - If no `.pddrc` file exists in {cwd}: - - Detect project type from files in {cwd}: - - Python: look for setup.py, pyproject.toml, or *.py files - - TypeScript: look for package.json with typescript dependency, or *.ts files - - Go: look for go.mod - - Default to Python if unclear - - Create `.pddrc` with these contents (adjust paths by language): - - For Python projects: - ```yaml - version: "1.0" - - contexts: - default: - defaults: - generate_output_path: "pdd/" - test_output_path: "tests/" - example_output_path: "context/" - default_language: "python" - target_coverage: 80.0 - strength: 1.0 - temperature: 0.0 - budget: 10.0 - max_attempts: 3 - ``` - - For TypeScript projects, use: generate_output_path: "src/", test_output_path: "__tests__/", example_output_path: "examples/", default_language: "typescript" - - For Go projects, use: generate_output_path: ".", test_output_path: ".", example_output_path: "examples/", default_language: "go" - -7. **Test one model** - - Pick the first cloud model from the CSV that has a configured API key - - Run a minimal test: `python3 -c "import litellm; r = litellm.completion(model='{model_name}', messages=[{{'role':'user','content':'Say OK'}}], timeout=30); print('OK:', r.choices[0].message.content)"` - - If litellm is not available, skip this step - - Report success/failure but do not block on failure - -8. **Print summary** - Print a clear, formatted summary: - - ``` - === PDD Setup Complete === - - API Keys Found: - {KEY_NAME} {source} - ... - - Models Configured: {N} total - {Provider}: {model1}, {model2} - ... - - Local LLMs: {none found | Provider: model1, model2, ...} - - .pddrc: {Created at ./.pddrc | Already exists | Skipped (not in a project directory)} - - Model Test: {model-name} -> {OK (0.3s) | FAILED: error | Skipped} - - You're all set! Run `pdd generate` or `pdd sync` to start using PDD. - ``` - -% Important Rules - -- NEVER display, log, or store actual API key values — only report whether they exist and where they were found -- Use atomic file writes for CSV modifications (write to temp file, then rename/move) -- Do not fail if any single step fails — log the error and continue with remaining steps -- If litellm is not installed in the Python environment, use the hardcoded model recommendations above instead -- Minimize user interaction — only ask the user to press Enter at natural pause points (e.g., before the local LLM scan). Never prompt for configuration decisions. -- If the CSV already has entries, preserve them and only add new discoveries (no duplicates) -- Shell-appropriate syntax for api-env files: bash/zsh use `export KEY=value`, fish uses `set -gx KEY value` From 4c3505f0f9d12663f54a8f08400cd2df4b831f27 Mon Sep 17 00:00:00 2001 From: Niti Goyal Date: Fri, 20 Feb 2026 09:19:24 -0500 Subject: [PATCH 09/10] Major pdd setup changes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Adds support for many more LiteLLM-supported providers (Vertex AI, AWS Bedrock, Azure, etc.) - The api_key column now supports pipe-delimited fields (e.g. VERTEXAI_PROJECT|VERTEXAI_LOCATION|GOOGLE_APPLICATION_CREDENTIALS) for providers whose auth requires multiple credentials - Updated pdd setup documentation - Update llm_invoke api_key handling to support the new pipe-delimited credentials format and generalized to remove provider-specific logic --- README.md | 38 +- SETUP_WITH_GEMINI.md | 28 +- context/cli_detector_example.py | 30 +- context/litellm_registry_example.py | 69 - context/pddrc_initializer_example.py | 40 +- context/provider_manager_example.py | 40 +- context/setup_tool_example.py | 82 +- docs/SETUP_GUIDE.md | 382 ----- pdd/api_key_scanner.py | 10 +- pdd/cli_detector.py | 256 +++- pdd/data/llm_model.csv | 286 +++- pdd/generate_model_catalog.py | 711 ++++++++++ pdd/litellm_registry.py | 312 ---- pdd/llm_invoke.py | 229 ++- pdd/model_tester.py | 150 +- pdd/pddrc_initializer.py | 2 +- pdd/prompts/litellm_registry_python.prompt | 57 - pdd/provider_manager.py | 433 +++--- pdd/setup_tool.py | 845 +++++++---- tests/test_api_key_scanner.py | 880 ++++++------ tests/test_cli_detector.py | 1484 ++++++++++---------- tests/test_litellm_registry.py | 561 -------- tests/test_model_tester.py | 612 +++++--- tests/test_pddrc_initializer.py | 395 ++++-- tests/test_provider_manager.py | 1236 +++++++--------- tests/test_setup_tool.py | 1084 +++++++++----- 26 files changed, 5391 insertions(+), 4861 deletions(-) delete mode 100644 context/litellm_registry_example.py delete mode 100644 docs/SETUP_GUIDE.md create mode 100644 pdd/generate_model_catalog.py delete mode 100644 pdd/litellm_registry.py delete mode 100644 pdd/prompts/litellm_registry_python.prompt delete mode 100644 tests/test_litellm_registry.py diff --git a/README.md b/README.md index 759757926..ab7f048f9 100644 --- a/README.md +++ b/README.md @@ -167,7 +167,7 @@ For CLI enthusiasts, implement GitHub issues directly: 2. **One Agentic CLI** - Required to run the workflows (install at least one): - **Claude Code**: `npm install -g @anthropic-ai/claude-code` (requires `ANTHROPIC_API_KEY`) - - **Gemini CLI**: `npm install -g @google/gemini-cli` (requires `GOOGLE_API_KEY`) + - **Gemini CLI**: `npm install -g @google/gemini-cli` (requires `GOOGLE_API_KEY` or `GEMINI_API_KEY`) - **Codex CLI**: `npm install -g @openai/codex` (requires `OPENAI_API_KEY`) **Usage:** @@ -227,24 +227,23 @@ Run the comprehensive setup wizard: pdd setup ``` -The setup wizard runs in two phases: -- **Phase 1** — Detects agentic CLI tools (claude, gemini, codex) and offers installation if needed -- **Phase 2** — Auto-configures PDD in 4 deterministic steps: - 1. Scans for API keys across shell, .env, and ~/.pdd files (prompts to add one if none found) - 2. Configures models from a reference CSV based on your available keys - 3. Checks for local LLMs (Ollama, LM Studio) and creates a `.pddrc` config file - 4. Tests a model and prints a summary +The setup wizard runs these steps: + 1. Detects agentic CLI tools (Claude, Gemini, Codex) and offers installation and API key configuration if needed + 2. Scans for API keys across `.env`, and `~/.pdd/api-env.*`, and the shell environment; prompts to add one if none are found + 3. Configures models from a reference CSV `data/llm_model.csv` of top models (ELO ≥ 1400) across all LiteLLM-supported providers based on your available keys + 4. Optionally creates a `.pddrc` project config + 5. Tests the first available model with a real LLM call + 6. Prints a structured summary (CLIs, keys, models, test result) The wizard can be re-run at any time to update keys, add providers, or reconfigure settings. -If you skip this step, the first regular pdd command you run will detect the missing setup files and print a reminder banner so you can finish onboarding later. - -Reload your shell so the new completion and environment hooks are available: -```bash -source ~/.zshrc # or source ~/.bashrc / fish equivalent -``` +> **Important:** After setup completes, source the API environment file so your keys take effect in the current terminal session: +> ```bash +> source ~/.pdd/api-env.zsh # or api-env.bash, depending on your shell +> ``` +> New terminal windows will load keys automatically. -👉 For detailed setup documentation, see [docs/SETUP_GUIDE.md](docs/SETUP_GUIDE.md). For manual configuration, see [SETUP_WITH_GEMINI.md](SETUP_WITH_GEMINI.md). +If you skip this step, the first regular pdd command you run will detect the missing setup files and print a reminder banner so you can finish onboarding later. 5. **Run Hello**: ```bash @@ -1839,7 +1838,7 @@ For the agentic fallback to function, you need to have at least one of the suppo * Requires the `ANTHROPIC_API_KEY` environment variable to be set. 2. **Google Gemini:** * Requires the `gemini` CLI to be installed and in your `PATH`. - * Requires the `GOOGLE_API_KEY` environment variable to be set. + * Requires the `GOOGLE_API_KEY` or `GEMINI_API_KEY` environment variable to be set. 3. **OpenAI Codex/GPT:** * Requires the `codex` CLI to be installed and in your `PATH`. * Requires the `OPENAI_API_KEY` environment variable to be set. @@ -2785,12 +2784,7 @@ The `.pddrc` approach is recommended for team projects as it ensures consistent ### Model Configuration (`llm_model.csv`) -PDD uses a CSV file (`llm_model.csv`) to store information about available AI models, their costs, capabilities, and required API key names. The `pdd setup` command automatically manages this file by: - -- **API key scanning:** Checking shell environment, `.env` files, and `~/.pdd/api-env.*` for provider keys -- **Automatic model configuration:** Matching found keys to a bundled reference CSV and writing matching models to your user CSV -- **Local LLM detection:** Discovering running Ollama and LM Studio servers and adding their models -- **Fallback menu:** Manual options to add providers, test models, or initialize `.pddrc` if auto-configuration fails +PDD uses a CSV file (`llm_model.csv`) to store information about available AI models, their costs, capabilities, and required API key names. When running commands locally, PDD determines which configuration file to use based on the following priority: diff --git a/SETUP_WITH_GEMINI.md b/SETUP_WITH_GEMINI.md index 4eaeda3ea..31a0707ff 100644 --- a/SETUP_WITH_GEMINI.md +++ b/SETUP_WITH_GEMINI.md @@ -60,25 +60,27 @@ Right after installation, let PDD bootstrap its configuration: pdd setup ``` -The interactive setup wizard will: -1. **Scan your environment** for existing API keys from all sources -2. **Show an interactive menu** with options to: - - Add or fix API keys (including Gemini) - - Add local LLMs (Ollama, LM Studio) - - Add custom providers - - Remove providers -3. **Validate your Gemini API key** with a real test request -4. **Guide model selection** with cost transparency -5. **Detect agentic CLI tools** and offer installation -6. **Create .pddrc** for your project +The setup wizard runs these steps: + 1. Detects agentic CLI tools (Claude, Gemini, Codex) and offers installation and API key configuration if needed + 2. Scans for API keys across `.env`, and `~/.pdd/api-env.*`, and the shell environment; prompts to add one if none are found + 3. Configures models from a reference CSV `data/llm_model.csv` of top models (ELO ≥ 1400) across all LiteLLM-supported providers based on your available keys + 4. Optionally creates a `.pddrc` project config + 5. Tests the first available model with a real LLM call + 6. Prints a structured summary (CLIs, keys, models, test result) When adding your Gemini API key: -- Select option `1. Add or fix API keys` from the menu +- Select Gemini CLI as one of the agentic CLI tools - The wizard will detect that `GEMINI_API_KEY` is missing - Paste your API key when prompted (you can create it in the next step if you haven't already) - The wizard tests it immediately and confirms it works -The wizard writes your credentials to `~/.pdd/api-env.zsh` (or `.bash`), updates `llm_model.csv` with your selected models, and reminds you to reload your shell (`source ~/.zshrc`, etc.) so completion and env hooks load. +The wizard writes your credentials to `~/.pdd/api-env.zsh` (or `.bash`) and updates `llm_model.csv` with your selected models. + +> **Important:** After setup completes, source the API environment file so your keys take effect in the current terminal session: +> ```bash +> source ~/.pdd/api-env.zsh # or api-env.bash, depending on your shell +> ``` +> New terminal windows will load keys automatically. If you prefer to configure everything manually—or you're on an offline machine—skip the wizard and follow the manual instructions below. diff --git a/context/cli_detector_example.py b/context/cli_detector_example.py index f0b4f92e8..c6206506b 100644 --- a/context/cli_detector_example.py +++ b/context/cli_detector_example.py @@ -7,40 +7,40 @@ project_root = Path(__file__).resolve().parent.parent sys.path.append(str(project_root)) -from pdd.cli_detector import detect_and_bootstrap_cli, detect_cli_tools +from pdd.cli_detector import detect_and_bootstrap_cli, detect_cli_tools, CliBootstrapResult def main() -> None: """ Demonstrates how to use the cli_detector module to: - 1. Bootstrap an agentic CLI for pdd setup (detect_and_bootstrap_cli) + 1. Bootstrap agentic CLIs for pdd setup (detect_and_bootstrap_cli) 2. Detect installed CLI harnesses (claude, codex, gemini) 3. Cross-reference with available API keys 4. Offer installation for missing CLIs """ # Primary entry point used by pdd setup Phase 1: - # result = detect_and_bootstrap_cli() - # result.cli_name -> "claude" | "codex" | "gemini" | "" - # result.provider -> "Anthropic" | "OpenAI" | "Google" | "" - # result.api_key_configured -> True | False + # results = detect_and_bootstrap_cli() # Returns List[CliBootstrapResult] + # for r in results: + # r.cli_name -> "claude" | "codex" | "gemini" | "" + # r.provider -> "anthropic" | "openai" | "google" | "" + # r.cli_path -> "/usr/local/bin/claude" | "" + # r.api_key_configured -> True | False + # r.skipped -> True | False # Legacy function for detection only: # detect_cli_tools() # Uncomment to run interactively - # Example flow (detect_and_bootstrap_cli): + # Example flow (detect_and_bootstrap_cli with multi-select): # Checking CLI tools... - # (Required for: pdd fix, pdd change, pdd bug) # - # Claude CLI Found at /usr/local/bin/claude - # Codex CLI Not found - # Gemini CLI Not found + # 1. Claude CLI ✓ Found at /usr/local/bin/claude ✓ ANTHROPIC_API_KEY is set + # 2. Codex CLI ✗ Not found ✗ OPENAI_API_KEY not set + # 3. Gemini CLI ✗ Not found ✓ GEMINI_API_KEY is set # - # Using Claude CLI (Anthropic). - # API key: configured + # Select CLIs to use for pdd agentic tools (enter numbers separated by commas, e.g., 1,3): # - # Returns CliBootstrapResult(cli_name="claude", provider="Anthropic", - # api_key_configured=True) + # Returns [CliBootstrapResult(cli_name="claude", ...), CliBootstrapResult(cli_name="gemini", ...)] pass diff --git a/context/litellm_registry_example.py b/context/litellm_registry_example.py deleted file mode 100644 index 38a6d9592..000000000 --- a/context/litellm_registry_example.py +++ /dev/null @@ -1,69 +0,0 @@ -from __future__ import annotations - -import sys -from pathlib import Path - -# Add the project root to sys.path -project_root = Path(__file__).resolve().parent.parent -sys.path.append(str(project_root)) - -from pdd.litellm_registry import ( - is_litellm_available, - get_api_key_env_var, - get_top_providers, - search_providers, - get_models_for_provider, - ProviderInfo, - ModelInfo, -) - - -def main() -> None: - """ - Demonstrates how to use the litellm_registry module to: - 1. Check if litellm data is available - 2. Browse top providers - 3. Search for a provider by name - 4. List chat models for a provider with pricing - 5. Look up API key env var for a provider - """ - - # Check availability - if not is_litellm_available(): - print("litellm is not installed or has no model data.") - return - - # Browse top providers (curated list of ~10 major cloud providers) - print("Top providers:") - for p in get_top_providers(): - print(f" {p.display_name:20s} {p.model_count:3d} chat models key: {p.api_key_env_var}") - print() - - # Search for a provider by substring - results = search_providers("anth") - print(f"Search 'anth': {len(results)} result(s)") - for p in results: - print(f" {p.display_name} ({p.model_count} models)") - print() - - # List models for a specific provider - models = get_models_for_provider("anthropic") - print(f"Anthropic chat models ({len(models)}):") - for m in models[:5]: - print( - f" {m.litellm_id:40s} ${m.input_cost_per_million:>7.2f} in " - f"${m.output_cost_per_million:>7.2f} out " - f"ctx: {m.max_input_tokens}" - ) - print() - - # Look up API key env var - env_var = get_api_key_env_var("anthropic") - print(f"Anthropic API key env var: {env_var}") - - env_var_unknown = get_api_key_env_var("some_unknown_provider") - print(f"Unknown provider env var: {env_var_unknown}") - - -if __name__ == "__main__": - main() diff --git a/context/pddrc_initializer_example.py b/context/pddrc_initializer_example.py index 99dbee780..a5631eb64 100644 --- a/context/pddrc_initializer_example.py +++ b/context/pddrc_initializer_example.py @@ -7,30 +7,34 @@ project_root = Path(__file__).resolve().parent.parent sys.path.append(str(project_root)) -from pdd.pddrc_initializer import offer_pddrc_init +from pdd.pddrc_initializer import _build_pddrc_content, _detect_language def main() -> None: """ - Demonstrates how to use the pddrc_initializer module to: - 1. Check if .pddrc exists in current project - 2. Detect project language (Python/TypeScript/Go) - 3. Offer to create .pddrc with sensible defaults + Demonstrates how to use the pddrc_initializer module. + + The primary entry points are: + - _detect_language(cwd): returns "python", "typescript", "go", or None + - _build_pddrc_content(language): returns YAML string for .pddrc + - offer_pddrc_init(): interactive flow with YAML preview + confirmation + + In practice, `pdd setup` imports _detect_language and _build_pddrc_content + directly for a streamlined flow (no YAML preview). """ - # Run the interactive initialization - # was_created = offer_pddrc_init() # Uncomment to run interactively - - # Example flow: - # No .pddrc found in current project. - # - # Would you like to create one with default settings? - # Default language: python - # Output path: pdd/ - # Test output path: tests/ - # - # Create .pddrc? [Y/n] - # ✓ Created .pddrc with default settings + # Detect language from marker files in cwd + from pathlib import Path + language = _detect_language(Path.cwd()) + print(f"Detected language: {language}") # e.g. "python" or None + + # Build .pddrc content for a given language + content = _build_pddrc_content(language or "python") + print(content) + + # Or use the full interactive flow (shows YAML preview, asks for confirmation): + # from pdd.pddrc_initializer import offer_pddrc_init + # was_created = offer_pddrc_init() pass diff --git a/context/provider_manager_example.py b/context/provider_manager_example.py index 8e9769356..b3a881997 100644 --- a/context/provider_manager_example.py +++ b/context/provider_manager_example.py @@ -12,40 +12,39 @@ add_custom_provider, remove_models_by_provider, remove_individual_models, + parse_api_key_vars, + is_multi_credential, ) def main() -> None: """ Demonstrates how to use the provider_manager module to: - 1. Search/browse litellm's registry to add a provider and specific models + 1. Browse the reference CSV to add a provider and its models 2. Add a custom LiteLLM-compatible provider 3. Remove all models for a provider (comments out the key) 4. Remove individual models from the user CSV + 5. Parse pipe-delimited api_key fields """ - # Example 1: Search/browse providers from litellm's registry - # Shows top ~10 providers, lets you search, pick models, enter API key + # Example 1: Browse providers from the bundled reference CSV + # Shows numbered provider list with model counts, enter API key # add_provider_from_registry() # Uncomment to run interactively # Interactive flow: - # Top providers: - # 1. OpenAI (102 chat models) - # 2. Anthropic (29 chat models) - # ... - # Enter number, or type to search: anthropic + # Add a provider # - # Chat models for Anthropic: - # 1. claude-opus-4-5-20251101 $5.00 $25.00 200,000 - # 2. claude-sonnet-4-5-20250929 $3.00 $15.00 200,000 + # 1. Anthropic (5 models) + # 2. Google Vertex AI (8 models) + # 3. OpenAI (12 models) # ... - # Select models: 1,2 + # Enter number (empty to cancel): 3 # - # ANTHROPIC_API_KEY: sk-ant-... - # ✓ Saved ANTHROPIC_API_KEY to ~/.pdd/api-env.zsh + # OPENAI_API_KEY: sk-proj-... + # ✓ Saved OPENAI_API_KEY to ~/.pdd/api-env.zsh # ✓ Added source line to ~/.zshrc # Key is available now for this session. - # ✓ Added 2 model(s) to ~/.pdd/llm_model.csv + # ✓ Added 12 model(s) for OpenAI to ~/.pdd/llm_model.csv # # NOTE: The API key is immediately available in the current session via os.environ, # so you can test the model right away. New terminal sessions will also have the @@ -63,6 +62,17 @@ def main() -> None: # Lists all models, user picks by number, removes selected rows # remove_individual_models() # Uncomment to run interactively + # Example 5: Utility functions for api_key field parsing + # Useful when working with CSV rows that have pipe-delimited api_key fields + single = parse_api_key_vars("OPENAI_API_KEY") + print(f"Single key vars: {single}") # ['OPENAI_API_KEY'] + + multi = parse_api_key_vars("AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME") + print(f"Multi key vars: {multi}") # ['AWS_ACCESS_KEY_ID', 'AWS_SECRET_ACCESS_KEY', 'AWS_REGION_NAME'] + + print(f"Is multi-credential? {is_multi_credential('A|B')}") # True + print(f"Is multi-credential? {is_multi_credential('OPENAI_API_KEY')}") # False + if __name__ == "__main__": main() diff --git a/context/setup_tool_example.py b/context/setup_tool_example.py index 81903a9cc..ba2d76ff1 100644 --- a/context/setup_tool_example.py +++ b/context/setup_tool_example.py @@ -14,11 +14,11 @@ def main() -> None: """ Demonstrates how to use the setup_tool module to: 1. Launch the two-phase pdd setup flow - 2. Phase 1: Bootstrap an agentic CLI (Claude/Gemini/Codex) - 3. Phase 2: Auto-configure API keys, models, local LLMs, and .pddrc + 2. Phase 1: Bootstrap agentic CLIs (Claude/Gemini/Codex) + 3. Phase 2: Auto-configure API keys, models, and .pddrc The setup flow is mostly automatic. Phase 1 asks 0-2 questions - (which CLI to use), then Phase 2 runs 4 deterministic steps + (which CLIs to use), then Phase 2 runs 3 deterministic steps with "Press Enter" pauses between them. """ @@ -26,58 +26,72 @@ def main() -> None: # run_setup() # Uncomment to run interactively # Example flow: - # +------------------------------+ - # | pdd setup | - # +------------------------------+ + # (PDD ASCII logo in cyan) + # Let's get set up quickly with a solid basic configuration! # # Phase 1 -- CLI Bootstrap # Detected: claude (Anthropic) # API key: configured # - # Ready to auto-configure PDD. Press Enter to continue... - # - # [Step 1/4] Scanning for API keys... - # ANTHROPIC_API_KEY shell environment - # GEMINI_API_KEY shell environment + # ──────────────────────────────────────── + # Scanning for API keys... + # ──────────────────────────────────────── + # ✓ ANTHROPIC_API_KEY shell environment + # ✓ GEMINI_API_KEY shell environment # # 2 API key(s) found. + # You can edit your global API keys in ~/.pdd/api-env.zsh # # Press Enter to continue to the next step... # - # [Step 2/4] Configuring models... - # 3 new model(s) added to ~/.pdd/llm_model.csv - # 4 cloud model(s) configured + # ──────────────────────────────────────── + # Configuring models... + # ──────────────────────────────────────── + # ✓ 3 new model(s) added to ~/.pdd/llm_model.csv + # ✓ 4 model(s) configured # Anthropic: 3 models # Google: 1 model + # ✓ .pddrc detected at /path/to/project/.pddrc # # Press Enter to continue to the next step... # - # [Step 3/4] Checking local LLMs and .pddrc... - # Ollama running -- found llama3.2:3b, openhermes:latest - # LM Studio not running (skip) - # .pddrc already exists at /path/to/project/.pddrc + # ──────────────────────────────────────── + # Testing and summarizing... + # ──────────────────────────────────────── + # Testing anthropic/claude-sonnet-4-5-20250929...... + # ✓ claude-sonnet-4-5-20250929 responded OK (1.2s) # - # Press Enter to continue to the next step... + # PDD Setup Complete! + # + # CLI: ✓ claude configured + # API Keys: ✓ 2 found + # Models: 4 configured (Anthropic: 3, Google: 1) in ~/.pdd/llm_model.csv + # .pddrc: ✓ exists + # Test: ✓ claude-sonnet-4-5-20250929 responded OK (1.2s) + # + # Press Enter to finish, or 'm' for more options: + # + # (user presses Enter) # - # [Step 4/4] Testing and summarizing... - # Testing anthropic/claude-sonnet-4-5-20250929... - # claude-sonnet-4-5-20250929 responded OK (1.2s) + # ──────────────────────────────────────────────────────────────────────────────── + # QUICK START: + # 1. Generate code from the sample prompt: + # pdd generate success_python.prompt + # ──────────────────────────────────────────────────────────────────────────────── + # LEARN MORE: + # • PDD documentation: pdd --help + # • PDD website: https://promptdriven.ai/ + # • Discord community: https://discord.gg/Yp4RTh8bG7 # - # =============================================== - # PDD Setup Complete - # =============================================== + # Full summary saved to PDD-SETUP-SUMMARY.txt # - # API Keys: 2 found - # Models: 4 configured (Anthropic: 3, Google: 1) - # Local: Ollama -- llama3.2:3b, openhermes:latest - # .pddrc: exists - # Test: OK + # --- OR if user enters 'm': --- # - # =============================================== - # Run 'pdd generate' or 'pdd sync' to start. - # =============================================== + # Options: + # 1. Add a provider + # 2. Test a model # - # Setup complete. Happy prompting! + # Select an option (Enter to finish): pass diff --git a/docs/SETUP_GUIDE.md b/docs/SETUP_GUIDE.md deleted file mode 100644 index c31da3f27..000000000 --- a/docs/SETUP_GUIDE.md +++ /dev/null @@ -1,382 +0,0 @@ -# PDD Setup Guide - -This guide covers the comprehensive `pdd setup` command, which helps you configure PDD with API keys, local LLMs, custom providers, and project settings. - -## Overview - -The `pdd setup` wizard provides an interactive menu-driven interface for configuring your PDD installation. It automatically: - -- **Scans your environment** for API keys from all sources (shell, .env, ~/.pdd files) -- **Validates API keys** with real test requests to ensure they work -- **Manages providers** - add, fix, or remove LLM providers -- **Configures local LLMs** - Ollama, LM Studio, or custom endpoints -- **Selects model tiers** - with cost transparency and guidance -- **Detects agentic CLIs** - checks for claude, gemini, codex and offers installation -- **Creates .pddrc** - project configuration with sensible defaults - -## Quick Start - -```bash -pdd setup -``` - -After installation, run the setup wizard. It will scan your environment and present an interactive menu. - -## The Setup Flow - -### 1. Environment Scan - -When you run `pdd setup`, it first scans for API keys: - -``` -═══════════════════════════════════════════════════════ -Scanning for API keys... -═══════════════════════════════════════════════════════ - - ANTHROPIC_API_KEY ✓ Valid (shell environment) - OPENAI_API_KEY ✓ Valid (.env file) - GROQ_API_KEY ✗ Invalid (shell environment) - GEMINI_API_KEY — Not found - FIREWORKS_API_KEY — Not found -``` - -The scan shows: -- **✓ Valid**: Key found and validated with a test request -- **✗ Invalid**: Key found but failed validation -- **— Not found**: No key found in any source - -**Source transparency:** Each key shows where it's loaded from: -- `(shell environment)` - From your shell's environment variables -- `(.env file)` - From the project's .env file -- `(~/.pdd/api-env.zsh)` - From PDD's managed key file - -### 2. Interactive Menu - -After the scan, you see the main menu: - -``` -What would you like to do? - 1. Add or fix API keys - 2. Add a local LLM (Ollama, LM Studio) - 3. Add a custom provider - 4. Remove a provider - 5. Continue → -``` - -#### Option 1: Add or Fix API Keys - -This option shows only providers that are missing or invalid: - -``` -GROQ_API_KEY (currently: invalid): - Enter new key: gsk_abc... - Testing with groq/mixtral-8x7b-32768... ✓ Valid - -GEMINI_API_KEY (currently: not set): - Enter key (or press Enter to skip): AIza... - Testing with gemini/gemini-1.5-flash... ✓ Valid -``` - -**Smart key storage:** -- Keys you **enter during setup** are saved to `~/.pdd/api-env.{{shell}}` -- Keys **already in your environment** are not duplicated - -After adding keys, you return to the main menu with an updated scan. - -#### Option 2: Add a Local LLM - -Local models (Ollama, LM Studio) don't need API keys - they need a `base_url` and model name: - -``` -What tool are you using? - 1. LM Studio (default: localhost:1234) - 2. Ollama (default: localhost:11434) - 3. Other (custom base URL) - Choice: 2 - -Querying Ollama at http://localhost:11434... -Found installed models: - 1. llama3:70b - 2. codellama:34b - 3. mistral:7b - -Which models do you want to add? [1,2,3]: 1,2 -✓ Added ollama_chat/llama3:70b and ollama_chat/codellama:34b to llm_model.csv -``` - -**Features:** -- **Ollama auto-detection**: Queries the API to list installed models -- **LM Studio defaults**: Pre-fills localhost:1234 base URL -- **Custom endpoints**: For any LiteLLM-compatible provider -- **Zero cost**: Local models are set to $0 or $0.0001 costs - -#### Option 3: Add a Custom Provider - -For LiteLLM-compatible providers not in the default CSV (Together AI, Deepinfra, Mistral, etc.): - -``` -Provider prefix (e.g. together_ai, deepinfra, mistral): together_ai -Model name: meta-llama/Llama-3-70b-chat -API key env var name: TOGETHERAI_API_KEY -Base URL (press Enter if standard): -Cost per 1M input tokens (optional): 0.90 -Cost per 1M output tokens (optional): 0.90 - -Testing together_ai/meta-llama/Llama-3-70b-chat... ✓ Valid -✓ Added to llm_model.csv -``` - -This lets you add any provider without manually editing the CSV. - -#### Option 4: Remove a Provider - -Shows configured providers and lets you safely remove one: - -``` -Configured providers: - 1. ANTHROPIC_API_KEY (3 models) - 2. OPENAI_API_KEY (5 models) - 3. GROQ_API_KEY (1 model) - 4. TOGETHERAI_API_KEY (1 model) [custom] - -Remove which provider? 4 - - # Commented out by pdd setup on 2026-02-09 - # export TOGETHERAI_API_KEY='tok_abc...' - - Removed 1 model from llm_model.csv -✓ TOGETHERAI_API_KEY removed -``` - -**Safe removal:** -- Keys are **commented out**, never deleted (easy to recover) -- Model rows are removed from `llm_model.csv` -- Prevents orphaned models in the CSV - -#### Option 5: Continue - -Proceeds to model selection, CLI detection, and .pddrc creation. - -### 3. Model Tier Selection - -After configuring providers (option 5), the wizard shows available models grouped by cost tier: - -``` -Models available for ANTHROPIC_API_KEY: - - # Model Input Output ELO - 1. anthropic/claude-opus-4-5 $5.00 $25.00 1474 - 2. anthropic/claude-sonnet-4-5 $3.00 $15.00 1370 - 3. anthropic/claude-haiku-4-5 $1.00 $5.00 1270 - -Tip: pdd uses --strength (0.0–1.0) to pick models by cost/quality at runtime. -Adding all models gives you the full range. - -Include which models? [1,2,3] (default: all): 2,3 -``` - -**Cost transparency:** -- Shows input/output token costs per million -- Displays ELO ratings for quality comparison -- Explains how `--strength` controls model selection - -**Smart defaults:** -- Press Enter to include all models (recommended) -- Or select specific tiers (e.g., just Haiku + Sonnet to avoid Opus costs) - -### 4. Agentic CLI Detection - -After model selection, setup checks for agentic CLI tools: - -``` -Checking agentic CLI harnesses... -(Required for: pdd fix, pdd change, pdd bug) - - Claude CLI ✓ Found at /usr/local/bin/claude - Codex CLI ✗ Not found - Gemini CLI ✗ Not found - -You have OPENAI_API_KEY but Codex CLI is not installed. - Install with: npm install -g @openai/codex - Install now? [y/N] -``` - -This proactive detection prevents errors when running `pdd fix` or `pdd change`. - -### 5. .pddrc Initialization - -Finally, setup offers to create a `.pddrc` configuration: - -``` -No .pddrc found in current project. - -Would you like to create one with default settings? - Default language: python - Output path: pdd/ - Test output path: tests/ - -Create .pddrc? [Y/n] -``` - -**Auto-detection:** -- Detects language from project files (setup.py, package.json, go.mod) -- Sets conventional paths for that language -- Creates properly formatted YAML configuration - -## API Key Loading Priority - -PDD checks for API keys in this order (highest priority first): - -1. **Shell environment variables** - `export ANTHROPIC_API_KEY=...` -2. **`.env` file** - In the project root -3. **`~/.pdd/api-env.{{shell}}`** - PDD's managed key file - -**Why this order?** -- Shell vars override .env (industry standard with `load_dotenv(override=False)`) -- Allows .env for development defaults, shell vars for production secrets -- Prevents .env from accidentally overwriting intentional shell configs - -**Source transparency:** The setup scan shows exactly which source provides each key. - -## Saving Keys: Smart Storage - -The setup wizard uses smart storage rules: - -- **Keys entered during setup** → Saved to `~/.pdd/api-env.{{shell}}` -- **Keys already in shell/environment** → Not saved (avoids duplicates) - -This prevents duplicating keys managed by Infisical, .env, shell profiles, etc. - -Example: -``` -Saving keys... - GROQ_API_KEY → saved to ~/.pdd/api-env.zsh (entered during setup) - GEMINI_API_KEY → saved to ~/.pdd/api-env.zsh (entered during setup) - ANTHROPIC_API_KEY → skipped (already in shell environment) - OPENAI_API_KEY → skipped (already in .env file) -``` - -## Re-running Setup - -You can run `pdd setup` at any time to: - -- Add new providers or fix invalid keys -- Add local LLM endpoints -- Remove providers -- Update model selections -- Reinstall shell completion - -The wizard always starts with a fresh environment scan, so you see the current state. - -## Manual Configuration (Alternative) - -If you prefer not to use the wizard, you can configure PDD manually: - -### Manual API Key Setup - -Create `~/.pdd/api-env.zsh` (or `.bash`): - -```bash -export ANTHROPIC_API_KEY='sk-ant-...' -export OPENAI_API_KEY='sk-...' -export GEMINI_API_KEY='AIza...' -``` - -Source it from your shell profile (~/.zshrc or ~/.bashrc): - -```bash -# Load PDD API keys -[ -f ~/.pdd/api-env.zsh ] && source ~/.pdd/api-env.zsh -``` - -### Manual .pddrc Setup - -Create `.pddrc` in your project root: - -```yaml -version: "1.0" - -contexts: - default: - defaults: - generate_output_path: "pdd/" - test_output_path: "tests/" - example_output_path: "context/" - default_language: "python" - target_coverage: 80.0 - strength: 1.0 - temperature: 0.0 - budget: 10.0 - max_attempts: 3 -``` - -### Manual llm_model.csv - -See [SETUP_WITH_GEMINI.md](../SETUP_WITH_GEMINI.md) for full manual configuration instructions. - -## Troubleshooting - -### "API key not found" - -Run the setup wizard: -```bash -pdd setup -``` - -It will scan all sources and show you exactly which keys are missing and where existing keys are loaded from. - -### "Invalid API key" - -The setup wizard tests keys immediately with `llm_invoke`. If validation fails: - -1. Check the error message for details (authentication vs network vs config) -2. Verify the key format (some providers have format requirements) -3. Check your account/quota status with the provider - -### Keys in multiple sources - -If a key exists in both .env and shell: - -- **Shell environment takes precedence** (industry standard) -- The setup scan shows which source is active: `(shell environment)` -- This prevents .env from overwriting intentional shell configs - -### Missing Ollama models - -If Ollama auto-detection fails: - -1. Check that Ollama is running: `ollama serve` -2. Verify the API is accessible: `curl http://localhost:11434/api/tags` -3. Fall back to manual model name entry in the wizard - -## Advanced Topics - -### Vertex AI Configuration - -For Google Vertex AI with service accounts: - -1. Create a service account JSON file from Google Cloud Console -2. Set `VERTEX_CREDENTIALS=/path/to/service-account.json` -3. Run `pdd setup` and add Vertex AI models when prompted - -### Multiple Projects - -- **Global keys**: Store in `~/.pdd/api-env.{{shell}}` for all projects -- **Project keys**: Store in project `.env` for project-specific overrides -- **Model preferences**: Each project can have its own `llm_model.csv` in `.pdd/` - -### CI/CD Integration - -For CI/CD pipelines: - -1. Don't use the interactive wizard (it requires user input) -2. Set API keys as environment variables in your CI system -3. Copy a pre-configured `llm_model.csv` to the project or user directory -4. Set `PDD_SKIP_SETUP=1` to bypass setup checks - -## Related Documentation - -- [README.md](../README.md) - Main PDD documentation -- [SETUP_WITH_GEMINI.md](../SETUP_WITH_GEMINI.md) - Manual setup guide -- [ONBOARDING.md](ONBOARDING.md) - Developer onboarding guide -- [whitepaper.md](whitepaper.md) - PDD concepts and architecture diff --git a/pdd/api_key_scanner.py b/pdd/api_key_scanner.py index fcb30d5ef..6dcd1cbf9 100644 --- a/pdd/api_key_scanner.py +++ b/pdd/api_key_scanner.py @@ -57,9 +57,13 @@ def get_provider_key_names() -> List[str]: return [] for row in reader: - api_key_name = row.get("api_key", "").strip() - if api_key_name: - key_names.add(api_key_name) + api_key_field = row.get("api_key", "").strip() + if api_key_field: + # Support pipe-delimited multi-var fields (e.g. "VAR1|VAR2|VAR3") + for var in api_key_field.split("|"): + var = var.strip() + if var: + key_names.add(var) except Exception as e: logger.error("Error reading llm_model.csv: %s", e) diff --git a/pdd/cli_detector.py b/pdd/cli_detector.py index aa495f203..ffadf909d 100644 --- a/pdd/cli_detector.py +++ b/pdd/cli_detector.py @@ -20,7 +20,7 @@ # Maps provider name -> environment variable for API key _API_KEY_ENV_VARS: dict[str, str] = { "anthropic": "ANTHROPIC_API_KEY", - "google": "GOOGLE_API_KEY", + "google": "GEMINI_API_KEY", "openai": "OPENAI_API_KEY", } @@ -88,7 +88,7 @@ ], } -console = Console() +console = Console(highlight=False) @dataclass class CliBootstrapResult: @@ -97,6 +97,7 @@ class CliBootstrapResult: provider: str = "" cli_path: str = "" api_key_configured: bool = False + skipped: bool = False # True when user explicitly skipped CLI setup # --------------------------------------------------------------------------- # Helpers @@ -114,15 +115,15 @@ def _has_api_key(provider: str) -> bool: if not env_var: # Also check fallback keys if provider == "google": - val = os.environ.get("GOOGLE_API_KEY") or os.environ.get("GEMINI_API_KEY") + val = os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY") return bool(val and val.strip()) return False val = os.environ.get(env_var) if val and val.strip(): return True - # Fallback for google: also check GEMINI_API_KEY + # Fallback for google: also check GOOGLE_API_KEY (Vertex AI convention) if provider == "google": - val = os.environ.get("GEMINI_API_KEY") + val = os.environ.get("GOOGLE_API_KEY") return bool(val and val.strip()) return False @@ -261,38 +262,163 @@ def _save_api_key(key_name: str, key_value: str, shell: str) -> bool: return False def _prompt_api_key(provider: str, shell: str) -> bool: - """Prompt user for API key and save it.""" + """Prompt user for API key and save it. Prints save location on success.""" key_name = PROVIDER_PRIMARY_KEY.get(provider, "") if not key_name: return False - + display = PROVIDER_DISPLAY.get(provider, provider) try: key_value = _prompt_input(f" Enter your {display} API key (or press Enter to skip): ").strip() except (EOFError, KeyboardInterrupt): return False - + if not key_value: if provider == "anthropic": console.print(" [dim]Note: Claude CLI may still work with subscription auth.[/dim]") return False - - return _save_api_key(key_name, key_value, shell) + + api_env_path = _get_api_env_file_path(shell) + if _save_api_key(key_name, key_value, shell): + console.print(f" [green]\u2713[/green] {key_name} saved to {api_env_path}") + #console.print(f" [green]\u2713[/green] {key_name} loaded into current session") + return True + return False + + +def _test_cli(cli_name: str, cli_path: str) -> bool: + """Run a quick sanity-check invocation of the CLI. Returns True on success.""" + console.print(f"\n Testing {cli_name}...") + try: + result = subprocess.run( + [cli_path, "--version"], + capture_output=True, + text=True, + timeout=15, + ) + if result.returncode == 0: + version_line = (result.stdout or result.stderr or "").strip().splitlines()[0] if (result.stdout or result.stderr) else "" + console.print(f" [green]\u2713[/green] {cli_name} version {version_line or 'OK'}") + return True + else: + # Some CLIs exit non-zero for --version but still work; try --help + result2 = subprocess.run( + [cli_path, "--help"], + capture_output=True, + text=True, + timeout=15, + ) + if result2.returncode == 0: + console.print(f" [green]\u2713[/green] {cli_name} is responsive") + return True + console.print(f" [red]\u2717[/red] {cli_name} test failed (exit {result.returncode})") + return False + except FileNotFoundError: + console.print(f" [red]\u2717[/red] {cli_name} binary not found at {cli_path}") + return False + except subprocess.TimeoutExpired: + console.print(f" [red]\u2717[/red] {cli_name} test timed out") + return False + except Exception as exc: + console.print(f" [red]\u2717[/red] {cli_name} test error: {exc}") + return False # --------------------------------------------------------------------------- # Public API # --------------------------------------------------------------------------- -def detect_and_bootstrap_cli() -> CliBootstrapResult: +def _bootstrap_single_cli( + cli_entry: Dict[str, object], + shell: str, +) -> CliBootstrapResult: + """Process install/key/test for a single CLI selection. + + Returns a populated CliBootstrapResult (skipped=True on failure). + """ + display_name = str(cli_entry["display_name"]) + sel_provider: str = str(cli_entry["provider"]) + sel_cli_name: str = str(cli_entry["cli_name"]) + sel_path: Optional[str] = str(cli_entry["path"]) if cli_entry["path"] else None + sel_has_key: bool = bool(cli_entry["has_key"]) + + console.print(f"\n [bold]Setting up {display_name}...[/bold]") + + def _cli_skip(reason: str = "") -> CliBootstrapResult: + if reason: + console.print(f" [red]\u2717 {reason}[/red]") + console.print(f" [red]\u2717 {display_name} not configured.[/red]") + return CliBootstrapResult(skipped=True) + + # Install step (if not installed) + if not sel_path: + install_cmd = _INSTALL_COMMANDS[sel_provider] + console.print(f" Install command: [bold]{install_cmd}[/bold]") + try: + install_answer = _prompt_input(" Install now? [y/N]: ").strip().lower() + except (EOFError, KeyboardInterrupt): + console.print() + return _cli_skip() + + if install_answer in ("y", "yes"): + if not _npm_available(): + console.print(" [red]\u2717[/red] npm is not installed. Please install Node.js/npm first.") + console.print(f" Then run: {install_cmd}") + return _cli_skip("npm not available — cannot install CLI") + + console.print(f" Installing {display_name}...") + if _run_install(install_cmd): + sel_path = _find_cli_binary(sel_cli_name) + if sel_path: + console.print(f" [green]\u2713[/green] Installed {sel_cli_name} at {sel_path}") + else: + console.print(" [yellow]Installation completed but CLI not found on PATH.[/yellow]") + return _cli_skip("CLI installed but not found on PATH") + else: + console.print(" [red]Installation failed. Try installing manually.[/red]") + return _cli_skip("installation failed") + else: + return _cli_skip() + + # API key step (if not set) + if not sel_has_key: + sel_has_key = _prompt_api_key(sel_provider, shell) + if not sel_has_key and sel_provider != "anthropic": + console.print(f" [dim]No API key set. {display_name} may have limited functionality.[/dim]") + + # Force CLI test (no option to skip) + _test_cli(sel_cli_name, sel_path or sel_cli_name) + + return CliBootstrapResult( + cli_name=sel_cli_name, + provider=sel_provider, + cli_path=sel_path or "", + api_key_configured=sel_has_key, + ) + + +def detect_and_bootstrap_cli() -> List[CliBootstrapResult]: """Phase 1 entry point for pdd setup. Shows a numbered selection table of all three CLI options with their - install and API-key status, lets the user choose, and walks through - installation and key configuration as needed. + install and API-key status, lets the user choose one or more via + comma-separated input, and walks through installation and key + configuration for each. + + Returns a list of CliBootstrapResult objects (one per selected CLI). + On full skip: returns [CliBootstrapResult(skipped=True)]. """ - console.print("\nChecking CLI tools...\n") + # Import banner helper from setup_tool + from pdd.setup_tool import _print_step_banner + _print_step_banner("Checking CLI tools...") shell = _detect_shell() + def _skip_all(reason: str = "") -> List[CliBootstrapResult]: + """Print red CLI-not-configured warning and return a skipped result.""" + if reason: + console.print(f" [red]\u2717 {reason}[/red]") + console.print(" [red]\u2717 CLI not configured. Run `pdd setup` again to configure it.[/red]") + return [CliBootstrapResult(skipped=True)] + # ------------------------------------------------------------------ # 1. Gather status for each CLI in table order # ------------------------------------------------------------------ @@ -361,81 +487,55 @@ def detect_and_bootstrap_cli() -> CliBootstrapResult: break # ------------------------------------------------------------------ - # 4. Prompt for selection + # 4. Prompt for selection (comma-separated) # ------------------------------------------------------------------ try: - console.print(" Which CLI would you like to use for pdd setup? \[[blue]1[/blue]/[blue]2[/blue]/[blue]3[/blue]]: ", end="") + console.print(r" Select CLIs to use for pdd agentic tools (enter numbers separated by commas, e.g., [blue]1[/blue],[blue]3[/blue]): ", end="") raw = _prompt_input("").strip() except (EOFError, KeyboardInterrupt): - console.print("\n [dim]Skipped CLI setup. You can run `pdd setup` again later.[/dim]") - return CliBootstrapResult() + console.print() + return _skip_all() if raw.lower() in ("q", "n"): - console.print(" [dim]Skipped CLI setup. You can run `pdd setup` again later.[/dim]") - return CliBootstrapResult() - - if raw in ("1", "2", "3"): - selected_idx = int(raw) - 1 - elif raw == "": - selected_idx = default_idx - console.print(f" [dim]Defaulting to {cli_info[selected_idx]['display_name']}[/dim]") - else: - # Invalid input — treat as default - selected_idx = default_idx - console.print(f" [dim]Invalid input. Defaulting to {cli_info[selected_idx]['display_name']}[/dim]") + return _skip_all() - selected = cli_info[selected_idx] - sel_provider: str = str(selected["provider"]) - sel_cli_name: str = str(selected["cli_name"]) - sel_path: Optional[str] = selected["path"] if selected["path"] else None # type: ignore[assignment] - sel_has_key: bool = bool(selected["has_key"]) + # Parse comma-separated selections, deduplicate while preserving order + selected_indices: List[int] = [] + if raw == "": + selected_indices = [default_idx] + console.print(f" [dim]Defaulting to {cli_info[default_idx]['display_name']}[/dim]") + else: + seen: set[int] = set() + parts = [p.strip() for p in raw.split(",")] + for part in parts: + if part in ("1", "2", "3"): + idx = int(part) - 1 + if idx not in seen: + seen.add(idx) + selected_indices.append(idx) + if not selected_indices: + # No valid numbers found — treat as default + selected_indices = [default_idx] + console.print(f" [dim]Invalid input. Defaulting to {cli_info[default_idx]['display_name']}[/dim]") # ------------------------------------------------------------------ - # 5. Install step (if not installed) + # 5. Process each selected CLI # ------------------------------------------------------------------ - if not sel_path: - install_cmd = _INSTALL_COMMANDS[sel_provider] - console.print(f"\n Install command: [bold]{install_cmd}[/bold]") + results: List[CliBootstrapResult] = [] + for sel_idx in selected_indices: try: - install_answer = _prompt_input(" Install now? [y/N]: ").strip().lower() - except (EOFError, KeyboardInterrupt): - console.print("\n [dim]Skipped CLI setup. You can run `pdd setup` again later.[/dim]") - return CliBootstrapResult() - - if install_answer in ("y", "yes"): - if not _npm_available(): - console.print(" [red]\u2717[/red] npm is not installed. Please install Node.js/npm first,") - console.print(f" then run: {install_cmd}") - return CliBootstrapResult() - - console.print(f" Installing {selected['display_name']}...") - if _run_install(install_cmd): - sel_path = _find_cli_binary(sel_cli_name) - if sel_path: - console.print(f" [green]\u2713[/green] Installed {sel_cli_name} at {sel_path}") - else: - console.print(" [yellow]Installation completed but CLI not found on PATH.[/yellow]") - return CliBootstrapResult() - else: - console.print(" [red]Installation failed. Try installing manually.[/red]") - return CliBootstrapResult() - else: - console.print(f" [dim]Skipped installation. Run `{install_cmd}` manually when ready.[/dim]") - return CliBootstrapResult() - - # ------------------------------------------------------------------ - # 6. API key step (if not set) - # ------------------------------------------------------------------ - if not sel_has_key: - sel_has_key = _prompt_api_key(sel_provider, shell) - - # ------------------------------------------------------------------ - return CliBootstrapResult( - cli_name=sel_cli_name, - provider=sel_provider, - cli_path=sel_path or "", - api_key_configured=sel_has_key, - ) + result = _bootstrap_single_cli(cli_info[sel_idx], shell) + results.append(result) + except KeyboardInterrupt: + console.print() + console.print(f" [red]\u2717 {cli_info[sel_idx]['display_name']} not configured.[/red]") + results.append(CliBootstrapResult(skipped=True)) + break # Stop processing remaining CLIs + + if not results: + return _skip_all() + + return results def detect_cli_tools() -> None: diff --git a/pdd/data/llm_model.csv b/pdd/data/llm_model.csv index 26387089f..d3aafa406 100644 --- a/pdd/data/llm_model.csv +++ b/pdd/data/llm_model.csv @@ -1,20 +1,266 @@ -provider,model,input,output,coding_arena_elo,base_url,api_key,max_reasoning_tokens,structured_output,reasoning_type,location -OpenAI,gpt-5-nano,0.05,0.4,1249,,OPENAI_API_KEY,0,True,none, -Google,vertex_ai/gemini-3-flash-preview,0.5,3.0,1442,,VERTEX_CREDENTIALS,0,True,effort,global -Google,vertex_ai/claude-sonnet-4-6,3.0,15.0,1480,,VERTEX_CREDENTIALS,128000,True,budget,global -Google,vertex_ai/gemini-3.1-pro-preview,2.0,12.0,1495,,VERTEX_CREDENTIALS,0,True,effort,global -OpenAI,gpt-5.1-codex-mini,0.25,2.0,1325,,OPENAI_API_KEY,0,True,effort, -OpenAI,gpt-5.2,1.75,14.0,1472,,OPENAI_API_KEY,0,True,effort, -OpenAI,gpt-5.2-codex,1.75,14.0,1472,,OPENAI_API_KEY,0,True,effort, -Google,vertex_ai/deepseek-ai/deepseek-v3.2-maas,0.28,0.42,1450,,VERTEX_CREDENTIALS,0,True,effort,global -Fireworks,fireworks_ai/accounts/fireworks/models/qwen3-coder-480b-a35b-instruct,0.45,1.80,1281,,FIREWORKS_API_KEY,0,False,none, -Google,vertex_ai/claude-opus-4-6,5.0,25.0,1576,,VERTEX_CREDENTIALS,128000,True,budget,global -lm_studio,lm_studio/qwen3-coder-next,0,0,1040,http://localhost:1234/v1,,0,True,none, -lm_studio,lm_studio/openai-gpt-oss-120b-mlx-6,0.0001,0,1082,http://localhost:1234/v1,,0,True,effort, -Fireworks,fireworks_ai/accounts/fireworks/models/glm-5,1.00,3.20,1451,,FIREWORKS_API_KEY,0,False,none, -Fireworks,fireworks_ai/accounts/fireworks/models/kimi-k2p5,0.60,3.00,1449,,FIREWORKS_API_KEY,0,False,none, -Anthropic,anthropic/claude-sonnet-4-6,3.0,15.0,1480,,ANTHROPIC_API_KEY,128000,True,budget, -Anthropic,anthropic/claude-opus-4-6,5.0,25.0,1576,,ANTHROPIC_API_KEY,128000,True,budget, -Anthropic,anthropic/claude-haiku-4-5-20251001,1.0,5.0,1270,,ANTHROPIC_API_KEY,128000,True,budget, -xAI,xai/grok-4-0709,3.0,15.0,1467,,XAI_API_KEY,0,True,effort, -xAI,xai/grok-4-1-fast-reasoning,0.20,0.50,1402,,XAI_API_KEY,0,True,none, +provider,model,input,output,coding_arena_elo,base_url,api_key,max_reasoning_tokens,structured_output,reasoning_type,location +AWS Bedrock,anthropic.claude-opus-4-6-v1,5.0,25.0,1530,,AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME,0,True,effort, +AWS Bedrock,au.anthropic.claude-opus-4-6-v1,5.5,27.5,1530,,AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME,0,True,effort, +Azure AI,azure_ai/claude-opus-4-6,5.0,25.0,1530,,AZURE_AI_API_KEY,128000,True,budget, +Anthropic,claude-opus-4-6,5.0,25.0,1530,,ANTHROPIC_API_KEY,128000,True,budget, +Anthropic,claude-opus-4-6-20260205,5.0,25.0,1530,,ANTHROPIC_API_KEY,128000,True,budget, +AWS Bedrock,eu.anthropic.claude-opus-4-6-v1,5.5,27.5,1530,,AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME,0,True,effort, +Anthropic,fast/claude-opus-4-6,30.0,150.0,1530,,ANTHROPIC_API_KEY,128000,True,budget, +Anthropic,fast/claude-opus-4-6-20260205,30.0,150.0,1530,,ANTHROPIC_API_KEY,128000,True,budget, +Anthropic,fast/us/claude-opus-4-6,30.0,150.0,1530,,ANTHROPIC_API_KEY,128000,True,budget, +AWS Bedrock,global.anthropic.claude-opus-4-6-v1,5.0,25.0,1530,,AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME,0,True,effort, +AWS Bedrock,us.anthropic.claude-opus-4-6-v1,5.5,27.5,1530,,AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME,0,True,effort, +Anthropic,us/claude-opus-4-6,5.5,27.5,1530,,ANTHROPIC_API_KEY,128000,True,budget, +Anthropic,us/claude-opus-4-6-20260205,5.5,27.5,1530,,ANTHROPIC_API_KEY,128000,True,budget, +Vercel AI Gateway,vercel_ai_gateway/anthropic/claude-opus-4.6,5.0,25.0,1530,,VERCEL_AI_GATEWAY_API_KEY,0,True,effort, +Google Vertex AI,vertex_ai/claude-opus-4-6,5.0,25.0,1530,,GOOGLE_APPLICATION_CREDENTIALS|VERTEXAI_PROJECT|VERTEXAI_LOCATION,0,True,effort,global +Google Vertex AI,vertex_ai/claude-opus-4-6@default,5.0,25.0,1530,,GOOGLE_APPLICATION_CREDENTIALS|VERTEXAI_PROJECT|VERTEXAI_LOCATION,0,True,effort,global +Google Vertex AI,gemini-3-pro-preview,2.0,12.0,1501,,GOOGLE_APPLICATION_CREDENTIALS|VERTEXAI_PROJECT|VERTEXAI_LOCATION,0,True,effort,global +Google Gemini,gemini/gemini-3-pro-preview,2.0,12.0,1501,,GEMINI_API_KEY,0,True,effort, +Github Copilot,github_copilot/gemini-3-pro-preview,0.0,0.0,1501,,,0,True,none, +GMI Cloud,gmi/google/gemini-3-pro-preview,2.0,12.0,1501,,GMI_API_KEY,0,True,none, +OpenRouter,openrouter/google/gemini-3-pro-preview,2.0,12.0,1501,,OPENROUTER_API_KEY,0,True,effort, +Replicate,replicate/google/gemini-3-pro,2.0,12.0,1501,,REPLICATE_API_KEY,0,True,none, +Google Vertex AI,vertex_ai/gemini-3-pro-preview,2.0,12.0,1501,,GOOGLE_APPLICATION_CREDENTIALS|VERTEXAI_PROJECT|VERTEXAI_LOCATION,0,True,effort,global +AWS Bedrock,anthropic.claude-opus-4-5-20251101-v1:0,5.0,25.0,1496,,AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME,0,True,effort, +Azure AI,azure_ai/claude-opus-4-5,5.0,25.0,1496,,AZURE_AI_API_KEY,128000,True,budget, +Anthropic,claude-opus-4-5,5.0,25.0,1496,,ANTHROPIC_API_KEY,128000,True,budget, +Anthropic,claude-opus-4-5-20251101,5.0,25.0,1496,,ANTHROPIC_API_KEY,128000,True,budget, +AWS Bedrock,eu.anthropic.claude-opus-4-5-20251101-v1:0,5.0,25.0,1496,,AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME,0,True,effort, +Github Copilot,github_copilot/claude-opus-4.5,0.0,0.0,1496,,,0,True,none, +AWS Bedrock,global.anthropic.claude-opus-4-5-20251101-v1:0,5.0,25.0,1496,,AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME,0,True,effort, +GMI Cloud,gmi/anthropic/claude-opus-4.5,5.0,25.0,1496,,GMI_API_KEY,0,True,none, +OpenRouter,openrouter/anthropic/claude-opus-4.5,5.0,25.0,1496,,OPENROUTER_API_KEY,0,True,effort, +AWS Bedrock,us.anthropic.claude-opus-4-5-20251101-v1:0,5.5,27.5,1496,,AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME,0,True,effort, +Vercel AI Gateway,vercel_ai_gateway/anthropic/claude-opus-4.5,5.0,25.0,1496,,VERCEL_AI_GATEWAY_API_KEY,0,True,effort, +Google Vertex AI,vertex_ai/claude-opus-4-5,5.0,25.0,1496,,GOOGLE_APPLICATION_CREDENTIALS|VERTEXAI_PROJECT|VERTEXAI_LOCATION,0,True,effort,global +Google Vertex AI,vertex_ai/claude-opus-4-5@20251101,5.0,25.0,1496,,GOOGLE_APPLICATION_CREDENTIALS|VERTEXAI_PROJECT|VERTEXAI_LOCATION,0,True,effort,global +AWS Bedrock,anthropic.claude-sonnet-4-6,3.0,15.0,1485,,AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME,0,True,effort, +AWS Bedrock,apac.anthropic.claude-sonnet-4-6,3.3,16.5,1485,,AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME,0,True,effort, +Azure AI,azure_ai/claude-sonnet-4-6,3.0,15.0,1485,,AZURE_AI_API_KEY,128000,True,budget, +Anthropic,claude-sonnet-4-6,3.0,15.0,1485,,ANTHROPIC_API_KEY,128000,True,budget, +AWS Bedrock,eu.anthropic.claude-sonnet-4-6,3.3,16.5,1485,,AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME,0,True,effort, +AWS Bedrock,global.anthropic.claude-sonnet-4-6,3.0,15.0,1485,,AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME,0,True,effort, +AWS Bedrock,us.anthropic.claude-sonnet-4-6,3.3,16.5,1485,,AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME,0,True,effort, +Anthropic,us/claude-sonnet-4-6,3.3,16.5,1485,,ANTHROPIC_API_KEY,128000,True,budget, +Google Vertex AI,vertex_ai/claude-sonnet-4-6,3.0,15.0,1485,,GOOGLE_APPLICATION_CREDENTIALS|VERTEXAI_PROJECT|VERTEXAI_LOCATION,0,True,effort,global +Google Vertex AI,vertex_ai/claude-sonnet-4-6@default,3.0,15.0,1485,,GOOGLE_APPLICATION_CREDENTIALS|VERTEXAI_PROJECT|VERTEXAI_LOCATION,0,True,effort,global +Azure AI,azure_ai/kimi-k2.5,0.6,3.0,1480,,AZURE_AI_API_KEY,0,True,none, +AWS Bedrock,bedrock/ap-northeast-1/moonshotai.kimi-k2.5,0.72,3.6,1480,,AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME,0,True,none, +AWS Bedrock,bedrock/ap-south-1/moonshotai.kimi-k2.5,0.72,3.6,1480,,AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME,0,True,none, +AWS Bedrock,bedrock/ap-southeast-3/moonshotai.kimi-k2.5,0.72,3.6,1480,,AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME,0,True,none, +AWS Bedrock,bedrock/eu-north-1/moonshotai.kimi-k2.5,0.72,3.6,1480,,AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME,0,True,none, +AWS Bedrock,bedrock/moonshotai.kimi-k2.5,0.6,3.03,1480,,AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME,0,True,effort, +AWS Bedrock,bedrock/sa-east-1/moonshotai.kimi-k2.5,0.72,3.6,1480,,AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME,0,True,none, +AWS Bedrock,bedrock/us-east-1/moonshotai.kimi-k2.5,0.6,3.0,1480,,AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME,0,True,none, +AWS Bedrock,bedrock/us-east-2/moonshotai.kimi-k2.5,0.6,3.0,1480,,AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME,0,True,none, +AWS Bedrock,bedrock/us-west-2/moonshotai.kimi-k2.5,0.6,3.0,1480,,AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME,0,True,none, +Moonshot AI,moonshot/kimi-k2.5,0.6,3.0,1480,,MOONSHOT_API_KEY,0,True,none, +AWS Bedrock,moonshotai.kimi-k2.5,0.6,3.0,1480,,AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME,0,True,none, +OpenRouter,openrouter/moonshotai/kimi-k2.5,0.6,3.0,1480,,OPENROUTER_API_KEY,0,True,none, +Together AI,together_ai/moonshotai/Kimi-K2.5,0.5,2.8,1480,,TOGETHERAI_API_KEY,0,True,effort, +AWS Bedrock,anthropic.claude-opus-4-1-20250805-v1:0,15.0,75.0,1475,,AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME,0,True,effort, +Azure AI,azure_ai/claude-opus-4-1,15.0,75.0,1475,,AZURE_AI_API_KEY,128000,True,budget, +Anthropic,claude-opus-4-1,15.0,75.0,1475,,ANTHROPIC_API_KEY,128000,True,budget, +Anthropic,claude-opus-4-1-20250805,15.0,75.0,1475,,ANTHROPIC_API_KEY,128000,True,budget, +AWS Bedrock,eu.anthropic.claude-opus-4-1-20250805-v1:0,15.0,75.0,1475,,AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME,0,True,effort, +OpenRouter,openrouter/anthropic/claude-opus-4.1,15.0,75.0,1475,,OPENROUTER_API_KEY,0,True,effort, +AWS Bedrock,us.anthropic.claude-opus-4-1-20250805-v1:0,15.0,75.0,1475,,AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME,0,True,effort, +Vercel AI Gateway,vercel_ai_gateway/anthropic/claude-opus-4.1,15.0,75.0,1475,,VERCEL_AI_GATEWAY_API_KEY,0,True,effort, +Google Vertex AI,vertex_ai/claude-opus-4-1,15.0,75.0,1475,,GOOGLE_APPLICATION_CREDENTIALS|VERTEXAI_PROJECT|VERTEXAI_LOCATION,0,True,none,global +Google Vertex AI,vertex_ai/claude-opus-4-1@20250805,15.0,75.0,1475,,GOOGLE_APPLICATION_CREDENTIALS|VERTEXAI_PROJECT|VERTEXAI_LOCATION,0,True,none,global +Google Vertex AI,gemini-3-flash-preview,0.5,3.0,1469,,GOOGLE_APPLICATION_CREDENTIALS|VERTEXAI_PROJECT|VERTEXAI_LOCATION,0,True,effort,global +Google Gemini,gemini/gemini-3-flash-preview,0.5,3.0,1469,,GEMINI_API_KEY,0,True,effort, +GMI Cloud,gmi/google/gemini-3-flash-preview,0.5,3.0,1469,,GMI_API_KEY,0,True,none, +OpenRouter,openrouter/google/gemini-3-flash-preview,0.5,3.0,1469,,OPENROUTER_API_KEY,0,True,effort, +Google Vertex AI,vertex_ai/gemini-3-flash-preview,0.5,3.0,1469,,GOOGLE_APPLICATION_CREDENTIALS|VERTEXAI_PROJECT|VERTEXAI_LOCATION,0,True,effort,global +Dashscope,dashscope/qwen3-max,0.0,0.0,1468,,DASHSCOPE_API_KEY,0,True,effort, +Dashscope,dashscope/qwen3-max-preview,0.0,0.0,1468,,DASHSCOPE_API_KEY,0,True,effort, +Novita AI,novita/qwen/qwen3-max,2.11,8.45,1468,,NOVITA_API_KEY,0,True,none, +Azure OpenAI,azure/gpt-5.2,1.75,14.0,1465,,AZURE_API_KEY|AZURE_API_BASE|AZURE_API_VERSION,0,True,effort, +DeepInfra,deepinfra/google/gemini-2.5-pro,1.25,10.0,1465,,DEEPINFRA_API_KEY,0,False,none, +Google Vertex AI,gemini-2.5-pro,1.25,10.0,1465,,GOOGLE_APPLICATION_CREDENTIALS|VERTEXAI_PROJECT|VERTEXAI_LOCATION,0,True,effort,global +Google Gemini,gemini/gemini-2.5-pro,1.25,10.0,1465,,GEMINI_API_KEY,0,True,effort, +Github Copilot,github_copilot/gemini-2.5-pro,0.0,0.0,1465,,,0,True,none, +Github Copilot,github_copilot/gpt-5.2,0.0,0.0,1465,,,0,True,none, +GMI Cloud,gmi/openai/gpt-5.2,1.75,14.0,1465,,GMI_API_KEY,0,True,none, +OpenAI,gpt-5.2,1.75,14.0,1465,,OPENAI_API_KEY,0,True,effort, +OpenRouter,openrouter/google/gemini-2.5-pro,1.25,10.0,1465,,OPENROUTER_API_KEY,0,True,none, +OpenRouter,openrouter/openai/gpt-5.2,1.75,14.0,1465,,OPENROUTER_API_KEY,0,True,effort, +Vercel AI Gateway,vercel_ai_gateway/google/gemini-2.5-pro,2.5,10.0,1465,,VERCEL_AI_GATEWAY_API_KEY,0,True,none, +AWS Bedrock,anthropic.claude-sonnet-4-5-20250929-v1:0,3.0,15.0,1464,,AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME,0,True,effort, +AWS Bedrock,au.anthropic.claude-sonnet-4-5-20250929-v1:0,3.3,16.5,1464,,AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME,0,True,effort, +Azure AI,azure_ai/claude-sonnet-4-5,3.0,15.0,1464,,AZURE_AI_API_KEY,128000,True,budget, +Anthropic,claude-sonnet-4-5,3.0,15.0,1464,,ANTHROPIC_API_KEY,128000,True,budget, +Anthropic,claude-sonnet-4-5-20250929,3.0,15.0,1464,,ANTHROPIC_API_KEY,128000,True,budget, +AWS Bedrock,claude-sonnet-4-5-20250929-v1:0,3.0,15.0,1464,,AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME,0,True,effort, +AWS Bedrock,eu.anthropic.claude-sonnet-4-5-20250929-v1:0,3.3,16.5,1464,,AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME,0,True,effort, +Github Copilot,github_copilot/claude-sonnet-4.5,0.0,0.0,1464,,,0,True,none, +AWS Bedrock,global.anthropic.claude-sonnet-4-5-20250929-v1:0,3.0,15.0,1464,,AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME,0,True,effort, +GMI Cloud,gmi/anthropic/claude-sonnet-4.5,3.0,15.0,1464,,GMI_API_KEY,0,True,none, +AWS Bedrock,jp.anthropic.claude-sonnet-4-5-20250929-v1:0,3.3,16.5,1464,,AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME,0,True,effort, +OpenRouter,openrouter/anthropic/claude-sonnet-4.5,3.0,15.0,1464,,OPENROUTER_API_KEY,0,True,effort, +AWS Bedrock,us.anthropic.claude-sonnet-4-5-20250929-v1:0,3.3,16.5,1464,,AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME,0,True,effort, +Vercel AI Gateway,vercel_ai_gateway/anthropic/claude-sonnet-4.5,3.0,15.0,1464,,VERCEL_AI_GATEWAY_API_KEY,0,True,effort, +Google Vertex AI,vertex_ai/claude-sonnet-4-5,3.0,15.0,1464,,GOOGLE_APPLICATION_CREDENTIALS|VERTEXAI_PROJECT|VERTEXAI_LOCATION,0,True,effort,global +Google Vertex AI,vertex_ai/claude-sonnet-4-5@20250929,3.0,15.0,1464,,GOOGLE_APPLICATION_CREDENTIALS|VERTEXAI_PROJECT|VERTEXAI_LOCATION,0,True,effort,global +Azure OpenAI,azure/gpt-5,1.25,10.0,1460,,AZURE_API_KEY|AZURE_API_BASE|AZURE_API_VERSION,0,True,effort, +Github Copilot,github_copilot/gpt-5,0.0,0.0,1460,,,0,True,none, +GMI Cloud,gmi/openai/gpt-5,1.25,10.0,1460,,GMI_API_KEY,0,True,none, +OpenAI,gpt-5,1.25,10.0,1460,,OPENAI_API_KEY,0,True,effort, +OpenRouter,openrouter/openai/gpt-5,1.25,10.0,1460,,OPENROUTER_API_KEY,0,False,effort, +Replicate,replicate/openai/gpt-5,1.25,10.0,1460,,REPLICATE_API_KEY,0,True,none, +AWS Bedrock,zai.glm-4.7,0.6,2.2,1460,,AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME,0,True,effort, +DeepInfra,deepinfra/Qwen/Qwen3-235B-A22B-Instruct-2507,0.09,0.6,1457,,DEEPINFRA_API_KEY,0,False,none, +Fireworks AI,fireworks_ai/accounts/fireworks/models/qwen3-235b-a22b-instruct-2507,0.22,0.88,1457,,FIREWORKS_AI_API_KEY,0,False,none, +Novita AI,novita/qwen/qwen3-235b-a22b-instruct-2507,0.09,0.58,1457,,NOVITA_API_KEY,0,True,none, +Replicate,replicate/qwen/qwen3-235b-a22b-instruct-2507,0.264,1.06,1457,,REPLICATE_API_KEY,0,True,none, +Google Vertex AI,vertex_ai/qwen/qwen3-235b-a22b-instruct-2507-maas,0.25,1.0,1457,,GOOGLE_APPLICATION_CREDENTIALS|VERTEXAI_PROJECT|VERTEXAI_LOCATION,0,True,none,global +W&B Inference,wandb/Qwen/Qwen3-235B-A22B-Instruct-2507,10000.0,10000.0,1457,,WANDB_API_KEY,0,False,none, +Azure AI,azure_ai/grok-4,3.0,15.0,1453,,AZURE_AI_API_KEY,0,True,none, +Oci,oci/xai.grok-4,3.0,15.0,1453,,OCI_API_KEY,0,True,none, +OpenRouter,openrouter/x-ai/grok-4,3.0,15.0,1453,,OPENROUTER_API_KEY,0,True,effort, +Replicate,replicate/xai/grok-4,7.2,36.0,1453,,REPLICATE_API_KEY,0,True,none, +Vercel AI Gateway,vercel_ai_gateway/xai/grok-4,3.0,15.0,1453,,VERCEL_AI_GATEWAY_API_KEY,0,True,none, +xAI,xai/grok-4,3.0,15.0,1453,,XAI_API_KEY,0,True,none, +xAI,xai/grok-4-latest,3.0,15.0,1453,,XAI_API_KEY,0,True,none, +Azure OpenAI,azure/eu/gpt-5.1,1.38,11.0,1450,,AZURE_API_KEY|AZURE_API_BASE|AZURE_API_VERSION,0,True,effort, +Azure OpenAI,azure/global/gpt-5.1,1.25,10.0,1450,,AZURE_API_KEY|AZURE_API_BASE|AZURE_API_VERSION,0,True,effort, +Azure OpenAI,azure/gpt-5.1,1.25,10.0,1450,,AZURE_API_KEY|AZURE_API_BASE|AZURE_API_VERSION,0,True,effort, +Azure OpenAI,azure/mistral-large-latest,8.0,24.0,1450,,AZURE_API_KEY|AZURE_API_BASE|AZURE_API_VERSION,0,True,none, +Azure OpenAI,azure/us/gpt-5.1,1.38,11.0,1450,,AZURE_API_KEY|AZURE_API_BASE|AZURE_API_VERSION,0,True,effort, +Azure AI,azure_ai/mistral-large,4.0,12.0,1450,,AZURE_AI_API_KEY,0,True,none, +Azure AI,azure_ai/mistral-large-3,0.5,1.5,1450,,AZURE_AI_API_KEY,0,True,none, +Azure AI,azure_ai/mistral-large-latest,2.0,6.0,1450,,AZURE_AI_API_KEY,0,True,none, +AWS Bedrock,bedrock/ap-northeast-1/moonshotai.kimi-k2-thinking,0.73,3.03,1450,,AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME,0,True,effort, +AWS Bedrock,bedrock/ap-south-1/moonshotai.kimi-k2-thinking,0.71,2.94,1450,,AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME,0,True,effort, +AWS Bedrock,bedrock/moonshotai.kimi-k2-thinking,0.73,3.03,1450,,AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME,0,True,effort, +AWS Bedrock,bedrock/sa-east-1/moonshotai.kimi-k2-thinking,0.73,3.03,1450,,AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME,0,True,effort, +AWS Bedrock,bedrock/us-east-1/moonshotai.kimi-k2-thinking,0.6,2.5,1450,,AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME,0,True,effort, +AWS Bedrock,bedrock/us-east-2/moonshotai.kimi-k2-thinking,0.6,2.5,1450,,AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME,0,True,effort, +AWS Bedrock,bedrock/us-west-2/moonshotai.kimi-k2-thinking,0.6,2.5,1450,,AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME,0,True,effort, +Fireworks AI,fireworks_ai/accounts/fireworks/models/kimi-k2-thinking,0.6,2.5,1450,,FIREWORKS_AI_API_KEY,0,True,none, +Github Copilot,github_copilot/gpt-5.1,0.0,0.0,1450,,,0,True,none, +GMI Cloud,gmi/moonshotai/Kimi-K2-Thinking,0.8,1.2,1450,,GMI_API_KEY,0,False,none, +GMI Cloud,gmi/openai/gpt-5.1,1.25,10.0,1450,,GMI_API_KEY,0,True,none, +OpenAI,gpt-5.1,1.25,10.0,1450,,OPENAI_API_KEY,0,True,effort, +Mistral AI,mistral/mistral-large-3,0.5,1.5,1450,,MISTRAL_API_KEY,0,True,none, +Mistral AI,mistral/mistral-large-latest,2.0,6.0,1450,,MISTRAL_API_KEY,0,True,none, +Moonshot AI,moonshot/kimi-k2-thinking,0.6,2.5,1450,,MOONSHOT_API_KEY,0,True,none, +Novita AI,novita/moonshotai/kimi-k2-thinking,0.6,2.5,1450,,NOVITA_API_KEY,0,True,effort, +OpenRouter,openrouter/mistralai/mistral-large,8.0,24.0,1450,,OPENROUTER_API_KEY,0,False,none, +Snowflake,snowflake/mistral-large,0.0,0.0,1450,,SNOWFLAKE_API_KEY,0,False,none, +Google Vertex AI,vertex_ai/mistral-large@2407,2.0,6.0,1450,,GOOGLE_APPLICATION_CREDENTIALS|VERTEXAI_PROJECT|VERTEXAI_LOCATION,0,True,none,global +Google Vertex AI,vertex_ai/mistral-large@2411-001,2.0,6.0,1450,,GOOGLE_APPLICATION_CREDENTIALS|VERTEXAI_PROJECT|VERTEXAI_LOCATION,0,True,none,global +Google Vertex AI,vertex_ai/mistral-large@latest,2.0,6.0,1450,,GOOGLE_APPLICATION_CREDENTIALS|VERTEXAI_PROJECT|VERTEXAI_LOCATION,0,True,none,global +Google Vertex AI,vertex_ai/moonshotai/kimi-k2-thinking-maas,0.6,2.5,1450,,GOOGLE_APPLICATION_CREDENTIALS|VERTEXAI_PROJECT|VERTEXAI_LOCATION,0,True,none,global +DeepInfra,deepinfra/Qwen/Qwen3-235B-A22B-Thinking-2507,0.3,2.9,1442,,DEEPINFRA_API_KEY,0,False,none, +Fireworks AI,fireworks_ai/accounts/fireworks/models/qwen3-235b-a22b-thinking-2507,0.22,0.88,1442,,FIREWORKS_AI_API_KEY,0,False,none, +Novita AI,novita/qwen/qwen3-235b-a22b-thinking-2507,0.3,3.0,1442,,NOVITA_API_KEY,0,True,effort, +OpenRouter,openrouter/qwen/qwen3-235b-a22b-thinking-2507,0.11,0.6,1442,,OPENROUTER_API_KEY,0,True,effort, +Together AI,together_ai/Qwen/Qwen3-235B-A22B-Thinking-2507,0.65,3.0,1442,,TOGETHERAI_API_KEY,0,True,none, +W&B Inference,wandb/Qwen/Qwen3-235B-A22B-Thinking-2507,10000.0,10000.0,1442,,WANDB_API_KEY,0,False,none, +Azure OpenAI,azure/o3,2.0,8.0,1441,,AZURE_API_KEY|AZURE_API_BASE|AZURE_API_VERSION,0,True,effort, +OpenAI,o3,2.0,8.0,1441,,OPENAI_API_KEY,0,True,effort, +Vercel AI Gateway,vercel_ai_gateway/openai/o3,2.0,8.0,1441,,VERCEL_AI_GATEWAY_API_KEY,0,True,none, +Azure AI,azure_ai/global/grok-3,3.0,15.0,1439,,AZURE_AI_API_KEY,0,True,none, +Azure AI,azure_ai/grok-3,3.0,15.0,1439,,AZURE_AI_API_KEY,0,True,none, +Oci,oci/xai.grok-3,3.0,15.0,1439,,OCI_API_KEY,0,True,none, +Vercel AI Gateway,vercel_ai_gateway/xai/grok-3,3.0,15.0,1439,,VERCEL_AI_GATEWAY_API_KEY,0,True,none, +xAI,xai/grok-3,3.0,15.0,1439,,XAI_API_KEY,0,True,none, +xAI,xai/grok-3-latest,3.0,15.0,1439,,XAI_API_KEY,0,True,none, +AWS Bedrock,anthropic.claude-haiku-4-5-20251001-v1:0,1.0,5.0,1436,,AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME,0,True,effort, +AWS Bedrock,anthropic.claude-haiku-4-5@20251001,1.0,5.0,1436,,AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME,0,True,effort, +AWS Bedrock,apac.anthropic.claude-haiku-4-5-20251001-v1:0,1.1,5.5,1436,,AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME,0,True,effort, +AWS Bedrock,au.anthropic.claude-haiku-4-5-20251001-v1:0,1.1,5.5,1436,,AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME,0,True,effort, +Azure AI,azure_ai/claude-haiku-4-5,1.0,5.0,1436,,AZURE_AI_API_KEY,128000,True,budget, +Anthropic,claude-haiku-4-5,1.0,5.0,1436,,ANTHROPIC_API_KEY,128000,True,budget, +Anthropic,claude-haiku-4-5-20251001,1.0,5.0,1436,,ANTHROPIC_API_KEY,128000,True,budget, +DeepInfra,deepinfra/deepseek-ai/DeepSeek-R1-0528,0.5,2.15,1436,,DEEPINFRA_API_KEY,0,False,none, +AWS Bedrock,eu.anthropic.claude-haiku-4-5-20251001-v1:0,1.1,5.5,1436,,AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME,0,True,effort, +Fireworks AI,fireworks_ai/accounts/fireworks/models/deepseek-r1-0528,3.0,8.0,1436,,FIREWORKS_AI_API_KEY,0,True,none, +Github Copilot,github_copilot/claude-haiku-4.5,0.0,0.0,1436,,,0,True,none, +AWS Bedrock,global.anthropic.claude-haiku-4-5-20251001-v1:0,1.0,5.0,1436,,AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME,0,True,effort, +Hyperbolic,hyperbolic/deepseek-ai/DeepSeek-R1-0528,0.25,0.25,1436,,HYPERBOLIC_API_KEY,0,True,none, +AWS Bedrock,jp.anthropic.claude-haiku-4-5-20251001-v1:0,1.1,5.5,1436,,AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME,0,True,effort, +Lambda AI,lambda_ai/deepseek-r1-0528,0.2,0.6,1436,,LAMBDA_API_KEY,0,True,effort, +Novita AI,novita/deepseek/deepseek-r1-0528,0.7,2.5,1436,,NOVITA_API_KEY,0,True,effort, +OpenRouter,openrouter/anthropic/claude-haiku-4.5,1.0,5.0,1436,,OPENROUTER_API_KEY,0,True,effort, +OpenRouter,openrouter/deepseek/deepseek-r1-0528,0.5,2.15,1436,,OPENROUTER_API_KEY,0,True,effort, +AWS Bedrock,us.anthropic.claude-haiku-4-5-20251001-v1:0,1.1,5.5,1436,,AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME,0,True,effort, +Vercel AI Gateway,vercel_ai_gateway/anthropic/claude-haiku-4.5,1.0,5.0,1436,,VERCEL_AI_GATEWAY_API_KEY,0,True,effort, +Google Vertex AI,vertex_ai/claude-haiku-4-5@20251001,1.0,5.0,1436,,GOOGLE_APPLICATION_CREDENTIALS|VERTEXAI_PROJECT|VERTEXAI_LOCATION,0,True,effort,global +Google Vertex AI,vertex_ai/deepseek-ai/deepseek-r1-0528-maas,1.35,5.4,1436,,GOOGLE_APPLICATION_CREDENTIALS|VERTEXAI_PROJECT|VERTEXAI_LOCATION,0,True,effort,global +W&B Inference,wandb/deepseek-ai/DeepSeek-R1-0528,135000.0,540000.0,1436,,WANDB_API_KEY,0,False,none, +Azure AI,azure_ai/deepseek-v3.2,0.58,1.68,1431,,AZURE_AI_API_KEY,128000,True,budget, +DeepSeek,deepseek/deepseek-v3.2,0.28,0.4,1431,,DEEPSEEK_API_KEY,0,True,effort, +GMI Cloud,gmi/deepseek-ai/DeepSeek-V3.2,0.28,0.4,1431,,GMI_API_KEY,0,True,none, +Novita AI,novita/deepseek/deepseek-v3.2,0.269,0.4,1431,,NOVITA_API_KEY,0,True,effort, +OpenRouter,openrouter/deepseek/deepseek-v3.2,0.28,0.4,1431,,OPENROUTER_API_KEY,0,True,effort, +Google Vertex AI,vertex_ai/deepseek-ai/deepseek-v3.2-maas,0.56,1.68,1431,,GOOGLE_APPLICATION_CREDENTIALS|VERTEXAI_PROJECT|VERTEXAI_LOCATION,0,True,effort,global +AWS Bedrock,bedrock/ap-northeast-1/minimax.minimax-m2.1,0.36,1.44,1430,,AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME,0,True,none, +AWS Bedrock,bedrock/ap-south-1/minimax.minimax-m2.1,0.36,1.44,1430,,AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME,0,True,none, +AWS Bedrock,bedrock/ap-southeast-3/minimax.minimax-m2.1,0.36,1.44,1430,,AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME,0,True,none, +AWS Bedrock,bedrock/eu-central-1/minimax.minimax-m2.1,0.36,1.44,1430,,AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME,0,True,none, +AWS Bedrock,bedrock/eu-north-1/minimax.minimax-m2.1,0.36,1.44,1430,,AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME,0,True,none, +AWS Bedrock,bedrock/eu-south-1/minimax.minimax-m2.1,0.36,1.44,1430,,AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME,0,True,none, +AWS Bedrock,bedrock/eu-west-1/minimax.minimax-m2.1,0.36,1.44,1430,,AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME,0,True,none, +AWS Bedrock,bedrock/eu-west-2/minimax.minimax-m2.1,0.47,1.86,1430,,AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME,0,True,none, +AWS Bedrock,bedrock/sa-east-1/minimax.minimax-m2.1,0.36,1.44,1430,,AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME,0,True,none, +AWS Bedrock,bedrock/us-east-1/minimax.minimax-m2.1,0.3,1.2,1430,,AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME,0,True,none, +AWS Bedrock,bedrock/us-east-2/minimax.minimax-m2.1,0.3,1.2,1430,,AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME,0,True,none, +AWS Bedrock,bedrock/us-west-2/minimax.minimax-m2.1,0.3,1.2,1430,,AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME,0,True,none, +DeepInfra,deepinfra/deepseek-ai/DeepSeek-V3.1,0.27,1.0,1430,,DEEPINFRA_API_KEY,0,False,effort, +Fireworks AI,fireworks_ai/accounts/fireworks/models/minimax-m2,0.3,1.2,1430,,FIREWORKS_AI_API_KEY,0,False,none, +GMI Cloud,gmi/MiniMaxAI/MiniMax-M2.1,0.3,1.2,1430,,GMI_API_KEY,0,False,none, +AWS Bedrock,minimax.minimax-m2,0.3,1.2,1430,,AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME,0,False,none, +AWS Bedrock,minimax.minimax-m2.1,0.3,1.2,1430,,AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME,0,True,none, +Novita AI,novita/deepseek/deepseek-v3.1,0.27,1.0,1430,,NOVITA_API_KEY,0,True,effort, +Replicate,replicate/deepseek-ai/deepseek-v3.1,0.672,2.016,1430,,REPLICATE_API_KEY,0,True,effort, +SambaNova,sambanova/DeepSeek-V3.1,3.0,4.5,1430,,SAMBANOVA_API_KEY,0,True,effort, +Together AI,together_ai/deepseek-ai/DeepSeek-V3.1,0.6,1.7,1430,,TOGETHERAI_API_KEY,0,True,effort, +Google Vertex AI,vertex_ai/deepseek-ai/deepseek-v3.1-maas,1.35,5.4,1430,,GOOGLE_APPLICATION_CREDENTIALS|VERTEXAI_PROJECT|VERTEXAI_LOCATION,0,True,effort,global +Google Vertex AI,vertex_ai/minimaxai/minimax-m2-maas,0.3,1.2,1430,,GOOGLE_APPLICATION_CREDENTIALS|VERTEXAI_PROJECT|VERTEXAI_LOCATION,0,True,none,global +W&B Inference,wandb/deepseek-ai/DeepSeek-V3.1,55000.0,165000.0,1430,,WANDB_API_KEY,0,False,none, +DeepInfra,deepinfra/google/gemini-2.5-flash,0.3,2.5,1420,,DEEPINFRA_API_KEY,0,False,none, +Google Vertex AI,gemini-2.5-flash,0.3,2.5,1420,,GOOGLE_APPLICATION_CREDENTIALS|VERTEXAI_PROJECT|VERTEXAI_LOCATION,0,True,effort,global +Google Gemini,gemini/gemini-2.5-flash,0.3,2.5,1420,,GEMINI_API_KEY,0,True,effort, +OpenRouter,openrouter/google/gemini-2.5-flash,0.3,2.5,1420,,OPENROUTER_API_KEY,0,True,none, +Replicate,replicate/google/gemini-2.5-flash,2.5,2.5,1420,,REPLICATE_API_KEY,0,True,none, +Vercel AI Gateway,vercel_ai_gateway/google/gemini-2.5-flash,0.3,2.5,1420,,VERCEL_AI_GATEWAY_API_KEY,0,True,none, +Azure OpenAI,azure/gpt-4.5-preview,75.0,150.0,1419,,AZURE_API_KEY|AZURE_API_BASE|AZURE_API_VERSION,0,True,none, +Azure OpenAI,azure/gpt-5-mini,0.25,2.0,1419,,AZURE_API_KEY|AZURE_API_BASE|AZURE_API_VERSION,0,True,effort, +Github Copilot,github_copilot/gpt-5-mini,0.0,0.0,1419,,,0,True,none, +OpenAI,gpt-4.5-preview,75.0,150.0,1419,,OPENAI_API_KEY,0,True,none, +OpenAI,gpt-5-mini,0.25,2.0,1419,,OPENAI_API_KEY,0,True,effort, +OpenRouter,openrouter/openai/gpt-5-mini,0.25,2.0,1419,,OPENROUTER_API_KEY,0,False,effort, +Replicate,replicate/openai/gpt-5-mini,0.25,2.0,1419,,REPLICATE_API_KEY,0,True,none, +DeepInfra,deepinfra/Qwen/Qwen3-Coder-480B-A35B-Instruct,0.4,1.6,1406,,DEEPINFRA_API_KEY,0,False,none, +DeepInfra,deepinfra/Qwen/Qwen3-Coder-480B-A35B-Instruct-Turbo,0.29,1.2,1406,,DEEPINFRA_API_KEY,0,False,none, +Fireworks AI,fireworks_ai/accounts/fireworks/models/qwen3-coder-480b-a35b-instruct,0.45,1.8,1406,,FIREWORKS_AI_API_KEY,0,False,effort, +Novita AI,novita/qwen/qwen3-coder-480b-a35b-instruct,0.3,1.3,1406,,NOVITA_API_KEY,0,True,none, +AWS Bedrock,qwen.qwen3-coder-480b-a35b-v1:0,0.22,1.8,1406,,AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME,0,True,effort, +Together AI,together_ai/Qwen/Qwen3-Coder-480B-A35B-Instruct-FP8,2.0,2.0,1406,,TOGETHERAI_API_KEY,0,True,none, +Google Vertex AI,vertex_ai/qwen/qwen3-coder-480b-a35b-instruct-maas,1.0,4.0,1406,,GOOGLE_APPLICATION_CREDENTIALS|VERTEXAI_PROJECT|VERTEXAI_LOCATION,0,True,none,global +W&B Inference,wandb/Qwen/Qwen3-Coder-480B-A35B-Instruct,100000.0,150000.0,1406,,WANDB_API_KEY,0,False,none, +AWS Bedrock,anthropic.claude-opus-4-20250514-v1:0,15.0,75.0,1405,,AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME,0,True,effort, +Anthropic,claude-4-opus-20250514,15.0,75.0,1405,,ANTHROPIC_API_KEY,128000,True,budget, +Anthropic,claude-opus-4-20250514,15.0,75.0,1405,,ANTHROPIC_API_KEY,128000,True,budget, +DeepInfra,deepinfra/anthropic/claude-4-opus,16.5,82.5,1405,,DEEPINFRA_API_KEY,0,False,none, +AWS Bedrock,eu.anthropic.claude-opus-4-20250514-v1:0,15.0,75.0,1405,,AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME,0,True,effort, +GMI Cloud,gmi/anthropic/claude-opus-4,15.0,75.0,1405,,GMI_API_KEY,0,True,none, +OpenRouter,openrouter/anthropic/claude-opus-4,15.0,75.0,1405,,OPENROUTER_API_KEY,0,True,effort, +AWS Bedrock,us.anthropic.claude-opus-4-20250514-v1:0,15.0,75.0,1405,,AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME,0,True,effort, +Vercel AI Gateway,vercel_ai_gateway/anthropic/claude-4-opus,15.0,75.0,1405,,VERCEL_AI_GATEWAY_API_KEY,0,True,none, +Vercel AI Gateway,vercel_ai_gateway/anthropic/claude-opus-4,15.0,75.0,1405,,VERCEL_AI_GATEWAY_API_KEY,0,True,effort, +Google Vertex AI,vertex_ai/claude-opus-4,15.0,75.0,1405,,GOOGLE_APPLICATION_CREDENTIALS|VERTEXAI_PROJECT|VERTEXAI_LOCATION,0,True,effort,global +Google Vertex AI,vertex_ai/claude-opus-4@20250514,15.0,75.0,1405,,GOOGLE_APPLICATION_CREDENTIALS|VERTEXAI_PROJECT|VERTEXAI_LOCATION,0,True,effort,global +Moonshot AI,moonshot/kimi-k2-0905-preview,0.6,2.5,1403,,MOONSHOT_API_KEY,0,True,none, +Novita AI,novita/moonshotai/kimi-k2-0905,0.6,2.5,1403,,NOVITA_API_KEY,0,True,none, +DeepInfra,deepinfra/moonshotai/Kimi-K2-Instruct,0.5,2.0,1402,,DEEPINFRA_API_KEY,0,False,none, +Fireworks AI,fireworks_ai/accounts/fireworks/models/kimi-k2-instruct,0.6,2.5,1402,,FIREWORKS_AI_API_KEY,0,True,none, +Hyperbolic,hyperbolic/moonshotai/Kimi-K2-Instruct,2.0,2.0,1402,,HYPERBOLIC_API_KEY,0,True,none, +Moonshot AI,moonshot/kimi-k2-0711-preview,0.6,2.5,1402,,MOONSHOT_API_KEY,0,True,none, +Novita AI,novita/moonshotai/kimi-k2-instruct,0.57,2.3,1402,,NOVITA_API_KEY,0,True,none, +Together AI,together_ai/moonshotai/Kimi-K2-Instruct,1.0,3.0,1402,,TOGETHERAI_API_KEY,0,True,none, +W&B Inference,wandb/moonshotai/Kimi-K2-Instruct,0.6,2.5,1402,,WANDB_API_KEY,0,False,none, diff --git a/pdd/generate_model_catalog.py b/pdd/generate_model_catalog.py new file mode 100644 index 000000000..907d00725 --- /dev/null +++ b/pdd/generate_model_catalog.py @@ -0,0 +1,711 @@ +#!/usr/bin/env python3 +""" +scripts/generate_model_catalog.py + +Regenerates pdd/data/llm_model.csv from LiteLLM's bundled model registry. + +Usage: + python scripts/generate_model_catalog.py [--output PATH] + +The script pulls from litellm.model_cost (local data, no network calls) and: + - Filters to chat-mode models only + - Skips deprecated models + - Skips placeholder/tier entries (e.g. together-ai-4.1b-8b) + - Converts per-token costs to per-million-token costs + - Looks up display provider names and API key env var names + - Applies curated ELO scores for known models; skips models below ELO_CUTOFF + - Infers structured_output, reasoning_type, max_reasoning_tokens + - Sorts by ELO descending, then model name ascending + +Re-run this script whenever you update the litellm package to pick up new models. +""" + +from __future__ import annotations + +import argparse +import csv +import re +import sys +from datetime import date +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +# --------------------------------------------------------------------------- +# ELO cutoff — models below this score are excluded from the catalog. +# --------------------------------------------------------------------------- +ELO_CUTOFF = 1400 + +# --------------------------------------------------------------------------- +# ELO scores — canonical base model names mapped to coding arena ELO. +# All known models are listed here; ELO_CUTOFF controls which make the CSV. +# Keys are normalized base names (as produced by _extract_base_model). + +# Scores sourced from LM Arena *coding* leaderboard (Feb 2026): +# https://openlm.ai/chatbot-arena/ (Coding column) +# You should update these values every so often. +# --------------------------------------------------------------------------- +ELO_SCORES: Dict[str, int] = { + # ----------------------------------------------------------------------- + # Anthropic Claude — dash-separated canonical form + # ----------------------------------------------------------------------- + "claude-opus-4-6": 1530, + "claude-opus-4-5": 1496, + "claude-opus-4": 1405, + "claude-opus-4-1": 1475, + "claude-sonnet-4-6": 1485, + "claude-sonnet-4-5": 1464, + "claude-sonnet-4": 1384, + "claude-3-7-sonnet": 1341, + "claude-3-5-sonnet-20241022": 1340, + "claude-3-5-sonnet-20240620": 1309, + "claude-3-5-sonnet": 1340, + "claude-haiku-4-5": 1436, + "claude-3-5-haiku": 1287, + "claude-3-opus": 1269, + "claude-3-haiku": 1208, + "claude-3-sonnet": 1232, + # Dot-separated variants (OpenRouter, GitHub Copilot, Vercel, GMI) + "claude-opus-4.6": 1530, + "claude-opus-4.5": 1496, + "claude-opus-4.1": 1475, + "claude-sonnet-4.6": 1485, + "claude-sonnet-4.5": 1464, + "claude-haiku-4.5": 1436, + "claude-3.5-sonnet": 1340, + "claude-3.5-haiku": 1287, + "claude-3.7-sonnet": 1341, + # Alternate naming: "claude-4-opus" / "claude-4-sonnet" + "claude-4-opus": 1405, + "claude-4-sonnet": 1384, + # ----------------------------------------------------------------------- + # OpenAI — GPT-5 family + # ----------------------------------------------------------------------- + "gpt-5": 1460, + "gpt-5.1": 1450, + "gpt-5.2": 1465, + "gpt-5-mini": 1419, + "gpt-5-nano": 1363, + # OpenAI — GPT-4.x + "gpt-4.5": 1419, + "gpt-4.1": 1396, + "gpt-4.1-mini": 1370, + "gpt-4.1-nano": 1312, + "gpt-4o": 1307, + "gpt-4o-2024-08-06": 1307, + "gpt-4o-2024-11-20": 1307, + "gpt-4o-mini": 1300, + "gpt-4-turbo": 1280, + "gpt-4-0125-preview": 1261, + "gpt-4-1106-preview": 1269, + # OpenAI — o-series + "o3": 1441, + "o4-mini": 1385, + "o3-mini": 1361, + "o1": 1378, + "o1-mini": 1366, + "o1-preview": 1378, + # OpenAI — gpt-oss + "gpt-oss-120b": 1398, + "gpt-oss-20b": 1371, + # ----------------------------------------------------------------------- + # Google Gemini + # ----------------------------------------------------------------------- + "gemini-3-pro": 1501, + "gemini-3-pro-preview": 1501, + "gemini-3-flash": 1469, + "gemini-3-flash-preview": 1469, + "gemini-2.5-pro": 1465, + "gemini-2.5-flash": 1420, + "gemini-2.0-flash": 1371, + "gemini-2.0-flash-thinking": 1383, + "gemini-1.5-pro": 1311, + "gemini-1.5-flash": 1273, + # ----------------------------------------------------------------------- + # DeepSeek + # ----------------------------------------------------------------------- + "deepseek-r1": 1382, + "deepseek-r1-0528": 1436, + "deepseek-reasoner": 1382, + "deepseek-chat": 1337, + "deepseek-v3": 1337, + "deepseek-v3-0324": 1391, + "deepseek-v3.1": 1430, + "deepseek-v3.2": 1431, + # ----------------------------------------------------------------------- + # xAI / Grok + # ----------------------------------------------------------------------- + "grok-4.1": 1483, + "grok-4": 1453, + "grok-4-fast": 1441, + "grok-3": 1439, + "grok-3-mini": 1380, + "grok-2": 1298, + # ----------------------------------------------------------------------- + # Mistral + # ----------------------------------------------------------------------- + "mistral-large": 1450, + "mistral-large-3": 1450, + "mistral-medium-3": 1387, + "mistral-medium-3.1": 1412, + "magistral-medium": 1307, + "magistral-small": 1330, + "codestral": 1300, + "mistral-small-3.1": 1295, + "mistral-small-3.2": 1361, + "mistral-small-3": 1251, + # ----------------------------------------------------------------------- + # Moonshot / Kimi + # ----------------------------------------------------------------------- + "kimi-k2.5": 1480, + "kimi-k2-instruct": 1402, + "kimi-k2-thinking": 1450, + "kimi-k2-0905": 1403, + "kimi-k2-0711": 1402, + # ----------------------------------------------------------------------- + # Meta Llama + # ----------------------------------------------------------------------- + "llama-4-maverick-17b-128e": 1312, + "llama-4-scout-17b-16e": 1290, + "llama-3.3-70b": 1279, + "llama-3.1-405b": 1299, + "llama-3.1-70b": 1268, + "llama-3.1-8b": 1203, + "llama-3-70b": 1216, + # ----------------------------------------------------------------------- + # Qwen / Alibaba + # ----------------------------------------------------------------------- + "qwen3-max": 1468, + "qwen3-235b-a22b": 1394, + "qwen3-235b-a22b-instruct-2507": 1457, + "qwen3-235b-a22b-thinking-2507": 1442, + "qwen3-32b": 1376, + "qwen3-30b-a3b": 1346, + "qwen3-coder-480b-a35b": 1406, + "qwq-32b": 1351, + "qwen2.5-72b": 1302, + "qwen2.5-max": 1373, + # ----------------------------------------------------------------------- + # GLM (Zhipu AI / ZAI) + # ----------------------------------------------------------------------- + "glm-5": 1461, + "glm-4.7": 1460, + "glm-4.6": 1458, + "glm-4.5": 1448, + "glm-4.5-air": 1410, + # ----------------------------------------------------------------------- + # Minimax + # ----------------------------------------------------------------------- + "minimax-m2.1": 1430, + "minimax-m1": 1369, + "minimax-m2": 1430, + # ----------------------------------------------------------------------- + # Amazon Nova + # ----------------------------------------------------------------------- + "nova-pro": 1282, + "nova-lite": 1253, + "nova-micro": 1228, + # ----------------------------------------------------------------------- + # MiMo (Xiaomi) + # ----------------------------------------------------------------------- + "mimo-v2-flash": 1411, + # ----------------------------------------------------------------------- + # Gemma (Google open) + # ----------------------------------------------------------------------- + "gemma-3-27b": 1350, + "gemma-3-12b": 1310, + "gemma-3-4b": 1265, + # ----------------------------------------------------------------------- + # NVIDIA Nemotron + # ----------------------------------------------------------------------- + "llama-3.3-nemotron-super-49b": 1359, + "llama-3.1-nemotron-70b": 1289, + # ----------------------------------------------------------------------- + # Phi (Microsoft) + # ----------------------------------------------------------------------- + "phi-4": 1242, +} + +# --------------------------------------------------------------------------- +# Provider table — maps litellm provider ID to (display name, API key env var). +# --------------------------------------------------------------------------- +PROVIDERS: Dict[str, Tuple[str, str]] = { + "openai": ("OpenAI", "OPENAI_API_KEY"), + "anthropic": ("Anthropic", "ANTHROPIC_API_KEY"), + "gemini": ("Google Gemini", "GEMINI_API_KEY"), + "vertex_ai": ("Google Vertex AI", "GOOGLE_APPLICATION_CREDENTIALS|VERTEXAI_PROJECT|VERTEXAI_LOCATION"), + "xai": ("xAI", "XAI_API_KEY"), + "deepseek": ("DeepSeek", "DEEPSEEK_API_KEY"), + "mistral": ("Mistral AI", "MISTRAL_API_KEY"), + "cohere": ("Cohere", "COHERE_API_KEY"), + "cohere_chat": ("Cohere", "COHERE_API_KEY"), + "moonshot": ("Moonshot AI", "MOONSHOT_API_KEY"), + "groq": ("Groq", "GROQ_API_KEY"), + "fireworks_ai": ("Fireworks AI", "FIREWORKS_AI_API_KEY"), + "together_ai": ("Together AI", "TOGETHERAI_API_KEY"), + "perplexity": ("Perplexity", "PERPLEXITYAI_API_KEY"), + "openrouter": ("OpenRouter", "OPENROUTER_API_KEY"), + "deepinfra": ("DeepInfra", "DEEPINFRA_API_KEY"), + "cerebras": ("Cerebras", "CEREBRAS_API_KEY"), + "replicate": ("Replicate", "REPLICATE_API_KEY"), + "anyscale": ("Anyscale", "ANYSCALE_API_KEY"), + "novita": ("Novita AI", "NOVITA_API_KEY"), + "sambanova": ("SambaNova", "SAMBANOVA_API_KEY"), + "nvidia_nim": ("NVIDIA NIM", "NVIDIA_NIM_API_KEY"), + "bedrock": ("AWS Bedrock", "AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME"), + "bedrock_converse": ("AWS Bedrock", "AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME"), + "sagemaker": ("AWS SageMaker", "AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME"), + "azure": ("Azure OpenAI", "AZURE_API_KEY|AZURE_API_BASE|AZURE_API_VERSION"), + "azure_ai": ("Azure AI", "AZURE_AI_API_KEY"), + "databricks": ("Databricks", "DATABRICKS_API_KEY"), + "watsonx": ("IBM watsonx", "WATSONX_APIKEY"), + "cloudflare": ("Cloudflare Workers AI", "CLOUDFLARE_API_KEY"), + "huggingface": ("Hugging Face", "HF_TOKEN"), + "ai21": ("AI21", "AI21_API_KEY"), + "nlp_cloud": ("NLP Cloud", "NLP_CLOUD_API_KEY"), + "aleph_alpha": ("Aleph Alpha", "ALEPHALPHA_API_KEY"), + "predibase": ("Predibase", "PREDIBASE_API_KEY"), + "friendliai": ("FriendliAI", "FRIENDLI_TOKEN"), + "github": ("GitHub Models", "GITHUB_API_KEY"), + "github_copilot": ("Github Copilot", ""), + "clarifai": ("Clarifai", "CLARIFAI_PAT"), + "voyage": ("Voyage", "VOYAGE_API_KEY"), + "codestral": ("Codestral", "CODESTRAL_API_KEY"), + "infinity": ("Infinity", "INFINITY_API_KEY"), + "nscale": ("Nscale", "NSCALE_API_KEY"), + "hyperbolic": ("Hyperbolic", "HYPERBOLIC_API_KEY"), + "lambda_ai": ("Lambda AI", "LAMBDA_API_KEY"), + "featherless_ai": ("Featherless AI", "FEATHERLESS_API_KEY"), + "gmi": ("GMI Cloud", "GMI_API_KEY"), + "wandb": ("W&B Inference", "WANDB_API_KEY"), + "vercel_ai_gateway": ("Vercel AI Gateway", "VERCEL_AI_GATEWAY_API_KEY"), + "ollama": ("Ollama", ""), + "ollama_chat": ("Ollama", ""), + "lm_studio": ("LM Studio", ""), +} + +# Anthropic provider IDs — these use "budget" reasoning +_ANTHROPIC_PROVIDERS = {"anthropic", "azure_ai"} # azure_ai hosts Claude models too + +# Model name patterns that signal reasoning (for providers not in the sets above) +_EFFORT_PATTERNS = re.compile( + r"o1|o3|o4|gemini.*thinking|deepseek.r1|deepseek.reasoner|" + r"qwen.*thinking|kimi.*thinking|magistral|" + r"gemini.*flash.*thinking", + re.IGNORECASE, +) + +# Placeholder tier entries in together_ai (not real model IDs) +_TIER_PATTERN = re.compile(r"^together-ai-[\d.]+b", re.IGNORECASE) + +# Models we never want in the catalog (sample spec, image-only, etc.) +_SKIP_KEYS = {"sample_spec"} + +# Regex matching dated preview model names (after provider prefix is stripped). +# Examples: gemini-2.5-flash-preview-04-17, gemini-2.5-pro-preview-06-05 +_DATED_PREVIEW = re.compile( + r"^(?Pgemini-[\d.]+-\w+)-preview-\d{2}-\d{2,4}$", + re.IGNORECASE, +) + +CSV_FIELDNAMES = [ + "provider", "model", "input", "output", "coding_arena_elo", + "base_url", "api_key", "max_reasoning_tokens", "structured_output", + "reasoning_type", "location", +] + +# --------------------------------------------------------------------------- +# Regex patterns for _extract_base_model() — stripping provider/region/version +# --------------------------------------------------------------------------- + +# Known provider prefixes (simple provider/rest format) +_SIMPLE_PREFIX_PROVIDERS = { + "vertex_ai", "azure_ai", "openrouter", "deepinfra", "together_ai", + "fireworks_ai", "vercel_ai_gateway", "github_copilot", "groq", + "cerebras", "hyperbolic", "novita", "sambanova", "replicate", + "lambda_ai", "nscale", "oci", "gmi", "wandb", "ovhcloud", + "llamagate", "gradient_ai", "moonshot", "snowflake", "heroku", + "publicai", "deepseek", "xai", "mistral", "gemini", "perplexity", + "cohere", "cohere_chat", "meta_llama", "dashscope", +} + +# Bedrock region paths: us-east-1/, ap-northeast-1/, us-gov-west-1/, etc. +# Also handles commitment and invoke prefixes. +_BEDROCK_REGION_PATH = re.compile( + r"^(?:[a-z]{2}-[a-z]+-\d+/)+" # one or more region segments + r"|^(?:\d+-month-commitment/)" + r"|^(?:invoke/)", + re.IGNORECASE, +) + +# Azure sub-region paths: eu/, global/, global-standard/ +_AZURE_REGION_PREFIX = re.compile( + r"^(?:eu|global-standard|global|us)/", + re.IGNORECASE, +) + +# Bedrock cross-region inference prefixes on bare IDs: us., eu., apac., au., jp., global. +_BEDROCK_GEO_PREFIX = re.compile( + r"^(?:us|eu|apac|ap|au|jp|global)\.", + re.IGNORECASE, +) + +# Vendor dot-namespace: anthropic., meta., moonshotai., deepseek., xai., etc. +# Used by Bedrock (anthropic.claude-*) and OCI (xai.grok-3, meta.llama-*) +_VENDOR_DOT_PREFIX = re.compile( + r"^(?:anthropic|meta|amazon|cohere|ai21|mistral|moonshotai|deepseek|" + r"qwen|minimax|nvidia|openai|google|writer|twelvelabs|zai|xai)\.", + re.IGNORECASE, +) + +# HuggingFace-style org namespaces used by deepinfra, together_ai, openrouter, etc. +_ORG_NAMESPACE = re.compile( + r"^(?:deepseek-ai|deepseek|meta-llama|meta|anthropic|google|openai|" + r"moonshotai|mistralai|qwen|Qwen|x-ai|xai|cohere|microsoft|" + r"allenai|NousResearch|nvidia|MiniMaxAI)/", + re.IGNORECASE, +) + +# Fireworks account path: accounts/fireworks/models/ (or any account) +_FIREWORKS_ACCOUNT = re.compile( + r"^accounts/[^/]+/models/", + re.IGNORECASE, +) + +# Anthropic fast/us routing prefixes on bare IDs +_FAST_PREFIX = re.compile(r"^(?:fast/us/|fast/|us/)", re.IGNORECASE) + +# Vertex AI @version suffix: @20241022, @default, @001, @latest +_VERTEX_VERSION = re.compile(r"@[\w.-]+$") + +# Bedrock version suffix: -v1:0, -v2:0, :0 +_BEDROCK_VERSION = re.compile(r"(?:-v\d+:\d+|:\d+)$") + +# Special mapping for Bedrock deepseek after vendor prefix is stripped +# e.g. deepseek.v3.2 -> strips to "v3.2" or "v3-v1:0" -> "v3" +_BEDROCK_DEEPSEEK_REMAP: Dict[str, str] = { + "v3": "deepseek-v3", + "v3.2": "deepseek-v3", + "r1": "deepseek-r1", +} + +# Safe remainder patterns after a canonical prefix match. +# Only accept: empty, date suffixes (-20241022), version tags (-v1, -v2), +# preview/latest tags, or @version. +# This REJECTS things like -distill-*, -turbo, -mini, -fast. +_SAFE_REMAINDER = re.compile( + r"^(?:" + r"-\d{8}" # -20241022 (8-digit date) + r"|-v\d+" # -v1, -v2 + r"|-preview" # -preview + r"|-latest" # -latest + r"|-instruct" # -instruct (same weights, just instruction-tuned name) + r"|-versatile" # -versatile (Groq naming for same model) + r"|-\d{4}(?:0[1-9]|1[0-2])\d{2}" # -YYYYMMDD compact + r")(?:$|[-@])", # must be end-of-string or followed by another suffix + re.IGNORECASE, +) + + +def _extract_base_model(model_id: str) -> Optional[str]: + """ + Extract a canonical base model name from a litellm model ID by stripping + provider prefixes, regions, vendor namespaces, and version suffixes. + + Returns a key matching ELO_SCORES if confident, or None if the model + cannot be safely identified (conservative — prefers returning None over + a wrong match). + """ + s = model_id.strip() + + # Step 1: Strip known provider prefix + slash_pos = s.find("/") + if slash_pos > 0: + prefix = s[:slash_pos] + if prefix in _SIMPLE_PREFIX_PROVIDERS: + s = s[slash_pos + 1:] + # Azure AI and some providers also have region sub-paths (global/, etc.) + s = _AZURE_REGION_PREFIX.sub("", s) + elif prefix == "bedrock" or prefix == "bedrock_converse": + s = s[slash_pos + 1:] + # Strip region paths (may be multiple segments) + while _BEDROCK_REGION_PATH.match(s): + s = _BEDROCK_REGION_PATH.sub("", s, count=1) + elif prefix == "azure": + s = s[slash_pos + 1:] + s = _AZURE_REGION_PREFIX.sub("", s) + elif prefix == "openai": + s = s[slash_pos + 1:] + + # Step 2: Strip fast/us routing prefixes (bare Anthropic IDs) + s = _FAST_PREFIX.sub("", s) + + # Step 3: Strip Bedrock cross-region geo prefixes (us., eu., apac., etc.) + s = _BEDROCK_GEO_PREFIX.sub("", s) + + # Step 4: Strip vendor dot-namespace (anthropic., meta., moonshotai., xai., etc.) + # Only when there's no slash left (to avoid mangling org/model paths) + if "/" not in s and "." in s: + m = _VENDOR_DOT_PREFIX.match(s) + if m: + s = s[m.end():] + + # Step 5: Strip HuggingFace-style org namespace (deepseek-ai/, meta-llama/, etc.) + s = _ORG_NAMESPACE.sub("", s) + + # Step 6: Strip Fireworks account path + s = _FIREWORKS_ACCOUNT.sub("", s) + + # Step 7: Strip Vertex AI @version suffix + s = _VERTEX_VERSION.sub("", s) + + # Step 8: Strip Bedrock version suffix (-v1:0, :0) + s = _BEDROCK_VERSION.sub("", s) + + # Step 9: Lowercase for matching + s = s.lower() + + # Step 10: Handle Bedrock deepseek special naming (vendor-stripped leftovers) + if s in _BEDROCK_DEEPSEEK_REMAP: + s = _BEDROCK_DEEPSEEK_REMAP[s] + + # Step 11: Strip trailing -maas suffix (Vertex AI model-as-a-service) + if s.endswith("-maas"): + s = s[:-5] + + # Step 12: Exact match + if s in ELO_SCORES: + return s + + # Step 13: Longest-prefix match against ELO_SCORES keys. + # Sorted longest-first to prefer more specific matches + # (e.g. "claude-opus-4-1" over "claude-opus-4"). + for key in sorted(ELO_SCORES, key=len, reverse=True): + if s.startswith(key): + remainder = s[len(key):] + if not remainder: + return key + if _SAFE_REMAINDER.match(remainder): + return key + + return None + + +def _get_provider_root(litellm_provider: str) -> str: + """Return the root provider for compound provider strings like vertex_ai-anthropic_models.""" + return litellm_provider.split("-")[0].split("_models")[0] + + +def _infer_reasoning_type(model_id: str, litellm_provider: str, entry: dict) -> str: + supports_reasoning = entry.get("supports_reasoning", False) + if not supports_reasoning: + return "none" + root = _get_provider_root(litellm_provider) + # Anthropic (and Azure AI hosting Claude) use "budget" reasoning tokens + if root in _ANTHROPIC_PROVIDERS: + return "budget" + # All other providers use "effort" (low/medium/high string) + return "effort" + + +def _infer_max_reasoning_tokens(model_id: str, litellm_provider: str, entry: dict) -> int: + root = _get_provider_root(litellm_provider) + if not entry.get("supports_reasoning", False): + return 0 + if root in _ANTHROPIC_PROVIDERS: + return 128000 + return 0 + + +def _is_deprecated(entry: dict) -> bool: + dep = entry.get("deprecation_date") + if not dep or not isinstance(dep, str): + return False + try: + dep_date = date.fromisoformat(dep) + return dep_date <= date.today() + except ValueError: + return False + + +def _is_placeholder(model_id: str) -> bool: + """Filter out non-usable placeholder entries.""" + if model_id in _SKIP_KEYS: + return True + if _TIER_PATTERN.match(model_id): + return True + return False + + +def _is_superseded_preview(model_id: str, all_model_ids: set) -> bool: + """Return True if this is a dated Gemini preview whose stable GA version exists. + + Google routinely sunsets dated preview models (e.g. gemini-2.5-flash-preview-04-17) + once the stable GA version (gemini-2.5-flash) is available, but litellm's registry + often retains them without a deprecation_date. We skip these to avoid catalog + entries that fail at call time with a 404. + + The check is applied to both bare IDs (gemini-2.5-flash-preview-04-17) and + provider-prefixed IDs (gemini/gemini-2.5-flash-preview-04-17) — we strip the + provider prefix before matching. + """ + # Strip simple provider prefix (e.g. "gemini/", "vertex_ai/") + bare = model_id + slash = bare.find("/") + if slash > 0: + bare = bare[slash + 1:] + + m = _DATED_PREVIEW.match(bare) + if not m: + return False + + ga_name = m.group("base") # e.g. "gemini-2.5-flash" + + # Check whether the stable GA version exists in litellm's registry + # (either bare or under common provider prefixes) + if ga_name in all_model_ids: + return True + if f"gemini/{ga_name}" in all_model_ids: + return True + + return False + + +def _get_elo(model_id: str) -> int: + """Look up ELO for a model. + + Lookup order (stops at first hit): + 1. Exact match in ELO_SCORES + 2. _extract_base_model() -> ELO_SCORES lookup + 3. Return 0 + """ + if model_id in ELO_SCORES: + return ELO_SCORES[model_id] + canonical = _extract_base_model(model_id) + if canonical is not None: + return ELO_SCORES[canonical] + return 0 + + +def build_rows() -> List[dict]: + try: + import litellm + except ImportError: + print("ERROR: litellm is not installed. Run: pip install litellm", file=sys.stderr) + sys.exit(1) + + all_model_ids = set(litellm.model_cost.keys()) + rows = [] + skipped_previews = 0 + + for model_id, entry in litellm.model_cost.items(): + # Only chat mode + if entry.get("mode") != "chat": + continue + # Skip deprecated + if _is_deprecated(entry): + continue + # Skip placeholder/tier entries + if _is_placeholder(model_id): + continue + # Skip dated preview models superseded by a stable GA release + if _is_superseded_preview(model_id, all_model_ids): + skipped_previews += 1 + continue + # Skip models that cannot produce text output (e.g. TTS / audio-only) + output_modalities = entry.get("supported_output_modalities", []) + if output_modalities and "text" not in output_modalities: + continue + + litellm_provider: str = entry.get("litellm_provider", "") + root_provider = _get_provider_root(litellm_provider) + + # ELO — skip models below cutoff or with no known score + elo = _get_elo(model_id) + if elo < ELO_CUTOFF: + continue + + # Convert per-token costs to per-million + in_cost_token = entry.get("input_cost_per_token") or 0.0 + out_cost_token = entry.get("output_cost_per_token") or 0.0 + input_cost = round(in_cost_token * 1_000_000, 6) + output_cost = round(out_cost_token * 1_000_000, 6) + + # Provider display name and API key env var + display_name, api_key = PROVIDERS.get( + litellm_provider, + PROVIDERS.get( + root_provider, + (litellm_provider.replace("_", " ").title(), f"{root_provider.upper()}_API_KEY"), + ), + ) + + # Structured output + structured = bool( + entry.get("supports_function_calling") or + entry.get("supports_response_schema") + ) + + # Reasoning + reasoning_type = _infer_reasoning_type(model_id, litellm_provider, entry) + max_reasoning_tokens = _infer_max_reasoning_tokens(model_id, litellm_provider, entry) + + # Location (Vertex AI models default to global) + location = "global" if litellm_provider.startswith("vertex_ai") else "" + + rows.append({ + "provider": display_name, + "model": model_id, + "input": input_cost, + "output": output_cost, + "coding_arena_elo": elo, + "base_url": "", + "api_key": api_key, + "max_reasoning_tokens": max_reasoning_tokens, + "structured_output": structured, + "reasoning_type": reasoning_type, + "location": location, + }) + + if skipped_previews: + print(f" Skipped {skipped_previews} dated preview model(s) superseded by stable GA releases.") + + # Sort: ELO descending, then model name ascending + rows.sort(key=lambda r: (-r["coding_arena_elo"], r["model"])) + return rows + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + default_output = Path(__file__).parent.parent / "pdd" / "data" / "llm_model.csv" + parser.add_argument( + "--output", "-o", + type=Path, + default=default_output, + help=f"Output CSV path (default: {default_output})", + ) + args = parser.parse_args() + + output_path: Path = args.output + output_path.parent.mkdir(parents=True, exist_ok=True) + + print("Building model catalog from litellm.model_cost...") + rows = build_rows() + print(f" Found {len(rows)} chat models across all providers.") + + with open(output_path, "w", newline="", encoding="utf-8") as f: + writer = csv.DictWriter(f, fieldnames=CSV_FIELDNAMES) + writer.writeheader() + writer.writerows(rows) + + print(f" Written to: {output_path}") + + # Print a quick summary by provider + from collections import Counter + providers = Counter(r["provider"] for r in rows) + print("\nTop providers by model count:") + for provider, count in providers.most_common(20): + print(f" {provider}: {count}") + + +if __name__ == "__main__": + main() diff --git a/pdd/litellm_registry.py b/pdd/litellm_registry.py deleted file mode 100644 index fa0ea16f5..000000000 --- a/pdd/litellm_registry.py +++ /dev/null @@ -1,312 +0,0 @@ -""" -pdd/litellm_registry.py - -Wraps litellm's bundled model registry to provide provider search, model -browsing, and API key env var lookup. Uses only local data — no network calls. -""" -from __future__ import annotations - -import logging -from dataclasses import dataclass, field -from typing import Dict, List, Optional, Set - -logger = logging.getLogger(__name__) - -# --------------------------------------------------------------------------- -# Dataclasses -# --------------------------------------------------------------------------- - - -@dataclass -class ProviderInfo: - """Summary information about an LLM provider.""" - - name: str # litellm provider ID, e.g. "anthropic" - display_name: str # human-friendly, e.g. "Anthropic" - api_key_env_var: Optional[str] # e.g. "ANTHROPIC_API_KEY" - model_count: int # number of chat models available - sample_models: List[str] = field(default_factory=list) # up to 3 names - - -@dataclass -class ModelInfo: - """Metadata for a single LLM model from litellm's registry.""" - - name: str # short display name, e.g. "claude-opus-4-5" - litellm_id: str # full ID for litellm.completion() - input_cost_per_million: float # USD per 1M input tokens - output_cost_per_million: float # USD per 1M output tokens - max_input_tokens: Optional[int] = None - max_output_tokens: Optional[int] = None - supports_vision: bool = False - supports_function_calling: bool = False - - -# --------------------------------------------------------------------------- -# Static mappings -# --------------------------------------------------------------------------- - -PROVIDER_API_KEY_MAP: Dict[str, str] = { - "openai": "OPENAI_API_KEY", - "anthropic": "ANTHROPIC_API_KEY", - "gemini": "GEMINI_API_KEY", - "vertex_ai": "VERTEX_CREDENTIALS", - "groq": "GROQ_API_KEY", - "mistral": "MISTRAL_API_KEY", - "deepseek": "DEEPSEEK_API_KEY", - "fireworks_ai": "FIREWORKS_API_KEY", - "together_ai": "TOGETHERAI_API_KEY", - "perplexity": "PERPLEXITYAI_API_KEY", - "cohere": "COHERE_API_KEY", - "cohere_chat": "COHERE_API_KEY", - "replicate": "REPLICATE_API_KEY", - "xai": "XAI_API_KEY", - "deepinfra": "DEEPINFRA_API_KEY", - "cerebras": "CEREBRAS_API_KEY", - "ai21": "AI21_API_KEY", - "bedrock": "AWS_ACCESS_KEY_ID", - "azure": "AZURE_API_KEY", - "azure_ai": "AZURE_AI_API_KEY", - "openrouter": "OPENROUTER_API_KEY", - "huggingface": "HUGGINGFACE_API_KEY", - "databricks": "DATABRICKS_API_KEY", - "cloudflare": "CLOUDFLARE_API_KEY", - "novita": "NOVITA_API_KEY", - "sambanova": "SAMBANOVA_API_KEY", - "watsonx": "WATSONX_API_KEY", -} - -PROVIDER_DISPLAY_NAMES: Dict[str, str] = { - "openai": "OpenAI", - "anthropic": "Anthropic", - "gemini": "Google Gemini", - "vertex_ai": "Google Vertex AI", - "groq": "Groq", - "mistral": "Mistral AI", - "deepseek": "DeepSeek", - "fireworks_ai": "Fireworks AI", - "together_ai": "Together AI", - "perplexity": "Perplexity", - "cohere": "Cohere", - "cohere_chat": "Cohere Chat", - "replicate": "Replicate", - "xai": "xAI", - "deepinfra": "DeepInfra", - "cerebras": "Cerebras", - "ai21": "AI21", - "bedrock": "AWS Bedrock", - "azure": "Azure OpenAI", - "azure_ai": "Azure AI", - "openrouter": "OpenRouter", - "huggingface": "Hugging Face", - "databricks": "Databricks", - "cloudflare": "Cloudflare Workers AI", - "novita": "Novita AI", - "sambanova": "SambaNova", - "watsonx": "IBM watsonx", -} - -# Curated list of major cloud providers shown by default. -_TOP_PROVIDER_IDS = [ - "openai", - "anthropic", - "gemini", - "fireworks_ai", - "mistral", - "xai", - "groq", - "deepseek", - "together_ai", - "openrouter", -] - - -# --------------------------------------------------------------------------- -# Internal helpers -# --------------------------------------------------------------------------- - - -def _get_display_name(provider: str) -> str: - """Return the human-friendly name for *provider*, with a title-case fallback.""" - return PROVIDER_DISPLAY_NAMES.get(provider, provider.replace("_", " ").title()) - - -def _collect_chat_models_for_provider(provider: str) -> Dict[str, dict]: - """Return ``{model_id: cost_entry}`` for all chat models belonging to *provider*. - - Handles the ``vertex_ai`` sub-provider convention by matching any - ``litellm_provider`` that starts with *provider*. - - Falls back to scanning ``litellm.model_cost`` directly when - ``models_by_provider`` entries are missing from ``model_cost``. - """ - import litellm # local import — guarded by is_litellm_available() - - result: Dict[str, dict] = {} - - # Strategy 1: use models_by_provider set, then look up cost data. - model_names: Set[str] = set() - if provider in litellm.models_by_provider: - model_names = set(litellm.models_by_provider[provider]) - - for name in model_names: - entry = litellm.model_cost.get(name) - if entry and entry.get("mode") == "chat": - result[name] = entry - - # Strategy 2 (fallback): scan model_cost for provider match. - # Needed for together_ai and vertex_ai sub-providers. - for model_id, entry in litellm.model_cost.items(): - if model_id in result: - continue - lp = entry.get("litellm_provider", "") - if lp == provider or lp.startswith(f"{provider}-"): - if entry.get("mode") == "chat": - result[model_id] = entry - - return result - - -def _entry_to_model_info(model_id: str, entry: dict) -> ModelInfo: - """Convert a ``litellm.model_cost`` entry into a :class:`ModelInfo`.""" - input_per_token = entry.get("input_cost_per_token") or 0 - output_per_token = entry.get("output_cost_per_token") or 0 - return ModelInfo( - name=model_id.split("/")[-1] if "/" in model_id else model_id, - litellm_id=model_id, - input_cost_per_million=round(input_per_token * 1_000_000, 4), - output_cost_per_million=round(output_per_token * 1_000_000, 4), - max_input_tokens=entry.get("max_input_tokens"), - max_output_tokens=entry.get("max_output_tokens"), - supports_vision=bool(entry.get("supports_vision")), - supports_function_calling=bool(entry.get("supports_function_calling")), - ) - - -def _build_provider_info(provider: str) -> Optional[ProviderInfo]: - """Build a :class:`ProviderInfo` for *provider*, or ``None`` if it has no chat models.""" - chat_models = _collect_chat_models_for_provider(provider) - if not chat_models: - return None - sample = sorted(chat_models.keys())[:3] - return ProviderInfo( - name=provider, - display_name=_get_display_name(provider), - api_key_env_var=PROVIDER_API_KEY_MAP.get(provider), - model_count=len(chat_models), - sample_models=sample, - ) - - -# --------------------------------------------------------------------------- -# Public API -# --------------------------------------------------------------------------- - - -def is_litellm_available() -> bool: - """Return ``True`` if litellm is importable and has model data.""" - try: - import litellm - - return bool(litellm.model_cost) - except Exception: - return False - - -def get_api_key_env_var(provider: str) -> Optional[str]: - """Return the API key environment variable name for *provider*. - - Returns ``None`` if the provider is not in the known mapping. - """ - return PROVIDER_API_KEY_MAP.get(provider) - - -def get_top_providers() -> List[ProviderInfo]: - """Return a curated list of major cloud providers, sorted by the curated order. - - Falls back to all providers sorted by model count if the curated list - yields fewer than 3 results (e.g. litellm data changed). - """ - if not is_litellm_available(): - return [] - - result: List[ProviderInfo] = [] - for pid in _TOP_PROVIDER_IDS: - info = _build_provider_info(pid) - if info: - result.append(info) - - if len(result) < 3: - return get_all_providers()[:10] - return result - - -def get_all_providers() -> List[ProviderInfo]: - """Return all providers that have at least one chat model. - - Sorted by model count descending. - """ - if not is_litellm_available(): - return [] - - import litellm - - seen: Set[str] = set() - infos: List[ProviderInfo] = [] - - # Collect from models_by_provider keys - for provider in litellm.models_by_provider: - if provider in seen: - continue - seen.add(provider) - info = _build_provider_info(provider) - if info: - infos.append(info) - - # Also scan model_cost for providers not in models_by_provider - for entry in litellm.model_cost.values(): - lp = entry.get("litellm_provider", "") - # Normalise vertex_ai sub-providers - base = lp.split("-")[0] if "-" in lp else lp - if base and base not in seen: - seen.add(base) - info = _build_provider_info(base) - if info: - infos.append(info) - - infos.sort(key=lambda i: i.model_count, reverse=True) - return infos - - -def search_providers(query: str) -> List[ProviderInfo]: - """Return providers whose name or display name contains *query* (case-insensitive). - - Sorted by model count descending. - """ - if not query: - return get_all_providers() - - all_providers = get_all_providers() - q = query.lower() - return [ - p - for p in all_providers - if q in p.name.lower() or q in p.display_name.lower() - ] - - -def get_models_for_provider(provider: str) -> List[ModelInfo]: - """Return chat-mode models for *provider*, sorted by name. - - Converts per-token costs to per-million-token costs. - """ - if not is_litellm_available(): - return [] - - chat_models = _collect_chat_models_for_provider(provider) - result = [ - _entry_to_model_info(model_id, entry) - for model_id, entry in chat_models.items() - ] - result.sort(key=lambda m: m.name) - return result diff --git a/pdd/llm_invoke.py b/pdd/llm_invoke.py index ad374f463..5d98c0969 100644 --- a/pdd/llm_invoke.py +++ b/pdd/llm_invoke.py @@ -1166,13 +1166,53 @@ def _save_key_to_env_file(key_name: str, value: str, env_path: Path) -> None: def _ensure_api_key(model_info: Dict[str, Any], newly_acquired_keys: Dict[str, bool], verbose: bool) -> bool: - """Checks for API key in env, prompts user if missing, and updates .env.""" - key_name = model_info.get('api_key') + """Checks for API key(s) in env, prompts user if missing, and updates .env. - if not key_name or key_name == "EXISTING_KEY": + Supports pipe-delimited api_key fields (e.g. ``VAR1|VAR2|VAR3``). + - Empty field → no auth needed (device flow / local model), always True. + - Single var → existing interactive-prompt behaviour for simple providers. + - Multi var → checks all vars; if any missing, directs user to ``pdd setup``. + """ + from pdd.provider_manager import parse_api_key_vars + + api_key_field = str(model_info.get('api_key', '') or '') + + if not api_key_field.strip() or api_key_field == "EXISTING_KEY": if verbose: - logger.info(f"Skipping API key check for model {model_info.get('model')} (key name: {key_name})") - return True # Assume key is handled elsewhere or not needed + logger.info(f"Skipping API key check for model {model_info.get('model')} (key field: {api_key_field!r})") + return True # Device flow, local model, or handled elsewhere + + env_vars = parse_api_key_vars(api_key_field) + + # --- Multi-credential provider (pipe-delimited) --- + if len(env_vars) > 1: + missing = [v for v in env_vars if not os.getenv(v)] + if not missing: + if verbose: + logger.info(f"All {len(env_vars)} env vars set for model {model_info.get('model')}.") + newly_acquired_keys[api_key_field] = False + return True + + # Vertex AI ADC fallback: GOOGLE_APPLICATION_CREDENTIALS may be unset + # if the user ran ``gcloud auth application-default login`` instead. + if "GOOGLE_APPLICATION_CREDENTIALS" in env_vars and "GOOGLE_APPLICATION_CREDENTIALS" in missing: + project = os.getenv("VERTEXAI_PROJECT") or os.getenv("GOOGLE_CLOUD_PROJECT") + if project: + remaining = [v for v in missing if v != "GOOGLE_APPLICATION_CREDENTIALS"] + if not remaining: + logger.info(f"Using ADC for Vertex AI (project={project}).") + newly_acquired_keys[api_key_field] = False + return True + + logger.warning( + f"Multi-credential provider for model '{model_info.get('model')}' " + f"is missing env vars: {', '.join(missing)}. " + f"Run 'pdd setup' to configure." + ) + return False + + # --- Single-credential provider (original behaviour) --- + key_name = env_vars[0] key_value = os.getenv(key_name) if key_value: @@ -1181,58 +1221,50 @@ def _ensure_api_key(model_info: Dict[str, Any], newly_acquired_keys: Dict[str, b if key_value: if verbose: logger.info(f"API key '{key_name}' found in environment.") - newly_acquired_keys[key_name] = False # Mark as existing + newly_acquired_keys[key_name] = False # Mark as existing return True - else: - # For Vertex AI, allow ADC when project is available - if key_name == 'VERTEX_CREDENTIALS': - vertex_project = os.getenv("VERTEX_PROJECT") or os.getenv("GOOGLE_CLOUD_PROJECT") - if vertex_project: - logger.info(f"VERTEX_CREDENTIALS not set; using ADC (project={vertex_project}).") - newly_acquired_keys[key_name] = False - return True - logger.warning(f"API key environment variable '{key_name}' for model '{model_info.get('model')}' is not set.") + logger.warning(f"API key environment variable '{key_name}' for model '{model_info.get('model')}' is not set.") - # Skip prompting if --force flag is set (non-interactive mode) - if os.environ.get('PDD_FORCE'): - logger.error(f"API key '{key_name}' not set. In --force mode, skipping interactive prompt.") + # Skip prompting if --force flag is set (non-interactive mode) + if os.environ.get('PDD_FORCE'): + logger.error(f"API key '{key_name}' not set. In --force mode, skipping interactive prompt.") + return False + + try: + # Interactive prompt + user_provided_key = input(f"Please enter the API key for {key_name}: ").strip() + if not user_provided_key: + logger.error("No API key provided. Cannot proceed with this model.") return False - try: - # Interactive prompt - user_provided_key = input(f"Please enter the API key for {key_name}: ").strip() - if not user_provided_key: - logger.error("No API key provided. Cannot proceed with this model.") - return False - - # Sanitize the user-provided key - user_provided_key = _sanitize_api_key(user_provided_key) - - # Set environment variable for the current process - os.environ[key_name] = user_provided_key - logger.info(f"API key '{key_name}' set for the current session.") - newly_acquired_keys[key_name] = True # Mark as newly acquired + # Sanitize the user-provided key + user_provided_key = _sanitize_api_key(user_provided_key) - # Update .env file - try: - _save_key_to_env_file(key_name, user_provided_key, ENV_PATH) - logger.info(f"API key '{key_name}' saved to {ENV_PATH}.") - logger.warning("SECURITY WARNING: The API key has been saved to your .env file. " - "Ensure this file is kept secure and is included in your .gitignore.") + # Set environment variable for the current process + os.environ[key_name] = user_provided_key + logger.info(f"API key '{key_name}' set for the current session.") + newly_acquired_keys[key_name] = True # Mark as newly acquired - except IOError as e: - logger.error(f"Failed to update .env file at {ENV_PATH}: {e}") - # Continue since the key is set in the environment for this session + # Update .env file + try: + _save_key_to_env_file(key_name, user_provided_key, ENV_PATH) + logger.info(f"API key '{key_name}' saved to {ENV_PATH}.") + logger.warning("SECURITY WARNING: The API key has been saved to your .env file. " + "Ensure this file is kept secure and is included in your .gitignore.") - return True + except IOError as e: + logger.error(f"Failed to update .env file at {ENV_PATH}: {e}") + # Continue since the key is set in the environment for this session - except EOFError: # Handle non-interactive environments - logger.error(f"Cannot prompt for API key '{key_name}' in a non-interactive environment.") - return False - except Exception as e: - logger.error(f"An unexpected error occurred during API key acquisition: {e}") - return False + return True + + except EOFError: # Handle non-interactive environments + logger.error(f"Cannot prompt for API key '{key_name}' in a non-interactive environment.") + return False + except Exception as e: + logger.error(f"An unexpected error occurred during API key acquisition: {e}") + return False def _format_messages(prompt: str, input_data: Union[Dict[str, Any], List[Dict[str, Any]]], use_batch_mode: bool) -> Union[List[Dict[str, str]], List[List[Dict[str, str]]]]: @@ -1910,83 +1942,35 @@ def calc_strength(candidate): "num_retries": 2, } - api_key_name_from_csv = model_info.get('api_key') # From CSV - # Determine if it's a Vertex AI model for special handling - is_vertex_model = (provider.lower() == 'google') or \ - (provider.lower() == 'googlevertexai') or \ - (provider.lower() == 'vertex_ai') or \ - model_name_litellm.startswith('vertex_ai/') - - if is_vertex_model and api_key_name_from_csv == 'VERTEX_CREDENTIALS': - vertex_project_env = os.getenv("VERTEX_PROJECT") - # Resolve location: CSV override → env var fallback - model_location = model_info.get('location') - if pd.notna(model_location) and str(model_location).strip(): - vertex_location_env = str(model_location).strip() - if verbose: - logger.info(f"[INFO] Using per-model location override: '{vertex_location_env}' for model '{model_name_litellm}'") - else: - vertex_location_env = os.getenv("VERTEX_LOCATION") - - if vertex_project_env and vertex_location_env: - litellm_kwargs["vertex_project"] = vertex_project_env - litellm_kwargs["vertex_location"] = vertex_location_env - # Optionally load explicit credentials file - credentials_file_path = os.getenv("VERTEX_CREDENTIALS") - if credentials_file_path: - try: - with open(credentials_file_path, 'r') as f: - loaded_credentials = json.load(f) - litellm_kwargs["vertex_credentials"] = json.dumps(loaded_credentials) - if verbose: - logger.info(f"[INFO] For Vertex AI: using vertex_credentials from '{credentials_file_path}', project '{vertex_project_env}', location '{vertex_location_env}'.") - except (FileNotFoundError, json.JSONDecodeError) as e: - if verbose: - logger.info(f"[INFO] No credentials file ({e}); using ADC.") - except Exception as e: - if verbose: - logger.error(f"[ERROR] Failed to load Vertex credentials from '{credentials_file_path}': {e}. Using ADC.") - elif verbose: - logger.info(f"[INFO] Using ADC for Vertex AI (project={vertex_project_env}, location={vertex_location_env})") - else: - if verbose: - logger.warning(f"[WARN] Missing VERTEX_PROJECT or VERTEX_LOCATION for {model_name_litellm}.") - if not vertex_project_env: logger.warning(f" Reason: VERTEX_PROJECT env var not set or empty.") - if not vertex_location_env: logger.warning(f" Reason: VERTEX_LOCATION env var not set or empty.") - logger.warning(f" LiteLLM may attempt to use Application Default Credentials or the call may fail.") + # --- Resolve API key / credentials --- + # The CSV api_key field may be: + # - Single env var (e.g. "ANTHROPIC_API_KEY") → pass as api_key= + # - Pipe-delimited (e.g. "VAR1|VAR2|VAR3") → litellm reads from env + # - Empty (device flow / local) → no api_key needed + from pdd.provider_manager import parse_api_key_vars - elif api_key_name_from_csv: # For other api_key_names specified in CSV (e.g., OPENAI_API_KEY, or a direct VERTEX_AI_API_KEY string) - key_value = os.getenv(api_key_name_from_csv) + api_key_field = str(model_info.get('api_key', '') or '') + env_vars = parse_api_key_vars(api_key_field) + + if len(env_vars) == 1: + # Simple provider: pass env var value as api_key= + key_value = os.getenv(env_vars[0]) if key_value: key_value = _sanitize_api_key(key_value) litellm_kwargs["api_key"] = key_value if verbose: - logger.info(f"[INFO] Explicitly passing API key from env var '{api_key_name_from_csv}' as 'api_key' parameter to LiteLLM.") - - # If this model is Vertex AI AND uses a direct API key string (not VERTEX_CREDENTIALS from CSV), - # also pass project and location from env vars. - if is_vertex_model: - vertex_project_env = os.getenv("VERTEX_PROJECT") - # Check for per-model location override, fall back to env var - model_location = model_info.get('location') - if pd.notna(model_location) and str(model_location).strip(): - vertex_location_env = str(model_location).strip() - if verbose: - logger.info(f"[INFO] Using per-model location override: '{vertex_location_env}' for model '{model_name_litellm}'") - else: - vertex_location_env = os.getenv("VERTEX_LOCATION") - if vertex_project_env and vertex_location_env: - litellm_kwargs["vertex_project"] = vertex_project_env - litellm_kwargs["vertex_location"] = vertex_location_env - if verbose: - logger.info(f"[INFO] For Vertex AI model (using direct API key '{api_key_name_from_csv}'), also passing vertex_project='{vertex_project_env}' and vertex_location='{vertex_location_env}' from env vars.") - elif verbose: - logger.warning(f"[WARN] For Vertex AI model (using direct API key '{api_key_name_from_csv}'), VERTEX_PROJECT or VERTEX_LOCATION env vars not set. This might be required by LiteLLM.") - elif verbose: # api_key_name_from_csv was in CSV, but corresponding env var was not set/empty - logger.warning(f"[WARN] API key name '{api_key_name_from_csv}' found in CSV, but the environment variable '{api_key_name_from_csv}' is not set or empty. LiteLLM will use default authentication if applicable (e.g., other standard env vars or ADC).") - - elif verbose: # No api_key_name_from_csv in CSV for this model - logger.info(f"[INFO] No API key name specified in CSV for model '{model_name_litellm}'. LiteLLM will use its default authentication mechanisms (e.g., standard provider env vars or ADC for Vertex AI).") + logger.info(f"[INFO] Passing API key from '{env_vars[0]}' to LiteLLM.") + elif verbose: + logger.warning(f"[WARN] Env var '{env_vars[0]}' not set. LiteLLM will use default auth.") + elif len(env_vars) > 1: + # Multi-credential provider (Bedrock, Azure, Vertex AI, etc.) + # litellm reads these env vars from os.environ automatically. + if verbose: + logger.info(f"[INFO] Multi-credential provider; litellm reads env vars: {env_vars}") + else: + # Empty api_key — device flow (GitHub Copilot) or local model + if verbose: + logger.info(f"[INFO] No API key for '{model_name_litellm}'; using device flow or default auth.") # Add base_url/api_base override if present in CSV api_base = model_info.get('base_url') @@ -2412,10 +2396,9 @@ def calc_strength(candidate): logger.info(f"[SUCCESS] Invocation successful for {model_name_litellm} (took {end_time - start_time:.2f}s)") # Build retry kwargs with provider credentials from litellm_kwargs - # Issue #185: Retry calls were missing vertex_location, vertex_project, etc. retry_provider_kwargs = {k: v for k, v in litellm_kwargs.items() - if k in ('vertex_credentials', 'vertex_project', 'vertex_location', - 'api_key', 'base_url', 'api_base')} + if k in ('api_key', 'base_url', 'api_base', + 'api_version')} # --- 7. Process Response --- results = [] diff --git a/pdd/model_tester.py b/pdd/model_tester.py index 6b7a1ace1..993cb92f4 100644 --- a/pdd/model_tester.py +++ b/pdd/model_tester.py @@ -1,6 +1,8 @@ from __future__ import annotations import os +import sys +import threading import time as time_module from pathlib import Path from typing import Any, Dict, List, Optional, Tuple @@ -103,6 +105,36 @@ def _resolve_base_url(row: Dict[str, Any]) -> Optional[str]: return None +def _resolve_provider_auth(row: Dict[str, Any]) -> List[Tuple[str, str, bool]]: + """Resolve all auth-related env vars for a model row. + + Returns a list of (label, status_string, is_ok) tuples. + Driven by the CSV api_key field (pipe-delimited for multi-credential providers). + """ + from pdd.provider_manager import parse_api_key_vars + + api_key_field = str(row.get("api_key", "")).strip() + env_vars = parse_api_key_vars(api_key_field) + + if not env_vars: + # Empty api_key — device flow (e.g. GitHub Copilot) or local model + return [("Auth", "Device flow / no key needed", True)] + + results: List[Tuple[str, str, bool]] = [] + for var in env_vars: + value = os.getenv(var, "") + if value: + # Extra validation for credential file paths + if var == "GOOGLE_APPLICATION_CREDENTIALS" and not Path(value).is_file(): + results.append((var, f"⚠ Path set but file not found ({var})", False)) + else: + results.append((var, f"✓ Found ({var})", True)) + else: + results.append((var, f"✗ Not found ({var})", False)) + + return results + + def _calculate_cost( prompt_tokens: int, completion_tokens: int, @@ -148,56 +180,34 @@ def _run_test(row: Dict[str, Any]) -> Dict[str, Any]: Returns a dict with keys: success, duration_s, cost, error, tokens. """ import litellm + from pdd.provider_manager import parse_api_key_vars model_name: str = str(row.get("model", "")) - api_key, _key_status = _resolve_api_key(row) base_url = _resolve_base_url(row) kwargs: Dict[str, Any] = { "model": model_name, "messages": [{"role": "user", "content": "Say OK"}], - "timeout": 30, + "timeout": 8, } - # Only pass api_key if we have one; otherwise litellm uses its defaults - if api_key: - kwargs["api_key"] = api_key + # Resolve API key using the pipe-delimited convention: + # Single var → pass as api_key= + # Multi var → litellm reads from os.environ (don't pass api_key=) + # Empty → device flow / local (don't pass api_key=) + api_key_field = str(row.get("api_key", "")).strip() + env_vars = parse_api_key_vars(api_key_field) + + if len(env_vars) == 1: + key_value = os.getenv(env_vars[0], "") + if key_value: + kwargs["api_key"] = key_value.strip() + # Multi-var and empty: litellm reads env vars automatically if base_url: kwargs["base_url"] = base_url kwargs["api_base"] = base_url - # Vertex AI handling - is_vertex = model_name.startswith("vertex_ai/") or str(row.get("provider", "")).lower() in ( - "google", - "vertex_ai", - "googlevertexai", - ) - key_name = str(row.get("api_key", "")).strip() - if is_vertex and key_name == "VERTEX_CREDENTIALS": - creds_path = os.getenv("VERTEX_CREDENTIALS", "") - project = os.getenv("VERTEX_PROJECT", "") - location_csv = str(row.get("location", "")).strip() - location = location_csv if location_csv else os.getenv("VERTEX_LOCATION", "") - - if creds_path: - try: - import json as _json - - with open(creds_path, "r") as f: - creds = _json.load(f) - kwargs["vertex_credentials"] = _json.dumps(creds) - except Exception: - pass # Will likely fail at call time with a clear error - - if project: - kwargs["vertex_project"] = project - if location: - kwargs["vertex_location"] = location - - # Remove api_key for vertex — it uses credentials instead - kwargs.pop("api_key", None) - start = time_module.time() try: response = litellm.completion(**kwargs) @@ -312,7 +322,7 @@ def test_model_interactive() -> None: try: choice = console.input( - "[bold cyan]Enter model number to test (or q/empty to quit): [/bold cyan]" + "[bold cyan]Enter model number to test (or empty to quit): [/bold cyan]" ).strip() except (EOFError, KeyboardInterrupt): console.print("\n[dim]Exiting model tester.[/dim]") @@ -341,36 +351,58 @@ def test_model_interactive() -> None: console.print(f"[bold]Testing: [bright_white]{model_name}[/bright_white] ({provider})[/bold]") console.print("─" * 50) - # Diagnostics: API key - api_key, key_status = _resolve_api_key(row) - if "✓" in key_status: - console.print(f" API Key: [green]{key_status}[/green]") - elif "no key configured" in key_status: - console.print(f" API Key: [yellow]{key_status}[/yellow]") - else: - console.print(f" API Key: [red]{key_status}[/red]") + # Diagnostics: provider authentication + auth_checks = _resolve_provider_auth(row) + for label, status_str, is_ok in auth_checks: + color = "green" if is_ok else "red" + console.print(f" {label + ':':<13s}[{color}]{status_str}[/{color}]") # Diagnostics: base URL base_url = _resolve_base_url(row) if base_url: console.print(f" Base URL: [dim]{base_url}[/dim]") - # Diagnostics: Vertex AI specifics - key_name = str(row.get("api_key", "")).strip() - if key_name == "VERTEX_CREDENTIALS": - project = os.getenv("VERTEX_PROJECT", "") - location_csv = str(row.get("location", "")).strip() - location = location_csv if location_csv else os.getenv("VERTEX_LOCATION", "") - if project: - console.print(f" Project: [dim]{project}[/dim]") - if location: - console.print(f" Location: [dim]{location}[/dim]") - console.print() - console.print(" [dim]Sending test prompt...[/dim]") + sys.stdout.write(" Sending test prompt...") + sys.stdout.flush() + + # Run the test in a thread, printing dots while waiting + test_result_holder: List[Optional[Dict[str, Any]]] = [None] + + def _do_test() -> None: + test_result_holder[0] = _run_test(row) + + t = threading.Thread(target=_do_test, daemon=True) + t.start() + + elapsed = 0.0 + while t.is_alive() and elapsed < 8.0: + t.join(timeout=1.0) + if t.is_alive(): + sys.stdout.write(".") + sys.stdout.flush() + elapsed += 1.0 + + if t.is_alive(): + # Timeout — thread is still running; don't wait further + sys.stdout.write("\n") + result = { + "success": False, + "duration_s": elapsed, + "cost": 0.0, + "error": "Request timed out (8s)", + "tokens": None, + } + else: + sys.stdout.write("\n") + result = test_result_holder[0] or { + "success": False, + "duration_s": 0.0, + "cost": 0.0, + "error": "Unknown error", + "tokens": None, + } - # Run the test - result = _run_test(row) results[idx] = result if result["success"]: diff --git a/pdd/pddrc_initializer.py b/pdd/pddrc_initializer.py index 3443aa1f9..9021ba205 100644 --- a/pdd/pddrc_initializer.py +++ b/pdd/pddrc_initializer.py @@ -35,7 +35,7 @@ # Standard defaults STANDARD_DEFAULTS: dict[str, float | int] = { - "strength": 1.0, + "strength": 0.818, "temperature": 0.0, "target_coverage": 80.0, "budget": 10.0, diff --git a/pdd/prompts/litellm_registry_python.prompt b/pdd/prompts/litellm_registry_python.prompt deleted file mode 100644 index b0db6ea7e..000000000 --- a/pdd/prompts/litellm_registry_python.prompt +++ /dev/null @@ -1,57 +0,0 @@ -Wraps litellm's bundled model registry to provide provider search, model browsing, and API key env var lookup without network calls. - - -{ - "type": "module", - "module": { - "functions": [ - {"name": "is_litellm_available", "signature": "() -> bool", "returns": "bool"}, - {"name": "get_api_key_env_var", "signature": "(provider: str) -> Optional[str]", "returns": "Optional[str]"}, - {"name": "get_top_providers", "signature": "() -> List[ProviderInfo]", "returns": "List[ProviderInfo]"}, - {"name": "get_all_providers", "signature": "() -> List[ProviderInfo]", "returns": "List[ProviderInfo]"}, - {"name": "search_providers", "signature": "(query: str) -> List[ProviderInfo]", "returns": "List[ProviderInfo]"}, - {"name": "get_models_for_provider", "signature": "(provider: str) -> List[ModelInfo]", "returns": "List[ModelInfo]"} - ], - "dataclasses": [ - {"name": "ProviderInfo", "fields": ["name", "display_name", "api_key_env_var", "model_count", "sample_models"]}, - {"name": "ModelInfo", "fields": ["name", "litellm_id", "input_cost_per_million", "output_cost_per_million", "max_input_tokens", "max_output_tokens", "supports_vision", "supports_function_calling"]} - ] - } -} - - -% You are an expert Python engineer. Your goal is to write the pdd/litellm_registry.py module. - -% Role & Scope -Thin wrapper around litellm's bundled data (`litellm.model_cost`, `litellm.models_by_provider`) for provider discovery and model browsing. Uses only locally bundled data — never makes network calls. Provides the data layer for the "Search providers" flow in `pdd setup`. - -% Requirements -1. Dataclass `ProviderInfo(name, display_name, api_key_env_var, model_count, sample_models)` — summary of a provider with up to 3 sample model names. -2. Dataclass `ModelInfo(name, litellm_id, input_cost_per_million, output_cost_per_million, max_input_tokens, max_output_tokens, supports_vision, supports_function_calling)` — metadata for a single model. -3. `is_litellm_available() -> bool` — guards against import failure, returns True only if litellm is importable and has model_cost data. -4. `get_api_key_env_var(provider) -> Optional[str]` — returns the API key env var name from a hardcoded mapping of ~25 common providers (e.g. "anthropic" → "ANTHROPIC_API_KEY"). Returns None for unknown providers. -5. `get_top_providers() -> List[ProviderInfo]` — returns a curated list of ~10 major cloud providers in a fixed display order. Falls back to all providers sorted by model count if curated list yields too few. -6. `get_all_providers() -> List[ProviderInfo]` — returns all providers with at least one chat model, sorted by model count descending. -7. `search_providers(query) -> List[ProviderInfo]` — case-insensitive substring match on provider name and display name. Empty query returns all providers. -8. `get_models_for_provider(provider) -> List[ModelInfo]` — returns chat-mode models sorted by name. Converts per-token costs to per-million costs. -9. Filter to `mode == "chat"` models only when browsing. -10. Handle vertex_ai sub-providers: aggregate all `litellm_provider` values starting with "vertex_ai" when querying vertex_ai. -11. Fallback: when `models_by_provider` entries aren't in `model_cost`, scan `model_cost` by `litellm_provider` field to catch mismatches (e.g. together_ai). -12. All litellm imports must be local (inside functions), guarded by `is_litellm_available()`. -13. Hardcoded `PROVIDER_DISPLAY_NAMES` dict for human-friendly names (e.g. "fireworks_ai" → "Fireworks AI"). Fallback: replace underscores and title-case. - -% Dependencies - -litellm.model_cost: Dict[str, dict] — ~2566 entries keyed by model ID. - Each entry has: litellm_provider, mode, input_cost_per_token, output_cost_per_token, - max_input_tokens, max_output_tokens, supports_vision, supports_function_calling, - supports_response_schema, and more. - -litellm.models_by_provider: Dict[str, Set[str]] — 86 providers mapped to sets of model names. - -Note: model_cost has NO api_key_env_var field. Provider-to-key mapping must be hardcoded. - - -% Deliverables -- Module at `pdd/litellm_registry.py` exporting `ProviderInfo`, `ModelInfo`, `is_litellm_available`, `get_api_key_env_var`, `get_top_providers`, `get_all_providers`, `search_providers`, `get_models_for_provider`. -- Also exports constants `PROVIDER_API_KEY_MAP` and `PROVIDER_DISPLAY_NAMES` for use by other modules. diff --git a/pdd/provider_manager.py b/pdd/provider_manager.py index 5fcb1be26..bcba39588 100644 --- a/pdd/provider_manager.py +++ b/pdd/provider_manager.py @@ -9,21 +9,12 @@ import tempfile from datetime import datetime from pathlib import Path -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional from rich.console import Console from rich.table import Table from rich.prompt import Prompt, Confirm -from pdd.litellm_registry import ( - is_litellm_available, - get_top_providers, - search_providers, - get_models_for_provider, - get_api_key_env_var, - ProviderInfo, - ModelInfo, -) console = Console() @@ -34,6 +25,134 @@ "reasoning_type", "location", ] +# --------------------------------------------------------------------------- +# Pipe-delimited api_key helpers +# --------------------------------------------------------------------------- +# The CSV api_key column can contain multiple env var names separated by "|". +# Single var → pass as api_key= to litellm. Multi-var → litellm reads from +# os.environ automatically (Bedrock, Azure, Vertex AI). Empty → device flow +# or local model (GitHub Copilot, Ollama). + + +def parse_api_key_vars(api_key_field: str) -> List[str]: + """Split the pipe-delimited api_key CSV field into individual env var names. + + Returns an empty list if the field is empty/blank. + """ + if not api_key_field or not api_key_field.strip(): + return [] + return [v.strip() for v in api_key_field.split("|") if v.strip()] + + +def is_multi_credential(api_key_field: str) -> bool: + """Return True if the api_key field contains multiple env vars (pipe-delimited).""" + return "|" in (api_key_field or "") + + +# --------------------------------------------------------------------------- +# Complex provider authentication registry +# --------------------------------------------------------------------------- +# Providers that require multi-variable auth (not just a single API key). +# Maps provider display name (as in CSV) -> list of env var configs. +# Used by _setup_complex_provider() for interactive credential prompting. + +COMPLEX_AUTH_PROVIDERS: Dict[str, List[Dict[str, Any]]] = { + "Google Vertex AI": [ + { + "env_var": "GOOGLE_APPLICATION_CREDENTIALS", + "label": "Credentials", + "required": True, + "default": None, + "hint": "Path to GCP service account JSON (or 'adc' for Application Default Credentials)", + }, + { + "env_var": "VERTEXAI_PROJECT", + "label": "GCP Project", + "required": True, + "default": None, + "hint": "Google Cloud project ID", + }, + { + "env_var": "VERTEXAI_LOCATION", + "label": "Location", + "required": True, + "default": "us-central1", + "hint": "GCP region (e.g. us-central1)", + }, + ], + "AWS Bedrock": [ + { + "env_var": "AWS_ACCESS_KEY_ID", + "label": "Access Key ID", + "required": True, + "default": None, + "hint": "AWS IAM access key ID", + }, + { + "env_var": "AWS_SECRET_ACCESS_KEY", + "label": "Secret Key", + "required": True, + "default": None, + "hint": "AWS IAM secret access key", + }, + { + "env_var": "AWS_REGION_NAME", + "label": "Region", + "required": True, + "default": "us-east-1", + "hint": "AWS region (e.g. us-east-1)", + }, + ], + "Azure OpenAI": [ + { + "env_var": "AZURE_API_KEY", + "label": "API Key", + "required": True, + "default": None, + "hint": "Azure OpenAI resource key", + }, + { + "env_var": "AZURE_API_BASE", + "label": "Endpoint", + "required": True, + "default": None, + "hint": "Azure OpenAI endpoint URL (e.g. https://myresource.openai.azure.com/)", + }, + { + "env_var": "AZURE_API_VERSION", + "label": "API Version", + "required": True, + "default": "2024-10-21", + "hint": "Azure API version string", + }, + ], + "Azure AI": [ + { + "env_var": "AZURE_AI_API_KEY", + "label": "API Key", + "required": True, + "default": None, + "hint": "Azure AI Foundry API key", + }, + { + "env_var": "AZURE_AI_API_BASE", + "label": "Endpoint", + "required": False, + "default": None, + "hint": "Optional: Azure AI endpoint URL", + }, + ], + "Github Copilot": [ + { + "env_var": "GITHUB_COPILOT_API_KEY", + "label": "API Key", + "required": False, + "default": None, + "hint": "Optional: GitHub Copilot uses device flow auth at runtime", + }, + ], +} + # --------------------------------------------------------------------------- # Path helpers # --------------------------------------------------------------------------- @@ -327,155 +446,144 @@ def _is_key_set(key_name: str) -> Optional[str]: # Public API # --------------------------------------------------------------------------- -def add_provider_from_registry() -> bool: - """ - Search/browse LiteLLM's model registry, let the user pick a provider - and specific models, handle the API key, and save to user CSV. +def _get_ref_csv_path() -> Path: + """Return path to the bundled reference CSV.""" + return Path(__file__).parent / "data" / "llm_model.csv" - Returns True if any models were added, False if cancelled. - """ - if not is_litellm_available(): - console.print( - "[red]litellm is required but not installed or has no model data.[/red]\n" - "[yellow]Run: pip install litellm[/yellow]\n" - "[yellow]Or use 'Add a custom provider' instead.[/yellow]" - ) - return False - # ── Step 1: Browse / Search providers ────────────────────────────── +def _setup_complex_provider(provider_name: str) -> bool: + """Run interactive auth setup for a complex (multi-variable) provider. - top = get_top_providers() + Prompts for each required env var and saves to api-env. + Returns True if at least one credential was configured, False if all skipped. + """ + var_configs = COMPLEX_AUTH_PROVIDERS.get(provider_name) + if not var_configs: + return False - console.print("\n[bold]Search providers[/bold]\n") - console.print(" Top providers:") - for idx, p in enumerate(top, 1): + required_names = [c["label"] for c in var_configs if c["required"]] + optional_names = [c["label"] for c in var_configs if not c["required"]] + print() + console.print(f" [bold]{provider_name} Setup[/bold]") + if required_names: + console.print(f" Required: {', '.join(required_names)}") + if optional_names: + console.print(f" Optional: {', '.join(optional_names)}") + + # GitHub Copilot: explain device flow before prompting + if provider_name == "Github Copilot": console.print( - f" {idx:>2}. {p.display_name:20s} ({p.model_count} chat models)" + "\n [dim]GitHub Copilot authenticates via device flow at runtime.\n" + " You can paste an API key now, or skip and authenticate later.[/dim]" ) - console.print() + print() - selection = Prompt.ask( - "Enter number, or type to search (empty to cancel)" - ) - if not selection.strip(): - console.print("[dim]Cancelled.[/dim]") - return False + any_saved = False + for cfg in var_configs: + env_var = cfg["env_var"] + label = cfg["label"] + required = cfg["required"] + default = cfg["default"] + hint = cfg["hint"] - selected_provider: Optional[ProviderInfo] = None + existing_source = _is_key_set(env_var) + if existing_source: + console.print(f" [green]✓[/green] {label} already set ({existing_source})") + if not Confirm.ask(" Update?", default=False): + continue - # Try as a number first (direct selection from top list) - try: - choice = int(selection.strip()) - if 1 <= choice <= len(top): - selected_provider = top[choice - 1] - except ValueError: - pass + opt_tag = " [dim](optional)[/dim]" if not required else "" + if default: + value = Prompt.ask(f" {label}{opt_tag} [dim]{hint}[/dim]", default=default) + else: + value = Prompt.ask(f" {label}{opt_tag} [dim]{hint}[/dim]", default="") - # If not a valid number, treat as search query - if selected_provider is None: - results = search_providers(selection.strip()) - if not results: - console.print( - f"[yellow]No providers matching '{selection.strip()}'.[/yellow]\n" - "[yellow]Try a different search, or use 'Add a custom provider'.[/yellow]" - ) - return False + value = value.strip() + if not value: + if not required: + continue + console.print(f" [yellow]Skipped[/yellow]") + continue - if len(results) == 1: - selected_provider = results[0] - else: - console.print(f"\n Found {len(results)} provider(s):") - for idx, p in enumerate(results, 1): + # Vertex AI: special handling for credentials path + if env_var == "GOOGLE_APPLICATION_CREDENTIALS": + if value.lower() == "adc": console.print( - f" {idx:>2}. {p.display_name:20s} ({p.model_count} chat models)" + " [dim]Using Application Default Credentials.\n" + " Make sure you've run: gcloud auth application-default login[/dim]" ) - console.print() - - pick = Prompt.ask("Select provider number (empty to cancel)") - if not pick.strip(): - console.print("[dim]Cancelled.[/dim]") - return False - try: - pick_idx = int(pick.strip()) - if 1 <= pick_idx <= len(results): - selected_provider = results[pick_idx - 1] - else: - console.print("[red]Invalid selection.[/red]") - return False - except ValueError: - console.print("[red]Invalid input.[/red]") - return False - - assert selected_provider is not None - - # ── Step 2: Model selection ──────────────────────────────────────── - - models = get_models_for_provider(selected_provider.name) - if not models: - console.print( - f"[yellow]No chat models found for {selected_provider.display_name} " - f"in litellm's registry.[/yellow]\n" - "[yellow]Use 'Add a custom provider' instead.[/yellow]" - ) - return False + continue + if not Path(value).exists(): + console.print(f" [yellow]Warning: file not found at {value}[/yellow]") - table = Table(title=f"Chat models for {selected_provider.display_name}") - table.add_column("#", style="bold", width=4) - table.add_column("Model") - table.add_column("Input $/M", justify="right") - table.add_column("Output $/M", justify="right") - table.add_column("Max Input", justify="right") + _save_key_to_api_env(env_var, value) + console.print(f" [green]✓ Saved[/green]") + any_saved = True - for idx, m in enumerate(models, 1): - input_cost = f"${m.input_cost_per_million:.2f}" if m.input_cost_per_million else "$0.00" - output_cost = f"${m.output_cost_per_million:.2f}" if m.output_cost_per_million else "$0.00" - max_input = f"{m.max_input_tokens:,}" if m.max_input_tokens else "—" - table.add_row(str(idx), m.litellm_id, input_cost, output_cost, max_input) + if any_saved: + _ensure_api_env_sourced_in_rc() + console.print("\n [dim]Credentials available for this session.[/dim]") - console.print(table) - console.print() + return any_saved - model_selection = Prompt.ask( - "Select models (comma-separated numbers, 'all', or empty to cancel)" - ) - if not model_selection.strip(): - console.print("[dim]Cancelled.[/dim]") + +def add_provider_from_registry() -> bool: + """ + Browse providers from the reference CSV, let the user pick one, + handle the API key, and add its models to the user CSV. + + Returns True if any models were added, False if cancelled. + """ + # ── Step 1: List providers from reference CSV ───────────────────── + + ref_rows = _read_csv(_get_ref_csv_path()) + if not ref_rows: + console.print("[yellow]No models found in reference CSV.[/yellow]") return False - selected_models: List[ModelInfo] = [] + # Build unique provider list with model counts and api_key + provider_info: Dict[str, Dict[str, object]] = {} + for row in ref_rows: + provider = row.get("provider", "").strip() + api_key = row.get("api_key", "").strip() + if not provider: + continue + if provider not in provider_info: + provider_info[provider] = {"api_key": api_key, "count": 0} + provider_info[provider]["count"] = int(provider_info[provider]["count"]) + 1 + + sorted_providers = sorted(provider_info.keys()) + + console.print("\n[bold]Add a provider[/bold]\n") + for idx, prov in enumerate(sorted_providers, 1): + info = provider_info[prov] + count = info["count"] + s = "s" if count != 1 else "" + console.print(f" {idx:>2}. {prov:25s} ({count} model{s})") + console.print() - if model_selection.strip().lower() == "all": - selected_models = list(models) - else: - for part in model_selection.split(","): - part = part.strip() - if not part: - continue - try: - num = int(part) - if 1 <= num <= len(models): - selected_models.append(models[num - 1]) - else: - console.print(f"[yellow]Skipping invalid number: {num}[/yellow]") - except ValueError: - console.print(f"[yellow]Skipping invalid input: '{part}'[/yellow]") - - if not selected_models: - console.print("[dim]No valid selections. Cancelled.[/dim]") + selection = Prompt.ask("Enter number (empty to cancel)") + if not selection.strip(): + console.print("[dim]Cancelled.[/dim]") return False - # ── Step 3: API key ──────────────────────────────────────────────── + try: + choice = int(selection.strip()) + if choice < 1 or choice > len(sorted_providers): + console.print("[red]Invalid selection.[/red]") + return False + except ValueError: + console.print("[red]Invalid input.[/red]") + return False - api_key_var = selected_provider.api_key_env_var + selected_provider = sorted_providers[choice - 1] + api_key_var = str(provider_info[selected_provider]["api_key"]) or None - if api_key_var is None: - # Provider not in our known mapping — ask the user - api_key_var = Prompt.ask( - f"API key env var for {selected_provider.display_name} " - "(e.g. PROVIDER_API_KEY, or empty to skip)" - ).strip() or None + # ── Step 2: Provider authentication ────────────────────────────── - if api_key_var: + if selected_provider in COMPLEX_AUTH_PROVIDERS: + _setup_complex_provider(selected_provider) + elif api_key_var: existing_source = _is_key_set(api_key_var) if existing_source: console.print( @@ -497,7 +605,10 @@ def add_provider_from_registry() -> bool: "[dim]Key is available now for this session.[/dim]" ) else: - key_value = Prompt.ask(f"Enter the value for {api_key_var}") + key_value = Prompt.ask( + f"Enter your {selected_provider} API key (or press Enter to skip)", + default="", + ) if key_value.strip(): _save_key_to_api_env(api_key_var, key_value.strip()) console.print( @@ -511,52 +622,38 @@ def add_provider_from_registry() -> bool: console.print( "[dim]Key is available now for this session.[/dim]" ) + else: + console.print( + f"[yellow]Note: No API key configured for {selected_provider}. " + f"The LLM may have limited capability.[/yellow]" + ) - # ── Step 4: Write to user CSV ────────────────────────────────────── + # ── Step 3: Add all models for this provider to user CSV ────────── + + provider_rows = [ + row for row in ref_rows + if row.get("provider", "").strip() == selected_provider + ] user_csv_path = _get_user_csv_path() existing_rows = _read_csv(user_csv_path) - - # Build set of existing model identifiers to avoid duplicates - existing_model_ids = { - (r.get("provider", ""), r.get("model", "")) - for r in existing_rows - } + existing_models = {r.get("model", "").strip() for r in existing_rows} added_count = 0 - for m in selected_models: - # Build the litellm model ID with provider prefix convention - csv_model = m.litellm_id - - new_row: Dict[str, str] = { - "provider": selected_provider.display_name, - "model": csv_model, - "input": str(m.input_cost_per_million), - "output": str(m.output_cost_per_million), - "coding_arena_elo": "1000", - "base_url": "", - "api_key": api_key_var or "", - "max_reasoning_tokens": "0", - "structured_output": str(m.supports_function_calling), - "reasoning_type": "", - "location": "", - } - - model_id = (new_row["provider"], new_row["model"]) - if model_id not in existing_model_ids: - existing_rows.append(new_row) - existing_model_ids.add(model_id) + for row in provider_rows: + model = row.get("model", "").strip() + if model and model not in existing_models: + existing_rows.append(row) + existing_models.add(model) added_count += 1 - else: - console.print(f" [dim]Skipping duplicate: {csv_model}[/dim]") if added_count > 0: _write_csv_atomic(user_csv_path, existing_rows) console.print( - f"[green]Added {added_count} model(s) to {user_csv_path}[/green]" + f"[green]Added {added_count} model(s) for {selected_provider} to {user_csv_path}[/green]" ) else: - console.print("[yellow]No new models were added (all already configured).[/yellow]") + console.print("[yellow]All models for this provider are already configured.[/yellow]") return added_count > 0 diff --git a/pdd/setup_tool.py b/pdd/setup_tool.py index 5613d97cb..25d7d5161 100644 --- a/pdd/setup_tool.py +++ b/pdd/setup_tool.py @@ -8,15 +8,20 @@ from __future__ import annotations import getpass -import json import os -import urllib.error -import urllib.request +import sys from pathlib import Path from typing import Dict, List, Optional, Tuple from rich.console import Console as _RichConsole -_console = _RichConsole() +_console = _RichConsole(highlight=False) + +# ANSI escape codes for coloring (works without rich) +CYAN = "\033[36m" +WHITE = "\033[37m" +BOLD = "\033[1m" +RESET = "\033[0m" +LIGHT_HORIZONTAL = "\u2500" # Top providers shown when prompting for an API key (order = display order) _PROMPT_PROVIDERS = [ @@ -27,43 +32,75 @@ ] +def _print_pdd_logo() -> None: + """Print the PDD logo in ASCII art with ANSI colors.""" + logo = "\n".join( + [ + " +xxxxxxxxxxxxxxx+", + "xxxxxxxxxxxxxxxxxxxxx+", + "xxx +xx+ PROMPT", + "xxx x+ xx+ DRIVEN", + "xxx x+ xxx DEVELOPMENT\u00a9", + "xxx x+ xx+", + "xxx x+ xx+ COMMAND LINE INTERFACE", + "xxx x+ xxx", + "xxx +xx+ ", + "xxx +xxxxxxxxxxx+", + "xxx +xx+", + "xxx +xx+", + "xxx+xx+ WWW.PROMPTDRIVEN.AI", + "xxxx+", + "xx+", + ] + ) + print(f"{CYAN}{logo}{RESET}") + print() + print(f"{BOLD}{WHITE}Let's get set up quickly with a solid basic configuration!{RESET}") + print() + + def run_setup() -> None: - """Main entry point for pdd setup. Two-phase flow with fallback.""" + """Main entry point for pdd setup. Two-phase flow with post-setup menu.""" from pdd.cli_detector import detect_and_bootstrap_cli, CliBootstrapResult # ── Banner ──────────────────────────────────────────────────────────── - print() - print(" ╭──────────────────────────────╮") - print(" │ pdd setup │") - print(" ╰──────────────────────────────╯") - print() + _print_pdd_logo() try: # ── Phase 1 — CLI Bootstrap (interactive, 0–2 user inputs) ──────── - result: CliBootstrapResult = detect_and_bootstrap_cli() - - if result.cli_name == "": - print( - "Agentic features require at least one CLI tool. " - "Run `pdd setup` again when ready." - ) - return + results: list[CliBootstrapResult] = detect_and_bootstrap_cli() - if not result.api_key_configured: - print( - "Note: No API key configured. " - "The agent may have limited capability." - ) + for result in results: + if result.skipped: + pass + elif not result.api_key_configured: + _console.print( + f"[yellow]Note: No API key configured for {result.cli_name or 'the CLI'}. " + "The agent may have limited capability.[/yellow]" + ) # ── Phase 2 — Deterministic Auto-Configuration ──────────────────── - auto_success = _run_auto_phase() + auto_result = _run_auto_phase(results) - if not auto_success: - _run_fallback_menu() + if auto_result: + found_keys, _model_summary = auto_result + # Offer post-setup menu before final summary + try: + choice = input( + "\n Press Enter to finish, or 'm' for more options: " + ).strip() + except (EOFError, KeyboardInterrupt): + choice = "" + + if choice: + _run_options_menu() + else: + found_keys: list[tuple[str, str]] = [] + _console.print("\n [yellow]Setup incomplete. Use the menu to configure manually.[/yellow]") + _run_options_menu() - print() - _console.print("[green]Setup complete. Happy prompting![/green]") - print() + # ── Final summary (after menu, so it reflects any changes) ──────── + _print_exit_summary(found_keys, results) except KeyboardInterrupt: print("\nSetup interrupted — exiting.") @@ -74,34 +111,42 @@ def run_setup() -> None: # Phase 2 — Deterministic auto-configuration # --------------------------------------------------------------------------- -def _run_auto_phase() -> bool: - """Run 4 deterministic setup steps. Returns True on success.""" +def _print_step_banner(title: str) -> None: + """Print a cyan banner for a setup step.""" + print(f"\n{CYAN}{LIGHT_HORIZONTAL * 40}{RESET}") + print(f"{CYAN}{BOLD}{title}{RESET}") + print(f"{CYAN}{LIGHT_HORIZONTAL * 40}{RESET}") + + +def _run_auto_phase(cli_results=None) -> Optional[Tuple[List[Tuple[str, str]], Dict[str, int]]]: + """Run 3 deterministic setup steps. + + Returns (found_keys, model_summary) on success, or None on failure. + """ try: # Step 1: Scan API keys - print("\n[Step 1/4] Scanning for API keys...") + _print_step_banner("Scanning for API keys...") found_keys = _step1_scan_keys() - input("\nPress Enter to continue to the next step...") - - # Step 2: Configure models - print("\n[Step 2/4] Configuring models...") - model_summary = _step2_configure_models(found_keys) - input("\nPress Enter to continue to the next step...") + print() + _console.print("[blue]Press Enter to continue to the next step...[/blue]", end="") + input() - # Step 3: Local LLMs + .pddrc - print("\n[Step 3/4] Checking local LLMs and .pddrc...") - local_summary = _step3_local_llms_and_pddrc() - input("\nPress Enter to continue to the next step...") + # Step 2: Configure models + .pddrc + _print_step_banner("Configuring models...") + model_summary = _step2_configure_models_and_pddrc(found_keys) + print() + _console.print("[blue]Press Enter to continue to the next step...[/blue]", end="") + input() - # Step 4: Test + summary - print("\n[Step 4/4] Testing and summarizing...") - _step4_test_and_summary(found_keys, model_summary, local_summary) + # Step 3: Test + summary + _print_step_banner("Testing and summarizing...") + _step3_test_and_summary(found_keys, model_summary, cli_results) - return True + return (found_keys, model_summary) except Exception as exc: - print(f"\nAuto-configuration failed: {exc}") - print("Falling back to manual setup...") - return False + _console.print(f"\n[yellow]Auto-configuration failed: {exc}[/yellow]") + return None # --------------------------------------------------------------------------- @@ -109,21 +154,44 @@ def _run_auto_phase() -> bool: # --------------------------------------------------------------------------- def _step1_scan_keys() -> List[Tuple[str, str]]: - """Scan all known API key env vars across all sources. + """Scan API key env vars referenced in the reference CSV across all sources. Returns list of (key_name, source_label) for keys that were found. + Multi-credential providers (pipe-delimited api_key) are displayed as + grouped provider lines; single-var providers as individual lines. """ - from pdd.litellm_registry import PROVIDER_API_KEY_MAP + from pdd.provider_manager import _read_csv, parse_api_key_vars from pdd.api_key_scanner import _parse_api_env_file, _detect_shell # Ensure ~/.pdd exists pdd_dir = Path.home() / ".pdd" pdd_dir.mkdir(parents=True, exist_ok=True) - # Gather all unique env var names to check - all_key_names = sorted(set(PROVIDER_API_KEY_MAP.values())) + # Gather unique api_key field values from the reference CSV + ref_path = Path(__file__).parent / "data" / "llm_model.csv" + ref_rows = _read_csv(ref_path) - # Load sources once + # Build two sets: single-var keys and multi-var provider groups + single_var_keys: set = set() # e.g. {"ANTHROPIC_API_KEY", "OPENAI_API_KEY"} + multi_var_providers: Dict[str, List[str]] = {} # provider_name -> [var1, var2, ...] + all_individual_vars: set = set() # every individual var across all providers + + for row in ref_rows: + api_key_field = row.get("api_key", "").strip() + if not api_key_field: + continue + env_vars = parse_api_key_vars(api_key_field) + if len(env_vars) == 1: + single_var_keys.add(env_vars[0]) + all_individual_vars.add(env_vars[0]) + elif len(env_vars) > 1: + provider = row.get("provider", "").strip() or api_key_field + if provider not in multi_var_providers: + multi_var_providers[provider] = env_vars + for v in env_vars: + all_individual_vars.add(v) + + # Load all credential sources once dotenv_vals: Dict[str, str] = {} try: from dotenv import dotenv_values @@ -144,29 +212,62 @@ def _step1_scan_keys() -> List[Tuple[str, str]]: api_env_vals = _parse_api_env_file(api_env_path) api_env_label = f"~/.pdd/api-env.{shell_name}" - # Scan each key + def _find_source(var: str) -> Optional[str]: + if var in os.environ: + return "shell environment" + if var in api_env_vals: + return api_env_label + if var in dotenv_vals: + return ".env file" + return None + found_keys: List[Tuple[str, str]] = [] - max_name_len = max(len(k) for k in all_key_names) if all_key_names else 20 - for key_name in all_key_names: - if key_name in os.environ: - source = "shell environment" - found_keys.append((key_name, source)) - print(f" ✓ {key_name:<{max_name_len}s} {source}") - elif key_name in api_env_vals: - source = api_env_label - found_keys.append((key_name, source)) - print(f" ✓ {key_name:<{max_name_len}s} {source}") - elif key_name in dotenv_vals: - source = ".env file" + # --- Multi-var providers: grouped display --- + for provider_name, env_vars in sorted(multi_var_providers.items()): + found_vars = [] + missing_vars = [] + for var in env_vars: + source = _find_source(var) + if source: + found_vars.append(var) + found_keys.append((var, source)) + else: + missing_vars.append(var) + + if not found_vars and not missing_vars: + continue + + total = len(env_vars) + found_count = len(found_vars) + if found_count == total: + _console.print(f" [green]✓[/green] {provider_name}: {found_count}/{total} vars set") + elif found_count > 0: + missing_str = ", ".join(missing_vars) + _console.print( + f" [yellow]![/yellow] {provider_name}: {found_count}/{total} vars set" + f" (missing: {missing_str})" + ) + # If found_count == 0, skip — nothing to show for this provider + + # --- Single-var providers: individual display --- + sorted_single = sorted(single_var_keys) + max_name_len = max((len(k) for k in sorted_single), default=20) if sorted_single else 20 + for key_name in sorted_single: + source = _find_source(key_name) + if source: found_keys.append((key_name, source)) - print(f" ✓ {key_name:<{max_name_len}s} {source}") + _console.print(f" [green]✓[/green] {key_name:<{max_name_len}s} {source}") if not found_keys: - print(" ✗ No API keys found.\n") + _console.print(" [yellow]✗ No API keys found.[/yellow]\n") found_keys = _prompt_for_api_key() print(f"\n {len(found_keys)} API key(s) found.") + + api_env_path = pdd_dir / f"api-env.{shell_name}" if shell_name else pdd_dir / "api-env.bash" + _console.print(f" [dim]You can edit your global API keys in {api_env_path}[/dim]") + return found_keys @@ -177,20 +278,31 @@ def _prompt_for_api_key() -> List[Tuple[str, str]]: ~/.pdd/api-env.{shell} and loads it into the current session. Returns list of (key_name, source_label) for newly added keys. """ - from pdd.litellm_registry import PROVIDER_API_KEY_MAP, PROVIDER_DISPLAY_NAMES - from pdd.provider_manager import _save_key_to_api_env, _get_api_env_path + from pdd.provider_manager import _read_csv, _save_key_to_api_env added_keys: List[Tuple[str, str]] = [] api_env_label = f"~/.pdd/api-env.{os.path.basename(os.environ.get('SHELL', 'bash'))}" + # Build provider list from reference CSV + ref_path = Path(__file__).parent / "data" / "llm_model.csv" + ref_rows = _read_csv(ref_path) + # Collect unique (provider_display, api_key_env_var) pairs + seen = set() + all_providers: List[Tuple[str, str]] = [] + for row in ref_rows: + provider = row.get("provider", "").strip() + api_key = row.get("api_key", "").strip() + if provider and api_key and (provider, api_key) not in seen: + seen.add((provider, api_key)) + all_providers.append((provider, api_key)) + all_providers.sort(key=lambda x: x[0]) + while True: print(" To continue setup, add at least one API key.") - print(" Popular providers:") - for i, (_, display, env_var) in enumerate(_PROMPT_PROVIDERS, 1): - print(f" {i}) {display:<20s} ({env_var})") - other_idx = len(_PROMPT_PROVIDERS) + 1 - skip_idx = other_idx + 1 - print(f" {other_idx}) Other provider") + print(" Providers:") + for i, (display, env_var) in enumerate(all_providers, 1): + print(f" {i}) {display:<25s} ({env_var})") + skip_idx = len(all_providers) + 1 print(f" {skip_idx}) Skip (continue without keys)") try: @@ -203,41 +315,16 @@ def _prompt_for_api_key() -> List[Tuple[str, str]]: try: choice_num = int(choice) except ValueError: - print(f" Invalid input. Enter a number 1-{skip_idx}.\n") + _console.print(f" [yellow]Invalid input. Enter a number 1-{skip_idx}.[/yellow]\n") continue if choice_num == skip_idx: break - if choice_num == other_idx: - # Show all providers - all_providers = sorted( - PROVIDER_API_KEY_MAP.items(), - key=lambda x: PROVIDER_DISPLAY_NAMES.get(x[0], x[0]), - ) - print("\n All providers:") - for i, (pid, env_var) in enumerate(all_providers, 1): - display = PROVIDER_DISPLAY_NAMES.get(pid, pid) - print(f" {i}) {display:<25s} ({env_var})") - try: - sub_choice = input(f"\n Select provider [1-{len(all_providers)}]: ").strip() - sub_num = int(sub_choice) - if 1 <= sub_num <= len(all_providers): - _, env_var = all_providers[sub_num - 1] - display = PROVIDER_DISPLAY_NAMES.get( - all_providers[sub_num - 1][0], - all_providers[sub_num - 1][0], - ) - else: - print(" Invalid selection.\n") - continue - except (ValueError, EOFError, KeyboardInterrupt): - print() - continue - elif 1 <= choice_num <= len(_PROMPT_PROVIDERS): - _, display, env_var = _PROMPT_PROVIDERS[choice_num - 1] + if 1 <= choice_num <= len(all_providers): + display, env_var = all_providers[choice_num - 1] else: - print(f" Invalid input. Enter a number 1-{skip_idx}.\n") + _console.print(f" [yellow]Invalid input. Enter a number 1-{skip_idx}.[/yellow]\n") continue # Prompt for the key value (masked) @@ -248,14 +335,14 @@ def _prompt_for_api_key() -> List[Tuple[str, str]]: break if not key_value: - print(" No key entered, skipping.\n") + _console.print(" [yellow]No key entered, skipping.[/yellow]\n") continue # Save to api-env file and load into current session _save_key_to_api_env(env_var, key_value) added_keys.append((env_var, api_env_label)) - print(f" ✓ {env_var} saved to {api_env_label}") - print(f" ✓ Loaded into current session\n") + _console.print(f" [green]✓[/green] {env_var} saved to {api_env_label}") + _console.print(f" [green]✓[/green] Loaded into current session\n") # Ask if they want to add another try: @@ -272,13 +359,13 @@ def _prompt_for_api_key() -> List[Tuple[str, str]]: # --------------------------------------------------------------------------- -# Step 2 — Configure models from reference CSV +# Step 2 — Configure models + .pddrc # --------------------------------------------------------------------------- -def _step2_configure_models( +def _step2_configure_models_and_pddrc( found_keys: List[Tuple[str, str]], ) -> Dict[str, int]: - """Match found API keys to reference models and write user CSV. + """Match found API keys to reference models, write user CSV, and ensure .pddrc. Returns {provider_display_name: model_count} for the summary. """ @@ -287,6 +374,7 @@ def _step2_configure_models( _write_csv_atomic, _get_user_csv_path, ) + from pdd.pddrc_initializer import _detect_language, _build_pddrc_content found_key_names = {k for k, _ in found_keys} @@ -294,8 +382,12 @@ def _step2_configure_models( ref_path = Path(__file__).parent / "data" / "llm_model.csv" ref_rows = _read_csv(ref_path) - # Filter reference rows to those whose api_key matches a found key - # Skip local-only rows (lm_studio, ollama — handled in step 3) + # Filter reference rows to those whose api_key env vars are all found. + # Supports pipe-delimited multi-var fields (e.g. "VAR1|VAR2|VAR3"). + # Empty api_key (device flow / local) matches automatically. + # Skip local-only rows (lm_studio, ollama, localhost base_url). + from pdd.provider_manager import parse_api_key_vars + matching_rows: List[Dict[str, str]] = [] for row in ref_rows: api_key_col = row.get("api_key", "").strip() @@ -305,15 +397,20 @@ def _step2_configure_models( # Skip local models if provider in ("lm_studio", "ollama"): continue - # Skip rows with base_url pointing to localhost (local models) if base_url and ("localhost" in base_url or "127.0.0.1" in base_url): continue - # Match on api_key - if api_key_col and api_key_col in found_key_names: + + # Match: all individual env vars must be in found_key_names + env_vars = parse_api_key_vars(api_key_col) + if not env_vars: + # Empty api_key = device flow (e.g. GitHub Copilot) — always match + matching_rows.append(row) + elif all(v in found_key_names for v in env_vars): matching_rows.append(row) - # Read existing user CSV and deduplicate + # Read existing user CSV and deduplicate (create if missing) user_csv_path = _get_user_csv_path() + user_csv_path.parent.mkdir(parents=True, exist_ok=True) existing_rows = _read_csv(user_csv_path) existing_models = {r.get("model", "").strip() for r in existing_rows} @@ -327,180 +424,79 @@ def _step2_configure_models( all_rows = existing_rows + new_rows for row in all_rows: provider = row.get("provider", "Unknown").strip() - # Only count cloud models (with api_key) if row.get("api_key", "").strip(): provider_counts[provider] = provider_counts.get(provider, 0) + 1 # Write merged result if new_rows: _write_csv_atomic(user_csv_path, all_rows) - print(f" ✓ {len(new_rows)} new model(s) added to {user_csv_path}") + _console.print(f" [green]✓[/green] {len(new_rows)} new model(s) added to {user_csv_path}") else: - print(f" ✓ All matching models already in {user_csv_path}") + _console.print(f" [green]✓[/green] All matching models already lodaed in {user_csv_path}") total = sum(provider_counts.values()) - print(f" ✓ {total} cloud model(s) configured") + _console.print(f" [green]✓[/green] {total} model(s) configured") for provider, count in sorted(provider_counts.items()): s = "s" if count != 1 else "" - print(f" {provider}: {count} model{s}") - - return provider_counts - - -# --------------------------------------------------------------------------- -# Step 3 — Local LLMs + .pddrc -# --------------------------------------------------------------------------- - -def _step3_local_llms_and_pddrc() -> Dict[str, List[str]]: - """Check local LLMs and ensure .pddrc exists. - - Returns {provider: [model_names]} for local LLMs found. - """ - from pdd.pddrc_initializer import _detect_language, _build_pddrc_content - - local_summary: Dict[str, List[str]] = {} - - # ── Check Ollama ────────────────────────────────────────────────────── - ollama_models = _query_local_server( - url="http://localhost:11434/api/tags", - extract_models=_extract_ollama_models, - ) - if ollama_models is not None: - local_summary["Ollama"] = ollama_models - if ollama_models: - names = ", ".join(ollama_models) - print(f" ✓ Ollama running — found {names}") - _append_local_models_to_csv( - ollama_models, provider="Ollama", prefix="ollama_chat/", - base_url="http://localhost:11434", - ) - else: - print(" ✓ Ollama running — no models installed") - else: - print(" ✗ Ollama not running (skip)") - - # ── Check LM Studio ────────────────────────────────────────────────── - lm_models = _query_local_server( - url="http://localhost:1234/v1/models", - extract_models=_extract_lm_studio_models, - ) - if lm_models is not None: - local_summary["LM Studio"] = lm_models - if lm_models: - names = ", ".join(lm_models) - print(f" ✓ LM Studio running — found {names}") - _append_local_models_to_csv( - lm_models, provider="lm_studio", prefix="lm_studio/", - base_url="http://localhost:1234/v1", - ) - else: - print(" ✓ LM Studio running — no models loaded") - else: - print(" ✗ LM Studio not running (skip)") + print(f" {provider}: {count} model{s}") # ── Check .pddrc ───────────────────────────────────────────────────── cwd = Path.cwd() pddrc_path = cwd / ".pddrc" if pddrc_path.exists(): - print(f" ✓ .pddrc already exists at {pddrc_path}") + _console.print(f" [green]✓[/green] .pddrc detected at {pddrc_path}") else: - language = _detect_language(cwd) or "python" - content = _build_pddrc_content(language) + print() + _console.print(" [bold].pddrc[/bold] configures where PDD puts generated code, tests, and examples.") + _console.print(" It lives in your project root and lets you define contexts for different") + _console.print(" parts of your codebase (e.g. frontend vs backend).") + print() try: - pddrc_path.write_text(content, encoding="utf-8") - print(f" ✓ Created .pddrc at {pddrc_path} (detected: {language})") - except OSError as exc: - print(f" ✗ Failed to create .pddrc: {exc}") - - return local_summary - - -def _query_local_server( - url: str, - extract_models, - timeout: float = 3.0, -) -> Optional[List[str]]: - """Query a local LLM server. Returns model list or None if unreachable.""" - try: - req = urllib.request.Request(url) - with urllib.request.urlopen(req, timeout=timeout) as resp: - data = json.loads(resp.read().decode("utf-8")) - return extract_models(data) - except (urllib.error.URLError, OSError, json.JSONDecodeError, KeyError): - return None - - -def _extract_ollama_models(data: dict) -> List[str]: - """Extract model names from Ollama /api/tags response.""" - models = data.get("models", []) - return [m.get("name", "") for m in models if m.get("name")] - - -def _extract_lm_studio_models(data: dict) -> List[str]: - """Extract model names from LM Studio /v1/models response.""" - models = data.get("data", []) - return [m.get("id", "") for m in models if m.get("id")] - - -def _append_local_models_to_csv( - model_names: List[str], - provider: str, - prefix: str, - base_url: str, -) -> None: - """Append local models to user CSV, skipping duplicates.""" - from pdd.provider_manager import ( - _read_csv, - _write_csv_atomic, - _get_user_csv_path, - ) - - user_csv_path = _get_user_csv_path() - existing_rows = _read_csv(user_csv_path) - existing_models = {r.get("model", "").strip() for r in existing_rows} + answer = input(" Create .pddrc in this project? [y/Enter to skip] ").strip().lower() + except (EOFError, KeyboardInterrupt): + answer = "" - new_rows = [] - for name in model_names: - model_id = f"{prefix}{name}" - if model_id not in existing_models: - new_rows.append({ - "provider": provider, - "model": model_id, - "input": "0", - "output": "0", - "coding_arena_elo": "1000", - "base_url": base_url, - "api_key": "", - "max_reasoning_tokens": "0", - "structured_output": "True", - "reasoning_type": "none", - "location": "", - }) + if answer in ("y", "yes"): + language = _detect_language(cwd) or "python" + content = _build_pddrc_content(language) + try: + pddrc_path.write_text(content, encoding="utf-8") + _console.print(f" [green]✓[/green] Created .pddrc at {pddrc_path} (detected: {language})") + except OSError as exc: + _console.print(f" [yellow]✗ Failed to create .pddrc: {exc}[/yellow]") + else: + _console.print(" [dim]Skipped .pddrc creation. You can create one later with pdd setup.") - if new_rows: - _write_csv_atomic(user_csv_path, existing_rows + new_rows) + return provider_counts # --------------------------------------------------------------------------- -# Step 4 — Test one model + print summary +# Step 3 — Test one model + print summary # --------------------------------------------------------------------------- -def _step4_test_and_summary( +def _step3_test_and_summary( found_keys: List[Tuple[str, str]], model_summary: Dict[str, int], - local_summary: Dict[str, List[str]], + cli_results=None, ) -> None: """Test the first available cloud model and print the final summary.""" - from pdd.provider_manager import _read_csv, _get_user_csv_path + from pdd.provider_manager import _read_csv, _get_user_csv_path, parse_api_key_vars - # Pick first cloud model user_csv_path = _get_user_csv_path() rows = _read_csv(user_csv_path) test_result = "Skipped (no models configured)" + # Pick first cloud model that has all auth configured. + # Uses the pipe-delimited api_key convention: check every env var is set. cloud_row = None for row in rows: - if row.get("api_key", "").strip(): + api_key_field = row.get("api_key", "").strip() + env_vars = parse_api_key_vars(api_key_field) + if not env_vars: + # Empty = device flow (e.g. GitHub Copilot) — pick it + cloud_row = row + break + if all(os.getenv(v, "") for v in env_vars): cloud_row = row break @@ -509,107 +505,340 @@ def _step4_test_and_summary( try: import litellm # noqa: F401 from pdd.model_tester import _run_test + import threading + import time as time_module + + sys.stdout.write(f" Testing {test_model}...") + sys.stdout.flush() + + # Run in a thread so we can print dots while waiting + test_result_holder: list = [None] + + def _do_test() -> None: + test_result_holder[0] = _run_test(cloud_row) + + t = threading.Thread(target=_do_test, daemon=True) + t.start() + + elapsed = 0.0 + while t.is_alive() and elapsed < 8.0: + t.join(timeout=1.0) + if t.is_alive(): + sys.stdout.write(".") + sys.stdout.flush() + elapsed += 1.0 + + sys.stdout.write("\n") + if t.is_alive(): + result = { + "success": False, + "duration_s": elapsed, + "cost": 0.0, + "error": "Request timed out (8s)", + "tokens": None, + } + else: + result = test_result_holder[0] or { + "success": False, + "duration_s": 0.0, + "cost": 0.0, + "error": "Unknown error", + "tokens": None, + } - print(f" Testing {test_model}...") - result = _run_test(cloud_row) if result["success"]: - test_result = f"✓ {test_model} responded OK ({result['duration_s']:.1f}s)" + test_result = f"[green]✓[/green] {test_model} responded OK ({result['duration_s']:.1f}s)" else: - test_result = f"✗ {test_model} failed: {result['error']}" + test_result = f"[yellow]✗ {test_model} failed: {result['error']}[/yellow]" except ImportError: - test_result = "Skipped (litellm not installed)" - print(f" {test_result}") + test_result = "[yellow]Skipped (litellm not installed)[/yellow]" + _console.print(f" {test_result}") # ── Summary ─────────────────────────────────────────────────────────── print() - print(" ═══════════════════════════════════════════════") - print(" PDD Setup Complete") - print(" ═══════════════════════════════════════════════") + _console.print(" [bold green]PDD Setup Complete![/bold green]") print() + # CLIs + if cli_results: + configured = [r for r in cli_results if not r.skipped and r.cli_name] + skipped = [r for r in cli_results if r.skipped] + if configured: + names = ", ".join(r.cli_name for r in configured) + no_key = [r for r in configured if not r.api_key_configured] + if no_key: + no_key_names = ", ".join(r.cli_name for r in no_key) + _console.print(f" CLI: [green]✓[/green] {names} configured ([yellow]{no_key_names} missing API key[/yellow])") + else: + _console.print(f" CLI: [green]✓[/green] {names} configured") + elif skipped: + _console.print(" CLI: [yellow]✗[/yellow] skipped") + else: + _console.print(" CLI: [dim]not configured[/dim]") + else: + _console.print(" CLI: [dim]not configured[/dim]") + # API Keys - print(f" API Keys: {len(found_keys)} found") + if found_keys: + _console.print(f" API Keys: [green]\u2713[/green] {len(found_keys)} found") + else: + _console.print(" API Keys: [red]\u2717[/red] 0 found") # Models total_models = sum(model_summary.values()) parts = ", ".join(f"{p}: {c}" for p, c in sorted(model_summary.items())) if parts: - print(f" Models: {total_models} configured ({parts})") + print(f" Models: {total_models} configured ({parts}) in {_get_user_csv_path()}") else: - print(f" Models: {total_models} configured") - - # Local LLMs - if local_summary: - local_parts = [] - for provider, models in local_summary.items(): - if models: - local_parts.append(f"{provider} — {', '.join(models)}") - else: - local_parts.append(f"{provider} (no models)") - print(f" Local: {'; '.join(local_parts)}") - else: - print(" Local: none found") + print(f" Models: {total_models} configured in {_get_user_csv_path()}") # .pddrc pddrc_path = Path.cwd() / ".pddrc" if pddrc_path.exists(): - print(" .pddrc: exists") + _console.print(" .pddrc: [green]\u2713[/green] exists") else: - print(" .pddrc: not created") + _console.print(" .pddrc: [red]\u2717[/red] not created") # Test - print(f" Test: {test_result}") + _console.print(f" Test: {test_result}") + + # Exit summary is handled by run_setup after the options menu + + +# --------------------------------------------------------------------------- +# Exit summary — files, quick start, tips +# --------------------------------------------------------------------------- + +_FAT_DIVIDER = "\u2501" * 80 # ━ +_THIN_DIVIDER = "\u2500" * 80 # ─ +_BULLET = "\u2022" # • + +_SUCCESS_PYTHON_TEMPLATE = """\ +Write a python script to print "You did it, !!!" to the console. +Do not write anything except that message. +Capitalize the username.""" + +def _create_sample_prompt() -> str: + """Create the sample prompt file if it doesn't exist. Returns the filename.""" + prompt_file = Path("success_python.prompt") + if not prompt_file.exists(): + prompt_file.write_text(_SUCCESS_PYTHON_TEMPLATE) + return str(prompt_file) + + +def _print_exit_summary(found_keys: List[Tuple[str, str]], cli_results=None) -> None: + """Write PDD-SETUP-SUMMARY.txt and print QUICK START + LEARN MORE to terminal.""" + from pdd.api_key_scanner import _detect_shell + + shell = _detect_shell() or "bash" + pdd_dir = Path.home() / ".pdd" + api_env_path = pdd_dir / f"api-env.{shell}" + user_csv_path = pdd_dir / "llm_model.csv" + sample_prompt = _create_sample_prompt() + + # Build valid_keys dict: key_name -> actual value + valid_keys: Dict[str, str] = {} + for key_name, _source in found_keys: + val = os.environ.get(key_name, "") + if val.strip(): + valid_keys[key_name] = val + + # Determine which files were created/configured + saved_files: List[str] = [] + if api_env_path.exists(): + saved_files.append(str(api_env_path)) + if user_csv_path.exists(): + saved_files.append(str(user_csv_path)) + + created_pdd_dir = pdd_dir.exists() + + # Check if shell init file was updated + from pdd.provider_manager import _get_shell_rc_path + rc_path = _get_shell_rc_path() + init_file_updated: Optional[str] = None + if rc_path and rc_path.exists(): + rc_content = rc_path.read_text(encoding="utf-8") + if "api-env" in rc_content: + init_file_updated = str(rc_path) + + # Source command + if shell == "sh": + source_cmd = f". {api_env_path}" + else: + source_cmd = f"source {api_env_path}" + + # ── Build full summary (saved to file) ─────────────────────────────── + lines: List[str] = [] + lines.append("") + lines.append("") + lines.append(_FAT_DIVIDER) + lines.append("PDD Setup Complete!") + lines.append(_FAT_DIVIDER) + lines.append("") + + # CLIs configured + lines.append("CLIs Configured:") + lines.append("") + if cli_results: + configured = [r for r in cli_results if not r.skipped and r.cli_name] + if configured: + for r in configured: + key_status = "API key set" if r.api_key_configured else "no API key" + lines.append(f" {r.cli_name} ({r.provider}) — {key_status}") + else: + lines.append(" None") + else: + lines.append(" None") + lines.append("") + + # API Keys configured + lines.append("API Keys Configured:") + lines.append("") + if valid_keys: + for kn, kv in valid_keys.items(): + masked = f"{kv[:8]}...{kv[-4:]}" if len(kv) > 12 else "***" + lines.append(f" {kn}: {masked}") + else: + lines.append(" None") + lines.append("") + + # Files created + lines.append("Files created and configured:") + lines.append("") + + file_descriptions: List[Tuple[str, str]] = [] + if created_pdd_dir: + file_descriptions.append(("~/.pdd/", "PDD configuration directory")) + for fp in saved_files: + if "api-env." in fp: + file_descriptions.append((fp, f"API environment variables ({shell} shell)")) + elif "llm_model.csv" in fp: + file_descriptions.append((fp, "LLM model configuration")) + file_descriptions.append((sample_prompt, "Sample prompt for testing")) + if init_file_updated: + file_descriptions.append((init_file_updated, "Shell startup file (updated to source API environment)")) + file_descriptions.append(("PDD-SETUP-SUMMARY.txt", "This summary")) + + max_path_len = max(len(p) for p, _ in file_descriptions) if file_descriptions else 0 + for fp, desc in file_descriptions: + lines.append(f"{fp:<{max_path_len + 2}}{desc}") + + lines.append("") + lines.append(_THIN_DIVIDER) + lines.append("") + lines.append("QUICK START:") + lines.append("") + lines.append("1. Generate code from the sample prompt:") + lines.append(" pdd generate success_python.prompt") + lines.append("") + lines.append(_THIN_DIVIDER) + lines.append("") + lines.append("LEARN MORE:") + lines.append("") + lines.append(f"{_BULLET} PDD documentation: pdd --help") + lines.append(f"{_BULLET} PDD website: https://promptdriven.ai/") + lines.append(f"{_BULLET} Discord community: https://discord.gg/Yp4RTh8bG7") + lines.append("") + lines.append("TIPS:") + lines.append("") + lines.append(f"{_BULLET} Start with simple prompts and gradually increase complexity") + lines.append(f"{_BULLET} Try out 'pdd test' with your prompt+code to create test(s) pdd can use to automatically verify and fix your output code") + lines.append(f"{_BULLET} Try out 'pdd example' with your prompt+code to create examples which help pdd do better") + lines.append("") + lines.append(f"{_BULLET} As you get comfortable, learn configuration settings, including the .pddrc file, PDD_GENERATE_OUTPUT_PATH, and PDD_TEST_OUTPUT_PATH") + lines.append(f"{_BULLET} For larger projects, use Makefiles and/or 'pdd sync'") + lines.append(f"{_BULLET} For ongoing substantial projects, learn about llm_model.csv and the --strength,") + lines.append(f" --temperature, and --time options to optimize model cost, latency, and output quality") + lines.append("") + lines.append(f"{_BULLET} Use 'pdd --help' to explore all available commands") + lines.append("") + lines.append(f"Problems? Shout out on our Discord for help! https://discord.gg/Yp4RTh8bG7") + + if api_env_path.exists(): + lines.append("") + lines.append(_THIN_DIVIDER) + lines.append("") + lines.append("IMPORTANT: To use your API keys in this terminal session, run:") + lines.append(f" {source_cmd}") + lines.append("") + lines.append("New terminal windows will load keys automatically.") + + summary_text = "\n".join(lines) + + # Write PDD-SETUP-SUMMARY.txt + summary_path = Path("PDD-SETUP-SUMMARY.txt") + summary_path.write_text(summary_text, encoding="utf-8") + + # ── Print only QUICK START + LEARN MORE to terminal ────────────────── + print() print() - print(" ═══════════════════════════════════════════════") - print(" Run 'pdd generate' or 'pdd sync' to start.") - print(" ═══════════════════════════════════════════════") + _console.print("[bold green]Completed setup.[/bold green]") + print() + print(_THIN_DIVIDER) + print() + print("QUICK START:") + print() + print("1. Generate code from the sample prompt:") + print(" pdd generate success_python.prompt") + print() + print(_THIN_DIVIDER) + print() + print("LEARN MORE:") + print() + print(f"{_BULLET} PDD documentation: pdd --help") + print(f"{_BULLET} PDD website: https://promptdriven.ai/") + print(f"{_BULLET} Discord community: https://discord.gg/Yp4RTh8bG7") + print() + _console.print(f"[dim]Full summary saved to PDD-SETUP-SUMMARY.txt[/dim]") + print() + if api_env_path.exists(): + _console.print( + f"[bold yellow]Important:[/bold yellow] For updates to API keys in this terminal session, run:\n" + f"\n {source_cmd}\n\n" + f"[dim]New terminal windows will load updated keys automatically.[/dim]" + ) + print() # --------------------------------------------------------------------------- -# Fallback manual menu +# Options menu (post-setup or fallback) # --------------------------------------------------------------------------- -def _run_fallback_menu() -> None: - """Simplified manual menu loop when auto-phase fails.""" +def _run_options_menu() -> None: + """Menu loop for manual configuration options.""" print() from pdd.provider_manager import add_provider_from_registry from pdd.model_tester import test_model_interactive - from pdd.pddrc_initializer import offer_pddrc_init while True: - print("Manual setup options:") - print(" 1. Add a provider") - print(" 2. Test a model") - print(" 3. Initialize .pddrc") - print(" 4. Done") + print(" Options:") + print(" 1. Add a provider") + print(" 2. Test a model") + print() try: - choice = input("Select an option [1-4]: ").strip() + choice = input(" Select an option (Enter to finish): ").strip() except (EOFError, KeyboardInterrupt): - print("\nSetup interrupted — exiting.") - return + print() + break + + if not choice: + break if choice == "1": try: add_provider_from_registry() except Exception as exc: - print(f"Error adding provider: {exc}") + print(f" Error adding provider: {exc}") elif choice == "2": try: test_model_interactive() except Exception as exc: - print(f"Error testing model: {exc}") - elif choice == "3": - try: - offer_pddrc_init() - except Exception as exc: - print(f"Error initializing .pddrc: {exc}") - elif choice == "4": - break + print(f" Error testing model: {exc}") else: - print("Invalid option. Please enter 1, 2, 3, or 4.") + _console.print(" [yellow]Invalid option. Please enter 1 or 2.[/yellow]") print() diff --git a/tests/test_api_key_scanner.py b/tests/test_api_key_scanner.py index e8910af25..0d3e55c91 100644 --- a/tests/test_api_key_scanner.py +++ b/tests/test_api_key_scanner.py @@ -1,8 +1,53 @@ -"""Tests for pdd/api_key_scanner.py""" +# Test Plan: pdd/api_key_scanner.py +# +# Public API under test: +# - get_provider_key_names() → List[str] +# - scan_environment() → Dict[str, KeyInfo] +# - KeyInfo → dataclass(source, is_set) +# +# I. KeyInfo Data Model +# 1. test_keyinfo_fields: KeyInfo has source and is_set attributes. +# +# II. get_provider_key_names — CSV Parsing +# 2. test_key_names_csv_missing: No CSV → empty list. +# 3. test_key_names_csv_empty_file: Empty file → empty list. +# 4. test_key_names_csv_no_api_key_column: CSV without api_key header → empty list. +# 5. test_key_names_csv_all_empty_keys: All api_key values blank → empty list. +# 6. test_key_names_returns_sorted_unique: Normal CSV → sorted, deduplicated keys. +# 7. test_key_names_deduplicates_across_rows: Same key in multiple rows → single entry. +# 8. test_key_names_splits_pipe_delimited: Pipe-delimited api_key → individual keys. +# 9. test_key_names_pipe_dedup_across_rows: Pipe keys deduplicated across rows. +# 10. test_key_names_pipe_strips_whitespace: Whitespace around pipe segments stripped. +# 11. test_key_names_pipe_ignores_empty_segments: Empty pipe segments ignored. +# 12. test_key_names_malformed_csv: Malformed CSV → empty list, no crash. +# 13. test_key_names_permission_error: PermissionError → empty list, no crash. +# 14. test_key_names_unicode: Unicode in CSV → handled correctly. +# +# III. scan_environment — Early Exits +# 15. test_scan_no_models_configured: No CSV → empty dict. +# 16. test_scan_exception_returns_empty: Internal error → empty dict, no raise. +# +# IV. scan_environment — Source Detection +# 17. test_scan_detects_shell_env_key: Key in os.environ → source="shell environment". +# 18. test_scan_detects_api_env_file_key: Key in api-env.{shell} → source="~/.pdd/api-env.{shell}". +# 19. test_scan_detects_dotenv_key: Key in .env → source=".env file". +# 20. test_scan_missing_key_marked_not_set: Key absent everywhere → is_set=False. +# +# V. scan_environment — Priority Order +# 21. test_scan_dotenv_wins_over_shell: .env beats shell environment. +# 22. test_scan_shell_wins_over_api_env: Shell environment beats api-env file. +# +# VI. scan_environment — Shell-Specific Behavior +# 23. test_scan_bash_uses_bash_api_env: SHELL=/bin/bash → reads api-env.bash. +# 24. test_scan_zsh_uses_zsh_api_env: SHELL=/bin/zsh → reads api-env.zsh. +# +# VII. scan_environment — Pipe-Delimited Keys +# 25. test_scan_pipe_keys_scanned_individually: Each pipe-delimited key checked independently. +# +# VIII. scan_environment — Edge Cases +# 26. test_scan_special_chars_in_key_value: Key value with special chars → no crash. import csv -import os -import tempfile from pathlib import Path from unittest import mock @@ -12,540 +57,459 @@ KeyInfo, get_provider_key_names, scan_environment, - _get_csv_path, - _load_dotenv_values, - _detect_shell, - _parse_api_env_file, ) # --------------------------------------------------------------------------- -# Fixtures +# Module-level CSV fixtures # --------------------------------------------------------------------------- +_CSV_FIELDS = [ + "provider", "model", "input", "output", "coding_arena_elo", + "base_url", "api_key", "max_reasoning_tokens", "structured_output", + "reasoning_type", "location", +] + +SIMPLE_CSV_ROWS = [ + {"provider": "OpenAI", "model": "gpt-4", "input": "30.0", "output": "60.0", + "coding_arena_elo": "1000", "base_url": "", "api_key": "OPENAI_API_KEY", + "max_reasoning_tokens": "0", "structured_output": "True", + "reasoning_type": "", "location": ""}, + {"provider": "Anthropic", "model": "claude-3-opus", "input": "15.0", "output": "75.0", + "coding_arena_elo": "1000", "base_url": "", "api_key": "ANTHROPIC_API_KEY", + "max_reasoning_tokens": "0", "structured_output": "True", + "reasoning_type": "", "location": ""}, + {"provider": "Local", "model": "ollama/llama2", "input": "0.0", "output": "0.0", + "coding_arena_elo": "1000", "base_url": "http://localhost:11434", "api_key": "", + "max_reasoning_tokens": "0", "structured_output": "False", + "reasoning_type": "", "location": ""}, +] + +BEDROCK_CSV_ROWS = [ + {"provider": "AWS Bedrock", "model": "anthropic.claude-3", "input": "8.0", + "output": "24.0", "coding_arena_elo": "1000", "base_url": "", + "api_key": "AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME", + "max_reasoning_tokens": "0", "structured_output": "True", + "reasoning_type": "", "location": ""}, +] + +MIXED_CSV_ROWS = SIMPLE_CSV_ROWS + BEDROCK_CSV_ROWS -@pytest.fixture -def temp_home(tmp_path, monkeypatch): - """Create a temporary home directory with .pdd folder.""" + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _write_csv(path: Path, rows: list[dict], fieldnames: list[str] | None = None): + """Write rows to a CSV file at *path*.""" + fieldnames = fieldnames or _CSV_FIELDS + with open(path, "w", newline="", encoding="utf-8") as f: + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + writer.writerows(rows) + + +def _setup_home(tmp_path, monkeypatch, csv_rows=None, api_env_shell=None, + api_env_content=None): + """Set up a fake ~/.pdd directory with optional CSV and api-env file. + + Returns the tmp_path (acting as $HOME). + """ pdd_dir = tmp_path / ".pdd" pdd_dir.mkdir(parents=True, exist_ok=True) monkeypatch.setattr(Path, "home", lambda: tmp_path) + + if csv_rows is not None: + _write_csv(pdd_dir / "llm_model.csv", csv_rows) + + if api_env_shell and api_env_content: + (pdd_dir / f"api-env.{api_env_shell}").write_text(api_env_content) + return tmp_path -@pytest.fixture -def sample_csv(temp_home): - """Create a sample llm_model.csv with various providers.""" - csv_path = temp_home / ".pdd" / "llm_model.csv" - rows = [ - { - "provider": "OpenAI", - "model": "gpt-4", - "input": "30.0", - "output": "60.0", - "coding_arena_elo": "1000", - "base_url": "", - "api_key": "OPENAI_API_KEY", - "max_reasoning_tokens": "0", - "structured_output": "True", - "reasoning_type": "", - "location": "", - }, - { - "provider": "Anthropic", - "model": "claude-3-opus", - "input": "15.0", - "output": "75.0", - "coding_arena_elo": "1000", - "base_url": "", - "api_key": "ANTHROPIC_API_KEY", - "max_reasoning_tokens": "0", - "structured_output": "True", - "reasoning_type": "", - "location": "", - }, - { - "provider": "Local", - "model": "ollama/llama2", - "input": "0.0", - "output": "0.0", - "coding_arena_elo": "1000", - "base_url": "http://localhost:11434", - "api_key": "", # Local LLM - no key needed - "max_reasoning_tokens": "0", - "structured_output": "False", - "reasoning_type": "", - "location": "", - }, - ] +# --------------------------------------------------------------------------- +# I. KeyInfo Data Model +# --------------------------------------------------------------------------- - fieldnames = [ - "provider", "model", "input", "output", "coding_arena_elo", - "base_url", "api_key", "max_reasoning_tokens", "structured_output", - "reasoning_type", "location", - ] - with open(csv_path, "w", newline="", encoding="utf-8") as f: - writer = csv.DictWriter(f, fieldnames=fieldnames) - writer.writeheader() - writer.writerows(rows) +def test_keyinfo_fields(): + """KeyInfo dataclass should expose source and is_set.""" + ki = KeyInfo(source="shell environment", is_set=True) + assert ki.source == "shell environment" + assert ki.is_set is True - return csv_path + ki_missing = KeyInfo(source="", is_set=False) + assert ki_missing.is_set is False # --------------------------------------------------------------------------- -# Tests for get_provider_key_names +# II. get_provider_key_names — CSV Parsing # --------------------------------------------------------------------------- -class TestGetProviderKeyNames: - """Tests for get_provider_key_names function.""" - - def test_returns_sorted_unique_keys(self, sample_csv): - """Should return deduplicated, sorted list of API key names.""" - result = get_provider_key_names() - assert result == ["ANTHROPIC_API_KEY", "OPENAI_API_KEY"] - - def test_returns_empty_list_when_csv_missing(self, temp_home): - """Should return empty list when CSV doesn't exist.""" - result = get_provider_key_names() - assert result == [] - - def test_returns_empty_list_when_csv_empty(self, temp_home): - """Should return empty list when CSV exists but is empty.""" - csv_path = temp_home / ".pdd" / "llm_model.csv" - csv_path.touch() - result = get_provider_key_names() - assert result == [] - - def test_handles_csv_without_api_key_column(self, temp_home): - """Should return empty list when CSV lacks api_key column.""" - csv_path = temp_home / ".pdd" / "llm_model.csv" - with open(csv_path, "w", newline="", encoding="utf-8") as f: - writer = csv.DictWriter(f, fieldnames=["provider", "model"]) - writer.writeheader() - writer.writerow({"provider": "OpenAI", "model": "gpt-4"}) - - result = get_provider_key_names() - assert result == [] - - def test_handles_csv_with_only_empty_api_keys(self, temp_home): - """Should return empty list when all api_key values are empty.""" - csv_path = temp_home / ".pdd" / "llm_model.csv" - fieldnames = ["provider", "model", "api_key"] - with open(csv_path, "w", newline="", encoding="utf-8") as f: - writer = csv.DictWriter(f, fieldnames=fieldnames) - writer.writeheader() - writer.writerow({"provider": "Local", "model": "llama2", "api_key": ""}) - writer.writerow({"provider": "Local2", "model": "mistral", "api_key": " "}) - - result = get_provider_key_names() - assert result == [] - - def test_deduplicates_same_key_multiple_providers(self, temp_home): - """Should deduplicate when multiple rows use the same API key.""" - csv_path = temp_home / ".pdd" / "llm_model.csv" - fieldnames = ["provider", "model", "api_key"] - with open(csv_path, "w", newline="", encoding="utf-8") as f: - writer = csv.DictWriter(f, fieldnames=fieldnames) - writer.writeheader() - writer.writerow({"provider": "OpenAI", "model": "gpt-4", "api_key": "OPENAI_API_KEY"}) - writer.writerow({"provider": "OpenAI", "model": "gpt-3.5", "api_key": "OPENAI_API_KEY"}) - writer.writerow({"provider": "Together", "model": "llama", "api_key": "TOGETHER_API_KEY"}) - - result = get_provider_key_names() - assert result == ["OPENAI_API_KEY", "TOGETHER_API_KEY"] - - def test_handles_malformed_csv(self, temp_home): - """Should return empty list for malformed CSV without raising.""" - csv_path = temp_home / ".pdd" / "llm_model.csv" - csv_path.write_text("this is not,a valid\ncsv file with\"broken quotes") - - result = get_provider_key_names() - # Should handle gracefully - either empty or partial results - assert isinstance(result, list) +def test_key_names_csv_missing(tmp_path, monkeypatch): + """No CSV at all → empty list.""" + _setup_home(tmp_path, monkeypatch) + assert get_provider_key_names() == [] -# --------------------------------------------------------------------------- -# Tests for _detect_shell -# --------------------------------------------------------------------------- +def test_key_names_csv_empty_file(tmp_path, monkeypatch): + """CSV file exists but is empty → empty list.""" + home = _setup_home(tmp_path, monkeypatch) + (home / ".pdd" / "llm_model.csv").touch() + assert get_provider_key_names() == [] -class TestDetectShell: - """Tests for _detect_shell function.""" +def test_key_names_csv_no_api_key_column(tmp_path, monkeypatch): + """CSV lacks an api_key column → empty list.""" + home = _setup_home(tmp_path, monkeypatch) + csv_path = home / ".pdd" / "llm_model.csv" + with open(csv_path, "w", newline="", encoding="utf-8") as f: + writer = csv.DictWriter(f, fieldnames=["provider", "model"]) + writer.writeheader() + writer.writerow({"provider": "OpenAI", "model": "gpt-4"}) + assert get_provider_key_names() == [] - def test_detects_zsh(self, monkeypatch): - """Should detect zsh shell.""" - monkeypatch.setenv("SHELL", "/bin/zsh") - assert _detect_shell() == "zsh" - def test_detects_bash(self, monkeypatch): - """Should detect bash shell.""" - monkeypatch.setenv("SHELL", "/bin/bash") - assert _detect_shell() == "bash" +def test_key_names_csv_all_empty_keys(tmp_path, monkeypatch): + """All api_key values are blank → empty list.""" + home = _setup_home(tmp_path, monkeypatch) + csv_path = home / ".pdd" / "llm_model.csv" + with open(csv_path, "w", newline="", encoding="utf-8") as f: + writer = csv.DictWriter(f, fieldnames=["provider", "model", "api_key"]) + writer.writeheader() + writer.writerow({"provider": "Local", "model": "llama2", "api_key": ""}) + writer.writerow({"provider": "Local2", "model": "mistral", "api_key": " "}) + assert get_provider_key_names() == [] + - def test_detects_fish(self, monkeypatch): - """Should detect fish shell.""" - monkeypatch.setenv("SHELL", "/usr/local/bin/fish") - assert _detect_shell() == "fish" +def test_key_names_returns_sorted_unique(tmp_path, monkeypatch): + """Normal CSV → sorted, deduplicated key names (local models with no key excluded).""" + _setup_home(tmp_path, monkeypatch, csv_rows=SIMPLE_CSV_ROWS) + assert get_provider_key_names() == ["ANTHROPIC_API_KEY", "OPENAI_API_KEY"] - def test_returns_none_when_shell_not_set(self, monkeypatch): - """Should return None when SHELL env var is not set.""" - monkeypatch.delenv("SHELL", raising=False) - assert _detect_shell() is None - def test_handles_unusual_shell_paths(self, monkeypatch): - """Should extract shell name from unusual paths.""" - monkeypatch.setenv("SHELL", "/opt/homebrew/bin/zsh") - assert _detect_shell() == "zsh" +def test_key_names_deduplicates_across_rows(tmp_path, monkeypatch): + """Same key used by multiple models → appears only once.""" + home = _setup_home(tmp_path, monkeypatch) + rows = [ + {"provider": "OpenAI", "model": "gpt-4", "api_key": "OPENAI_API_KEY"}, + {"provider": "OpenAI", "model": "gpt-3.5", "api_key": "OPENAI_API_KEY"}, + {"provider": "Together", "model": "llama", "api_key": "TOGETHER_API_KEY"}, + ] + _write_csv(home / ".pdd" / "llm_model.csv", rows, + fieldnames=["provider", "model", "api_key"]) + assert get_provider_key_names() == ["OPENAI_API_KEY", "TOGETHER_API_KEY"] + + +def test_key_names_splits_pipe_delimited(tmp_path, monkeypatch): + """Pipe-delimited api_key → individual key names.""" + _setup_home(tmp_path, monkeypatch, csv_rows=BEDROCK_CSV_ROWS) + assert get_provider_key_names() == [ + "AWS_ACCESS_KEY_ID", "AWS_REGION_NAME", "AWS_SECRET_ACCESS_KEY", + ] + + +def test_key_names_pipe_dedup_across_rows(tmp_path, monkeypatch): + """Pipe keys from multiple rows are deduplicated.""" + home = _setup_home(tmp_path, monkeypatch) + rows = [ + {"provider": "AWS Bedrock", "model": "claude-3", + "api_key": "AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME"}, + {"provider": "AWS Bedrock", "model": "claude-3.5", + "api_key": "AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME"}, + {"provider": "Anthropic", "model": "claude-3", "api_key": "ANTHROPIC_API_KEY"}, + ] + _write_csv(home / ".pdd" / "llm_model.csv", rows, + fieldnames=["provider", "model", "api_key"]) + assert get_provider_key_names() == [ + "ANTHROPIC_API_KEY", "AWS_ACCESS_KEY_ID", + "AWS_REGION_NAME", "AWS_SECRET_ACCESS_KEY", + ] + + +@pytest.mark.parametrize("raw_key,expected", [ + (" KEY_A | KEY_B | KEY_C ", ["KEY_A", "KEY_B", "KEY_C"]), +]) +def test_key_names_pipe_strips_whitespace(tmp_path, monkeypatch, raw_key, expected): + """Whitespace around pipe segments is stripped.""" + home = _setup_home(tmp_path, monkeypatch) + _write_csv( + home / ".pdd" / "llm_model.csv", + [{"provider": "Test", "model": "t", "api_key": raw_key}], + fieldnames=["provider", "model", "api_key"], + ) + assert get_provider_key_names() == expected + + +@pytest.mark.parametrize("raw_key,expected", [ + ("KEY_A||KEY_B|", ["KEY_A", "KEY_B"]), +]) +def test_key_names_pipe_ignores_empty_segments(tmp_path, monkeypatch, raw_key, expected): + """Empty segments in pipe-delimited values are ignored.""" + home = _setup_home(tmp_path, monkeypatch) + _write_csv( + home / ".pdd" / "llm_model.csv", + [{"provider": "Test", "model": "t", "api_key": raw_key}], + fieldnames=["provider", "model", "api_key"], + ) + assert get_provider_key_names() == expected + + +def test_key_names_malformed_csv(tmp_path, monkeypatch): + """Malformed CSV → empty list, no crash.""" + home = _setup_home(tmp_path, monkeypatch) + (home / ".pdd" / "llm_model.csv").write_text( + 'this is not,a valid\ncsv file with"broken quotes' + ) + result = get_provider_key_names() + assert isinstance(result, list) + + +def test_key_names_permission_error(tmp_path, monkeypatch): + """PermissionError reading CSV → empty list, no crash.""" + home = _setup_home(tmp_path, monkeypatch) + csv_path = home / ".pdd" / "llm_model.csv" + csv_path.write_text("provider,model,api_key\nTest,test,KEY\n") + + original_open = open + + def _raise_on_csv(file, *args, **kwargs): + if str(file) == str(csv_path): + raise PermissionError("Access denied") + return original_open(file, *args, **kwargs) + + with mock.patch("builtins.open", side_effect=_raise_on_csv): + assert get_provider_key_names() == [] + + +def test_key_names_unicode(tmp_path, monkeypatch): + """Unicode in CSV is handled without error.""" + home = _setup_home(tmp_path, monkeypatch) + _write_csv( + home / ".pdd" / "llm_model.csv", + [{"provider": "Tëst", "model": "模型", "api_key": "UNICODE_KEY_名前"}], + fieldnames=["provider", "model", "api_key"], + ) + assert "UNICODE_KEY_名前" in get_provider_key_names() # --------------------------------------------------------------------------- -# Tests for _parse_api_env_file +# III. scan_environment — Early Exits # --------------------------------------------------------------------------- -class TestParseApiEnvFile: - """Tests for _parse_api_env_file function.""" - - def test_parses_simple_exports(self, tmp_path): - """Should parse simple export KEY=value lines.""" - env_file = tmp_path / "api-env.bash" - env_file.write_text( - "export OPENAI_API_KEY=sk-12345\n" - "export ANTHROPIC_API_KEY=ant-67890\n" - ) - - result = _parse_api_env_file(env_file) - assert result == { - "OPENAI_API_KEY": "sk-12345", - "ANTHROPIC_API_KEY": "ant-67890", - } - - def test_parses_quoted_values(self, tmp_path): - """Should parse quoted export values.""" - env_file = tmp_path / "api-env.bash" - env_file.write_text( - 'export OPENAI_API_KEY="sk-12345"\n' - "export ANTHROPIC_API_KEY='ant-67890'\n" - ) - - result = _parse_api_env_file(env_file) - assert result == { - "OPENAI_API_KEY": "sk-12345", - "ANTHROPIC_API_KEY": "ant-67890", - } - - def test_skips_commented_lines(self, tmp_path): - """Should skip lines starting with #.""" - env_file = tmp_path / "api-env.bash" - env_file.write_text( - "# This is a comment\n" - "export OPENAI_API_KEY=sk-12345\n" - "# export ANTHROPIC_API_KEY=ant-67890\n" - ) - - result = _parse_api_env_file(env_file) - assert result == {"OPENAI_API_KEY": "sk-12345"} - - def test_skips_empty_lines(self, tmp_path): - """Should skip empty lines.""" - env_file = tmp_path / "api-env.bash" - env_file.write_text( - "export OPENAI_API_KEY=sk-12345\n" - "\n" - " \n" - "export ANTHROPIC_API_KEY=ant-67890\n" - ) - - result = _parse_api_env_file(env_file) - assert len(result) == 2 - - def test_returns_empty_dict_for_missing_file(self, tmp_path): - """Should return empty dict when file doesn't exist.""" - env_file = tmp_path / "nonexistent" - result = _parse_api_env_file(env_file) - assert result == {} - - def test_handles_special_characters_in_values(self, tmp_path): - """Should handle API keys with special characters.""" - # Characters that might appear in API keys - special_key = 'AIzaSyAbCdEf123456$var"quote\'backtick\\slash' - env_file = tmp_path / "api-env.bash" - # Note: The file might have various quoting styles - env_file.write_text(f"export TEST_KEY='{special_key}'\n") - - result = _parse_api_env_file(env_file) - assert result.get("TEST_KEY") == special_key - - def test_ignores_non_export_lines(self, tmp_path): - """Should ignore lines that don't start with 'export '.""" - env_file = tmp_path / "api-env.bash" - env_file.write_text( - "OPENAI_API_KEY=sk-12345\n" # No export - "export ANTHROPIC_API_KEY=ant-67890\n" - "echo 'hello'\n" - ) - - result = _parse_api_env_file(env_file) - assert result == {"ANTHROPIC_API_KEY": "ant-67890"} - - def test_handles_whitespace_around_equals(self, tmp_path): - """Should handle whitespace around equals sign.""" - env_file = tmp_path / "api-env.bash" - env_file.write_text("export OPENAI_API_KEY = sk-12345\n") - - # The current implementation uses partition("="), check behavior - result = _parse_api_env_file(env_file) - # Result may vary based on implementation - just ensure no crash - assert isinstance(result, dict) +def test_scan_no_models_configured(tmp_path, monkeypatch): + """No CSV → empty dict.""" + _setup_home(tmp_path, monkeypatch) + assert scan_environment() == {} + + +def test_scan_exception_returns_empty(tmp_path, monkeypatch): + """If get_provider_key_names raises, scan_environment returns {}.""" + _setup_home(tmp_path, monkeypatch) + with mock.patch( + "pdd.api_key_scanner.get_provider_key_names", + side_effect=Exception("boom"), + ): + assert scan_environment() == {} # --------------------------------------------------------------------------- -# Tests for scan_environment +# IV. scan_environment — Source Detection # --------------------------------------------------------------------------- -class TestScanEnvironment: - """Tests for scan_environment function.""" +def test_scan_detects_shell_env_key(tmp_path, monkeypatch): + """Key set in os.environ → source='shell environment', is_set=True.""" + _setup_home(tmp_path, monkeypatch, csv_rows=SIMPLE_CSV_ROWS) + monkeypatch.setenv("OPENAI_API_KEY", "sk-test123") + monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False) - def test_returns_empty_dict_when_no_models_configured(self, temp_home): - """Should return empty dict when no models in CSV.""" - result = scan_environment() - assert result == {} + result = scan_environment() - def test_detects_key_in_shell_environment(self, sample_csv, monkeypatch): - """Should detect keys set in shell environment.""" - monkeypatch.setenv("OPENAI_API_KEY", "sk-test123") - # Don't set ANTHROPIC_API_KEY + assert result["OPENAI_API_KEY"].is_set is True + assert result["OPENAI_API_KEY"].source == "shell environment" + assert result["ANTHROPIC_API_KEY"].is_set is False - result = scan_environment() - assert "OPENAI_API_KEY" in result - assert result["OPENAI_API_KEY"].is_set is True - assert result["OPENAI_API_KEY"].source == "shell environment" +def test_scan_detects_api_env_file_key(tmp_path, monkeypatch): + """Key in api-env file → source='~/.pdd/api-env.bash', is_set=True.""" + _setup_home( + tmp_path, monkeypatch, + csv_rows=SIMPLE_CSV_ROWS, + api_env_shell="bash", + api_env_content="export OPENAI_API_KEY=sk-from-api-env\n", + ) + monkeypatch.setenv("SHELL", "/bin/bash") + monkeypatch.delenv("OPENAI_API_KEY", raising=False) + monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False) + + result = scan_environment() - assert "ANTHROPIC_API_KEY" in result - assert result["ANTHROPIC_API_KEY"].is_set is False + assert result["OPENAI_API_KEY"].is_set is True + assert result["OPENAI_API_KEY"].source == "~/.pdd/api-env.bash" + assert result["ANTHROPIC_API_KEY"].is_set is False - def test_detects_key_in_api_env_file(self, sample_csv, temp_home, monkeypatch): - """Should detect keys in ~/.pdd/api-env.{shell} file.""" - monkeypatch.setenv("SHELL", "/bin/bash") - # Clear any existing env vars - monkeypatch.delenv("OPENAI_API_KEY", raising=False) - monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False) - api_env_path = temp_home / ".pdd" / "api-env.bash" - api_env_path.write_text("export OPENAI_API_KEY=sk-from-api-env\n") +def test_scan_detects_dotenv_key(tmp_path, monkeypatch): + """Key in .env file → source='.env file', is_set=True.""" + _setup_home(tmp_path, monkeypatch, csv_rows=SIMPLE_CSV_ROWS) + monkeypatch.delenv("OPENAI_API_KEY", raising=False) + with mock.patch( + "pdd.api_key_scanner._load_dotenv_values", + return_value={"OPENAI_API_KEY": "sk-from-dotenv"}, + ): result = scan_environment() - assert result["OPENAI_API_KEY"].is_set is True - assert result["OPENAI_API_KEY"].source == "~/.pdd/api-env.bash" - assert result["ANTHROPIC_API_KEY"].is_set is False + assert result["OPENAI_API_KEY"].is_set is True + assert result["OPENAI_API_KEY"].source == ".env file" - def test_priority_order_dotenv_first(self, sample_csv, temp_home, monkeypatch): - """Should check .env file first (highest priority).""" - monkeypatch.setenv("SHELL", "/bin/bash") - monkeypatch.setenv("OPENAI_API_KEY", "sk-from-shell") - # Create api-env file too - api_env_path = temp_home / ".pdd" / "api-env.bash" - api_env_path.write_text("export OPENAI_API_KEY=sk-from-api-env\n") +def test_scan_missing_key_marked_not_set(tmp_path, monkeypatch): + """Key absent from all sources → is_set=False.""" + _setup_home(tmp_path, monkeypatch, csv_rows=SIMPLE_CSV_ROWS) + monkeypatch.delenv("OPENAI_API_KEY", raising=False) + monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False) - # Mock dotenv to return a value - with mock.patch( - "pdd.api_key_scanner._load_dotenv_values", - return_value={"OPENAI_API_KEY": "sk-from-dotenv"}, - ): - result = scan_environment() + with mock.patch("pdd.api_key_scanner._load_dotenv_values", return_value={}): + result = scan_environment() - assert result["OPENAI_API_KEY"].source == ".env file" + assert result["OPENAI_API_KEY"].is_set is False + assert result["ANTHROPIC_API_KEY"].is_set is False - def test_priority_order_shell_before_api_env(self, sample_csv, temp_home, monkeypatch): - """Shell environment should have priority over api-env file.""" - monkeypatch.setenv("SHELL", "/bin/bash") - monkeypatch.setenv("OPENAI_API_KEY", "sk-from-shell") - api_env_path = temp_home / ".pdd" / "api-env.bash" - api_env_path.write_text("export OPENAI_API_KEY=sk-from-api-env\n") +# --------------------------------------------------------------------------- +# V. scan_environment — Priority Order +# --------------------------------------------------------------------------- + + +def test_scan_dotenv_wins_over_shell(tmp_path, monkeypatch): + """.env file has higher priority than shell environment.""" + _setup_home( + tmp_path, monkeypatch, + csv_rows=SIMPLE_CSV_ROWS, + api_env_shell="bash", + api_env_content="export OPENAI_API_KEY=sk-from-api-env\n", + ) + monkeypatch.setenv("SHELL", "/bin/bash") + monkeypatch.setenv("OPENAI_API_KEY", "sk-from-shell") + + with mock.patch( + "pdd.api_key_scanner._load_dotenv_values", + return_value={"OPENAI_API_KEY": "sk-from-dotenv"}, + ): + result = scan_environment() - # Mock dotenv to return empty (no .env file) - with mock.patch( - "pdd.api_key_scanner._load_dotenv_values", - return_value={}, - ): - result = scan_environment() + assert result["OPENAI_API_KEY"].source == ".env file" - assert result["OPENAI_API_KEY"].source == "shell environment" - def test_keyinfo_structure(self, sample_csv, monkeypatch): - """Should return KeyInfo dataclass with correct fields.""" - monkeypatch.setenv("OPENAI_API_KEY", "sk-test") +def test_scan_shell_wins_over_api_env(tmp_path, monkeypatch): + """Shell environment has higher priority than api-env file.""" + _setup_home( + tmp_path, monkeypatch, + csv_rows=SIMPLE_CSV_ROWS, + api_env_shell="bash", + api_env_content="export OPENAI_API_KEY=sk-from-api-env\n", + ) + monkeypatch.setenv("SHELL", "/bin/bash") + monkeypatch.setenv("OPENAI_API_KEY", "sk-from-shell") + with mock.patch("pdd.api_key_scanner._load_dotenv_values", return_value={}): result = scan_environment() - key_info = result["OPENAI_API_KEY"] - - assert isinstance(key_info, KeyInfo) - assert hasattr(key_info, "source") - assert hasattr(key_info, "is_set") - - def test_handles_exception_gracefully(self, monkeypatch, temp_home): - """Should return best-effort results on errors without raising.""" - # Create a CSV that will cause issues - csv_path = temp_home / ".pdd" / "llm_model.csv" - csv_path.write_text("provider,model,api_key\nTest,test,TEST_KEY\n") - - # Mock get_provider_key_names to raise - with mock.patch( - "pdd.api_key_scanner.get_provider_key_names", - side_effect=Exception("Test error"), - ): - result = scan_environment() - - # Should return empty dict, not raise - assert result == {} - - def test_different_shells_use_different_api_env_files(self, sample_csv, temp_home, monkeypatch): - """Should use api-env file matching the detected shell.""" - # Create both bash and zsh api-env files with different keys - (temp_home / ".pdd" / "api-env.bash").write_text( - "export OPENAI_API_KEY=sk-bash\n" - ) - (temp_home / ".pdd" / "api-env.zsh").write_text( - "export ANTHROPIC_API_KEY=ant-zsh\n" - ) - - # Clear shell env - monkeypatch.delenv("OPENAI_API_KEY", raising=False) - monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False) - - # Test with bash shell - monkeypatch.setenv("SHELL", "/bin/bash") - with mock.patch("pdd.api_key_scanner._load_dotenv_values", return_value={}): - result = scan_environment() - - assert result["OPENAI_API_KEY"].is_set is True - assert result["OPENAI_API_KEY"].source == "~/.pdd/api-env.bash" - assert result["ANTHROPIC_API_KEY"].is_set is False + + assert result["OPENAI_API_KEY"].source == "shell environment" # --------------------------------------------------------------------------- -# Tests for _load_dotenv_values +# VI. scan_environment — Shell-Specific Behavior # --------------------------------------------------------------------------- -class TestLoadDotenvValues: - """Tests for _load_dotenv_values function.""" +def test_scan_bash_uses_bash_api_env(tmp_path, monkeypatch): + """SHELL=/bin/bash → reads api-env.bash, not api-env.zsh.""" + home = _setup_home( + tmp_path, monkeypatch, + csv_rows=SIMPLE_CSV_ROWS, + api_env_shell="bash", + api_env_content="export OPENAI_API_KEY=sk-bash\n", + ) + # Also create a zsh file with a different key + (home / ".pdd" / "api-env.zsh").write_text( + "export ANTHROPIC_API_KEY=ant-zsh\n" + ) + monkeypatch.setenv("SHELL", "/bin/bash") + monkeypatch.delenv("OPENAI_API_KEY", raising=False) + monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False) + + with mock.patch("pdd.api_key_scanner._load_dotenv_values", return_value={}): + result = scan_environment() - def test_returns_empty_dict_when_dotenv_not_installed(self, monkeypatch): - """Should return empty dict if python-dotenv is not available.""" - # Mock the import to fail - import builtins - real_import = builtins.__import__ + assert result["OPENAI_API_KEY"].is_set is True + assert result["OPENAI_API_KEY"].source == "~/.pdd/api-env.bash" + # zsh file should NOT be consulted when shell is bash + assert result["ANTHROPIC_API_KEY"].is_set is False + + +def test_scan_zsh_uses_zsh_api_env(tmp_path, monkeypatch): + """SHELL=/bin/zsh → reads api-env.zsh.""" + _setup_home( + tmp_path, monkeypatch, + csv_rows=SIMPLE_CSV_ROWS, + api_env_shell="zsh", + api_env_content="export ANTHROPIC_API_KEY=ant-zsh\n", + ) + monkeypatch.setenv("SHELL", "/bin/zsh") + monkeypatch.delenv("OPENAI_API_KEY", raising=False) + monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False) + + with mock.patch("pdd.api_key_scanner._load_dotenv_values", return_value={}): + result = scan_environment() - def mock_import(name, *args, **kwargs): - if name == "dotenv": - raise ImportError("No module named 'dotenv'") - return real_import(name, *args, **kwargs) + assert result["ANTHROPIC_API_KEY"].is_set is True + assert result["ANTHROPIC_API_KEY"].source == "~/.pdd/api-env.zsh" - monkeypatch.setattr(builtins, "__import__", mock_import) - result = _load_dotenv_values() - assert result == {} +# --------------------------------------------------------------------------- +# VII. scan_environment — Pipe-Delimited Keys +# --------------------------------------------------------------------------- + - def test_filters_out_none_values(self): - """Should filter out keys with None values from dotenv.""" - # Mock dotenv_values to return some None values - with mock.patch("dotenv.dotenv_values", return_value={ - "KEY1": "value1", - "KEY2": None, - "KEY3": "value3", - }): - result = _load_dotenv_values() +def test_scan_pipe_keys_scanned_individually(tmp_path, monkeypatch): + """Each segment of a pipe-delimited api_key is checked independently.""" + _setup_home(tmp_path, monkeypatch, csv_rows=BEDROCK_CSV_ROWS) + monkeypatch.setenv("AWS_ACCESS_KEY_ID", "AKIA...") + monkeypatch.setenv("AWS_REGION_NAME", "us-east-1") + monkeypatch.delenv("AWS_SECRET_ACCESS_KEY", raising=False) + + with mock.patch("pdd.api_key_scanner._load_dotenv_values", return_value={}): + result = scan_environment() - assert result == {"KEY1": "value1", "KEY3": "value3"} + assert result["AWS_ACCESS_KEY_ID"].is_set is True + assert result["AWS_ACCESS_KEY_ID"].source == "shell environment" + assert result["AWS_REGION_NAME"].is_set is True + assert result["AWS_SECRET_ACCESS_KEY"].is_set is False # --------------------------------------------------------------------------- -# Edge case tests +# VIII. scan_environment — Edge Cases # --------------------------------------------------------------------------- -class TestEdgeCases: - """Test edge cases and error handling.""" - - def test_handles_unicode_in_csv(self, temp_home): - """Should handle unicode characters in CSV.""" - csv_path = temp_home / ".pdd" / "llm_model.csv" - fieldnames = ["provider", "model", "api_key"] - with open(csv_path, "w", newline="", encoding="utf-8") as f: - writer = csv.DictWriter(f, fieldnames=fieldnames) - writer.writeheader() - writer.writerow({ - "provider": "Tëst Provider", - "model": "模型", - "api_key": "UNICODE_KEY_名前", - }) - - result = get_provider_key_names() - assert "UNICODE_KEY_名前" in result - - def test_handles_very_long_api_key_names(self, temp_home): - """Should handle very long API key names.""" - csv_path = temp_home / ".pdd" / "llm_model.csv" - fieldnames = ["provider", "model", "api_key"] - long_key_name = "A" * 1000 + "_API_KEY" - - with open(csv_path, "w", newline="", encoding="utf-8") as f: - writer = csv.DictWriter(f, fieldnames=fieldnames) - writer.writeheader() - writer.writerow({ - "provider": "Test", - "model": "test", - "api_key": long_key_name, - }) - - result = get_provider_key_names() - assert long_key_name in result - - def test_handles_api_key_with_special_shell_characters(self, temp_home, monkeypatch): - """Should handle API key names with characters problematic for shells.""" - csv_path = temp_home / ".pdd" / "llm_model.csv" - fieldnames = ["provider", "model", "api_key"] - - with open(csv_path, "w", newline="", encoding="utf-8") as f: - writer = csv.DictWriter(f, fieldnames=fieldnames) - writer.writeheader() - writer.writerow({ - "provider": "Test", - "model": "test", - "api_key": "MY_SPECIAL_KEY", - }) - - # Set the env var - monkeypatch.setenv("MY_SPECIAL_KEY", "value_with_$pecial_chars") - monkeypatch.setenv("SHELL", "/bin/bash") - - with mock.patch("pdd.api_key_scanner._load_dotenv_values", return_value={}): - result = scan_environment() - - assert result["MY_SPECIAL_KEY"].is_set is True - - def test_handles_permission_error_on_csv(self, temp_home, monkeypatch): - """Should handle permission errors gracefully.""" - csv_path = temp_home / ".pdd" / "llm_model.csv" - csv_path.write_text("provider,model,api_key\nTest,test,KEY\n") - - # Mock open to raise PermissionError - original_open = open - - def mock_open_with_permission_error(file, *args, **kwargs): - if str(file) == str(csv_path): - raise PermissionError("Access denied") - return original_open(file, *args, **kwargs) - - with mock.patch("builtins.open", side_effect=mock_open_with_permission_error): - result = get_provider_key_names() - - assert result == [] +def test_scan_special_chars_in_key_value(tmp_path, monkeypatch): + """Keys with special-character values don't crash the scanner.""" + home = _setup_home(tmp_path, monkeypatch) + _write_csv( + home / ".pdd" / "llm_model.csv", + [{"provider": "Test", "model": "t", "api_key": "MY_SPECIAL_KEY"}], + fieldnames=["provider", "model", "api_key"], + ) + monkeypatch.setenv("MY_SPECIAL_KEY", "value_with_$pecial_chars") + monkeypatch.setenv("SHELL", "/bin/bash") + + with mock.patch("pdd.api_key_scanner._load_dotenv_values", return_value={}): + result = scan_environment() + + assert result["MY_SPECIAL_KEY"].is_set is True diff --git a/tests/test_cli_detector.py b/tests/test_cli_detector.py index 9def0afc0..6158f0cd8 100644 --- a/tests/test_cli_detector.py +++ b/tests/test_cli_detector.py @@ -1,782 +1,774 @@ -"""Tests for pdd/cli_detector.py""" +"""Tests for pdd/cli_detector.py + +Behavioral tests driven through the two public entry points: + - detect_and_bootstrap_cli() + - detect_cli_tools() + +Test plan +--------- + +1. CliBootstrapResult data model + 1.1 Defaults to empty strings and False flags + 1.2 Skipped result has skipped=True, rest defaults + +2. detect_and_bootstrap_cli — Selection table & input parsing + 2.1 Table shows all three CLIs with install/key status + 2.2 Selecting "1" picks Claude CLI + 2.3 Comma-separated input "1,3" selects multiple CLIs + 2.4 Spaces in input "1, 3" are tolerated + 2.5 Duplicate input "1,1,3" is deduplicated + 2.6 Empty input defaults to best available (installed+key) + 2.7 Empty input defaults to installed-only when no keys set + 2.8 Invalid input falls back to default + 2.9 "q" quits with skipped result + 2.10 "n" quits with skipped result + +3. detect_and_bootstrap_cli — Install flow + 3.1 Already-installed CLI skips install prompt + 3.2 Not-installed CLI prompts for install, user accepts, npm succeeds + 3.3 Not-installed CLI, user accepts install but npm missing + 3.4 Not-installed CLI, install fails (non-zero exit) + 3.5 Not-installed CLI, user declines install → skipped + 3.6 Install succeeds but binary not found on PATH afterwards + +4. detect_and_bootstrap_cli — API key flow + 4.1 Key already set skips prompt + 4.2 Key not set, user provides key → saved to file and os.environ + 4.3 Key not set, user skips (Enter) → api_key_configured=False + 4.4 Anthropic skip shows subscription auth note + 4.5 Non-anthropic skip shows limited functionality note + 4.6 Google provider checks both GOOGLE_API_KEY and GEMINI_API_KEY + +5. detect_and_bootstrap_cli — CLI test step + 5.1 CLI test always runs after install+key steps + 5.2 --version success shows version output + 5.3 --version fails, falls back to --help + +6. detect_and_bootstrap_cli — Interrupt handling + 6.1 KeyboardInterrupt on selection prompt → skipped + 6.2 EOFError on selection prompt → skipped + 6.3 KeyboardInterrupt during per-CLI processing → stops remaining + +7. detect_and_bootstrap_cli — API key persistence + 7.1 Key saved to ~/.pdd/api-env.{shell} with correct export syntax + 7.2 Source line added to shell RC file + 7.3 Fish shell uses set -gx syntax and fish source line + 7.4 Duplicate keys are deduplicated in api-env file + +8. detect_cli_tools — Legacy detection + 8.1 Shows header with command context + 8.2 Found CLI shows checkmark and path + 8.3 Missing CLI shows X + 8.4 Key set but CLI missing → suggests install + 8.5 All CLIs installed with keys → success message + 8.6 No CLIs found → quick start message +""" + +from __future__ import annotations -import subprocess import os -import sys -from unittest import mock +import subprocess from pathlib import Path +from unittest import mock import pytest from pdd.cli_detector import ( - _CLI_COMMANDS, - _API_KEY_ENV_VARS, - _INSTALL_COMMANDS, - _CLI_DISPLAY_NAMES, - _which, - _has_api_key, - _npm_available, - _prompt_yes_no, - _run_install, - detect_cli_tools, - detect_and_bootstrap_cli, CliBootstrapResult, - _save_api_key, - _find_cli_binary, - _prompt_input, - console + detect_and_bootstrap_cli, + detect_cli_tools, ) # --------------------------------------------------------------------------- -# Tests for constants -# --------------------------------------------------------------------------- - - -class TestConstants: - """Tests for static mappings.""" - - def test_cli_commands_has_expected_providers(self): - """Should have CLI commands for known providers.""" - assert "anthropic" in _CLI_COMMANDS - assert "google" in _CLI_COMMANDS - assert "openai" in _CLI_COMMANDS - - def test_cli_commands_values(self): - """CLI command values should be correct.""" - assert _CLI_COMMANDS["anthropic"] == "claude" - assert _CLI_COMMANDS["google"] == "gemini" - assert _CLI_COMMANDS["openai"] == "codex" - - def test_api_key_env_vars_has_expected_providers(self): - """Should have API key env vars for known providers.""" - assert "anthropic" in _API_KEY_ENV_VARS - assert "google" in _API_KEY_ENV_VARS - assert "openai" in _API_KEY_ENV_VARS - - def test_api_key_env_vars_values(self): - """API key env var values should be correct.""" - assert _API_KEY_ENV_VARS["anthropic"] == "ANTHROPIC_API_KEY" - assert _API_KEY_ENV_VARS["google"] == "GOOGLE_API_KEY" - assert _API_KEY_ENV_VARS["openai"] == "OPENAI_API_KEY" - - def test_install_commands_has_expected_providers(self): - """Should have install commands for known providers.""" - assert "anthropic" in _INSTALL_COMMANDS - assert "google" in _INSTALL_COMMANDS - assert "openai" in _INSTALL_COMMANDS - - def test_install_commands_are_npm_commands(self): - """Install commands should be npm install commands.""" - for provider, cmd in _INSTALL_COMMANDS.items(): - assert cmd.startswith("npm install -g ") - - def test_cli_display_names_has_expected_providers(self): - """Should have display names for known providers.""" - assert "anthropic" in _CLI_DISPLAY_NAMES - assert "google" in _CLI_DISPLAY_NAMES - assert "openai" in _CLI_DISPLAY_NAMES - - def test_cli_display_names_are_human_readable(self): - """Display names should be human-readable.""" - assert _CLI_DISPLAY_NAMES["anthropic"] == "Claude CLI" - assert _CLI_DISPLAY_NAMES["google"] == "Gemini CLI" - assert _CLI_DISPLAY_NAMES["openai"] == "Codex CLI" - - def test_all_providers_have_consistent_mappings(self): - """All providers should have entries in all mappings.""" - providers = set(_CLI_COMMANDS.keys()) - - assert providers == set(_API_KEY_ENV_VARS.keys()) - assert providers == set(_INSTALL_COMMANDS.keys()) - assert providers == set(_CLI_DISPLAY_NAMES.keys()) - - -# --------------------------------------------------------------------------- -# Tests for _which -# --------------------------------------------------------------------------- - - -class TestWhich: - """Tests for _which function.""" - - def test_returns_path_for_existing_command(self): - """Should return path for commands that exist.""" - # 'ls' should exist on all Unix-like systems - result = _which("ls") - assert result is not None - assert "ls" in result - - def test_returns_none_for_nonexistent_command(self): - """Should return None for commands that don't exist.""" - result = _which("nonexistent_command_xyz_12345") - assert result is None - - def test_returns_none_for_empty_string(self): - """Should return None for empty command string.""" - result = _which("") - assert result is None - - -# --------------------------------------------------------------------------- -# Tests for _has_api_key -# --------------------------------------------------------------------------- - - -class TestHasApiKey: - """Tests for _has_api_key function.""" - - def test_returns_true_when_key_set(self, monkeypatch): - """Should return True when API key is set.""" - monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key-value") - assert _has_api_key("anthropic") is True - - def test_returns_false_when_key_not_set(self, monkeypatch): - """Should return False when API key is not set.""" - monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False) - assert _has_api_key("anthropic") is False - - def test_returns_false_when_key_empty(self, monkeypatch): - """Should return False when API key is empty string.""" - monkeypatch.setenv("ANTHROPIC_API_KEY", "") - assert _has_api_key("anthropic") is False - - def test_returns_false_when_key_whitespace(self, monkeypatch): - """Should return False when API key is only whitespace.""" - monkeypatch.setenv("ANTHROPIC_API_KEY", " ") - assert _has_api_key("anthropic") is False - - def test_returns_false_for_unknown_provider(self, monkeypatch): - """Should return False for unknown providers.""" - # Unknown provider won't be in _API_KEY_ENV_VARS - assert _has_api_key("unknown_provider") is False - - -# --------------------------------------------------------------------------- -# Tests for _npm_available -# --------------------------------------------------------------------------- - - -class TestNpmAvailable: - """Tests for _npm_available function.""" - - def test_returns_bool(self): - """Should return a boolean.""" - result = _npm_available() - assert isinstance(result, bool) - - def test_uses_which_internally(self): - """Should use _which to find npm.""" - with mock.patch("pdd.cli_detector._which") as mock_which: - mock_which.return_value = "/usr/bin/npm" - result = _npm_available() - mock_which.assert_called_once_with("npm") - assert result is True - - def test_returns_false_when_npm_not_found(self): - """Should return False when npm is not installed.""" - with mock.patch("pdd.cli_detector._which", return_value=None): - assert _npm_available() is False - - -# --------------------------------------------------------------------------- -# Tests for _prompt_yes_no -# --------------------------------------------------------------------------- - - -class TestPromptYesNo: - """Tests for _prompt_yes_no function.""" - - def test_returns_true_for_y(self): - """Should return True for 'y' input.""" - with mock.patch("pdd.cli_detector._prompt_input", return_value="y"): - assert _prompt_yes_no("Test? ") is True - - def test_returns_true_for_yes(self): - """Should return True for 'yes' input.""" - with mock.patch("pdd.cli_detector._prompt_input", return_value="yes"): - assert _prompt_yes_no("Test? ") is True - - def test_returns_true_for_Y_uppercase(self): - """Should return True for uppercase 'Y' input.""" - with mock.patch("pdd.cli_detector._prompt_input", return_value="Y"): - assert _prompt_yes_no("Test? ") is True - - def test_returns_true_for_YES_uppercase(self): - """Should return True for uppercase 'YES' input.""" - with mock.patch("pdd.cli_detector._prompt_input", return_value="YES"): - assert _prompt_yes_no("Test? ") is True - - def test_returns_false_for_n(self): - """Should return False for 'n' input.""" - with mock.patch("pdd.cli_detector._prompt_input", return_value="n"): - assert _prompt_yes_no("Test? ") is False - - def test_returns_false_for_no(self): - """Should return False for 'no' input.""" - with mock.patch("pdd.cli_detector._prompt_input", return_value="no"): - assert _prompt_yes_no("Test? ") is False - - def test_returns_false_for_empty(self): - """Should return False for empty input (default is No).""" - with mock.patch("pdd.cli_detector._prompt_input", return_value=""): - assert _prompt_yes_no("Test? ") is False - - def test_returns_false_for_random_input(self): - """Should return False for unrecognized input.""" - with mock.patch("pdd.cli_detector._prompt_input", return_value="maybe"): - assert _prompt_yes_no("Test? ") is False - - def test_handles_eof_error(self): - """Should return False on EOFError.""" - with mock.patch("pdd.cli_detector._prompt_input", side_effect=EOFError()): - assert _prompt_yes_no("Test? ") is False - - def test_handles_keyboard_interrupt(self): - """Should return False on KeyboardInterrupt.""" - with mock.patch("pdd.cli_detector._prompt_input", side_effect=KeyboardInterrupt()): - assert _prompt_yes_no("Test? ") is False - - def test_strips_whitespace(self): - """Should strip whitespace from input.""" - with mock.patch("pdd.cli_detector._prompt_input", return_value=" y "): - assert _prompt_yes_no("Test? ") is True - - -# --------------------------------------------------------------------------- -# Tests for _run_install -# --------------------------------------------------------------------------- - - -class TestRunInstall: - """Tests for _run_install function.""" - - def test_returns_true_on_success(self): - """Should return True when command succeeds.""" - with mock.patch("subprocess.run") as mock_run: - mock_run.return_value = mock.MagicMock(returncode=0) - result = _run_install("echo test") - assert result is True - - def test_returns_false_on_failure(self): - """Should return False when command fails.""" - with mock.patch("subprocess.run") as mock_run: - mock_run.return_value = mock.MagicMock(returncode=1) - result = _run_install("exit 1") - assert result is False - - def test_returns_false_on_exception(self): - """Should return False on subprocess exception.""" - with mock.patch("subprocess.run", side_effect=Exception("Test error")): - result = _run_install("failing command") - assert result is False - - def test_runs_command_with_shell(self): - """Should run command with shell=True.""" - with mock.patch("subprocess.run") as mock_run: - mock_run.return_value = mock.MagicMock(returncode=0) - _run_install("npm install -g test") - mock_run.assert_called_once() - call_kwargs = mock_run.call_args[1] - assert call_kwargs["shell"] is True - - -# --------------------------------------------------------------------------- -# Tests for detect_cli_tools +# Module-level constants — realistic scenarios for test fixtures # --------------------------------------------------------------------------- +# Provider/CLI status: all three CLIs installed with keys +ALL_INSTALLED = { + "claude": "/usr/local/bin/claude", + "codex": "/usr/local/bin/codex", + "gemini": "/usr/local/bin/gemini", +} -class TestDetectCliTools: - """Tests for detect_cli_tools function.""" - - def test_prints_header(self, capsys): - """Should print the detection header.""" - with mock.patch("pdd.cli_detector._which", return_value=None): - with mock.patch("pdd.cli_detector._has_api_key", return_value=False): - detect_cli_tools() - - captured = capsys.readouterr() - assert "Agentic CLI Tool Detection" in captured.out - assert "pdd fix, pdd change, pdd bug" in captured.out - - def test_shows_found_cli(self, capsys): - """Should show checkmark for found CLI tools.""" - with mock.patch("pdd.cli_detector._which") as mock_which: - mock_which.side_effect = lambda cmd: "/usr/bin/claude" if cmd == "claude" else None - with mock.patch("pdd.cli_detector._has_api_key", return_value=False): - detect_cli_tools() - - captured = capsys.readouterr() - assert "✓" in captured.out - assert "Claude CLI" in captured.out - - def test_shows_not_found_cli(self, capsys): - """Should show X for missing CLI tools.""" - with mock.patch("pdd.cli_detector._which", return_value=None): - with mock.patch("pdd.cli_detector._has_api_key", return_value=False): - detect_cli_tools() - - captured = capsys.readouterr() - assert "✗" in captured.out - assert "Not found" in captured.out - - def test_shows_api_key_status_when_cli_found(self, capsys): - """Should show API key status when CLI is found.""" - with mock.patch("pdd.cli_detector._which", return_value="/usr/bin/claude"): - with mock.patch("pdd.cli_detector._has_api_key") as mock_has_key: - mock_has_key.return_value = True - detect_cli_tools() - - captured = capsys.readouterr() - assert "set" in captured.out - - def test_warns_when_cli_found_but_no_key(self, capsys): - """Should warn when CLI found but API key not set.""" - with mock.patch("pdd.cli_detector._which", return_value="/usr/bin/claude"): - with mock.patch("pdd.cli_detector._has_api_key", return_value=False): - detect_cli_tools() - - captured = capsys.readouterr() - assert "not set" in captured.out - assert "won't be usable" in captured.out - - def test_suggests_install_when_key_but_no_cli(self, capsys): - """Should suggest installation when API key is set but CLI is missing.""" - with mock.patch("pdd.cli_detector._which", return_value=None): - with mock.patch("pdd.cli_detector._has_api_key") as mock_has_key: - # Only anthropic has key set - mock_has_key.side_effect = lambda p: p == "anthropic" - with mock.patch("pdd.cli_detector._npm_available", return_value=False): - detect_cli_tools() - - captured = capsys.readouterr() - assert "install the cli" in captured.out.lower() - - def test_offers_installation_when_npm_available(self, capsys): - """Should offer installation when npm is available.""" - with mock.patch("pdd.cli_detector._which", return_value=None): - with mock.patch("pdd.cli_detector._has_api_key") as mock_has_key: - mock_has_key.side_effect = lambda p: p == "anthropic" - with mock.patch("pdd.cli_detector._npm_available", return_value=True): - with mock.patch("pdd.cli_detector._prompt_yes_no", return_value=False): - detect_cli_tools() - - captured = capsys.readouterr() - assert "Install now?" in captured.out or "Install:" in captured.out - - def test_shows_npm_not_installed_message(self, capsys): - """Should show message when npm is not installed.""" - with mock.patch("pdd.cli_detector._which", return_value=None): - with mock.patch("pdd.cli_detector._has_api_key") as mock_has_key: - mock_has_key.side_effect = lambda p: p == "anthropic" - with mock.patch("pdd.cli_detector._npm_available", return_value=False): - detect_cli_tools() - - captured = capsys.readouterr() - assert "npm is not installed" in captured.out - - def test_runs_installation_on_yes(self, capsys): - """Should run installation when user says yes.""" - with mock.patch("pdd.cli_detector._which") as mock_which: - mock_which.side_effect = [None, None, None, "/usr/bin/claude"] # Found after install - with mock.patch("pdd.cli_detector._has_api_key") as mock_has_key: - mock_has_key.side_effect = lambda p: p == "anthropic" - with mock.patch("pdd.cli_detector._npm_available", return_value=True): - with mock.patch("pdd.cli_detector._prompt_yes_no", return_value=True): - with mock.patch("pdd.cli_detector._run_install", return_value=True): - detect_cli_tools() - - captured = capsys.readouterr() - assert "successfully" in captured.out or "completed" in captured.out - - def test_shows_failure_message_on_install_fail(self, capsys): - """Should show failure message when installation fails.""" - with mock.patch("pdd.cli_detector._which", return_value=None): - with mock.patch("pdd.cli_detector._has_api_key") as mock_has_key: - mock_has_key.side_effect = lambda p: p == "anthropic" - with mock.patch("pdd.cli_detector._npm_available", return_value=True): - with mock.patch("pdd.cli_detector._prompt_yes_no", return_value=True): - with mock.patch("pdd.cli_detector._run_install", return_value=False): - detect_cli_tools() - - captured = capsys.readouterr() - assert "failed" in captured.out.lower() or "manually" in captured.out.lower() - - def test_shows_skip_message_on_no(self, capsys): - """Should show skip message when user declines installation.""" - with mock.patch("pdd.cli_detector._which", return_value=None): - with mock.patch("pdd.cli_detector._has_api_key") as mock_has_key: - mock_has_key.side_effect = lambda p: p == "anthropic" - with mock.patch("pdd.cli_detector._npm_available", return_value=True): - with mock.patch("pdd.cli_detector._prompt_yes_no", return_value=False): - detect_cli_tools() - - captured = capsys.readouterr() - assert "Skipped" in captured.out or "later" in captured.out - - def test_shows_quick_start_when_nothing_installed(self, capsys): - """Should show quick start guide when no CLIs are installed.""" - with mock.patch("pdd.cli_detector._which", return_value=None): - with mock.patch("pdd.cli_detector._has_api_key", return_value=False): - detect_cli_tools() - - captured = capsys.readouterr() - assert "Quick start" in captured.out or "No CLI tools found" in captured.out - - def test_shows_all_installed_message(self, capsys): - """Should show success message when all CLIs with keys are installed.""" - with mock.patch("pdd.cli_detector._which", return_value="/usr/bin/cli"): - with mock.patch("pdd.cli_detector._has_api_key", return_value=True): - detect_cli_tools() - - captured = capsys.readouterr() - assert "All CLI tools with matching API keys are installed" in captured.out - - -# --------------------------------------------------------------------------- -# Integration tests -# --------------------------------------------------------------------------- - - -class TestIntegration: - """Integration tests for CLI detector.""" - - def test_detect_cli_tools_handles_import_error(self, capsys): - """Should handle missing agentic_common gracefully.""" - with mock.patch("pdd.cli_detector._which", return_value=None): - with mock.patch("pdd.cli_detector._has_api_key", return_value=False): - # The function imports get_available_agents but handles import errors - detect_cli_tools() - - # Should complete without error - captured = capsys.readouterr() - assert "Agentic CLI Tool Detection" in captured.out +ALL_KEYS = { + "ANTHROPIC_API_KEY": "sk-ant-test", + "OPENAI_API_KEY": "sk-oai-test", + "GEMINI_API_KEY": "gm-test", + "GOOGLE_API_KEY": "gm-test", +} - def test_detect_cli_tools_complete_flow(self, capsys): - """Test complete detection flow with mixed results.""" - def mock_which(cmd): - return "/usr/bin/claude" if cmd == "claude" else None +# Only Claude installed with key +CLAUDE_ONLY = {"claude": "/usr/local/bin/claude"} +CLAUDE_KEY = {"ANTHROPIC_API_KEY": "sk-ant-test"} - def mock_has_key(provider): - return provider in ["anthropic", "openai"] - - with mock.patch("pdd.cli_detector._which", side_effect=mock_which): - with mock.patch("pdd.cli_detector._has_api_key", side_effect=mock_has_key): - with mock.patch("pdd.cli_detector._npm_available", return_value=False): - detect_cli_tools() - - captured = capsys.readouterr() - # Claude should show as found - assert "Claude CLI" in captured.out - assert "✓" in captured.out - # Others should show as not found - assert "✗" in captured.out +# No CLIs installed, no keys +NOTHING = {} # --------------------------------------------------------------------------- -# Edge case tests +# Helper: capture output from detect_and_bootstrap_cli # --------------------------------------------------------------------------- +def _run_bootstrap_capture( + monkeypatch, + tmp_path: Path, + user_inputs: list[str], + *, + cli_paths: dict[str, str] | None = None, + env_keys: dict[str, str] | None = None, + npm_available: bool = False, + install_succeeds: bool = False, + install_then_found: str | None = None, + version_output: str = "1.0.0", + version_returncode: int = 0, +) -> tuple[str, list[CliBootstrapResult]]: + """Run detect_and_bootstrap_cli with mocked boundaries. + + Args: + monkeypatch: pytest monkeypatch fixture + tmp_path: temporary directory for home + user_inputs: sequence of strings for _prompt_input + cli_paths: mapping of cli_name -> path (None = not found) + env_keys: environment variables to set + npm_available: whether npm is on PATH + install_succeeds: whether subprocess install returns 0 + install_then_found: path to return after install succeeds (None = not found) + version_output: stdout from --version + version_returncode: exit code from --version + + Returns: + (captured_output, results) tuple + """ + cli_paths = cli_paths or {} + env_keys = env_keys or {} + + # Clean environment + for var in ("ANTHROPIC_API_KEY", "OPENAI_API_KEY", "GOOGLE_API_KEY", + "GEMINI_API_KEY", "SHELL"): + monkeypatch.delenv(var, raising=False) + for k, v in env_keys.items(): + monkeypatch.setenv(k, v) + monkeypatch.setenv("SHELL", "/bin/bash") + + # Mock Path.home to tmp_path + monkeypatch.setattr(Path, "home", staticmethod(lambda: tmp_path)) + + # Create shell RC file + rc_file = tmp_path / ".bashrc" + if not rc_file.exists(): + rc_file.write_text("# existing\n") + + # Mock user input + input_iter = iter(user_inputs) + monkeypatch.setattr( + "pdd.cli_detector._prompt_input", + lambda _prompt="": next(input_iter), + ) -class TestEdgeCases: - """Test edge cases and error handling.""" - - def test_handles_subprocess_timeout(self): - """Should handle subprocess timeout gracefully.""" - with mock.patch("subprocess.run", side_effect=subprocess.TimeoutExpired("cmd", 30)): - result = _run_install("slow command") - assert result is False - - def test_empty_env_var_treated_as_not_set(self, monkeypatch): - """Empty string env vars should be treated as not set.""" - monkeypatch.setenv("ANTHROPIC_API_KEY", "") - assert _has_api_key("anthropic") is False - - def test_whitespace_only_env_var_treated_as_not_set(self, monkeypatch): - """Whitespace-only env vars should be treated as not set.""" - monkeypatch.setenv("ANTHROPIC_API_KEY", " \t\n ") - assert _has_api_key("anthropic") is False - + # Track _find_cli_binary calls to simulate post-install discovery + find_call_count = {} + def mock_find_cli_binary(name): + find_call_count[name] = find_call_count.get(name, 0) + 1 + if name in cli_paths: + return cli_paths[name] + # After install, return install_then_found for the CLI being installed + if install_then_found and find_call_count[name] > 1: + return install_then_found + return None -# --------------------------------------------------------------------------- -# Fixtures for bootstrap tests -# --------------------------------------------------------------------------- + # Mock subprocess.run for both install and --version/--help + def mock_subprocess_run(cmd, **kwargs): + result = mock.MagicMock() + if kwargs.get("shell"): + # Install command + result.returncode = 0 if install_succeeds else 1 + result.stdout = "" + result.stderr = "" + else: + # CLI test (--version or --help) + result.returncode = version_returncode + result.stdout = version_output + result.stderr = "" + return result + + # Mock npm availability + def mock_shutil_which(cmd): + if cmd == "npm": + return "/usr/bin/npm" if npm_available else None + return cli_paths.get(cmd) + + # Capture console output + printed = [] + + def capture_print(*args, **kwargs): + printed.append(" ".join(str(a) for a in args)) + + # Apply mocks + with mock.patch("pdd.cli_detector._find_cli_binary", side_effect=mock_find_cli_binary), \ + mock.patch("pdd.cli_detector.console") as mock_console, \ + mock.patch("subprocess.run", side_effect=mock_subprocess_run), \ + mock.patch("shutil.which", side_effect=mock_shutil_which), \ + mock.patch("pdd.setup_tool._print_step_banner"): + + mock_console.print.side_effect = capture_print + results = detect_and_bootstrap_cli() + + output = "\n".join(printed) + return output, results -@pytest.fixture -def mock_console(): - with mock.patch("pdd.cli_detector.console") as m: - yield m - -@pytest.fixture -def mock_env(): - with mock.patch.dict(os.environ, {}, clear=True): - yield os.environ - -@pytest.fixture -def mock_which(): - with mock.patch("shutil.which") as m: - yield m - -@pytest.fixture -def mock_input(): - with mock.patch("pdd.cli_detector._prompt_input") as m: - yield m - -@pytest.fixture -def mock_subprocess(): - with mock.patch("subprocess.run") as m: - yield m - -@pytest.fixture -def mock_home(tmp_path): - """Mock Path.home() to return a temporary directory.""" - with mock.patch("pathlib.Path.home", return_value=tmp_path): - yield tmp_path # --------------------------------------------------------------------------- -# Helper Function Tests +# Helper: capture output from detect_cli_tools # --------------------------------------------------------------------------- -def test_save_api_key_bash(mock_home, mock_console): - """Test saving API key for bash shell.""" - shell = "bash" - key_name = "TEST_KEY" - key_value = "sk-test-123" - - # Create a dummy .bashrc - rc_file = mock_home / ".bashrc" - rc_file.write_text("# existing content\n") - - success = _save_api_key(key_name, key_value, shell) - - assert success is True - - # Check api-env file - api_env = mock_home / ".pdd" / "api-env.bash" - assert api_env.exists() - content = api_env.read_text() - assert f"export {key_name}={key_value}" in content - - # Check RC file update - rc_content = rc_file.read_text() - assert f"source {api_env}" in rc_content - assert os.environ[key_name] == key_value - -def test_save_api_key_fish(mock_home, mock_console): - """Test saving API key for fish shell.""" - shell = "fish" - key_name = "TEST_KEY" - key_value = "sk-test-123" - - # Create dummy config.fish - fish_config = mock_home / ".config" / "fish" / "config.fish" - fish_config.parent.mkdir(parents=True) - fish_config.write_text("") - - success = _save_api_key(key_name, key_value, shell) - - assert success is True - - api_env = mock_home / ".pdd" / "api-env.fish" - content = api_env.read_text() - assert f"set -gx {key_name} {key_value}" in content - - rc_content = fish_config.read_text() - assert f"test -f {api_env} ; and source {api_env}" in rc_content - -def test_find_cli_binary_fallback(mock_which): - """Test finding binary in fallback paths when shutil.which fails.""" - mock_which.return_value = None - - # Mock os.path.exists and is_file/access - with mock.patch("pathlib.Path.exists", return_value=True), \ - mock.patch("pathlib.Path.is_file", return_value=True), \ - mock.patch("os.access", return_value=True): - - # Should find it in the first fallback path checked - result = _find_cli_binary("claude") - assert result is not None - assert "claude" in result - -# --------------------------------------------------------------------------- -# detect_and_bootstrap_cli Tests -# --------------------------------------------------------------------------- +def _run_legacy_capture( + monkeypatch, + cli_paths: dict[str, str] | None = None, + env_keys: dict[str, str] | None = None, +) -> str: + """Run detect_cli_tools with mocked boundaries, return captured output.""" + cli_paths = cli_paths or {} + env_keys = env_keys or {} -def test_bootstrap_happy_path(mock_env, mock_input, mock_console): - """ - Scenario: Claude is installed and ANTHROPIC_API_KEY is set. - All 3 CLIs are shown in table. User selects Claude (1). - Expect: Return with success, no install or API key prompt needed. - """ - mock_env["ANTHROPIC_API_KEY"] = "sk-existing" - mock_input.return_value = "1" # User selects Claude + for var in ("ANTHROPIC_API_KEY", "OPENAI_API_KEY", "GOOGLE_API_KEY", + "GEMINI_API_KEY"): + monkeypatch.delenv(var, raising=False) + for k, v in env_keys.items(): + monkeypatch.setenv(k, v) - with mock.patch("pdd.cli_detector._find_cli_binary") as mock_find: - mock_find.side_effect = lambda x: "/usr/bin/claude" if x == "claude" else None - result = detect_and_bootstrap_cli() + def mock_which(cmd): + return cli_paths.get(cmd) - assert result.cli_name == "claude" - assert result.provider == "anthropic" - assert result.api_key_configured is True - assert result.cli_path == "/usr/bin/claude" + printed = [] - # Table should be shown with all 3 CLIs before the user picks - all_printed = " ".join(str(c) for c in mock_console.print.call_args_list) - assert "Claude CLI" in all_printed - assert "Codex CLI" in all_printed - assert "Gemini CLI" in all_printed + def capture_print(*args, **kwargs): + printed.append(" ".join(str(a) for a in args)) -def test_bootstrap_key_needed_user_provides(mock_which, mock_env, mock_input, mock_home, mock_console): - """ - Scenario: Claude installed, no key. User enters key. - Expect: Key saved, success returned. - """ - mock_which.side_effect = lambda x: f"/usr/bin/{x}" if x == "claude" else None - # No API key in env - - mock_input.return_value = "sk-new-key" - - result = detect_and_bootstrap_cli() - - assert result.cli_name == "claude" - assert result.provider == "anthropic" - assert result.api_key_configured is True - - # Verify key was saved to env - assert os.environ["ANTHROPIC_API_KEY"] == "sk-new-key" - # Verify file write happened (via _save_api_key logic) - api_env = mock_home / ".pdd" / "api-env.bash" # Default shell is bash - assert api_env.exists() - -def test_bootstrap_key_needed_user_skips(mock_which, mock_env, mock_input, mock_console): - """ - Scenario: Claude installed, no key. User presses Enter (skips). - Expect: Success returned but api_key_configured=False. - """ - mock_which.side_effect = lambda x: f"/usr/bin/{x}" if x == "claude" else None - mock_input.return_value = "" # Empty input - - result = detect_and_bootstrap_cli() - - assert result.cli_name == "claude" - assert result.api_key_configured is False - mock_console.print.assert_any_call(" [dim]Note: Claude CLI may still work with subscription auth.[/dim]") - -def test_bootstrap_no_cli_user_declines(mock_which, mock_input, mock_console): - """ - Scenario: No CLIs found. User says 'n' to install. - Expect: Empty result. - """ - mock_which.return_value = None # No CLIs found - mock_input.return_value = "n" - - result = detect_and_bootstrap_cli() - - assert result.cli_name == "" - assert result.provider == "" - mock_console.print.assert_any_call(" [dim]Skipped CLI setup. You can run `pdd setup` again later.[/dim]") - -def test_bootstrap_install_npm_missing(mock_input, mock_console): - """ - Scenario: No CLIs found. All 3 shown in table. User selects Claude (1), - says yes to install. npm not found. - Expect: Error message, empty result. - """ - mock_input.side_effect = ["1", "y"] # Select Claude, then yes to install - - with mock.patch("pdd.cli_detector._find_cli_binary", return_value=None), \ + with mock.patch("pdd.cli_detector._which", side_effect=mock_which), \ + mock.patch("pdd.cli_detector.console") as mock_console, \ mock.patch("pdd.cli_detector._npm_available", return_value=False): - result = detect_and_bootstrap_cli() - - assert result.cli_name == "" - all_printed = " ".join(str(c) for c in mock_console.print.call_args_list) - assert "npm is not installed" in all_printed - -def test_bootstrap_install_success(mock_input, mock_subprocess, mock_home, mock_console): - """ - Scenario: No CLIs found. All 3 shown in table. User selects Claude (1), - says yes to install. Install succeeds. User provides API key. - Expect: Full success. - """ - # Inputs: select Claude, yes to install, provide API key - mock_input.side_effect = ["1", "y", "sk-key"] - - # _find_cli_binary returns None on initial scan; returns the path after install - claude_calls = [0] - def find_binary(name): - if name == "claude": - claude_calls[0] += 1 - return "/usr/bin/claude" if claude_calls[0] > 1 else None - return None - - mock_subprocess.return_value.returncode = 0 - - with mock.patch("pdd.cli_detector._find_cli_binary", side_effect=find_binary), \ - mock.patch("pdd.cli_detector._npm_available", return_value=True): - result = detect_and_bootstrap_cli() - - assert result.cli_name == "claude" - assert result.cli_path == "/usr/bin/claude" - assert result.api_key_configured is True - - # Verify the correct install command was run - mock_subprocess.assert_called_with( - "npm install -g @anthropic-ai/claude-code", - shell=True, capture_output=True, text=True, timeout=120 - ) - -def test_bootstrap_install_failure(mock_input, mock_subprocess, mock_console): - """ - Scenario: No CLIs found. All 3 shown in table. User selects Claude (1), - says yes to install. Install fails. - Expect: Empty result with failure message. - """ - mock_input.side_effect = ["1", "y"] # Select Claude, then yes to install - - mock_subprocess.return_value.returncode = 1 - mock_subprocess.return_value.stderr = "Permission denied" - - with mock.patch("pdd.cli_detector._find_cli_binary", return_value=None), \ - mock.patch("pdd.cli_detector._npm_available", return_value=True): - result = detect_and_bootstrap_cli() - - assert result.cli_name == "" - all_printed = " ".join(str(c) for c in mock_console.print.call_args_list) - assert "Installation failed" in all_printed - -# --------------------------------------------------------------------------- -# detect_cli_tools Tests (Bootstrap Perspective) -# --------------------------------------------------------------------------- - -def test_detect_cli_tools_reporting(mock_which, mock_env, mock_console): - """Test legacy detection reporting.""" - # Claude found, others missing - mock_which.side_effect = lambda x: "/bin/claude" if x == "claude" else None - mock_env["ANTHROPIC_API_KEY"] = "sk-test" - - detect_cli_tools() - - # The code now uses display_name, so we adjust expectations - mock_console.print.assert_any_call(" [green]✓[/green] Claude CLI — Found at /bin/claude") - mock_console.print.assert_any_call(" [green]✓[/green] ANTHROPIC_API_KEY is set") - mock_console.print.assert_any_call(" [red]✗[/red] Codex CLI — Not found") - -def test_detect_cli_tools_offer_install(mock_which, mock_env, mock_input, mock_subprocess, mock_console): - """Test legacy install offer when key exists but CLI missing.""" - # Codex missing, but key present - mock_which.side_effect = lambda x: "/bin/npm" if x == "npm" else None - mock_env["OPENAI_API_KEY"] = "sk-openai" - - # User says yes to install - mock_input.return_value = "y" - mock_subprocess.return_value.returncode = 0 - - detect_cli_tools() - - # The code now uses display_name, so we adjust expectations - mock_console.print.assert_any_call(" [yellow]You have OPENAI_API_KEY set but Codex CLI is not installed.[/yellow]") - mock_subprocess.assert_called_with( - "npm install -g @openai/codex", - shell=True, capture_output=True, text=True, timeout=120 - ) + mock_console.print.side_effect = capture_print + detect_cli_tools() + + return "\n".join(printed) + + +# =================================================================== +# 1. CliBootstrapResult data model +# =================================================================== + + +class TestCliBootstrapResult: + """Pure contract tests for the result dataclass.""" + + def test_defaults_to_empty(self): + r = CliBootstrapResult() + assert r.cli_name == "" + assert r.provider == "" + assert r.cli_path == "" + assert r.api_key_configured is False + assert r.skipped is False + + def test_skipped_result(self): + r = CliBootstrapResult(skipped=True) + assert r.skipped is True + assert r.cli_name == "" + + def test_populated_result(self): + r = CliBootstrapResult( + cli_name="claude", provider="anthropic", + cli_path="/usr/local/bin/claude", api_key_configured=True, + ) + assert r.cli_name == "claude" + assert r.provider == "anthropic" + assert r.cli_path == "/usr/local/bin/claude" + assert r.api_key_configured is True + assert r.skipped is False + + +# =================================================================== +# 2. detect_and_bootstrap_cli — Selection table & input parsing +# =================================================================== + + +class TestBootstrapSelectionTable: + """Tests for the numbered table display and user input parsing.""" + + def test_table_shows_all_three_clis(self, monkeypatch, tmp_path): + output, _ = _run_bootstrap_capture( + monkeypatch, tmp_path, ["1"], + cli_paths=CLAUDE_ONLY, env_keys=CLAUDE_KEY, + ) + assert "Claude CLI" in output + assert "Codex CLI" in output + assert "Gemini CLI" in output + + def test_table_shows_install_and_key_status(self, monkeypatch, tmp_path): + output, _ = _run_bootstrap_capture( + monkeypatch, tmp_path, ["1"], + cli_paths=CLAUDE_ONLY, env_keys=CLAUDE_KEY, + ) + # Claude is installed with key + assert "Found at" in output + assert "ANTHROPIC_API_KEY" in output + # Others are not installed + assert "Not found" in output + + def test_select_single_cli(self, monkeypatch, tmp_path): + _, results = _run_bootstrap_capture( + monkeypatch, tmp_path, ["1"], + cli_paths=CLAUDE_ONLY, env_keys=CLAUDE_KEY, + ) + assert len(results) == 1 + assert results[0].cli_name == "claude" + assert results[0].provider == "anthropic" + assert results[0].api_key_configured is True + + def test_multi_select_comma_separated(self, monkeypatch, tmp_path): + _, results = _run_bootstrap_capture( + monkeypatch, tmp_path, ["1,3"], + cli_paths=ALL_INSTALLED, env_keys=ALL_KEYS, + ) + assert len(results) == 2 + assert results[0].cli_name == "claude" + assert results[1].cli_name == "gemini" + + def test_multi_select_with_spaces(self, monkeypatch, tmp_path): + _, results = _run_bootstrap_capture( + monkeypatch, tmp_path, ["1, 3"], + cli_paths=ALL_INSTALLED, env_keys=ALL_KEYS, + ) + assert len(results) == 2 + assert results[0].cli_name == "claude" + assert results[1].cli_name == "gemini" + + def test_duplicate_input_deduplicated(self, monkeypatch, tmp_path): + _, results = _run_bootstrap_capture( + monkeypatch, tmp_path, ["1,1,3"], + cli_paths=ALL_INSTALLED, env_keys=ALL_KEYS, + ) + assert len(results) == 2 + assert results[0].cli_name == "claude" + assert results[1].cli_name == "gemini" + + def test_empty_input_defaults_to_installed_with_key(self, monkeypatch, tmp_path): + """Empty input → default to first CLI that is installed AND has a key.""" + output, results = _run_bootstrap_capture( + monkeypatch, tmp_path, [""], + cli_paths={"gemini": "/usr/bin/gemini"}, + env_keys={"GEMINI_API_KEY": "gm-test"}, + ) + assert len(results) == 1 + assert results[0].cli_name == "gemini" + assert "Defaulting" in output + + def test_empty_input_defaults_to_installed_when_no_keys(self, monkeypatch, tmp_path): + """No keys set → default to first installed CLI.""" + output, results = _run_bootstrap_capture( + monkeypatch, tmp_path, ["", ""], # selection + key prompt skip + cli_paths={"codex": "/usr/bin/codex"}, + ) + assert len(results) == 1 + assert results[0].cli_name == "codex" + + def test_invalid_input_falls_back_to_default(self, monkeypatch, tmp_path): + output, results = _run_bootstrap_capture( + monkeypatch, tmp_path, ["xyz"], + cli_paths=CLAUDE_ONLY, env_keys=CLAUDE_KEY, + ) + assert len(results) == 1 + assert "Invalid input" in output or "Defaulting" in output + + @pytest.mark.parametrize("quit_input", ["q", "n"]) + def test_quit_returns_skipped(self, monkeypatch, tmp_path, quit_input): + _, results = _run_bootstrap_capture( + monkeypatch, tmp_path, [quit_input], + cli_paths=CLAUDE_ONLY, env_keys=CLAUDE_KEY, + ) + assert len(results) == 1 + assert results[0].skipped is True + + +# =================================================================== +# 3. detect_and_bootstrap_cli — Install flow +# =================================================================== + + +class TestBootstrapInstallFlow: + """Tests for CLI installation behavior.""" + + def test_installed_cli_skips_install_prompt(self, monkeypatch, tmp_path): + """If CLI is already found, no install prompt is shown.""" + output, results = _run_bootstrap_capture( + monkeypatch, tmp_path, ["1"], + cli_paths=CLAUDE_ONLY, env_keys=CLAUDE_KEY, + ) + assert results[0].cli_name == "claude" + assert "Install now?" not in output + + def test_not_installed_user_accepts_npm_succeeds(self, monkeypatch, tmp_path): + """User accepts install, npm present, install succeeds.""" + _, results = _run_bootstrap_capture( + monkeypatch, tmp_path, + ["1", "y", ""], # select, accept install, skip key + npm_available=True, + install_succeeds=True, + install_then_found="/usr/local/bin/claude", + ) + assert len(results) == 1 + assert results[0].cli_name == "claude" + assert results[0].cli_path == "/usr/local/bin/claude" + assert results[0].skipped is False + + def test_not_installed_npm_missing(self, monkeypatch, tmp_path): + """User accepts install but npm is not available.""" + output, results = _run_bootstrap_capture( + monkeypatch, tmp_path, + ["1", "y"], # select, accept install + npm_available=False, + ) + assert results[0].skipped is True + assert "npm" in output.lower() + + def test_not_installed_install_fails(self, monkeypatch, tmp_path): + """Install command exits non-zero.""" + output, results = _run_bootstrap_capture( + monkeypatch, tmp_path, + ["1", "y"], # select, accept install + npm_available=True, + install_succeeds=False, + ) + assert results[0].skipped is True + assert "failed" in output.lower() or "manually" in output.lower() + + def test_not_installed_user_declines(self, monkeypatch, tmp_path): + """User declines install.""" + output, results = _run_bootstrap_capture( + monkeypatch, tmp_path, + ["1", "n"], # select, decline install + ) + assert results[0].skipped is True + assert "not configured" in output.lower() + + def test_install_succeeds_but_binary_not_found(self, monkeypatch, tmp_path): + """Install exits 0 but binary still not on PATH.""" + output, results = _run_bootstrap_capture( + monkeypatch, tmp_path, + ["1", "y"], # select, accept install + npm_available=True, + install_succeeds=True, + install_then_found=None, # not found after install + ) + assert results[0].skipped is True + assert "not found on PATH" in output or "not configured" in output.lower() + + +# =================================================================== +# 4. detect_and_bootstrap_cli — API key flow +# =================================================================== + + +class TestBootstrapApiKeyFlow: + """Tests for API key configuration behavior.""" + + def test_key_already_set_skips_prompt(self, monkeypatch, tmp_path): + """If key is already in env, no prompt for it.""" + output, results = _run_bootstrap_capture( + monkeypatch, tmp_path, ["1"], + cli_paths=CLAUDE_ONLY, env_keys=CLAUDE_KEY, + ) + assert results[0].api_key_configured is True + assert "Enter your" not in output + + def test_key_not_set_user_provides(self, monkeypatch, tmp_path): + """User provides key when prompted.""" + _, results = _run_bootstrap_capture( + monkeypatch, tmp_path, + ["1", "sk-new-key"], # select, provide key + cli_paths=CLAUDE_ONLY, + ) + assert results[0].api_key_configured is True + assert os.environ.get("ANTHROPIC_API_KEY") == "sk-new-key" + + def test_key_saved_to_file(self, monkeypatch, tmp_path): + """Provided key is written to ~/.pdd/api-env.bash.""" + _run_bootstrap_capture( + monkeypatch, tmp_path, + ["1", "sk-saved-key"], + cli_paths=CLAUDE_ONLY, + ) + api_env = tmp_path / ".pdd" / "api-env.bash" + assert api_env.exists() + content = api_env.read_text() + assert "export ANTHROPIC_API_KEY=sk-saved-key" in content + + def test_source_line_added_to_rc(self, monkeypatch, tmp_path): + """Source line is added to ~/.bashrc.""" + _run_bootstrap_capture( + monkeypatch, tmp_path, + ["1", "sk-test"], + cli_paths=CLAUDE_ONLY, + ) + rc_content = (tmp_path / ".bashrc").read_text() + api_env_path = str(tmp_path / ".pdd" / "api-env.bash") + assert f"source {api_env_path}" in rc_content + + def test_key_not_set_user_skips(self, monkeypatch, tmp_path): + """User presses Enter to skip key.""" + _, results = _run_bootstrap_capture( + monkeypatch, tmp_path, + ["1", ""], # select, skip key + cli_paths=CLAUDE_ONLY, + ) + assert results[0].api_key_configured is False + + def test_anthropic_skip_shows_subscription_note(self, monkeypatch, tmp_path): + """Skipping Anthropic key mentions subscription auth.""" + output, _ = _run_bootstrap_capture( + monkeypatch, tmp_path, + ["1", ""], # select, skip key + cli_paths=CLAUDE_ONLY, + ) + assert "subscription" in output.lower() + + def test_non_anthropic_skip_shows_limited_note(self, monkeypatch, tmp_path): + """Skipping non-Anthropic key mentions limited functionality.""" + output, _ = _run_bootstrap_capture( + monkeypatch, tmp_path, + ["2", ""], # select codex, skip key + cli_paths={"codex": "/usr/bin/codex"}, + ) + assert "limited functionality" in output.lower() + + def test_google_checks_gemini_key(self, monkeypatch, tmp_path): + """Google provider recognizes GEMINI_API_KEY.""" + _, results = _run_bootstrap_capture( + monkeypatch, tmp_path, ["3"], + cli_paths={"gemini": "/usr/bin/gemini"}, + env_keys={"GEMINI_API_KEY": "gm-test"}, + ) + assert results[0].api_key_configured is True + + def test_google_checks_google_api_key(self, monkeypatch, tmp_path): + """Google provider recognizes GOOGLE_API_KEY.""" + _, results = _run_bootstrap_capture( + monkeypatch, tmp_path, ["3"], + cli_paths={"gemini": "/usr/bin/gemini"}, + env_keys={"GOOGLE_API_KEY": "gm-test"}, + ) + assert results[0].api_key_configured is True + + +# =================================================================== +# 5. detect_and_bootstrap_cli — CLI test step +# =================================================================== + + +class TestBootstrapCliTest: + """Tests for the forced CLI verification step.""" + + def test_cli_test_runs_after_setup(self, monkeypatch, tmp_path): + """CLI test always runs, output includes version info.""" + output, _ = _run_bootstrap_capture( + monkeypatch, tmp_path, ["1"], + cli_paths=CLAUDE_ONLY, env_keys=CLAUDE_KEY, + version_output="2.5.0", + ) + assert "Testing" in output + assert "2.5.0" in output or "version" in output.lower() + + +# =================================================================== +# 6. detect_and_bootstrap_cli — Interrupt handling +# =================================================================== + + +class TestBootstrapInterrupts: + """Tests for graceful interrupt handling.""" + + def test_keyboard_interrupt_on_selection(self, monkeypatch, tmp_path): + """KeyboardInterrupt at selection prompt → skipped result.""" + for var in ("ANTHROPIC_API_KEY", "OPENAI_API_KEY", "GOOGLE_API_KEY", + "GEMINI_API_KEY"): + monkeypatch.delenv(var, raising=False) + monkeypatch.setenv("SHELL", "/bin/bash") + monkeypatch.setattr(Path, "home", staticmethod(lambda: tmp_path)) + + monkeypatch.setattr( + "pdd.cli_detector._prompt_input", + mock.MagicMock(side_effect=KeyboardInterrupt), + ) + + with mock.patch("pdd.cli_detector._find_cli_binary", return_value=None), \ + mock.patch("pdd.cli_detector.console"), \ + mock.patch("pdd.setup_tool._print_step_banner"): + results = detect_and_bootstrap_cli() + + assert len(results) == 1 + assert results[0].skipped is True + + def test_eof_on_selection(self, monkeypatch, tmp_path): + """EOFError at selection prompt → skipped result.""" + for var in ("ANTHROPIC_API_KEY", "OPENAI_API_KEY", "GOOGLE_API_KEY", + "GEMINI_API_KEY"): + monkeypatch.delenv(var, raising=False) + monkeypatch.setenv("SHELL", "/bin/bash") + monkeypatch.setattr(Path, "home", staticmethod(lambda: tmp_path)) + + monkeypatch.setattr( + "pdd.cli_detector._prompt_input", + mock.MagicMock(side_effect=EOFError), + ) + + with mock.patch("pdd.cli_detector._find_cli_binary", return_value=None), \ + mock.patch("pdd.cli_detector.console"), \ + mock.patch("pdd.setup_tool._print_step_banner"): + results = detect_and_bootstrap_cli() + + assert len(results) == 1 + assert results[0].skipped is True + + +# =================================================================== +# 7. detect_and_bootstrap_cli — API key persistence (shell variants) +# =================================================================== + + +class TestApiKeyPersistence: + """Tests for key file format across shell types.""" + + def test_fish_shell_syntax(self, monkeypatch, tmp_path): + """Fish shell uses set -gx and fish source syntax.""" + monkeypatch.setenv("SHELL", "/usr/bin/fish") + monkeypatch.setattr(Path, "home", staticmethod(lambda: tmp_path)) + + # Create fish config + fish_config = tmp_path / ".config" / "fish" / "config.fish" + fish_config.parent.mkdir(parents=True) + fish_config.write_text("") + + for var in ("ANTHROPIC_API_KEY", "OPENAI_API_KEY", "GOOGLE_API_KEY", + "GEMINI_API_KEY"): + monkeypatch.delenv(var, raising=False) + + input_iter = iter(["1", "sk-fish-key"]) + monkeypatch.setattr( + "pdd.cli_detector._prompt_input", + lambda _prompt="": next(input_iter), + ) + + with mock.patch("pdd.cli_detector._find_cli_binary") as mock_find, \ + mock.patch("pdd.cli_detector.console"), \ + mock.patch("subprocess.run") as mock_run, \ + mock.patch("shutil.which", return_value=None), \ + mock.patch("pdd.setup_tool._print_step_banner"): + mock_find.side_effect = lambda n: "/usr/bin/claude" if n == "claude" else None + mock_run.return_value = mock.MagicMock(returncode=0, stdout="1.0", stderr="") + detect_and_bootstrap_cli() + + api_env = tmp_path / ".pdd" / "api-env.fish" + assert api_env.exists() + content = api_env.read_text() + assert "set -gx ANTHROPIC_API_KEY sk-fish-key" in content + + rc_content = fish_config.read_text() + assert "test -f" in rc_content + assert "and source" in rc_content + + def test_duplicate_key_deduplicated(self, monkeypatch, tmp_path): + """Saving the same key twice doesn't create duplicate lines.""" + monkeypatch.setenv("SHELL", "/bin/bash") + monkeypatch.setattr(Path, "home", staticmethod(lambda: tmp_path)) + (tmp_path / ".bashrc").write_text("") + + for var in ("ANTHROPIC_API_KEY", "OPENAI_API_KEY", "GOOGLE_API_KEY", + "GEMINI_API_KEY"): + monkeypatch.delenv(var, raising=False) + + # First save + input_iter = iter(["1", "sk-first"]) + monkeypatch.setattr( + "pdd.cli_detector._prompt_input", + lambda _prompt="": next(input_iter), + ) + with mock.patch("pdd.cli_detector._find_cli_binary") as mock_find, \ + mock.patch("pdd.cli_detector.console"), \ + mock.patch("subprocess.run") as mock_run, \ + mock.patch("shutil.which", return_value=None), \ + mock.patch("pdd.setup_tool._print_step_banner"): + mock_find.side_effect = lambda n: "/usr/bin/claude" if n == "claude" else None + mock_run.return_value = mock.MagicMock(returncode=0, stdout="1.0", stderr="") + detect_and_bootstrap_cli() + + # Second save (overwrite key) + monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False) + input_iter2 = iter(["1", "sk-second"]) + monkeypatch.setattr( + "pdd.cli_detector._prompt_input", + lambda _prompt="": next(input_iter2), + ) + with mock.patch("pdd.cli_detector._find_cli_binary") as mock_find, \ + mock.patch("pdd.cli_detector.console"), \ + mock.patch("subprocess.run") as mock_run, \ + mock.patch("shutil.which", return_value=None), \ + mock.patch("pdd.setup_tool._print_step_banner"): + mock_find.side_effect = lambda n: "/usr/bin/claude" if n == "claude" else None + mock_run.return_value = mock.MagicMock(returncode=0, stdout="1.0", stderr="") + detect_and_bootstrap_cli() + + api_env = tmp_path / ".pdd" / "api-env.bash" + content = api_env.read_text() + # Should have only one export line for ANTHROPIC_API_KEY + export_lines = [l for l in content.splitlines() + if "ANTHROPIC_API_KEY" in l] + assert len(export_lines) == 1 + assert "sk-second" in export_lines[0] + + +# =================================================================== +# 8. detect_cli_tools — Legacy detection +# =================================================================== + + +class TestDetectCliToolsLegacy: + """Tests for the legacy detect_cli_tools function.""" + + def test_shows_header(self, monkeypatch): + output = _run_legacy_capture(monkeypatch) + assert "Agentic CLI Tool Detection" in output + assert "pdd fix" in output + + def test_found_cli_shows_checkmark_and_path(self, monkeypatch): + output = _run_legacy_capture( + monkeypatch, + cli_paths={"claude": "/usr/local/bin/claude"}, + env_keys=CLAUDE_KEY, + ) + assert "Claude CLI" in output + assert "Found at" in output or "/usr/local/bin/claude" in output + + def test_missing_cli_shows_not_found(self, monkeypatch): + output = _run_legacy_capture(monkeypatch) + assert "Not found" in output + + def test_key_set_but_cli_missing_suggests_install(self, monkeypatch): + output = _run_legacy_capture( + monkeypatch, + env_keys={"OPENAI_API_KEY": "sk-test"}, + ) + assert "OPENAI_API_KEY" in output + assert "not installed" in output.lower() or "install" in output.lower() + + def test_all_installed_with_keys_shows_success(self, monkeypatch): + output = _run_legacy_capture( + monkeypatch, + cli_paths=ALL_INSTALLED, + env_keys=ALL_KEYS, + ) + assert "All CLI tools" in output + + def test_no_clis_found_shows_quick_start(self, monkeypatch): + output = _run_legacy_capture(monkeypatch) + assert "No CLI tools found" in output or "Quick start" in output diff --git a/tests/test_litellm_registry.py b/tests/test_litellm_registry.py deleted file mode 100644 index 41079eeb3..000000000 --- a/tests/test_litellm_registry.py +++ /dev/null @@ -1,561 +0,0 @@ -"""Tests for pdd/litellm_registry.py""" - -from unittest import mock - -import pytest - -from pdd.litellm_registry import ( - ProviderInfo, - ModelInfo, - PROVIDER_API_KEY_MAP, - PROVIDER_DISPLAY_NAMES, - is_litellm_available, - get_api_key_env_var, - get_top_providers, - get_all_providers, - search_providers, - get_models_for_provider, - _get_display_name, - _entry_to_model_info, -) - - -# --------------------------------------------------------------------------- -# Tests for constants -# --------------------------------------------------------------------------- - - -class TestConstants: - """Tests for static mappings.""" - - def test_provider_api_key_map_has_common_providers(self): - """Should have API key mappings for common providers.""" - assert "openai" in PROVIDER_API_KEY_MAP - assert "anthropic" in PROVIDER_API_KEY_MAP - assert "gemini" in PROVIDER_API_KEY_MAP - assert "groq" in PROVIDER_API_KEY_MAP - assert "mistral" in PROVIDER_API_KEY_MAP - - def test_provider_api_key_map_values_are_strings(self): - """All API key env var names should be strings.""" - for key_name in PROVIDER_API_KEY_MAP.values(): - assert isinstance(key_name, str) - assert len(key_name) > 0 - # Should look like an env var (uppercase with underscores) - assert key_name.isupper() or "_" in key_name - - def test_provider_display_names_has_common_providers(self): - """Should have display names for common providers.""" - assert "openai" in PROVIDER_DISPLAY_NAMES - assert PROVIDER_DISPLAY_NAMES["openai"] == "OpenAI" - assert PROVIDER_DISPLAY_NAMES["anthropic"] == "Anthropic" - assert PROVIDER_DISPLAY_NAMES["gemini"] == "Google Gemini" - - def test_provider_display_names_are_human_readable(self): - """Display names should be human-readable (not snake_case).""" - for provider, display_name in PROVIDER_DISPLAY_NAMES.items(): - assert isinstance(display_name, str) - assert len(display_name) > 0 - # Should not be all lowercase with underscores - if "_" in provider: - assert "_" not in display_name or display_name != provider - - -# --------------------------------------------------------------------------- -# Tests for dataclasses -# --------------------------------------------------------------------------- - - -class TestDataclasses: - """Tests for ProviderInfo and ModelInfo dataclasses.""" - - def test_provider_info_fields(self): - """ProviderInfo should have all required fields.""" - info = ProviderInfo( - name="openai", - display_name="OpenAI", - api_key_env_var="OPENAI_API_KEY", - model_count=10, - sample_models=["gpt-4", "gpt-3.5-turbo"], - ) - assert info.name == "openai" - assert info.display_name == "OpenAI" - assert info.api_key_env_var == "OPENAI_API_KEY" - assert info.model_count == 10 - assert info.sample_models == ["gpt-4", "gpt-3.5-turbo"] - - def test_provider_info_defaults(self): - """ProviderInfo sample_models should default to empty list.""" - info = ProviderInfo( - name="test", - display_name="Test", - api_key_env_var=None, - model_count=0, - ) - assert info.sample_models == [] - - def test_model_info_fields(self): - """ModelInfo should have all required fields.""" - info = ModelInfo( - name="gpt-4", - litellm_id="gpt-4", - input_cost_per_million=30.0, - output_cost_per_million=60.0, - max_input_tokens=128000, - max_output_tokens=8192, - supports_vision=True, - supports_function_calling=True, - ) - assert info.name == "gpt-4" - assert info.litellm_id == "gpt-4" - assert info.input_cost_per_million == 30.0 - assert info.output_cost_per_million == 60.0 - assert info.max_input_tokens == 128000 - assert info.max_output_tokens == 8192 - assert info.supports_vision is True - assert info.supports_function_calling is True - - def test_model_info_defaults(self): - """ModelInfo should have sensible defaults.""" - info = ModelInfo( - name="test", - litellm_id="test", - input_cost_per_million=0.0, - output_cost_per_million=0.0, - ) - assert info.max_input_tokens is None - assert info.max_output_tokens is None - assert info.supports_vision is False - assert info.supports_function_calling is False - - -# --------------------------------------------------------------------------- -# Tests for is_litellm_available -# --------------------------------------------------------------------------- - - -class TestIsLitellmAvailable: - """Tests for is_litellm_available function.""" - - def test_returns_true_when_litellm_installed(self): - """Should return True when litellm is importable with model data.""" - # If litellm is installed in test environment, this should return True - # We'll mock it to ensure consistent behavior - mock_litellm = mock.MagicMock() - mock_litellm.model_cost = {"gpt-4": {"mode": "chat"}} - - with mock.patch.dict("sys.modules", {"litellm": mock_litellm}): - # Need to reimport or call the function after mocking - result = is_litellm_available() - # Either True (if litellm is installed) or we need to mock - assert isinstance(result, bool) - - def test_returns_false_when_litellm_not_installed(self): - """Should return False when litellm import fails.""" - with mock.patch.dict("sys.modules", {"litellm": None}): - # Force ImportError - def raise_import_error(): - raise ImportError("No module named 'litellm'") - - with mock.patch( - "pdd.litellm_registry.is_litellm_available", - side_effect=raise_import_error, - ): - # The actual function should handle this gracefully - pass - - def test_returns_false_when_model_cost_empty(self): - """Should return False when litellm.model_cost is empty.""" - mock_litellm = mock.MagicMock() - mock_litellm.model_cost = {} - - with mock.patch("pdd.litellm_registry.is_litellm_available") as mock_fn: - mock_fn.return_value = False - assert mock_fn() is False - - -# --------------------------------------------------------------------------- -# Tests for get_api_key_env_var -# --------------------------------------------------------------------------- - - -class TestGetApiKeyEnvVar: - """Tests for get_api_key_env_var function.""" - - def test_returns_key_for_known_providers(self): - """Should return correct API key env var for known providers.""" - assert get_api_key_env_var("openai") == "OPENAI_API_KEY" - assert get_api_key_env_var("anthropic") == "ANTHROPIC_API_KEY" - assert get_api_key_env_var("gemini") == "GEMINI_API_KEY" - assert get_api_key_env_var("groq") == "GROQ_API_KEY" - - def test_returns_none_for_unknown_providers(self): - """Should return None for providers not in the mapping.""" - assert get_api_key_env_var("unknown_provider") is None - assert get_api_key_env_var("") is None - assert get_api_key_env_var("my_custom_llm") is None - - def test_case_sensitive(self): - """Provider name lookup should be case-sensitive.""" - assert get_api_key_env_var("openai") == "OPENAI_API_KEY" - assert get_api_key_env_var("OpenAI") is None - assert get_api_key_env_var("OPENAI") is None - - -# --------------------------------------------------------------------------- -# Tests for _get_display_name -# --------------------------------------------------------------------------- - - -class TestGetDisplayName: - """Tests for _get_display_name helper function.""" - - def test_returns_mapped_name_for_known_providers(self): - """Should return the mapped display name for known providers.""" - assert _get_display_name("openai") == "OpenAI" - assert _get_display_name("fireworks_ai") == "Fireworks AI" - assert _get_display_name("together_ai") == "Together AI" - - def test_falls_back_to_title_case_for_unknown(self): - """Should fallback to title-case with underscore replacement.""" - assert _get_display_name("my_custom_provider") == "My Custom Provider" - assert _get_display_name("unknown") == "Unknown" - - def test_handles_empty_string(self): - """Should handle empty string gracefully.""" - result = _get_display_name("") - assert result == "" - - -# --------------------------------------------------------------------------- -# Tests for _entry_to_model_info -# --------------------------------------------------------------------------- - - -class TestEntryToModelInfo: - """Tests for _entry_to_model_info helper function.""" - - def test_converts_basic_entry(self): - """Should convert a model_cost entry to ModelInfo.""" - entry = { - "input_cost_per_token": 0.00003, - "output_cost_per_token": 0.00006, - "max_input_tokens": 128000, - "max_output_tokens": 8192, - "supports_vision": True, - "supports_function_calling": True, - } - - result = _entry_to_model_info("gpt-4", entry) - - assert result.name == "gpt-4" - assert result.litellm_id == "gpt-4" - assert result.input_cost_per_million == 30.0 - assert result.output_cost_per_million == 60.0 - assert result.max_input_tokens == 128000 - assert result.max_output_tokens == 8192 - assert result.supports_vision is True - assert result.supports_function_calling is True - - def test_handles_provider_prefix_in_model_id(self): - """Should extract model name from provider/model format.""" - entry = {"input_cost_per_token": 0, "output_cost_per_token": 0} - - result = _entry_to_model_info("anthropic/claude-3-opus", entry) - - assert result.name == "claude-3-opus" - assert result.litellm_id == "anthropic/claude-3-opus" - - def test_handles_missing_cost_fields(self): - """Should handle entries with missing cost fields.""" - entry = {} - - result = _entry_to_model_info("test-model", entry) - - assert result.input_cost_per_million == 0.0 - assert result.output_cost_per_million == 0.0 - - def test_handles_none_cost_values(self): - """Should handle None cost values.""" - entry = { - "input_cost_per_token": None, - "output_cost_per_token": None, - } - - result = _entry_to_model_info("test-model", entry) - - assert result.input_cost_per_million == 0.0 - assert result.output_cost_per_million == 0.0 - - def test_converts_per_token_to_per_million(self): - """Should correctly convert per-token costs to per-million.""" - entry = { - "input_cost_per_token": 0.000001, # $1 per million - "output_cost_per_token": 0.000002, # $2 per million - } - - result = _entry_to_model_info("test", entry) - - assert result.input_cost_per_million == 1.0 - assert result.output_cost_per_million == 2.0 - - -# --------------------------------------------------------------------------- -# Tests for get_top_providers (with mocking) -# --------------------------------------------------------------------------- - - -class TestGetTopProviders: - """Tests for get_top_providers function.""" - - def test_returns_empty_list_when_litellm_unavailable(self): - """Should return empty list when litellm is not available.""" - with mock.patch( - "pdd.litellm_registry.is_litellm_available", return_value=False - ): - result = get_top_providers() - assert result == [] - - def test_returns_list_of_provider_info(self): - """Should return a list of ProviderInfo objects.""" - # Only test if litellm is actually available - if not is_litellm_available(): - pytest.skip("litellm not installed") - - result = get_top_providers() - - assert isinstance(result, list) - if result: - assert isinstance(result[0], ProviderInfo) - - def test_includes_major_providers(self): - """Top providers should include major cloud providers.""" - if not is_litellm_available(): - pytest.skip("litellm not installed") - - result = get_top_providers() - provider_names = [p.name for p in result] - - # At least some major providers should be present - major_providers = {"openai", "anthropic", "gemini", "mistral"} - found = set(provider_names) & major_providers - assert len(found) > 0, f"Expected some major providers, got {provider_names}" - - -# --------------------------------------------------------------------------- -# Tests for get_all_providers (with mocking) -# --------------------------------------------------------------------------- - - -class TestGetAllProviders: - """Tests for get_all_providers function.""" - - def test_returns_empty_list_when_litellm_unavailable(self): - """Should return empty list when litellm is not available.""" - with mock.patch( - "pdd.litellm_registry.is_litellm_available", return_value=False - ): - result = get_all_providers() - assert result == [] - - def test_returns_sorted_by_model_count(self): - """Should return providers sorted by model count descending.""" - if not is_litellm_available(): - pytest.skip("litellm not installed") - - result = get_all_providers() - - if len(result) > 1: - for i in range(len(result) - 1): - assert result[i].model_count >= result[i + 1].model_count - - def test_all_providers_have_at_least_one_model(self): - """All returned providers should have at least one chat model.""" - if not is_litellm_available(): - pytest.skip("litellm not installed") - - result = get_all_providers() - - for provider in result: - assert provider.model_count >= 1 - - -# --------------------------------------------------------------------------- -# Tests for search_providers -# --------------------------------------------------------------------------- - - -class TestSearchProviders: - """Tests for search_providers function.""" - - def test_empty_query_returns_all_providers(self): - """Empty query should return all providers.""" - if not is_litellm_available(): - pytest.skip("litellm not installed") - - all_providers = get_all_providers() - search_result = search_providers("") - - assert len(search_result) == len(all_providers) - - def test_case_insensitive_search(self): - """Search should be case-insensitive.""" - if not is_litellm_available(): - pytest.skip("litellm not installed") - - result_lower = search_providers("openai") - result_upper = search_providers("OPENAI") - result_mixed = search_providers("OpenAI") - - # All should return the same results - assert len(result_lower) == len(result_upper) == len(result_mixed) - - def test_searches_in_name_and_display_name(self): - """Should search in both provider name and display name.""" - if not is_litellm_available(): - pytest.skip("litellm not installed") - - # Search by display name component - result = search_providers("Gemini") - provider_names = [p.name for p in result] - - # Should find gemini provider - assert any("gemini" in name.lower() for name in provider_names) - - def test_returns_empty_for_no_match(self): - """Should return empty list when no providers match.""" - result = search_providers("xyznonexistentprovider123") - assert result == [] - - def test_partial_match(self): - """Should match partial strings.""" - if not is_litellm_available(): - pytest.skip("litellm not installed") - - # "open" should match "openai" - result = search_providers("open") - if result: - assert any("open" in p.name.lower() for p in result) - - -# --------------------------------------------------------------------------- -# Tests for get_models_for_provider -# --------------------------------------------------------------------------- - - -class TestGetModelsForProvider: - """Tests for get_models_for_provider function.""" - - def test_returns_empty_list_when_litellm_unavailable(self): - """Should return empty list when litellm is not available.""" - with mock.patch( - "pdd.litellm_registry.is_litellm_available", return_value=False - ): - result = get_models_for_provider("openai") - assert result == [] - - def test_returns_list_of_model_info(self): - """Should return a list of ModelInfo objects.""" - if not is_litellm_available(): - pytest.skip("litellm not installed") - - result = get_models_for_provider("openai") - - assert isinstance(result, list) - if result: - assert isinstance(result[0], ModelInfo) - - def test_returns_sorted_by_name(self): - """Models should be sorted by name.""" - if not is_litellm_available(): - pytest.skip("litellm not installed") - - result = get_models_for_provider("openai") - - if len(result) > 1: - names = [m.name for m in result] - assert names == sorted(names) - - def test_returns_empty_for_unknown_provider(self): - """Should return empty list for unknown provider.""" - if not is_litellm_available(): - pytest.skip("litellm not installed") - - result = get_models_for_provider("nonexistent_provider_xyz") - assert result == [] - - def test_model_info_has_litellm_id(self): - """Each model should have a litellm_id for API calls.""" - if not is_litellm_available(): - pytest.skip("litellm not installed") - - result = get_models_for_provider("anthropic") - - for model in result: - assert model.litellm_id is not None - assert len(model.litellm_id) > 0 - - def test_handles_vertex_ai_subproviders(self): - """Should aggregate vertex_ai sub-providers.""" - if not is_litellm_available(): - pytest.skip("litellm not installed") - - result = get_models_for_provider("vertex_ai") - - # Should return some models (vertex_ai has many) - # The exact count depends on litellm version - assert isinstance(result, list) - - -# --------------------------------------------------------------------------- -# Integration tests -# --------------------------------------------------------------------------- - - -class TestIntegration: - """Integration tests that verify module works end-to-end.""" - - def test_workflow_search_to_models(self): - """Test typical workflow: search provider -> get models.""" - if not is_litellm_available(): - pytest.skip("litellm not installed") - - # Search for anthropic - providers = search_providers("anthropic") - assert len(providers) > 0 - - # Get the first matching provider - provider = providers[0] - assert provider.name == "anthropic" or "anthropic" in provider.name - - # Get models for that provider - models = get_models_for_provider(provider.name) - assert len(models) > 0 - - # Each model should have required fields - for model in models: - assert model.litellm_id - assert isinstance(model.input_cost_per_million, (int, float)) - assert isinstance(model.output_cost_per_million, (int, float)) - - def test_top_providers_have_valid_data(self): - """Top providers should all have valid, complete data.""" - if not is_litellm_available(): - pytest.skip("litellm not installed") - - top = get_top_providers() - - for provider in top: - # Each provider should have a name and display name - assert provider.name - assert provider.display_name - - # Model count should be positive - assert provider.model_count > 0 - - # Sample models should not exceed 3 - assert len(provider.sample_models) <= 3 - - # Can get models for this provider - models = get_models_for_provider(provider.name) - assert len(models) == provider.model_count diff --git a/tests/test_model_tester.py b/tests/test_model_tester.py index ed6365b71..dd691308a 100644 --- a/tests/test_model_tester.py +++ b/tests/test_model_tester.py @@ -1,6 +1,46 @@ -"""Tests for model_tester.py — CSV loading, key resolution, cost calculation, error classification.""" +# Test Plan: +# I. No-CSV Edge Cases (test_model_interactive exits early) +# 1. test_no_csv_file: No ~/.pdd/llm_model.csv → prints guidance message and returns. +# 2. test_empty_csv: CSV exists but has no data rows → same early exit. +# 3. test_csv_missing_required_columns: CSV exists but lacks provider/model/api_key → early exit. +# +# II. Interactive Flow — User Input Handling +# 4. test_quit_with_empty_input: User presses Enter immediately → exits cleanly. +# 5. test_quit_with_q: User types "q" → exits cleanly. +# 6. test_invalid_input_then_quit: User types "abc", sees error, then quits. +# 7. test_out_of_range_then_quit: User types "99" (out of range), sees error, then quits. +# 8. test_eof_exits_gracefully: EOFError during input → exits without crashing. +# +# III. Successful Model Test (end-to-end through test_model_interactive) +# 9. test_successful_test_shows_ok: User picks model 1, LLM returns OK → output shows ✓ OK with cost/tokens. +# 10. test_successful_test_passes_api_key_for_single_var: Single api_key var → passed as api_key= to litellm. +# 11. test_multi_var_provider_no_api_key_kwarg: Bedrock (pipe-delimited) → api_key= NOT passed to litellm. +# 12. test_device_flow_no_api_key_kwarg: Empty api_key → api_key= NOT passed to litellm. +# +# IV. Failed Model Test (end-to-end through test_model_interactive) +# 13. test_auth_error_shows_classified_message: LLM raises 401 → output shows "Authentication error". +# 14. test_connection_refused_shows_local_server_hint: LLM raises connection error → output suggests local server. +# +# V. Diagnostics Displayed Before Test +# 15. test_diagnostics_show_key_found: API key in env → output includes "✓ Found". +# 16. test_diagnostics_show_key_missing: API key not in env → output includes "✗ Not found". +# 17. test_diagnostics_show_base_url_for_lm_studio: LM Studio model → base URL shown in output. +# 18. test_diagnostics_bedrock_checks_all_vars: Bedrock model → all three env vars checked in output. +# 19. test_diagnostics_vertex_bad_creds_file: GOOGLE_APPLICATION_CREDENTIALS path invalid → warns in output. +# 20. test_diagnostics_device_flow_no_key_needed: Empty api_key → output indicates no key needed. +# +# VI. Session Persistence +# 21. test_results_persist_across_picks: User tests model 1 then model 2 → both results shown in table. +# +# VII. CSV Loading Normalization +# 22. test_csv_normalizes_nan_strings_and_bad_numerics: NaN strings → "", bad numbers → 0.0. +# +# VIII. Pure Function Contracts +# 23-28. _classify_error: auth, connection refused, not found, timeout, rate limit, generic. +# 29-30. _calculate_cost: basic math, zero tokens. + +"""Tests for model_tester.py — behavioral tests driven through test_model_interactive().""" -import os import pytest from unittest.mock import MagicMock, patch @@ -8,131 +48,421 @@ # --------------------------------------------------------------------------- -# Tests for _resolve_api_key +# Helpers # --------------------------------------------------------------------------- -def test_resolve_api_key_empty_key_name(): - """No api_key configured returns None with status message.""" - row = {"api_key": ""} - key, status = model_tester._resolve_api_key(row) - assert key is None - assert "no key configured" in status +def _make_csv(tmp_path, content): + """Write a CSV file at the expected ~/.pdd/llm_model.csv location.""" + csv_file = tmp_path / ".pdd" / "llm_model.csv" + csv_file.parent.mkdir(parents=True, exist_ok=True) + csv_file.write_text(content) + return csv_file -def test_resolve_api_key_found_in_env(monkeypatch): - """Key found in os.environ returns the value.""" - monkeypatch.setenv("TEST_API_KEY", "sk-abc123") - row = {"api_key": "TEST_API_KEY"} - key, status = model_tester._resolve_api_key(row) - assert key == "sk-abc123" - assert "Found" in status - assert "TEST_API_KEY" in status +def _mock_litellm_success(prompt_tokens=10, completion_tokens=5): + """Return a mock litellm response with token usage.""" + usage = MagicMock() + usage.prompt_tokens = prompt_tokens + usage.completion_tokens = completion_tokens + response = MagicMock() + response.usage = usage + return response -def test_resolve_api_key_not_found(monkeypatch): - """Key not in any source returns None with 'Not found'.""" - monkeypatch.delenv("MISSING_KEY", raising=False) - row = {"api_key": "MISSING_KEY"} - with patch.dict("sys.modules", {"dotenv": None}): - key, status = model_tester._resolve_api_key(row) - assert key is None - assert "Not found" in status +def _run_interactive(tmp_path, csv_content, user_inputs, monkeypatch, + mock_completion=None, env_vars=None): + """Run test_model_interactive with mocked CSV, user input, and litellm. + Returns the captured console output as a string. + """ + _make_csv(tmp_path, csv_content) -def test_resolve_api_key_strips_whitespace(monkeypatch): - """Key value is stripped of leading/trailing whitespace.""" - monkeypatch.setenv("PADDED_KEY", " sk-test ") - row = {"api_key": "PADDED_KEY"} - key, _ = model_tester._resolve_api_key(row) - assert key == "sk-test" + for k, v in (env_vars or {}).items(): + monkeypatch.setenv(k, v) + input_iter = iter(user_inputs) + mock_console_input = MagicMock(side_effect=input_iter) -# --------------------------------------------------------------------------- -# Tests for _resolve_base_url -# --------------------------------------------------------------------------- + if mock_completion is None: + mock_completion = MagicMock(return_value=_mock_litellm_success()) -def test_resolve_base_url_explicit(): - """Explicit base_url in row is returned directly.""" - row = {"base_url": "https://custom.api.com/v1"} - assert model_tester._resolve_base_url(row) == "https://custom.api.com/v1" + with patch.object(model_tester.Path, "home", return_value=tmp_path), \ + patch.object(model_tester.console, "input", mock_console_input), \ + patch("litellm.completion", mock_completion), \ + patch("sys.stdout"): # suppress dot-printing from thread + model_tester.test_model_interactive() + # Collect all console.print() calls into a single string for assertions. + # Each call may contain rich markup; we join them for substring matching. + output_parts = [] + for c in model_tester.console.print.call_args_list if hasattr(model_tester.console.print, "call_args_list") else []: + for arg in c.args: + output_parts.append(str(arg)) + return "\n".join(output_parts), mock_completion -def test_resolve_base_url_empty(): - """Empty base_url for non-local model returns None.""" - row = {"base_url": "", "model": "anthropic/claude", "provider": "Anthropic"} - assert model_tester._resolve_base_url(row) is None +def _run_interactive_capture(tmp_path, csv_content, user_inputs, monkeypatch, + mock_completion=None, env_vars=None): + """Like _run_interactive but patches console.print to capture output.""" + _make_csv(tmp_path, csv_content) -def test_resolve_base_url_lm_studio_model(monkeypatch): - """LM Studio model gets default localhost URL.""" - monkeypatch.delenv("LM_STUDIO_API_BASE", raising=False) - row = {"base_url": "", "model": "lm_studio/my-model", "provider": "lm_studio"} - result = model_tester._resolve_base_url(row) - assert result == "http://localhost:1234/v1" + for k, v in (env_vars or {}).items(): + monkeypatch.setenv(k, v) + input_iter = iter(user_inputs) -def test_resolve_base_url_lm_studio_custom_env(monkeypatch): - """LM Studio respects LM_STUDIO_API_BASE env var.""" - monkeypatch.setenv("LM_STUDIO_API_BASE", "http://remote:5000/v1") - row = {"base_url": "", "model": "lm_studio/model", "provider": "lm_studio"} - assert model_tester._resolve_base_url(row) == "http://remote:5000/v1" + if mock_completion is None: + mock_completion = MagicMock(return_value=_mock_litellm_success()) + captured = [] -# --------------------------------------------------------------------------- -# Tests for _calculate_cost -# --------------------------------------------------------------------------- + def _capture_print(*args, **kwargs): + for a in args: + captured.append(str(a)) -def test_calculate_cost_basic(): - """Cost calculation with known token counts and prices.""" - # 100 prompt tokens at $3/M + 50 completion tokens at $15/M - cost = model_tester._calculate_cost(100, 50, 3.0, 15.0) - expected = (100 * 3.0 + 50 * 15.0) / 1_000_000.0 - assert abs(cost - expected) < 1e-10 + with patch.object(model_tester.Path, "home", return_value=tmp_path), \ + patch.object(model_tester.console, "input", side_effect=input_iter), \ + patch.object(model_tester.console, "print", side_effect=_capture_print), \ + patch("litellm.completion", mock_completion), \ + patch("sys.stdout"): + model_tester.test_model_interactive() + return "\n".join(captured), mock_completion -def test_calculate_cost_zero(): - """Zero tokens or zero prices produce zero cost.""" - assert model_tester._calculate_cost(0, 0, 3.0, 15.0) == 0.0 - assert model_tester._calculate_cost(100, 100, 0.0, 0.0) == 0.0 +SIMPLE_CSV = "provider,model,api_key,input,output\nOpenAI,gpt-5,OPENAI_API_KEY,3.0,15.0\n" -# --------------------------------------------------------------------------- -# Tests for _classify_error -# --------------------------------------------------------------------------- +TWO_MODEL_CSV = ( + "provider,model,api_key,input,output\n" + "OpenAI,gpt-5,OPENAI_API_KEY,3.0,15.0\n" + "Anthropic,claude-sonnet,ANTHROPIC_API_KEY,3.0,15.0\n" +) -def test_classify_error_auth(): - """Authentication-related errors are classified correctly.""" - exc = Exception("401 Unauthorized - invalid api key") - result = model_tester._classify_error(exc) - assert "Authentication error" in result +BEDROCK_CSV = ( + "provider,model,api_key,input,output\n" + "AWS Bedrock,bedrock/anthropic.claude-v1," + "AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME,1.0,5.0\n" +) +DEVICE_FLOW_CSV = ( + "provider,model,api_key,input,output\n" + "Github Copilot,github_copilot/gpt-5,,0.0,0.0\n" +) -def test_classify_error_connection_refused(): - """Connection refused errors suggest local server issue.""" - exc = ConnectionError("Connection refused") - result = model_tester._classify_error(exc) - assert "Connection refused" in result +LM_STUDIO_CSV = ( + "provider,model,api_key,input,output,base_url\n" + "lm_studio,lm_studio/local-model,,0.0,0.0,\n" +) +VERTEX_CSV = ( + "provider,model,api_key,input,output\n" + "Google Vertex AI,vertex_ai/gemini-2.5-pro," + "GOOGLE_APPLICATION_CREDENTIALS|VERTEXAI_PROJECT|VERTEXAI_LOCATION,1.0,5.0\n" +) -def test_classify_error_not_found(): - """404 / model not found errors classified correctly.""" - exc = Exception("404 Model does not exist") - result = model_tester._classify_error(exc) - assert "Model not found" in result +# =========================================================================== +# I. No-CSV Edge Cases +# =========================================================================== -def test_classify_error_timeout(): - """Timeout errors classified correctly.""" - exc = TimeoutError("Request timed out after 30s") - result = model_tester._classify_error(exc) - assert "timed out" in result +def test_no_csv_file(tmp_path): + """No ~/.pdd/llm_model.csv → prints guidance and returns.""" + with patch.object(model_tester.Path, "home", return_value=tmp_path): + captured = [] + with patch.object(model_tester.console, "print", + side_effect=lambda *a, **kw: captured.extend(str(x) for x in a)): + model_tester.test_model_interactive() + output = "\n".join(captured) + assert "No user model CSV" in output + assert "pdd setup" in output + + +def test_empty_csv(tmp_path): + """CSV with headers but no rows → same early exit.""" + _make_csv(tmp_path, "provider,model,api_key,input,output\n") + with patch.object(model_tester.Path, "home", return_value=tmp_path): + captured = [] + with patch.object(model_tester.console, "print", + side_effect=lambda *a, **kw: captured.extend(str(x) for x in a)): + model_tester.test_model_interactive() + output = "\n".join(captured) + assert "No user model CSV" in output + + +def test_csv_missing_required_columns(tmp_path): + """CSV with wrong columns → early exit.""" + _make_csv(tmp_path, "name,value\nfoo,bar\n") + with patch.object(model_tester.Path, "home", return_value=tmp_path): + captured = [] + with patch.object(model_tester.console, "print", + side_effect=lambda *a, **kw: captured.extend(str(x) for x in a)): + model_tester.test_model_interactive() + output = "\n".join(captured) + assert "No user model CSV" in output or "missing required columns" in output + + +# =========================================================================== +# II. Interactive Flow — User Input Handling +# =========================================================================== + +def test_quit_with_empty_input(tmp_path, monkeypatch): + """User presses Enter → exits cleanly.""" + output, _ = _run_interactive_capture( + tmp_path, SIMPLE_CSV, [""], monkeypatch, + env_vars={"OPENAI_API_KEY": "sk-test"}, + ) + assert "Exiting" in output + + +def test_quit_with_q(tmp_path, monkeypatch): + """User types 'q' → exits cleanly.""" + output, _ = _run_interactive_capture( + tmp_path, SIMPLE_CSV, ["q"], monkeypatch, + env_vars={"OPENAI_API_KEY": "sk-test"}, + ) + assert "Exiting" in output + + +def test_invalid_input_then_quit(tmp_path, monkeypatch): + """User types 'abc' → error message, then quits.""" + output, _ = _run_interactive_capture( + tmp_path, SIMPLE_CSV, ["abc", "q"], monkeypatch, + env_vars={"OPENAI_API_KEY": "sk-test"}, + ) + assert "Invalid input" in output + + +def test_out_of_range_then_quit(tmp_path, monkeypatch): + """User types '99' → out-of-range error, then quits.""" + output, _ = _run_interactive_capture( + tmp_path, SIMPLE_CSV, ["99", "q"], monkeypatch, + env_vars={"OPENAI_API_KEY": "sk-test"}, + ) + assert "Invalid selection" in output + + +def test_eof_exits_gracefully(tmp_path, monkeypatch): + """EOFError during input → exits without crashing.""" + _make_csv(tmp_path, SIMPLE_CSV) + monkeypatch.setenv("OPENAI_API_KEY", "sk-test") + + with patch.object(model_tester.Path, "home", return_value=tmp_path), \ + patch.object(model_tester.console, "input", side_effect=EOFError), \ + patch.object(model_tester.console, "print"), \ + patch("sys.stdout"): + # Should not raise + model_tester.test_model_interactive() + + +# =========================================================================== +# III. Successful Model Test +# =========================================================================== + +def test_successful_test_shows_ok(tmp_path, monkeypatch): + """User picks model 1, LLM succeeds → output shows ✓ OK with cost info.""" + output, _ = _run_interactive_capture( + tmp_path, SIMPLE_CSV, ["1", "q"], monkeypatch, + mock_completion=MagicMock(return_value=_mock_litellm_success(10, 5)), + env_vars={"OPENAI_API_KEY": "sk-test"}, + ) + assert "✓ OK" in output + + +def test_successful_test_passes_api_key_for_single_var(tmp_path, monkeypatch): + """Single-var provider → api_key= passed to litellm.completion.""" + mock_comp = MagicMock(return_value=_mock_litellm_success()) + _run_interactive_capture( + tmp_path, SIMPLE_CSV, ["1", "q"], monkeypatch, + mock_completion=mock_comp, + env_vars={"OPENAI_API_KEY": "sk-test123"}, + ) + call_kwargs = mock_comp.call_args[1] + assert call_kwargs["api_key"] == "sk-test123" + + +def test_multi_var_provider_no_api_key_kwarg(tmp_path, monkeypatch): + """Bedrock (pipe-delimited api_key) → api_key= NOT passed to litellm.""" + mock_comp = MagicMock(return_value=_mock_litellm_success()) + _run_interactive_capture( + tmp_path, BEDROCK_CSV, ["1", "q"], monkeypatch, + mock_completion=mock_comp, + env_vars={ + "AWS_ACCESS_KEY_ID": "AKIAEXAMPLE", + "AWS_SECRET_ACCESS_KEY": "secret", + "AWS_REGION_NAME": "us-east-1", + }, + ) + call_kwargs = mock_comp.call_args[1] + assert "api_key" not in call_kwargs + + +def test_device_flow_no_api_key_kwarg(tmp_path, monkeypatch): + """Device flow (empty api_key) → api_key= NOT passed to litellm.""" + mock_comp = MagicMock(return_value=_mock_litellm_success()) + _run_interactive_capture( + tmp_path, DEVICE_FLOW_CSV, ["1", "q"], monkeypatch, + mock_completion=mock_comp, + ) + call_kwargs = mock_comp.call_args[1] + assert "api_key" not in call_kwargs + + +# =========================================================================== +# IV. Failed Model Test +# =========================================================================== + +def test_auth_error_shows_classified_message(tmp_path, monkeypatch): + """LLM raises 401 → output shows 'Authentication error'.""" + mock_comp = MagicMock(side_effect=Exception("401 Unauthorized")) + output, _ = _run_interactive_capture( + tmp_path, SIMPLE_CSV, ["1", "q"], monkeypatch, + mock_completion=mock_comp, + env_vars={"OPENAI_API_KEY": "sk-bad"}, + ) + assert "Authentication error" in output + + +def test_connection_refused_shows_local_server_hint(tmp_path, monkeypatch): + """LLM raises connection error → output suggests local server.""" + mock_comp = MagicMock(side_effect=ConnectionError("Connection refused")) + output, _ = _run_interactive_capture( + tmp_path, LM_STUDIO_CSV, ["1", "q"], monkeypatch, + mock_completion=mock_comp, + ) + assert "Connection refused" in output + assert "local server" in output + + +# =========================================================================== +# V. Diagnostics Displayed Before Test +# =========================================================================== + +def test_diagnostics_show_key_found(tmp_path, monkeypatch): + """API key in env → diagnostics show ✓ Found.""" + output, _ = _run_interactive_capture( + tmp_path, SIMPLE_CSV, ["1", "q"], monkeypatch, + env_vars={"OPENAI_API_KEY": "sk-test"}, + ) + assert "✓ Found" in output + assert "OPENAI_API_KEY" in output + + +def test_diagnostics_show_key_missing(tmp_path, monkeypatch): + """API key NOT in env → diagnostics show ✗ Not found.""" + monkeypatch.delenv("OPENAI_API_KEY", raising=False) + output, _ = _run_interactive_capture( + tmp_path, SIMPLE_CSV, ["1", "q"], monkeypatch, + ) + assert "✗ Not found" in output + + +def test_diagnostics_show_base_url_for_lm_studio(tmp_path, monkeypatch): + """LM Studio model → base URL shown in diagnostics.""" + monkeypatch.delenv("LM_STUDIO_API_BASE", raising=False) + output, _ = _run_interactive_capture( + tmp_path, LM_STUDIO_CSV, ["1", "q"], monkeypatch, + ) + assert "localhost:1234" in output + + +def test_diagnostics_bedrock_checks_all_vars(tmp_path, monkeypatch): + """Bedrock model → all three env vars appear in diagnostics.""" + monkeypatch.setenv("AWS_ACCESS_KEY_ID", "AKIAEXAMPLE") + monkeypatch.delenv("AWS_SECRET_ACCESS_KEY", raising=False) + monkeypatch.setenv("AWS_REGION_NAME", "us-east-1") + output, _ = _run_interactive_capture( + tmp_path, BEDROCK_CSV, ["1", "q"], monkeypatch, + ) + assert "AWS_ACCESS_KEY_ID" in output + assert "AWS_SECRET_ACCESS_KEY" in output + assert "AWS_REGION_NAME" in output + # One should be found, one missing + assert "✓ Found" in output + assert "✗ Not found" in output + + +def test_diagnostics_vertex_bad_creds_file(tmp_path, monkeypatch): + """GOOGLE_APPLICATION_CREDENTIALS pointing to nonexistent file → warns.""" + monkeypatch.setenv("GOOGLE_APPLICATION_CREDENTIALS", "/nonexistent/creds.json") + monkeypatch.setenv("VERTEXAI_PROJECT", "my-project") + monkeypatch.setenv("VERTEXAI_LOCATION", "us-central1") + output, _ = _run_interactive_capture( + tmp_path, VERTEX_CSV, ["1", "q"], monkeypatch, + ) + assert "file not found" in output + + +def test_diagnostics_device_flow_no_key_needed(tmp_path, monkeypatch): + """Device flow provider → diagnostics say no key needed.""" + output, _ = _run_interactive_capture( + tmp_path, DEVICE_FLOW_CSV, ["1", "q"], monkeypatch, + ) + assert "Device flow" in output or "no key needed" in output + + +# =========================================================================== +# VI. Session Persistence +# =========================================================================== + +def test_results_persist_across_picks(tmp_path, monkeypatch): + """User tests model 1 then model 2 → second table render includes first result.""" + call_count = [0] + def _completion_side_effect(**kwargs): + call_count[0] += 1 + return _mock_litellm_success() + + mock_comp = MagicMock(side_effect=_completion_side_effect) + output, _ = _run_interactive_capture( + tmp_path, TWO_MODEL_CSV, ["1", "2", "q"], monkeypatch, + mock_completion=mock_comp, + env_vars={"OPENAI_API_KEY": "sk-test", "ANTHROPIC_API_KEY": "sk-test"}, + ) + # litellm.completion should have been called twice (once per model) + assert mock_comp.call_count == 2 + + +# =========================================================================== +# VII. CSV Loading Normalization +# =========================================================================== + +def test_csv_normalizes_nan_strings_and_bad_numerics(tmp_path): + """NaN string columns → empty string; non-numeric cost → 0.0.""" + csv_content = ( + "provider,model,api_key,base_url,location,input,output\n" + "OpenAI,gpt-5,,,us-east,bad,3.0\n" + ) + _make_csv(tmp_path, csv_content) + with patch.object(model_tester.Path, "home", return_value=tmp_path): + df = model_tester._load_user_csv() -def test_classify_error_rate_limit(): - """Rate limit errors classified correctly.""" - exc = Exception("429 Rate limit exceeded") + assert df is not None + row = df.iloc[0] + assert row["api_key"] == "" + assert row["base_url"] == "" + assert row["input"] == 0.0 + assert row["output"] == 3.0 + + +# =========================================================================== +# VIII. Pure Function Contracts — _classify_error +# These are kept as direct tests because _classify_error is a pure function +# with clear sub-contract semantics (like ExtractedCode in test_postprocess.py). +# =========================================================================== + +@pytest.mark.parametrize("message,expected_fragment", [ + ("401 Unauthorized - invalid api key", "Authentication error"), + ("403 Forbidden - access denied", "Authentication error"), + ("Connection refused", "Connection refused"), + ("404 Model does not exist", "Model not found"), + ("Request timed out after 30s", "timed out"), + ("429 Rate limit exceeded", "Rate limited"), +]) +def test_classify_error_categories(message, expected_fragment): + """Error messages are classified into user-friendly categories.""" + exc = Exception(message) result = model_tester._classify_error(exc) - assert "Rate limited" in result + assert expected_fragment in result def test_classify_error_generic(): @@ -143,90 +473,18 @@ def test_classify_error_generic(): assert "Something unexpected" in result -# --------------------------------------------------------------------------- -# Tests for _run_test -# --------------------------------------------------------------------------- +# =========================================================================== +# IX. Pure Function Contracts — _calculate_cost +# =========================================================================== -@patch("litellm.completion") -def test_run_test_success(mock_completion, monkeypatch): - """Successful litellm call returns success dict with tokens and cost.""" - monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-test") - - usage = MagicMock() - usage.prompt_tokens = 10 - usage.completion_tokens = 5 - response = MagicMock() - response.usage = usage - mock_completion.return_value = response - - row = {"model": "anthropic/claude", "api_key": "ANTHROPIC_API_KEY", - "input": 3.0, "output": 15.0} - result = model_tester._run_test(row) - - assert result["success"] is True - assert result["error"] is None - assert result["tokens"]["prompt"] == 10 - assert result["tokens"]["completion"] == 5 - assert result["cost"] > 0 - assert result["duration_s"] >= 0 - - -@patch("litellm.completion") -def test_run_test_failure(mock_completion, monkeypatch): - """Failed litellm call returns failure dict with classified error.""" - monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-test") - mock_completion.side_effect = Exception("401 Unauthorized") - - row = {"model": "anthropic/claude", "api_key": "ANTHROPIC_API_KEY", - "input": 3.0, "output": 15.0} - result = model_tester._run_test(row) - - assert result["success"] is False - assert result["cost"] == 0.0 - assert result["tokens"] is None - assert "Authentication error" in result["error"] - - -# --------------------------------------------------------------------------- -# Tests for _load_user_csv -# --------------------------------------------------------------------------- - -@patch("pdd.model_tester.Path") -def test_load_user_csv_missing_file(mock_path): - """Returns None when CSV file doesn't exist.""" - mock_home = MagicMock() - mock_csv = MagicMock() - mock_csv.is_file.return_value = False - mock_home.__truediv__ = MagicMock(return_value=MagicMock(__truediv__=MagicMock(return_value=mock_csv))) - mock_path.home.return_value = mock_home - - result = model_tester._load_user_csv() - assert result is None - - -def test_load_user_csv_valid(tmp_path): - """Returns DataFrame for a valid CSV with required columns.""" - csv_content = "provider,model,api_key,input,output\nAnthropic,claude,ANTHROPIC_API_KEY,3.0,15.0\n" - csv_file = tmp_path / ".pdd" / "llm_model.csv" - csv_file.parent.mkdir(parents=True) - csv_file.write_text(csv_content) - - with patch.object(model_tester.Path, "home", return_value=tmp_path): - df = model_tester._load_user_csv() - - assert df is not None - assert len(df) == 1 - assert df.iloc[0]["provider"] == "Anthropic" - - -def test_load_user_csv_missing_columns(tmp_path): - """Returns None when required columns are missing.""" - csv_content = "name,value\nfoo,bar\n" - csv_file = tmp_path / ".pdd" / "llm_model.csv" - csv_file.parent.mkdir(parents=True) - csv_file.write_text(csv_content) +def test_calculate_cost_basic(): + """Cost = (prompt_tokens * input_price + completion_tokens * output_price) / 1M.""" + cost = model_tester._calculate_cost(100, 50, 3.0, 15.0) + expected = (100 * 3.0 + 50 * 15.0) / 1_000_000.0 + assert abs(cost - expected) < 1e-10 - with patch.object(model_tester.Path, "home", return_value=tmp_path): - df = model_tester._load_user_csv() - assert df is None +def test_calculate_cost_zero(): + """Zero tokens or zero prices produce zero cost.""" + assert model_tester._calculate_cost(0, 0, 3.0, 15.0) == 0.0 + assert model_tester._calculate_cost(100, 100, 0.0, 0.0) == 0.0 diff --git a/tests/test_pddrc_initializer.py b/tests/test_pddrc_initializer.py index b9bd50af5..a05306fc7 100644 --- a/tests/test_pddrc_initializer.py +++ b/tests/test_pddrc_initializer.py @@ -1,207 +1,356 @@ -"""Tests for pddrc_initializer.py — language detection, content generation, offer flow.""" +# Test Plan: +# I. Early Exits +# 1. test_already_exists_returns_false: .pddrc exists → returns False, file untouched +# 2. test_already_exists_shows_message: .pddrc exists → output mentions "already exists" +# +# II. Language Detection (pure function contract — stable sub-contract) +# 3. test_detect_language_python_markers: pyproject.toml, setup.py, requirements.txt → "python" +# 4. test_detect_language_typescript: package.json with typescript dep → "typescript" +# 5. test_detect_language_not_typescript_without_dep: package.json without typescript → None +# 6. test_detect_language_go: go.mod → "go" +# 7. test_detect_language_none: empty dir → None +# 8. test_detect_language_python_priority: both pyproject.toml and go.mod → "python" +# +# III. Content Generation (pure function contract — stable sub-contract) +# 9. test_build_content_language_paths: each language gets correct output paths +# 10. test_build_content_standard_defaults: strength, temperature, etc. present +# 11. test_build_content_unknown_language_fallback: unknown lang falls back to Python paths +# 12. test_build_content_ends_with_newline: trailing newline +# +# IV. Success Path — File Created +# 13. test_creates_file_on_confirm_yes: user types "y" → file created, returns True +# 14. test_creates_file_on_enter: empty input (Enter) → file created, returns True +# 15. test_created_file_has_correct_content: created file contains detected language defaults +# 16. test_prompts_language_when_undetected: no markers → asks for language, then confirms +# +# V. User Declines +# 17. test_declined_returns_false: user types "n" → returns False, no file +# +# VI. Language Prompt with Invalid Input +# 18. test_language_prompt_retries_on_invalid: invalid then valid → correct language used +# +# VII. Output / Display +# 19. test_detected_language_shown: auto-detected language appears in output +# 20. test_preview_shown_before_confirmation: YAML preview shown before user asked to confirm +# 21. test_creation_success_message: "Created .pddrc" message appears after creation +# 22. test_skip_message_on_decline: "Skipped" message appears when user declines +# +# VIII. Filesystem Error +# 23. test_write_error_returns_false: OSError on write → returns False, error shown import json import pytest -from unittest.mock import MagicMock, patch +from unittest.mock import patch, MagicMock from pathlib import Path +from io import StringIO from pdd import pddrc_initializer +from pdd.pddrc_initializer import _detect_language, _build_pddrc_content # --------------------------------------------------------------------------- -# Tests for _detect_language +# Module-level fixtures / constants # --------------------------------------------------------------------------- -def test_detect_language_python_pyproject(tmp_path): - """Detects Python from pyproject.toml.""" - (tmp_path / "pyproject.toml").touch() - assert pddrc_initializer._detect_language(tmp_path) == "python" +PYTHON_PROJECT_MARKERS = ["pyproject.toml", "setup.py", "requirements.txt"] + +TS_PACKAGE_JSON = json.dumps({ + "devDependencies": {"typescript": "^5.0.0"} +}) + +NON_TS_PACKAGE_JSON = json.dumps({ + "dependencies": {"express": "^4.0.0"} +}) + + +# --------------------------------------------------------------------------- +# Helper: run offer_pddrc_init and capture output +# --------------------------------------------------------------------------- + +def _run_offer_capture(tmp_path, monkeypatch, user_inputs, *, marker_files=None): + """Run offer_pddrc_init() in tmp_path, capturing printed output. + + Parameters + ---------- + tmp_path : Path + Working directory for the test. + monkeypatch : pytest.MonkeyPatch + Used to patch cwd. + user_inputs : list[str] + Sequence of strings returned by console.input() calls. + marker_files : dict[str, str | None] | None + Files to create in tmp_path before running. Keys are filenames, + values are contents (None → touch). + + Returns + ------- + tuple[bool, str] + (return_value, captured_output_text) + """ + # Set up marker files + if marker_files: + for name, content in marker_files.items(): + path = tmp_path / name + if content is not None: + path.write_text(content) + else: + path.touch() + + # Mock console.input to feed user inputs + input_iter = iter(user_inputs) + + # Capture console.print output + captured = [] + + original_print = pddrc_initializer.console.print + + def fake_print(*args, **kwargs): + # Convert to plain string for assertion + buf = StringIO() + temp_console = pddrc_initializer.Console(file=buf, force_terminal=False, no_color=True) + temp_console.print(*args, **kwargs) + captured.append(buf.getvalue()) + + with patch.object(Path, "cwd", return_value=tmp_path), \ + patch.object(pddrc_initializer.console, "input", side_effect=input_iter), \ + patch.object(pddrc_initializer.console, "print", side_effect=fake_print): + result = pddrc_initializer.offer_pddrc_init() + output = "".join(captured) + return result, output -def test_detect_language_python_setup_py(tmp_path): - """Detects Python from setup.py.""" - (tmp_path / "setup.py").touch() - assert pddrc_initializer._detect_language(tmp_path) == "python" +# --------------------------------------------------------------------------- +# I. Early Exits +# --------------------------------------------------------------------------- -def test_detect_language_python_requirements(tmp_path): - """Detects Python from requirements.txt.""" - (tmp_path / "requirements.txt").touch() - assert pddrc_initializer._detect_language(tmp_path) == "python" +def test_already_exists_returns_false(tmp_path, monkeypatch): + """When .pddrc already exists, offer_pddrc_init returns False.""" + (tmp_path / ".pddrc").write_text("existing config") + result, _ = _run_offer_capture(tmp_path, monkeypatch, []) + assert result is False + assert (tmp_path / ".pddrc").read_text() == "existing config" + + +def test_already_exists_shows_message(tmp_path, monkeypatch): + """When .pddrc already exists, user sees 'already exists' message.""" + (tmp_path / ".pddrc").write_text("existing config") + _, output = _run_offer_capture(tmp_path, monkeypatch, []) + assert "already exists" in output + + +# --------------------------------------------------------------------------- +# II. Language Detection (pure function — stable sub-contract) +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize("marker", PYTHON_PROJECT_MARKERS) +def test_detect_language_python_markers(tmp_path, marker): + """Python marker files are detected correctly.""" + (tmp_path / marker).touch() + assert _detect_language(tmp_path) == "python" def test_detect_language_typescript(tmp_path): - """Detects TypeScript from package.json with typescript dependency.""" - pkg = {"devDependencies": {"typescript": "^5.0.0"}} - (tmp_path / "package.json").write_text(json.dumps(pkg)) - assert pddrc_initializer._detect_language(tmp_path) == "typescript" + """package.json with typescript dependency → 'typescript'.""" + (tmp_path / "package.json").write_text(TS_PACKAGE_JSON) + assert _detect_language(tmp_path) == "typescript" def test_detect_language_not_typescript_without_dep(tmp_path): - """package.json without typescript dep is not detected as TypeScript.""" - pkg = {"dependencies": {"express": "^4.0.0"}} - (tmp_path / "package.json").write_text(json.dumps(pkg)) - assert pddrc_initializer._detect_language(tmp_path) is None + """package.json without typescript dep → None.""" + (tmp_path / "package.json").write_text(NON_TS_PACKAGE_JSON) + assert _detect_language(tmp_path) is None def test_detect_language_go(tmp_path): - """Detects Go from go.mod.""" + """go.mod → 'go'.""" (tmp_path / "go.mod").touch() - assert pddrc_initializer._detect_language(tmp_path) == "go" + assert _detect_language(tmp_path) == "go" def test_detect_language_none(tmp_path): - """Returns None when no markers found.""" - assert pddrc_initializer._detect_language(tmp_path) is None + """Empty directory → None.""" + assert _detect_language(tmp_path) is None -def test_detect_language_python_priority_over_go(tmp_path): +def test_detect_language_python_priority(tmp_path): """Python markers take priority over Go markers.""" (tmp_path / "pyproject.toml").touch() (tmp_path / "go.mod").touch() - assert pddrc_initializer._detect_language(tmp_path) == "python" + assert _detect_language(tmp_path) == "python" # --------------------------------------------------------------------------- -# Tests for _build_pddrc_content +# III. Content Generation (pure function — stable sub-contract) # --------------------------------------------------------------------------- -def test_build_pddrc_content_python(): - """Python content has correct paths and defaults.""" - content = pddrc_initializer._build_pddrc_content("python") - assert 'version: "1.0"' in content - assert 'generate_output_path: "pdd/"' in content - assert 'test_output_path: "tests/"' in content - assert 'example_output_path: "context/"' in content - assert 'default_language: "python"' in content - assert "strength: 1.0" in content +@pytest.mark.parametrize("language, gen_path, test_path, example_path", [ + ("python", "pdd/", "tests/", "context/"), + ("typescript", "src/", "__tests__/", "examples/"), + ("go", ".", ".", "examples/"), +]) +def test_build_content_language_paths(language, gen_path, test_path, example_path): + """Each language gets correct output paths in generated content.""" + content = _build_pddrc_content(language) + assert f'generate_output_path: "{gen_path}"' in content + assert f'test_output_path: "{test_path}"' in content + assert f'example_output_path: "{example_path}"' in content + assert f'default_language: "{language}"' in content + + +def test_build_content_standard_defaults(): + """Generated content includes all standard defaults.""" + content = _build_pddrc_content("python") + assert "strength: 0.818" in content assert "temperature: 0.0" in content assert "target_coverage: 80.0" in content assert "budget: 10.0" in content assert "max_attempts: 3" in content + assert 'version: "1.0"' in content -def test_build_pddrc_content_typescript(): - """TypeScript content has correct paths.""" - content = pddrc_initializer._build_pddrc_content("typescript") - assert 'generate_output_path: "src/"' in content - assert 'test_output_path: "__tests__/"' in content - assert 'example_output_path: "examples/"' in content - assert 'default_language: "typescript"' in content - - -def test_build_pddrc_content_go(): - """Go content has correct paths.""" - content = pddrc_initializer._build_pddrc_content("go") - assert 'generate_output_path: "."' in content - assert 'test_output_path: "."' in content - assert 'example_output_path: "examples/"' in content - assert 'default_language: "go"' in content - - -def test_build_pddrc_content_unknown_falls_back_to_python(): - """Unknown language falls back to Python defaults.""" - content = pddrc_initializer._build_pddrc_content("rust") +def test_build_content_unknown_language_fallback(): + """Unknown language falls back to Python paths but uses given language name.""" + content = _build_pddrc_content("rust") assert 'generate_output_path: "pdd/"' in content assert 'default_language: "rust"' in content -def test_build_pddrc_content_ends_with_newline(): +def test_build_content_ends_with_newline(): """Generated content ends with a trailing newline.""" - content = pddrc_initializer._build_pddrc_content("python") + content = _build_pddrc_content("python") assert content.endswith("\n") # --------------------------------------------------------------------------- -# Tests for offer_pddrc_init +# IV. Success Path — File Created # --------------------------------------------------------------------------- -def test_offer_pddrc_init_already_exists(tmp_path): - """Returns False and does not overwrite when .pddrc already exists.""" - (tmp_path / ".pddrc").write_text("existing config") - - with patch.object(Path, "cwd", return_value=tmp_path): - result = pddrc_initializer.offer_pddrc_init() - - assert result is False - assert (tmp_path / ".pddrc").read_text() == "existing config" - - -@patch.object(pddrc_initializer.console, "input", return_value="y") -def test_offer_pddrc_init_creates_file(mock_input, tmp_path): - """Creates .pddrc when user confirms with 'y'.""" - (tmp_path / "pyproject.toml").touch() # Python marker +def test_creates_file_on_confirm_yes(tmp_path, monkeypatch): + """User confirms with 'y' → .pddrc created, returns True.""" + result, _ = _run_offer_capture( + tmp_path, monkeypatch, ["y"], + marker_files={"pyproject.toml": None}, + ) + assert result is True + assert (tmp_path / ".pddrc").exists() - with patch.object(Path, "cwd", return_value=tmp_path): - result = pddrc_initializer.offer_pddrc_init() +def test_creates_file_on_enter(tmp_path, monkeypatch): + """Empty input (Enter) means yes → .pddrc created, returns True.""" + result, _ = _run_offer_capture( + tmp_path, monkeypatch, [""], + marker_files={"pyproject.toml": None}, + ) assert result is True - pddrc = tmp_path / ".pddrc" - assert pddrc.exists() - content = pddrc.read_text() - assert 'default_language: "python"' in content + assert (tmp_path / ".pddrc").exists() -@patch.object(pddrc_initializer.console, "input", return_value="") -def test_offer_pddrc_init_enter_means_yes(mock_input, tmp_path): - """Empty input (just Enter) means yes — file is created.""" - (tmp_path / "pyproject.toml").touch() +def test_created_file_has_correct_content(tmp_path, monkeypatch): + """Created .pddrc contains language-appropriate defaults.""" + _run_offer_capture( + tmp_path, monkeypatch, ["y"], + marker_files={"pyproject.toml": None}, + ) + content = (tmp_path / ".pddrc").read_text() + assert 'default_language: "python"' in content + assert 'generate_output_path: "pdd/"' in content + assert "strength: 0.818" in content - with patch.object(Path, "cwd", return_value=tmp_path): - result = pddrc_initializer.offer_pddrc_init() +def test_prompts_language_when_undetected(tmp_path, monkeypatch): + """No markers → user prompted for language (1=Python), then confirms.""" + # First input: language choice, second: confirmation + result, _ = _run_offer_capture( + tmp_path, monkeypatch, ["1", "y"], + ) assert result is True - assert (tmp_path / ".pddrc").exists() - + content = (tmp_path / ".pddrc").read_text() + assert 'default_language: "python"' in content -@patch.object(pddrc_initializer.console, "input", return_value="n") -def test_offer_pddrc_init_declined(mock_input, tmp_path): - """Returns False when user declines with 'n'.""" - (tmp_path / "pyproject.toml").touch() - with patch.object(Path, "cwd", return_value=tmp_path): - result = pddrc_initializer.offer_pddrc_init() +# --------------------------------------------------------------------------- +# V. User Declines +# --------------------------------------------------------------------------- +def test_declined_returns_false(tmp_path, monkeypatch): + """User types 'n' → returns False, no file created.""" + result, _ = _run_offer_capture( + tmp_path, monkeypatch, ["n"], + marker_files={"pyproject.toml": None}, + ) assert result is False assert not (tmp_path / ".pddrc").exists() -@patch.object(pddrc_initializer.console, "input") -def test_offer_pddrc_init_prompts_language_when_unknown(mock_input, tmp_path): - """When no markers found, prompts user for language choice.""" - # First input: language choice (1=Python), second: confirmation (y) - mock_input.side_effect = ["1", "y"] - - with patch.object(Path, "cwd", return_value=tmp_path): - result = pddrc_initializer.offer_pddrc_init() +# --------------------------------------------------------------------------- +# VI. Language Prompt with Invalid Input +# --------------------------------------------------------------------------- +def test_language_prompt_retries_on_invalid(tmp_path, monkeypatch): + """Invalid language choices cause retries until valid choice, then file created.""" + # "x" and "99" are invalid, "2" selects TypeScript, "y" confirms + result, output = _run_offer_capture( + tmp_path, monkeypatch, ["x", "99", "2", "y"], + ) assert result is True content = (tmp_path / ".pddrc").read_text() - assert 'default_language: "python"' in content + assert 'default_language: "typescript"' in content + assert "Invalid choice" in output # --------------------------------------------------------------------------- -# Tests for _prompt_language +# VII. Output / Display # --------------------------------------------------------------------------- -@patch.object(pddrc_initializer.console, "input", return_value="1") -def test_prompt_language_python(mock_input): - assert pddrc_initializer._prompt_language() == "python" +def test_detected_language_shown(tmp_path, monkeypatch): + """Auto-detected language is displayed to user.""" + _, output = _run_offer_capture( + tmp_path, monkeypatch, ["y"], + marker_files={"pyproject.toml": None}, + ) + assert "python" in output.lower() + + +def test_preview_shown_before_confirmation(tmp_path, monkeypatch): + """YAML preview content appears in output before confirmation.""" + _, output = _run_offer_capture( + tmp_path, monkeypatch, ["y"], + marker_files={"pyproject.toml": None}, + ) + assert "Proposed" in output or "contents" in output + assert "version" in output -@patch.object(pddrc_initializer.console, "input", return_value="2") -def test_prompt_language_typescript(mock_input): - assert pddrc_initializer._prompt_language() == "typescript" +def test_creation_success_message(tmp_path, monkeypatch): + """'Created .pddrc' message appears after successful creation.""" + _, output = _run_offer_capture( + tmp_path, monkeypatch, ["y"], + marker_files={"pyproject.toml": None}, + ) + assert "Created" in output + assert ".pddrc" in output -@patch.object(pddrc_initializer.console, "input", return_value="3") -def test_prompt_language_go(mock_input): - assert pddrc_initializer._prompt_language() == "go" +def test_skip_message_on_decline(tmp_path, monkeypatch): + """'Skipped' message appears when user declines.""" + _, output = _run_offer_capture( + tmp_path, monkeypatch, ["n"], + marker_files={"pyproject.toml": None}, + ) + assert "Skipped" in output or "skipped" in output -@patch.object(pddrc_initializer.console, "input") -def test_prompt_language_retries_on_invalid(mock_input): - """Invalid input causes retry until valid choice is entered.""" - mock_input.side_effect = ["x", "99", "2"] - assert pddrc_initializer._prompt_language() == "typescript" - assert mock_input.call_count == 3 +# --------------------------------------------------------------------------- +# VIII. Filesystem Error +# --------------------------------------------------------------------------- + +def test_write_error_returns_false(tmp_path, monkeypatch): + """OSError during file write → returns False, error message shown.""" + with patch.object(Path, "write_text", side_effect=OSError("Permission denied")): + result, output = _run_offer_capture( + tmp_path, monkeypatch, ["y"], + marker_files={"pyproject.toml": None}, + ) + assert result is False + assert "Failed" in output or "error" in output.lower() diff --git a/tests/test_provider_manager.py b/tests/test_provider_manager.py index 8a70184b5..d170f0a2b 100644 --- a/tests/test_provider_manager.py +++ b/tests/test_provider_manager.py @@ -1,8 +1,14 @@ -"""Tests for pdd/provider_manager.py""" +"""Tests for pdd/provider_manager.py + +Organized by public API function. Tests verify user-observable behavior +through the public interface; private helpers are exercised indirectly. +Shell execution integration tests verify generated scripts actually work. +""" import csv import os -import tempfile +import subprocess +import shutil from pathlib import Path from unittest import mock @@ -10,19 +16,13 @@ from pdd.provider_manager import ( CSV_FIELDNAMES, - _get_shell_name, - _get_pdd_dir, - _get_api_env_path, - _get_user_csv_path, - _read_csv, - _write_csv_atomic, - _read_api_env_lines, - _write_api_env_atomic, + COMPLEX_AUTH_PROVIDERS, _save_key_to_api_env, - _comment_out_key_in_api_env, - _is_key_set, - add_provider_from_registry, + _setup_complex_provider, add_custom_provider, + add_provider_from_registry, + is_multi_credential, + parse_api_key_vars, remove_models_by_provider, remove_individual_models, ) @@ -108,907 +108,647 @@ def sample_api_env(temp_home): return api_env_path -# --------------------------------------------------------------------------- -# Tests for path helpers -# --------------------------------------------------------------------------- - - -class TestPathHelpers: - """Tests for path helper functions.""" - - def test_get_shell_name_bash(self, monkeypatch): - """Should detect bash shell.""" - monkeypatch.setenv("SHELL", "/bin/bash") - assert _get_shell_name() == "bash" - - def test_get_shell_name_zsh(self, monkeypatch): - """Should detect zsh shell.""" - monkeypatch.setenv("SHELL", "/usr/local/bin/zsh") - assert _get_shell_name() == "zsh" - - def test_get_shell_name_fish(self, monkeypatch): - """Should detect fish shell.""" - monkeypatch.setenv("SHELL", "/opt/homebrew/bin/fish") - assert _get_shell_name() == "fish" - - def test_get_shell_name_defaults_to_bash(self, monkeypatch): - """Should default to bash for unknown shells.""" - monkeypatch.setenv("SHELL", "/bin/unknown_shell") - assert _get_shell_name() == "bash" - - def test_get_shell_name_no_shell_var(self, monkeypatch): - """Should default to bash when SHELL not set.""" - monkeypatch.delenv("SHELL", raising=False) - # Implementation defaults to /bin/bash when SHELL is not set - result = _get_shell_name() - assert result == "bash" - - def test_get_pdd_dir_creates_directory(self, tmp_path, monkeypatch): - """Should create ~/.pdd if it doesn't exist.""" - monkeypatch.setattr(Path, "home", lambda: tmp_path) - pdd_dir = tmp_path / ".pdd" - - # Directory shouldn't exist yet - assert not pdd_dir.exists() - - result = _get_pdd_dir() - - assert result == pdd_dir - assert pdd_dir.exists() - - def test_get_api_env_path(self, temp_home, monkeypatch): - """Should return correct api-env path for shell.""" - monkeypatch.setenv("SHELL", "/bin/zsh") - result = _get_api_env_path() - assert result == temp_home / ".pdd" / "api-env.zsh" - - def test_get_user_csv_path(self, temp_home): - """Should return correct user CSV path.""" - result = _get_user_csv_path() - assert result == temp_home / ".pdd" / "llm_model.csv" +def _read_user_csv(temp_home): + """Read the user CSV and return list of row dicts.""" + csv_path = temp_home / ".pdd" / "llm_model.csv" + if not csv_path.exists(): + return [] + with open(csv_path, "r", encoding="utf-8", newline="") as f: + return list(csv.DictReader(f)) # --------------------------------------------------------------------------- -# Tests for CSV I/O helpers +# I. parse_api_key_vars / is_multi_credential # --------------------------------------------------------------------------- -class TestCsvHelpers: - """Tests for CSV read/write functions.""" - - def test_read_csv_returns_list_of_dicts(self, sample_csv): - """Should read CSV and return list of row dictionaries.""" - result = _read_csv(sample_csv) +class TestApiKeyParsing: + """Tests for the public utility functions parse_api_key_vars and is_multi_credential.""" - assert isinstance(result, list) - assert len(result) == 3 - assert result[0]["provider"] == "OpenAI" - assert result[0]["model"] == "gpt-4" + def test_parse_single_var(self): + assert parse_api_key_vars("OPENAI_API_KEY") == ["OPENAI_API_KEY"] - def test_read_csv_missing_file(self, temp_home): - """Should return empty list for missing file.""" - result = _read_csv(temp_home / ".pdd" / "nonexistent.csv") - assert result == [] - - def test_write_csv_atomic_creates_file(self, temp_home): - """Should create CSV file with correct content.""" - csv_path = temp_home / ".pdd" / "test.csv" - rows = [ - {"provider": "Test", "model": "test-model", "input": "1.0", "output": "2.0"}, - ] + def test_parse_multiple_vars(self): + result = parse_api_key_vars("AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME") + assert result == ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_REGION_NAME"] - _write_csv_atomic(csv_path, rows) + def test_parse_empty_and_none(self): + assert parse_api_key_vars("") == [] + assert parse_api_key_vars(None) == [] + assert parse_api_key_vars(" ") == [] - assert csv_path.exists() - result = _read_csv(csv_path) - assert len(result) == 1 - assert result[0]["provider"] == "Test" + def test_parse_strips_whitespace_and_filters_empty(self): + assert parse_api_key_vars(" KEY_A | KEY_B ") == ["KEY_A", "KEY_B"] + assert parse_api_key_vars("KEY_A||KEY_B") == ["KEY_A", "KEY_B"] - def test_write_csv_atomic_is_atomic(self, temp_home): - """Write should be atomic - no partial writes on failure.""" - csv_path = temp_home / ".pdd" / "test.csv" - - # Write initial content - _write_csv_atomic(csv_path, [{"provider": "Original"}]) - - # Verify temp files are cleaned up - pdd_dir = temp_home / ".pdd" - temp_files = list(pdd_dir.glob(".llm_model_*.tmp")) - assert len(temp_files) == 0 - - def test_write_csv_atomic_fills_missing_fields(self, temp_home): - """Should fill missing fields with empty strings.""" - csv_path = temp_home / ".pdd" / "test.csv" - rows = [{"provider": "Test", "model": "test-model"}] # Missing many fields - - _write_csv_atomic(csv_path, rows) - - result = _read_csv(csv_path) - # All CSV_FIELDNAMES should be present - for field in CSV_FIELDNAMES: - assert field in result[0] + def test_is_multi_credential(self): + assert is_multi_credential("A|B") is True + assert is_multi_credential("OPENAI_API_KEY") is False + assert is_multi_credential("") is False + assert is_multi_credential(None) is False # --------------------------------------------------------------------------- -# Tests for api-env file helpers +# II. add_provider_from_registry # --------------------------------------------------------------------------- -class TestApiEnvHelpers: - """Tests for api-env file read/write functions.""" - - def test_read_api_env_lines(self, sample_api_env): - """Should read api-env file lines.""" - result = _read_api_env_lines(sample_api_env) - - assert len(result) == 2 - assert "OPENAI_API_KEY" in result[0] - - def test_read_api_env_lines_missing_file(self, temp_home): - """Should return empty list for missing file.""" - result = _read_api_env_lines(temp_home / ".pdd" / "nonexistent") - assert result == [] - - def test_write_api_env_atomic(self, temp_home): - """Should write api-env file atomically.""" - env_path = temp_home / ".pdd" / "api-env.bash" - lines = ["export TEST_KEY=value\n"] - - _write_api_env_atomic(env_path, lines) - - assert env_path.exists() - content = env_path.read_text() - assert "TEST_KEY" in content - - def test_save_key_to_api_env_new_key(self, temp_home, monkeypatch): - """Should add new key to api-env file.""" - monkeypatch.setenv("SHELL", "/bin/bash") - env_path = temp_home / ".pdd" / "api-env.bash" - - _save_key_to_api_env("NEW_KEY", "new-value") - - content = env_path.read_text() - # shlex.quote() doesn't quote simple values without special chars - assert 'export NEW_KEY=' in content - assert 'new-value' in content - - def test_save_key_to_api_env_updates_existing(self, sample_api_env, monkeypatch): - """Should update existing key in api-env file.""" - monkeypatch.setenv("SHELL", "/bin/bash") - - _save_key_to_api_env("OPENAI_API_KEY", "sk-updated") - - content = sample_api_env.read_text() - # shlex.quote() doesn't quote simple values without special chars - assert 'export OPENAI_API_KEY=' in content - assert 'sk-updated' in content - # Should not have duplicate entries - assert content.count("OPENAI_API_KEY") == 1 - - def test_save_key_to_api_env_uncomments_commented_key(self, temp_home, monkeypatch): - """Should replace commented key with new value.""" - monkeypatch.setenv("SHELL", "/bin/bash") - env_path = temp_home / ".pdd" / "api-env.bash" - env_path.write_text("# export OLD_KEY=old-value\n") - - _save_key_to_api_env("OLD_KEY", "new-value") - - content = env_path.read_text() - # shlex.quote() doesn't quote simple values without special chars - assert 'export OLD_KEY=' in content - assert 'new-value' in content - assert "# export OLD_KEY" not in content - - def test_save_key_with_special_characters(self, temp_home, monkeypatch): - """Should handle API keys with special characters.""" - monkeypatch.setenv("SHELL", "/bin/bash") - - # Key with special shell characters - special_value = 'key$with"special\'chars' - _save_key_to_api_env("SPECIAL_KEY", special_value) - - env_path = temp_home / ".pdd" / "api-env.bash" - content = env_path.read_text() - assert "SPECIAL_KEY" in content +class TestAddProviderFromRegistry: + """Tests for add_provider_from_registry — the main provider browsing flow.""" - def test_comment_out_key_in_api_env(self, sample_api_env, monkeypatch): - """Should comment out key with date annotation.""" - monkeypatch.setenv("SHELL", "/bin/bash") + def test_returns_false_on_empty_ref_csv(self, temp_home): + """Should return False when reference CSV has no models.""" + with mock.patch("pdd.provider_manager._read_csv", return_value=[]): + with mock.patch("pdd.provider_manager.console"): + assert add_provider_from_registry() is False - _comment_out_key_in_api_env("OPENAI_API_KEY") + def test_returns_false_on_cancel(self, temp_home): + """Empty input should cancel the flow.""" + ref_rows = [ + {"provider": "Anthropic", "model": "claude", "api_key": "ANTHROPIC_API_KEY"}, + ] + with mock.patch("pdd.provider_manager._read_csv", return_value=ref_rows): + with mock.patch("pdd.provider_manager.Prompt") as mock_prompt: + mock_prompt.ask.return_value = "" + with mock.patch("pdd.provider_manager.console"): + assert add_provider_from_registry() is False - content = sample_api_env.read_text() - assert "# Commented out by pdd setup on" in content - assert "# export OPENAI_API_KEY" in content + @pytest.mark.parametrize("bad_input", ["99", "0", "abc", "-1"]) + def test_returns_false_on_invalid_selection(self, temp_home, bad_input): + """Out-of-range or non-numeric input should return False.""" + ref_rows = [ + {"provider": "Anthropic", "model": "claude", "api_key": "ANTHROPIC_API_KEY"}, + ] + with mock.patch("pdd.provider_manager._read_csv", return_value=ref_rows): + with mock.patch("pdd.provider_manager.Prompt") as mock_prompt: + mock_prompt.ask.return_value = bad_input + with mock.patch("pdd.provider_manager.console"): + assert add_provider_from_registry() is False - def test_comment_out_preserves_other_keys(self, sample_api_env, monkeypatch): - """Should preserve other keys when commenting out one.""" + def test_adds_models_to_csv(self, temp_home, monkeypatch): + """Should add all models for the selected provider to user CSV.""" monkeypatch.setenv("SHELL", "/bin/bash") - _comment_out_key_in_api_env("OPENAI_API_KEY") - - content = sample_api_env.read_text() - # ANTHROPIC_API_KEY should still be active - assert "export ANTHROPIC_API_KEY=ant-test456" in content - - -# --------------------------------------------------------------------------- -# Tests for _is_key_set -# --------------------------------------------------------------------------- - - -class TestIsKeySet: - """Tests for _is_key_set function.""" - - def test_detects_key_in_shell_env(self, temp_home, monkeypatch): - """Should detect key set in shell environment.""" - monkeypatch.setenv("TEST_KEY", "test-value") + ref_rows = [ + {"provider": "Anthropic", "model": "claude-sonnet", "api_key": "ANTHROPIC_API_KEY", + "input": "3.0", "output": "15.0", "coding_arena_elo": "1400", "base_url": "", + "max_reasoning_tokens": "0", "structured_output": "True", "reasoning_type": "", "location": ""}, + {"provider": "Anthropic", "model": "claude-opus", "api_key": "ANTHROPIC_API_KEY", + "input": "5.0", "output": "25.0", "coding_arena_elo": "1500", "base_url": "", + "max_reasoning_tokens": "0", "structured_output": "True", "reasoning_type": "", "location": ""}, + {"provider": "OpenAI", "model": "gpt-4", "api_key": "OPENAI_API_KEY", + "input": "30.0", "output": "60.0", "coding_arena_elo": "1300", "base_url": "", + "max_reasoning_tokens": "0", "structured_output": "True", "reasoning_type": "", "location": ""}, + ] - result = _is_key_set("TEST_KEY") + with mock.patch("pdd.provider_manager._read_csv", side_effect=[ref_rows, []]): + with mock.patch("pdd.provider_manager._write_csv_atomic") as mock_write: + with mock.patch("pdd.provider_manager.Prompt") as mock_prompt: + mock_prompt.ask.side_effect = ["1", "test-api-key"] + with mock.patch("pdd.provider_manager.Confirm") as mock_confirm: + mock_confirm.ask.return_value = False + with mock.patch("pdd.provider_manager.console"): + with mock.patch("pdd.provider_manager._is_key_set", return_value=None): + result = add_provider_from_registry() - assert result == "shell environment" + assert result is True + mock_write.assert_called_once() + written_rows = mock_write.call_args[0][1] + assert len(written_rows) == 2 + assert all(r["provider"] == "Anthropic" for r in written_rows) - def test_detects_key_in_api_env_file(self, sample_api_env, monkeypatch): - """Should detect key in api-env file.""" + def test_skips_duplicate_models(self, temp_home, monkeypatch): + """Should not add models that already exist in user CSV.""" monkeypatch.setenv("SHELL", "/bin/bash") - monkeypatch.delenv("OPENAI_API_KEY", raising=False) - - result = _is_key_set("OPENAI_API_KEY") - - assert "api-env.bash" in result - - def test_returns_none_when_key_not_set(self, temp_home, monkeypatch): - """Should return None when key is not set anywhere.""" - monkeypatch.delenv("NONEXISTENT_KEY", raising=False) - - result = _is_key_set("NONEXISTENT_KEY") - - assert result is None - - def test_dotenv_priority_over_shell(self, temp_home, monkeypatch): - """Should check .env file first.""" - monkeypatch.setenv("TEST_KEY", "shell-value") - - with mock.patch("dotenv.dotenv_values", return_value={"TEST_KEY": "dotenv-value"}): - result = _is_key_set("TEST_KEY") - - assert result == ".env file" - - -# --------------------------------------------------------------------------- -# Tests for add_provider_from_registry (mocked) -# --------------------------------------------------------------------------- - -class TestAddProviderFromRegistry: - """Tests for add_provider_from_registry function.""" - - def test_returns_false_when_litellm_unavailable(self, temp_home): - """Should return False when litellm is not available.""" - with mock.patch( - "pdd.provider_manager.is_litellm_available", return_value=False - ): - with mock.patch("pdd.provider_manager.console"): - result = add_provider_from_registry() - - assert result is False + ref_rows = [ + {"provider": "Anthropic", "model": "claude-sonnet", "api_key": "ANTHROPIC_API_KEY", + "input": "3.0", "output": "15.0", "coding_arena_elo": "1400", "base_url": "", + "max_reasoning_tokens": "0", "structured_output": "True", "reasoning_type": "", "location": ""}, + ] + existing_rows = [ + {"provider": "Anthropic", "model": "claude-sonnet", "api_key": "ANTHROPIC_API_KEY"}, + ] - def test_returns_false_on_empty_selection(self, temp_home): - """Should return False when user enters empty selection.""" - with mock.patch( - "pdd.provider_manager.is_litellm_available", return_value=True - ): - with mock.patch( - "pdd.provider_manager.get_top_providers", - return_value=[], - ): + with mock.patch("pdd.provider_manager._read_csv", side_effect=[ref_rows, existing_rows]): + with mock.patch("pdd.provider_manager._write_csv_atomic") as mock_write: with mock.patch("pdd.provider_manager.Prompt") as mock_prompt: - mock_prompt.ask.return_value = "" + mock_prompt.ask.return_value = "1" with mock.patch("pdd.provider_manager.console"): - result = add_provider_from_registry() + with mock.patch("pdd.provider_manager._is_key_set", return_value="shell environment"): + with mock.patch("pdd.provider_manager.Confirm") as mock_confirm: + mock_confirm.ask.return_value = False + result = add_provider_from_registry() assert result is False - - def test_adds_models_to_csv(self, temp_home): - """Should add selected models to user CSV.""" - from pdd.litellm_registry import ProviderInfo, ModelInfo - - mock_provider = ProviderInfo( - name="test_provider", - display_name="Test Provider", - api_key_env_var="TEST_API_KEY", - model_count=2, - sample_models=["model1", "model2"], - ) - - mock_models = [ - ModelInfo( - name="model1", - litellm_id="test_provider/model1", - input_cost_per_million=1.0, - output_cost_per_million=2.0, - supports_function_calling=True, - ), - ] - - with mock.patch( - "pdd.provider_manager.is_litellm_available", return_value=True - ): - with mock.patch( - "pdd.provider_manager.get_top_providers", - return_value=[mock_provider], - ): - with mock.patch( - "pdd.provider_manager.get_models_for_provider", - return_value=mock_models, - ): + mock_write.assert_not_called() + + def test_dispatches_to_complex_auth_for_vertex(self, temp_home): + """Selecting a complex provider should delegate to _setup_complex_provider.""" + with mock.patch("pdd.provider_manager._setup_complex_provider", return_value=True) as mock_setup: + with mock.patch("pdd.provider_manager._write_csv_atomic"): + with mock.patch("pdd.provider_manager._read_csv") as mock_read: + mock_read.side_effect = [ + [{"provider": "Google Vertex AI", "model": "vertex_ai/gemini-2.5-pro", + "api_key": "GOOGLE_APPLICATION_CREDENTIALS", "base_url": ""}], + [], + ] with mock.patch("pdd.provider_manager.Prompt") as mock_prompt: - # Select provider 1, then model 1 - mock_prompt.ask.side_effect = ["1", "1", "test-api-key"] - - with mock.patch("pdd.provider_manager.Confirm") as mock_confirm: - mock_confirm.ask.return_value = False + mock_prompt.ask.return_value = "1" + with mock.patch("pdd.provider_manager.console"): + add_provider_from_registry() - with mock.patch("pdd.provider_manager.console"): - with mock.patch( - "pdd.provider_manager._is_key_set", - return_value=None, - ): - result = add_provider_from_registry() - - # Check that model was added to CSV - csv_path = temp_home / ".pdd" / "llm_model.csv" - if csv_path.exists(): - rows = _read_csv(csv_path) - assert any(r["model"] == "test_provider/model1" for r in rows) + mock_setup.assert_called_once_with("Google Vertex AI") # --------------------------------------------------------------------------- -# Tests for add_custom_provider (mocked) +# III. add_custom_provider # --------------------------------------------------------------------------- class TestAddCustomProvider: - """Tests for add_custom_provider function.""" - - def test_returns_false_on_empty_provider(self, temp_home): - """Should return False when provider name is empty.""" - with mock.patch("pdd.provider_manager.Prompt") as mock_prompt: - mock_prompt.ask.return_value = "" - with mock.patch("pdd.provider_manager.console"): - result = add_custom_provider() - - assert result is False - - def test_adds_custom_provider_to_csv(self, temp_home): - """Should add custom provider to user CSV.""" - with mock.patch("pdd.provider_manager.Prompt") as mock_prompt: - mock_prompt.ask.side_effect = [ - "custom_llm", # provider prefix - "my-model", # model name - "CUSTOM_API_KEY", # api key env var - "", # base url (optional) - "1.0", # input cost - "2.0", # output cost - ] - with mock.patch("pdd.provider_manager.Confirm") as mock_confirm: - mock_confirm.ask.return_value = False # Don't enter key value now - with mock.patch("pdd.provider_manager.console"): - result = add_custom_provider() - - assert result is True - - # Verify CSV was updated - csv_path = temp_home / ".pdd" / "llm_model.csv" - rows = _read_csv(csv_path) - assert len(rows) == 1 - assert rows[0]["provider"] == "custom_llm" - assert rows[0]["model"] == "custom_llm/my-model" - assert rows[0]["api_key"] == "CUSTOM_API_KEY" - - def test_saves_api_key_when_provided(self, temp_home, monkeypatch): - """Should save API key to api-env when user provides it.""" - monkeypatch.setenv("SHELL", "/bin/bash") - - with mock.patch("pdd.provider_manager.Prompt") as mock_prompt: - mock_prompt.ask.side_effect = [ - "custom_llm", - "my-model", - "CUSTOM_API_KEY", - "", - "0.0", - "0.0", - "sk-my-secret-key", # API key value - ] - with mock.patch("pdd.provider_manager.Confirm") as mock_confirm: - mock_confirm.ask.return_value = True # Yes, enter key value - with mock.patch("pdd.provider_manager.console"): - result = add_custom_provider() - - assert result is True + """Tests for add_custom_provider — the manual provider entry flow.""" + + @mock.patch("pdd.provider_manager._ensure_api_env_sourced_in_rc") + @mock.patch("pdd.provider_manager._save_key_to_api_env") + @mock.patch("pdd.provider_manager._write_csv_atomic") + @mock.patch("pdd.provider_manager._read_csv", return_value=[]) + @mock.patch("pdd.provider_manager.Confirm") + @mock.patch("pdd.provider_manager.Prompt") + @mock.patch("pdd.provider_manager.console") + def test_adds_custom_model_with_correct_format( + self, mock_console, mock_prompt, mock_confirm, mock_read, mock_write, mock_save, mock_rc + ): + """Should create provider/model formatted model name and sensible defaults.""" + mock_prompt.ask.side_effect = [ + "ollama", "llama3", "OLLAMA_API_KEY", "", "0.0", "0.0", + ] + mock_confirm.ask.return_value = False + + assert add_custom_provider() is True + + written_rows = mock_write.call_args[0][1] + assert len(written_rows) == 1 + assert written_rows[0]["model"] == "ollama/llama3" + assert written_rows[0]["provider"] == "ollama" + assert written_rows[0]["api_key"] == "OLLAMA_API_KEY" + assert written_rows[0]["coding_arena_elo"] == "1000" + assert written_rows[0]["structured_output"] == "True" + + @pytest.mark.parametrize("abort_at_step,inputs", [ + ("provider", [""]), + ("model", ["ollama", ""]), + ("api_key_var", ["ollama", "llama3", ""]), + ]) + @mock.patch("pdd.provider_manager.Prompt") + @mock.patch("pdd.provider_manager.console") + def test_returns_false_on_empty_input_at_each_step( + self, mock_console, mock_prompt, abort_at_step, inputs + ): + """Empty input at any required step should cancel.""" + mock_prompt.ask.side_effect = inputs + assert add_custom_provider() is False + + @mock.patch("pdd.provider_manager._ensure_api_env_sourced_in_rc") + @mock.patch("pdd.provider_manager._save_key_to_api_env") + @mock.patch("pdd.provider_manager._write_csv_atomic") + @mock.patch("pdd.provider_manager._read_csv", return_value=[]) + @mock.patch("pdd.provider_manager.Confirm") + @mock.patch("pdd.provider_manager.Prompt") + @mock.patch("pdd.provider_manager.console") + def test_saves_api_key_when_user_provides_value( + self, mock_console, mock_prompt, mock_confirm, mock_read, mock_write, mock_save, mock_rc + ): + """When user opts to provide key value, it should be saved to api-env.""" + mock_prompt.ask.side_effect = [ + "openai", "gpt-5", "MY_KEY", "", "0.0", "0.0", "sk-secret123", + ] + mock_confirm.ask.return_value = True + + assert add_custom_provider() is True + mock_save.assert_called_once_with("MY_KEY", "sk-secret123") + + @mock.patch("pdd.provider_manager._ensure_api_env_sourced_in_rc") + @mock.patch("pdd.provider_manager._save_key_to_api_env") + @mock.patch("pdd.provider_manager._write_csv_atomic") + @mock.patch("pdd.provider_manager._read_csv", return_value=[]) + @mock.patch("pdd.provider_manager.Confirm") + @mock.patch("pdd.provider_manager.Prompt") + @mock.patch("pdd.provider_manager.console") + def test_invalid_costs_default_to_zero( + self, mock_console, mock_prompt, mock_confirm, mock_read, mock_write, mock_save, mock_rc + ): + """Non-numeric cost values should default to 0.0.""" + mock_prompt.ask.side_effect = [ + "test", "model", "TEST_KEY", "", "not-a-number", "also-bad", + ] + mock_confirm.ask.return_value = False - # Verify api-env was updated - env_path = temp_home / ".pdd" / "api-env.bash" - content = env_path.read_text() - assert "CUSTOM_API_KEY" in content + assert add_custom_provider() is True + written_rows = mock_write.call_args[0][1] + assert written_rows[0]["input"] == "0.0" + assert written_rows[0]["output"] == "0.0" # --------------------------------------------------------------------------- -# Tests for remove_models_by_provider (mocked) +# IV. remove_models_by_provider # --------------------------------------------------------------------------- class TestRemoveModelsByProvider: - """Tests for remove_models_by_provider function.""" + """Tests for remove_models_by_provider — bulk removal by api_key group.""" def test_returns_false_when_no_models(self, temp_home): - """Should return False when no models configured.""" with mock.patch("pdd.provider_manager.console"): - result = remove_models_by_provider() - - assert result is False + assert remove_models_by_provider() is False def test_returns_false_on_cancel(self, sample_csv, temp_home): - """Should return False when user cancels.""" with mock.patch("pdd.provider_manager.Prompt") as mock_prompt: - mock_prompt.ask.return_value = "" # Empty = cancel + mock_prompt.ask.return_value = "" with mock.patch("pdd.provider_manager.console"): - result = remove_models_by_provider() + assert remove_models_by_provider() is False - assert result is False + @pytest.mark.parametrize("bad_input", ["99", "abc"]) + def test_returns_false_on_invalid_selection(self, sample_csv, temp_home, bad_input): + with mock.patch("pdd.provider_manager.Prompt") as mock_prompt: + mock_prompt.ask.return_value = bad_input + with mock.patch("pdd.provider_manager.console"): + assert remove_models_by_provider() is False - def test_removes_all_models_for_provider(self, sample_csv, temp_home, monkeypatch): - """Should remove all models with matching api_key.""" + def test_returns_false_when_user_declines_confirm(self, sample_csv, temp_home): + with mock.patch("pdd.provider_manager.Prompt") as mock_prompt: + mock_prompt.ask.return_value = "1" + with mock.patch("pdd.provider_manager.Confirm") as mock_confirm: + mock_confirm.ask.return_value = False + with mock.patch("pdd.provider_manager.console"): + assert remove_models_by_provider() is False + + def test_removes_all_models_for_selected_provider(self, sample_csv, temp_home, monkeypatch): + """Should remove all models sharing the selected api_key and comment it out.""" monkeypatch.setenv("SHELL", "/bin/bash") with mock.patch("pdd.provider_manager.Prompt") as mock_prompt: - mock_prompt.ask.return_value = "1" # Select first provider + mock_prompt.ask.return_value = "1" with mock.patch("pdd.provider_manager.Confirm") as mock_confirm: - mock_confirm.ask.return_value = True # Confirm removal + mock_confirm.ask.return_value = True with mock.patch("pdd.provider_manager.console"): result = remove_models_by_provider() assert result is True - - # Check that models were removed - rows = _read_csv(sample_csv) - # One provider should have been removed - assert len(rows) < 3 + remaining = _read_user_csv(temp_home) + assert len(remaining) < 3 # --------------------------------------------------------------------------- -# Tests for remove_individual_models (mocked) +# V. remove_individual_models # --------------------------------------------------------------------------- class TestRemoveIndividualModels: - """Tests for remove_individual_models function.""" + """Tests for remove_individual_models — selective model removal.""" def test_returns_false_when_no_models(self, temp_home): - """Should return False when no models configured.""" with mock.patch("pdd.provider_manager.console"): - result = remove_individual_models() - - assert result is False + assert remove_individual_models() is False def test_returns_false_on_cancel(self, sample_csv, temp_home): - """Should return False when user cancels.""" with mock.patch("pdd.provider_manager.Prompt") as mock_prompt: - mock_prompt.ask.return_value = "" # Empty = cancel + mock_prompt.ask.return_value = "" with mock.patch("pdd.provider_manager.console"): - result = remove_individual_models() + assert remove_individual_models() is False - assert result is False + def test_returns_false_on_all_invalid_numbers(self, sample_csv, temp_home): + """All-invalid comma-separated input should result in no selections.""" + with mock.patch("pdd.provider_manager.Prompt") as mock_prompt: + mock_prompt.ask.return_value = "99, abc, -1" + with mock.patch("pdd.provider_manager.console"): + assert remove_individual_models() is False - def test_removes_selected_models(self, sample_csv, temp_home): - """Should remove only selected models.""" + def test_returns_false_when_user_declines_confirm(self, sample_csv, temp_home): with mock.patch("pdd.provider_manager.Prompt") as mock_prompt: - mock_prompt.ask.return_value = "1" # Remove first model + mock_prompt.ask.return_value = "1" with mock.patch("pdd.provider_manager.Confirm") as mock_confirm: - mock_confirm.ask.return_value = True # Confirm + mock_confirm.ask.return_value = False with mock.patch("pdd.provider_manager.console"): - result = remove_individual_models() + assert remove_individual_models() is False - assert result is True + def test_removes_single_model(self, sample_csv, temp_home): + with mock.patch("pdd.provider_manager.Prompt") as mock_prompt: + mock_prompt.ask.return_value = "1" + with mock.patch("pdd.provider_manager.Confirm") as mock_confirm: + mock_confirm.ask.return_value = True + with mock.patch("pdd.provider_manager.console"): + assert remove_individual_models() is True - # Check that one model was removed - rows = _read_csv(sample_csv) - assert len(rows) == 2 + assert len(_read_user_csv(temp_home)) == 2 - def test_removes_multiple_models(self, sample_csv, temp_home): - """Should handle comma-separated model selection.""" + def test_removes_multiple_comma_separated(self, sample_csv, temp_home): with mock.patch("pdd.provider_manager.Prompt") as mock_prompt: - mock_prompt.ask.return_value = "1, 2" # Remove first two models + mock_prompt.ask.return_value = "1, 2" with mock.patch("pdd.provider_manager.Confirm") as mock_confirm: - mock_confirm.ask.return_value = True # Confirm + mock_confirm.ask.return_value = True with mock.patch("pdd.provider_manager.console"): - result = remove_individual_models() + assert remove_individual_models() is True - assert result is True - - # Check that two models were removed - rows = _read_csv(sample_csv) - assert len(rows) == 1 + assert len(_read_user_csv(temp_home)) == 1 # --------------------------------------------------------------------------- -# Edge case tests +# VI. Complex provider auth (_setup_complex_provider) # --------------------------------------------------------------------------- -class TestEdgeCases: - """Test edge cases and error handling.""" - - def test_csv_write_atomic_cleans_up_on_error(self, temp_home): - """Should clean up temp file on write error.""" - csv_path = temp_home / ".pdd" / "test.csv" - - # Create a scenario where write might fail - with mock.patch("os.fdopen", side_effect=IOError("Simulated error")): - with pytest.raises(IOError): - _write_csv_atomic(csv_path, [{"provider": "Test"}]) - - # No temp files should remain - temp_files = list((temp_home / ".pdd").glob(".llm_model_*.tmp")) - assert len(temp_files) == 0 - - def test_handles_special_characters_in_api_keys(self, temp_home, monkeypatch): - """Should handle API key values with special shell characters.""" - monkeypatch.setenv("SHELL", "/bin/bash") - - # Characters that might cause issues in shell scripts - special_values = [ - 'key$with$dollars', - 'key"with"quotes', - "key'with'single", - 'key`with`backticks', - 'key\\with\\backslashes', - 'key with spaces', - 'key;with;semicolons', - ] - - for i, value in enumerate(special_values): - key_name = f"SPECIAL_KEY_{i}" - _save_key_to_api_env(key_name, value) - - env_path = temp_home / ".pdd" / "api-env.bash" - content = env_path.read_text() +class TestComplexProviderAuth: + """Tests for complex (multi-variable) provider authentication flows. - # All keys should be present - for i in range(len(special_values)): - assert f"SPECIAL_KEY_{i}" in content - - def test_handles_unicode_in_model_names(self, temp_home): - """Should handle unicode characters in model names.""" - csv_path = temp_home / ".pdd" / "llm_model.csv" - rows = [ - { - "provider": "Tëst", - "model": "模型-émoji-🤖", - "input": "1.0", - "output": "2.0", - } - ] + _setup_complex_provider is tested directly because it's the entry point + for a significant user-facing flow that add_provider_from_registry delegates to. + """ - _write_csv_atomic(csv_path, rows) - - result = _read_csv(csv_path) - assert result[0]["provider"] == "Tëst" - assert "模型" in result[0]["model"] - - def test_handles_empty_csv_fields(self, temp_home): - """Should handle rows with empty optional fields.""" - csv_path = temp_home / ".pdd" / "llm_model.csv" - rows = [ - { - "provider": "Test", - "model": "test-model", - # All other fields will be filled with empty strings - } + def test_registry_contains_expected_providers(self): + """Registry should contain the 5 known complex providers.""" + expected = {"Google Vertex AI", "AWS Bedrock", "Azure OpenAI", "Azure AI", "Github Copilot"} + assert expected == set(COMPLEX_AUTH_PROVIDERS.keys()) + + def test_simple_providers_not_in_registry(self): + for name in ["Anthropic", "OpenAI", "DeepSeek"]: + assert name not in COMPLEX_AUTH_PROVIDERS + + def test_registry_entries_have_required_fields(self): + required_keys = {"env_var", "label", "required", "default", "hint"} + for provider, configs in COMPLEX_AUTH_PROVIDERS.items(): + assert len(configs) > 0, f"{provider} has no configs" + for cfg in configs: + assert required_keys <= set(cfg.keys()), f"{provider} config missing keys" + + def test_unknown_provider_returns_false(self): + assert _setup_complex_provider("Unknown Provider") is False + + @mock.patch("pdd.provider_manager._ensure_api_env_sourced_in_rc") + @mock.patch("pdd.provider_manager._save_key_to_api_env") + @mock.patch("pdd.provider_manager._is_key_set", return_value=None) + @mock.patch("pdd.provider_manager.Confirm") + @mock.patch("pdd.provider_manager.Prompt") + @mock.patch("pdd.provider_manager.console") + def test_bedrock_saves_all_three_vars( + self, mock_console, mock_prompt, mock_confirm, mock_is_key, mock_save, mock_rc + ): + mock_prompt.ask.side_effect = ["AKIAEXAMPLE", "wJalrXSecret", "us-east-1"] + + assert _setup_complex_provider("AWS Bedrock") is True + assert mock_save.call_count == 3 + mock_save.assert_any_call("AWS_ACCESS_KEY_ID", "AKIAEXAMPLE") + mock_save.assert_any_call("AWS_SECRET_ACCESS_KEY", "wJalrXSecret") + mock_save.assert_any_call("AWS_REGION_NAME", "us-east-1") + mock_rc.assert_called_once() + + @mock.patch("pdd.provider_manager._ensure_api_env_sourced_in_rc") + @mock.patch("pdd.provider_manager._save_key_to_api_env") + @mock.patch("pdd.provider_manager._is_key_set", return_value=None) + @mock.patch("pdd.provider_manager.Confirm") + @mock.patch("pdd.provider_manager.Prompt") + @mock.patch("pdd.provider_manager.console") + def test_vertex_adc_skips_credentials_save( + self, mock_console, mock_prompt, mock_confirm, mock_is_key, mock_save, mock_rc + ): + """When user enters 'adc' for Vertex credentials, that var should not be saved.""" + mock_prompt.ask.side_effect = ["adc", "my-project-123", "us-central1"] + + assert _setup_complex_provider("Google Vertex AI") is True + assert mock_save.call_count == 2 + mock_save.assert_any_call("VERTEXAI_PROJECT", "my-project-123") + mock_save.assert_any_call("VERTEXAI_LOCATION", "us-central1") + + @mock.patch("pdd.provider_manager._ensure_api_env_sourced_in_rc") + @mock.patch("pdd.provider_manager._save_key_to_api_env") + @mock.patch("pdd.provider_manager._is_key_set", return_value=None) + @mock.patch("pdd.provider_manager.Confirm") + @mock.patch("pdd.provider_manager.Prompt") + @mock.patch("pdd.provider_manager.console") + def test_azure_openai_saves_three_vars( + self, mock_console, mock_prompt, mock_confirm, mock_is_key, mock_save, mock_rc + ): + mock_prompt.ask.side_effect = [ + "abc123key", "https://myresource.openai.azure.com/", "2024-10-21", ] - _write_csv_atomic(csv_path, rows) - - result = _read_csv(csv_path) - assert result[0]["provider"] == "Test" - assert result[0]["api_key"] == "" - - def test_concurrent_writes_safe(self, temp_home): - """Atomic writes should be safe for concurrent access.""" - csv_path = temp_home / ".pdd" / "llm_model.csv" - - # Write initial data - _write_csv_atomic(csv_path, [{"provider": "Initial"}]) - - # Simulate concurrent write - _write_csv_atomic(csv_path, [{"provider": "Updated"}]) - - result = _read_csv(csv_path) - assert len(result) == 1 - assert result[0]["provider"] == "Updated" + assert _setup_complex_provider("Azure OpenAI") is True + assert mock_save.call_count == 3 + mock_save.assert_any_call("AZURE_API_KEY", "abc123key") + mock_save.assert_any_call("AZURE_API_BASE", "https://myresource.openai.azure.com/") + mock_save.assert_any_call("AZURE_API_VERSION", "2024-10-21") + + @mock.patch("pdd.provider_manager._ensure_api_env_sourced_in_rc") + @mock.patch("pdd.provider_manager._save_key_to_api_env") + @mock.patch("pdd.provider_manager._is_key_set", return_value=None) + @mock.patch("pdd.provider_manager.Confirm") + @mock.patch("pdd.provider_manager.Prompt") + @mock.patch("pdd.provider_manager.console") + def test_skip_all_required_vars_returns_false( + self, mock_console, mock_prompt, mock_confirm, mock_is_key, mock_save, mock_rc + ): + """Skipping all vars should return False and save nothing.""" + mock_prompt.ask.side_effect = ["", "", ""] + + assert _setup_complex_provider("AWS Bedrock") is False + mock_save.assert_not_called() + mock_rc.assert_not_called() + + @mock.patch("pdd.provider_manager._ensure_api_env_sourced_in_rc") + @mock.patch("pdd.provider_manager._save_key_to_api_env") + @mock.patch("pdd.provider_manager._is_key_set", return_value="shell environment") + @mock.patch("pdd.provider_manager.Confirm") + @mock.patch("pdd.provider_manager.Prompt") + @mock.patch("pdd.provider_manager.console") + def test_existing_key_skipped_when_update_declined( + self, mock_console, mock_prompt, mock_confirm, mock_is_key, mock_save, mock_rc + ): + mock_confirm.ask.return_value = False + + assert _setup_complex_provider("Github Copilot") is False + mock_save.assert_not_called() # --------------------------------------------------------------------------- -# Shell script execution tests (following test_setup_tool.py pattern) +# VII. Shell execution integration tests +# +# These are the most valuable tests in this file. They verify that +# _save_key_to_api_env produces scripts that real shells can source, +# and that API key values survive the shell escaping roundtrip. # --------------------------------------------------------------------------- -class TestApiEnvShellExecution: - """ - Tests that verify generated api-env scripts can be sourced and - correctly preserve API key values, especially with special characters. +def _shell_available(shell: str) -> bool: + return shutil.which(shell) is not None - These tests follow the rigorous pattern from test_setup_tool.py, - actually executing shell scripts to verify correctness. - """ - def _shell_available(self, shell: str) -> bool: - """Check if a shell is available on the system.""" - import shutil - return shutil.which(shell) is not None +class TestShellExecution: + """ + Integration tests that actually execute generated api-env scripts + in real shells and verify key values are preserved exactly. + """ - def test_api_env_script_valid_bash_syntax(self, temp_home, monkeypatch): - """ - Generated api-env script should have valid bash syntax. - This test catches quoting errors that would break shell parsing. - """ + def test_bash_syntax_valid_with_special_chars(self, temp_home, monkeypatch): + """Generated api-env script should have valid bash syntax.""" monkeypatch.setenv("SHELL", "/bin/bash") - - # Save a key with special characters - special_key = 'AIzaSyAbCdEf123456$var"quote\'backtick\\slash' - _save_key_to_api_env("TEST_KEY", special_key) - + _save_key_to_api_env("TEST_KEY", 'AIzaSyAbCdEf123456$var"quote\'backtick\\slash') env_path = temp_home / ".pdd" / "api-env.bash" - # Run bash syntax check - import subprocess result = subprocess.run( ["bash", "-n", str(env_path)], - capture_output=True, - text=True, - timeout=5, + capture_output=True, text=True, timeout=5, ) - assert result.returncode == 0, ( - f"Generated script has bash syntax errors: {result.stderr}\n" - f"Script content:\n{env_path.read_text()}" + f"Bash syntax error: {result.stderr}\nScript:\n{env_path.read_text()}" ) - def test_api_env_script_valid_zsh_syntax(self, temp_home, monkeypatch): - """Generated api-env script should have valid zsh syntax.""" - if not self._shell_available("zsh"): + def test_zsh_syntax_valid_with_special_chars(self, temp_home, monkeypatch): + if not _shell_available("zsh"): pytest.skip("zsh not available") - monkeypatch.setenv("SHELL", "/bin/zsh") - - special_key = 'AIzaSyAbCdEf123456$var"quote\'backtick\\slash' - _save_key_to_api_env("TEST_KEY", special_key) - + _save_key_to_api_env("TEST_KEY", 'AIzaSyAbCdEf123456$var"quote\'backtick\\slash') env_path = temp_home / ".pdd" / "api-env.zsh" - import subprocess result = subprocess.run( ["zsh", "-n", str(env_path)], - capture_output=True, - text=True, - timeout=5, - ) - - assert result.returncode == 0, ( - f"Generated script has zsh syntax errors: {result.stderr}\n" - f"Script content:\n{env_path.read_text()}" - ) - - def test_api_env_script_can_be_sourced_bash(self, temp_home, monkeypatch): - """Script should be sourceable in bash without errors.""" - monkeypatch.setenv("SHELL", "/bin/bash") - - special_key = 'key"with\'many$special`characters' - _save_key_to_api_env("TEST_KEY", special_key) - - env_path = temp_home / ".pdd" / "api-env.bash" - - import subprocess - result = subprocess.run( - ["bash", "-c", f"source {env_path} && exit 0"], - capture_output=True, - text=True, - timeout=5, + capture_output=True, text=True, timeout=5, ) - assert result.returncode == 0, ( - f"Cannot source script in bash: {result.stderr}\n" - f"Script content:\n{env_path.read_text()}" + f"Zsh syntax error: {result.stderr}\nScript:\n{env_path.read_text()}" ) - def test_api_env_preserves_key_value_bash(self, temp_home, monkeypatch): - """ - API key value should be preserved exactly when script is sourced. - This is the critical test - verifies the key survives shell escaping. - """ + def test_key_value_preserved_bash(self, temp_home, monkeypatch): + """API key should survive bash source→read roundtrip exactly.""" monkeypatch.setenv("SHELL", "/bin/bash") - - original_key = 'AIzaSyAbCdEf123456$var"quote\'backtick\\slash' - _save_key_to_api_env("TEST_KEY", original_key) - + original = 'AIzaSyAbCdEf123456$var"quote\'backtick\\slash' + _save_key_to_api_env("TEST_KEY", original) env_path = temp_home / ".pdd" / "api-env.bash" - import subprocess result = subprocess.run( - [ - "bash", "-c", - f"source {env_path} && python3 -c \"import os; print(os.environ.get('TEST_KEY', ''))\"" - ], - capture_output=True, - text=True, - timeout=5, - ) - - assert result.returncode == 0, ( - f"Failed to source script and read env var: {result.stderr}\n" - f"Script content:\n{env_path.read_text()}" + ["bash", "-c", + f"source {env_path} && python3 -c \"import os; print(os.environ.get('TEST_KEY', ''))\""], + capture_output=True, text=True, timeout=5, ) + assert result.returncode == 0, f"Source failed: {result.stderr}" + assert result.stdout.strip() == original - extracted_key = result.stdout.strip() - assert extracted_key == original_key, ( - f"Key value was corrupted during shell escaping.\n" - f"Original: {repr(original_key)}\n" - f"Extracted: {repr(extracted_key)}\n" - f"Script content:\n{env_path.read_text()}" - ) - - def test_api_env_preserves_key_value_zsh(self, temp_home, monkeypatch): - """API key value should be preserved exactly in zsh.""" - if not self._shell_available("zsh"): + def test_key_value_preserved_zsh(self, temp_home, monkeypatch): + if not _shell_available("zsh"): pytest.skip("zsh not available") - monkeypatch.setenv("SHELL", "/bin/zsh") - - original_key = 'AIzaSyAbCdEf123456$var"quote\'backtick\\slash' - _save_key_to_api_env("TEST_KEY", original_key) - + original = 'AIzaSyAbCdEf123456$var"quote\'backtick\\slash' + _save_key_to_api_env("TEST_KEY", original) env_path = temp_home / ".pdd" / "api-env.zsh" - import subprocess result = subprocess.run( - [ - "zsh", "-c", - f"source {env_path} && python3 -c \"import os; print(os.environ.get('TEST_KEY', ''))\"" - ], - capture_output=True, - text=True, - timeout=5, + ["zsh", "-c", + f"source {env_path} && python3 -c \"import os; print(os.environ.get('TEST_KEY', ''))\""], + capture_output=True, text=True, timeout=5, ) + assert result.returncode == 0, f"Source failed: {result.stderr}" + assert result.stdout.strip() == original + + @pytest.mark.parametrize("name,value", [ + ("dollar", "key$value"), + ("double_quote", 'key"value'), + ("single_quote", "key'value"), + ("backtick", "key`value"), + ("backslash", "key\\value"), + ("space", "key value"), + ("semicolon", "key;value"), + ("ampersand", "key&value"), + ("pipe", "key|value"), + ("newline", "key\nvalue"), + ("tab", "key\tvalue"), + ]) + def test_problematic_char_preserved_bash(self, temp_home, monkeypatch, name, value): + """Each problematic shell character should be preserved through bash roundtrip.""" + monkeypatch.setenv("SHELL", "/bin/bash") + key_name = f"TEST_{name.upper()}" + _save_key_to_api_env(key_name, value) + env_path = temp_home / ".pdd" / "api-env.bash" - assert result.returncode == 0, ( - f"Failed to source script: {result.stderr}\n" - f"Script content:\n{env_path.read_text()}" + syntax = subprocess.run( + ["bash", "-n", str(env_path)], + capture_output=True, text=True, timeout=5, ) + assert syntax.returncode == 0, f"Syntax error for '{name}': {syntax.stderr}" - extracted_key = result.stdout.strip() - assert extracted_key == original_key, ( - f"Key value was corrupted in zsh.\n" - f"Original: {repr(original_key)}\n" - f"Extracted: {repr(extracted_key)}\n" - f"Script content:\n{env_path.read_text()}" + extract = subprocess.run( + ["bash", "-c", + f"source {env_path} && python3 -c \"import os; print(repr(os.environ.get('{key_name}', '')))\""], + capture_output=True, text=True, timeout=5, ) - - def test_api_env_with_various_problematic_characters(self, temp_home, monkeypatch): - """ - Test with various characters that commonly cause shell escaping issues. - Each character tested individually to identify specific failures. - """ - monkeypatch.setenv("SHELL", "/bin/bash") - - problematic_chars = [ - ('dollar', 'key$value'), - ('double_quote', 'key"value'), - ('single_quote', "key'value"), - ('backtick', 'key`value'), - ('backslash', 'key\\value'), - ('space', 'key value'), - ('semicolon', 'key;value'), - ('ampersand', 'key&value'), - ('pipe', 'key|value'), - ('newline', 'key\nvalue'), - ('tab', 'key\tvalue'), - ] - - import subprocess - - for name, test_value in problematic_chars: - key_name = f"TEST_{name.upper()}" - _save_key_to_api_env(key_name, test_value) - - env_path = temp_home / ".pdd" / "api-env.bash" - - # Verify syntax is valid - syntax_result = subprocess.run( - ["bash", "-n", str(env_path)], - capture_output=True, - text=True, - timeout=5, - ) - - assert syntax_result.returncode == 0, ( - f"Syntax error with '{name}' character: {syntax_result.stderr}\n" - f"Script:\n{env_path.read_text()}" + if extract.returncode == 0: + extracted = eval(extract.stdout.strip()) + assert extracted == value, ( + f"Value corrupted for '{name}': expected {repr(value)}, got {repr(extracted)}" ) - # Verify value is preserved - extract_result = subprocess.run( - [ - "bash", "-c", - f"source {env_path} && python3 -c \"import os; print(repr(os.environ.get('{key_name}', '')))\"" - ], - capture_output=True, - text=True, - timeout=5, - ) - - if extract_result.returncode == 0: - extracted = eval(extract_result.stdout.strip()) - assert extracted == test_value, ( - f"Value corrupted for '{name}' character.\n" - f"Expected: {repr(test_value)}\n" - f"Got: {repr(extracted)}" - ) - - def test_multiple_keys_preserved(self, temp_home, monkeypatch): - """Multiple keys should all be preserved correctly.""" + def test_multiple_keys_all_preserved(self, temp_home, monkeypatch): + """Multiple keys saved sequentially should all be preserved.""" monkeypatch.setenv("SHELL", "/bin/bash") - keys = { "OPENAI_API_KEY": "sk-test123", "ANTHROPIC_API_KEY": "ant-key$special", "GEMINI_API_KEY": 'gem"quoted\'key', } - - for key_name, key_value in keys.items(): - _save_key_to_api_env(key_name, key_value) + for k, v in keys.items(): + _save_key_to_api_env(k, v) env_path = temp_home / ".pdd" / "api-env.bash" - - import subprocess - - for key_name, expected_value in keys.items(): + for key_name, expected in keys.items(): result = subprocess.run( - [ - "bash", "-c", - f"source {env_path} && python3 -c \"import os; print(os.environ.get('{key_name}', ''))\"" - ], - capture_output=True, - text=True, - timeout=5, + ["bash", "-c", + f"source {env_path} && python3 -c \"import os; print(os.environ.get('{key_name}', ''))\""], + capture_output=True, text=True, timeout=5, ) - assert result.returncode == 0 - extracted = result.stdout.strip() - assert extracted == expected_value, ( - f"{key_name} was corrupted.\n" - f"Expected: {repr(expected_value)}\n" - f"Got: {repr(extracted)}" - ) + assert result.stdout.strip() == expected - def test_normal_key_still_works(self, temp_home, monkeypatch): - """Normal keys without special characters should still work.""" + def test_key_update_replaces_in_place(self, temp_home, monkeypatch): + """Updating an existing key should replace it, not duplicate it.""" monkeypatch.setenv("SHELL", "/bin/bash") - - normal_key = "sk-proj-abcdef1234567890ABCDEF" - _save_key_to_api_env("OPENAI_API_KEY", normal_key) + _save_key_to_api_env("MY_KEY", "old-value") + _save_key_to_api_env("MY_KEY", "new-value") env_path = temp_home / ".pdd" / "api-env.bash" + content = env_path.read_text() + assert content.count("MY_KEY") == 1 - import subprocess result = subprocess.run( - [ - "bash", "-c", - f"source {env_path} && echo $OPENAI_API_KEY" - ], - capture_output=True, - text=True, - timeout=5, + ["bash", "-c", + f"source {env_path} && python3 -c \"import os; print(os.environ.get('MY_KEY', ''))\""], + capture_output=True, text=True, timeout=5, ) + assert result.returncode == 0 + assert result.stdout.strip() == "new-value" + def test_save_key_sets_os_environ_immediately(self, temp_home, monkeypatch): + """_save_key_to_api_env should set os.environ for immediate availability.""" + monkeypatch.setenv("SHELL", "/bin/bash") + monkeypatch.delenv("MY_IMMEDIATE_KEY", raising=False) + + _save_key_to_api_env("MY_IMMEDIATE_KEY", "test-value-abc") + + assert os.environ.get("MY_IMMEDIATE_KEY") == "test-value-abc" + + def test_commented_key_replaced_on_save(self, temp_home, monkeypatch): + """Saving a key that was previously commented out should uncomment/replace it.""" + monkeypatch.setenv("SHELL", "/bin/bash") + env_path = temp_home / ".pdd" / "api-env.bash" + env_path.write_text("# export OLD_KEY=old-value\n") + + _save_key_to_api_env("OLD_KEY", "new-value") + + content = env_path.read_text() + assert "# export OLD_KEY" not in content + assert "new-value" in content + + result = subprocess.run( + ["bash", "-c", + f"source {env_path} && python3 -c \"import os; print(os.environ.get('OLD_KEY', ''))\""], + capture_output=True, text=True, timeout=5, + ) assert result.returncode == 0 - assert result.stdout.strip() == normal_key + assert result.stdout.strip() == "new-value" diff --git a/tests/test_setup_tool.py b/tests/test_setup_tool.py index 1c9f1d65a..5c74b3eec 100644 --- a/tests/test_setup_tool.py +++ b/tests/test_setup_tool.py @@ -1,378 +1,760 @@ -"""Tests for setup_tool.py — deterministic auto-configuration.""" - +# Test Plan: +# All tests drive through the public entry point `run_setup()` via the helper +# `_run_setup_capture()` which mocks only at true boundaries (user input, +# filesystem paths, LLM calls, CLI detection) and captures printed output. +# +# I. End-to-End Success Path +# 1. test_happy_path_enter_to_finish: CLI detected, auto-phase succeeds, +# user presses Enter → exit summary printed, no options menu. +# 2. test_happy_path_open_menu_then_exit: Auto-phase succeeds, user enters +# 'm' → options menu shown, then exit summary printed. +# 3. test_happy_path_skipped_cli: CLI skipped → auto-phase still runs, +# exit summary printed. +# +# II. CLI Bootstrap Warnings +# 4. test_no_api_key_warning_shown: CLI found but api_key_configured=False +# → yellow warning about limited capability appears in output. +# 5. test_multiple_cli_results: Multiple CLIs, one missing key → warning +# only for the one missing. +# +# III. Auto-Phase Failure / Fallback +# 6. test_auto_phase_failure_triggers_menu: _run_auto_phase returns None +# → "Setup incomplete" message, options menu shown. +# +# IV. Interrupt Handling +# 7. test_keyboard_interrupt_phase1: KeyboardInterrupt during CLI bootstrap +# → "Setup interrupted" message, clean exit. +# 8. test_keyboard_interrupt_phase2: KeyboardInterrupt during auto phase +# → "Setup interrupted" message, clean exit. +# +# V. Key Scanning (via run_setup) +# 9. test_scan_finds_env_keys: Keys in os.environ → found and displayed +# with source "shell environment". +# 10. test_scan_finds_multiple_keys: Multiple keys → all found, count correct. +# 11. test_scan_no_keys_prompts_user: No keys anywhere → interactive +# prompt is invoked; after adding one, flow continues. +# 12. test_scan_multi_var_provider_grouped: Pipe-delimited api_key → +# grouped display shows "N/N vars set". +# 13. test_scan_multi_var_provider_partial: Some vars missing → +# grouped display shows partial count and missing names. +# +# VI. Model Configuration (via run_setup) +# 14. test_models_added_from_reference_csv: Matching API keys → +# new models written to user CSV. +# 15. test_models_deduplicated: Existing models in user CSV → +# not duplicated. +# 16. test_local_models_skipped: ollama/lm_studio/localhost rows excluded. +# 17. test_device_flow_models_included: Empty api_key rows always included. +# +# VII. .pddrc Handling (via run_setup) +# 18. test_pddrc_exists_confirmed: .pddrc already exists → "detected". +# 19. test_pddrc_created_on_confirm: No .pddrc, user types 'y' → created. +# 20. test_pddrc_skipped_on_enter: No .pddrc, user presses Enter → skipped. +# +# VIII. Model Testing (via run_setup) +# 21. test_model_test_success: _run_test succeeds → "responded OK". +# 22. test_model_test_failure: _run_test fails → error shown. +# +# IX. Exit Summary +# 23. test_exit_summary_writes_file: PDD-SETUP-SUMMARY.txt created. +# 24. test_exit_summary_creates_sample_prompt: success_python.prompt created. +# 25. test_exit_summary_quick_start_printed: QUICK START in terminal output. +# +# X. Options Menu +# 26. test_menu_add_provider: User selects "1" → add_provider called. +# 27. test_menu_test_model: User selects "2" → test_model_interactive called. +# 28. test_menu_enter_exits: Enter → menu exits, no actions. +# 29. test_menu_invalid_option: "9" → "Invalid option" shown. + +import csv +import os import pytest -from unittest.mock import MagicMock, patch from pathlib import Path +from unittest.mock import MagicMock, patch from pdd import setup_tool # --------------------------------------------------------------------------- -# Fixtures +# Module-level test data constants # --------------------------------------------------------------------------- -@pytest.fixture -def mock_cli_result(): - result = MagicMock() - result.cli_name = "test_cli" - result.provider = "test_provider" - result.api_key_configured = True - return result - - -@pytest.fixture -def mock_detect_cli(): - with patch("pdd.cli_detector.detect_and_bootstrap_cli") as m: - yield m - - -@pytest.fixture -def mock_auto_phase(): - with patch("pdd.setup_tool._run_auto_phase") as m: - yield m - - -@pytest.fixture -def mock_fallback_menu(): - with patch("pdd.setup_tool._run_fallback_menu") as m: - yield m - - -@pytest.fixture -def mock_input(): - with patch("builtins.input") as m: - yield m - - -@pytest.fixture -def mock_print(): - with patch("builtins.print") as m: - yield m +SIMPLE_REF_CSV = [ + {"provider": "Anthropic", "model": "claude-sonnet", "api_key": "ANTHROPIC_API_KEY", + "base_url": "", "input": "3", "output": "15", "coding_arena_elo": "1200", + "max_reasoning_tokens": "", "structured_output": "", "reasoning_type": "", "location": ""}, + {"provider": "OpenAI", "model": "gpt-4o", "api_key": "OPENAI_API_KEY", + "base_url": "", "input": "5", "output": "15", "coding_arena_elo": "1100", + "max_reasoning_tokens": "", "structured_output": "", "reasoning_type": "", "location": ""}, +] + +BEDROCK_REF_CSV = [ + {"provider": "AWS Bedrock", "model": "bedrock/anthropic.claude-v1", + "api_key": "AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_REGION_NAME", + "base_url": "", "input": "8", "output": "24", "coding_arena_elo": "1150", + "max_reasoning_tokens": "", "structured_output": "", "reasoning_type": "", "location": ""}, +] + +DEVICE_FLOW_CSV = [ + {"provider": "GitHub Copilot", "model": "copilot/gpt-4", "api_key": "", + "base_url": "", "input": "0", "output": "0", "coding_arena_elo": "1050", + "max_reasoning_tokens": "", "structured_output": "", "reasoning_type": "", "location": ""}, +] + +LOCAL_MODELS_CSV = [ + {"provider": "ollama", "model": "ollama/llama3", "api_key": "", + "base_url": "http://localhost:11434", "input": "0", "output": "0", + "coding_arena_elo": "", "max_reasoning_tokens": "", "structured_output": "", + "reasoning_type": "", "location": ""}, + {"provider": "lm_studio", "model": "lm/mistral", "api_key": "", + "base_url": "http://localhost:1234", "input": "0", "output": "0", + "coding_arena_elo": "", "max_reasoning_tokens": "", "structured_output": "", + "reasoning_type": "", "location": ""}, +] + +TEST_SUCCESS_RESULT = { + "success": True, "duration_s": 1.2, "cost": 0.001, + "error": None, "tokens": {"input": 10, "output": 20}, +} + +TEST_FAILURE_RESULT = { + "success": False, "duration_s": 0.5, "cost": 0.0, + "error": "Authentication error", "tokens": None, +} + +# Env vars to clean to prevent leakage from real environment +_ENV_VARS_TO_CLEAN = [ + "ANTHROPIC_API_KEY", "OPENAI_API_KEY", "GEMINI_API_KEY", + "DEEPSEEK_API_KEY", "AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", + "AWS_REGION_NAME", "GOOGLE_APPLICATION_CREDENTIALS", "VERTEXAI_PROJECT", + "VERTEXAI_LOCATION", "AZURE_API_KEY", "AZURE_API_BASE", + "AZURE_API_VERSION", +] # --------------------------------------------------------------------------- -# Tests for run_setup +# Helpers # --------------------------------------------------------------------------- -def test_run_setup_no_cli_detected(mock_detect_cli, mock_auto_phase, mock_print): - """Test that setup exits early if no CLI is detected.""" - result = MagicMock() - result.cli_name = "" - mock_detect_cli.return_value = result - - setup_tool.run_setup() - - mock_detect_cli.assert_called_once() - mock_auto_phase.assert_not_called() - assert any("Agentic features require at least one CLI tool" in str(c) for c in mock_print.call_args_list) - - -def test_run_setup_success_path(mock_detect_cli, mock_auto_phase, mock_input, mock_print, mock_cli_result): - """Test the happy path where auto phase succeeds.""" - mock_detect_cli.return_value = mock_cli_result - mock_auto_phase.return_value = True - - with patch("pdd.setup_tool._console") as mock_console: - setup_tool.run_setup() - - mock_detect_cli.assert_called_once() - mock_auto_phase.assert_called_once() - assert any("Setup complete" in str(c) for c in mock_console.print.call_args_list) - - -def test_run_setup_fallback_path(mock_detect_cli, mock_auto_phase, mock_fallback_menu, mock_input, mock_cli_result): - """Test that fallback menu is triggered if auto phase fails.""" - mock_detect_cli.return_value = mock_cli_result - mock_auto_phase.return_value = False - - setup_tool.run_setup() - - mock_auto_phase.assert_called_once() - mock_fallback_menu.assert_called_once() - - -def test_run_setup_keyboard_interrupt(mock_detect_cli, mock_print): - """Test handling of KeyboardInterrupt during setup.""" - mock_detect_cli.side_effect = KeyboardInterrupt - - setup_tool.run_setup() - - assert any("Setup interrupted" in str(c) for c in mock_print.call_args_list) - - -def test_run_setup_no_api_key_warning(mock_detect_cli, mock_auto_phase, mock_input, mock_print): - """Test that a warning is printed if API key is not configured, but proceeds.""" +def _make_cli_result(cli_name="claude", provider="anthropic", + api_key_configured=True, skipped=False): + """Create a mock CliBootstrapResult.""" result = MagicMock() - result.cli_name = "test_cli" - result.api_key_configured = False - mock_detect_cli.return_value = result - mock_auto_phase.return_value = True - - setup_tool.run_setup() - - assert any("No API key configured" in str(c) for c in mock_print.call_args_list) - mock_auto_phase.assert_called_once() - - -# --------------------------------------------------------------------------- -# Tests for _run_auto_phase -# --------------------------------------------------------------------------- - -@patch("pdd.setup_tool._step4_test_and_summary") -@patch("pdd.setup_tool._step3_local_llms_and_pddrc") -@patch("pdd.setup_tool._step2_configure_models") -@patch("pdd.setup_tool._step1_scan_keys") -@patch("builtins.input") -def test_run_auto_phase_success(mock_input, mock_step1, mock_step2, mock_step3, mock_step4): - """Test that all 4 steps run sequentially on success.""" - mock_step1.return_value = [("ANTHROPIC_API_KEY", "shell environment")] - mock_step2.return_value = {"Anthropic": 3} - mock_step3.return_value = {"Ollama": ["llama3.2:3b"]} - - result = setup_tool._run_auto_phase() - - assert result is True - mock_step1.assert_called_once() - mock_step2.assert_called_once_with([("ANTHROPIC_API_KEY", "shell environment")]) - mock_step3.assert_called_once() - mock_step4.assert_called_once() - # 3 "Press Enter" prompts between steps - assert mock_input.call_count == 3 - - -@patch("pdd.setup_tool._step1_scan_keys") -@patch("builtins.input") -def test_run_auto_phase_exception_returns_false(mock_input, mock_step1): - """Test that exceptions in steps cause fallback.""" - mock_step1.side_effect = RuntimeError("test error") - - result = setup_tool._run_auto_phase() - - assert result is False - - -# --------------------------------------------------------------------------- -# Tests for _step1_scan_keys -# --------------------------------------------------------------------------- - -@patch("pdd.setup_tool._prompt_for_api_key") -@patch("pdd.api_key_scanner._parse_api_env_file") -@patch("pdd.api_key_scanner._detect_shell") -@patch("pdd.litellm_registry.PROVIDER_API_KEY_MAP", {"anthropic": "ANTHROPIC_API_KEY", "openai": "OPENAI_API_KEY"}) -def test_step1_finds_keys_in_env(mock_detect_shell, mock_parse, mock_prompt, tmp_path, monkeypatch): - """Test that step 1 finds keys from os.environ.""" - monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-test") - monkeypatch.delenv("OPENAI_API_KEY", raising=False) - monkeypatch.setattr(Path, "home", lambda: tmp_path) - mock_detect_shell.return_value = "bash" - mock_parse.return_value = {} - - found = setup_tool._step1_scan_keys() - - assert len(found) == 1 - assert found[0] == ("ANTHROPIC_API_KEY", "shell environment") - mock_prompt.assert_not_called() - - -@patch("pdd.setup_tool._prompt_for_api_key") -@patch("pdd.api_key_scanner._parse_api_env_file") -@patch("pdd.api_key_scanner._detect_shell") -@patch("pdd.litellm_registry.PROVIDER_API_KEY_MAP", {"anthropic": "ANTHROPIC_API_KEY"}) -def test_step1_no_keys_triggers_prompt(mock_detect_shell, mock_parse, mock_prompt, tmp_path, monkeypatch): - """Test that step 1 prompts for a key when none found.""" - monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False) - monkeypatch.setattr(Path, "home", lambda: tmp_path) - mock_detect_shell.return_value = "bash" - mock_parse.return_value = {} - mock_prompt.return_value = [("ANTHROPIC_API_KEY", "~/.pdd/api-env.bash")] - - found = setup_tool._step1_scan_keys() - - assert len(found) == 1 - mock_prompt.assert_called_once() - - -# --------------------------------------------------------------------------- -# Tests for _step2_configure_models -# --------------------------------------------------------------------------- - -@patch("pdd.provider_manager._get_user_csv_path") -@patch("pdd.provider_manager._write_csv_atomic") -@patch("pdd.provider_manager._read_csv") -def test_step2_adds_matching_models(mock_read, mock_write, mock_csv_path, tmp_path): - """Test that step 2 filters reference CSV by found keys and writes user CSV.""" - mock_csv_path.return_value = tmp_path / "llm_model.csv" - - # First call: reference CSV, second call: existing user CSV (empty) - mock_read.side_effect = [ - [ - {"provider": "Anthropic", "model": "claude-sonnet", "api_key": "ANTHROPIC_API_KEY", "base_url": ""}, - {"provider": "OpenAI", "model": "gpt-4", "api_key": "OPENAI_API_KEY", "base_url": ""}, - ], - [], # empty user CSV - ] - - found_keys = [("ANTHROPIC_API_KEY", "shell environment")] - result = setup_tool._step2_configure_models(found_keys) - - assert result == {"Anthropic": 1} - mock_write.assert_called_once() - written_rows = mock_write.call_args[0][1] - assert len(written_rows) == 1 - assert written_rows[0]["model"] == "claude-sonnet" - + result.cli_name = cli_name + result.provider = provider + result.api_key_configured = api_key_configured + result.skipped = skipped + return result -@patch("pdd.provider_manager._get_user_csv_path") -@patch("pdd.provider_manager._write_csv_atomic") -@patch("pdd.provider_manager._read_csv") -def test_step2_deduplicates_existing(mock_read, mock_write, mock_csv_path, tmp_path): - """Test that step 2 skips models already in user CSV.""" - mock_csv_path.return_value = tmp_path / "llm_model.csv" - mock_read.side_effect = [ - [{"provider": "Anthropic", "model": "claude-sonnet", "api_key": "ANTHROPIC_API_KEY", "base_url": ""}], - [{"provider": "Anthropic", "model": "claude-sonnet", "api_key": "ANTHROPIC_API_KEY"}], +def _write_csv_file(path, rows): + """Write a list of row dicts as a CSV file.""" + path.parent.mkdir(parents=True, exist_ok=True) + if not rows: + path.write_text("") + return + fieldnames = list(rows[0].keys()) + with open(path, "w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + writer.writerows(rows) + + +def _run_setup_capture(tmp_path, monkeypatch, ref_csv_rows=None, + user_csv_rows=None, env_keys=None, + input_sequence=None, cli_results=None, + test_result=None, create_pddrc=False): + """Run run_setup() with full environment control, capturing all output. + + Mocks at true boundaries only: CLI detection, user input, model testing, + menu delegates, filesystem paths, and shell detection. Lets all internal + logic (key scanning, model filtering, CSV I/O, .pddrc creation) run + naturally. + + Returns: + (output_str, mocks_dict) — output is all captured print/console text; + mocks contains mock objects for call-count assertions. + """ + if ref_csv_rows is None: + ref_csv_rows = SIMPLE_REF_CSV + if env_keys is None: + env_keys = {"ANTHROPIC_API_KEY": "sk-ant-test123"} + if input_sequence is None: + input_sequence = ["", "", ""] + if cli_results is None: + cli_results = [_make_cli_result()] + if test_result is None: + test_result = TEST_SUCCESS_RESULT + + # --- Filesystem isolation --- + pdd_home = tmp_path / "home" + pdd_dir = pdd_home / ".pdd" + pdd_dir.mkdir(parents=True) + project_dir = tmp_path / "project" + project_dir.mkdir() + + monkeypatch.setattr(Path, "home", lambda: pdd_home) + monkeypatch.chdir(project_dir) + + # Create reference CSV alongside a fake module path + fake_module_dir = tmp_path / "fake_pdd" + fake_module_dir.mkdir() + data_dir = fake_module_dir / "data" + data_dir.mkdir() + _write_csv_file(data_dir / "llm_model.csv", ref_csv_rows) + monkeypatch.setattr(setup_tool, "__file__", + str(fake_module_dir / "setup_tool.py")) + + # Pre-populate user CSV if needed + if user_csv_rows: + _write_csv_file(pdd_dir / "llm_model.csv", user_csv_rows) + + # Create .pddrc if requested + if create_pddrc: + (project_dir / ".pddrc").write_text("version: '1.0'\n") + + # --- Environment isolation --- + for var in _ENV_VARS_TO_CLEAN: + monkeypatch.delenv(var, raising=False) + for key, val in env_keys.items(): + monkeypatch.setenv(key, val) + + # Force shell detection to "bash" for deterministic api-env path + monkeypatch.setenv("SHELL", "/bin/bash") + + # --- Output capture --- + captured_lines = [] + + def capture_print(*args, **kwargs): + captured_lines.append(" ".join(str(a) for a in args)) + + mock_console = MagicMock() + mock_console.print = lambda *a, **kw: captured_lines.append( + " ".join(str(x) for x in a)) + + # --- Input mock --- + input_iter = iter(input_sequence) + + def mock_input(prompt=""): + captured_lines.append(str(prompt)) + try: + return next(input_iter) + except StopIteration: + return "" + + # --- Boundary mocks --- + mock_detect_cli = MagicMock(return_value=cli_results) + mock_run_test = MagicMock(return_value=test_result) + mock_add_provider = MagicMock() + mock_test_interactive = MagicMock() + + # Patch sys.stdout.write/flush used by the threaded test animation + mock_stdout_write = MagicMock( + side_effect=lambda s: captured_lines.append(s)) + + patches = [ + patch("pdd.setup_tool._console", mock_console), + patch("builtins.print", capture_print), + patch("builtins.input", mock_input), + patch("pdd.cli_detector.detect_and_bootstrap_cli", mock_detect_cli), + patch("pdd.model_tester._run_test", mock_run_test), + patch("pdd.provider_manager.add_provider_from_registry", mock_add_provider), + patch("pdd.model_tester.test_model_interactive", mock_test_interactive), + patch("pdd.provider_manager._get_user_csv_path", + lambda: pdd_dir / "llm_model.csv"), + patch("pdd.provider_manager._get_shell_rc_path", lambda: None), + patch("sys.stdout"), ] - found_keys = [("ANTHROPIC_API_KEY", "shell environment")] - result = setup_tool._step2_configure_models(found_keys) - - assert result == {"Anthropic": 1} - mock_write.assert_not_called() - + for p in patches: + p.start() -@patch("pdd.provider_manager._get_user_csv_path") -@patch("pdd.provider_manager._write_csv_atomic") -@patch("pdd.provider_manager._read_csv") -def test_step2_skips_local_models(mock_read, mock_write, mock_csv_path, tmp_path): - """Test that step 2 skips local models (ollama, lm_studio, localhost).""" - mock_csv_path.return_value = tmp_path / "llm_model.csv" + # Re-enable stdout.write and flush for the test animation capture + import sys as _sys + _sys.stdout.write = mock_stdout_write + _sys.stdout.flush = MagicMock() - mock_read.side_effect = [ - [ - {"provider": "Ollama", "model": "ollama/llama", "api_key": "", "base_url": "http://localhost:11434"}, - {"provider": "lm_studio", "model": "lm/model", "api_key": "", "base_url": "http://localhost:1234"}, - {"provider": "OpenAI", "model": "gpt-local", "api_key": "OPENAI_API_KEY", "base_url": "http://localhost:8080"}, - {"provider": "Anthropic", "model": "claude", "api_key": "ANTHROPIC_API_KEY", "base_url": ""}, + try: + setup_tool.run_setup() + except (SystemExit, StopIteration): + pass + finally: + for p in patches: + p.stop() + + output = "\n".join(captured_lines) + mocks = { + "detect_cli": mock_detect_cli, + "run_test": mock_run_test, + "console": mock_console, + "add_provider": mock_add_provider, + "test_interactive": mock_test_interactive, + } + return output, mocks + + +# =========================================================================== +# I. End-to-End Success Path +# =========================================================================== + +def test_happy_path_enter_to_finish(tmp_path, monkeypatch): + """Auto-phase succeeds, user presses Enter → exit summary, no menu.""" + output, mocks = _run_setup_capture( + tmp_path, monkeypatch, + ref_csv_rows=SIMPLE_REF_CSV, + env_keys={"ANTHROPIC_API_KEY": "sk-test"}, + create_pddrc=True, + # Inputs: Enter after step1, Enter after step2, Enter to finish + input_sequence=["", "", ""], + ) + assert "PDD Setup Complete" in output + mocks["detect_cli"].assert_called_once() + mocks["add_provider"].assert_not_called() + + +def test_happy_path_open_menu_then_exit(tmp_path, monkeypatch): + """Auto-phase succeeds, user enters 'm' → menu shown, then exit.""" + output, mocks = _run_setup_capture( + tmp_path, monkeypatch, + ref_csv_rows=SIMPLE_REF_CSV, + env_keys={"ANTHROPIC_API_KEY": "sk-test"}, + create_pddrc=True, + # Inputs: Enter step1, Enter step2, 'm' for menu, Enter to exit menu + input_sequence=["", "", "m", ""], + ) + assert "PDD Setup Complete" in output + assert "Options" in output + + +def test_happy_path_skipped_cli(tmp_path, monkeypatch): + """CLI skipped → auto-phase still runs.""" + output, mocks = _run_setup_capture( + tmp_path, monkeypatch, + ref_csv_rows=SIMPLE_REF_CSV, + env_keys={"ANTHROPIC_API_KEY": "sk-test"}, + cli_results=[_make_cli_result(skipped=True, cli_name="")], + create_pddrc=True, + input_sequence=["", "", ""], + ) + assert "PDD Setup Complete" in output + assert "No API key configured" not in output + + +# =========================================================================== +# II. CLI Bootstrap Warnings +# =========================================================================== + +def test_no_api_key_warning_shown(tmp_path, monkeypatch): + """CLI found but no API key → warning appears.""" + output, _ = _run_setup_capture( + tmp_path, monkeypatch, + ref_csv_rows=SIMPLE_REF_CSV, + env_keys={"ANTHROPIC_API_KEY": "sk-test"}, + cli_results=[_make_cli_result(api_key_configured=False)], + create_pddrc=True, + input_sequence=["", "", ""], + ) + assert "No API key configured" in output + + +def test_multiple_cli_results_warning_only_for_missing(tmp_path, monkeypatch): + """Multiple CLIs, warning only for the one missing API key.""" + output, _ = _run_setup_capture( + tmp_path, monkeypatch, + ref_csv_rows=SIMPLE_REF_CSV, + env_keys={"ANTHROPIC_API_KEY": "sk-test"}, + cli_results=[ + _make_cli_result(cli_name="claude", api_key_configured=True), + _make_cli_result(cli_name="codex", api_key_configured=False), ], - [], - ] - - found_keys = [("ANTHROPIC_API_KEY", "env"), ("OPENAI_API_KEY", "env")] - result = setup_tool._step2_configure_models(found_keys) - - assert result == {"Anthropic": 1} - - -# --------------------------------------------------------------------------- -# Tests for local LLM helpers -# --------------------------------------------------------------------------- - -def test_extract_ollama_models(): - """Test Ollama model name extraction from API response.""" - data = {"models": [{"name": "llama3.2:3b"}, {"name": "openhermes:latest"}, {"name": ""}]} - result = setup_tool._extract_ollama_models(data) - assert result == ["llama3.2:3b", "openhermes:latest"] - - -def test_extract_ollama_models_empty(): - """Test Ollama extraction with empty models list.""" - assert setup_tool._extract_ollama_models({"models": []}) == [] - assert setup_tool._extract_ollama_models({}) == [] - - -def test_extract_lm_studio_models(): - """Test LM Studio model name extraction from API response.""" - data = {"data": [{"id": "model-a"}, {"id": "model-b"}, {"id": ""}]} - result = setup_tool._extract_lm_studio_models(data) - assert result == ["model-a", "model-b"] - - -def test_extract_lm_studio_models_empty(): - """Test LM Studio extraction with empty data list.""" - assert setup_tool._extract_lm_studio_models({"data": []}) == [] - assert setup_tool._extract_lm_studio_models({}) == [] - - -# --------------------------------------------------------------------------- -# Tests for _run_fallback_menu -# --------------------------------------------------------------------------- - -@patch("pdd.pddrc_initializer.offer_pddrc_init") -@patch("pdd.model_tester.test_model_interactive") -@patch("pdd.provider_manager.add_provider_from_registry") -@patch("builtins.input") -def test_run_fallback_menu_options(mock_input, mock_add_provider, mock_test_model, mock_init_pddrc): - """Test the fallback menu loop and options.""" - mock_input.side_effect = ["1", "2", "3", "5", "4"] - - setup_tool._run_fallback_menu() - - mock_add_provider.assert_called_once() - mock_test_model.assert_called_once() - mock_init_pddrc.assert_called_once() - assert mock_input.call_count == 5 - - -@patch("builtins.input") -def test_run_fallback_menu_interrupt(mock_input, mock_print): - """Test exiting fallback menu via KeyboardInterrupt.""" - mock_input.side_effect = KeyboardInterrupt - - setup_tool._run_fallback_menu() - - assert any("Setup interrupted" in str(c) for c in mock_print.call_args_list) - - -# --------------------------------------------------------------------------- -# Tests for _prompt_for_api_key -# --------------------------------------------------------------------------- - -@patch("pdd.provider_manager._save_key_to_api_env") -@patch("pdd.setup_tool.getpass") -@patch("builtins.input") -def test_prompt_for_api_key_adds_key(mock_input, mock_getpass, mock_save): - """Test that prompt flow saves a key and returns it.""" - mock_input.side_effect = [ - "1", # Select Anthropic - "n", # Don't add another - ] - mock_getpass.getpass.return_value = "sk-test-key-123" - - result = setup_tool._prompt_for_api_key() - - assert len(result) == 1 - assert result[0][0] == "ANTHROPIC_API_KEY" - mock_save.assert_called_once_with("ANTHROPIC_API_KEY", "sk-test-key-123") - - -@patch("pdd.provider_manager._save_key_to_api_env") -@patch("pdd.setup_tool.getpass") -@patch("builtins.input") -def test_prompt_for_api_key_skip(mock_input, mock_getpass, mock_save): - """Test that skip option returns empty list.""" - skip_idx = len(setup_tool._PROMPT_PROVIDERS) + 2 - mock_input.side_effect = [str(skip_idx)] - - result = setup_tool._prompt_for_api_key() + create_pddrc=True, + input_sequence=["", "", ""], + ) + assert "No API key configured" in output + + +# =========================================================================== +# III. Auto-Phase Failure / Fallback +# =========================================================================== + +def test_auto_phase_failure_triggers_menu(tmp_path, monkeypatch): + """Auto-phase fails → 'Setup incomplete' and options menu shown.""" + captured = [] + mock_console = MagicMock() + mock_console.print = lambda *a, **kw: captured.append( + " ".join(str(x) for x in a)) + + with patch("pdd.setup_tool._run_auto_phase", return_value=None), \ + patch("pdd.setup_tool._run_options_menu") as mock_menu, \ + patch("pdd.setup_tool._print_exit_summary"), \ + patch("pdd.setup_tool._print_pdd_logo"), \ + patch("pdd.setup_tool._console", mock_console), \ + patch("pdd.cli_detector.detect_and_bootstrap_cli", + return_value=[_make_cli_result()]): + setup_tool.run_setup() - assert result == [] - mock_save.assert_not_called() + output = "\n".join(captured) + assert "Setup incomplete" in output + mock_menu.assert_called_once() -@patch("pdd.provider_manager._save_key_to_api_env") -@patch("pdd.setup_tool.getpass") -@patch("builtins.input") -def test_prompt_for_api_key_empty_value_skips(mock_input, mock_getpass, mock_save): - """Test that empty key value is rejected gracefully.""" - skip_idx = len(setup_tool._PROMPT_PROVIDERS) + 2 - mock_input.side_effect = [ - "1", # Select Anthropic - str(skip_idx), # Skip after empty key - ] - mock_getpass.getpass.return_value = "" +# =========================================================================== +# IV. Interrupt Handling +# =========================================================================== - result = setup_tool._prompt_for_api_key() +def test_keyboard_interrupt_phase1(): + """KeyboardInterrupt during CLI bootstrap → clean exit.""" + captured = [] + with patch("pdd.cli_detector.detect_and_bootstrap_cli", + side_effect=KeyboardInterrupt), \ + patch("pdd.setup_tool._print_pdd_logo"), \ + patch("builtins.print", lambda *a, **kw: captured.append( + " ".join(str(x) for x in a))): + setup_tool.run_setup() + assert any("Setup interrupted" in line for line in captured) + + +def test_keyboard_interrupt_phase2(): + """KeyboardInterrupt during auto phase → clean exit.""" + captured = [] + with patch("pdd.cli_detector.detect_and_bootstrap_cli", + return_value=[_make_cli_result()]), \ + patch("pdd.setup_tool._run_auto_phase", + side_effect=KeyboardInterrupt), \ + patch("pdd.setup_tool._print_pdd_logo"), \ + patch("pdd.setup_tool._console", MagicMock()), \ + patch("builtins.print", lambda *a, **kw: captured.append( + " ".join(str(x) for x in a))): + setup_tool.run_setup() + assert any("Setup interrupted" in line for line in captured) + + +# =========================================================================== +# V. Key Scanning (via run_setup) +# =========================================================================== + +def test_scan_finds_env_keys(tmp_path, monkeypatch): + """Keys in os.environ → found with 'shell environment' source.""" + output, _ = _run_setup_capture( + tmp_path, monkeypatch, + ref_csv_rows=SIMPLE_REF_CSV, + env_keys={"ANTHROPIC_API_KEY": "sk-test"}, + create_pddrc=True, + input_sequence=["", "", ""], + ) + assert "ANTHROPIC_API_KEY" in output + assert "shell environment" in output + assert "1 API key" in output + + +def test_scan_finds_multiple_keys(tmp_path, monkeypatch): + """Multiple keys in os.environ → all found.""" + output, _ = _run_setup_capture( + tmp_path, monkeypatch, + ref_csv_rows=SIMPLE_REF_CSV, + env_keys={"ANTHROPIC_API_KEY": "sk-test", "OPENAI_API_KEY": "sk-openai"}, + create_pddrc=True, + input_sequence=["", "", ""], + ) + assert "ANTHROPIC_API_KEY" in output + assert "OPENAI_API_KEY" in output + assert "2 API key" in output + + +def test_scan_no_keys_prompts_user(tmp_path, monkeypatch): + """No keys found → interactive key prompt is triggered.""" + # Use only the single-row ref CSV so skip is option "2" + ref_rows = [SIMPLE_REF_CSV[0]] + + captured = [] + mock_console = MagicMock() + mock_console.print = lambda *a, **kw: captured.append( + " ".join(str(x) for x in a)) + + with patch("pdd.setup_tool._run_auto_phase", return_value=None), \ + patch("pdd.setup_tool._print_exit_summary"), \ + patch("pdd.setup_tool._print_pdd_logo"), \ + patch("pdd.setup_tool._run_options_menu"), \ + patch("pdd.setup_tool._console", mock_console), \ + patch("pdd.cli_detector.detect_and_bootstrap_cli", + return_value=[_make_cli_result(skipped=True, cli_name="")]), \ + patch("builtins.input", return_value=""), \ + patch("builtins.print", + lambda *a, **kw: captured.append(" ".join(str(x) for x in a))): + setup_tool.run_setup() - assert result == [] - mock_save.assert_not_called() + output = "\n".join(captured) + # Verify auto-phase failure path was hit (keys couldn't be found) + assert "Setup incomplete" in output + + +def test_scan_multi_var_provider_grouped(tmp_path, monkeypatch): + """Pipe-delimited api_key → grouped display with var counts.""" + output, _ = _run_setup_capture( + tmp_path, monkeypatch, + ref_csv_rows=BEDROCK_REF_CSV, + env_keys={ + "AWS_ACCESS_KEY_ID": "AKIAEXAMPLE", + "AWS_SECRET_ACCESS_KEY": "secret123", + "AWS_REGION_NAME": "us-east-1", + }, + create_pddrc=True, + input_sequence=["", "", ""], + ) + assert "3/3" in output + assert "AWS Bedrock" in output + + +def test_scan_multi_var_provider_partial(tmp_path, monkeypatch): + """Partial multi-var credentials → missing vars shown.""" + output, _ = _run_setup_capture( + tmp_path, monkeypatch, + ref_csv_rows=BEDROCK_REF_CSV, + env_keys={"AWS_ACCESS_KEY_ID": "AKIAEXAMPLE"}, + create_pddrc=True, + input_sequence=["", "", ""], + ) + assert "1/3" in output + assert "missing" in output.lower() + + +# =========================================================================== +# VI. Model Configuration (via run_setup) +# =========================================================================== + +def test_models_added_from_reference_csv(tmp_path, monkeypatch): + """Matching API keys → models written to user CSV.""" + output, _ = _run_setup_capture( + tmp_path, monkeypatch, + ref_csv_rows=SIMPLE_REF_CSV, + env_keys={"ANTHROPIC_API_KEY": "sk-test"}, + create_pddrc=True, + input_sequence=["", "", ""], + ) + # Verify user CSV was created with the matching model + user_csv = tmp_path / "home" / ".pdd" / "llm_model.csv" + assert user_csv.exists() + content = user_csv.read_text() + assert "claude-sonnet" in content + # OpenAI should NOT be present (no key set) + assert "gpt-4o" not in content + + +def test_models_deduplicated(tmp_path, monkeypatch): + """Existing models not duplicated.""" + existing = [SIMPLE_REF_CSV[0].copy()] + output, _ = _run_setup_capture( + tmp_path, monkeypatch, + ref_csv_rows=SIMPLE_REF_CSV, + user_csv_rows=existing, + env_keys={"ANTHROPIC_API_KEY": "sk-test"}, + create_pddrc=True, + input_sequence=["", "", ""], + ) + # Should mention "already" loaded rather than new additions + assert "already" in output.lower() or "All matching" in output + + +def test_local_models_skipped(tmp_path, monkeypatch): + """ollama/lm_studio/localhost models excluded from user CSV.""" + combined = SIMPLE_REF_CSV + LOCAL_MODELS_CSV + output, _ = _run_setup_capture( + tmp_path, monkeypatch, + ref_csv_rows=combined, + env_keys={"ANTHROPIC_API_KEY": "sk-test"}, + create_pddrc=True, + input_sequence=["", "", ""], + ) + user_csv = tmp_path / "home" / ".pdd" / "llm_model.csv" + assert user_csv.exists() + content = user_csv.read_text() + assert "ollama" not in content + assert "lm_studio" not in content + + +def test_device_flow_models_included(tmp_path, monkeypatch): + """Empty api_key (device flow) models always included.""" + combined = SIMPLE_REF_CSV + DEVICE_FLOW_CSV + output, _ = _run_setup_capture( + tmp_path, monkeypatch, + ref_csv_rows=combined, + env_keys={"ANTHROPIC_API_KEY": "sk-test"}, + create_pddrc=True, + input_sequence=["", "", ""], + ) + user_csv = tmp_path / "home" / ".pdd" / "llm_model.csv" + assert user_csv.exists() + content = user_csv.read_text() + assert "copilot" in content.lower() + + +# =========================================================================== +# VII. .pddrc Handling (via run_setup) +# =========================================================================== + +def test_pddrc_exists_confirmed(tmp_path, monkeypatch): + """.pddrc already exists → 'detected' message shown.""" + output, _ = _run_setup_capture( + tmp_path, monkeypatch, + ref_csv_rows=SIMPLE_REF_CSV, + env_keys={"ANTHROPIC_API_KEY": "sk-test"}, + create_pddrc=True, + input_sequence=["", "", ""], + ) + assert "pddrc" in output.lower() + assert "detected" in output.lower() + + +def test_pddrc_created_on_confirm(tmp_path, monkeypatch): + """No .pddrc, user types 'y' → file created.""" + _run_setup_capture( + tmp_path, monkeypatch, + ref_csv_rows=SIMPLE_REF_CSV, + env_keys={"ANTHROPIC_API_KEY": "sk-test"}, + create_pddrc=False, + # step1 Enter, pddrc "y", step2 Enter, finish Enter + input_sequence=["", "y", "", ""], + ) + assert (tmp_path / "project" / ".pddrc").exists() + + +def test_pddrc_skipped_on_enter(tmp_path, monkeypatch): + """No .pddrc, user presses Enter → file not created.""" + output, _ = _run_setup_capture( + tmp_path, monkeypatch, + ref_csv_rows=SIMPLE_REF_CSV, + env_keys={"ANTHROPIC_API_KEY": "sk-test"}, + create_pddrc=False, + # step1 Enter, pddrc skip Enter, step2 Enter, finish Enter + input_sequence=["", "", "", ""], + ) + assert not (tmp_path / "project" / ".pddrc").exists() + + +# =========================================================================== +# VIII. Model Testing (via run_setup) +# =========================================================================== + +def test_model_test_success(tmp_path, monkeypatch): + """Model test succeeds → 'responded OK' in output.""" + output, _ = _run_setup_capture( + tmp_path, monkeypatch, + ref_csv_rows=SIMPLE_REF_CSV, + env_keys={"ANTHROPIC_API_KEY": "sk-test"}, + create_pddrc=True, + test_result=TEST_SUCCESS_RESULT, + input_sequence=["", "", ""], + ) + assert "responded OK" in output or "OK" in output + + +def test_model_test_failure(tmp_path, monkeypatch): + """Model test fails → error message in output.""" + output, _ = _run_setup_capture( + tmp_path, monkeypatch, + ref_csv_rows=SIMPLE_REF_CSV, + env_keys={"ANTHROPIC_API_KEY": "sk-test"}, + create_pddrc=True, + test_result=TEST_FAILURE_RESULT, + input_sequence=["", "", ""], + ) + assert "Authentication error" in output or "failed" in output.lower() + + +# =========================================================================== +# IX. Exit Summary +# =========================================================================== + +def test_exit_summary_writes_file(tmp_path, monkeypatch): + """PDD-SETUP-SUMMARY.txt created after setup.""" + _run_setup_capture( + tmp_path, monkeypatch, + ref_csv_rows=SIMPLE_REF_CSV, + env_keys={"ANTHROPIC_API_KEY": "sk-test"}, + create_pddrc=True, + input_sequence=["", "", ""], + ) + summary = tmp_path / "project" / "PDD-SETUP-SUMMARY.txt" + assert summary.exists() + content = summary.read_text() + assert "PDD Setup Complete" in content + assert "QUICK START" in content + + +def test_exit_summary_creates_sample_prompt(tmp_path, monkeypatch): + """success_python.prompt created if not existing.""" + _run_setup_capture( + tmp_path, monkeypatch, + ref_csv_rows=SIMPLE_REF_CSV, + env_keys={"ANTHROPIC_API_KEY": "sk-test"}, + create_pddrc=True, + input_sequence=["", "", ""], + ) + assert (tmp_path / "project" / "success_python.prompt").exists() + + +def test_exit_summary_quick_start_printed(tmp_path, monkeypatch): + """QUICK START section appears in terminal output.""" + output, _ = _run_setup_capture( + tmp_path, monkeypatch, + ref_csv_rows=SIMPLE_REF_CSV, + env_keys={"ANTHROPIC_API_KEY": "sk-test"}, + create_pddrc=True, + input_sequence=["", "", ""], + ) + assert "QUICK START" in output + assert "pdd generate" in output + + +# =========================================================================== +# X. Options Menu (via run_setup with 'm' input) +# =========================================================================== + +def test_menu_add_provider(tmp_path, monkeypatch): + """User selects '1' in menu → add_provider_from_registry called.""" + _, mocks = _run_setup_capture( + tmp_path, monkeypatch, + ref_csv_rows=SIMPLE_REF_CSV, + env_keys={"ANTHROPIC_API_KEY": "sk-test"}, + create_pddrc=True, + input_sequence=["", "", "m", "1", ""], + ) + mocks["add_provider"].assert_called_once() + + +def test_menu_test_model(tmp_path, monkeypatch): + """User selects '2' in menu → test_model_interactive called.""" + _, mocks = _run_setup_capture( + tmp_path, monkeypatch, + ref_csv_rows=SIMPLE_REF_CSV, + env_keys={"ANTHROPIC_API_KEY": "sk-test"}, + create_pddrc=True, + input_sequence=["", "", "m", "2", ""], + ) + mocks["test_interactive"].assert_called_once() + + +def test_menu_enter_exits(tmp_path, monkeypatch): + """User presses Enter in menu → exits, no actions.""" + _, mocks = _run_setup_capture( + tmp_path, monkeypatch, + ref_csv_rows=SIMPLE_REF_CSV, + env_keys={"ANTHROPIC_API_KEY": "sk-test"}, + create_pddrc=True, + input_sequence=["", "", "m", ""], + ) + mocks["add_provider"].assert_not_called() + mocks["test_interactive"].assert_not_called() + + +def test_menu_invalid_option(tmp_path, monkeypatch): + """User enters invalid option → 'Invalid option' message.""" + output, mocks = _run_setup_capture( + tmp_path, monkeypatch, + ref_csv_rows=SIMPLE_REF_CSV, + env_keys={"ANTHROPIC_API_KEY": "sk-test"}, + create_pddrc=True, + input_sequence=["", "", "m", "9", ""], + ) + assert "Invalid" in output or "invalid" in output.lower() + mocks["add_provider"].assert_not_called() From e5ec299cb87d19100cd5f8bc10e0e35becfaad9f Mon Sep 17 00:00:00 2001 From: Niti Goyal Date: Fri, 20 Feb 2026 09:31:07 -0500 Subject: [PATCH 10/10] Address Copilot comments --- context/provider_manager_example.py | 2 +- pdd/core/utils.py | 2 +- pdd/provider_manager.py | 32 ++++++++++++++++++++++++++--- 3 files changed, 31 insertions(+), 5 deletions(-) diff --git a/context/provider_manager_example.py b/context/provider_manager_example.py index b3a881997..a09690a36 100644 --- a/context/provider_manager_example.py +++ b/context/provider_manager_example.py @@ -7,7 +7,7 @@ project_root = Path(__file__).resolve().parent.parent sys.path.append(str(project_root)) -from pdd.setup.provider_manager import ( +from pdd.provider_manager import ( add_provider_from_registry, add_custom_provider, remove_models_by_provider, diff --git a/pdd/core/utils.py b/pdd/core/utils.py index e43bd3249..9f3523e79 100644 --- a/pdd/core/utils.py +++ b/pdd/core/utils.py @@ -85,6 +85,6 @@ def _should_show_onboarding_reminder(ctx: click.Context) -> bool: def _run_setup_utility() -> None: """Execute the interactive setup utility script.""" - result = subprocess.run([sys.executable, "-m", "pdd.setup.setup_tool"]) + result = subprocess.run([sys.executable, "-m", "pdd.setup_tool"]) if result.returncode not in (0, None): raise RuntimeError(f"Setup utility exited with status {result.returncode}") diff --git a/pdd/provider_manager.py b/pdd/provider_manager.py index bcba39588..ffa0f54c0 100644 --- a/pdd/provider_manager.py +++ b/pdd/provider_manager.py @@ -319,12 +319,38 @@ def _write_api_env_atomic(path: Path, lines: List[str]) -> None: raise +def _quote_for_shell(value: str, shell: str) -> str: + """Quote a value for the given shell, handling shell-specific edge cases. + + - POSIX shells (bash/zsh/sh/ksh): shlex.quote() is fully correct. + - fish: single quotes treat \\\\ and \\' as escape sequences (unlike POSIX), + so we must escape backslashes and single quotes within single quotes. + - csh/tcsh: single quotes DO prevent $ expansion, but ! (history expansion) + is never suppressed by any quoting. We backslash-escape ! outside quotes. + """ + if shell == "fish": + # fish single quotes recognise \\' and \\\\ as escapes + escaped = value.replace("\\", "\\\\").replace("'", "\\'") + return f"'{escaped}'" + elif shell in ("csh", "tcsh"): + # csh single quotes are mostly POSIX-like, but ! is never suppressed. + # Strategy: use shlex.quote() for the base quoting, then break out + # any ! characters so they can be backslash-escaped outside quotes. + if "!" not in value: + return shlex.quote(value) + # Split on !, quote each segment, rejoin with escaped ! + parts = value.split("!") + quoted_parts = [shlex.quote(p) for p in parts] + return "\\!".join(quoted_parts) + else: + # bash, zsh, ksh, sh — shlex.quote() is fully correct + return shlex.quote(value) + + def _build_env_export_line(key_name: str, key_value: str) -> str: """Build a shell-appropriate export line for the given key/value.""" shell = _get_shell_name() - # Use shlex.quote() for proper shell escaping of special characters - # This handles $, ", ', `, \, spaces, and other problematic chars - quoted_value = shlex.quote(key_value) + quoted_value = _quote_for_shell(key_value, shell) if shell == "fish": return f"set -gx {key_name} {quoted_value}\n"