Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .env_example
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# This is an example of the .env file. Copy to .env and fill in your secrets.
# This is an example of the .env file. Copy to ~/.pyrit/.env and fill in your endpoint configurations.
# Note that if you are using Entra authentication for certain Azure resources (use_entra_auth = True in PyRIT),
# keys for those resources are not needed.

Expand Down
51 changes: 47 additions & 4 deletions doc/setup/populating_secrets.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,33 @@ With this setup, you can run most PyRIT notebooks and examples!

## Setting Up Environment Variables

PyRIT loads secrets and endpoints from environment variables or a `.env` file in your repo root. The `.env_example` file shows the format and available options.
PyRIT loads secrets and endpoints from environment variables or `.env` files. The `.env_example` file shows the format and available options.

### Environment Variable Precedence

When `initialize_pyrit_async` runs, environment variables are loaded in a specific order. **Later sources override earlier ones:**

```{mermaid}
flowchart LR
A["1. System Environment"] --> B{"env_files provided?"}
B -->|No| C["2. ~/.pyrit/.env"]
C --> D["3. ~/.pyrit/.env.local"]
B -->|Yes| E["2. Your specified files (in order)"]
```

**Default behavior** (no `env_files` argument):

| Priority | Source | Description |
|----------|--------|-------------|
| Lowest | System environment variables | Always loaded as the baseline |
| Medium | `~/.pyrit/.env` | Default config file (loaded if it exists) |
| Highest | `~/.pyrit/.env.local` | Local overrides (loaded if it exists) |

**Custom behavior** (with `env_files` argument): Only your specified files are loaded, in order. Default paths are completely ignored.

### Creating Your .env File

1. Copy `.env_example` to `.env` in your repository root
1. Copy `.env_example` to `.env` in your home directory in ~/.pyrit/.env
2. Add your API credentials. For example, for Azure OpenAI:

```bash
Expand All @@ -37,12 +59,33 @@ To find these values in Azure Portal: `Azure Portal > Azure AI Services > Azure

### Using .env.local for Overrides

You can use `.env.local` to override values in `.env` without modifying the base file. This is useful for:
You can use `~/.pyrit/.env.local` to override values in `~/.pyrit/.env` without modifying the base file. This is useful for:
- Testing different targets
- Using personal credentials instead of shared ones
- Switching between configurations quickly

Simply create `.env.local` and add any variables you want to override. PyRIT will prioritize `.env.local` over `.env`.
Simply create `.env.local` in your `~/.pyrit/` directory and add any variables you want to override.

### Custom Environment Files

You can also specify exactly which `.env` files to load using the `env_files` parameter:

```python
from pathlib import Path
from pyrit.setup import initialize_pyrit_async

await initialize_pyrit_async(
memory_db_type="InMemory",
env_files=[Path("./project-config.env"), Path("./local-overrides.env")]
)
```

When `env_files` is provided:
- **Only** the specified files are loaded (default paths are skipped entirely)
- Files are loaded in order—later files override earlier ones
- A `ValueError` is raised if any specified file doesn't exist

The CLI also supports custom environment files via the `--env-files` flag.

## Authentication Options

Expand Down
2 changes: 1 addition & 1 deletion pyrit/auxiliary_attacks/gcg/experiments/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def run_trainer(*, model_name: str, setup: str = "single", **extra_config_parame
"Model name not supported. Currently supports 'mistral', 'llama_2', 'llama_3', 'vicuna', and 'phi_3_mini'"
)

_load_environment_files()
_load_environment_files(env_files=None)
hf_token = os.environ.get("HUGGINGFACE_TOKEN")
if not hf_token:
raise ValueError("Please set the HUGGINGFACE_TOKEN environment variable")
Expand Down
74 changes: 56 additions & 18 deletions pyrit/cli/frontend_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def __init__(
database: str = SQLITE,
initialization_scripts: Optional[list[Path]] = None,
initializer_names: Optional[list[str]] = None,
env_files: Optional[list[Path]] = None,
log_level: str = "WARNING",
):
"""
Expand All @@ -86,6 +87,7 @@ def __init__(
database: Database type (InMemory, SQLite, or AzureSQL).
initialization_scripts: Optional list of initialization script paths.
initializer_names: Optional list of built-in initializer names to run.
env_files: Optional list of environment file paths to load in order.
log_level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL). Defaults to WARNING.

Raises:
Expand All @@ -95,6 +97,7 @@ def __init__(
self._database = validate_database(database=database)
self._initialization_scripts = initialization_scripts
self._initializer_names = initializer_names
self._env_files = env_files
self._log_level = validate_log_level(log_level=log_level)

# Lazy-loaded registries
Expand All @@ -119,6 +122,7 @@ async def initialize_async(self) -> None:
memory_db_type=self._database,
initialization_scripts=None,
initializers=None,
env_files=self._env_files,
)

# Load registries
Expand Down Expand Up @@ -259,6 +263,7 @@ async def run_scenario_async(
memory_db_type=context._database,
initialization_scripts=context._initialization_scripts,
initializers=initializer_instances,
env_files=context._env_files,
)

# Get scenario class
Expand Down Expand Up @@ -557,6 +562,46 @@ def wrapper(value):
return wrapper


def resolve_initialization_scripts(script_paths: list[str]) -> list[Path]:
"""
Resolve initialization script paths.

