diff --git a/.env_example b/.env_example index 2016926fd..656cfeee6 100644 --- a/.env_example +++ b/.env_example @@ -1,4 +1,4 @@ -# This is an example of the .env file. Copy to .env and fill in your secrets. +# This is an example of the .env file. Copy to ~/.pyrit/.env and fill in your endpoint configurations. # Note that if you are using Entra authentication for certain Azure resources (use_entra_auth = True in PyRIT), # keys for those resources are not needed. diff --git a/doc/setup/populating_secrets.md b/doc/setup/populating_secrets.md index 7c2506d23..3861aa9c1 100644 --- a/doc/setup/populating_secrets.md +++ b/doc/setup/populating_secrets.md @@ -21,11 +21,33 @@ With this setup, you can run most PyRIT notebooks and examples! ## Setting Up Environment Variables -PyRIT loads secrets and endpoints from environment variables or a `.env` file in your repo root. The `.env_example` file shows the format and available options. +PyRIT loads secrets and endpoints from environment variables or `.env` files. The `.env_example` file shows the format and available options. + +### Environment Variable Precedence + +When `initialize_pyrit_async` runs, environment variables are loaded in a specific order. **Later sources override earlier ones:** + +```{mermaid} +flowchart LR + A["1. System Environment"] --> B{"env_files provided?"} + B -->|No| C["2. ~/.pyrit/.env"] + C --> D["3. ~/.pyrit/.env.local"] + B -->|Yes| E["2. Your specified files (in order)"] +``` + +**Default behavior** (no `env_files` argument): + +| Priority | Source | Description | +|----------|--------|-------------| +| Lowest | System environment variables | Always loaded as the baseline | +| Medium | `~/.pyrit/.env` | Default config file (loaded if it exists) | +| Highest | `~/.pyrit/.env.local` | Local overrides (loaded if it exists) | + +**Custom behavior** (with `env_files` argument): Only your specified files are loaded, in order. Default paths are completely ignored. ### Creating Your .env File -1. Copy `.env_example` to `.env` in your repository root +1. Copy `.env_example` to `.env` in your home directory in ~/.pyrit/.env 2. Add your API credentials. For example, for Azure OpenAI: ```bash @@ -37,12 +59,33 @@ To find these values in Azure Portal: `Azure Portal > Azure AI Services > Azure ### Using .env.local for Overrides -You can use `.env.local` to override values in `.env` without modifying the base file. This is useful for: +You can use `~/.pyrit/.env.local` to override values in `~/.pyrit/.env` without modifying the base file. This is useful for: - Testing different targets - Using personal credentials instead of shared ones - Switching between configurations quickly -Simply create `.env.local` and add any variables you want to override. PyRIT will prioritize `.env.local` over `.env`. +Simply create `.env.local` in your `~/.pyrit/` directory and add any variables you want to override. + +### Custom Environment Files + +You can also specify exactly which `.env` files to load using the `env_files` parameter: + +```python +from pathlib import Path +from pyrit.setup import initialize_pyrit_async + +await initialize_pyrit_async( + memory_db_type="InMemory", + env_files=[Path("./project-config.env"), Path("./local-overrides.env")] +) +``` + +When `env_files` is provided: +- **Only** the specified files are loaded (default paths are skipped entirely) +- Files are loaded in order—later files override earlier ones +- A `ValueError` is raised if any specified file doesn't exist + +The CLI also supports custom environment files via the `--env-files` flag. ## Authentication Options diff --git a/pyrit/auxiliary_attacks/gcg/experiments/run.py b/pyrit/auxiliary_attacks/gcg/experiments/run.py index 3b21a18f7..bb57d823b 100644 --- a/pyrit/auxiliary_attacks/gcg/experiments/run.py +++ b/pyrit/auxiliary_attacks/gcg/experiments/run.py @@ -39,7 +39,7 @@ def run_trainer(*, model_name: str, setup: str = "single", **extra_config_parame "Model name not supported. Currently supports 'mistral', 'llama_2', 'llama_3', 'vicuna', and 'phi_3_mini'" ) - _load_environment_files() + _load_environment_files(env_files=None) hf_token = os.environ.get("HUGGINGFACE_TOKEN") if not hf_token: raise ValueError("Please set the HUGGINGFACE_TOKEN environment variable") diff --git a/pyrit/cli/frontend_core.py b/pyrit/cli/frontend_core.py index 0b67ac042..370c53fdf 100644 --- a/pyrit/cli/frontend_core.py +++ b/pyrit/cli/frontend_core.py @@ -77,6 +77,7 @@ def __init__( database: str = SQLITE, initialization_scripts: Optional[list[Path]] = None, initializer_names: Optional[list[str]] = None, + env_files: Optional[list[Path]] = None, log_level: str = "WARNING", ): """ @@ -86,6 +87,7 @@ def __init__( database: Database type (InMemory, SQLite, or AzureSQL). initialization_scripts: Optional list of initialization script paths. initializer_names: Optional list of built-in initializer names to run. + env_files: Optional list of environment file paths to load in order. log_level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL). Defaults to WARNING. Raises: @@ -95,6 +97,7 @@ def __init__( self._database = validate_database(database=database) self._initialization_scripts = initialization_scripts self._initializer_names = initializer_names + self._env_files = env_files self._log_level = validate_log_level(log_level=log_level) # Lazy-loaded registries @@ -119,6 +122,7 @@ async def initialize_async(self) -> None: memory_db_type=self._database, initialization_scripts=None, initializers=None, + env_files=self._env_files, ) # Load registries @@ -259,6 +263,7 @@ async def run_scenario_async( memory_db_type=context._database, initialization_scripts=context._initialization_scripts, initializers=initializer_instances, + env_files=context._env_files, ) # Get scenario class @@ -557,6 +562,46 @@ def wrapper(value): return wrapper +def resolve_initialization_scripts(script_paths: list[str]) -> list[Path]: + """ + Resolve initialization script paths. + + Args: + script_paths: List of script path strings. + + Returns: + List of resolved Path objects. + + Raises: + FileNotFoundError: If a script path does not exist. + """ + from pyrit.cli.initializer_registry import InitializerRegistry + + return InitializerRegistry.resolve_script_paths(script_paths=script_paths) + + +def resolve_env_files(*, env_file_paths: list[str]) -> list[Path]: + """ + Resolve environment file paths to absolute Path objects. + + Args: + env_file_paths: List of environment file path strings. + + Returns: + List of resolved Path objects. + + Raises: + ValueError: If any path does not exist. + """ + resolved_paths = [] + for path_str in env_file_paths: + path = Path(path_str).resolve() + if not path.exists(): + raise ValueError(f"Environment file not found: {path}") + resolved_paths.append(path) + return resolved_paths + + # Argparse-compatible validators # # These wrappers adapt our core validators (which use keyword-only parameters and raise @@ -573,6 +618,7 @@ def wrapper(value): validate_log_level_argparse = _argparse_validator(validate_log_level) positive_int = _argparse_validator(lambda v: validate_integer(v, min_value=1)) non_negative_int = _argparse_validator(lambda v: validate_integer(v, min_value=0)) +resolve_env_files_argparse = _argparse_validator(resolve_env_files) def parse_memory_labels(json_string: str) -> dict[str, str]: @@ -604,24 +650,6 @@ def parse_memory_labels(json_string: str) -> dict[str, str]: return labels -def resolve_initialization_scripts(script_paths: list[str]) -> list[Path]: - """ - Resolve initialization script paths. - - Args: - script_paths: List of script path strings. - - Returns: - List of resolved Path objects. - - Raises: - FileNotFoundError: If a script path does not exist. - """ - from pyrit.cli.initializer_registry import InitializerRegistry - - return InitializerRegistry.resolve_script_paths(script_paths=script_paths) - - def get_default_initializer_discovery_path() -> Path: """ Get the default path for discovering initializers. @@ -688,6 +716,8 @@ async def print_initializers_list_async(*, context: FrontendCore, discovery_path ARG_HELP = { "initializers": "Built-in initializer names to run before the scenario (e.g., openai_objective_target)", "initialization_scripts": "Paths to custom Python initialization scripts to run before the scenario", + "env_files": "Paths to environment files to load in order (e.g., .env.production .env.local). Later files " + "override earlier ones.", "scenario_strategies": "List of strategy names to run (e.g., base64 rot13)", "max_concurrency": "Maximum number of concurrent attack executions (must be >= 1)", "max_retries": "Maximum number of automatic retries on exception (must be >= 0)", @@ -728,6 +758,7 @@ def parse_run_arguments(*, args_string: str) -> dict[str, Any]: "scenario_name": parts[0], "initializers": None, "initialization_scripts": None, + "env_files": None, "scenario_strategies": None, "max_concurrency": None, "max_retries": None, @@ -752,6 +783,13 @@ def parse_run_arguments(*, args_string: str) -> dict[str, Any]: while i < len(parts) and not parts[i].startswith("--"): result["initialization_scripts"].append(parts[i]) i += 1 + elif parts[i] == "--env-files": + # Collect env file paths until next flag + result["env_files"] = [] + i += 1 + while i < len(parts) and not parts[i].startswith("--"): + result["env_files"].append(parts[i]) + i += 1 elif parts[i] in ("--strategies", "-s"): # Collect strategies until next flag result["scenario_strategies"] = [] diff --git a/pyrit/cli/pyrit_scan.py b/pyrit/cli/pyrit_scan.py index 9cb1affa1..19d9d8535 100644 --- a/pyrit/cli/pyrit_scan.py +++ b/pyrit/cli/pyrit_scan.py @@ -94,6 +94,13 @@ def parse_args(args=None) -> Namespace: help=frontend_core.ARG_HELP["initialization_scripts"], ) + parser.add_argument( + "--env-files", + type=str, + nargs="+", + help=frontend_core.ARG_HELP["env_files"], + ) + parser.add_argument( "--strategies", "-s", @@ -152,9 +159,18 @@ def main(args=None) -> int: print(f"Error: {e}") return 1 + env_files = None + if parsed_args.env_files: + try: + env_files = frontend_core.resolve_env_files(env_file_paths=parsed_args.env_files) + except ValueError as e: + print(f"Error: {e}") + return 1 + context = frontend_core.FrontendCore( database=parsed_args.database, initialization_scripts=initialization_scripts, + env_files=env_files, log_level=parsed_args.log_level, ) @@ -181,11 +197,17 @@ def main(args=None) -> int: script_paths=parsed_args.initialization_scripts ) + # Collect environment files + env_files = None + if parsed_args.env_files: + env_files = frontend_core.resolve_env_files(env_file_paths=parsed_args.env_files) + # Create context with initializers context = frontend_core.FrontendCore( database=parsed_args.database, initialization_scripts=initialization_scripts, initializer_names=parsed_args.initializers, + env_files=env_files, log_level=parsed_args.log_level, ) diff --git a/pyrit/cli/pyrit_shell.py b/pyrit/cli/pyrit_shell.py index 6ae915781..d149037c7 100644 --- a/pyrit/cli/pyrit_shell.py +++ b/pyrit/cli/pyrit_shell.py @@ -39,10 +39,12 @@ class PyRITShell(cmd.Cmd): Shell Startup Options: --database Database type (InMemory, SQLite, AzureSQL) - default for all runs --log-level Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL) - default for all runs + --env-files ... Environment files to load in order - default for all runs Run Command Options: --initializers ... Built-in initializers to run before the scenario --initialization-scripts <...> Custom Python scripts to run before the scenario + --env-files ... Environment files to load in order (overrides startup default) --strategies, -s ... Strategy names to use --max-concurrency Maximum concurrent operations --max-retries Maximum retry attempts @@ -97,6 +99,7 @@ def __init__( self.context = context self.default_database = context._database self.default_log_level = context._log_level + self.default_env_files = context._env_files # Track scenario execution history: list of (command_string, ScenarioResult) tuples self._scenario_history: list[tuple[str, ScenarioResult]] = [] @@ -150,6 +153,7 @@ def do_run(self, line): Options: --initializers ... Built-in initializers to run before the scenario --initialization-scripts <...> Custom Python scripts to run before the scenario + --env-files ... Environment files to load in order --strategies, -s ... Strategy names to use --max-concurrency Maximum concurrent operations --max-retries Maximum retry attempts @@ -214,11 +218,24 @@ def do_run(self, line): print(f"Error: {e}") return + # Resolve env files if provided + resolved_env_files = None + if args["env_files"]: + try: + resolved_env_files = frontend_core.resolve_env_files(env_file_paths=args["env_files"]) + except ValueError as e: + print(f"Error: {e}") + return + else: + # Use default env files from shell startup + resolved_env_files = self.default_env_files + # Create a context for this run with overrides run_context = frontend_core.FrontendCore( database=args["database"] or self.default_database, initialization_scripts=resolved_scripts, initializer_names=args["initializers"], + env_files=resolved_env_files, log_level=args["log_level"] or self.default_log_level, ) # Use the existing registries (don't reinitialize) @@ -455,13 +472,30 @@ def main(): help="Default logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL) (default: WARNING, can be overridden per-run)", ) + parser.add_argument( + "--env-files", + type=str, + nargs="+", + help="Environment files to load in order (default for all runs, can be overridden per-run)", + ) + args = parser.parse_args() + # Resolve env files if provided + env_files = None + if args.env_files: + try: + env_files = frontend_core.resolve_env_files(env_file_paths=args.env_files) + except ValueError as e: + print(f"Error: {e}") + return 1 + # Create context (initializers are specified per-run, not at startup) context = frontend_core.FrontendCore( database=args.database, initialization_scripts=None, initializer_names=None, + env_files=env_files, log_level=args.log_level, ) diff --git a/pyrit/common/path.py b/pyrit/common/path.py index 40340f28a..749faf6c1 100644 --- a/pyrit/common/path.py +++ b/pyrit/common/path.py @@ -31,6 +31,8 @@ def in_git_repo() -> bool: PYRIT_PATH = pathlib.Path(__file__, "..", "..").resolve() +CONFIGURATION_DIRECTORY_PATH = pathlib.Path.home() / ".pyrit" + # Points to the root of the project HOME_PATH = pathlib.Path(PYRIT_PATH, "..").resolve() diff --git a/pyrit/setup/initialization.py b/pyrit/setup/initialization.py index ba16b2833..b8ebfcf45 100644 --- a/pyrit/setup/initialization.py +++ b/pyrit/setup/initialization.py @@ -28,27 +28,72 @@ MemoryDatabaseType = Literal["InMemory", "SQLite", "AzureSQL"] -def _load_environment_files() -> None: +def _load_environment_files(env_files: Optional[Sequence[pathlib.Path]], *, silent: bool = False) -> None: """ - Load the base environment file from .env if it exists, - then load a single .env.local file if it exists, overriding previous values. - """ - base_file_path = path.HOME_PATH / ".env" - local_file_path = path.HOME_PATH / ".env.local" + Load environment files in the order they are provided. + Later files override values from earlier files. - # Load the base .env file if it exists - if base_file_path.exists(): - dotenv.load_dotenv(base_file_path, override=True, interpolate=True) - logger.info(f"Loaded {base_file_path}") - else: - dotenv.load_dotenv(verbose=True) + Args: + env_files: Optional sequence of environment file paths. If None, loads default + .env and .env.local from PyRIT home directory (only if they exist). + silent: If True, suppresses print statements about environment file loading. + Defaults to False. - # Load the .env.local file if it exists, to override base .env values - if local_file_path.exists(): - dotenv.load_dotenv(local_file_path, override=True, interpolate=True) - logger.info(f"Loaded {local_file_path}") + Raises: + ValueError: If any provided env_files do not exist. + """ + # Validate env_files exist if they were provided + if env_files is not None: + if not silent: + _print_msg(f"Loading custom environment files: {[str(f) for f in env_files]}", quiet=silent, log=True) + for env_file in env_files: + if not env_file.exists(): + raise ValueError(f"Environment file not found: {env_file}") + + # By default load .env and .env.local from home directory of the package else: - dotenv.load_dotenv(dotenv_path=dotenv.find_dotenv(".env.local"), override=True, verbose=True) + default_files = [] + base_file = path.CONFIGURATION_DIRECTORY_PATH / ".env" + local_file = path.CONFIGURATION_DIRECTORY_PATH / ".env.local" + + if base_file.exists(): + default_files.append(base_file) + if local_file.exists(): + default_files.append(local_file) + + if not silent: + if default_files: + _print_msg( + f"Found default environment files: {[str(f) for f in default_files]}", quiet=silent, log=True + ) + else: + _print_msg( + "No default environment files found. Using system environment variables only.", + quiet=silent, + log=True, + ) + + env_files = default_files + + for env_file in env_files: + dotenv.load_dotenv(env_file, override=True, interpolate=True) + if not silent: + _print_msg(f"Loaded environment file: {env_file}", quiet=silent, log=True) + + +def _print_msg(message: str, quiet: bool, log: bool) -> None: + """ + Print a standard initialization message unless quiet is True. + + Args: + message (str): The message to print and/or log. + quiet (bool): If True, suppresses the initialization message. + log (bool): If True, logs the message using the logger. + """ + if not quiet: + print(message) + if log: + logger.info(message) def _load_initializers_from_scripts( @@ -189,6 +234,8 @@ async def initialize_pyrit_async( *, initialization_scripts: Optional[Sequence[Union[str, pathlib.Path]]] = None, initializers: Optional[Sequence["PyRITInitializer"]] = None, + env_files: Optional[Sequence[pathlib.Path]] = None, + silent: bool = False, **memory_instance_kwargs: Any, ) -> None: """ @@ -202,23 +249,17 @@ async def initialize_pyrit_async( or an 'initializers' variable that returns/contains a list of PyRITInitializer instances. initializers (Optional[Sequence[PyRITInitializer]]): Optional sequence of PyRITInitializer instances to execute directly. These provide type-safe, validated configuration with clear documentation. + env_files (Optional[Sequence[pathlib.Path]]): Optional sequence of environment file paths to load + in order. If not provided, will load default .env and .env.local files from PyRIT home if they exist. + All paths must be valid pathlib.Path objects. + silent (bool): If True, suppresses print statements about environment file loading. + Defaults to False. **memory_instance_kwargs (Optional[Any]): Additional keyword arguments to pass to the memory instance. Raises: - ValueError: If an unsupported memory_db_type is provided. + ValueError: If an unsupported memory_db_type is provided or if env_files contains non-existent files. """ - # Handle DuckDB deprecation before validation - if memory_db_type == "DuckDB": - logger.warning( - "DuckDB is no longer supported and has been replaced by SQLite for better compatibility and performance. " - "Please update your code to use SQLite instead. " - "For migration guidance, see the SQLite Memory documentation at: " - "doc/code/memory/1_sqlite_memory.ipynb. " - "Using in-memory SQLite instead." - ) - memory_db_type = IN_MEMORY - - _load_environment_files() + _load_environment_files(env_files=env_files, silent=silent) # Reset all default values before executing initialization scripts # This ensures a clean state for each initialization diff --git a/tests/unit/cli/test_pyrit_shell.py b/tests/unit/cli/test_pyrit_shell.py index f984ad0cd..320d5f5fc 100644 --- a/tests/unit/cli/test_pyrit_shell.py +++ b/tests/unit/cli/test_pyrit_shell.py @@ -123,7 +123,7 @@ def test_do_run_empty_line(self, capsys): def test_do_run_basic_scenario( self, mock_parse_args: MagicMock, - mock_run_scenario: AsyncMock, + _mock_run_scenario: AsyncMock, mock_asyncio_run: MagicMock, ): """Test do_run with basic scenario.""" @@ -138,6 +138,7 @@ def test_do_run_basic_scenario( "scenario_name": "test_scenario", "initializers": ["test_init"], "initialization_scripts": None, + "env_files": None, "scenario_strategies": None, "max_concurrency": None, "max_retries": None, @@ -197,6 +198,7 @@ def test_do_run_with_initialization_scripts( "scenario_name": "test_scenario", "initializers": None, "initialization_scripts": ["script.py"], + "env_files": None, "scenario_strategies": None, "max_concurrency": None, "max_retries": None, @@ -231,6 +233,7 @@ def test_do_run_with_missing_script( "scenario_name": "test_scenario", "initializers": None, "initialization_scripts": ["missing.py"], + "env_files": None, "scenario_strategies": None, "max_concurrency": None, "max_retries": None, @@ -266,6 +269,7 @@ def test_do_run_with_database_override( "scenario_name": "test_scenario", "initializers": ["test_init"], "initialization_scripts": None, + "env_files": None, "scenario_strategies": None, "max_concurrency": None, "max_retries": None, @@ -306,6 +310,7 @@ def test_do_run_with_exception( "scenario_name": "test_scenario", "initializers": ["test_init"], "initialization_scripts": None, + "env_files": None, "scenario_strategies": None, "max_concurrency": None, "max_retries": None, @@ -683,6 +688,7 @@ def test_run_with_all_parameters( "scenario_name": "test_scenario", "initializers": ["init1"], "initialization_scripts": None, + "env_files": None, "scenario_strategies": ["s1", "s2"], "max_concurrency": 10, "max_retries": 5, @@ -723,6 +729,7 @@ def test_run_stores_result_in_history( "scenario_name": "test_scenario", "initializers": ["test_init"], "initialization_scripts": None, + "env_files": None, "scenario_strategies": None, "max_concurrency": None, "max_retries": None, diff --git a/tests/unit/setup/test_initialization.py b/tests/unit/setup/test_initialization.py index ce67f1e41..bed550b51 100644 --- a/tests/unit/setup/test_initialization.py +++ b/tests/unit/setup/test_initialization.py @@ -2,6 +2,7 @@ # Licensed under the MIT license. import os +import pathlib import tempfile from unittest import mock @@ -9,7 +10,10 @@ from pyrit.common.apply_defaults import reset_default_values from pyrit.setup import IN_MEMORY, initialize_pyrit_async -from pyrit.setup.initialization import _load_initializers_from_scripts +from pyrit.setup.initialization import ( + _load_environment_files, + _load_initializers_from_scripts, +) class TestLoadInitializersFromScripts: @@ -104,3 +108,136 @@ async def test_invalid_memory_type_raises_error(self): """Test that invalid memory type raises ValueError.""" with pytest.raises(ValueError, match="is not a supported type"): await initialize_pyrit_async(memory_db_type="InvalidType") # type: ignore + + +class TestLoadEnvironmentFiles: + """Tests for _load_environment_files function and env_files parameter in initialize_pyrit_async.""" + + @pytest.mark.asyncio + @mock.patch("pyrit.setup.initialization.dotenv.load_dotenv") + @mock.patch("pyrit.setup.initialization.path.CONFIGURATION_DIRECTORY_PATH") + async def test_loads_default_env_files_when_none_provided(self, mock_config_path, mock_load_dotenv): + """Test that default .env and .env.local files are loaded when env_files is None.""" + # Create temporary directory and files + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = pathlib.Path(temp_dir) + env_file = temp_path / ".env" + env_local_file = temp_path / ".env.local" + + # Create the files + env_file.write_text("VAR1=value1") + env_local_file.write_text("VAR2=value2") + + # Mock CONFIGURATION_DIRECTORY_PATH to point to our temp directory + mock_config_path.__truediv__ = lambda self, other: temp_path / other + + # Call the function with None (default behavior) + _load_environment_files(env_files=None) + + # Verify both files were loaded + assert mock_load_dotenv.call_count == 2 + calls = [call[0][0] for call in mock_load_dotenv.call_args_list] + assert env_file in calls + assert env_local_file in calls + + @pytest.mark.asyncio + @mock.patch("pyrit.setup.initialization.dotenv.load_dotenv") + @mock.patch("pyrit.setup.initialization.path.CONFIGURATION_DIRECTORY_PATH") + async def test_only_loads_existing_default_files(self, mock_config_path, mock_load_dotenv): + """Test that only existing default files are loaded.""" + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = pathlib.Path(temp_dir) + env_file = temp_path / ".env" + + # Only create .env, not .env.local + env_file.write_text("VAR1=value1") + + mock_config_path.__truediv__ = lambda self, other: temp_path / other + + _load_environment_files(env_files=None) + + # Verify only one file was loaded + assert mock_load_dotenv.call_count == 1 + assert mock_load_dotenv.call_args[0][0] == env_file + + @pytest.mark.asyncio + @mock.patch("pyrit.setup.initialization.dotenv.load_dotenv") + async def test_loads_custom_env_files_in_order(self, mock_load_dotenv): + """Test that custom env_files are loaded in the order provided.""" + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = pathlib.Path(temp_dir) + env1 = temp_path / ".env.test" + env2 = temp_path / ".env.prod" + env3 = temp_path / ".env.local" + + # Create files + env1.write_text("VAR=test") + env2.write_text("VAR=prod") + env3.write_text("VAR=local") + + # Pass custom files + _load_environment_files(env_files=[env1, env2, env3]) + + # Verify all three files were loaded in order + assert mock_load_dotenv.call_count == 3 + call_args = [call[0][0] for call in mock_load_dotenv.call_args_list] + assert call_args == [env1, env2, env3] + + @pytest.mark.asyncio + async def test_raises_error_for_nonexistent_env_file(self): + """Test that ValueError is raised for non-existent env file.""" + nonexistent = pathlib.Path("/nonexistent/path/.env") + + with pytest.raises(ValueError, match="Environment file not found"): + _load_environment_files(env_files=[nonexistent]) + + @pytest.mark.asyncio + @mock.patch("pyrit.memory.central_memory.CentralMemory.set_memory_instance") + async def test_initialize_pyrit_with_custom_env_files(self, mock_set_memory): + """Test initialize_pyrit_async with custom env_files.""" + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = pathlib.Path(temp_dir) + env_file = temp_path / ".env.custom" + env_file.write_text("CUSTOM_VAR=custom_value") + + # Should not raise an error + await initialize_pyrit_async(memory_db_type=IN_MEMORY, env_files=[env_file]) + + mock_set_memory.assert_called_once() + + @pytest.mark.asyncio + @mock.patch("pyrit.memory.central_memory.CentralMemory.set_memory_instance") + async def test_initialize_pyrit_raises_for_nonexistent_env_file(self, mock_set_memory): + """Test that initialize_pyrit_async raises ValueError for non-existent env file.""" + nonexistent = pathlib.Path("/nonexistent/.env") + + with pytest.raises(ValueError, match="Environment file not found"): + await initialize_pyrit_async(memory_db_type=IN_MEMORY, env_files=[nonexistent]) + + @pytest.mark.asyncio + @mock.patch("pyrit.setup.initialization.dotenv.load_dotenv") + @mock.patch("pyrit.setup.initialization.path.HOME_PATH") + @mock.patch("pyrit.memory.central_memory.CentralMemory.set_memory_instance") + async def test_custom_env_files_override_default_behavior(self, mock_set_memory, mock_home_path, mock_load_dotenv): + """Test that passing custom env_files prevents loading default files.""" + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = pathlib.Path(temp_dir) + + # Create default files + default_env = temp_path / ".env" + default_env_local = temp_path / ".env.local" + default_env.write_text("DEFAULT=value") + default_env_local.write_text("DEFAULT_LOCAL=value") + + # Create custom file + custom_env = temp_path / ".env.custom" + custom_env.write_text("CUSTOM=value") + + mock_home_path.__truediv__ = lambda self, other: temp_path / other + + # Pass custom env_files - should NOT load defaults + await initialize_pyrit_async(memory_db_type=IN_MEMORY, env_files=[custom_env]) + + # Verify only custom file was loaded, not the default ones + assert mock_load_dotenv.call_count == 1 + assert mock_load_dotenv.call_args[0][0] == custom_env