Args:
script_paths: List of script path strings.

Returns:
List of resolved Path objects.

Raises:
FileNotFoundError: If a script path does not exist.
"""
from pyrit.cli.initializer_registry import InitializerRegistry

return InitializerRegistry.resolve_script_paths(script_paths=script_paths)


def resolve_env_files(*, env_file_paths: list[str]) -> list[Path]:
"""
Resolve environment file paths to absolute Path objects.

Args:
env_file_paths: List of environment file path strings.

Returns:
List of resolved Path objects.

Raises:
ValueError: If any path does not exist.
"""
resolved_paths = []
for path_str in env_file_paths:
path = Path(path_str).resolve()
if not path.exists():
raise ValueError(f"Environment file not found: {path}")
resolved_paths.append(path)
return resolved_paths


# Argparse-compatible validators
#
# These wrappers adapt our core validators (which use keyword-only parameters and raise
Expand All @@ -573,6 +618,7 @@ def wrapper(value):
validate_log_level_argparse = _argparse_validator(validate_log_level)
positive_int = _argparse_validator(lambda v: validate_integer(v, min_value=1))
non_negative_int = _argparse_validator(lambda v: validate_integer(v, min_value=0))
resolve_env_files_argparse = _argparse_validator(resolve_env_files)


def parse_memory_labels(json_string: str) -> dict[str, str]:
Expand Down Expand Up @@ -604,24 +650,6 @@ def parse_memory_labels(json_string: str) -> dict[str, str]:
return labels


def resolve_initialization_scripts(script_paths: list[str]) -> list[Path]:
"""
Resolve initialization script paths.

Args:
script_paths: List of script path strings.

Returns:
List of resolved Path objects.

Raises:
FileNotFoundError: If a script path does not exist.
"""
from pyrit.cli.initializer_registry import InitializerRegistry

return InitializerRegistry.resolve_script_paths(script_paths=script_paths)


def get_default_initializer_discovery_path() -> Path:
"""
Get the default path for discovering initializers.
Expand Down Expand Up @@ -688,6 +716,8 @@ async def print_initializers_list_async(*, context: FrontendCore, discovery_path
ARG_HELP = {
"initializers": "Built-in initializer names to run before the scenario (e.g., openai_objective_target)",
"initialization_scripts": "Paths to custom Python initialization scripts to run before the scenario",
"env_files": "Paths to environment files to load in order (e.g., .env.production .env.local). Later files "
"override earlier ones.",
"scenario_strategies": "List of strategy names to run (e.g., base64 rot13)",
"max_concurrency": "Maximum number of concurrent attack executions (must be >= 1)",
"max_retries": "Maximum number of automatic retries on exception (must be >= 0)",
Expand Down Expand Up @@ -728,6 +758,7 @@ def parse_run_arguments(*, args_string: str) -> dict[str, Any]:
"scenario_name": parts[0],
"initializers": None,
"initialization_scripts": None,
"env_files": None,
"scenario_strategies": None,
"max_concurrency": None,
"max_retries": None,
Expand All @@ -752,6 +783,13 @@ def parse_run_arguments(*, args_string: str) -> dict[str, Any]:
while i < len(parts) and not parts[i].startswith("--"):
result["initialization_scripts"].append(parts[i])
i += 1
elif parts[i] == "--env-files":
# Collect env file paths until next flag
result["env_files"] = []
i += 1
while i < len(parts) and not parts[i].startswith("--"):
result["env_files"].append(parts[i])
i += 1
elif parts[i] in ("--strategies", "-s"):
# Collect strategies until next flag
result["scenario_strategies"] = []
Expand Down
22 changes: 22 additions & 0 deletions pyrit/cli/pyrit_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,13 @@ def parse_args(args=None) -> Namespace:
help=frontend_core.ARG_HELP["initialization_scripts"],
)

parser.add_argument(
"--env-files",
type=str,
nargs="+",
help=frontend_core.ARG_HELP["env_files"],
)

parser.add_argument(
"--strategies",
"-s",
Expand Down Expand Up @@ -152,9 +159,18 @@ def main(args=None) -> int:
print(f"Error: {e}")
return 1

env_files = None
if parsed_args.env_files:
try:
env_files = frontend_core.resolve_env_files(env_file_paths=parsed_args.env_files)
except ValueError as e:
print(f"Error: {e}")
return 1

context = frontend_core.FrontendCore(
database=parsed_args.database,
initialization_scripts=initialization_scripts,
env_files=env_files,
log_level=parsed_args.log_level,
)

Expand All @@ -181,11 +197,17 @@ def main(args=None) -> int:
script_paths=parsed_args.initialization_scripts
)

# Collect environment files
env_files = None
if parsed_args.env_files:
env_files = frontend_core.resolve_env_files(env_file_paths=parsed_args.env_files)

# Create context with initializers
context = frontend_core.FrontendCore(
database=parsed_args.database,
initialization_scripts=initialization_scripts,
initializer_names=parsed_args.initializers,
env_files=env_files,
log_level=parsed_args.log_level,
)

Expand Down
34 changes: 34 additions & 0 deletions pyrit/cli/pyrit_shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,12 @@ class PyRITShell(cmd.Cmd):
Shell Startup Options:
--database <type> Database type (InMemory, SQLite, AzureSQL) - default for all runs
--log-level <level> Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL) - default for all runs
--env-files <path> ... Environment files to load in order - default for all runs

Run Command Options:
--initializers <name> ... Built-in initializers to run before the scenario
--initialization-scripts <...> Custom Python scripts to run before the scenario
--env-files <path> ... Environment files to load in order (overrides startup default)
--strategies, -s <s1> ... Strategy names to use
--max-concurrency <N> Maximum concurrent operations
--max-retries <N> Maximum retry attempts
Expand Down Expand Up @@ -97,6 +99,7 @@ def __init__(
self.context = context
self.default_database = context._database
self.default_log_level = context._log_level
self.default_env_files = context._env_files

# Track scenario execution history: list of (command_string, ScenarioResult) tuples
self._scenario_history: list[tuple[str, ScenarioResult]] = []
Expand Down Expand Up @@ -150,6 +153,7 @@ def do_run(self, line):
Options:
--initializers <name> ... Built-in initializers to run before the scenario
--initialization-scripts <...> Custom Python scripts to run before the scenario
--env-files <path> ... Environment files to load in order
--strategies, -s <s1> <s2> ... Strategy names to use
--max-concurrency <N> Maximum concurrent operations
--max-retries <N> Maximum retry attempts
Expand Down Expand Up @@ -214,11 +218,24 @@ def do_run(self, line):
print(f"Error: {e}")
return

# Resolve env files if provided
resolved_env_files = None
if args["env_files"]:
try:
resolved_env_files = frontend_core.resolve_env_files(env_file_paths=args["env_files"])
except ValueError as e:
print(f"Error: {e}")
return
else:
# Use default env files from shell startup
resolved_env_files = self.default_env_files

# Create a context for this run with overrides
run_context = frontend_core.FrontendCore(
database=args["database"] or self.default_database,
initialization_scripts=resolved_scripts,
initializer_names=args["initializers"],
env_files=resolved_env_files,
log_level=args["log_level"] or self.default_log_level,
)
# Use the existing registries (don't reinitialize)
Expand Down Expand Up @@ -455,13 +472,30 @@ def main():
help="Default logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL) (default: WARNING, can be overridden per-run)",
)

parser.add_argument(
"--env-files",
type=str,
nargs="+",
help="Environment files to load in order (default for all runs, can be overridden per-run)",
)

args = parser.parse_args()

# Resolve env files if provided
env_files = None
if args.env_files:
try:
env_files = frontend_core.resolve_env_files(env_file_paths=args.env_files)
except ValueError as e:
print(f"Error: {e}")
return 1

# Create context (initializers are specified per-run, not at startup)
context = frontend_core.FrontendCore(
database=args.database,
initialization_scripts=None,
initializer_names=None,
env_files=env_files,
log_level=args.log_level,
)

Expand Down
2 changes: 2 additions & 0 deletions pyrit/common/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def in_git_repo() -> bool:

PYRIT_PATH = pathlib.Path(__file__, "..", "..").resolve()

CONFIGURATION_DIRECTORY_PATH = pathlib.Path.home() / ".pyrit"

# Points to the root of the project
HOME_PATH = pathlib.Path(PYRIT_PATH, "..").resolve()

Expand Down
Loading
Loading