diff --git a/doc/code/scenarios/0_scenarios.ipynb b/doc/code/scenarios/0_scenarios.ipynb index b342540983..623fa51a46 100644 --- a/doc/code/scenarios/0_scenarios.ipynb +++ b/doc/code/scenarios/0_scenarios.ipynb @@ -390,9 +390,11 @@ } ], "source": [ - "from pyrit.cli.frontend_core import FrontendCore, print_scenarios_list_async\n", + "from pyrit.backend.services.scenario_service import get_scenario_service\n", + "from pyrit.cli._output import print_scenario_list\n", "\n", - "await print_scenarios_list_async(context=FrontendCore()) # type: ignore" + "response = await get_scenario_service().list_scenarios_async(limit=200) # type: ignore\n", + "print_scenario_list(items=[s.model_dump() for s in response.items])" ] }, { diff --git a/doc/code/scenarios/0_scenarios.py b/doc/code/scenarios/0_scenarios.py index bf04a72f62..1f973ac14c 100644 --- a/doc/code/scenarios/0_scenarios.py +++ b/doc/code/scenarios/0_scenarios.py @@ -165,9 +165,11 @@ def _build_display_group(self, *, technique_name: str, seed_group_name: str) -> # ## Existing Scenarios # %% -from pyrit.cli.frontend_core import FrontendCore, print_scenarios_list_async +from pyrit.backend.services.scenario_service import get_scenario_service +from pyrit.cli._output import print_scenario_list -await print_scenarios_list_async(context=FrontendCore()) # type: ignore +response = await get_scenario_service().list_scenarios_async(limit=200) # type: ignore +print_scenario_list(items=[s.model_dump() for s in response.items]) # %% [markdown] # diff --git a/doc/code/scenarios/2_custom_scenario_parameters.ipynb b/doc/code/scenarios/2_custom_scenario_parameters.ipynb index 1fbe651935..541be5dbde 100644 --- a/doc/code/scenarios/2_custom_scenario_parameters.ipynb +++ b/doc/code/scenarios/2_custom_scenario_parameters.ipynb @@ -290,15 +290,14 @@ } ], "source": [ - "from pyrit.cli.frontend_core import format_scenario_metadata\n", - "from pyrit.registry import ScenarioRegistry\n", + "from pyrit.backend.services.scenario_service import get_scenario_service\n", + "from pyrit.cli._output import print_scenario_list\n", "\n", "# Show scam (declares a parameter) and red_team_agent (none), so the\n", "# Supported Parameters section is visible in one and absent in the other.\n", "demo_names = {\"airt.scam\", \"foundry.red_team_agent\"}\n", - "for metadata in ScenarioRegistry.get_registry_singleton().list_metadata():\n", - " if metadata.registry_name in demo_names:\n", - " format_scenario_metadata(scenario_metadata=metadata)" + "response = await get_scenario_service().list_scenarios_async(limit=200) # type: ignore\n", + "print_scenario_list(items=[s.model_dump() for s in response.items if s.scenario_name in demo_names])" ] }, { diff --git a/doc/code/scenarios/2_custom_scenario_parameters.py b/doc/code/scenarios/2_custom_scenario_parameters.py index 2069b43eb0..bcebdb1f9e 100644 --- a/doc/code/scenarios/2_custom_scenario_parameters.py +++ b/doc/code/scenarios/2_custom_scenario_parameters.py @@ -182,15 +182,14 @@ # CLI uses is callable programmatically: # %% -from pyrit.cli.frontend_core import format_scenario_metadata -from pyrit.registry import ScenarioRegistry +from pyrit.backend.services.scenario_service import get_scenario_service +from pyrit.cli._output import print_scenario_list # Show scam (declares a parameter) and red_team_agent (none), so the # Supported Parameters section is visible in one and absent in the other. demo_names = {"airt.scam", "foundry.red_team_agent"} -for metadata in ScenarioRegistry.get_registry_singleton().list_metadata(): - if metadata.registry_name in demo_names: - format_scenario_metadata(scenario_metadata=metadata) +response = await get_scenario_service().list_scenarios_async(limit=200) # type: ignore +print_scenario_list(items=[s.model_dump() for s in response.items if s.scenario_name in demo_names]) # %% [markdown] # Notice the `Supported Parameters:` section under `airt.scam`. It's absent diff --git a/doc/myst.yml b/doc/myst.yml index df83376e47..ddcccd91c0 100644 --- a/doc/myst.yml +++ b/doc/myst.yml @@ -178,11 +178,11 @@ project: children: - file: api/pyrit_analytics.md - file: api/pyrit_auth.md - - file: api/pyrit_cli_frontend_core.md - - file: api/pyrit_cli_pyrit_backend.md + - file: api/pyrit_cli_api_client.md - file: api/pyrit_cli_pyrit_scan.md - file: api/pyrit_cli_pyrit_shell.md - file: api/pyrit_common.md + - file: api/pyrit_common_cli_helpers.md - file: api/pyrit_datasets.md - file: api/pyrit_embedding.md - file: api/pyrit_exceptions.md diff --git a/docker/start.sh b/docker/start.sh index 8376791733..81ae582d66 100644 --- a/docker/start.sh +++ b/docker/start.sh @@ -57,25 +57,33 @@ if [ "$PYRIT_MODE" = "jupyter" ]; then exec jupyter lab --ip=0.0.0.0 --port=8888 --no-browser --allow-root --notebook-dir=/app/notebooks elif [ "$PYRIT_MODE" = "gui" ]; then echo "Starting PyRIT GUI on port 8000..." - # Use Azure SQL if AZURE_SQL_SERVER is set (injected by Bicep), otherwise default to SQLite. - # Note: AZURE_SQL_DB_CONNECTION_STRING is in the .env file (loaded by Python dotenv), - # but we use AZURE_SQL_SERVER here because it's a direct env var from the Bicep template. - # Build CLI arguments - BACKEND_ARGS="--host 0.0.0.0 --port 8000" + # The thin backend only takes --host/--port/--config-file/--log-level. + # Translate AZURE_SQL_SERVER and PYRIT_INITIALIZER into a runtime config file + # so the FastAPI lifespan (ConfigurationLoader) picks them up on startup. + RUNTIME_CONFIG=/tmp/pyrit_runtime.yaml + { + if [ -n "$AZURE_SQL_SERVER" ]; then + echo "Using Azure SQL database (server: $AZURE_SQL_SERVER)" >&2 + echo "memory_db_type: AzureSQL" + else + echo "Using SQLite database (AZURE_SQL_SERVER not set)" >&2 + echo "memory_db_type: SQLite" + fi + if [ -n "$PYRIT_INITIALIZER" ]; then + echo "Using initializer: $PYRIT_INITIALIZER" >&2 + echo "initializers:" + # Split comma-separated initializer names into a YAML list. + IFS=',' read -ra INIT_NAMES <<<"$PYRIT_INITIALIZER" + for name in "${INIT_NAMES[@]}"; do + echo " - $(echo "$name" | xargs)" + done + fi + } >"$RUNTIME_CONFIG" - if [ -n "$AZURE_SQL_SERVER" ]; then - echo "Using Azure SQL database (server: $AZURE_SQL_SERVER)" - BACKEND_ARGS="$BACKEND_ARGS --database AzureSQL" - else - echo "Using SQLite database (AZURE_SQL_SERVER not set)" - fi - - if [ -n "$PYRIT_INITIALIZER" ]; then - echo "Using initializer: $PYRIT_INITIALIZER" - BACKEND_ARGS="$BACKEND_ARGS --initializers $PYRIT_INITIALIZER" - fi - - exec python -m pyrit.cli.pyrit_backend $BACKEND_ARGS + exec python -m pyrit.backend.pyrit_backend \ + --host 0.0.0.0 \ + --port 8000 \ + --config-file "$RUNTIME_CONFIG" else echo "ERROR: Invalid PYRIT_MODE '$PYRIT_MODE'. Must be 'jupyter' or 'gui'" exit 1 diff --git a/frontend/dev.py b/frontend/dev.py index f76dc2b9c3..41e82f903f 100644 --- a/frontend/dev.py +++ b/frontend/dev.py @@ -141,7 +141,7 @@ def _find_pids_on_port(port): def stop_servers(): """Stop all running servers""" print("šŸ›‘ Stopping servers...") - backend_pids = find_pids_by_pattern("pyrit.cli.pyrit_backend") + backend_pids = find_pids_by_pattern("pyrit.backend.pyrit_backend") frontend_pids = find_pids_by_pattern("node.*vite") # Also find any parent dev.py processes (detached wrappers) wrapper_pids = find_pids_by_pattern("frontend/dev.py") @@ -163,6 +163,10 @@ def start_backend(*, config_file: str | None = None, initializers: list[str] | N Configuration (initializers, database, env files) is read automatically from ~/.pyrit/.pyrit_conf by the pyrit_backend CLI via ConfigurationLoader, unless overridden with *config_file*. + + When *initializers* is supplied without a *config_file*, a tiny temporary + runtime config is written to forward those names — ``pyrit_backend`` only + accepts ``--config-file`` now (no ``--initializers`` flag). """ print("šŸš€ Starting backend on port 8000...") @@ -178,7 +182,7 @@ def start_backend(*, config_file: str | None = None, initializers: list[str] | N cmd = [ sys.executable, "-m", - "pyrit.cli.pyrit_backend", + "pyrit.backend.pyrit_backend", "--host", "localhost", "--port", @@ -186,12 +190,22 @@ def start_backend(*, config_file: str | None = None, initializers: list[str] | N "--log-level", "info", ] - if config_file: - cmd.extend(["--config-file", config_file]) - # Add initializers if specified - if initializers: - cmd.extend(["--initializers"] + initializers) + # Resolve config-file: explicit wins; otherwise synthesize one from initializers. + effective_config_file = config_file + if effective_config_file is None and initializers: + import tempfile + + fd, synthesized_path = tempfile.mkstemp(suffix=".yaml", prefix="pyrit_dev_", text=True) + with os.fdopen(fd, "w") as synthesized: + synthesized.write("initializers:\n") + for name in initializers: + synthesized.write(f" - {name}\n") + effective_config_file = synthesized_path + print(f" Wrote initializer overrides to {effective_config_file}") + + if effective_config_file: + cmd.extend(["--config-file", effective_config_file]) # Pipe stdout/stderr so dev.py controls output ordering return subprocess.Popen(cmd, env=env, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) @@ -456,7 +470,7 @@ def main(): elif command == "backend": print("šŸš€ Starting backend only...") # Kill stale backend processes - stale = find_pids_by_pattern("pyrit.cli.pyrit_backend") + stale = find_pids_by_pattern("pyrit.backend.pyrit_backend") if stale: print(f" Killing stale backend PIDs: {stale}") kill_pids(stale) diff --git a/pyproject.toml b/pyproject.toml index bdc563d000..1fca17cd00 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -157,7 +157,7 @@ all = [ ] [project.scripts] -pyrit_backend = "pyrit.cli.pyrit_backend:main" +pyrit_backend = "pyrit.backend.pyrit_backend:main" pyrit_scan = "pyrit.cli.pyrit_scan:main" pyrit_shell = "pyrit.cli.pyrit_shell:main" diff --git a/pyrit/backend/main.py b/pyrit/backend/main.py index 365d2b5656..c2c2f477cf 100644 --- a/pyrit/backend/main.py +++ b/pyrit/backend/main.py @@ -30,7 +30,7 @@ targets, version, ) -from pyrit.memory import CentralMemory +from pyrit.setup.configuration_loader import ConfigurationLoader # Check for development mode from environment variable DEV_MODE = os.getenv("PYRIT_DEV_MODE", "false").lower() == "true" @@ -40,17 +40,38 @@ @asynccontextmanager async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: - """Manage application startup and shutdown lifecycle.""" - # Initialization is handled by the pyrit_backend CLI before uvicorn starts. - # Running 'uvicorn pyrit.backend.main:app' directly is not supported; - # use 'pyrit_backend' instead. - try: - CentralMemory.get_memory_instance() - except ValueError: - logger.warning( - "CentralMemory is not initialized. " - "Start the server via 'pyrit_backend' CLI instead of running uvicorn directly." - ) + """ + Initialize PyRIT on startup using the config file, then yield. + + Config resolution order: + 1. ``PYRIT_CONFIG_FILE`` env var (if set) + 2. ``~/.pyrit/.pyrit_conf`` (if it exists) + 3. Built-in defaults (SQLite, no initializers) + """ + config_file_env = os.getenv("PYRIT_CONFIG_FILE") + config_file = Path(config_file_env) if config_file_env else None + + config = ConfigurationLoader.load_with_overrides(config_file=config_file) + await config.initialize_pyrit_async() + + # Expose config values to route handlers via app.state + default_labels: dict[str, str] = {} + if config.operator: + default_labels["operator"] = config.operator + if config.operation: + default_labels["operation"] = config.operation + app.state.default_labels = default_labels + app.state.max_concurrent_scenario_runs = config.max_concurrent_scenario_runs + app.state.allow_custom_initializers = config.allow_custom_initializers + + if config.allow_custom_initializers: + logger.warning("Custom initializer registration is ENABLED (allow_custom_initializers: true).") + + # Mount the bundled frontend (or print a dev/missing-frontend notice). + # Done here rather than at module load so test imports of `pyrit.backend.main` + # don't emit noise and don't perform filesystem side effects. + setup_frontend() + yield @@ -124,7 +145,3 @@ def setup_frontend() -> None: print(" The frontend must be built and included in the package.") print(" Run: python build_scripts/prepare_package.py") print(" API endpoints will still work but the UI won't be available.") - - -# Set up frontend at module load time (needed when running via uvicorn) -setup_frontend() diff --git a/pyrit/backend/models/scenarios.py b/pyrit/backend/models/scenarios.py index 1236817e19..54480b76c7 100644 --- a/pyrit/backend/models/scenarios.py +++ b/pyrit/backend/models/scenarios.py @@ -14,7 +14,6 @@ from pydantic import BaseModel, Field -from pyrit.backend.models.attacks import AttackSummary from pyrit.backend.models.common import PaginationInfo @@ -25,7 +24,8 @@ class ScenarioParameterSummary(BaseModel): description: str = Field(..., description="Human-readable description of the parameter") default: str | None = Field(None, description="Default value as a display string, or None if required") param_type: str = Field(..., description="Type of the parameter as a display string (e.g., 'int', 'str')") - choices: str | None = Field(None, description="Allowed values as a display string, or None if unconstrained") + choices: list[str] | None = Field(None, description="Allowed values as strings, or None if unconstrained") + is_list: bool = Field(False, description="True when the parameter accepts a list of values (e.g., list[str])") class RegisteredScenario(BaseModel): @@ -124,28 +124,3 @@ class ScenarioRunListResponse(BaseModel): """Response for listing scenario runs.""" items: list[ScenarioRunSummary] = Field(..., description="List of scenario runs") - - -# ============================================================================ -# Scenario Results Detail Models -# ============================================================================ - - -class AtomicAttackResults(BaseModel): - """Results grouped by atomic attack name.""" - - atomic_attack_name: str = Field(..., description="Name of the atomic attack (strategy)") - display_group: str | None = Field(None, description="Display group label for UI grouping") - results: list[AttackSummary] = Field(..., description="Individual attack results") - success_count: int = Field(0, ge=0, description="Number of successful attacks") - failure_count: int = Field(0, ge=0, description="Number of failed attacks") - total_count: int = Field(0, ge=0, description="Total number of attack results") - total_retries: int = Field(0, ge=0, description="Sum of retries across all attacks in this group") - error_count: int = Field(0, ge=0, description="Number of attacks with errors") - - -class ScenarioRunDetail(BaseModel): - """Full detailed results of a scenario run.""" - - run: ScenarioRunSummary = Field(..., description="The scenario run summary") - attacks: list[AtomicAttackResults] = Field(..., description="Results grouped by atomic attack") diff --git a/pyrit/backend/pyrit_backend.py b/pyrit/backend/pyrit_backend.py new file mode 100644 index 0000000000..489270540e --- /dev/null +++ b/pyrit/backend/pyrit_backend.py @@ -0,0 +1,132 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +PyRIT Backend CLI - Thin wrapper around uvicorn for the PyRIT backend server. + +All initialization (config loading, memory setup, initializer execution) is +handled by the FastAPI lifespan in ``pyrit.backend.main``. This CLI simply +parses host/port/config-file/log-level/reload and starts uvicorn. + +The config file path is forwarded to the app via the ``PYRIT_CONFIG_FILE`` +environment variable. +""" + +import logging +import os +import sys +from argparse import ArgumentParser, Namespace, RawDescriptionHelpFormatter +from pathlib import Path +from typing import Optional + +from pyrit.common.cli_helpers import CONFIG_FILE_HELP, validate_log_level_argparse + + +def parse_args(*, args: Optional[list[str]] = None) -> Namespace: + """ + Parse command-line arguments for the PyRIT backend server. + + Returns: + Namespace: Parsed command-line arguments. + """ + parser = ArgumentParser( + prog="pyrit_backend", + description="""PyRIT Backend - Run the PyRIT backend API server + +All configuration (database, initializers, env-files, etc.) is read from +the config file (~/.pyrit/.pyrit_conf by default, or --config-file). + +Examples: + # Start backend with default settings + pyrit_backend + + # Start with a custom config file + pyrit_backend --config-file ./my_config.yaml + + # Start with custom port and host + pyrit_backend --host 0.0.0.0 --port 8080 + + # Start with auto-reload for development + pyrit_backend --reload +""", + formatter_class=RawDescriptionHelpFormatter, + ) + + parser.add_argument( + "--host", + type=str, + default="localhost", + help="Host to bind the server to (default: localhost)", + ) + + parser.add_argument( + "--port", + type=int, + default=8000, + help="Port to bind the server to (default: 8000)", + ) + + parser.add_argument( + "--config-file", + type=Path, + help=CONFIG_FILE_HELP, + ) + + parser.add_argument( + "--log-level", + type=validate_log_level_argparse, + default="WARNING", + help="Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL) (default: WARNING)", + ) + + parser.add_argument( + "--reload", + action="store_true", + help="Enable auto-reload for development (watches for file changes)", + ) + + return parser.parse_args(args) + + +def main(*, args: Optional[list[str]] = None) -> int: + """ + Start the PyRIT backend server. + + Returns: + int: Exit code (0 for success, 1 for error). + """ + sys.stdout.reconfigure(errors="replace") # type: ignore[ty:unresolved-attribute] + sys.stderr.reconfigure(errors="replace") # type: ignore[ty:unresolved-attribute] + + try: + parsed_args = parse_args(args=args) + except SystemExit as e: + return e.code if isinstance(e.code, int) else 1 + + # Forward config file to the FastAPI lifespan via env var + if parsed_args.config_file is not None: + os.environ["PYRIT_CONFIG_FILE"] = str(parsed_args.config_file) + + try: + import uvicorn + + uvicorn.run( + "pyrit.backend.main:app", + host=parsed_args.host, + port=parsed_args.port, + log_level=logging.getLevelName(parsed_args.log_level).lower() + if isinstance(parsed_args.log_level, int) + else parsed_args.log_level.lower(), + reload=parsed_args.reload, + ) + return 0 + except KeyboardInterrupt: + print("\nšŸ›‘ Backend stopped") + return 0 + except Exception as e: + print(f"\nError: {e}") + return 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/pyrit/backend/routes/scenarios.py b/pyrit/backend/routes/scenarios.py index 756ce76755..4052a45075 100644 --- a/pyrit/backend/routes/scenarios.py +++ b/pyrit/backend/routes/scenarios.py @@ -21,7 +21,6 @@ ListRegisteredScenariosResponse, RegisteredScenario, RunScenarioRequest, - ScenarioRunDetail, ScenarioRunListResponse, ScenarioRunSummary, ) @@ -197,24 +196,22 @@ async def cancel_scenario_run(scenario_result_id: str) -> ScenarioRunSummary: @router.get( "/runs/{scenario_result_id}/results", - response_model=ScenarioRunDetail, responses={ 404: {"model": ProblemDetail, "description": "Run not found"}, 409: {"model": ProblemDetail, "description": "Run not yet completed"}, }, ) -async def get_scenario_run_results(scenario_result_id: str) -> ScenarioRunDetail: +async def get_scenario_run_results(scenario_result_id: str) -> dict: """ Get detailed results for a completed scenario run. - Returns per-attack outcomes including objectives, responses, scores, - and success/failure counts. + Returns the full ScenarioResult serialization. Args: scenario_result_id: The scenario_result_id. Returns: - ScenarioRunDetail: Full attack-level results. + dict: ScenarioResult.to_dict() payload. """ service = get_scenario_run_service() try: @@ -227,4 +224,4 @@ async def get_scenario_run_results(scenario_result_id: str) -> ScenarioRunDetail status_code=status.HTTP_404_NOT_FOUND, detail=f"Scenario run '{scenario_result_id}' not found", ) - return result + return result.to_dict() diff --git a/pyrit/backend/services/scenario_run_service.py b/pyrit/backend/services/scenario_run_service.py index 6b04f2a71c..1bf61f1b37 100644 --- a/pyrit/backend/services/scenario_run_service.py +++ b/pyrit/backend/services/scenario_run_service.py @@ -12,15 +12,10 @@ import contextlib import logging from dataclasses import dataclass -from datetime import datetime, timezone from typing import TYPE_CHECKING, Any -from pyrit.backend.mappers.attack_mappers import retry_events_to_response from pyrit.backend.models.scenarios import ( - AtomicAttackResults, - AttackSummary, RunScenarioRequest, - ScenarioRunDetail, ScenarioRunListResponse, ScenarioRunStatus, ScenarioRunSummary, @@ -448,18 +443,15 @@ def _build_response_from_db(self, *, scenario_result: ScenarioResult) -> Scenari completed_at=scenario_result.completion_time, ) - def get_run_results(self, *, scenario_result_id: str) -> ScenarioRunDetail | None: + def get_run_results(self, *, scenario_result_id: str) -> ScenarioResult | None: """ - Get detailed results for a completed scenario run. - - Retrieves the full ScenarioResult from CentralMemory and maps it - to a detailed response model with per-attack outcomes. + Get the ScenarioResult for a completed scenario run. Args: scenario_result_id: The scenario result ID. Returns: - ScenarioRunDetail if the run is completed and results exist, None if not found. + ScenarioResult if the run is completed and results exist, None if not found. Raises: ValueError: If the run is not in a completed state. @@ -474,83 +466,7 @@ def get_run_results(self, *, scenario_result_id: str) -> ScenarioRunDetail | Non if run_response.status != ScenarioRunStatus.COMPLETED: raise ValueError(f"Results are only available for completed runs. Current status: '{run_response.status}'.") - # Build per-attack detail - attacks: list[AtomicAttackResults] = [] - display_group_map = scenario_result.display_group_map - for attack_name, attack_results in scenario_result.attack_results.items(): - details: list[AttackSummary] = [] - success_count = 0 - failure_count = 0 - group_total_retries = 0 - group_error_count = 0 - - for ar in attack_results: - score_value = None - if ar.last_score is not None: - score_value = str(ar.last_score.get_value()) - - last_response_text = None - if ar.last_response is not None: - last_response_text = str(ar.last_response) - - timestamp = ar.timestamp or datetime.now(timezone.utc) - - # Build retry event responses using the shared mapper - retry_event_responses = retry_events_to_response(ar.retry_events) - - # Extract error/retry fields - ar_error_message = ar.error_message - ar_error_type = ar.error_type - ar_error_traceback = ar.error_traceback - ar_total_retries = ar.total_retries - - details.append( - AttackSummary( - attack_result_id=ar.attack_result_id, - conversation_id=ar.conversation_id, - objective=ar.objective, - outcome=ar.outcome.value, - outcome_reason=ar.outcome_reason, - last_response=last_response_text, - score_value=score_value, - executed_turns=ar.executed_turns, - execution_time_ms=ar.execution_time_ms, - created_at=timestamp, - updated_at=timestamp, - error_message=ar_error_message, - error_type=ar_error_type, - error_traceback=ar_error_traceback, - total_retries=ar_total_retries, - retry_events=retry_event_responses, - ) - ) - - if ar.outcome == AttackOutcome.SUCCESS: - success_count += 1 - elif ar.outcome == AttackOutcome.FAILURE: - failure_count += 1 - - group_total_retries += ar_total_retries - if ar_error_message: - group_error_count += 1 - - attacks.append( - AtomicAttackResults( - atomic_attack_name=attack_name, - display_group=display_group_map.get(attack_name), - results=details, - success_count=success_count, - failure_count=failure_count, - total_count=len(details), - total_retries=group_total_retries, - error_count=group_error_count, - ) - ) - - return ScenarioRunDetail( - run=run_response, - attacks=attacks, - ) + return scenario_result _service_instance: ScenarioRunService | None = None diff --git a/pyrit/backend/services/scenario_service.py b/pyrit/backend/services/scenario_service.py index 1f8d4dee61..939b863306 100644 --- a/pyrit/backend/services/scenario_service.py +++ b/pyrit/backend/services/scenario_service.py @@ -45,6 +45,7 @@ def _metadata_to_registered_scenario(metadata: ScenarioMetadata) -> RegisteredSc default=repr(p.default) if p.default is not None else None, param_type=p.param_type, choices=p.choices, + is_list=p.is_list, ) for p in metadata.supported_parameters ], diff --git a/pyrit/cli/_cli_args.py b/pyrit/cli/_cli_args.py index f88bb0f990..e982b6ae06 100644 --- a/pyrit/cli/_cli_args.py +++ b/pyrit/cli/_cli_args.py @@ -23,6 +23,11 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Optional, get_origin +from pyrit.common.cli_helpers import ( + CONFIG_FILE_HELP, + validate_log_level, + validate_log_level_argparse, +) from pyrit.common.parameter import Parameter, coerce_value if TYPE_CHECKING: @@ -62,27 +67,6 @@ def validate_database(*, database: str) -> str: return database -def validate_log_level(*, log_level: str) -> int: - """ - Validate log level and convert to logging constant. - - Args: - log_level: Log level string (case-insensitive). - - Returns: - Validated log level as logging constant (e.g., logging.WARNING). - - Raises: - ValueError: If log level is invalid. - """ - valid_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] - level_upper = log_level.upper() - if level_upper not in valid_levels: - raise ValueError(f"Invalid log level: {log_level}. Must be one of: {', '.join(valid_levels)}") - level_value: int = getattr(logging, level_upper) - return level_value - - def validate_integer(value: str, *, name: str = "value", min_value: Optional[int] = None) -> int: """ Validate and parse an integer value. @@ -226,7 +210,6 @@ def resolve_env_files(*, env_file_paths: list[str]) -> list[Path]: # apply the min_value parameter while still allowing the decorator to work correctly. # --------------------------------------------------------------------------- validate_database_argparse = _argparse_validator(validate_database) -validate_log_level_argparse = _argparse_validator(validate_log_level) positive_int = _argparse_validator(lambda v: validate_integer(v, min_value=1)) non_negative_int = _argparse_validator(lambda v: validate_integer(v, min_value=0)) resolve_env_files_argparse = _argparse_validator(resolve_env_files) @@ -270,11 +253,7 @@ def parse_memory_labels(json_string: str) -> dict[str, str]: # Shared argument help text # --------------------------------------------------------------------------- ARG_HELP = { - "config_file": ( - "Path to a YAML configuration file. Allows specifying database, initializers (with args), " - "initialization scripts, and env files. CLI arguments override config file values. " - "If not specified, ~/.pyrit/.pyrit_conf is loaded if it exists." - ), + "config_file": CONFIG_FILE_HELP, "initializers": ( "Built-in initializer names to run before the scenario. " "Supports optional params with name:key=val syntax " @@ -434,12 +413,10 @@ class _ArgSpec: _RUN_ARG_SPECS: list[_ArgSpec] = [ _INITIALIZERS_ARG, - _INIT_SCRIPTS_ARG, _STRATEGIES_ARG, _MAX_CONCURRENCY_ARG, _MAX_RETRIES_ARG, _MEMORY_LABELS_ARG, - _LOG_LEVEL_ARG, _DATASET_NAMES_ARG, _MAX_DATASET_SIZE_ARG, _TARGET_ARG, @@ -596,13 +573,14 @@ def _arg_spec_from_parameter(*, param: Parameter) -> _ArgSpec: """ multi = get_origin(param.param_type) is list parser: Callable[[str], Any] | None - if param.param_type is None or param.param_type is str: - parser = None - elif multi: + if multi: # Per-element coercion; v1 only ships list[str]. parser = str + elif param.param_type is None or (param.param_type is str and param.choices is None): + # No coercion needed and no choices to enforce. + parser = None else: - + # Coerce + validate (handles ints/floats/bools AND str-with-choices). def parser(raw: str) -> Any: return coerce_value(param=param, raw_value=raw) @@ -665,6 +643,45 @@ def extract_scenario_args(*, parsed: dict[str, Any]) -> dict[str, Any]: # --------------------------------------------------------------------------- +def build_parameters_from_api(*, api_params: list[dict[str, Any]]) -> Optional[list[Parameter]]: + """ + Build ``Parameter`` objects from a scenario catalog's ``supported_parameters``. + + Maps the display ``param_type`` string ("int", "float", "bool", "str", + "list[...]", "any") back to a concrete ``param_type`` so the shell parser + can apply per-element coercion and treat list params as ``multi_value``. + + Args: + api_params: List of parameter dicts from ``GET /api/scenarios/catalog/{name}``. + + Returns: + Optional[list[Parameter]]: Parameter list when ``api_params`` is non-empty, else ``None``. + """ + if not api_params: + return None + type_map: dict[str, Any] = {"int": int, "float": float, "bool": bool, "str": str} + parameters: list[Parameter] = [] + for p in api_params: + type_display = p.get("param_type", "") + if p.get("is_list"): + element_type = type_map.get(type_display.removeprefix("list[").rstrip("]"), str) + resolved_type: Any = list[element_type] # type: ignore[valid-type] + else: + resolved_type = type_map.get(type_display) + raw_choices = p.get("choices") + choices: Optional[tuple[Any, ...]] = tuple(raw_choices) if raw_choices else None + parameters.append( + Parameter( + name=p["name"], + description=p.get("description", ""), + param_type=resolved_type, + default=p.get("default"), + choices=choices, + ) + ) + return parameters + + def add_common_arguments(parser: argparse.ArgumentParser) -> None: """Add arguments shared between pyrit_shell and pyrit_scan.""" parser.add_argument("--config-file", type=Path, help=ARG_HELP["config_file"]) diff --git a/pyrit/cli/_config_reader.py b/pyrit/cli/_config_reader.py new file mode 100644 index 0000000000..8be554fe1a --- /dev/null +++ b/pyrit/cli/_config_reader.py @@ -0,0 +1,116 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Lightweight config reader for the PyRIT CLI thin client. + +Reads only the ``server.url`` field from ``~/.pyrit/.pyrit_conf`` (and an +optional overlay file) using ``yaml.safe_load``. No heavy pyrit imports. +""" + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import Any + +_logger = logging.getLogger(__name__) + +# Mirror the default path from pyrit.common.path without importing it. +_DEFAULT_CONFIG_DIR = Path.home() / ".pyrit" +_DEFAULT_CONFIG_FILE = _DEFAULT_CONFIG_DIR / ".pyrit_conf" + +DEFAULT_SERVER_URL = "http://localhost:8000" + +# Top-level config blocks the thin CLI does not read (server picks them up). +# Surfacing them here lets us warn users whose configs still drive scenario +# selection or scenario args from disk. +_CLIENT_IGNORED_BLOCKS = ("scenario",) + + +def read_server_url(*, config_file: Path | None = None) -> str | None: + """ + Read ``server.url`` from the default config and an optional overlay. + + Layers (later wins): + 1. ``~/.pyrit/.pyrit_conf`` (if it exists) + 2. *config_file* (if provided and exists) + + Args: + config_file (Path | None): Optional explicit config path. + + Returns: + str | None: The server URL, or ``None`` if not configured. + """ + import yaml + + paths: list[Path] = [] + if _DEFAULT_CONFIG_FILE.exists(): + paths.append(_DEFAULT_CONFIG_FILE) + if config_file is not None and config_file.exists(): + paths.append(config_file) + + url: str | None = None + for p in paths: + url = _extract_server_url(path=p, yaml_module=yaml) or url + return url + + +def warn_on_client_ignored_blocks(*, config_file: Path | None = None) -> None: + """ + Emit a one-line deprecation notice if the layered config contains blocks + the thin CLI ignores (e.g. ``scenario:``). The server still honors these. + + Args: + config_file: Optional overlay path; the default ``~/.pyrit/.pyrit_conf`` + is always checked when present. + """ + import yaml + + paths: list[Path] = [] + if _DEFAULT_CONFIG_FILE.exists(): + paths.append(_DEFAULT_CONFIG_FILE) + if config_file is not None and config_file.exists(): + paths.append(config_file) + + for p in paths: + try: + with open(p) as fh: + data = yaml.safe_load(fh) + except Exception: + continue + if not isinstance(data, dict): + continue + for block in _CLIENT_IGNORED_BLOCKS: + if block in data: + print( + f"Deprecation: '{block}:' block in {p} is ignored by the CLI " + f"(pass the scenario name positionally instead). " + f"The backend server still reads this block." + ) + + +def _extract_server_url(*, path: Path, yaml_module: Any) -> str | None: + """ + Extract ``server.url`` from a single YAML file. + + Args: + path (Path): YAML config file path. + yaml_module (Any): The imported ``yaml`` module (passed to avoid + top-level import). + + Returns: + str | None: The URL string, or ``None`` if absent/malformed. + """ + try: + with open(path) as fh: + data = yaml_module.safe_load(fh) + if isinstance(data, dict): + server_block = data.get("server") + if isinstance(server_block, dict): + raw_url = server_block.get("url") + if isinstance(raw_url, str) and raw_url.strip(): + return raw_url.strip() + except Exception: + _logger.debug("Failed to read server URL from %s", path, exc_info=True) + return None diff --git a/pyrit/cli/_output.py b/pyrit/cli/_output.py new file mode 100644 index 0000000000..3204c0ad9b --- /dev/null +++ b/pyrit/cli/_output.py @@ -0,0 +1,340 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Console output formatting for the PyRIT CLI thin client. + +All functions accept plain ``dict`` payloads (deserialized JSON from the REST +API) and print human-readable output to stdout. No heavy pyrit imports. +""" + +from __future__ import annotations + +import sys +from typing import Any + +try: + import termcolor + + _HAS_COLOR = True +except ImportError: + _HAS_COLOR = False + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + + +def _cprint(text: str, *, color: str | None = None, bold: bool = False) -> None: + """Print *text*, optionally coloured if ``termcolor`` is available.""" + if _HAS_COLOR and color: + attrs = ["bold"] if bold else None + termcolor.cprint(text, color, attrs=attrs) + else: + print(text) + + +def _header(text: str) -> None: + _cprint(f"\n {text}", color="cyan", bold=True) + + +def _wrap(*, text: str, indent: str, width: int = 78) -> str: + """ + Word-wrap *text* with the given *indent*. + + Returns: + str: The wrapped text with newline separators. + """ + words = text.split() + lines: list[str] = [] + current = "" + for word in words: + if not current: + current = word + elif len(indent) + len(current) + 1 + len(word) <= width: + current += " " + word + else: + lines.append(indent + current) + current = word + if current: + lines.append(indent + current) + return "\n".join(lines) + + +# --------------------------------------------------------------------------- +# Scenario listing +# --------------------------------------------------------------------------- + + +def print_scenario_list(*, items: list[dict[str, Any]]) -> None: + """ + Print a formatted list of scenarios. + + Args: + items: List of scenario dicts from ``GET /api/scenarios/catalog``. + """ + if not items: + print("No scenarios found.") + return + + print("\nAvailable Scenarios:") + print("=" * 80) + for sc in items: + _header(sc.get("scenario_name", "unknown")) + print(f" Class: {sc.get('scenario_type', '')}") + desc = sc.get("description", "") + if desc: + print(" Description:") + print(_wrap(text=desc, indent=" ")) + agg = sc.get("aggregate_strategies") or [] + if agg: + print(" Aggregate Strategies:") + print(_wrap(text=", ".join(agg), indent=" - ")) + strategies = sc.get("all_strategies") or [] + if strategies: + print(f" Available Strategies ({len(strategies)}):") + print(_wrap(text=", ".join(strategies), indent=" ")) + default_strat = sc.get("default_strategy") + if default_strat: + print(f" Default Strategy: {default_strat}") + datasets = sc.get("default_datasets") or [] + max_ds = sc.get("max_dataset_size") + if datasets: + suffix = f", max {max_ds} per dataset" if max_ds else "" + print(f" Default Datasets ({len(datasets)}{suffix}):") + print(_wrap(text=", ".join(datasets), indent=" ")) + params = sc.get("supported_parameters") or [] + if params: + print(" Supported Parameters:") + for p in params: + default_str = f" [default: {p.get('default')!r}]" if p.get("default") is not None else "" + type_str = f" ({p.get('param_type', '')})" if p.get("param_type") else "" + choices = p.get("choices") + choices_display = ", ".join(choices) if isinstance(choices, list) else choices + choices_str = f" [choices: {choices_display}]" if choices_display else "" + print(f" - {p.get('name', '?')}{type_str}{default_str}{choices_str}: {p.get('description', '')}") + print("\n" + "=" * 80) + print(f"\nTotal scenarios: {len(items)}") + + +# --------------------------------------------------------------------------- +# Initializer listing +# --------------------------------------------------------------------------- + + +def print_initializer_list(*, items: list[dict[str, Any]]) -> None: + """ + Print a formatted list of initializers. + + Args: + items: List of initializer dicts from ``GET /api/initializers``. + """ + if not items: + print("No initializers found.") + return + + print("\nAvailable Initializers:") + print("=" * 80) + for init in items: + _header(init.get("initializer_name", "unknown")) + print(f" Class: {init.get('initializer_type', '')}") + env_vars = init.get("required_env_vars") or [] + if env_vars: + print(" Required Environment Variables:") + for var in env_vars: + print(f" - {var}") + else: + print(" Required Environment Variables: None") + params = init.get("supported_parameters") or [] + if params: + print(" Supported Parameters:") + for p in params: + default_str = f" [default: {p.get('default')}]" if p.get("default") else "" + print(f" - {p.get('name', '?')}{default_str}: {p.get('description', '')}") + desc = init.get("description", "") + if desc: + print(" Description:") + print(_wrap(text=desc, indent=" ")) + print("\n" + "=" * 80) + print(f"\nTotal initializers: {len(items)}") + + +# --------------------------------------------------------------------------- +# Target listing +# --------------------------------------------------------------------------- + + +def print_target_list(*, items: list[dict[str, Any]]) -> None: + """ + Print a formatted list of targets. + + Args: + items: List of target dicts from ``GET /api/targets``. + """ + if not items: + print("\nNo targets found in registry.") + print( + "\nTargets are registered by initializers. Include an initializer that " + "registers targets, for example:\n --initializers target\n" + ) + return + + print("\nRegistered Targets:") + print("=" * 80) + for tgt in items: + _header(tgt.get("target_registry_name", "unknown")) + print(f" Class: {tgt.get('target_type', '')}") + model = tgt.get("underlying_model_name") or tgt.get("model_name") or "" + if model: + print(f" Model: {model}") + endpoint = tgt.get("endpoint") or "" + if endpoint: + print(f" Endpoint: {endpoint}") + print("\n" + "=" * 80) + print(f"\nTotal targets: {len(items)}") + + +# --------------------------------------------------------------------------- +# Scenario run progress & summary +# --------------------------------------------------------------------------- + + +def print_scenario_run_progress(*, run: dict[str, Any], total_strategies: int = 0) -> None: + """ + Print a single-line progress update (overwrites the current line). + + Args: + run: ScenarioRunSummary dict from ``GET /api/scenarios/runs/{id}``. + total_strategies: Total number of strategies expected (0 if unknown). + """ + run_status = run.get("status", "UNKNOWN") + total = run.get("total_attacks", 0) + completed = run.get("completed_attacks", 0) + rate = run.get("objective_achieved_rate", 0) + strategies_done = len(run.get("strategies_used") or []) + # Strategies the user passed may be aggregates that expand on the server + # (e.g. `single_turn` -> N concrete strategies). Trust whichever count is larger. + effective_total = max(total_strategies, strategies_done) + + parts: list[str] = [] + + if effective_total > 0: + parts.append(f"strategies: {strategies_done}/{effective_total}") + elif strategies_done > 0: + parts.append(f"strategies: {strategies_done}") + + if total > 0: + pct = int((completed / total) * 100) + bar_width = 30 + filled = int(bar_width * completed / total) + bar = "ā–ˆ" * filled + "ā–‘" * (bar_width - filled) + parts.append(f"[{bar}] {completed}/{total} attacks ({pct}%)") + else: + parts.append(f"attacks: {completed}") + + parts.append(f"success rate: {rate}%") + parts.append(run_status) + + line = "\r " + " | ".join(parts) + sys.stdout.write(line) + sys.stdout.flush() + + +def print_scenario_run_summary(*, run: dict[str, Any]) -> None: + """ + Print a brief summary of a completed scenario run. + + Args: + run: ScenarioRunSummary dict. + """ + print() # newline after progress bar + status = run.get("status", "UNKNOWN") + name = run.get("scenario_name", "unknown") + rid = run.get("scenario_result_id", "?") + total = run.get("total_attacks", 0) + completed = run.get("completed_attacks", 0) + rate = run.get("objective_achieved_rate", 0) + + print(f"\nScenario: {name}") + print(f" Result ID: {rid}") + print(f" Status: {status}") + print(f" Total Attacks: {total}") + print(f" Completed: {completed}") + print(f" Success Rate: {rate}%") + + error = run.get("error") + if error: + print(f" Error: {error}") + + strategies = run.get("strategies_used") or [] + if strategies: + print(f" Strategies: {', '.join(strategies)}") + + +# --------------------------------------------------------------------------- +# Scenario run detail (full results via output module) +# --------------------------------------------------------------------------- + + +async def print_scenario_result_async(*, result_dict: dict[str, Any]) -> None: + """ + Print detailed scenario results using the output module. + + Args: + result_dict: ``ScenarioResult.to_dict()`` payload from the REST API. + """ + from pyrit.models.scenario_result import ScenarioResult + from pyrit.output.scenario_result.pretty import PrettyScenarioResultMemoryPrinter + + scenario_result = ScenarioResult.from_dict(result_dict) + printer = PrettyScenarioResultMemoryPrinter() + await printer.write_async(scenario_result) + + +# --------------------------------------------------------------------------- +# Scenario run history +# --------------------------------------------------------------------------- + + +def print_scenario_runs_list(*, runs: list[dict[str, Any]]) -> None: + """ + Print a list of scenario run summaries. + + Args: + runs: List of ScenarioRunSummary dicts from ``GET /api/scenarios/runs``. + """ + if not runs: + print("No scenario runs found.") + return + + print("\nScenario Run History:") + print("=" * 80) + for idx, run in enumerate(runs, start=1): + status = run.get("status", "?") + name = run.get("scenario_name", "unknown") + rid = run.get("scenario_result_id", "?")[:8] + total = run.get("total_attacks", 0) + rate = run.get("objective_achieved_rate", 0) + created = run.get("created_at", "?") + print(f" {idx}) [{status}] {name} (id: {rid}…) — {total} attacks, {rate}% success — {created}") + print("=" * 80) + print(f"\nTotal runs: {len(runs)}") + + +# --------------------------------------------------------------------------- +# Error display +# --------------------------------------------------------------------------- + + +def print_error_with_hint(*, message: str, hint: str | None = None) -> None: + """ + Print an error message with an optional actionable hint. + + Args: + message: The error text. + hint: Optional follow-up guidance. + """ + print(f"\nError: {message}") + if hint: + print(f"Hint: {hint}") diff --git a/pyrit/cli/_server_launcher.py b/pyrit/cli/_server_launcher.py new file mode 100644 index 0000000000..56cc61b29f --- /dev/null +++ b/pyrit/cli/_server_launcher.py @@ -0,0 +1,260 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Manage a local ``pyrit_backend`` subprocess. + +Provides helpers to probe whether a server is already running, start a +detached backend process, and (optionally) stop it. +""" + +from __future__ import annotations + +import asyncio +import logging +import os +import signal +import subprocess +import sys +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from pathlib import Path + +_logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Port-based process termination +# --------------------------------------------------------------------------- + + +def _find_pid_on_port_windows(*, port: int) -> int | None: + """ + Find the PID listening on *port* on Windows via ``netstat``. + + Args: + port: TCP port to look up. + + Returns: + int | None: The PID, or ``None`` if no listener was found. + """ + try: + result = subprocess.run( + ["netstat", "-ano", "-p", "TCP"], + capture_output=True, + text=True, + timeout=5, + ) + except (OSError, subprocess.SubprocessError): + return None + for line in result.stdout.splitlines(): + if "LISTENING" not in line: + continue + tokens = line.split() + if len(tokens) < 5: + continue + # Match the local-address column exactly so port 80 doesn't match + # :8000 / :8080 / etc. Handles both IPv4 ("0.0.0.0:8000") and + # IPv6 ("[::]:8000") local-address formats. + local_addr = tokens[1] + if local_addr.rsplit(":", 1)[-1] != str(port): + continue + try: + return int(tokens[-1]) + except (ValueError, IndexError): + continue + return None + + +def _find_pid_on_port_unix(*, port: int) -> int | None: + """ + Find the first PID listening on *port* on Unix via ``lsof``. + + Args: + port: TCP port to look up. + + Returns: + int | None: The PID, or ``None`` if no listener was found. + """ + try: + result = subprocess.run( + ["lsof", "-ti", f":{port}"], + capture_output=True, + text=True, + timeout=5, + ) + except (OSError, subprocess.SubprocessError): + return None + for pid_str in result.stdout.strip().splitlines(): + try: + return int(pid_str) + except ValueError: + continue + return None + + +def stop_server_on_port(*, port: int) -> bool: + """ + Find and terminate the process listening on *port*. + + Args: + port: TCP port to look up. + + Returns: + bool: ``True`` if a process was found and signalled, ``False`` otherwise. + """ + if sys.platform == "win32": + pid = _find_pid_on_port_windows(port=port) + else: + pid = _find_pid_on_port_unix(port=port) + if pid is None: + return False + try: + os.kill(pid, signal.SIGTERM) + return True + except OSError: + return False + + +class ServerLauncher: + """ + Launch and manage a local ``pyrit_backend`` server. + + The subprocess is **detached** — it survives after the parent CLI exits. + This is intentional: a running server on ``localhost:8000`` is reusable + across multiple ``pyrit_scan`` / ``pyrit_shell`` sessions. + """ + + def __init__(self) -> None: + self._process: subprocess.Popen | None = None # type: ignore[type-arg] + self._pid: int | None = None + + # ------------------------------------------------------------------ + # Health probe + # ------------------------------------------------------------------ + + @staticmethod + async def probe_health_async(*, base_url: str) -> bool: + """ + Check whether a server at *base_url* is healthy. + + Args: + base_url: Server root URL (e.g. ``http://localhost:8000``). + + Returns: + bool: ``True`` if ``GET /api/health`` returned 200. + """ + from pyrit.cli.api_client import PyRITApiClient + + async with PyRITApiClient(base_url=base_url) as client: + return await client.health_check_async() + + # ------------------------------------------------------------------ + # Start + # ------------------------------------------------------------------ + + async def start_async( + self, + *, + host: str = "localhost", + port: int = 8000, + config_file: Path | None = None, + log_level: str | None = None, + startup_timeout: int = 30, + ) -> str: + """ + Start ``pyrit_backend`` as a detached subprocess and wait until healthy. + + Args: + host: Bind address forwarded to ``pyrit_backend --host``. + port: Bind port forwarded to ``pyrit_backend --port``. + config_file: Optional config forwarded via ``--config-file``. + log_level: Optional log level forwarded via ``--log-level``. + startup_timeout: Seconds to wait for the server to become healthy. + + Returns: + str: The ``base_url`` of the running server. + + Raises: + RuntimeError: If the server did not become healthy within the timeout. + """ + base_url = f"http://{host}:{port}" + + # Already running? + if await self.probe_health_async(base_url=base_url): + _logger.info("Server already running at %s", base_url) + return base_url + + cmd: list[str] = [ + sys.executable, + "-m", + "pyrit.backend.pyrit_backend", + "--host", + host, + "--port", + str(port), + ] + if config_file is not None: + cmd.extend(["--config-file", str(config_file)]) + if log_level is not None: + cmd.extend(["--log-level", log_level]) + + _logger.info("Launching pyrit_backend: %s", " ".join(cmd)) + + creation_flags = 0 + start_new_session = False + if os.name == "nt": + creation_flags = subprocess.CREATE_NEW_PROCESS_GROUP # type: ignore[attr-defined] + else: + start_new_session = True + + print(f"Starting server at {base_url}...") + sys.stdout.flush() + + self._process = subprocess.Popen( + cmd, + creationflags=creation_flags, + start_new_session=start_new_session, + ) + self._pid = self._process.pid + _logger.info("Backend PID: %d", self._pid) + + # Wait for health, checking if the process crashed + for _elapsed in range(startup_timeout): + await asyncio.sleep(1) + + exit_code = self._process.poll() + if exit_code is not None: + raise RuntimeError(f"Server process exited with code {exit_code} during startup.") + + if await self.probe_health_async(base_url=base_url): + print(f"Server ready (PID {self._pid})") + return base_url + + raise RuntimeError( + f"pyrit_backend did not become healthy within {startup_timeout}s. " + f"Check the server logs or start it manually with: pyrit_backend" + ) + + # ------------------------------------------------------------------ + # Stop + # ------------------------------------------------------------------ + + def stop(self) -> None: + """Terminate the owned subprocess (if any).""" + if self._process is not None: + try: + self._process.terminate() + self._process.wait(timeout=5) + _logger.info("Stopped server (PID %d)", self._pid) + except Exception: + _logger.warning("Failed to stop server (PID %s)", self._pid, exc_info=True) + finally: + self._process = None + self._pid = None + + @property + def pid(self) -> int | None: + """PID of the owned backend process, or ``None``.""" + return self._pid diff --git a/pyrit/cli/api_client.py b/pyrit/cli/api_client.py new file mode 100644 index 0000000000..11674a2298 --- /dev/null +++ b/pyrit/cli/api_client.py @@ -0,0 +1,279 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Async REST client for the PyRIT backend API. + +Uses ``httpx`` internally but defers the import to method calls so that +importing this module does not trigger the import-guard ban on ``httpx`` +at CLI parse time. +""" + +from __future__ import annotations + +import logging +from typing import Any + +_logger = logging.getLogger(__name__) + + +class ServerNotAvailableError(Exception): + """Raised when the CLI cannot reach the PyRIT backend server.""" + + +class PyRITApiClient: + """ + Lightweight async REST client for the PyRIT backend. + + No heavy pyrit imports. + + Use as an async context manager:: + + async with PyRITApiClient(base_url="http://localhost:8000") as client: + scenarios = await client.list_scenarios_async() + """ + + def __init__(self, *, base_url: str) -> None: + """ + Initialize the API client. + + Args: + base_url (str): Base URL of the PyRIT backend (e.g., ``"http://localhost:8000"``). + """ + self._base_url = base_url.rstrip("/") + self._client: Any = None # httpx.AsyncClient (typed Any to avoid top-level import) + + async def __aenter__(self) -> PyRITApiClient: + """ + Open the underlying ``httpx.AsyncClient``. + + Returns: + PyRITApiClient: ``self``, with the HTTP client opened. + """ + import httpx + + self._client = httpx.AsyncClient(base_url=self._base_url, timeout=60.0) + return self + + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + """Close the underlying HTTP client.""" + await self.close_async() + + # ------------------------------------------------------------------ + # Health + # ------------------------------------------------------------------ + + async def health_check_async(self) -> bool: + """ + Probe the server health endpoint. + + Returns: + bool: ``True`` if the server returned a healthy response. + """ + import httpx + + try: + client = self._get_client() + resp = await client.get("/api/health") + return resp.status_code == 200 + except httpx.ConnectError: + return False + except Exception: + _logger.debug("Health check failed", exc_info=True) + return False + + # ------------------------------------------------------------------ + # Scenarios + # ------------------------------------------------------------------ + + async def list_scenarios_async(self, *, limit: int = 200) -> dict[str, Any]: + """ + List all available scenarios. + + Returns: + dict: ``ListRegisteredScenariosResponse`` payload. + """ + return await self._get_json(path="/api/scenarios/catalog", params={"limit": limit}) + + async def get_scenario_async(self, *, scenario_name: str) -> dict[str, Any] | None: + """ + Get metadata for a single scenario. + + Returns: + dict | None: ``RegisteredScenario`` payload, or ``None`` if 404. + + Raises: + httpx.HTTPStatusError: For non-404 HTTP error responses. + """ + import httpx + + try: + return await self._get_json(path=f"/api/scenarios/catalog/{scenario_name}") + except httpx.HTTPStatusError as exc: + if exc.response.status_code == 404: + return None + raise + + # ------------------------------------------------------------------ + # Initializers + # ------------------------------------------------------------------ + + async def list_initializers_async(self, *, limit: int = 200) -> dict[str, Any]: + """ + List all available initializers. + + Returns: + dict: ``ListRegisteredInitializersResponse`` payload. + """ + return await self._get_json(path="/api/initializers", params={"limit": limit}) + + async def register_initializer_async(self, *, name: str, script_content: str) -> dict[str, Any]: + """ + Register a custom initializer by uploading Python source code. + + Args: + name: Registry name for the initializer. + script_content: Python source code containing a ``PyRITInitializer`` subclass. + + Returns: + dict: ``RegisteredInitializer`` payload. + + Raises: + ServerNotAvailableError: If custom initializers are disabled (403). + """ + client = self._get_client() + resp = await client.post( + "/api/initializers", + json={"name": name, "script_content": script_content}, + ) + if resp.status_code == 403: + detail = resp.json().get("detail", "Custom initializer operations are disabled on the server.") + raise ServerNotAvailableError(detail) + resp.raise_for_status() + return resp.json() + + # ------------------------------------------------------------------ + # Targets + # ------------------------------------------------------------------ + + async def list_targets_async(self, *, limit: int = 200) -> dict[str, Any]: + """ + List all available targets. + + Returns: + dict: ``TargetListResponse`` payload. + """ + return await self._get_json(path="/api/targets", params={"limit": limit}) + + # ------------------------------------------------------------------ + # Scenario runs + # ------------------------------------------------------------------ + + async def start_scenario_run_async(self, *, request: dict[str, Any]) -> dict[str, Any]: + """ + Start a new scenario run. + + Args: + request: ``RunScenarioRequest``-shaped dict. + + Returns: + dict: ``ScenarioRunSummary`` payload. + """ + client = self._get_client() + resp = await client.post("/api/scenarios/runs", json=request) + resp.raise_for_status() + return resp.json() + + async def get_scenario_run_async(self, *, scenario_result_id: str) -> dict[str, Any]: + """ + Get the current status of a scenario run. + + Returns: + dict: ``ScenarioRunSummary`` payload. + """ + return await self._get_json(path=f"/api/scenarios/runs/{scenario_result_id}") + + async def get_scenario_run_results_async(self, *, scenario_result_id: str) -> dict[str, Any]: + """ + Get detailed results for a completed scenario run. + + Returns: + dict: ``ScenarioResult.to_dict()`` payload. + """ + return await self._get_json(path=f"/api/scenarios/runs/{scenario_result_id}/results") + + async def cancel_scenario_run_async(self, *, scenario_result_id: str) -> dict[str, Any]: + """ + Cancel a running scenario. + + Returns: + dict: Updated ``ScenarioRunSummary`` payload. + """ + client = self._get_client() + resp = await client.post(f"/api/scenarios/runs/{scenario_result_id}/cancel") + resp.raise_for_status() + return resp.json() + + async def list_scenario_runs_async(self, *, limit: int = 100) -> dict[str, Any]: + """ + List tracked scenario runs. + + Returns: + dict: ``ScenarioRunListResponse`` payload. + """ + return await self._get_json(path="/api/scenarios/runs", params={"limit": limit}) + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + + async def close_async(self) -> None: + """Close the underlying HTTP client.""" + if self._client is not None: + await self._client.aclose() + self._client = None + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _get_client(self) -> Any: + """ + Return the ``httpx.AsyncClient``, raising if not opened. + + Returns: + Any: The opened ``httpx.AsyncClient`` instance. + + Raises: + ServerNotAvailableError: If the client has not been opened via ``__aenter__``. + """ + if self._client is None: + raise ServerNotAvailableError( + f"API client is not connected to {self._base_url}. " + "Use 'async with PyRITApiClient(...)' or call __aenter__ first." + ) + return self._client + + async def _get_json(self, *, path: str, params: dict[str, Any] | None = None) -> dict[str, Any]: + """ + GET a JSON endpoint and return the parsed response. + + Returns: + dict[str, Any]: The parsed JSON response body. + + Raises: + ServerNotAvailableError: On connection failure. + """ + import httpx + + client = self._get_client() + try: + resp = await client.get(path, params=params) + except httpx.ConnectError as exc: + raise ServerNotAvailableError( + f"Cannot connect to PyRIT server at {self._base_url}.\n" + "Hint: Use '--start-server' to launch a local backend, " + "or pass '--server-url '." + ) from exc + resp.raise_for_status() + return resp.json() diff --git a/pyrit/cli/frontend_core.py b/pyrit/cli/frontend_core.py deleted file mode 100644 index f0a85b24d1..0000000000 --- a/pyrit/cli/frontend_core.py +++ /dev/null @@ -1,762 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -""" -Shared core logic for PyRIT Frontends. - -This module contains all the business logic for: -- Loading and discovering scenarios -- Running scenarios -- Formatting output -- Managing initialization scripts - -Both pyrit_scan and pyrit_shell use these functions. -""" - -from __future__ import annotations - -import logging -import sys -from typing import TYPE_CHECKING, Any, Optional - -from pyrit.cli._cli_args import ARG_HELP as ARG_HELP -from pyrit.cli._cli_args import AZURE_SQL as AZURE_SQL -from pyrit.cli._cli_args import IN_MEMORY as IN_MEMORY -from pyrit.cli._cli_args import SQLITE as SQLITE -from pyrit.cli._cli_args import _argparse_validator as _argparse_validator -from pyrit.cli._cli_args import _parse_initializer_arg as _parse_initializer_arg -from pyrit.cli._cli_args import add_common_arguments as add_common_arguments -from pyrit.cli._cli_args import extract_scenario_args as extract_scenario_args -from pyrit.cli._cli_args import non_negative_int as non_negative_int -from pyrit.cli._cli_args import parse_list_targets_arguments as parse_list_targets_arguments -from pyrit.cli._cli_args import parse_memory_labels as parse_memory_labels -from pyrit.cli._cli_args import parse_run_arguments as parse_run_arguments -from pyrit.cli._cli_args import positive_int as positive_int -from pyrit.cli._cli_args import resolve_env_files as resolve_env_files -from pyrit.cli._cli_args import resolve_env_files_argparse as resolve_env_files_argparse -from pyrit.cli._cli_args import validate_database as validate_database -from pyrit.cli._cli_args import validate_database_argparse as validate_database_argparse -from pyrit.cli._cli_args import validate_integer as validate_integer -from pyrit.cli._cli_args import validate_log_level as validate_log_level -from pyrit.cli._cli_args import validate_log_level_argparse as validate_log_level_argparse -from pyrit.output.scenario_result.pretty import PrettyScenarioResultMemoryPrinter as ConsoleScenarioResultPrinter -from pyrit.registry import InitializerRegistry, ScenarioRegistry, TargetRegistry -from pyrit.scenario import DatasetConfiguration -from pyrit.setup import ConfigurationLoader, initialize_pyrit_async -from pyrit.setup.configuration_loader import _MEMORY_DB_TYPE_MAP - -try: - import termcolor - - HAS_TERMCOLOR = True -except ImportError: - HAS_TERMCOLOR = False - - # Create a dummy termcolor module for fallback - class termcolor: # noqa: N801 - """Dummy termcolor fallback for colored printing if termcolor is not installed.""" - - @staticmethod - def cprint(text: str, color: str | None = None, attrs: list[Any] | None = None) -> None: - """Print text without color.""" - print(text) - - -if TYPE_CHECKING: - from collections.abc import Sequence - from pathlib import Path - - from pyrit.models.scenario_result import ScenarioResult - from pyrit.registry import ( - InitializerMetadata, - ScenarioMetadata, - ) - -logger = logging.getLogger(__name__) - - -class FrontendCore: - """ - Shared context for PyRIT operations. - - This object holds all the registries and configuration needed to run - scenarios. It can be created once (for shell) or per-command (for CLI). - """ - - def __init__( - self, - *, - config_file: Optional[Path] = None, - database: Optional[str] = None, - initialization_scripts: Optional[list[Path]] = None, - initializer_names: Optional[list[Any]] = None, - env_files: Optional[list[Path]] = None, - log_level: Optional[int] = None, - ) -> None: - """ - Initialize PyRIT context. - - Configuration is loaded in the following order (later values override earlier): - 1. Default config file (~/.pyrit/.pyrit_conf) if it exists - 2. Explicit config_file argument if provided - 3. Individual CLI arguments (database, initializers, etc.) - - Args: - config_file: Optional path to a YAML-formatted configuration file. - The file uses .pyrit_conf extension but is YAML format. - database: Database type (InMemory, SQLite, or AzureSQL). - initialization_scripts: Optional list of initialization script paths. - initializer_names: Optional list of initializer entries. Each entry can be - a string name (e.g., "simple") or a dict with 'name' and optional 'args' - (e.g., {"name": "target", "args": {"tags": "default,scorer"}}). - env_files: Optional list of environment file paths to load in order. - log_level: Logging level constant (e.g., logging.WARNING). Defaults to logging.WARNING. - - Raises: - ValueError: If database is invalid, or if config file is invalid. - FileNotFoundError: If an explicitly specified config_file does not exist. - """ - # Use provided log level or default to WARNING - self._log_level = log_level if log_level is not None else logging.WARNING - - # Load configuration using ConfigurationLoader.load_with_overrides - try: - config = ConfigurationLoader.load_with_overrides( - config_file=config_file, - memory_db_type=database, - initializers=initializer_names, - initialization_scripts=[str(p) for p in initialization_scripts] if initialization_scripts else None, - env_files=[str(p) for p in env_files] if env_files else None, - ) - except ValueError as e: - # Re-raise with user-friendly message for CLI users - error_msg = str(e) - if "memory_db_type" in error_msg: - raise ValueError( - f"Invalid database type '{database}'. Must be one of: InMemory, SQLite, AzureSQL" - ) from e - raise - - # Extract values from config for internal use - # Use canonical mapping from configuration_loader - self._database = _MEMORY_DB_TYPE_MAP[config.memory_db_type] - self._initialization_scripts = config._resolve_initialization_scripts() - self._initializer_configs = config._initializer_configs if config._initializer_configs else None - self._scenario_config = config._scenario_config - self._env_files = config._resolve_env_files() - self._operator = config.operator - self._operation = config.operation - self._max_concurrent_scenario_runs = config.max_concurrent_scenario_runs - self._allow_custom_initializers = config.allow_custom_initializers - - # Lazy-loaded registries - self._scenario_registry: Optional[ScenarioRegistry] = None - self._initializer_registry: Optional[InitializerRegistry] = None - self._initialized = False - - # Configure logging - logging.basicConfig(level=self._log_level) - - async def initialize_async(self) -> None: - """ - Initialize PyRIT and load registries (heavy operation). - - Sets up memory and loads scenario/initializer registries. - Initializers are NOT run here — they are run separately - (per-scenario in pyrit_scan, or up-front in pyrit_backend). - """ - if self._initialized: - return - - # Initialize PyRIT without initializers (they run separately) - await initialize_pyrit_async( - memory_db_type=self._database, - initialization_scripts=None, - initializers=None, - env_files=self._env_files, - ) - # Mark that initial env loading has been printed - self._silent_reinit = True - - # Load registries (use singleton pattern for shared access) - self._scenario_registry = ScenarioRegistry.get_registry_singleton() - if self._initialization_scripts: - print("Discovering user scenarios...") - sys.stdout.flush() - self._scenario_registry.discover_user_scenarios() - - self._initializer_registry = InitializerRegistry() - - self._initialized = True - - def with_overrides( - self, - *, - initializer_names: Optional[list[Any]] = None, - initialization_scripts: Optional[list[Path]] = None, - log_level: Optional[int] = None, - ) -> FrontendCore: - """ - Create a derived FrontendCore with per-command overrides. - - Copies inherited state (database, env_files, operator, operation, config) - from this instance and applies the given overrides. Shares registries - with the parent to avoid redundant re-discovery and skips re-reading - config files. - - Args: - initializer_names (Optional[list[Any]]): Per-command initializer overrides. - Each entry can be a string name or a dict with 'name' and optional 'args'. - None keeps the parent's value. - initialization_scripts (Optional[list[Path]]): Per-command script overrides. - None keeps the parent's value. - log_level (Optional[int]): Per-command log level override. - None keeps the parent's value. - - Returns: - FrontendCore: A new context ready for use, without re-reading config files. - """ - derived = object.__new__(FrontendCore) - - # Inherit from parent - derived._database = self._database - derived._env_files = self._env_files - derived._operator = self._operator - derived._operation = self._operation - derived._max_concurrent_scenario_runs = self._max_concurrent_scenario_runs - derived._allow_custom_initializers = self._allow_custom_initializers - derived._scenario_config = self._scenario_config - - # Apply overrides or inherit - derived._log_level = log_level if log_level is not None else self._log_level - - if initializer_names is not None: - loader = ConfigurationLoader.from_dict({"initializers": initializer_names}) - derived._initializer_configs = loader._initializer_configs - else: - derived._initializer_configs = self._initializer_configs - - if initialization_scripts is not None: - derived._initialization_scripts = initialization_scripts - else: - derived._initialization_scripts = self._initialization_scripts - - # Share registries (singletons, no need to re-discover) - derived._scenario_registry = self._scenario_registry - derived._initializer_registry = self._initializer_registry - derived._initialized = True - derived._silent_reinit = True - - return derived - - @property - def scenario_registry(self) -> ScenarioRegistry: - """ - Get the scenario registry. Must call await initialize_async() first. - - Raises: - RuntimeError: If initialize_async() has not been called. - ValueError: If the scenario registry is not initialized. - """ - if not self._initialized: - raise RuntimeError( - "FrontendCore not initialized. Call 'await context.initialize_async()' before accessing registries." - ) - if self._scenario_registry is None: - raise ValueError("self._scenario_registry is not initialized") - return self._scenario_registry - - @property - def initializer_registry(self) -> InitializerRegistry: - """ - Get the initializer registry. Must call await initialize_async() first. - - Raises: - RuntimeError: If initialize_async() has not been called. - ValueError: If the initializer registry is not initialized. - """ - if not self._initialized: - raise RuntimeError( - "FrontendCore not initialized. Call 'await context.initialize_async()' before accessing registries." - ) - if self._initializer_registry is None: - raise ValueError("self._initializer_registry is not initialized") - return self._initializer_registry - - -async def list_scenarios_async(*, context: FrontendCore) -> list[ScenarioMetadata]: - """ - List metadata for all available scenarios. - - Args: - context: PyRIT context with loaded registries. - - Returns: - List of scenario metadata dictionaries describing each scenario class. - """ - if not context._initialized: - await context.initialize_async() - return context.scenario_registry.list_metadata() - - -async def list_initializers_async( - *, - context: FrontendCore, -) -> Sequence[InitializerMetadata]: - """ - List metadata for all available initializers. - - Args: - context: PyRIT context with loaded registries. - - Returns: - Sequence of initializer metadata dictionaries describing each initializer class. - """ - if not context._initialized: - await context.initialize_async() - return context.initializer_registry.list_metadata() - - -async def list_targets_async( - *, - context: FrontendCore, -) -> list[str]: - """ - List available target names from the TargetRegistry. - - Since targets are registered by initializers, this function requires initializers - to have been run first. Configure initializers on the FrontendCore context - (via initializer_names or initialization_scripts) before calling this function. - - Args: - context: PyRIT context with loaded registries. - - Returns: - Sorted list of registered target names. - """ - if not context._initialized: - await context.initialize_async() - - # Run initializers and/or initialization scripts to populate the target registry - if context._initializer_configs or context._initialization_scripts: - initializer_instances = [] - if context._initializer_configs: - for config in context._initializer_configs: - initializer_class = context.initializer_registry.get_class(config.name) - instance = initializer_class() - if config.args: - instance.set_params_from_args(args=config.args) - initializer_instances.append(instance) - - await initialize_pyrit_async( - memory_db_type=context._database, - initialization_scripts=context._initialization_scripts, - initializers=initializer_instances or None, - env_files=context._env_files, - silent=getattr(context, "_silent_reinit", False), - ) - - target_registry = TargetRegistry.get_registry_singleton() - return target_registry.get_names() - - -async def run_scenario_async( - *, - scenario_name: str, - context: FrontendCore, - target_name: str | None = None, - scenario_strategies: Optional[list[str]] = None, - max_concurrency: Optional[int] = None, - max_retries: Optional[int] = None, - memory_labels: Optional[dict[str, str]] = None, - dataset_names: Optional[list[str]] = None, - max_dataset_size: Optional[int] = None, - scenario_args: Optional[dict[str, Any]] = None, - print_summary: bool = True, -) -> ScenarioResult: - """ - Run a scenario by name. - - Args: - scenario_name: Name of the scenario to run. - context: PyRIT context with loaded registries. - target_name: Name of a registered target from the TargetRegistry to use as the - objective target. Targets are registered by initializers (e.g., the 'target' - initializer). Use --list-targets to see available names after initializers run. - scenario_strategies: Optional list of strategy names. - max_concurrency: Max concurrent operations. - max_retries: Max retry attempts. - memory_labels: Labels to attach to memory entries. - dataset_names: Optional list of dataset names to use instead of scenario defaults. - If provided, creates a new dataset configuration (fetches all items unless - max_dataset_size is also specified). - max_dataset_size: Optional maximum number of items to use from the dataset. - If dataset_names is provided, limits items from the new datasets. - If only max_dataset_size is provided, overrides the scenario's default limit. - scenario_args: Optional map of scenario-declared parameter values - (CLI/config merge from the caller), passed to - ``Scenario.set_params_from_args`` before ``initialize_async``. - print_summary: Whether to print the summary after execution. Defaults to True. - - Returns: - ScenarioResult: The result of the scenario execution. - - Raises: - ValueError: If scenario not found, target not found, or fails to run. - - Note: - Initializers from PyRITContext will be run before the scenario executes. - """ - # Ensure context is initialized first (loads registries) - # This must happen BEFORE we run initializers to avoid double-initialization - if not context._initialized: - await context.initialize_async() - - # Run initializers before scenario - initializer_instances = None - if context._initializer_configs: - print(f"Running {len(context._initializer_configs)} initializer(s)...") - sys.stdout.flush() - - initializer_instances = [] - - for config in context._initializer_configs: - initializer_class = context.initializer_registry.get_class(config.name) - instance = initializer_class() - if config.args: - instance.set_params_from_args(args=config.args) - initializer_instances.append(instance) - - # Re-initialize PyRIT with the scenario-specific initializers - # This resets memory and applies initializer defaults - await initialize_pyrit_async( - memory_db_type=context._database, - initialization_scripts=context._initialization_scripts, - initializers=initializer_instances, - env_files=context._env_files, - silent=getattr(context, "_silent_reinit", False), - ) - - # Resolve objective target from TargetRegistry - if target_name is not None: - target_registry = TargetRegistry.get_registry_singleton() - objective_target = target_registry.get_instance_by_name(target_name) - if objective_target is None: - available_names = target_registry.get_names() - if not available_names: - raise ValueError( - f"Target '{target_name}' not found. The target registry is empty.\n" - "Targets are registered by initializers. Make sure to include an initializer " - "that registers targets (e.g., --initializers target)." - ) - raise ValueError( - f"Target '{target_name}' not found in registry.\nAvailable targets: {', '.join(available_names)}" - ) - else: - objective_target = None - - # Get scenario class - scenario_class = context.scenario_registry.get_class(scenario_name) - - if scenario_class is None: - available = ", ".join(context.scenario_registry.get_names()) - raise ValueError(f"Scenario '{scenario_name}' not found.\nAvailable scenarios: {available}") - - # Build initialization kwargs (these go to initialize_async, not __init__) - init_kwargs: dict[str, Any] = {} - - if objective_target is not None: - init_kwargs["objective_target"] = objective_target - - if scenario_strategies: - strategy_class = scenario_class.get_strategy_class() - strategy_enums = [] - for name in scenario_strategies: - try: - strategy_enums.append(strategy_class(name)) - except ValueError: - available_strategies = [s.value for s in strategy_class] - raise ValueError( - f"Strategy '{name}' not found for scenario '{scenario_name}'. " - f"Available: {', '.join(available_strategies)}" - ) from None - init_kwargs["scenario_strategies"] = strategy_enums - - if max_concurrency is not None: - init_kwargs["max_concurrency"] = max_concurrency - if max_retries is not None: - init_kwargs["max_retries"] = max_retries - if memory_labels is not None: - init_kwargs["memory_labels"] = memory_labels - - # Build dataset_config based on CLI args: - # - No args: scenario uses its default_dataset_config() - # - dataset_names only: new config with those datasets, fetches all items - # - dataset_names + max_dataset_size: new config with limited items - # - max_dataset_size only: default datasets with overridden limit - if dataset_names: - # User specified dataset names - create new config (fetches all unless max_dataset_size set) - init_kwargs["dataset_config"] = DatasetConfiguration( - dataset_names=dataset_names, - max_dataset_size=max_dataset_size, - ) - elif max_dataset_size is not None: - # User only specified max_dataset_size - override default config's limit - default_config = scenario_class.default_dataset_config() - default_config.max_dataset_size = max_dataset_size - init_kwargs["dataset_config"] = default_config - - # Instantiate and run - print(f"\nRunning scenario: {scenario_name}") - sys.stdout.flush() - - # Scenarios here are a concrete subclass - # Runtime parameters are passed to initialize_async() - scenario = scenario_class() # type: ignore[ty:missing-argument] - # Empty args still triggers missing-required validation + default materialization. - scenario.set_params_from_args(args=scenario_args or {}) - await scenario.initialize_async(**init_kwargs) - result = await scenario.run_async() - - # Print results if requested - if print_summary: - printer = ConsoleScenarioResultPrinter() - await printer.print_summary_async(result) - - return result - - -def _format_wrapped_text(*, text: str, indent: str, width: int = 78) -> str: - """ - Format text with word wrapping. - - Args: - text: Text to wrap. - indent: Indentation string for wrapped lines. - width: Maximum line width. Defaults to 78. - - Returns: - Formatted text with line breaks. - """ - words = text.split() - lines = [] - current_line = "" - - for word in words: - if not current_line: - current_line = word - elif len(current_line) + len(word) + 1 + len(indent) <= width: - current_line += " " + word - else: - lines.append(indent + current_line) - current_line = word - - if current_line: - lines.append(indent + current_line) - - return "\n".join(lines) - - -def _print_header(*, text: str) -> None: - """ - Print a colored header if termcolor is available. - - Args: - text: Header text to print. - """ - if HAS_TERMCOLOR: - termcolor.cprint(f"\n {text}", "cyan", attrs=["bold"]) - else: - print(f"\n {text}") - - -def format_scenario_metadata(*, scenario_metadata: ScenarioMetadata) -> None: - """ - Print formatted information about a scenario class. - - Args: - scenario_metadata: Dataclass containing scenario metadata. - """ - _print_header(text=scenario_metadata.registry_name) - print(f" Class: {scenario_metadata.class_name}") - - description = scenario_metadata.class_description - if description: - print(" Description:") - print(_format_wrapped_text(text=description, indent=" ")) - - if scenario_metadata.aggregate_strategies: - agg_strategies = scenario_metadata.aggregate_strategies - print(" Aggregate Strategies:") - formatted = _format_wrapped_text(text=", ".join(agg_strategies), indent=" - ") - print(formatted) - - if scenario_metadata.all_strategies: - strategies = scenario_metadata.all_strategies - print(f" Available Strategies ({len(strategies)}):") - formatted = _format_wrapped_text(text=", ".join(strategies), indent=" ") - print(formatted) - - if scenario_metadata.default_strategy: - print(f" Default Strategy: {scenario_metadata.default_strategy}") - - if scenario_metadata.default_datasets: - datasets = scenario_metadata.default_datasets - max_size = scenario_metadata.max_dataset_size - if datasets: - size_suffix = f", max {max_size} per dataset" if max_size else "" - print(f" Default Datasets ({len(datasets)}{size_suffix}):") - formatted = _format_wrapped_text(text=", ".join(datasets), indent=" ") - print(formatted) - else: - print(" Default Datasets: None") - - if scenario_metadata.supported_parameters: - print(" Supported Parameters:") - for param in scenario_metadata.supported_parameters: - default_str = f" [default: {param.default!r}]" if param.default is not None else "" - type_display = f" ({param.param_type})" if param.param_type else "" - choices_display = f" [choices: {param.choices}]" if param.choices else "" - print(f" - {param.name}{type_display}{default_str}{choices_display}: {param.description}") - - -def format_initializer_metadata(*, initializer_metadata: InitializerMetadata) -> None: - """ - Print formatted information about an initializer class. - - Args: - initializer_metadata: Dataclass containing initializer metadata. - """ - _print_header(text=initializer_metadata.registry_name) - print(f" Class: {initializer_metadata.class_name}") - - if initializer_metadata.required_env_vars: - print(" Required Environment Variables:") - for env_var in initializer_metadata.required_env_vars: - print(f" - {env_var}") - else: - print(" Required Environment Variables: None") - - if initializer_metadata.supported_parameters: - print(" Supported Parameters:") - for param_name, param_desc, param_default in initializer_metadata.supported_parameters: - default_str = f" [default: {param_default}]" if param_default else "" - print(f" - {param_name}{default_str}: {param_desc}") - - if initializer_metadata.class_description: - print(" Description:") - print(_format_wrapped_text(text=initializer_metadata.class_description, indent=" ")) - - -def resolve_initialization_scripts(script_paths: list[str]) -> list[Path]: - """ - Resolve initialization script paths. - - Args: - script_paths: List of script path strings. - - Returns: - List of resolved Path objects. - - Raises: - FileNotFoundError: If a script path does not exist. - """ - return InitializerRegistry.resolve_script_paths(script_paths=script_paths) - - -async def print_scenarios_list_async(*, context: FrontendCore) -> int: - """ - Print a formatted list of all available scenarios. - - Args: - context: PyRIT context with loaded registries. - - Returns: - Exit code (0 for success). - """ - scenarios = await list_scenarios_async(context=context) - - if not scenarios: - print("No scenarios found.") - return 0 - - print("\nAvailable Scenarios:") - print("=" * 80) - for scenario_metadata in scenarios: - format_scenario_metadata(scenario_metadata=scenario_metadata) - print("\n" + "=" * 80) - print(f"\nTotal scenarios: {len(scenarios)}") - return 0 - - -async def print_initializers_list_async(*, context: FrontendCore) -> int: - """ - Print a formatted list of all available initializers. - - Args: - context: PyRIT context with loaded registries. - - Returns: - Exit code (0 for success). - """ - initializers = await list_initializers_async(context=context) - - if not initializers: - print("No initializers found.") - return 0 - - print("\nAvailable Initializers:") - print("=" * 80) - for initializer_metadata in initializers: - format_initializer_metadata(initializer_metadata=initializer_metadata) - print("\n" + "=" * 80) - print(f"\nTotal initializers: {len(initializers)}") - return 0 - - -async def print_targets_list_async(*, context: FrontendCore) -> int: - """ - Print a formatted list of all available targets from the TargetRegistry. - - Targets are registered by initializers, so this requires initializers to run first. - If no targets are found, prints a hint about using the 'target' initializer. - - Args: - context: PyRIT context with loaded registries. - - Returns: - Exit code (0 for success). - """ - target_names = await list_targets_async(context=context) - - if not target_names: - print("\nNo targets found in registry.") - print( - "\nTargets are registered by initializers. Include an initializer that registers " - "targets, for example:\n --initializers target\n" - ) - return 0 - - target_registry = TargetRegistry.get_registry_singleton() - - print("\nRegistered Targets:") - print("=" * 80) - for name in target_names: - target = target_registry.get_instance_by_name(name) - if target is None: - print(f" {name}") - continue - - model = target._underlying_model or target._model_name or "" - endpoint = target._endpoint or "" - class_name = type(target).__name__ - - _print_header(text=name) - print(f" Class: {class_name}") - if model: - print(f" Model: {model}") - if endpoint: - print(f" Endpoint: {endpoint}") - print("\n" + "=" * 80) - print(f"\nTotal targets: {len(target_names)}") - return 0 diff --git a/pyrit/cli/pyrit_backend.py b/pyrit/cli/pyrit_backend.py deleted file mode 100644 index 819ad7baa9..0000000000 --- a/pyrit/cli/pyrit_backend.py +++ /dev/null @@ -1,265 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -""" -PyRIT Backend CLI - Command-line interface for running the PyRIT backend server. - -This module provides the main entry point for the pyrit_backend command. -""" - -import asyncio -import sys -from argparse import ArgumentParser, Namespace, RawDescriptionHelpFormatter -from pathlib import Path -from typing import Optional - -from pyrit.cli import frontend_core - - -def parse_args(*, args: Optional[list[str]] = None) -> Namespace: - """ - Parse command-line arguments for the PyRIT backend server. - - Returns: - Namespace: Parsed command-line arguments. - """ - parser = ArgumentParser( - prog="pyrit_backend", - description="""PyRIT Backend - Run the PyRIT backend API server - -Examples: - # Start backend with default settings - pyrit_backend - - # Start with built-in initializers - pyrit_backend --initializers airt - - # Start with custom initialization scripts - pyrit_backend --initialization-scripts ./my_targets.py - - # Start with custom port and host - pyrit_backend --host 0.0.0.0 --port 8080 - - # Expose to network (listen on all interfaces) - pyrit_backend --host 0.0.0.0 - - # List available initializers - pyrit_backend --list-initializers -""", - formatter_class=RawDescriptionHelpFormatter, - ) - - parser.add_argument( - "--host", - type=str, - default="localhost", - help="Host to bind the server to (default: localhost)", - ) - - parser.add_argument( - "--port", - type=int, - default=8000, - help="Port to bind the server to (default: 8000)", - ) - - parser.add_argument( - "--config-file", - type=Path, - help=frontend_core.ARG_HELP["config_file"], - ) - - parser.add_argument( - "--log-level", - type=frontend_core.validate_log_level_argparse, - default="INFO", - help="Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL) (default: INFO)", - ) - - parser.add_argument( - "--list-initializers", - action="store_true", - help="List all available initializers and exit", - ) - - parser.add_argument( - "--database", - type=frontend_core.validate_database_argparse, - default=None, - help=( - f"Database type to use for memory storage ({frontend_core.IN_MEMORY}, " - f"{frontend_core.SQLITE}, {frontend_core.AZURE_SQL}). " - f"Defaults to value from config file, or {frontend_core.SQLITE} if not specified." - ), - ) - - parser.add_argument( - "--initializers", - type=frontend_core._parse_initializer_arg, - nargs="+", - help=frontend_core.ARG_HELP["initializers"], - ) - - parser.add_argument( - "--initialization-scripts", - type=str, - nargs="+", - help=frontend_core.ARG_HELP["initialization_scripts"], - ) - - parser.add_argument( - "--env-files", - type=str, - nargs="+", - help=frontend_core.ARG_HELP["env_files"], - ) - - parser.add_argument( - "--reload", - action="store_true", - help="Enable auto-reload for development (watches for file changes)", - ) - - return parser.parse_args(args) - - -async def initialize_and_run_async(*, parsed_args: Namespace) -> int: - """ - Initialize PyRIT and start the backend server. - - Returns: - int: Exit code (0 for success, 1 for error). - """ - from pyrit.setup import initialize_pyrit_async - - # Resolve initialization scripts if provided - initialization_scripts = None - if parsed_args.initialization_scripts: - try: - initialization_scripts = frontend_core.resolve_initialization_scripts( - script_paths=parsed_args.initialization_scripts - ) - except FileNotFoundError as e: - print(f"Error: {e}") - return 1 - - # Resolve env files if provided - env_files = None - if parsed_args.env_files: - try: - env_files = frontend_core.resolve_env_files(env_file_paths=parsed_args.env_files) - except ValueError as e: - print(f"Error: {e}") - return 1 - - # Create context using FrontendCore (handles config file merging) - context = frontend_core.FrontendCore( - config_file=parsed_args.config_file, - database=parsed_args.database, - initialization_scripts=initialization_scripts, - initializer_names=parsed_args.initializers, - env_files=env_files, - log_level=parsed_args.log_level, - ) - - # Initialize PyRIT (loads registries, sets up memory) - print("šŸ”§ Initializing PyRIT...") - await context.initialize_async() - - # Run initializers up-front (backend runs them once at startup, not per-scenario) - initializer_instances = None - if context._initializer_configs: - print(f"Running {len(context._initializer_configs)} initializer(s)...") - initializer_instances = [] - for config in context._initializer_configs: - initializer_class = context.initializer_registry.get_class(config.name) - instance = initializer_class() - if config.args: - instance.set_params_from_args(args=config.args) - initializer_instances.append(instance) - - # Re-initialize with initializers applied - await initialize_pyrit_async( - memory_db_type=context._database, - initialization_scripts=context._initialization_scripts, - initializers=initializer_instances, - env_files=context._env_files, - ) - - # Start uvicorn server - import uvicorn - - from pyrit.backend.main import app - - # Expose configured default labels to the version endpoint - default_labels: dict[str, str] = {} - if context._operator: - default_labels["operator"] = context._operator - if context._operation: - default_labels["operation"] = context._operation - app.state.default_labels = default_labels - app.state.max_concurrent_scenario_runs = context._max_concurrent_scenario_runs - app.state.allow_custom_initializers = context._allow_custom_initializers - - display_host = parsed_args.host - if context._allow_custom_initializers: - print("āš ļø WARNING: Custom initializer registration is ENABLED (allow_custom_initializers: true).") - print(" This allows arbitrary Python code execution via the REST API.") - if parsed_args.host == "0.0.0.0": - print(" 🚨 Server is bound to 0.0.0.0 — accessible from the NETWORK. Use only on trusted networks!") - else: - print(f" Server is bound to {display_host}.") - - print(f"šŸš€ Starting PyRIT backend on http://{display_host}:{parsed_args.port}") - print(f" API Docs: http://{display_host}:{parsed_args.port}/docs") - if parsed_args.host == "0.0.0.0": - print(f" Open in browser: http://localhost:{parsed_args.port}") - - uvicorn_config = uvicorn.Config( - "pyrit.backend.main:app", - host=parsed_args.host, - port=parsed_args.port, - log_level=parsed_args.log_level, - reload=parsed_args.reload, - ) - server = uvicorn.Server(uvicorn_config) - await server.serve() - - return 0 - - -def main(*, args: Optional[list[str]] = None) -> int: - """ - Start the PyRIT backend server CLI. - - Returns: - int: Exit code (0 for success, 1 for error). - """ - # Ensure emoji and other Unicode characters don't crash on Windows consoles - # that use legacy encodings like cp1252. - sys.stdout.reconfigure(errors="replace") # type: ignore[ty:unresolved-attribute] - sys.stderr.reconfigure(errors="replace") # type: ignore[ty:unresolved-attribute] - - try: - parsed_args = parse_args(args=args) - except SystemExit as e: - return e.code if isinstance(e.code, int) else 1 - - # Handle list-initializers command - if parsed_args.list_initializers: - context = frontend_core.FrontendCore(config_file=parsed_args.config_file, log_level=parsed_args.log_level) - return asyncio.run(frontend_core.print_initializers_list_async(context=context)) - - # Run the server - try: - return asyncio.run(initialize_and_run_async(parsed_args=parsed_args)) - except KeyboardInterrupt: - print("\nšŸ›‘ Backend stopped") - return 0 - except Exception as e: - print(f"\nError: {e}") - return 1 - - -if __name__ == "__main__": - sys.exit(main()) diff --git a/pyrit/cli/pyrit_scan.py b/pyrit/cli/pyrit_scan.py index d85c2235fd..2c0d83310a 100644 --- a/pyrit/cli/pyrit_scan.py +++ b/pyrit/cli/pyrit_scan.py @@ -5,6 +5,8 @@ PyRIT CLI - Command-line interface for running security scenarios. This module provides the main entry point for the pyrit_scan command. +It is a thin REST client that talks to the PyRIT backend server over HTTP. +No heavy pyrit imports — all operations go through the REST API. """ from __future__ import annotations @@ -15,44 +17,54 @@ import sys from argparse import ArgumentParser, Namespace, RawDescriptionHelpFormatter from pathlib import Path -from typing import TYPE_CHECKING, Any, Optional, get_origin +from typing import Any, Optional from pyrit.cli._cli_args import ( ARG_HELP, _parse_initializer_arg, - merge_config_scenario_args, non_negative_int, positive_int, validate_log_level_argparse, ) -if TYPE_CHECKING: - from pyrit.common.parameter import Parameter - from pyrit.scenario.core import Scenario +_TERMINAL_STATUSES = {"COMPLETED", "FAILED", "CANCELLED"} -# Namespacing prefix for scenario-declared params on the parsed Namespace. -_SCENARIO_DEST_PREFIX = "scenario__" -_DESCRIPTION = """PyRIT Scanner - Run security scenarios against AI systems +_DESCRIPTION = """PyRIT Scanner - Run AI security scenarios from the command line. + +Requires a running PyRIT backend server. Use --start-server to launch one, +or connect to an existing server with --server-url. Examples: - # List available scenarios, initializers, and targets + # Start the backend server + pyrit_scan --start-server + + # List scenarios, initializers, or targets pyrit_scan --list-scenarios pyrit_scan --list-initializers - pyrit_scan --list-targets --initializers target + pyrit_scan --list-targets - # Run a scenario with a target and initializers - pyrit_scan foundry.red_team_agent --target my_target --initializers target load_default_datasets + # Run single-turn cyber attacks against a target + pyrit_scan airt.cyber --target openai_chat --strategies single_turn - # Run with a configuration file (recommended for complex setups) - pyrit_scan foundry.red_team_agent --target my_target --config-file ./my_config.yaml + # Run rapid response with specific datasets and concurrency + pyrit_scan airt.rapid_response --target openai_chat + --strategies prompt_sending --dataset-names airt_hate + --max-dataset-size 5 --max-concurrency 4 - # Run with custom initialization scripts - pyrit_scan garak.encoding --target my_target --initialization-scripts ./my_config.py + # Run multi-turn red team agent with labels for tracking + pyrit_scan airt.red_team_agent --target openai_chat + --strategies crescendo + --memory-labels '{"experiment":"baseline"}' - # Run specific strategies or options - pyrit_scan foundry.red_team_agent --target my_target --strategies base64 rot13 --initializers target - pyrit_scan foundry.red_team_agent --target my_target --initializers target --max-concurrency 10 --max-retries 3 + # Register a custom initializer from a Python script + pyrit_scan --add-initializer ./my_custom_init.py + + # Connect to a remote server + pyrit_scan --server-url http://remote:8000 --list-scenarios + + # Stop the server + pyrit_scan --stop-server """ @@ -60,13 +72,8 @@ def _build_base_parser(*, add_help: bool = True) -> ArgumentParser: """ Build the ``pyrit_scan`` argparse parser with the built-in (non-scenario) flags. - Reused across the two-pass flow: pass 1 calls with ``add_help=False`` to - identify the scenario name; pass 2 calls with ``add_help=True`` and adds - scenario-declared params on top. - Args: - add_help (bool): Whether to register the standard ``-h``/``--help`` - action. Defaults to True. + add_help (bool): Whether to register the ``-h``/``--help`` action. Returns: ArgumentParser: Parser with all built-in flags registered. @@ -78,60 +85,80 @@ def _build_base_parser(*, add_help: bool = True) -> ArgumentParser: add_help=add_help, ) - parser.add_argument( + # -- Server management -- + server_group = parser.add_argument_group("server") + server_group.add_argument( + "--server-url", + type=str, + help="URL of the PyRIT backend server (default: http://localhost:8000)", + ) + server_group.add_argument( + "--start-server", + action="store_true", + help="Start a local backend server if one is not already running", + ) + server_group.add_argument( + "--stop-server", + action="store_true", + help="Stop the backend server and exit", + ) + server_group.add_argument( "--config-file", type=Path, help=ARG_HELP["config_file"], ) - - parser.add_argument( + server_group.add_argument( "--log-level", type=validate_log_level_argparse, default=logging.WARNING, help="Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL) (default: WARNING)", ) - parser.add_argument( + # -- Discovery -- + discovery_group = parser.add_argument_group("discovery") + discovery_group.add_argument( "--list-scenarios", action="store_true", help="List all available scenarios and exit", ) - - parser.add_argument( + discovery_group.add_argument( "--list-initializers", action="store_true", - help="List all available scenario initializers and exit", + help="List all available initializers and exit", ) - - parser.add_argument( + discovery_group.add_argument( "--list-targets", action="store_true", - help="List all available targets from the TargetRegistry and exit. " - "Requires initializers that register targets (e.g., --initializers target)", + help="List all available targets and exit", + ) + discovery_group.add_argument( + "--add-initializer", + type=str, + nargs="+", + metavar="FILE", + help="Register initializer(s) from Python script file(s) and exit", ) - parser.add_argument( + # -- Scenario run -- + run_group = parser.add_argument_group("scenario run") + run_group.add_argument( "scenario_name", type=str, nargs="?", help="Name of the scenario to run", ) - - parser.add_argument( + run_group.add_argument( + "--target", + type=str, + help=ARG_HELP["target"], + ) + run_group.add_argument( "--initializers", type=_parse_initializer_arg, nargs="+", help=ARG_HELP["initializers"], ) - - parser.add_argument( - "--initialization-scripts", - type=str, - nargs="+", - help=ARG_HELP["initialization_scripts"], - ) - - parser.add_argument( + run_group.add_argument( "--strategies", "-s", type=str, @@ -139,354 +166,569 @@ def _build_base_parser(*, add_help: bool = True) -> ArgumentParser: dest="scenario_strategies", help=ARG_HELP["scenario_strategies"], ) - - parser.add_argument( + run_group.add_argument( "--max-concurrency", type=positive_int, help=ARG_HELP["max_concurrency"], ) - - parser.add_argument( + run_group.add_argument( "--max-retries", type=non_negative_int, help=ARG_HELP["max_retries"], ) - - parser.add_argument( + run_group.add_argument( "--memory-labels", type=str, help=ARG_HELP["memory_labels"], ) - - parser.add_argument( + run_group.add_argument( "--dataset-names", type=str, nargs="+", help=ARG_HELP["dataset_names"], ) - - parser.add_argument( + run_group.add_argument( "--max-dataset-size", type=positive_int, help=ARG_HELP["max_dataset_size"], ) - parser.add_argument( - "--target", - type=str, - help=ARG_HELP["target"], - ) - return parser -def parse_args(args: Optional[list[str]] = None) -> Namespace: - """ - Parse command-line arguments using a two-pass flow. +# Namespacing prefix for scenario-declared params on the parsed Namespace. +_SCENARIO_DEST_PREFIX = "scenario__" - Pass 1 identifies the scenario name with ``parse_known_args`` so unknown - scenario flags don't fail. Pass 2 parses for real, with the resolved - scenario's declared params added as namespaced flags. - The scenario name may come from the CLI positional or, as a fallback, from - the ``scenario.name`` block in ``--config-file`` (or the default config - file). This mirrors the runtime behavior in ``main()`` so config-only - scenario names can still expose their declared CLI flags. +_SCALAR_TYPE_COERCERS: dict[str, Any] = { + "int": int, + "float": float, + "bool": lambda v: str(v).strip().lower() in ("1", "true", "yes", "y", "on"), + "str": str, +} - The CLI positional is only trusted when it resolves to a known scenario. - Pass 1 doesn't yet know about scenario-declared flags, so ``parse_known_args`` - can greedily consume an unknown flag's value (e.g. the ``"7"`` in - ``--max-turns 7``) as the positional. When that happens the positional won't - resolve, and we fall back to the config peek. + +def _scenario_param_kwargs(*, param: dict[str, Any]) -> dict[str, Any]: + """ + Build argparse ``add_argument`` kwargs for a scenario-declared parameter dict. + + Uses ``param_type``, ``is_list`` and ``choices`` from the catalog payload + so list params accept ``nargs='+'`` and scalar params get client-side + type coercion and choice validation. Args: - args (Optional[list[str]]): Argument list (``sys.argv[1:]`` when None). + param: Single entry from ``RegisteredScenario.supported_parameters``. Returns: - Namespace: Parsed command-line arguments. + dict[str, Any]: kwargs ready to pass to ``ArgumentParser.add_argument``. """ - pass1_parser = _build_base_parser(add_help=False) - parsed_pass1, _ = pass1_parser.parse_known_args(args) + kwargs: dict[str, Any] = { + "dest": f"{_SCENARIO_DEST_PREFIX}{param.get('name', '')}", + "default": argparse.SUPPRESS, + "help": param.get("description", ""), + } + if param.get("is_list"): + kwargs["nargs"] = "+" + else: + coercer = _SCALAR_TYPE_COERCERS.get(param.get("param_type", "")) + if coercer is not None and coercer is not str: + param_name = param.get("name", "") + + def _typed(raw: str) -> Any: + try: + return coercer(raw) + except (ValueError, TypeError) as exc: + raise argparse.ArgumentTypeError( + f"--{param_name.replace('_', '-')}: invalid value {raw!r} ({exc})" + ) from exc + + kwargs["type"] = _typed + choices = param.get("choices") + if choices: + kwargs["choices"] = list(choices) + return kwargs + + +def _add_scenario_params_from_api(*, parser: ArgumentParser, params: list[dict[str, Any]]) -> None: + """ + Add scenario-declared parameters (from the API response) as CLI flags. + + Args: + parser: Parser to extend. + params: List of parameter dicts from ``GET /api/scenarios/catalog/{name}``. + """ + seen_flags: set[str] = set(parser._option_string_actions.keys()) + for p in params: + name = p.get("name", "") + flag = f"--{name.replace('_', '-')}" + if flag in seen_flags: + continue + parser.add_argument(flag, **_scenario_param_kwargs(param=p)) + seen_flags.add(flag) - scenario_class = _resolve_scenario_class(parsed_pass1.scenario_name) - if scenario_class is None: - fallback_name = _peek_scenario_name_from_config(config_file=parsed_pass1.config_file) - scenario_class = _resolve_scenario_class(fallback_name) - pass2_parser = _build_base_parser(add_help=True) - if scenario_class is not None: - _add_scenario_params(parser=pass2_parser, declared=scenario_class.supported_parameters()) +def _extract_scenario_args(*, parsed: Namespace) -> dict[str, Any]: + """ + Pull scenario-declared parameter values out of a parsed Namespace. - return pass2_parser.parse_args(args) + Args: + parsed: Result of ``ArgumentParser.parse_args``. + + Returns: + dict[str, Any]: Map of original parameter name to value. + """ + return { + key.removeprefix(_SCENARIO_DEST_PREFIX): value + for key, value in vars(parsed).items() + if key.startswith(_SCENARIO_DEST_PREFIX) + } -def _peek_scenario_name_from_config(*, config_file: Optional[Path]) -> Optional[str]: +def parse_args(args: Optional[list[str]] = None) -> Namespace: """ - Best-effort lookup of the scenario name in layered config (default + explicit). + Parse command-line arguments (pass 1 — tolerant of scenario-declared flags). - Pass 1 of ``parse_args`` needs the scenario name to register that scenario's - declared parameters as flags. Failures are swallowed: if the YAML is missing - or malformed, return ``None`` and let ``main`` surface the canonical error. + Pass 1 uses ``parse_known_args`` so scenario-specific flags (e.g. + ``--max-turns 7``) don't cause an error before we've had a chance to + fetch the scenario's declared parameters from the server. The unknown + leftovers are stashed on the returned Namespace as ``_unknown_args`` + so :func:`_reparse_with_scenario_params` can detect truly unknown flags + when no scenario was specified. Args: - config_file (Optional[Path]): Path from ``--config-file``. + args: Argument list (``sys.argv[1:]`` when None). Returns: - Optional[str]: The scenario name, or ``None`` if not configured / unavailable. + Namespace: Parsed command-line arguments. """ - from pyrit.common.path import DEFAULT_CONFIG_PATH - from pyrit.setup.configuration_loader import ConfigurationLoader + parser = _build_base_parser(add_help=True) + parsed, unknown = parser.parse_known_args(args) + parsed._unknown_args = unknown + parsed._raw_args = list(args) if args is not None else list(sys.argv[1:]) + return parsed - paths: list[Path] = [] - if DEFAULT_CONFIG_PATH.exists(): - paths.append(DEFAULT_CONFIG_PATH) - if config_file is not None and config_file.exists(): - paths.append(config_file) - name: Optional[str] = None - for path in paths: - try: - loaded = ConfigurationLoader.from_yaml_file(path) - except Exception: - continue - if loaded.scenario_config is not None: - name = loaded.scenario_config.name - return name +async def _resolve_server_url_async(*, parsed_args: Namespace) -> str | None: + """ + Determine the server URL and ensure it is reachable. + + Resolution order: + 1. ``--server-url`` CLI flag + 2. ``server.url`` from config file + 3. Default ``http://localhost:8000`` + If ``--start-server`` is set and the server is not healthy, launches + a local ``pyrit_backend`` subprocess. -def _resolve_scenario_class(scenario_name: Optional[str]) -> Optional[type[Scenario]]: + Returns: + str | None: The server base URL, or ``None`` if unreachable. """ - Look up a built-in scenario class by name. Returns None if missing or unknown. + from pyrit.cli._config_reader import DEFAULT_SERVER_URL, read_server_url + from pyrit.cli._server_launcher import ServerLauncher - v1 limitation: user-defined scenarios from ``--initialization-scripts`` - are not augmented at parse time. + base_url = parsed_args.server_url + if base_url is None: + base_url = read_server_url(config_file=parsed_args.config_file) or DEFAULT_SERVER_URL - Args: - scenario_name (Optional[str]): Positional scenario name from pass 1. + # Probe existing server + if await ServerLauncher.probe_health_async(base_url=base_url): + return base_url + + # Auto-start if requested + if parsed_args.start_server: + launcher = ServerLauncher() + try: + return await launcher.start_async(config_file=parsed_args.config_file) + except RuntimeError as exc: + print(f"Error: {exc}") + return None + + return None + + +def _is_command_specified(*, parsed_args: Namespace) -> bool: + """ + Return True if the user supplied any actionable command flag (besides + ``--start-server`` / ``--stop-server``). Returns: - Optional[type[Scenario]]: The scenario class, or None. + bool: ``True`` if at least one actionable command flag was provided. """ - if not scenario_name: - return None - from pyrit.registry import ScenarioRegistry + return bool( + parsed_args.list_scenarios + or parsed_args.list_initializers + or parsed_args.list_targets + or parsed_args.add_initializer + or parsed_args.scenario_name + ) - registry = ScenarioRegistry.get_registry_singleton() - try: - return registry.get_class(scenario_name) - except KeyError: - return None +def _resolve_configured_server_url(*, parsed_args: Namespace) -> str: + """ + Resolve the effective server URL (without probing). -def _add_scenario_params(*, parser: ArgumentParser, declared: list[Parameter]) -> None: + Returns: + str: The configured server URL, falling back to the built-in default. """ - Add scenario-declared parameters to ``parser`` as ``--kebab-case`` flags. + from pyrit.cli._config_reader import DEFAULT_SERVER_URL, read_server_url - Each flag uses ``dest=scenario__``, ``default=argparse.SUPPRESS``, - and a coercion ``type=`` from ``pyrit.common.parameter``. + return parsed_args.server_url or read_server_url(config_file=parsed_args.config_file) or DEFAULT_SERVER_URL - Args: - parser (ArgumentParser): Parser to extend. - declared (list[Parameter]): Scenario's declared parameters. - Raises: - ValueError: If a scenario-derived flag collides with a built-in flag or - with another scenario param that normalizes to the same kebab form. +async def _handle_stop_server_async(*, parsed_args: Namespace) -> int: """ - # Seed from existing flags so we catch built-in collisions; grow as we add. - seen_flags: set[str] = set(parser._option_string_actions.keys()) - for param in declared: - flag = f"--{param.name.replace('_', '-')}" - if flag in seen_flags: - raise ValueError( - f"Scenario parameter '{param.name}' collides with an existing flag {flag!r}. " - f"This is either a built-in CLI flag or another scenario parameter that " - f"normalizes to the same kebab-case form. Rename the parameter." - ) - kwargs: dict[str, Any] = { - "dest": f"{_SCENARIO_DEST_PREFIX}{param.name}", - "default": argparse.SUPPRESS, - "help": param.description, - } - type_callable = _argparse_type_for(param=param) - if type_callable is not None: - kwargs["type"] = type_callable - if _is_list_param(param.param_type): - kwargs["nargs"] = "+" - if param.choices is not None: - kwargs["choices"] = list(param.choices) - parser.add_argument(flag, **kwargs) - seen_flags.add(flag) + Handle ``--stop-server``: probe, then terminate the listening process. + Returns: + int: Exit code (always ``0``). + """ + from urllib.parse import urlparse + + from pyrit.cli._server_launcher import ServerLauncher, stop_server_on_port -def _argparse_type_for(*, param: Parameter) -> Optional[Any]: + base_url = _resolve_configured_server_url(parsed_args=parsed_args) + if not await ServerLauncher.probe_health_async(base_url=base_url): + print(f"No server running at {base_url}.") + return 0 + + port = urlparse(base_url).port or 8000 + if stop_server_on_port(port=port): + print(f"Server on port {port} stopped.") + else: + print(f"Server at {base_url} is running but could not identify the process.") + print(f"Find and kill it manually: look for a process listening on port {port}.") + return 0 + + +async def _handle_list_commands_async(*, client: Any, parsed_args: Namespace) -> int | None: """ - Map a ``Parameter`` to an argparse ``type=`` callable, or ``None`` for str/raw. + Dispatch ``--list-*`` flags. - For list params, ``None`` is correct because ``nargs='+'`` collects strings; - list element validation happens via ``coerce_list`` at scenario-set time. + Returns: + int | None: Exit code if a flag was handled, else ``None``. + """ + from pyrit.cli import _output + + if parsed_args.list_scenarios: + resp = await client.list_scenarios_async() + _output.print_scenario_list(items=resp.get("items", [])) + return 0 + if parsed_args.list_initializers: + resp = await client.list_initializers_async() + _output.print_initializer_list(items=resp.get("items", [])) + return 0 + if parsed_args.list_targets: + resp = await client.list_targets_async() + _output.print_target_list(items=resp.get("items", [])) + return 0 + return None - Args: - param (Parameter): The scenario-declared parameter. + +async def _handle_add_initializer_async(*, client: Any, parsed_args: Namespace) -> int: + """ + Handle ``--add-initializer``: upload one or more scripts to the server. Returns: - Optional[Any]: Coercion callable, or ``None`` if no coercion is needed. + int: Exit code (``0`` on success, ``1`` on failure). """ - from pyrit.common.parameter import coerce_value + from pyrit.cli.api_client import ServerNotAvailableError - param_type = param.param_type - if param_type is None or param_type is str or _is_list_param(param_type): - return None - return lambda raw: coerce_value(param=param, raw_value=raw) + for script_path_str in parsed_args.add_initializer: + script_path = Path(script_path_str).resolve() + if not script_path.exists(): + print(f"Error: File not found: {script_path}") + return 1 + try: + script_content = script_path.read_text() + await client.register_initializer_async( + name=script_path.stem, + script_content=script_content, + ) + print(f"Registered initializer '{script_path.stem}' from {script_path}") + except ServerNotAvailableError as exc: + print(f"Error: {exc}") + return 1 + return 0 -def _is_list_param(param_type: Any) -> bool: - """Return True when ``param_type`` is a parameterized list generic (e.g. ``list[str]``).""" - return get_origin(param_type) is list +def _reparse_with_scenario_params( + *, parsed_args: Namespace, supported_params: list[dict[str, Any]] +) -> Namespace | None: + """ + Re-parse the original args with scenario-declared flags added to the base parser. + The original argument list is read from ``parsed_args._raw_args`` (populated + by :func:`parse_args`). If no scenario-declared parameters are supplied but + pass 1 left unknown args behind, surface the error now via strict re-parse. -def _extract_scenario_args(*, parsed: Namespace) -> dict[str, Any]: + Returns: + Namespace | None: The re-parsed Namespace, or ``None`` on argparse ``SystemExit``. """ - Pull scenario-declared parameter values out of a parsed Namespace. + raw_args: list[str] = getattr(parsed_args, "_raw_args", sys.argv[1:] if len(sys.argv) > 1 else []) + + if not supported_params: + unknown = getattr(parsed_args, "_unknown_args", None) + if not unknown: + return parsed_args + # Re-parse strictly so argparse prints the standard "unrecognized arguments" error + strict_parser = _build_base_parser(add_help=True) + try: + return strict_parser.parse_args(raw_args) + except SystemExit: + return None + + pass2_parser = _build_base_parser(add_help=True) + _add_scenario_params_from_api(parser=pass2_parser, params=supported_params) + try: + return pass2_parser.parse_args(raw_args) + except SystemExit: + return None - Args: - parsed (Namespace): Result of ``ArgumentParser.parse_args``. + +def _build_run_request(*, parsed_args: Namespace, scenario_name: str) -> dict[str, Any]: + """ + Build the ``RunScenarioRequest`` dict from parsed CLI args. Returns: - dict[str, Any]: Map of original parameter name to coerced value. - Empty when the scenario declares no parameters or the user - supplied none. + dict[str, Any]: The request payload to send to ``POST /api/scenarios/runs``. """ - return { - key.removeprefix(_SCENARIO_DEST_PREFIX): value - for key, value in vars(parsed).items() - if key.startswith(_SCENARIO_DEST_PREFIX) + from pyrit.cli._cli_args import parse_memory_labels + + request: dict[str, Any] = { + "scenario_name": scenario_name, + "target_name": parsed_args.target or "", } + if parsed_args.initializers: + init_names: list[str] = [] + init_args: dict[str, dict[str, Any]] = {} + for entry in parsed_args.initializers: + if isinstance(entry, str): + init_names.append(entry) + elif isinstance(entry, dict): + name = entry["name"] + init_names.append(name) + if entry.get("args"): + init_args[name] = entry["args"] + request["initializers"] = init_names + if init_args: + request["initializer_args"] = init_args + + if parsed_args.scenario_strategies: + request["strategies"] = parsed_args.scenario_strategies + if parsed_args.max_concurrency is not None: + request["max_concurrency"] = parsed_args.max_concurrency + if parsed_args.max_retries is not None: + request["max_retries"] = parsed_args.max_retries + if parsed_args.dataset_names: + request["dataset_names"] = parsed_args.dataset_names + if parsed_args.max_dataset_size is not None: + request["max_dataset_size"] = parsed_args.max_dataset_size + if parsed_args.memory_labels: + request["labels"] = parse_memory_labels(json_string=parsed_args.memory_labels) + + scenario_params = _extract_scenario_args(parsed=parsed_args) + if scenario_params: + request["scenario_params"] = scenario_params + + return request + + +async def _poll_until_terminal_async( + *, + client: Any, + scenario_result_id: str, + total_strategies: int, +) -> dict[str, Any]: + """ + Poll the server until the run reaches a terminal status. -def main(args: Optional[list[str]] = None) -> int: + Returns: + dict[str, Any]: The final run dict. """ - Start the PyRIT scanner CLI. + from pyrit.cli import _output + + while True: + run = await client.get_scenario_run_async(scenario_result_id=scenario_result_id) + status = run.get("status", "UNKNOWN") + _output.print_scenario_run_progress(run=run, total_strategies=total_strategies) + if status in _TERMINAL_STATUSES: + return run + await asyncio.sleep(0.5) + + +async def _run_scenario_async( + *, + client: Any, + parsed_args: Namespace, + scenario_meta: dict[str, Any], +) -> int: + """ + Start a scenario run, poll for completion, and print results. Returns: - int: Exit code (0 for success, 1 for error). + int: Exit code (``0`` if the run completed successfully, ``1`` otherwise). """ - try: - parsed_args = parse_args(args) - except SystemExit as e: - return e.code if isinstance(e.code, int) else 1 + from pyrit.cli import _output + + scenario_name = parsed_args.scenario_name + request = _build_run_request(parsed_args=parsed_args, scenario_name=scenario_name) - print("Starting PyRIT...") + total_strategies = len(request.get("strategies") or scenario_meta.get("all_strategies") or []) + print(f"\nRunning scenario: {scenario_name}") sys.stdout.flush() - # Defer the heavy import until after arg parsing so --help is instant. - from pyrit.cli import frontend_core + try: + run = await client.start_scenario_run_async(request=request) + except Exception as exc: + print(f"Error starting scenario: {exc}") + return 1 - # Handle list commands (don't need full context) - if parsed_args.list_scenarios: - # Simple context just for listing - initialization_scripts = None - if parsed_args.initialization_scripts: - try: - initialization_scripts = frontend_core.resolve_initialization_scripts( - script_paths=parsed_args.initialization_scripts - ) - except FileNotFoundError as e: - print(f"Error: {e}") - return 1 - - context = frontend_core.FrontendCore( - config_file=parsed_args.config_file, - initialization_scripts=initialization_scripts, - log_level=parsed_args.log_level, + scenario_result_id = run.get("scenario_result_id", "") + + try: + run = await _poll_until_terminal_async( + client=client, + scenario_result_id=scenario_result_id, + total_strategies=total_strategies, ) + except KeyboardInterrupt: + print("\n\nCancelling scenario run...") + try: + await client.cancel_scenario_run_async(scenario_result_id=scenario_result_id) + print("Scenario run cancelled.") + except Exception: + print("Warning: could not cancel scenario run on server.") + return 1 - return asyncio.run(frontend_core.print_scenarios_list_async(context=context)) + if run.get("status") == "COMPLETED": + try: + detail = await client.get_scenario_run_results_async(scenario_result_id=scenario_result_id) + await _output.print_scenario_result_async(result_dict=detail) + except Exception: + _output.print_scenario_run_summary(run=run) + else: + _output.print_scenario_run_summary(run=run) - if parsed_args.list_initializers: - context = frontend_core.FrontendCore( - config_file=parsed_args.config_file, - log_level=parsed_args.log_level, - ) - return asyncio.run(frontend_core.print_initializers_list_async(context=context)) + return 0 if run.get("status") == "COMPLETED" else 1 - if parsed_args.list_targets: - # Need initializers or initialization scripts to populate the target registry - initialization_scripts = None - if parsed_args.initialization_scripts: - try: - initialization_scripts = frontend_core.resolve_initialization_scripts( - script_paths=parsed_args.initialization_scripts - ) - except FileNotFoundError as e: - print(f"Error: {e}") - return 1 - - context = frontend_core.FrontendCore( - config_file=parsed_args.config_file, - initialization_scripts=initialization_scripts, - initializer_names=parsed_args.initializers, - log_level=parsed_args.log_level, - ) - return asyncio.run(frontend_core.print_targets_list_async(context=context)) - # Run scenario (verify scenario name from CLI positional or config block) - try: - # Collect initialization scripts - initialization_scripts = None - if parsed_args.initialization_scripts: - initialization_scripts = frontend_core.resolve_initialization_scripts( - script_paths=parsed_args.initialization_scripts - ) +async def _dispatch_with_client_async(*, client: Any, parsed_args: Namespace) -> int: + """ + Dispatch list/add-initializer/scenario-run commands once a client is open. - # Create context with initializers - context = frontend_core.FrontendCore( - config_file=parsed_args.config_file, - initialization_scripts=initialization_scripts, - initializer_names=parsed_args.initializers, - log_level=parsed_args.log_level, - ) + Returns: + int: Exit code from the dispatched command. + """ + list_result = await _handle_list_commands_async(client=client, parsed_args=parsed_args) + if list_result is not None: + return list_result - # Resolve the effective scenario name: CLI positional wins, config falls through. - config_scenario = context._scenario_config - effective_scenario_name = parsed_args.scenario_name or (config_scenario.name if config_scenario else None) - if not effective_scenario_name: - print("Error: No scenario specified. Provide one positionally or via the config file's `scenario:` block.") - return 1 + if parsed_args.add_initializer: + return await _handle_add_initializer_async(client=client, parsed_args=parsed_args) - # Parse memory labels if provided - memory_labels = None - if parsed_args.memory_labels: - memory_labels = frontend_core.parse_memory_labels(json_string=parsed_args.memory_labels) + scenario_name = parsed_args.scenario_name + if not scenario_name: + print("Error: No scenario specified. Provide one positionally or use --list-scenarios.") + return 1 - # Merge scenario args (CLI wins per-key over config args). - merged_scenario_args = merge_config_scenario_args( - config_scenario=config_scenario, - effective_scenario_name=effective_scenario_name, - cli_args=_extract_scenario_args(parsed=parsed_args), - ) + scenario_meta = await client.get_scenario_async(scenario_name=scenario_name) + if scenario_meta is None: + print(f"Error: Scenario '{scenario_name}' not found on server.") + resp = await client.list_scenarios_async() + names = [s.get("scenario_name", "") for s in resp.get("items", [])] + if names: + print(f"Available scenarios: {', '.join(names)}") + return 1 - # Run scenario - asyncio.run( - frontend_core.run_scenario_async( - scenario_name=effective_scenario_name, - context=context, - target_name=parsed_args.target, - scenario_strategies=parsed_args.scenario_strategies, - max_concurrency=parsed_args.max_concurrency, - max_retries=parsed_args.max_retries, - memory_labels=memory_labels, - dataset_names=parsed_args.dataset_names, - max_dataset_size=parsed_args.max_dataset_size, - scenario_args=merged_scenario_args, - ) + reparsed = _reparse_with_scenario_params( + parsed_args=parsed_args, + supported_params=scenario_meta.get("supported_parameters") or [], + ) + if reparsed is None: + return 1 + parsed_args = reparsed + + return await _run_scenario_async(client=client, parsed_args=parsed_args, scenario_meta=scenario_meta) + + +async def _run_async(*, parsed_args: Namespace) -> int: + """ + Core async logic for pyrit_scan. + + Returns: + int: Exit code (0 for success, 1 for error). + """ + from pyrit.cli import _output + from pyrit.cli.api_client import PyRITApiClient, ServerNotAvailableError + + if parsed_args.stop_server: + return await _handle_stop_server_async(parsed_args=parsed_args) + + if not (parsed_args.start_server or _is_command_specified(parsed_args=parsed_args)): + _build_base_parser().print_help() + return 0 + + base_url_result = await _resolve_server_url_async(parsed_args=parsed_args) + if base_url_result is None: + attempted = _resolve_configured_server_url(parsed_args=parsed_args) + _output.print_error_with_hint( + message=f"Server not available at {attempted}", + hint="Use '--start-server' to launch a local backend, or pass '--server-url '.", ) + return 1 + + # --start-server with no other command: just confirm and exit + if not _is_command_specified(parsed_args=parsed_args): + print(f"Server is running at {base_url_result}") return 0 - except Exception as e: - print(f"\nError: {e}") + try: + async with PyRITApiClient(base_url=base_url_result) as client: + return await _dispatch_with_client_async(client=client, parsed_args=parsed_args) + except ServerNotAvailableError as exc: + _output.print_error_with_hint( + message=str(exc), + hint="Use '--start-server' to launch a local backend, or pass '--server-url '.", + ) + return 1 + except Exception as exc: + print(f"\nError: {exc}") return 1 +def main(args: Optional[list[str]] = None) -> int: + """ + Start the PyRIT scanner CLI. + + Returns: + int: Exit code (0 for success, 1 for error). + """ + try: + parsed_args = parse_args(args) + except SystemExit as e: + return e.code if isinstance(e.code, int) else 1 + + # If there are leftover unknown flags AND no scenario was specified, + # there's no chance for pass 2 to recognize them - fail loudly now. + unknown = getattr(parsed_args, "_unknown_args", []) + if unknown and not parsed_args.scenario_name: + strict_parser = _build_base_parser(add_help=True) + try: + strict_parser.parse_args(parsed_args._raw_args) + except SystemExit as e: + return e.code if isinstance(e.code, int) else 1 + + logging.basicConfig(level=parsed_args.log_level) + + # Surface a one-line deprecation when the layered config contains blocks + # the thin CLI no longer reads (e.g. `scenario:`). The server still honors them. + from pyrit.cli._config_reader import warn_on_client_ignored_blocks + + warn_on_client_ignored_blocks(config_file=parsed_args.config_file) + + return asyncio.run(_run_async(parsed_args=parsed_args)) + + if __name__ == "__main__": sys.exit(main()) diff --git a/pyrit/cli/pyrit_shell.py b/pyrit/cli/pyrit_shell.py index 122122daac..c8c536b69c 100644 --- a/pyrit/cli/pyrit_shell.py +++ b/pyrit/cli/pyrit_shell.py @@ -4,256 +4,258 @@ """ PyRIT Shell - Interactive REPL for PyRIT. -This module provides an interactive shell where PyRIT modules are loaded once -at startup, making subsequent commands instant. +This module provides an interactive shell that talks to the PyRIT backend +server over HTTP. No heavy pyrit imports — all operations go through REST. """ from __future__ import annotations import asyncio import cmd +import contextlib import logging import sys import threading from pathlib import Path -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Optional, TypeVar + +from pyrit.cli import _banner as banner if TYPE_CHECKING: - import types + from collections.abc import Coroutine - from pyrit.cli import frontend_core - from pyrit.models.scenario_result import ScenarioResult - -from pyrit.cli import _banner as banner -from pyrit.cli._cli_args import merge_config_scenario_args -from pyrit.registry import ScenarioRegistry +_T = TypeVar("_T") class PyRITShell(cmd.Cmd): """ - Interactive shell for PyRIT. + Interactive shell for PyRIT (thin REST client). Commands: list-scenarios - List all available scenarios list-initializers - List all available initializers - list-targets [opts] - List all available targets from the registry + list-targets - List all available targets run [opts] - Run a scenario with optional parameters - scenario-history - List all previous scenario runs - print-scenario [N] - Print detailed results for scenario run(s) + scenario-history [N] - List the last N (default 10) scenario runs + print-scenario [id] - Print detailed results for a scenario run + start-server - Start a local backend server + stop-server - Stop the owned backend server help [command] - Show help for a command clear - Clear the screen exit (quit, q) - Exit the shell - - Shell Startup Options: - --config-file Path to config file (default: ~/.pyrit/.pyrit_conf) - --log-level Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL) - default for all runs - --no-animation Disable the animated startup banner - - Run Command Options: - --target Target name from the TargetRegistry (required) - --initializers ... Built-in initializers (supports name:key=val1,val2 syntax) - --initialization-scripts <...> Custom Python scripts to run before the scenario - --strategies, -s ... Strategy names to use - --max-concurrency Maximum concurrent operations - --max-retries Maximum retry attempts - --memory-labels JSON string of labels - --log-level Override default log level for this run """ prompt = "pyrit> " + _TERMINAL_STATUSES = {"COMPLETED", "FAILED", "CANCELLED"} + def __init__( self, *, no_animation: bool = False, - config_file: Optional[Path] = None, - database: Optional[str] = None, - initialization_scripts: Optional[list[Path]] = None, - initializer_names: Optional[list[Any]] = None, - env_files: Optional[list[Path]] = None, - log_level: Optional[int] = None, + server_url: str | None = None, + config_file: Path | None = None, + start_server: bool = False, ) -> None: """ Initialize the PyRIT shell. - The heavy ``frontend_core`` import, ``FrontendCore`` construction, and - ``initialize_async`` call all happen on a background thread so the - shell prompt appears immediately. - Args: - no_animation (bool): If True, skip the animated startup banner. - config_file (Optional[Path]): Path to a YAML configuration file. - database (Optional[str]): Database type (InMemory, SQLite, or AzureSQL). - initialization_scripts (Optional[list[Path]]): Initialization script paths. - initializer_names (Optional[list[Any]]): Initializer entries (names or dicts). - env_files (Optional[list[Path]]): Environment file paths to load in order. - log_level (Optional[int]): Logging level constant (e.g., ``logging.WARNING``). + no_animation: If True, skip the animated startup banner. + server_url: Optional explicit server URL. + config_file: Optional config file path. + start_server: If True, auto-start a local backend. """ super().__init__() self._no_animation = no_animation - self._context_kwargs: dict[str, Any] = { - k: v - for k, v in { - "config_file": config_file, - "database": database, - "initialization_scripts": initialization_scripts, - "initializer_names": initializer_names, - "env_files": env_files, - "log_level": log_level, - }.items() - if v is not None - } + self._server_url = server_url + self._config_file = config_file + self._start_server = start_server + self._api_client: Any = None # PyRITApiClient (lazy) + self._base_url: str | None = None + self._launcher: Any = None # ServerLauncher (lazy) + + # Persistent event loop running on a background thread. All async + # calls (health probe, REST methods, scenario polling) are scheduled + # here so the shared httpx.AsyncClient stays in a single loop. + self._loop: asyncio.AbstractEventLoop = asyncio.new_event_loop() + self._loop_thread = threading.Thread(target=self._loop.run_forever, name="pyrit-shell-loop", daemon=True) + self._loop_thread.start() + + def _run_async(self, coro: Coroutine[Any, Any, _T]) -> _T: + """ + Run a coroutine on the shell's persistent loop and return its result. - # Track scenario execution history: list of (command_string, ScenarioResult) tuples - self._scenario_history: list[tuple[str, ScenarioResult]] = [] + Args: + coro: Coroutine to schedule on the background loop. - # Set by the background thread after importing frontend_core. - self._fc: types.ModuleType | None = None - self.context: frontend_core.FrontendCore | None = None - self.default_log_level: int | None = None + Returns: + The coroutine's result. + """ + future = asyncio.run_coroutine_threadsafe(coro, self._loop) + return future.result() + + def _shutdown_loop(self) -> None: + """Stop the background event loop and join the thread.""" + if not self._loop.is_closed(): + self._loop.call_soon_threadsafe(self._loop.stop) + self._loop_thread.join(timeout=5) + with contextlib.suppress(Exception): + self._loop.close() + + def _resolve_base_url(self) -> str: + """ + Determine the server base URL. - # Initialize PyRIT in background thread for faster startup. - self._init_thread = threading.Thread(target=self._background_init, daemon=True) - self._init_complete = threading.Event() - self._init_error: Optional[BaseException] = None - self._init_thread.start() + Returns: + str: The configured base URL, falling back to the built-in default. + """ + from pyrit.cli._config_reader import DEFAULT_SERVER_URL, read_server_url - def _background_init(self) -> None: - """Import heavy modules and initialize PyRIT in the background.""" - try: - from pyrit.cli import frontend_core as fc - - self._fc = fc - self.context = fc.FrontendCore(**self._context_kwargs) - self.default_log_level = self.context._log_level - asyncio.run(self.context.initialize_async()) - except BaseException as exc: - self._init_error = exc - finally: - self._init_complete.set() - - def _raise_init_error(self) -> None: - """Re-raise background initialization failures on the calling thread.""" - if self._init_error is not None: - raise self._init_error - - def _ensure_initialized(self) -> None: + if self._server_url: + return self._server_url + return read_server_url(config_file=self._config_file) or DEFAULT_SERVER_URL + + def _ensure_client(self) -> bool: """ - Wait for initialization to complete if not already done. + Ensure the API client is connected. - Raises: - RuntimeError: If frontend core initialization failed or is not complete. + Returns: + bool: ``True`` if the client is ready, ``False`` otherwise. """ - if not self._init_complete.is_set(): - print("Waiting for PyRIT initialization to complete...") - sys.stdout.flush() - self._init_complete.wait() - self._raise_init_error() - if self._fc is None or self.context is None: - raise RuntimeError("Frontend core not initialized") + if self._api_client is not None: + return True + + base_url = self._base_url or self._resolve_base_url() + + # Check health + from pyrit.cli._server_launcher import ServerLauncher + + healthy = self._run_async(ServerLauncher.probe_health_async(base_url=base_url)) + + if not healthy and self._start_server: + self._launcher = ServerLauncher() + try: + base_url = self._run_async(self._launcher.start_async(config_file=self._config_file)) + healthy = True + except RuntimeError as exc: + print(f"Error starting server: {exc}") + return False + + if not healthy: + from pyrit.cli._output import print_error_with_hint + + print_error_with_hint( + message=f"Server not available at {base_url}", + hint="Use 'start-server' to launch a local backend, or restart with --server-url.", + ) + return False + + from pyrit.cli.api_client import PyRITApiClient + + self._base_url = base_url + self._api_client = PyRITApiClient(base_url=base_url) + self._run_async(self._api_client.__aenter__()) + self._start_server = False # only auto-start once + return True def cmdloop(self, intro: Optional[str] = None) -> None: """Override cmdloop to play animated banner before starting the REPL.""" if intro is None: - # Play animation immediately while background init continues. - # Suppress logging during the animation so log lines don't corrupt - # the ANSI cursor-positioned frames. prev_disable = logging.root.manager.disable logging.disable(logging.CRITICAL) try: intro = banner.play_animation(no_animation=self._no_animation) finally: logging.disable(prev_disable) - - # If init already failed while the animation played, surface it now. - if self._init_complete.is_set(): - self._raise_init_error() - elif self._init_complete.is_set(): - self._raise_init_error() self.intro = intro super().cmdloop(intro=self.intro) - def do_list_scenarios(self, arg: str) -> None: - """ - List all available scenarios. + # ------------------------------------------------------------------ + # List commands + # ------------------------------------------------------------------ - Raises: - RuntimeError: If initialization has not completed. - """ + def do_list_scenarios(self, arg: str) -> None: + """List all available scenarios.""" if arg.strip(): print(f"Error: list-scenarios does not accept arguments, got: {arg.strip()}") return - self._ensure_initialized() - if self._fc is None or self.context is None: - raise RuntimeError("Frontend core not initialized") + if not self._ensure_client(): + return + from pyrit.cli import _output + try: - asyncio.run(self._fc.print_scenarios_list_async(context=self.context)) + resp = self._run_async(self._api_client.list_scenarios_async()) + _output.print_scenario_list(items=resp.get("items", [])) except Exception as e: print(f"Error listing scenarios: {e}") def do_list_initializers(self, arg: str) -> None: - """ - List all available initializers. - - Raises: - RuntimeError: If initialization has not completed. - """ + """List all available initializers.""" if arg.strip(): print(f"Error: list-initializers does not accept arguments, got: {arg.strip()}") return - self._ensure_initialized() - if self._fc is None or self.context is None: - raise RuntimeError("Frontend core not initialized") + if not self._ensure_client(): + return + from pyrit.cli import _output + try: - asyncio.run(self._fc.print_initializers_list_async(context=self.context)) + resp = self._run_async(self._api_client.list_initializers_async()) + _output.print_initializer_list(items=resp.get("items", [])) except Exception as e: print(f"Error listing initializers: {e}") def do_list_targets(self, arg: str) -> None: + """List all available targets.""" + if arg.strip(): + print(f"Error: list-targets does not accept arguments, got: {arg.strip()}") + return + if not self._ensure_client(): + return + from pyrit.cli import _output + + try: + resp = self._run_async(self._api_client.list_targets_async()) + _output.print_target_list(items=resp.get("items", [])) + except Exception as e: + print(f"Error listing targets: {e}") + + def do_add_initializer(self, arg: str) -> None: """ - List all available targets from the TargetRegistry. + Register an initializer from a Python script file. Usage: - list-targets - list-targets --initializers [ ...] - list-targets --initialization-scripts [ ...] - - Options: - --initializers ... Built-in initializers to run first - --initialization-scripts <...> Custom Python scripts to run first + add-initializer [ ...] + """ + if not self._ensure_client(): + return + if not arg.strip(): + print("Usage: add-initializer [ ...]") + return - Examples: - list-targets --initializers target - list-targets --initializers target:tags=default,scorer + from pyrit.cli.api_client import ServerNotAvailableError - Raises: - RuntimeError: If initialization has not completed. - """ - self._ensure_initialized() - if self._fc is None or self.context is None: - raise RuntimeError("Frontend core not initialized") - try: - list_targets_context = self.context - if arg.strip(): - args = self._fc.parse_list_targets_arguments(args_string=arg) - - resolved_scripts = None - if args["initialization_scripts"]: - resolved_scripts = self._fc.resolve_initialization_scripts( - script_paths=args["initialization_scripts"] - ) - list_targets_context = self.context.with_overrides( - initialization_scripts=resolved_scripts, - initializer_names=args["initializers"], + for script_path_str in arg.split(): + script_path = Path(script_path_str).resolve() + if not script_path.exists(): + print(f"Error: File not found: {script_path}") + return + try: + content = script_path.read_text() + self._run_async( + self._api_client.register_initializer_async(name=script_path.stem, script_content=content) ) + print(f"Registered initializer '{script_path.stem}' from {script_path}") + except ServerNotAvailableError as exc: + print(f"Error: {exc}") + return + except Exception as exc: + print(f"Error registering initializer: {exc}") + return - asyncio.run(self._fc.print_targets_list_async(context=list_targets_context)) - except ValueError as e: - print(f"Error: {e}") - except FileNotFoundError as e: - print(f"Error: {e}") - except Exception as e: - print(f"Error listing targets: {e}") + # ------------------------------------------------------------------ + # Run command + # ------------------------------------------------------------------ def do_run(self, line: str) -> None: """ @@ -263,299 +265,292 @@ def do_run(self, line: str) -> None: run [options] Options: - --target Target name from the TargetRegistry (required) - --initializers ... Built-in initializers (supports name:key=val1,val2 syntax) - --initialization-scripts <...> Custom Python scripts to run before the scenario - --strategies, -s ... Strategy names to use + --target Target name (required) + --initializers ... Initializer names (supports name:key=val syntax) + --strategies, -s ... Strategy names --max-concurrency Maximum concurrent operations --max-retries Maximum retry attempts - --memory-labels JSON string of labels (e.g., '{"key":"value"}') - --log-level Override default log level (DEBUG, INFO, WARNING, ERROR, CRITICAL) - - Examples: - run garak.encoding --target my_target --initializers target \ - load_default_datasets - run garak.encoding --target my_target --initializers target \ - load_default_datasets --strategies base64 rot13 - run foundry.red_team_agent --target my_target --initializers target:tags=default,scorer \ - dataset:mode=strict --strategies base64 - run foundry.red_team_agent --target my_target --initializers target \ - load_default_datasets --max-concurrency 10 --max-retries 3 - run garak.encoding --target my_target --initializers target \ - load_default_datasets \ - --memory-labels '{"run_id":"test123","env":"dev"}' - run foundry.red_team_agent --target my_target --initializers target \ - load_default_datasets -s jailbreak crescendo - run garak.encoding --target my_target --initializers target \ - load_default_datasets --log-level DEBUG - run foundry.red_team_agent --target my_target --initialization-scripts ./my_custom_init.py -s all - - Note: - --target is required for every run. - Initializers can be specified per-run or configured in .pyrit_conf. - Database and env-files are configured via the config file. - - Raises: - RuntimeError: If initialization has not completed. + --memory-labels JSON string of labels + --dataset-names ... Override default dataset names + --max-dataset-size Maximum items per dataset + -- Scenario-declared parameters (see list-scenarios) + + Notes: + Database, env files, and initialization scripts are configured on + the backend via its config file. Use `add-initializer` to register + custom initializers on the running server. """ - self._ensure_initialized() - if self._fc is None or self.context is None: - raise RuntimeError("Frontend core not initialized") + if not self._ensure_client(): + return if not line.strip(): print("Error: Specify a scenario name") - print("\nUsage: run [options]") - print("\nNote: --target is required. Initializers can be specified per-run or in .pyrit_conf.") - print("\nOptions:") - print(f" --target {self._fc.ARG_HELP['target']}") - print(f" --initializers ... {self._fc.ARG_HELP['initializers']}") - print( - f" --initialization-scripts <...> {self._fc.ARG_HELP['initialization_scripts']}" - " (alternative to --initializers)" - ) - print(f" --strategies, -s ... {self._fc.ARG_HELP['scenario_strategies']}") - print(f" --max-concurrency {self._fc.ARG_HELP['max_concurrency']}") - print(f" --max-retries {self._fc.ARG_HELP['max_retries']}") - print(f" --memory-labels {self._fc.ARG_HELP['memory_labels']}") - print( - " --log-level Override default log level (DEBUG, INFO, WARNING, ERROR, CRITICAL)" - ) - print("\nExample:") - print(" run foundry.red_team_agent --target my_target --initializers target load_default_datasets") - print("\nType 'help run' for more details and examples") + print("Usage: run --target [options]") return - # Look up declared params for the scenario so the parser can recognize - # scenario-specific flags. Built-in scenarios only in v1. - declared_params = None - scenario_name_token = line.split(maxsplit=1)[0] if line.strip() else "" - if scenario_name_token: - try: - scenario_class = ScenarioRegistry.get_registry_singleton().get_class(scenario_name_token) - except KeyError: - scenario_class = None - if scenario_class is not None: - declared_params = scenario_class.supported_parameters() + from pyrit.cli._cli_args import build_parameters_from_api, extract_scenario_args, parse_run_arguments + from pyrit.cli._output import ( + print_scenario_result_async, + print_scenario_run_progress, + print_scenario_run_summary, + ) - # Parse arguments using shared parser + # Fetch scenario metadata so the parser recognizes scenario-declared flags. + scenario_name_token = line.split(maxsplit=1)[0] try: - args = self._fc.parse_run_arguments(args_string=line, declared_params=declared_params) + scenario_meta = self._run_async(self._api_client.get_scenario_async(scenario_name=scenario_name_token)) + except Exception as exc: + print(f"Error fetching scenario metadata: {exc}") + return + if scenario_meta is None: + print(f"Error: Scenario '{scenario_name_token}' not found on server.") + return + declared_params = build_parameters_from_api(api_params=scenario_meta.get("supported_parameters") or []) + + # Parse arguments + try: + args = parse_run_arguments(args_string=line, declared_params=declared_params) except ValueError as e: print(f"Error: {e}") - # Hint when an unknown-flag error likely stems from a user-defined scenario - # introduced via --initialization-scripts (not yet supported for shell augmentation). - if declared_params is None and "--initialization-scripts" in line: - print( - "Note: scenario-specific flags from --initialization-scripts scenarios " - "are not yet supported in pyrit_shell. Built-in scenarios only in this release." - ) return - # Resolve initialization scripts if provided - resolved_scripts = None - if args["initialization_scripts"]: - try: - resolved_scripts = self._fc.resolve_initialization_scripts(script_paths=args["initialization_scripts"]) - except FileNotFoundError as e: - print(f"Error: {e}") - return + scenario_name = args["scenario_name"] - # Create a context for this run with per-command overrides, - # inheriting config_file, database, and env_files from startup. - run_context = self.context.with_overrides( - initializer_names=args["initializers"], - initialization_scripts=resolved_scripts, - log_level=args["log_level"], - ) + # Build request + request: dict[str, Any] = { + "scenario_name": scenario_name, + "target_name": args.get("target") or "", + } + + # Map initializers + initializers = args.get("initializers") + if initializers: + init_names: list[str] = [] + init_args: dict[str, dict[str, Any]] = {} + for entry in initializers: + if isinstance(entry, str): + init_names.append(entry) + elif isinstance(entry, dict): + name = entry["name"] + init_names.append(name) + if entry.get("args"): + init_args[name] = entry["args"] + request["initializers"] = init_names + if init_args: + request["initializer_args"] = init_args + + if args.get("scenario_strategies"): + request["strategies"] = args["scenario_strategies"] + if args.get("max_concurrency") is not None: + request["max_concurrency"] = args["max_concurrency"] + if args.get("max_retries") is not None: + request["max_retries"] = args["max_retries"] + if args.get("dataset_names"): + request["dataset_names"] = args["dataset_names"] + if args.get("max_dataset_size") is not None: + request["max_dataset_size"] = args["max_dataset_size"] + if args.get("memory_labels"): + request["labels"] = args["memory_labels"] + + scenario_params = extract_scenario_args(parsed=args) + if scenario_params: + request["scenario_params"] = scenario_params + + # Start run + total_strategies = len(request.get("strategies") or []) + print(f"\nRunning scenario: {scenario_name}") + sys.stdout.flush() try: - # Merge config-file scenario args (CLI wins). Shell v1 requires the - # scenario name to be provided positionally; config-only scenarios - # are not supported in the shell. - merged_scenario_args = merge_config_scenario_args( - config_scenario=self.context._scenario_config, - effective_scenario_name=args["scenario_name"], - cli_args=self._fc.extract_scenario_args(parsed=args), - ) + run = self._run_async(self._api_client.start_scenario_run_async(request=request)) + except Exception as exc: + print(f"Error starting scenario: {exc}") + return - result = asyncio.run( - self._fc.run_scenario_async( - scenario_name=args["scenario_name"], - context=run_context, - target_name=args["target"], - scenario_strategies=args["scenario_strategies"], - max_concurrency=args["max_concurrency"], - max_retries=args["max_retries"], - memory_labels=args["memory_labels"], - dataset_names=args["dataset_names"], - max_dataset_size=args["max_dataset_size"], - scenario_args=merged_scenario_args, - ) - ) - # Store the command and result in history - self._scenario_history.append((line, result)) + scenario_result_id = run.get("scenario_result_id", "") + + # Poll for completion + import time + + try: + while True: + run = self._run_async(self._api_client.get_scenario_run_async(scenario_result_id=scenario_result_id)) + status = run.get("status", "UNKNOWN") + print_scenario_run_progress(run=run, total_strategies=total_strategies) + if status in self._TERMINAL_STATUSES: + break + time.sleep(0.5) except KeyboardInterrupt: - print("\n\nScenario interrupted. Returning to shell.") - except ValueError as e: - print(f"Error: {e}") - except Exception as e: - print(f"Error running scenario: {e}") - import traceback + print("\n\nCancelling scenario run...") + try: + self._run_async(self._api_client.cancel_scenario_run_async(scenario_result_id=scenario_result_id)) + print("Scenario run cancelled.") + except Exception: + print("Warning: could not cancel scenario run.") + print("Returning to shell.") + return + + # Print results + if run.get("status") == "COMPLETED": + try: + detail = self._run_async( + self._api_client.get_scenario_run_results_async(scenario_result_id=scenario_result_id) + ) + self._run_async(print_scenario_result_async(result_dict=detail)) + except Exception: + print_scenario_run_summary(run=run) + else: + print_scenario_run_summary(run=run) - traceback.print_exc() + # ------------------------------------------------------------------ + # History commands + # ------------------------------------------------------------------ def do_scenario_history(self, arg: str) -> None: """ - Display history of scenario runs. + Display history of scenario runs from the server (most recent first). Usage: - scenario-history - - Shows a numbered list of all scenario runs with the commands used. + scenario-history Show the last 10 runs + scenario-history Show the last N runs """ - if arg.strip(): - print(f"Error: scenario-history does not accept arguments, got: {arg.strip()}") - return - if not self._scenario_history: - print("No scenario runs in history.") + arg = arg.strip() + limit = 10 + if arg: + try: + limit = int(arg) + except ValueError: + limit = 0 + if limit < 1: + print(f"Usage: scenario-history [N]. Got non-positive-integer argument: {arg!r}") + return + if not self._ensure_client(): return + from pyrit.cli._output import print_scenario_runs_list - print("\nScenario Run History:") - print("=" * 80) - for idx, (command, _) in enumerate(self._scenario_history, start=1): - print(f"{idx}) {command}") - print("=" * 80) - print(f"\nTotal runs: {len(self._scenario_history)}") - print("\nUse 'print-scenario ' to view detailed results for a specific run.") - print("Use 'print-scenario' to view detailed results for all runs.") + try: + resp = self._run_async(self._api_client.list_scenario_runs_async(limit=limit)) + print_scenario_runs_list(runs=resp.get("items", [])) + except Exception as e: + print(f"Error: {e}") def do_print_scenario(self, arg: str) -> None: """ - Print detailed results for scenario runs. + Print detailed results for a scenario run. Usage: - print-scenario Print all scenario results - print-scenario Print results for scenario run number N - - Examples: - print-scenario Show all previous scenario results - print-scenario 1 Show results from first scenario run - print-scenario 3 Show results from third scenario run + print-scenario """ - if not self._scenario_history: - print("No scenario runs in history.") + if not self._ensure_client(): return + from pyrit.cli._output import print_scenario_result_async - # Parse argument arg = arg.strip() - if not arg: - # Print all scenarios - print("\nPrinting all scenario results:") - print("=" * 80) - for idx, (command, result) in enumerate(self._scenario_history, start=1): - print(f"\n{'#' * 80}") - print(f"Scenario Run #{idx}: {command}") - print(f"{'#' * 80}") - from pyrit.output.scenario_result.pretty import ( - PrettyScenarioResultMemoryPrinter as ConsoleScenarioResultPrinter, - ) + print("Usage: print-scenario ") + print("Use 'scenario-history' to see available run IDs.") + return + + try: + detail = self._run_async(self._api_client.get_scenario_run_results_async(scenario_result_id=arg)) + self._run_async(print_scenario_result_async(result_dict=detail)) + except Exception as e: + print(f"Error: {e}") + + # ------------------------------------------------------------------ + # Server management + # ------------------------------------------------------------------ + + def do_start_server(self, arg: str) -> None: + """Start a local pyrit_backend server.""" + if arg.strip(): + print(f"Error: start-server does not accept arguments, got: {arg.strip()}") + return + from pyrit.cli._server_launcher import ServerLauncher + from pyrit.cli.api_client import PyRITApiClient + + base_url = self._resolve_base_url() + + # Check if already running + if self._run_async(ServerLauncher.probe_health_async(base_url=base_url)): + print(f"Server already running at {base_url}") + if self._api_client is None: + self._base_url = base_url + self._api_client = PyRITApiClient(base_url=base_url) + self._run_async(self._api_client.__aenter__()) + return - printer = ConsoleScenarioResultPrinter() - asyncio.run(printer.print_summary_async(result)) + self._launcher = ServerLauncher() + try: + new_url = self._run_async(self._launcher.start_async(config_file=self._config_file)) + self._base_url = new_url + # Create new client for the started server + if self._api_client is not None: + self._run_async(self._api_client.close_async()) + self._api_client = PyRITApiClient(base_url=new_url) + self._run_async(self._api_client.__aenter__()) + except RuntimeError as exc: + print(f"Error: {exc}") + + def do_stop_server(self, arg: str) -> None: + """Stop the backend server.""" + if arg.strip(): + print(f"Error: stop-server does not accept arguments, got: {arg.strip()}") + return + from pyrit.cli._server_launcher import ServerLauncher, stop_server_on_port + + # If we own the launcher, use it directly + if self._launcher is not None: + self._launcher.stop() + print("Server stopped.") else: - # Print specific scenario - try: - scenario_num = int(arg) - if scenario_num < 1 or scenario_num > len(self._scenario_history): - print(f"Error: Scenario number must be between 1 and {len(self._scenario_history)}") - return - - command, result = self._scenario_history[scenario_num - 1] - print(f"\nScenario Run #{scenario_num}: {command}") - print("=" * 80) - from pyrit.output.scenario_result.pretty import ( - PrettyScenarioResultMemoryPrinter as ConsoleScenarioResultPrinter, - ) + # Find and kill by port. Probe first so we don't SIGTERM a non-pyrit + # process that happens to be listening on this port. + from urllib.parse import urlparse + + base_url = self._base_url or self._resolve_base_url() + port = urlparse(base_url).port or 8000 + if not self._run_async(ServerLauncher.probe_health_async(base_url=base_url)): + print(f"No pyrit backend responding at {base_url}; not stopping anything.") + return + if stop_server_on_port(port=port): + print(f"Server on port {port} stopped.") + else: + print(f"No server found on port {port}.") + return - printer = ConsoleScenarioResultPrinter() - asyncio.run(printer.print_summary_async(result)) - except ValueError: - print(f"Error: Invalid scenario number '{arg}'. Must be an integer.") + # Close the API client since the server is gone + if self._api_client is not None: + with contextlib.suppress(Exception): + self._run_async(self._api_client.close_async()) + self._api_client = None + self._launcher = None + + # ------------------------------------------------------------------ + # Utility commands + # ------------------------------------------------------------------ def do_help(self, arg: str) -> None: """Show help. Usage: help [command].""" if not arg: - from pyrit.cli._cli_args import ARG_HELP - - # Show general help (no full init needed — ARG_HELP is lightweight) super().do_help(arg) - print("\n" + "=" * 70) - print("Shell Startup Options:") - print("=" * 70) - print(" --config-file ") - print(" Path to YAML configuration file") - print(" Default: ~/.pyrit/.pyrit_conf") - print() - print(" --log-level ") - print(" Default logging level: DEBUG, INFO, WARNING, ERROR, CRITICAL") - print(" Default: WARNING") - print(" Can be overridden per-run with 'run --log-level '") - print() - print("=" * 70) - print("Run Command Options (specified when running scenarios):") - print("=" * 70) - print(" --target (REQUIRED)") - print(f" {ARG_HELP['target']}") - print(" Example: run foundry.red_team_agent --target my_target") - print(" --initializers target load_default_datasets") - print() - print(" --initializers [ ...]") - print(f" {ARG_HELP['initializers']}") - print(" Example: run foundry.red_team_agent --target my_target") - print(" --initializers target load_default_datasets") - print(" With params: run foundry.red_team_agent --target my_target") - print(" --initializers target:tags=default,scorer") - print(" Multiple with params: run foundry.red_team_agent --target my_target") - print(" --initializers target:tags=default,scorer dataset:mode=strict") - print() - print(" --initialization-scripts [ ...] (Alternative to --initializers)") - print(f" {ARG_HELP['initialization_scripts']}") - print(" Example: run foundry.red_team_agent --initialization-scripts ./my_init.py") - print() - print(" --strategies, -s [ ...]") - print(f" {ARG_HELP['scenario_strategies']}") - print(" Example: run garak.encoding --strategies base64 rot13") - print() - print(" --max-concurrency ") - print(f" {ARG_HELP['max_concurrency']}") - print() - print(" --max-retries ") - print(f" {ARG_HELP['max_retries']}") - print() - print(" --memory-labels ") - print(f" {ARG_HELP['memory_labels']}") - print(' Example: run foundry.red_team_agent --memory-labels \'{"env":"test"}\'') - print() - print(" --log-level Override (DEBUG, INFO, WARNING, ERROR, CRITICAL)") - print() - print(" Database and env-files are configured via the config file (--config-file).") - print() - print("Start the shell like:") - print(" pyrit_shell") - print(" pyrit_shell --config-file ./my_config.yaml --log-level DEBUG") + print("\nUse 'help ' for details on a specific command.") else: - # Convert hyphens to underscores (e.g. help list-targets -> help list_targets) for command lookup normalized_arg = arg.replace("-", "_") super().do_help(normalized_arg) def do_exit(self, arg: str) -> bool: """ - Exit the shell. Aliases: quit, q. + Exit the shell. Returns: - bool: True to exit the shell. + bool: Always ``True`` to signal the ``cmd`` loop to terminate. """ + if self._api_client is not None: + with contextlib.suppress(Exception): + self._run_async(self._api_client.close_async()) + self._api_client = None + self._shutdown_loop() print("\nGoodbye!") return True @@ -568,31 +563,27 @@ def do_clear(self, arg: str) -> None: # Shortcuts and aliases do_quit = do_exit do_q = do_exit - do_EOF = do_exit # Ctrl+D on Unix, Ctrl+Z on Windows # noqa: N815 + do_EOF = do_exit # noqa: N815 def emptyline(self) -> bool: """ Don't repeat last command on empty line. Returns: - bool: False to prevent repeating the last command. + bool: Always ``False`` so the ``cmd`` loop does not exit. """ return False def default(self, line: str) -> None: """Handle unknown commands and convert hyphens to underscores.""" - # Try converting hyphens to underscores for command lookup parts = line.split(None, 1) if parts: cmd_with_underscores = parts[0].replace("-", "_") method_name = f"do_{cmd_with_underscores}" - if hasattr(self, method_name): - # Call the method with the rest of the line as argument arg = parts[1] if len(parts) > 1 else "" getattr(self, method_name)(arg) return - print(f"Unknown command: {line}") print("Type 'help' or '?' for available commands") @@ -606,11 +597,23 @@ def main() -> int: """ import argparse - from pyrit.cli._cli_args import ARG_HELP, validate_log_level + from pyrit.cli._cli_args import ARG_HELP parser = argparse.ArgumentParser( prog="pyrit_shell", - description="PyRIT Interactive Shell - Load modules once, run commands instantly", + description="PyRIT Interactive Shell - Thin REST client for the PyRIT backend", + ) + + parser.add_argument( + "--server-url", + type=str, + help="URL of the PyRIT backend server (default: http://localhost:8000)", + ) + + parser.add_argument( + "--start-server", + action="store_true", + help="Start a local pyrit_backend server if one is not already running", ) parser.add_argument( @@ -624,23 +627,26 @@ def main() -> int: type=str, choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], default="WARNING", - help=( - "Default logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)" - " (default: WARNING, can be overridden per-run)" - ), + help="Logging level (default: WARNING)", ) parser.add_argument( "--no-animation", action="store_true", default=False, - help="Disable the animated startup banner (show static banner instead)", + help="Disable the animated startup banner", ) args = parser.parse_args() - # Play the banner immediately, before heavy imports. - # Suppress logging so background-thread output doesn't corrupt the animation. + logging.basicConfig(level=getattr(logging, args.log_level)) + + # Surface a deprecation if the layered config has blocks the CLI ignores. + from pyrit.cli._config_reader import warn_on_client_ignored_blocks + + warn_on_client_ignored_blocks(config_file=args.config_file) + + # Play banner immediately prev_disable = logging.root.manager.disable logging.disable(logging.CRITICAL) try: @@ -648,14 +654,12 @@ def main() -> int: finally: logging.disable(prev_disable) - # Create shell with deferred initialization — the background thread - # will import frontend_core, create the FrontendCore context, and call - # initialize_async while the user is already at the prompt. try: shell = PyRITShell( no_animation=args.no_animation, + server_url=args.server_url, config_file=args.config_file, - log_level=validate_log_level(log_level=args.log_level), + start_server=args.start_server, ) shell.cmdloop(intro=intro) return 0 diff --git a/pyrit/common/cli_helpers.py b/pyrit/common/cli_helpers.py new file mode 100644 index 0000000000..ec5d9674f9 --- /dev/null +++ b/pyrit/common/cli_helpers.py @@ -0,0 +1,65 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Lightweight CLI helpers shared between the backend launcher (``pyrit_backend``) +and the thin REST CLI (``pyrit_scan`` / ``pyrit_shell``). + +This module intentionally has no heavy pyrit imports so it can be loaded by +either entry point without dragging in unrelated subsystems. +""" + +from __future__ import annotations + +import argparse +import logging +from typing import Any + +CONFIG_FILE_HELP = ( + "Path to a YAML configuration file. Allows specifying database, initializers (with args), " + "initialization scripts, and env files. CLI arguments override config file values. " + "If not specified, ~/.pyrit/.pyrit_conf is loaded if it exists." +) + + +def validate_log_level(*, log_level: str) -> int: + """ + Validate a log level string and convert it to a ``logging`` constant. + + Args: + log_level: Log level string (case-insensitive). + + Returns: + Validated log level as a ``logging`` module constant (e.g. ``logging.WARNING``). + + Raises: + ValueError: If ``log_level`` is not one of DEBUG/INFO/WARNING/ERROR/CRITICAL. + """ + valid_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] + level_upper = log_level.upper() + if level_upper not in valid_levels: + raise ValueError(f"Invalid log level: {log_level}. Must be one of: {', '.join(valid_levels)}") + level_value: int = getattr(logging, level_upper) + return level_value + + +def validate_log_level_argparse(value: Any) -> int: + """ + Argparse-compatible wrapper around :func:`validate_log_level`. + + Adapts the keyword-only validator to argparse's positional ``type=`` calling + convention and converts ``ValueError`` to :class:`argparse.ArgumentTypeError`. + + Args: + value: Log level string supplied by argparse. + + Returns: + Validated log level as a ``logging`` module constant. + + Raises: + argparse.ArgumentTypeError: If ``value`` is not a valid log level. + """ + try: + return validate_log_level(log_level=value) + except ValueError as exc: + raise argparse.ArgumentTypeError(str(exc)) from exc diff --git a/pyrit/registry/class_registries/scenario_registry.py b/pyrit/registry/class_registries/scenario_registry.py index c535bc4fcf..4c34e648d2 100644 --- a/pyrit/registry/class_registries/scenario_registry.py +++ b/pyrit/registry/class_registries/scenario_registry.py @@ -71,7 +71,8 @@ class ScenarioParameterMetadata(NamedTuple): description: str default: Any param_type: str - choices: Optional[str] + choices: Optional[list[str]] + is_list: bool = False class ScenarioRegistry(BaseClassRegistry["Scenario", ScenarioMetadata]): @@ -220,7 +221,8 @@ def _build_metadata(self, name: str, entry: ClassEntry[Scenario]) -> ScenarioMet description=p.description, default=p.default, param_type=_param_type_display(p.param_type), - choices=", ".join(repr(c) for c in p.choices) if p.choices else None, + choices=[str(c) for c in p.choices] if p.choices else None, + is_list=get_origin(p.param_type) is list, ) for p in scenario_class.supported_parameters() ) diff --git a/pyrit/setup/configuration_loader.py b/pyrit/setup/configuration_loader.py index 0fe2db0a2e..411fa97def 100644 --- a/pyrit/setup/configuration_loader.py +++ b/pyrit/setup/configuration_loader.py @@ -54,6 +54,18 @@ class InitializerConfig: args: Optional[dict[str, YamlValue]] = None +@dataclass +class ServerConfig: + """ + Configuration for connecting to (or launching) a PyRIT backend server. + + Attributes: + url: Base URL of the backend (e.g. ``http://localhost:8000``). + """ + + url: str = "http://localhost:8000" + + @dataclass class ScenarioConfig: """ @@ -134,6 +146,7 @@ class ConfigurationLoader(YamlLoadable): scenario: Optional[Union[str, dict[str, Any]]] = None max_concurrent_scenario_runs: int = 3 allow_custom_initializers: bool = False + server: Optional[dict[str, Any]] = None extensions: dict[str, Any] = field(default_factory=dict) def __post_init__(self) -> None: @@ -141,6 +154,7 @@ def __post_init__(self) -> None: self._normalize_memory_db_type() self._normalize_initializers() self._normalize_scenario() + self._normalize_server() def _normalize_memory_db_type(self) -> None: """ @@ -239,6 +253,33 @@ def _normalize_scenario(self) -> None: raise ValueError(f"Scenario entry must be a string or dict, got: {type(self.scenario).__name__}") + def _normalize_server(self) -> None: + """ + Normalize the optional ``server`` block to a ``ServerConfig``. + + Accepts ``None`` (no server configured) or ``{"url": "..."}`` form. + + Raises: + ValueError: If ``server`` is not ``None`` or a dict, or if ``url`` is not a string. + """ + if self.server is None: + self._server_config: Optional[ServerConfig] = None + return + + if isinstance(self.server, dict): + url = self.server.get("url", "http://localhost:8000") + if not isinstance(url, str): + raise ValueError(f"Server 'url' must be a string. Got: {type(url).__name__}") + self._server_config = ServerConfig(url=url.rstrip("/")) + return + + raise ValueError(f"Server entry must be a dict, got: {type(self.server).__name__}") + + @property + def server_config(self) -> Optional[ServerConfig]: + """The normalized ``server:`` block, or ``None`` when not configured.""" + return self._server_config + @property def scenario_config(self) -> Optional[ScenarioConfig]: """The normalized ``scenario:`` block, or ``None`` when not configured.""" @@ -402,7 +443,7 @@ def get_default_config_path(cls) -> pathlib.Path: """ return DEFAULT_CONFIG_PATH - def _resolve_initializers(self) -> Sequence["PyRITInitializer"]: + def resolve_initializers(self) -> Sequence["PyRITInitializer"]: """ Resolve initializer names to PyRITInitializer instances. @@ -415,6 +456,8 @@ def _resolve_initializers(self) -> Sequence["PyRITInitializer"]: Raises: ValueError: If an initializer name is not found in the registry. """ + import logging + from pyrit.registry import InitializerRegistry if not self._initializer_configs: @@ -423,6 +466,8 @@ def _resolve_initializers(self) -> Sequence["PyRITInitializer"]: registry = InitializerRegistry() resolved: list[PyRITInitializer] = [] + logging.getLogger(__name__).info("Running %d initializer(s)...", len(self._initializer_configs)) + for config in self._initializer_configs: initializer_class = registry.get_class(config.name) if initializer_class is None: @@ -442,7 +487,7 @@ def _resolve_initializers(self) -> Sequence["PyRITInitializer"]: return resolved - def _resolve_initialization_scripts(self) -> Optional[Sequence[pathlib.Path]]: + def resolve_initialization_scripts(self) -> Optional[Sequence[pathlib.Path]]: """ Resolve initialization script paths. @@ -467,7 +512,7 @@ def _resolve_initialization_scripts(self) -> Optional[Sequence[pathlib.Path]]: return resolved - def _resolve_env_files(self) -> Optional[Sequence[pathlib.Path]]: + def resolve_env_files(self) -> Optional[Sequence[pathlib.Path]]: """ Resolve environment file paths. @@ -502,9 +547,9 @@ async def initialize_pyrit_async(self) -> None: Raises: ValueError: If configuration is invalid or initializers cannot be resolved. """ - resolved_initializers = self._resolve_initializers() - resolved_scripts = self._resolve_initialization_scripts() - resolved_env_files = self._resolve_env_files() + resolved_initializers = self.resolve_initializers() + resolved_scripts = self.resolve_initialization_scripts() + resolved_env_files = self.resolve_env_files() # Map snake_case memory_db_type to internal constant internal_memory_db_type = _MEMORY_DB_TYPE_MAP[self.memory_db_type] diff --git a/tests/unit/backend/test_main.py b/tests/unit/backend/test_main.py index b2e2efc363..b05ec5f6c8 100644 --- a/tests/unit/backend/test_main.py +++ b/tests/unit/backend/test_main.py @@ -7,32 +7,75 @@ Covers the lifespan manager and setup_frontend function. """ +import logging import os -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch from pyrit.backend.main import app, lifespan, setup_frontend +from pyrit.setup.configuration_loader import ConfigurationLoader class TestLifespan: """Tests for the application lifespan context manager.""" async def test_lifespan_yields(self) -> None: - """Test that lifespan yields without performing initialization (handled by CLI).""" - with patch("pyrit.memory.CentralMemory._memory_instance", MagicMock()): + """Test that lifespan delegates to ConfigurationLoader and yields.""" + fake_config = ConfigurationLoader() + with ( + patch.object(ConfigurationLoader, "load_with_overrides", return_value=fake_config), + patch.object(ConfigurationLoader, "initialize_pyrit_async", new=AsyncMock()) as init_mock, + patch("pyrit.backend.main.setup_frontend"), + ): async with lifespan(app): - pass # Should complete without error + pass + + init_mock.assert_awaited_once() + assert app.state.default_labels == {} + assert app.state.max_concurrent_scenario_runs == fake_config.max_concurrent_scenario_runs + assert app.state.allow_custom_initializers is False - async def test_lifespan_warns_when_memory_not_initialized(self) -> None: - """Test that lifespan logs a warning when CentralMemory is not set.""" + async def test_lifespan_warns_when_custom_initializers_allowed(self) -> None: + """Test that lifespan logs a warning when allow_custom_initializers is enabled.""" + fake_config = ConfigurationLoader(allow_custom_initializers=True) with ( - patch("pyrit.memory.CentralMemory._memory_instance", None), - patch("logging.Logger.warning") as mock_warning, + patch.object(ConfigurationLoader, "load_with_overrides", return_value=fake_config), + patch.object(ConfigurationLoader, "initialize_pyrit_async", new=AsyncMock()), + patch("pyrit.backend.main.setup_frontend"), + patch.object(logging.getLogger("pyrit.backend.main"), "warning") as mock_warning, ): async with lifespan(app): pass mock_warning.assert_called_once() + async def test_lifespan_populates_default_labels_from_operator_and_operation(self) -> None: + """Test that operator and operation are exposed as default_labels.""" + fake_config = ConfigurationLoader(operator="alice", operation="op-42") + with ( + patch.object(ConfigurationLoader, "load_with_overrides", return_value=fake_config), + patch.object(ConfigurationLoader, "initialize_pyrit_async", new=AsyncMock()), + patch("pyrit.backend.main.setup_frontend"), + ): + async with lifespan(app): + pass + + assert app.state.default_labels == {"operator": "alice", "operation": "op-42"} + + async def test_lifespan_reads_config_file_env_var(self) -> None: + """Test that PYRIT_CONFIG_FILE is forwarded to ConfigurationLoader.load_with_overrides.""" + fake_config = ConfigurationLoader() + with ( + patch.dict(os.environ, {"PYRIT_CONFIG_FILE": "/tmp/foo.yaml"}, clear=False), + patch.object(ConfigurationLoader, "load_with_overrides", return_value=fake_config) as load_mock, + patch.object(ConfigurationLoader, "initialize_pyrit_async", new=AsyncMock()), + patch("pyrit.backend.main.setup_frontend"), + ): + async with lifespan(app): + pass + + call_kwargs = load_mock.call_args.kwargs + assert str(call_kwargs["config_file"]).endswith("foo.yaml") + class TestSetupFrontend: """Tests for the setup_frontend function.""" diff --git a/tests/unit/backend/test_pyrit_backend.py b/tests/unit/backend/test_pyrit_backend.py new file mode 100644 index 0000000000..8b5fa0def6 --- /dev/null +++ b/tests/unit/backend/test_pyrit_backend.py @@ -0,0 +1,111 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from pyrit.backend import pyrit_backend + + +class TestParseArgs: + """Tests for pyrit_backend.parse_args.""" + + def test_parse_args_defaults(self) -> None: + args = pyrit_backend.parse_args(args=[]) + assert args.host == "localhost" + assert args.port == 8000 + assert args.config_file is None + assert args.reload is False + + def test_parse_args_accepts_config_file(self) -> None: + args = pyrit_backend.parse_args(args=["--config-file", "./custom_conf.yaml"]) + assert args.config_file == Path("./custom_conf.yaml") + + def test_parse_args_accepts_host_port(self) -> None: + args = pyrit_backend.parse_args(args=["--host", "0.0.0.0", "--port", "9000"]) + assert args.host == "0.0.0.0" + assert args.port == 9000 + + def test_parse_args_accepts_reload(self) -> None: + args = pyrit_backend.parse_args(args=["--reload"]) + assert args.reload is True + + +class TestMain: + """Tests for pyrit_backend.main.""" + + @patch("uvicorn.run") + def test_main_starts_uvicorn(self, mock_run: MagicMock) -> None: + result = pyrit_backend.main(args=[]) + assert result == 0 + mock_run.assert_called_once() + assert mock_run.call_args[0][0] == "pyrit.backend.main:app" + + @patch("uvicorn.run") + def test_main_forwards_config_file_via_env(self, mock_run: MagicMock) -> None: + import os + + with patch.dict(os.environ, {}, clear=False): + pyrit_backend.main(args=["--config-file", "./custom.yaml"]) + assert os.environ.get("PYRIT_CONFIG_FILE") is not None + assert "custom.yaml" in os.environ["PYRIT_CONFIG_FILE"] + + @patch("uvicorn.run") + def test_main_passes_host_and_port(self, mock_run: MagicMock) -> None: + pyrit_backend.main(args=["--host", "0.0.0.0", "--port", "9000"]) + call_kwargs = mock_run.call_args.kwargs + assert call_kwargs["host"] == "0.0.0.0" + assert call_kwargs["port"] == 9000 + + def test_main_invalid_args(self) -> None: + result = pyrit_backend.main(args=["--invalid-flag"]) + assert result == 2 + + @patch("uvicorn.run", side_effect=KeyboardInterrupt()) + def test_main_keyboard_interrupt_returns_zero(self, mock_run: MagicMock, capsys) -> None: + result = pyrit_backend.main(args=[]) + assert result == 0 + captured = capsys.readouterr() + assert "Backend stopped" in captured.out + + @patch("uvicorn.run", side_effect=RuntimeError("boom")) + def test_main_unexpected_exception_returns_one(self, mock_run: MagicMock, capsys) -> None: + result = pyrit_backend.main(args=[]) + assert result == 1 + captured = capsys.readouterr() + assert "boom" in captured.out + + @patch("uvicorn.run") + def test_main_forwards_log_level(self, mock_run: MagicMock) -> None: + pyrit_backend.main(args=["--log-level", "DEBUG"]) + assert mock_run.call_args.kwargs["log_level"] == "debug" + + @patch("uvicorn.run") + def test_main_forwards_reload_flag(self, mock_run: MagicMock) -> None: + pyrit_backend.main(args=["--reload"]) + assert mock_run.call_args.kwargs["reload"] is True + + +class TestParseArgsDoesNotAcceptLegacyFlags: + """ + Regression: the thin backend only takes --host/--port/--config-file/--log-level/--reload. + Legacy --database and --initializers must be rejected so callers (docker/start.sh, + frontend/dev.py) cannot silently regress to passing them. + """ + + def test_database_flag_rejected(self) -> None: + with pytest.raises(SystemExit) as exc_info: + pyrit_backend.parse_args(args=["--database", "SQLite"]) + assert exc_info.value.code != 0 + + def test_initializers_flag_rejected(self) -> None: + with pytest.raises(SystemExit) as exc_info: + pyrit_backend.parse_args(args=["--initializers", "target"]) + assert exc_info.value.code != 0 + + def test_initialization_scripts_flag_rejected(self) -> None: + with pytest.raises(SystemExit) as exc_info: + pyrit_backend.parse_args(args=["--initialization-scripts", "./x.py"]) + assert exc_info.value.code != 0 diff --git a/tests/unit/backend/test_scenario_run_routes.py b/tests/unit/backend/test_scenario_run_routes.py index fd128e3946..faf40a5b8f 100644 --- a/tests/unit/backend/test_scenario_run_routes.py +++ b/tests/unit/backend/test_scenario_run_routes.py @@ -14,10 +14,7 @@ import pyrit.backend.services.scenario_run_service as _svc_mod from pyrit.backend.main import app -from pyrit.backend.models.attacks import AttackSummary from pyrit.backend.models.scenarios import ( - AtomicAttackResults, - ScenarioRunDetail, ScenarioRunListResponse, ScenarioRunStatus, ScenarioRunSummary, @@ -234,58 +231,34 @@ class TestGetScenarioRunResultsRoute: def test_get_results_returns_200(self, client: TestClient) -> None: """Test that getting results of a completed run returns 200.""" - mock_result = ScenarioRunDetail( - run=ScenarioRunSummary( - scenario_result_id="result-uuid", - scenario_name="foundry.red_team_agent", - scenario_version=1, - status=ScenarioRunStatus.COMPLETED, - created_at=datetime(2025, 1, 1, tzinfo=timezone.utc), - updated_at=datetime(2025, 1, 1, tzinfo=timezone.utc), - objective_achieved_rate=50, - labels={"team": "red"}, - completed_at=datetime(2025, 1, 1, tzinfo=timezone.utc), - ), - attacks=[ - AtomicAttackResults( - atomic_attack_name="base64_attack", - display_group="encoding", - results=[ - AttackSummary( - attack_result_id="ar-1", - conversation_id="conv-1", - objective="Extract sensitive info", - outcome="success", - outcome_reason="Model revealed data", - last_response="Here is the data...", - score_value="1.0", - executed_turns=3, - execution_time_ms=1500, - created_at=datetime(2025, 1, 1, tzinfo=timezone.utc), - updated_at=datetime(2025, 1, 1, tzinfo=timezone.utc), - ), - ], - success_count=1, - failure_count=0, - total_count=1, - ), - ], - ) + mock_scenario_result = MagicMock() + mock_scenario_result.to_dict.return_value = { + "id": "result-uuid", + "scenario_identifier": {"name": "foundry.red_team_agent", "version": 1}, + "scenario_run_state": "COMPLETED", + "attack_results": { + "base64_attack": [ + { + "attack_result_id": "ar-1", + "conversation_id": "conv-1", + "objective": "Extract sensitive info", + "outcome": "success", + } + ] + }, + } with patch("pyrit.backend.routes.scenarios.get_scenario_run_service") as mock_get: mock_service = MagicMock() - mock_service.get_run_results.return_value = mock_result + mock_service.get_run_results.return_value = mock_scenario_result mock_get.return_value = mock_service response = client.get("/api/scenarios/runs/test-run-id/results") assert response.status_code == status.HTTP_200_OK data = response.json() - assert data["run"]["scenario_result_id"] == "result-uuid" - assert data["run"]["objective_achieved_rate"] == 50 - assert len(data["attacks"]) == 1 - assert data["attacks"][0]["atomic_attack_name"] == "base64_attack" - assert data["attacks"][0]["results"][0]["outcome"] == "success" + assert data["id"] == "result-uuid" + assert "base64_attack" in data["attack_results"] def test_get_results_not_found_returns_404(self, client: TestClient) -> None: """Test that getting results of a non-existent run returns 404.""" diff --git a/tests/unit/backend/test_scenario_run_service.py b/tests/unit/backend/test_scenario_run_service.py index 65169c90f1..1de264ed5b 100644 --- a/tests/unit/backend/test_scenario_run_service.py +++ b/tests/unit/backend/test_scenario_run_service.py @@ -492,26 +492,12 @@ def test_get_results_raises_if_not_completed(self, mock_memory) -> None: service.get_run_results(scenario_result_id="sr-running") def test_get_results_returns_details_for_completed_run(self, mock_memory) -> None: - """Test that get_run_results returns full details for a completed run.""" + """Test that get_run_results returns the ScenarioResult for a completed run.""" from pyrit.models import AttackOutcome mock_attack_result = MagicMock() - mock_attack_result.attack_result_id = "ar-1" - mock_attack_result.conversation_id = "conv-1" - mock_attack_result.objective = "Extract info" mock_attack_result.outcome = AttackOutcome.SUCCESS - mock_attack_result.outcome_reason = "Model complied" - mock_attack_result.last_response = MagicMock(value="Here is the data") - mock_attack_result.last_score = MagicMock() - mock_attack_result.last_score.get_value.return_value = "1.0" - mock_attack_result.executed_turns = 3 - mock_attack_result.execution_time_ms = 1500 - mock_attack_result.timestamp = None - mock_attack_result.error_message = None - mock_attack_result.error_type = None - mock_attack_result.error_traceback = None - mock_attack_result.total_retries = 0 - mock_attack_result.retry_events = [] + mock_attack_result.objective = "Extract info" db_result = _make_db_scenario_result( result_id="sr-123", @@ -522,16 +508,10 @@ def test_get_results_returns_details_for_completed_run(self, mock_memory) -> Non mock_memory.get_scenario_results.return_value = [db_result] service = ScenarioRunService() - detail = service.get_run_results(scenario_result_id="sr-123") - - assert detail is not None - assert detail.run.scenario_result_id == "sr-123" - assert detail.run.objective_achieved_rate == 100 - assert len(detail.attacks) == 1 - assert detail.attacks[0].atomic_attack_name == "base64_attack" - assert detail.attacks[0].success_count == 1 - assert detail.attacks[0].results[0].objective == "Extract info" - assert detail.attacks[0].results[0].outcome == "success" + result = service.get_run_results(scenario_result_id="sr-123") + + assert result is db_result + assert result.attack_results["base64_attack"][0].outcome == AttackOutcome.SUCCESS class TestScenarioRunServiceProgressReporting: diff --git a/tests/unit/backend/test_scenario_service.py b/tests/unit/backend/test_scenario_service.py index aa88ad3881..0471786b36 100644 --- a/tests/unit/backend/test_scenario_service.py +++ b/tests/unit/backend/test_scenario_service.py @@ -368,7 +368,7 @@ async def test_list_scenarios_includes_supported_parameters(self) -> None: description="Execution mode", default="fast", param_type="str", - choices="'fast', 'slow'", + choices=["fast", "slow"], ), ), ) @@ -389,12 +389,14 @@ async def test_list_scenarios_includes_supported_parameters(self) -> None: assert params[0].default == "5" assert params[0].param_type == "int" assert params[0].choices is None + assert params[0].is_list is False assert params[1].name == "mode" assert params[1].description == "Execution mode" assert params[1].default == "'fast'" assert params[1].param_type == "str" - assert params[1].choices == "'fast', 'slow'" + assert params[1].choices == ["fast", "slow"] + assert params[1].is_list is False async def test_scenario_with_no_parameters_has_empty_list(self) -> None: """Test that scenarios without parameters have empty supported_parameters.""" diff --git a/tests/unit/cli/test_api_client.py b/tests/unit/cli/test_api_client.py new file mode 100644 index 0000000000..f379985b85 --- /dev/null +++ b/tests/unit/cli/test_api_client.py @@ -0,0 +1,236 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Unit tests for pyrit.cli.api_client.PyRITApiClient. +""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest + +from pyrit.cli.api_client import PyRITApiClient, ServerNotAvailableError + + +@pytest.fixture() +def mock_httpx_client(): + """A MagicMock standing in for an opened ``httpx.AsyncClient``.""" + client = MagicMock() + client.get = AsyncMock() + client.post = AsyncMock() + client.aclose = AsyncMock() + return client + + +@pytest.fixture() +def client(mock_httpx_client): + """A PyRITApiClient with the underlying HTTP client pre-wired.""" + c = PyRITApiClient(base_url="http://localhost:8000/") + c._client = mock_httpx_client + return c + + +def _make_response(*, status_code=200, json_data=None): + resp = MagicMock() + resp.status_code = status_code + resp.json = MagicMock(return_value=json_data or {}) + resp.raise_for_status = MagicMock() + return resp + + +# --------------------------------------------------------------------------- +# Init / context manager / lifecycle +# --------------------------------------------------------------------------- + + +def test_init_strips_trailing_slash(): + c = PyRITApiClient(base_url="http://localhost:8000/") + assert c._base_url == "http://localhost:8000" + + +async def test_async_context_manager_opens_and_closes(mock_httpx_client): + c = PyRITApiClient(base_url="http://localhost:8000") + fake_async_client_cls = MagicMock(return_value=mock_httpx_client) + with patch("httpx.AsyncClient", fake_async_client_cls): + async with c as opened: + assert opened is c + assert c._client is mock_httpx_client + # After exit, close was called + mock_httpx_client.aclose.assert_awaited_once() + assert c._client is None + + +async def test_close_async_is_noop_when_already_closed(): + c = PyRITApiClient(base_url="http://localhost:8000") + await c.close_async() # Should not raise. + + +def test_get_client_raises_when_not_opened(): + c = PyRITApiClient(base_url="http://localhost:8000") + with pytest.raises(ServerNotAvailableError, match="not connected"): + c._get_client() + + +# --------------------------------------------------------------------------- +# health_check_async +# --------------------------------------------------------------------------- + + +async def test_health_check_returns_true_on_200(client, mock_httpx_client): + mock_httpx_client.get.return_value = _make_response(status_code=200) + assert await client.health_check_async() is True + mock_httpx_client.get.assert_awaited_once_with("/api/health") + + +async def test_health_check_returns_false_on_non_200(client, mock_httpx_client): + mock_httpx_client.get.return_value = _make_response(status_code=503) + assert await client.health_check_async() is False + + +async def test_health_check_returns_false_on_connect_error(client, mock_httpx_client): + mock_httpx_client.get.side_effect = httpx.ConnectError("nope") + assert await client.health_check_async() is False + + +async def test_health_check_returns_false_on_generic_exception(client, mock_httpx_client): + mock_httpx_client.get.side_effect = RuntimeError("broken") + assert await client.health_check_async() is False + + +# --------------------------------------------------------------------------- +# Scenarios +# --------------------------------------------------------------------------- + + +async def test_list_scenarios_async(client, mock_httpx_client): + payload = {"items": [{"scenario_name": "s1"}], "pagination": {}} + mock_httpx_client.get.return_value = _make_response(json_data=payload) + result = await client.list_scenarios_async(limit=10) + assert result == payload + mock_httpx_client.get.assert_awaited_once_with("/api/scenarios/catalog", params={"limit": 10}) + + +async def test_get_scenario_async_returns_payload(client, mock_httpx_client): + payload = {"scenario_name": "foo"} + mock_httpx_client.get.return_value = _make_response(json_data=payload) + result = await client.get_scenario_async(scenario_name="foo") + assert result == payload + mock_httpx_client.get.assert_awaited_once_with("/api/scenarios/catalog/foo", params=None) + + +async def test_get_scenario_async_returns_none_on_404(client, mock_httpx_client): + resp = _make_response(status_code=404) + error = httpx.HTTPStatusError("404", request=MagicMock(), response=resp) + mock_httpx_client.get.return_value = resp + resp.raise_for_status.side_effect = error + result = await client.get_scenario_async(scenario_name="missing") + assert result is None + + +async def test_get_scenario_async_raises_on_other_http_errors(client, mock_httpx_client): + resp = _make_response(status_code=500) + error = httpx.HTTPStatusError("500", request=MagicMock(), response=resp) + mock_httpx_client.get.return_value = resp + resp.raise_for_status.side_effect = error + with pytest.raises(httpx.HTTPStatusError): + await client.get_scenario_async(scenario_name="boom") + + +# --------------------------------------------------------------------------- +# Initializers +# --------------------------------------------------------------------------- + + +async def test_list_initializers_async(client, mock_httpx_client): + mock_httpx_client.get.return_value = _make_response(json_data={"items": []}) + await client.list_initializers_async(limit=5) + mock_httpx_client.get.assert_awaited_once_with("/api/initializers", params={"limit": 5}) + + +async def test_register_initializer_async_success(client, mock_httpx_client): + payload = {"initializer_name": "x"} + mock_httpx_client.post.return_value = _make_response(json_data=payload) + result = await client.register_initializer_async(name="x", script_content="print(1)") + assert result == payload + mock_httpx_client.post.assert_awaited_once_with( + "/api/initializers", json={"name": "x", "script_content": "print(1)"} + ) + + +async def test_register_initializer_async_raises_on_403(client, mock_httpx_client): + resp = _make_response(status_code=403, json_data={"detail": "Custom initializers disabled"}) + mock_httpx_client.post.return_value = resp + with pytest.raises(ServerNotAvailableError, match="disabled"): + await client.register_initializer_async(name="x", script_content="...") + + +async def test_register_initializer_async_raises_on_500(client, mock_httpx_client): + resp = _make_response(status_code=500) + resp.raise_for_status.side_effect = httpx.HTTPStatusError("500", request=MagicMock(), response=resp) + mock_httpx_client.post.return_value = resp + with pytest.raises(httpx.HTTPStatusError): + await client.register_initializer_async(name="x", script_content="...") + + +# --------------------------------------------------------------------------- +# Targets +# --------------------------------------------------------------------------- + + +async def test_list_targets_async(client, mock_httpx_client): + mock_httpx_client.get.return_value = _make_response(json_data={"items": []}) + await client.list_targets_async(limit=7) + mock_httpx_client.get.assert_awaited_once_with("/api/targets", params={"limit": 7}) + + +# --------------------------------------------------------------------------- +# Scenario runs +# --------------------------------------------------------------------------- + + +async def test_start_scenario_run_async(client, mock_httpx_client): + payload = {"scenario_result_id": "abc"} + mock_httpx_client.post.return_value = _make_response(json_data=payload) + request = {"scenario_name": "x"} + result = await client.start_scenario_run_async(request=request) + assert result == payload + mock_httpx_client.post.assert_awaited_once_with("/api/scenarios/runs", json=request) + + +async def test_get_scenario_run_async(client, mock_httpx_client): + mock_httpx_client.get.return_value = _make_response(json_data={"status": "RUNNING"}) + result = await client.get_scenario_run_async(scenario_result_id="abc") + assert result == {"status": "RUNNING"} + mock_httpx_client.get.assert_awaited_once_with("/api/scenarios/runs/abc", params=None) + + +async def test_get_scenario_run_results_async(client, mock_httpx_client): + mock_httpx_client.get.return_value = _make_response(json_data={"run": {}, "attacks": []}) + result = await client.get_scenario_run_results_async(scenario_result_id="abc") + assert "run" in result + mock_httpx_client.get.assert_awaited_once_with("/api/scenarios/runs/abc/results", params=None) + + +async def test_cancel_scenario_run_async(client, mock_httpx_client): + mock_httpx_client.post.return_value = _make_response(json_data={"status": "CANCELLED"}) + result = await client.cancel_scenario_run_async(scenario_result_id="abc") + assert result == {"status": "CANCELLED"} + mock_httpx_client.post.assert_awaited_once_with("/api/scenarios/runs/abc/cancel") + + +async def test_list_scenario_runs_async(client, mock_httpx_client): + mock_httpx_client.get.return_value = _make_response(json_data={"items": []}) + await client.list_scenario_runs_async(limit=20) + mock_httpx_client.get.assert_awaited_once_with("/api/scenarios/runs", params={"limit": 20}) + + +# --------------------------------------------------------------------------- +# _get_json error path +# --------------------------------------------------------------------------- + + +async def test_get_json_wraps_connect_error_as_server_not_available(client, mock_httpx_client): + mock_httpx_client.get.side_effect = httpx.ConnectError("nope") + with pytest.raises(ServerNotAvailableError, match="Cannot connect"): + await client.list_scenarios_async() diff --git a/tests/unit/cli/test_config_reader.py b/tests/unit/cli/test_config_reader.py new file mode 100644 index 0000000000..085a1b02b8 --- /dev/null +++ b/tests/unit/cli/test_config_reader.py @@ -0,0 +1,82 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Unit tests for pyrit.cli._config_reader. +""" + +from unittest.mock import patch + +from pyrit.cli import _config_reader +from pyrit.cli._config_reader import DEFAULT_SERVER_URL, read_server_url + + +def test_default_server_url_constant(): + assert DEFAULT_SERVER_URL == "http://localhost:8000" + + +def test_read_server_url_returns_none_when_no_files(tmp_path): + nonexistent = tmp_path / "missing.yaml" + with patch.object(_config_reader, "_DEFAULT_CONFIG_FILE", tmp_path / "missing_default.yaml"): + assert read_server_url(config_file=nonexistent) is None + + +def test_read_server_url_reads_from_default_when_no_overlay(tmp_path): + default = tmp_path / "default.yaml" + default.write_text("server:\n url: http://default-host:9000\n") + with patch.object(_config_reader, "_DEFAULT_CONFIG_FILE", default): + assert read_server_url(config_file=None) == "http://default-host:9000" + + +def test_read_server_url_overlay_overrides_default(tmp_path): + default = tmp_path / "default.yaml" + default.write_text("server:\n url: http://default-host:9000\n") + overlay = tmp_path / "overlay.yaml" + overlay.write_text("server:\n url: http://overlay-host:5000\n") + with patch.object(_config_reader, "_DEFAULT_CONFIG_FILE", default): + assert read_server_url(config_file=overlay) == "http://overlay-host:5000" + + +def test_read_server_url_overlay_missing_field_falls_back(tmp_path): + default = tmp_path / "default.yaml" + default.write_text("server:\n url: http://default-host:9000\n") + overlay = tmp_path / "overlay.yaml" + overlay.write_text("other_block: {}\n") + with patch.object(_config_reader, "_DEFAULT_CONFIG_FILE", default): + # Overlay doesn't have server.url, so default wins. + assert read_server_url(config_file=overlay) == "http://default-host:9000" + + +def test_read_server_url_strips_whitespace(tmp_path): + default = tmp_path / "default.yaml" + default.write_text("server:\n url: ' http://padded:9000 '\n") + with patch.object(_config_reader, "_DEFAULT_CONFIG_FILE", default): + assert read_server_url(config_file=None) == "http://padded:9000" + + +def test_read_server_url_non_string_returns_none(tmp_path): + bad = tmp_path / "bad.yaml" + bad.write_text("server:\n url: 12345\n") + with patch.object(_config_reader, "_DEFAULT_CONFIG_FILE", tmp_path / "missing.yaml"): + assert read_server_url(config_file=bad) is None + + +def test_read_server_url_handles_malformed_yaml(tmp_path): + bad = tmp_path / "bad.yaml" + bad.write_text(": :\nnot yaml: [unbalanced\n") + with patch.object(_config_reader, "_DEFAULT_CONFIG_FILE", tmp_path / "missing.yaml"): + assert read_server_url(config_file=bad) is None + + +def test_read_server_url_handles_non_dict_root(tmp_path): + odd = tmp_path / "odd.yaml" + odd.write_text("- 1\n- 2\n- 3\n") + with patch.object(_config_reader, "_DEFAULT_CONFIG_FILE", tmp_path / "missing.yaml"): + assert read_server_url(config_file=odd) is None + + +def test_read_server_url_empty_string_treated_as_missing(tmp_path): + empty = tmp_path / "empty.yaml" + empty.write_text("server:\n url: ''\n") + with patch.object(_config_reader, "_DEFAULT_CONFIG_FILE", tmp_path / "missing.yaml"): + assert read_server_url(config_file=empty) is None diff --git a/tests/unit/cli/test_frontend_core.py b/tests/unit/cli/test_frontend_core.py deleted file mode 100644 index c4837fbbc6..0000000000 --- a/tests/unit/cli/test_frontend_core.py +++ /dev/null @@ -1,1646 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -""" -Unit tests for the frontend_core module. -""" - -import logging -from pathlib import Path -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest - -from pyrit.cli import frontend_core -from pyrit.cli._cli_args import _ArgSpec, _parse_shell_arguments -from pyrit.registry import InitializerMetadata, ScenarioMetadata, ScenarioParameterMetadata - - -class TestFrontendCore: - """Tests for FrontendCore class.""" - - @patch("pyrit.setup.configuration_loader.DEFAULT_CONFIG_PATH", Path("/nonexistent/.pyrit_conf")) - def test_init_with_defaults(self): - """Test initialization with default parameters.""" - context = frontend_core.FrontendCore() - - assert context._database == frontend_core.SQLITE - assert context._initialization_scripts is None - assert context._initializer_configs is None - assert context._log_level == logging.WARNING - assert context._initialized is False - - def test_init_with_all_parameters(self): - """Test initialization with all parameters.""" - scripts = [Path("/test/script.py")] - initializers = ["alpha_init", "beta_init", "gamma_init"] - - context = frontend_core.FrontendCore( - database=frontend_core.IN_MEMORY, - initialization_scripts=scripts, - initializer_names=initializers, - log_level=logging.DEBUG, - ) - - assert context._database == frontend_core.IN_MEMORY - # Check path ends with expected components (Windows adds drive letter to Unix-style paths) - assert context._initialization_scripts is not None - assert len(context._initialization_scripts) == 1 - assert context._initialization_scripts[0].parts[-2:] == ("test", "script.py") - assert context._initializer_configs is not None - assert [ic.name for ic in context._initializer_configs] == initializers - assert context._log_level == logging.DEBUG - - def test_init_with_invalid_database(self): - """Test initialization with invalid database raises ValueError.""" - with pytest.raises(ValueError, match="Invalid database type"): - frontend_core.FrontendCore(database="InvalidDB") - - @patch("pyrit.cli.frontend_core.ScenarioRegistry") - @patch("pyrit.cli.frontend_core.InitializerRegistry") - @patch("pyrit.cli.frontend_core.initialize_pyrit_async", new_callable=AsyncMock) - def test_initialize_loads_registries( - self, - mock_init_pyrit: AsyncMock, - mock_init_registry: MagicMock, - mock_scenario_registry: MagicMock, - ): - """Test initialize method loads registries.""" - context = frontend_core.FrontendCore() - import asyncio - - asyncio.run(context.initialize_async()) - - assert context._initialized is True - mock_init_pyrit.assert_called_once() - mock_scenario_registry.get_registry_singleton.assert_called_once() - mock_init_registry.assert_called_once() - - @patch("pyrit.cli.frontend_core.ScenarioRegistry") - @patch("pyrit.cli.frontend_core.InitializerRegistry") - @patch("pyrit.cli.frontend_core.initialize_pyrit_async", new_callable=AsyncMock) - async def test_scenario_registry_property_initializes( - self, - mock_init_pyrit: AsyncMock, - mock_init_registry: MagicMock, - mock_scenario_registry: MagicMock, - ): - """Test scenario_registry property triggers initialization.""" - context = frontend_core.FrontendCore() - assert context._initialized is False - - await context.initialize_async() - registry = context.scenario_registry - - assert context._initialized is True - assert registry is not None - - @patch("pyrit.cli.frontend_core.ScenarioRegistry") - @patch("pyrit.cli.frontend_core.InitializerRegistry") - @patch("pyrit.cli.frontend_core.initialize_pyrit_async", new_callable=AsyncMock) - async def test_initializer_registry_property_initializes( - self, - mock_init_pyrit: AsyncMock, - mock_init_registry: MagicMock, - mock_scenario_registry: MagicMock, - ): - """Test initializer_registry property triggers initialization.""" - context = frontend_core.FrontendCore() - assert context._initialized is False - - await context.initialize_async() - registry = context.initializer_registry - - assert context._initialized is True - assert registry is not None - - def test_scenario_registry_raises_when_none_after_init(self): - """Test scenario_registry raises ValueError when registry is None despite _initialized=True.""" - context = frontend_core.FrontendCore() - context._initialized = True - context._scenario_registry = None - - with pytest.raises(ValueError, match="self._scenario_registry is not initialized"): - _ = context.scenario_registry - - def test_initializer_registry_raises_when_none_after_init(self): - """Test initializer_registry raises ValueError when registry is None despite _initialized=True.""" - context = frontend_core.FrontendCore() - context._initialized = True - context._initializer_registry = None - - with pytest.raises(ValueError, match="self._initializer_registry is not initialized"): - _ = context.initializer_registry - - -class TestValidationFunctions: - """Tests for validation functions.""" - - def test_validate_database_valid_values(self): - """Test validate_database with valid values.""" - assert frontend_core.validate_database(database=frontend_core.IN_MEMORY) == frontend_core.IN_MEMORY - assert frontend_core.validate_database(database=frontend_core.SQLITE) == frontend_core.SQLITE - assert frontend_core.validate_database(database=frontend_core.AZURE_SQL) == frontend_core.AZURE_SQL - - def test_validate_database_invalid_value(self): - """Test validate_database with invalid value.""" - with pytest.raises(ValueError, match="Invalid database type"): - frontend_core.validate_database(database="InvalidDB") - - def test_validate_log_level_valid_values(self): - """Test validate_log_level with valid values.""" - assert frontend_core.validate_log_level(log_level="DEBUG") == logging.DEBUG - assert frontend_core.validate_log_level(log_level="INFO") == logging.INFO - assert frontend_core.validate_log_level(log_level="warning") == logging.WARNING # Case-insensitive - assert frontend_core.validate_log_level(log_level="error") == logging.ERROR - assert frontend_core.validate_log_level(log_level="CRITICAL") == logging.CRITICAL - - def test_validate_log_level_invalid_value(self): - """Test validate_log_level with invalid value.""" - with pytest.raises(ValueError, match="Invalid log level"): - frontend_core.validate_log_level(log_level="INVALID") - - def test_validate_integer_valid(self): - """Test validate_integer with valid values.""" - assert frontend_core.validate_integer("42") == 42 - assert frontend_core.validate_integer("0") == 0 - assert frontend_core.validate_integer("-5") == -5 - - def test_validate_integer_with_min_value(self): - """Test validate_integer with min_value constraint.""" - assert frontend_core.validate_integer("5", min_value=1) == 5 - assert frontend_core.validate_integer("1", min_value=1) == 1 - - def test_validate_integer_below_min_value(self): - """Test validate_integer below min_value raises ValueError.""" - with pytest.raises(ValueError, match="must be at least"): - frontend_core.validate_integer("0", min_value=1) - - def test_validate_integer_invalid_string(self): - """Test validate_integer with non-integer string.""" - with pytest.raises(ValueError, match="must be an integer"): - frontend_core.validate_integer("not_a_number") - - def test_validate_integer_custom_name(self): - """Test validate_integer with custom parameter name.""" - with pytest.raises(ValueError, match="max_retries must be an integer"): - frontend_core.validate_integer("invalid", name="max_retries") - - def test_positive_int_valid(self): - """Test positive_int with valid values.""" - assert frontend_core.positive_int("1") == 1 - assert frontend_core.positive_int("100") == 100 - - def test_positive_int_zero(self): - """Test positive_int with zero raises error.""" - import argparse - - with pytest.raises(argparse.ArgumentTypeError): - frontend_core.positive_int("0") - - def test_positive_int_negative(self): - """Test positive_int with negative value raises error.""" - import argparse - - with pytest.raises(argparse.ArgumentTypeError): - frontend_core.positive_int("-1") - - def test_non_negative_int_valid(self): - """Test non_negative_int with valid values.""" - assert frontend_core.non_negative_int("0") == 0 - assert frontend_core.non_negative_int("5") == 5 - - def test_non_negative_int_negative(self): - """Test non_negative_int with negative value raises error.""" - import argparse - - with pytest.raises(argparse.ArgumentTypeError): - frontend_core.non_negative_int("-1") - - def test_validate_database_argparse(self): - """Test validate_database_argparse wrapper.""" - assert frontend_core.validate_database_argparse(frontend_core.IN_MEMORY) == frontend_core.IN_MEMORY - - import argparse - - with pytest.raises(argparse.ArgumentTypeError): - frontend_core.validate_database_argparse("InvalidDB") - - def test_validate_log_level_argparse(self): - """Test validate_log_level_argparse wrapper.""" - assert frontend_core.validate_log_level_argparse("DEBUG") == logging.DEBUG - - import argparse - - with pytest.raises(argparse.ArgumentTypeError): - frontend_core.validate_log_level_argparse("INVALID") - - -class TestParseMemoryLabels: - """Tests for parse_memory_labels function.""" - - def test_parse_memory_labels_valid(self): - """Test parsing valid JSON labels.""" - json_str = '{"key1": "value1", "key2": "value2"}' - result = frontend_core.parse_memory_labels(json_string=json_str) - - assert result == {"key1": "value1", "key2": "value2"} - - def test_parse_memory_labels_empty(self): - """Test parsing empty JSON object.""" - result = frontend_core.parse_memory_labels(json_string="{}") - assert result == {} - - def test_parse_memory_labels_invalid_json(self): - """Test parsing invalid JSON raises ValueError.""" - with pytest.raises(ValueError, match="Invalid JSON"): - frontend_core.parse_memory_labels(json_string="not valid json") - - def test_parse_memory_labels_not_dict(self): - """Test parsing JSON array raises ValueError.""" - with pytest.raises(ValueError, match="must be a JSON object"): - frontend_core.parse_memory_labels(json_string='["array", "not", "dict"]') - - def test_parse_memory_labels_non_string_key(self): - """Test parsing with non-string values raises ValueError.""" - with pytest.raises(ValueError, match="All label keys and values must be strings"): - frontend_core.parse_memory_labels(json_string='{"key": 123}') - - -class TestResolveInitializationScripts: - """Tests for resolve_initialization_scripts function.""" - - @patch("pyrit.cli.frontend_core.InitializerRegistry.resolve_script_paths") - def test_resolve_initialization_scripts(self, mock_resolve: MagicMock): - """Test resolve_initialization_scripts calls InitializerRegistry.""" - mock_resolve.return_value = [Path("/test/script.py")] - - result = frontend_core.resolve_initialization_scripts(script_paths=["script.py"]) - - mock_resolve.assert_called_once_with(script_paths=["script.py"]) - assert result == [Path("/test/script.py")] - - -class TestListFunctions: - """Tests for list_scenarios_async and list_initializers_async functions.""" - - def test_discover_builtin_scenarios_uses_dotted_names(self): - """Built-in scenario names should be dotted (package.module) lowercase names.""" - from pyrit.registry.class_registries.scenario_registry import ScenarioRegistry - - registry = ScenarioRegistry() - registry._discover_builtin_scenarios() - - names = list(registry._class_entries.keys()) - assert len(names) > 0, "Should discover at least one built-in scenario" - for name in names: - assert "." in name, f"Scenario name '{name}' should be a dotted name (package.module)" - assert name == name.lower(), f"Scenario name '{name}' should be lowercase" - - def test_discover_builtin_scenarios_excludes_deprecated_aliases(self): - """Deprecated alias scenarios like ContentHarms must not appear in the registry.""" - from pyrit.registry.class_registries.scenario_registry import ScenarioRegistry - - registry = ScenarioRegistry() - registry._discover_builtin_scenarios() - - names = set(registry._class_entries.keys()) - class_names = {entry.registered_class.__name__ for entry in registry._class_entries.values()} - - assert "airt.content_harms" not in names, "Deprecated 'airt.content_harms' should not be registered" - assert "ContentHarms" not in class_names, "ContentHarms class should not appear under any registry name" - - async def test_list_scenarios(self): - """Test list_scenarios_async returns scenarios from registry.""" - mock_registry = MagicMock() - mock_registry.list_metadata.return_value = [{"name": "test_scenario"}] - - context = frontend_core.FrontendCore() - context._scenario_registry = mock_registry - context._initialized = True - - result = await frontend_core.list_scenarios_async(context=context) - - assert result == [{"name": "test_scenario"}] - mock_registry.list_metadata.assert_called_once() - - async def test_list_initializers(self): - """Test list_initializers_async returns initializers from context registry.""" - mock_registry = MagicMock() - mock_registry.list_metadata.return_value = [{"name": "test_init"}] - - context = frontend_core.FrontendCore() - context._initializer_registry = mock_registry - context._initialized = True - - result = await frontend_core.list_initializers_async(context=context) - - assert result == [{"name": "test_init"}] - mock_registry.list_metadata.assert_called_once() - - -class TestPrintFunctions: - """Tests for print functions.""" - - async def test_print_scenarios_list_with_scenarios(self, capsys): - """Test print_scenarios_list with scenarios.""" - context = frontend_core.FrontendCore() - mock_registry = MagicMock() - mock_registry.list_metadata.return_value = [ - ScenarioMetadata( - class_name="TestScenario", - class_module="test.scenarios", - class_description="Test description", - registry_name="test", - default_strategy="default", - all_strategies=(), - aggregate_strategies=(), - default_datasets=(), - max_dataset_size=None, - ) - ] - context._scenario_registry = mock_registry - context._initialized = True - - result = await frontend_core.print_scenarios_list_async(context=context) - - assert result == 0 - captured = capsys.readouterr() - assert "Available Scenarios" in captured.out - assert "test" in captured.out - - async def test_print_scenarios_list_empty(self, capsys): - """Test print_scenarios_list with no scenarios.""" - context = frontend_core.FrontendCore() - mock_registry = MagicMock() - mock_registry.list_metadata.return_value = [] - context._scenario_registry = mock_registry - context._initialized = True - - result = await frontend_core.print_scenarios_list_async(context=context) - - assert result == 0 - captured = capsys.readouterr() - assert "No scenarios found" in captured.out - - async def test_print_initializers_list_with_initializers(self, capsys): - """Test print_initializers_list_async with initializers.""" - context = frontend_core.FrontendCore() - mock_registry = MagicMock() - mock_registry.list_metadata.return_value = [ - InitializerMetadata( - class_name="TestInit", - class_module="test.initializers", - class_description="Test initializer", - registry_name="test", - required_env_vars=(), - ) - ] - context._initializer_registry = mock_registry - context._initialized = True - - result = await frontend_core.print_initializers_list_async(context=context) - - assert result == 0 - captured = capsys.readouterr() - assert "Available Initializers" in captured.out - assert "test" in captured.out - - async def test_print_initializers_list_empty(self, capsys): - """Test print_initializers_list_async with no initializers.""" - context = frontend_core.FrontendCore() - mock_registry = MagicMock() - mock_registry.list_metadata.return_value = [] - context._initializer_registry = mock_registry - context._initialized = True - - result = await frontend_core.print_initializers_list_async(context=context) - - assert result == 0 - captured = capsys.readouterr() - assert "No initializers found" in captured.out - - -class TestFormatFunctions: - """Tests for format_scenario_metadata and format_initializer_metadata.""" - - def test_format_scenario_metadata_basic(self, capsys): - """Test format_scenario_metadata with basic metadata.""" - - scenario_metadata = ScenarioMetadata( - class_name="TestScenario", - class_module="test.scenarios", - class_description="", - registry_name="test", - default_strategy="", - all_strategies=(), - aggregate_strategies=(), - default_datasets=(), - max_dataset_size=None, - ) - - frontend_core.format_scenario_metadata(scenario_metadata=scenario_metadata) - - captured = capsys.readouterr() - assert "test" in captured.out - assert "TestScenario" in captured.out - - def test_format_scenario_metadata_with_description(self, capsys): - """Test format_scenario_metadata with description.""" - - scenario_metadata = ScenarioMetadata( - class_name="TestScenario", - class_module="test.scenarios", - class_description="This is a test scenario", - registry_name="test", - default_strategy="", - all_strategies=(), - aggregate_strategies=(), - default_datasets=(), - max_dataset_size=None, - ) - - frontend_core.format_scenario_metadata(scenario_metadata=scenario_metadata) - - captured = capsys.readouterr() - assert "This is a test scenario" in captured.out - - def test_format_scenario_metadata_with_strategies(self, capsys): - """Test format_scenario_metadata with strategies.""" - scenario_metadata = ScenarioMetadata( - class_name="TestScenario", - class_module="test.scenarios", - class_description="", - registry_name="test", - default_strategy="strategy1", - all_strategies=("strategy1", "strategy2"), - aggregate_strategies=(), - default_datasets=(), - max_dataset_size=None, - ) - - frontend_core.format_scenario_metadata(scenario_metadata=scenario_metadata) - - captured = capsys.readouterr() - assert "strategy1" in captured.out - assert "strategy2" in captured.out - assert "Default Strategy" in captured.out - - def test_format_scenario_metadata_with_supported_parameters(self, capsys): - """Test format_scenario_metadata renders the supported_parameters section.""" - scenario_metadata = ScenarioMetadata( - class_name="TestScenario", - class_module="test.scenarios", - class_description="", - registry_name="test", - default_strategy="", - all_strategies=(), - aggregate_strategies=(), - default_datasets=(), - max_dataset_size=None, - supported_parameters=( - ScenarioParameterMetadata("max_turns", "Conversation turn cap", 5, "int", None), - ScenarioParameterMetadata("mode", "Run mode", "fast", "str", "'fast', 'slow'"), - ScenarioParameterMetadata("optional_param", "Optional input", None, "str", None), - ), - ) - - frontend_core.format_scenario_metadata(scenario_metadata=scenario_metadata) - - captured = capsys.readouterr() - assert "Supported Parameters:" in captured.out - assert "max_turns" in captured.out - assert "(int)" in captured.out - assert "[default: 5]" in captured.out - assert "Conversation turn cap" in captured.out - assert "[choices: 'fast', 'slow']" in captured.out - assert "optional_param" in captured.out - - def test_format_scenario_metadata_omits_section_when_no_parameters(self, capsys): - """A scenario without declared parameters should not print the Supported Parameters header.""" - scenario_metadata = ScenarioMetadata( - class_name="TestScenario", - class_module="test.scenarios", - class_description="", - registry_name="test", - default_strategy="", - all_strategies=(), - aggregate_strategies=(), - default_datasets=(), - max_dataset_size=None, - ) - - frontend_core.format_scenario_metadata(scenario_metadata=scenario_metadata) - - captured = capsys.readouterr() - assert "Supported Parameters" not in captured.out - - def test_format_initializer_metadata_basic(self, capsys) -> None: - """Test format_initializer_metadata with basic metadata.""" - initializer_metadata = InitializerMetadata( - class_name="TestInit", - class_module="test.initializers", - class_description="", - registry_name="test", - required_env_vars=(), - ) - - frontend_core.format_initializer_metadata(initializer_metadata=initializer_metadata) - - captured = capsys.readouterr() - assert "test" in captured.out - assert "TestInit" in captured.out - - def test_format_initializer_metadata_with_env_vars(self, capsys) -> None: - """Test format_initializer_metadata with environment variables.""" - initializer_metadata = InitializerMetadata( - class_name="TestInit", - class_module="test.initializers", - class_description="", - registry_name="test", - required_env_vars=("VAR1", "VAR2"), - ) - - frontend_core.format_initializer_metadata(initializer_metadata=initializer_metadata) - - captured = capsys.readouterr() - assert "VAR1" in captured.out - assert "VAR2" in captured.out - - def test_format_initializer_metadata_with_supported_parameters(self, capsys) -> None: - """Test format_initializer_metadata prints supported parameters.""" - initializer_metadata = InitializerMetadata( - class_name="TestInit", - class_module="test.initializers", - class_description="", - registry_name="test", - required_env_vars=(), - supported_parameters=( - ("model_name", "The model to use", None), - ("temperature", "Sampling temperature", ["0.7"]), - ), - ) - - frontend_core.format_initializer_metadata(initializer_metadata=initializer_metadata) - - captured = capsys.readouterr() - assert "Supported Parameters:" in captured.out - assert "model_name" in captured.out - assert "temperature" in captured.out - assert "[default: ['0.7']]" in captured.out - - def test_format_initializer_metadata_with_description(self, capsys) -> None: - """Test format_initializer_metadata with description.""" - initializer_metadata = InitializerMetadata( - class_name="TestInit", - class_module="test.initializers", - class_description="Test description", - registry_name="test", - required_env_vars=(), - ) - - frontend_core.format_initializer_metadata(initializer_metadata=initializer_metadata) - - captured = capsys.readouterr() - assert "Test description" in captured.out - - -class TestParseInitializerArg: - """Tests for _parse_initializer_arg function.""" - - def test_simple_name_returns_string(self) -> None: - """Test that a plain name without ':' returns the string as-is.""" - assert frontend_core._parse_initializer_arg("simple") == "simple" - - def test_name_with_single_param(self) -> None: - """Test name:key=value parsing.""" - result = frontend_core._parse_initializer_arg("target:tags=default") - assert result == {"name": "target", "args": {"tags": ["default"]}} - - def test_name_with_comma_separated_values(self) -> None: - """Test that comma-separated values are split into a list.""" - result = frontend_core._parse_initializer_arg("target:tags=default,scorer") - assert result == {"name": "target", "args": {"tags": ["default", "scorer"]}} - - def test_name_with_multiple_params(self) -> None: - """Test semicolon-separated multiple params.""" - result = frontend_core._parse_initializer_arg("target:tags=default;mode=strict") - assert result == {"name": "target", "args": {"tags": ["default"], "mode": ["strict"]}} - - def test_missing_name_before_colon_raises(self) -> None: - """Test that ':key=val' with no name raises ValueError.""" - with pytest.raises(ValueError, match="missing name before ':'"): - frontend_core._parse_initializer_arg(":tags=default") - - def test_missing_equals_in_param_raises(self) -> None: - """Test that 'name:badparam' without '=' raises ValueError.""" - with pytest.raises(ValueError, match="expected key=value format"): - frontend_core._parse_initializer_arg("target:badparam") - - def test_empty_key_raises(self) -> None: - """Test that 'name:=value' with empty key raises ValueError.""" - with pytest.raises(ValueError, match="empty key"): - frontend_core._parse_initializer_arg("target:=value") - - def test_colon_but_no_params_returns_string(self) -> None: - """Test that 'name:' with trailing colon but no params returns the name string.""" - result = frontend_core._parse_initializer_arg("target:") - assert result == "target" - - -class TestParseShellArguments: - """Tests for the generic _parse_shell_arguments function.""" - - def test_empty_parts_returns_none_defaults(self): - """Test that empty input returns None for all result keys.""" - spec = _ArgSpec(flags=["--foo"], result_key="foo") - result = _parse_shell_arguments(parts=[], arg_specs=[spec]) - assert result == {"foo": None} - - def test_single_value_arg(self): - """Test parsing a single-value argument.""" - spec = _ArgSpec(flags=["--name"], result_key="name") - result = _parse_shell_arguments(parts=["--name", "alice"], arg_specs=[spec]) - assert result["name"] == "alice" - - def test_single_value_with_parser(self): - """Test that single-value parser is applied.""" - spec = _ArgSpec(flags=["--count"], result_key="count", parser=int) - result = _parse_shell_arguments(parts=["--count", "42"], arg_specs=[spec]) - assert result["count"] == 42 - - def test_single_value_missing_raises(self): - """Test that missing value for single-value arg raises ValueError.""" - spec = _ArgSpec(flags=["--name"], result_key="name") - with pytest.raises(ValueError, match="--name requires a value"): - _parse_shell_arguments(parts=["--name"], arg_specs=[spec]) - - def test_multi_value_arg(self): - """Test collecting multiple values until next flag.""" - spec = _ArgSpec(flags=["--items"], result_key="items", multi_value=True) - result = _parse_shell_arguments(parts=["--items", "a", "b", "c"], arg_specs=[spec]) - assert result["items"] == ["a", "b", "c"] - - def test_multi_value_stops_at_next_flag(self): - """Test that multi-value collection stops at the next known flag.""" - items_spec = _ArgSpec(flags=["--items"], result_key="items", multi_value=True) - name_spec = _ArgSpec(flags=["--name"], result_key="name") - result = _parse_shell_arguments( - parts=["--items", "a", "b", "--name", "alice"], - arg_specs=[items_spec, name_spec], - ) - assert result["items"] == ["a", "b"] - assert result["name"] == "alice" - - def test_multi_value_stops_at_short_flag_alias(self): - """Test that multi-value collection stops at a short flag alias like -s.""" - long_spec = _ArgSpec(flags=["--items"], result_key="items", multi_value=True) - short_spec = _ArgSpec(flags=["-s", "--short"], result_key="short", multi_value=True) - result = _parse_shell_arguments( - parts=["--items", "a", "b", "-s", "x"], - arg_specs=[long_spec, short_spec], - ) - assert result["items"] == ["a", "b"] - assert result["short"] == ["x"] - - def test_multi_value_with_parser(self): - """Test that parser transforms each collected value.""" - spec = _ArgSpec(flags=["--nums"], result_key="nums", multi_value=True, parser=int) - result = _parse_shell_arguments(parts=["--nums", "1", "2", "3"], arg_specs=[spec]) - assert result["nums"] == [1, 2, 3] - - def test_multi_value_no_values_raises(self): - """Test that multi-value arg with no values raises ValueError.""" - items_spec = _ArgSpec(flags=["--items"], result_key="items", multi_value=True) - name_spec = _ArgSpec(flags=["--name"], result_key="name") - with pytest.raises(ValueError, match="--items requires at least one value"): - _parse_shell_arguments( - parts=["--items", "--name", "alice"], - arg_specs=[items_spec, name_spec], - ) - - def test_unknown_flag_raises(self): - """Test that an unknown flag raises ValueError.""" - spec = _ArgSpec(flags=["--known"], result_key="known") - with pytest.raises(ValueError, match="Unknown argument: --unknown"): - _parse_shell_arguments(parts=["--unknown"], arg_specs=[spec]) - - def test_multiple_specs_all_none_when_unused(self): - """Test that unused specs default to None.""" - specs = [ - _ArgSpec(flags=["--a"], result_key="a"), - _ArgSpec(flags=["--b"], result_key="b", multi_value=True), - ] - result = _parse_shell_arguments(parts=[], arg_specs=specs) - assert result == {"a": None, "b": None} - - -class TestParseRunArguments: - """Tests for parse_run_arguments function.""" - - def test_parse_run_arguments_basic(self): - """Test parsing basic scenario name.""" - result = frontend_core.parse_run_arguments(args_string="test_scenario") - - assert result["scenario_name"] == "test_scenario" - assert result["initializers"] is None - assert result["scenario_strategies"] is None - - def test_parse_run_arguments_with_initializers(self): - """Test parsing with initializers.""" - result = frontend_core.parse_run_arguments(args_string="test_scenario --initializers init1 init2") - - assert result["scenario_name"] == "test_scenario" - assert result["initializers"] == ["init1", "init2"] - - def test_parse_run_arguments_with_initializer_params(self): - """Test parsing initializers with key=value params.""" - result = frontend_core.parse_run_arguments( - args_string="test_scenario --initializers simple target:tags=default" - ) - - assert result["initializers"][0] == "simple" - assert result["initializers"][1] == {"name": "target", "args": {"tags": ["default"]}} - - def test_parse_run_arguments_with_initializer_multiple_params(self): - """Test parsing initializers with multiple key=value params separated by semicolons.""" - result = frontend_core.parse_run_arguments( - args_string="test_scenario --initializers target:tags=default;mode=strict" - ) - - assert result["initializers"][0] == {"name": "target", "args": {"tags": ["default"], "mode": ["strict"]}} - - def test_parse_run_arguments_with_initializer_comma_list(self): - """Test parsing initializer params with comma-separated values into lists.""" - result = frontend_core.parse_run_arguments( - args_string="test_scenario --initializers target:tags=default,scorer" - ) - - assert result["initializers"][0] == {"name": "target", "args": {"tags": ["default", "scorer"]}} - - def test_parse_run_arguments_with_strategies(self): - """Test parsing with strategies.""" - result = frontend_core.parse_run_arguments(args_string="test_scenario --strategies s1 s2") - - assert result["scenario_strategies"] == ["s1", "s2"] - - def test_parse_run_arguments_with_short_strategies(self): - """Test parsing with -s flag.""" - result = frontend_core.parse_run_arguments(args_string="test_scenario -s s1 s2") - - assert result["scenario_strategies"] == ["s1", "s2"] - - def test_parse_run_arguments_with_max_concurrency(self): - """Test parsing with max-concurrency.""" - result = frontend_core.parse_run_arguments(args_string="test_scenario --max-concurrency 5") - - assert result["max_concurrency"] == 5 - - def test_parse_run_arguments_with_max_retries(self): - """Test parsing with max-retries.""" - result = frontend_core.parse_run_arguments(args_string="test_scenario --max-retries 3") - - assert result["max_retries"] == 3 - - def test_parse_run_arguments_with_memory_labels(self): - """Test parsing with memory-labels (JSON must be quoted in shell mode).""" - result = frontend_core.parse_run_arguments(args_string="""test_scenario --memory-labels '{"key":"value"}'""") - - assert result["memory_labels"] == {"key": "value"} - - def test_parse_run_arguments_with_log_level(self): - """Test parsing with log-level override.""" - result = frontend_core.parse_run_arguments(args_string="test_scenario --log-level DEBUG") - - assert result["log_level"] == logging.DEBUG - - def test_parse_run_arguments_with_initialization_scripts(self): - """Test parsing with initialization-scripts.""" - result = frontend_core.parse_run_arguments( - args_string="test_scenario --initialization-scripts script1.py script2.py" - ) - - assert result["initialization_scripts"] == ["script1.py", "script2.py"] - - def test_parse_run_arguments_with_quoted_paths(self): - """Test parsing quoted paths with spaces for shell mode.""" - result = frontend_core.parse_run_arguments( - args_string='test_scenario --initialization-scripts "/tmp/my script.py" --strategies s1' - ) - - assert result["initialization_scripts"] == ["/tmp/my script.py"] - assert result["scenario_strategies"] == ["s1"] - - def test_parse_run_arguments_with_quoted_memory_labels(self): - """Test parsing quoted JSON for memory-labels in shell mode.""" - result = frontend_core.parse_run_arguments( - args_string="""test_scenario --memory-labels '{"experiment": "test 1"}'""" - ) - - assert result["memory_labels"] == {"experiment": "test 1"} - - def test_parse_run_arguments_with_short_strategies_after_initializers(self): - """Test that -s is treated as a flag after multi-value initializers.""" - result = frontend_core.parse_run_arguments(args_string="test_scenario --initializers init1 -s s1 s2") - - assert result["initializers"] == ["init1"] - assert result["scenario_strategies"] == ["s1", "s2"] - - def test_parse_run_arguments_unterminated_quote_raises(self): - """Test that unterminated quotes raise ValueError.""" - with pytest.raises(ValueError): - frontend_core.parse_run_arguments(args_string='test_scenario --initialization-scripts "/tmp/my script.py') - - def test_parse_run_arguments_complex(self): - """Test parsing complex argument combination.""" - args = "test_scenario --initializers init1 --strategies s1 s2 --max-concurrency 10" - result = frontend_core.parse_run_arguments(args_string=args) - - assert result["scenario_name"] == "test_scenario" - assert result["initializers"] == ["init1"] - assert result["scenario_strategies"] == ["s1", "s2"] - assert result["max_concurrency"] == 10 - - def test_parse_run_arguments_empty_raises(self): - """Test parsing empty string raises ValueError.""" - with pytest.raises(ValueError, match="No scenario name provided"): - frontend_core.parse_run_arguments(args_string="") - - def test_parse_run_arguments_invalid_max_concurrency(self): - """Test parsing with invalid max-concurrency.""" - with pytest.raises(ValueError): - frontend_core.parse_run_arguments(args_string="test_scenario --max-concurrency 0") - - def test_parse_run_arguments_invalid_max_retries(self): - """Test parsing with invalid max-retries.""" - with pytest.raises(ValueError): - frontend_core.parse_run_arguments(args_string="test_scenario --max-retries -1") - - def test_parse_run_arguments_missing_value(self): - """Test parsing with missing argument value.""" - with pytest.raises(ValueError, match="requires a value"): - frontend_core.parse_run_arguments(args_string="test_scenario --max-concurrency") - - -class TestParseListTargetsArguments: - """Tests for parse_list_targets_arguments function.""" - - def test_parse_list_targets_arguments_empty(self): - """Test parsing empty string returns defaults.""" - result = frontend_core.parse_list_targets_arguments(args_string="") - assert result["initializers"] is None - assert result["initialization_scripts"] is None - - def test_parse_list_targets_arguments_with_initializers(self): - """Test parsing with initializers.""" - result = frontend_core.parse_list_targets_arguments(args_string="--initializers target init2") - assert result["initializers"] == ["target", "init2"] - - def test_parse_list_targets_arguments_with_initializer_params(self): - """Test parsing initializers with key=value params.""" - result = frontend_core.parse_list_targets_arguments(args_string="--initializers target:tags=default,scorer") - assert result["initializers"] == [{"name": "target", "args": {"tags": ["default", "scorer"]}}] - - def test_parse_list_targets_arguments_with_initialization_scripts(self): - """Test parsing with initialization-scripts.""" - result = frontend_core.parse_list_targets_arguments( - args_string="--initialization-scripts script1.py script2.py" - ) - assert result["initialization_scripts"] == ["script1.py", "script2.py"] - - def test_parse_list_targets_arguments_with_both(self): - """Test parsing with both initializers and scripts.""" - result = frontend_core.parse_list_targets_arguments( - args_string="--initializers target --initialization-scripts script1.py" - ) - assert result["initializers"] == ["target"] - assert result["initialization_scripts"] == ["script1.py"] - - def test_parse_list_targets_arguments_unknown_arg_raises(self): - """Test parsing with unknown argument raises ValueError.""" - with pytest.raises(ValueError, match="Unknown argument"): - frontend_core.parse_list_targets_arguments(args_string="--unknown-flag") - - -@pytest.mark.usefixtures("patch_central_database") -class TestRunScenarioAsync: - """Tests for run_scenario_async function.""" - - @patch("pyrit.cli.frontend_core.initialize_pyrit_async", new_callable=AsyncMock) - @patch("pyrit.cli.frontend_core.ConsoleScenarioResultPrinter") - async def test_run_scenario_async_basic( - self, - mock_printer_class: MagicMock, - mock_init: AsyncMock, - ): - """Test running a basic scenario.""" - # Mock context - context = frontend_core.FrontendCore() - mock_scenario_registry = MagicMock() - mock_scenario_class = MagicMock() - mock_scenario_instance = MagicMock() - mock_result = MagicMock() - mock_printer = MagicMock() - mock_printer.print_summary_async = AsyncMock() - - mock_scenario_instance.initialize_async = AsyncMock() - mock_scenario_instance.run_async = AsyncMock(return_value=mock_result) - mock_scenario_class.return_value = mock_scenario_instance - mock_scenario_registry.get_class.return_value = mock_scenario_class - mock_printer_class.return_value = mock_printer - - context._scenario_registry = mock_scenario_registry - context._initializer_registry = MagicMock() - context._initialized = True - - # Run scenario - result = await frontend_core.run_scenario_async( - scenario_name="test_scenario", - context=context, - ) - - assert result == mock_result - # Verify scenario was instantiated with no arguments (runtime params go to initialize_async) - mock_scenario_class.assert_called_once_with() - mock_scenario_instance.initialize_async.assert_called_once_with() - mock_scenario_instance.run_async.assert_called_once() - mock_printer.print_summary_async.assert_called_once_with(mock_result) - - @patch("pyrit.cli.frontend_core.initialize_pyrit_async", new_callable=AsyncMock) - async def test_run_scenario_async_not_found(self, mock_init: AsyncMock): - """Test running non-existent scenario raises ValueError.""" - context = frontend_core.FrontendCore() - mock_scenario_registry = MagicMock() - mock_scenario_registry.get_class.return_value = None - mock_scenario_registry.get_names.return_value = ["other_scenario"] - - context._scenario_registry = mock_scenario_registry - context._initializer_registry = MagicMock() - context._initialized = True - - with pytest.raises(ValueError, match="Scenario 'test_scenario' not found"): - await frontend_core.run_scenario_async( - scenario_name="test_scenario", - context=context, - ) - - @patch("pyrit.cli.frontend_core.initialize_pyrit_async", new_callable=AsyncMock) - @patch("pyrit.cli.frontend_core.ConsoleScenarioResultPrinter") - async def test_run_scenario_async_with_strategies( - self, - mock_printer_class: MagicMock, - mock_init: AsyncMock, - ): - """Test running scenario with strategies.""" - context = frontend_core.FrontendCore() - mock_scenario_registry = MagicMock() - mock_scenario_class = MagicMock() - mock_scenario_instance = MagicMock() - mock_result = MagicMock() - mock_printer = MagicMock() - mock_printer.print_summary_async = AsyncMock() - - # Mock strategy enum - from enum import Enum - - class MockStrategy(Enum): - strategy1 = "strategy1" - - mock_scenario_class.get_strategy_class.return_value = MockStrategy - mock_scenario_instance.initialize_async = AsyncMock() - mock_scenario_instance.run_async = AsyncMock(return_value=mock_result) - mock_scenario_class.return_value = mock_scenario_instance - mock_scenario_registry.get_class.return_value = mock_scenario_class - mock_printer_class.return_value = mock_printer - - context._scenario_registry = mock_scenario_registry - context._initializer_registry = MagicMock() - context._initialized = True - - # Run with strategies - await frontend_core.run_scenario_async( - scenario_name="test_scenario", - context=context, - scenario_strategies=["strategy1"], - ) - - # Verify scenario was instantiated with no arguments - mock_scenario_class.assert_called_once_with() - # Verify strategy was passed to initialize_async - call_kwargs = mock_scenario_instance.initialize_async.call_args[1] - assert "scenario_strategies" in call_kwargs - - @patch("pyrit.cli.frontend_core.initialize_pyrit_async", new_callable=AsyncMock) - @patch("pyrit.cli.frontend_core.ConsoleScenarioResultPrinter") - async def test_run_scenario_async_with_initializers( - self, - mock_printer_class: MagicMock, - mock_init: AsyncMock, - ): - """Test running scenario with initializers.""" - context = frontend_core.FrontendCore(initializer_names=["test_init"]) - mock_scenario_registry = MagicMock() - mock_initializer_registry = MagicMock() - mock_scenario_class = MagicMock() - mock_scenario_instance = MagicMock() - mock_result = MagicMock() - mock_printer = MagicMock() - mock_printer.print_summary_async = AsyncMock() - - mock_initializer_class = MagicMock() - mock_initializer_registry.get_class.return_value = mock_initializer_class - - mock_scenario_instance.initialize_async = AsyncMock() - mock_scenario_instance.run_async = AsyncMock(return_value=mock_result) - mock_scenario_class.return_value = mock_scenario_instance - mock_scenario_registry.get_class.return_value = mock_scenario_class - mock_printer_class.return_value = mock_printer - - context._scenario_registry = mock_scenario_registry - context._initializer_registry = mock_initializer_registry - context._initialized = True - - # Run with initializers - await frontend_core.run_scenario_async( - scenario_name="test_scenario", - context=context, - ) - - # Verify initializer was retrieved - mock_initializer_registry.get_class.assert_called_once_with("test_init") - - @patch("pyrit.cli.frontend_core.initialize_pyrit_async", new_callable=AsyncMock) - @patch("pyrit.cli.frontend_core.ConsoleScenarioResultPrinter") - async def test_run_scenario_async_with_max_concurrency( - self, - mock_printer_class: MagicMock, - mock_init: AsyncMock, - ): - """Test running scenario with max_concurrency.""" - context = frontend_core.FrontendCore() - mock_scenario_registry = MagicMock() - mock_scenario_class = MagicMock() - mock_scenario_instance = MagicMock() - mock_result = MagicMock() - mock_printer = MagicMock() - mock_printer.print_summary_async = AsyncMock() - - mock_scenario_instance.initialize_async = AsyncMock() - mock_scenario_instance.run_async = AsyncMock(return_value=mock_result) - mock_scenario_class.return_value = mock_scenario_instance - mock_scenario_registry.get_class.return_value = mock_scenario_class - mock_printer_class.return_value = mock_printer - - context._scenario_registry = mock_scenario_registry - context._initializer_registry = MagicMock() - context._initialized = True - - # Run with max_concurrency - await frontend_core.run_scenario_async( - scenario_name="test_scenario", - context=context, - max_concurrency=5, - ) - - # Verify scenario was instantiated with no arguments - mock_scenario_class.assert_called_once_with() - # Verify max_concurrency was passed to initialize_async - call_kwargs = mock_scenario_instance.initialize_async.call_args[1] - assert call_kwargs["max_concurrency"] == 5 - - @patch("pyrit.cli.frontend_core.initialize_pyrit_async", new_callable=AsyncMock) - @patch("pyrit.cli.frontend_core.ConsoleScenarioResultPrinter") - async def test_run_scenario_async_without_print_summary( - self, - mock_printer_class: MagicMock, - mock_init: AsyncMock, - ): - """Test running scenario without printing summary.""" - context = frontend_core.FrontendCore() - mock_scenario_registry = MagicMock() - mock_scenario_class = MagicMock() - mock_scenario_instance = MagicMock() - mock_result = MagicMock() - mock_printer = MagicMock() - - mock_scenario_instance.initialize_async = AsyncMock() - mock_scenario_instance.run_async = AsyncMock(return_value=mock_result) - mock_scenario_class.return_value = mock_scenario_instance - mock_scenario_registry.get_class.return_value = mock_scenario_class - mock_printer_class.return_value = mock_printer - - context._scenario_registry = mock_scenario_registry - context._initializer_registry = MagicMock() - context._initialized = True - - # Run without printing - await frontend_core.run_scenario_async( - scenario_name="test_scenario", - context=context, - print_summary=False, - ) - - # Verify printer was not called - assert mock_printer.print_summary_async.call_count == 0 - - -class TestArgHelp: - """Tests for frontend_core.ARG_HELP dictionary.""" - - def test_arg_help_contains_all_keys(self): - """Test frontend_core.ARG_HELP contains expected keys.""" - expected_keys = [ - "initializers", - "initialization_scripts", - "scenario_strategies", - "max_concurrency", - "max_retries", - "memory_labels", - "database", - "log_level", - "target", - ] - - for key in expected_keys: - assert key in frontend_core.ARG_HELP - assert isinstance(frontend_core.ARG_HELP[key], str) - assert len(frontend_core.ARG_HELP[key]) > 0 - - -class TestParseRunArgumentsTarget: - """Tests for --target parsing in parse_run_arguments.""" - - def test_parse_run_arguments_with_target(self): - """Test parsing with --target.""" - result = frontend_core.parse_run_arguments(args_string="test_scenario --target my_target") - - assert result["target"] == "my_target" - - def test_parse_run_arguments_target_with_other_args(self): - """Test parsing --target alongside other arguments.""" - result = frontend_core.parse_run_arguments( - args_string="test_scenario --target my_target --initializers init1 --max-concurrency 5" - ) - - -class TestWithOverrides: - """Tests for FrontendCore.with_overrides method.""" - - def _make_initialized_parent(self) -> frontend_core.FrontendCore: - """Create a fully-initialized FrontendCore for testing with_overrides.""" - parent = frontend_core.FrontendCore( - database=frontend_core.IN_MEMORY, - initializer_names=["parent_init"], - log_level=logging.WARNING, - ) - parent._scenario_registry = MagicMock() - parent._initializer_registry = MagicMock() - parent._initialized = True - parent._silent_reinit = True - return parent - - def test_with_overrides_inherits_fields(self): - """Test that derived context inherits database, env_files, operator, operation.""" - parent = self._make_initialized_parent() - - derived = parent.with_overrides() - - assert derived._database == parent._database - assert derived._env_files == parent._env_files - assert derived._operator == parent._operator - assert derived._operation == parent._operation - - def test_with_overrides_shares_registries(self): - """Test that derived context shares scenario and initializer registries.""" - parent = self._make_initialized_parent() - - derived = parent.with_overrides() - - assert derived._scenario_registry is parent._scenario_registry - assert derived._initializer_registry is parent._initializer_registry - - def test_with_overrides_sets_initialized_and_silent(self): - """Test that derived context is marked initialized with silent reinit.""" - parent = self._make_initialized_parent() - - derived = parent.with_overrides() - - assert derived._initialized is True - assert derived._silent_reinit is True - - def test_with_overrides_none_keeps_parent_values(self): - """Test that passing None for all overrides keeps parent's values.""" - parent = self._make_initialized_parent() - - derived = parent.with_overrides( - initializer_names=None, - initialization_scripts=None, - log_level=None, - ) - - assert derived._initializer_configs == parent._initializer_configs - assert derived._initialization_scripts == parent._initialization_scripts - assert derived._log_level == parent._log_level - - def test_with_overrides_initializer_names(self): - """Test that initializer_names override normalizes to InitializerConfig objects.""" - parent = self._make_initialized_parent() - - derived = parent.with_overrides(initializer_names=["target", "dataset"]) - - assert derived._initializer_configs is not None - names = [ic.name for ic in derived._initializer_configs] - assert names == ["target", "dataset"] - # Parent should still have original - assert [ic.name for ic in parent._initializer_configs] == ["parent_init"] - - def test_with_overrides_initializer_names_dict(self): - """Test initializer_names with dict entries (name + args).""" - parent = self._make_initialized_parent() - - derived = parent.with_overrides(initializer_names=[{"name": "target", "args": {"tags": "default"}}]) - - assert derived._initializer_configs is not None - assert len(derived._initializer_configs) == 1 - assert derived._initializer_configs[0].name == "target" - assert derived._initializer_configs[0].args == {"tags": "default"} - - def test_with_overrides_initialization_scripts(self): - """Test that initialization_scripts override replaces parent's scripts.""" - parent = self._make_initialized_parent() - new_scripts = [Path("/new/script.py")] - - derived = parent.with_overrides(initialization_scripts=new_scripts) - - assert derived._initialization_scripts == new_scripts - # Parent should be unchanged - assert parent._initialization_scripts != new_scripts - - def test_with_overrides_log_level(self): - """Test that log_level override replaces parent's log level.""" - parent = self._make_initialized_parent() - - derived = parent.with_overrides(log_level=logging.DEBUG) - - assert derived._log_level == logging.DEBUG - assert parent._log_level == logging.WARNING - - def test_with_overrides_does_not_mutate_parent(self): - """Test that with_overrides does not modify the parent context.""" - parent = self._make_initialized_parent() - original_configs = parent._initializer_configs - original_log_level = parent._log_level - original_scripts = parent._initialization_scripts - - parent.with_overrides( - initializer_names=["new_init"], - initialization_scripts=[Path("/new.py")], - log_level=logging.DEBUG, - ) - - assert parent._initializer_configs is original_configs - assert parent._log_level == original_log_level - assert parent._initialization_scripts is original_scripts - - def test_parse_run_arguments_target_missing_value(self): - """Test parsing --target without a value raises ValueError.""" - with pytest.raises(ValueError, match="--target requires a value"): - frontend_core.parse_run_arguments(args_string="test_scenario --target") - - def test_parse_run_arguments_no_target(self): - """Test parsing without --target returns None.""" - result = frontend_core.parse_run_arguments(args_string="test_scenario") - - assert result["target"] is None - - -@pytest.mark.usefixtures("patch_central_database") -class TestRunScenarioAsyncTarget: - """Tests for target resolution in run_scenario_async.""" - - @patch("pyrit.cli.frontend_core.TargetRegistry") - @patch("pyrit.cli.frontend_core.initialize_pyrit_async", new_callable=AsyncMock) - @patch("pyrit.cli.frontend_core.ConsoleScenarioResultPrinter") - async def test_run_scenario_async_with_valid_target( - self, - mock_printer_class: MagicMock, - mock_init: AsyncMock, - mock_target_registry_class: MagicMock, - ): - """Test running scenario with a valid target name resolves from registry.""" - # Setup mocks - mock_target = MagicMock() - mock_registry = MagicMock() - mock_registry.get_instance_by_name.return_value = mock_target - mock_target_registry_class.get_registry_singleton.return_value = mock_registry - - context = frontend_core.FrontendCore() - mock_scenario_registry = MagicMock() - mock_scenario_class = MagicMock() - mock_scenario_instance = MagicMock() - mock_result = MagicMock() - mock_printer = MagicMock() - mock_printer.print_summary_async = AsyncMock() - - mock_scenario_instance.initialize_async = AsyncMock() - mock_scenario_instance.run_async = AsyncMock(return_value=mock_result) - mock_scenario_class.return_value = mock_scenario_instance - mock_scenario_registry.get_class.return_value = mock_scenario_class - mock_printer_class.return_value = mock_printer - - context._scenario_registry = mock_scenario_registry - context._initializer_registry = MagicMock() - context._initialized = True - - result = await frontend_core.run_scenario_async( - scenario_name="test_scenario", - context=context, - target_name="my_target", - ) - - assert result == mock_result - mock_registry.get_instance_by_name.assert_called_once_with("my_target") - # Verify objective_target was passed to initialize_async - call_kwargs = mock_scenario_instance.initialize_async.call_args[1] - assert call_kwargs["objective_target"] is mock_target - - @patch("pyrit.cli.frontend_core.TargetRegistry") - @patch("pyrit.cli.frontend_core.initialize_pyrit_async", new_callable=AsyncMock) - async def test_run_scenario_async_with_invalid_target( - self, - mock_init: AsyncMock, - mock_target_registry_class: MagicMock, - ): - """Test running scenario with an invalid target name raises ValueError.""" - mock_registry = MagicMock() - mock_registry.get_instance_by_name.return_value = None - mock_registry.get_names.return_value = ["target_a", "target_b"] - mock_target_registry_class.get_registry_singleton.return_value = mock_registry - - context = frontend_core.FrontendCore() - context._scenario_registry = MagicMock() - context._initializer_registry = MagicMock() - context._initialized = True - - with pytest.raises(ValueError, match="Target 'bad_target' not found in registry"): - await frontend_core.run_scenario_async( - scenario_name="test_scenario", - context=context, - target_name="bad_target", - ) - - @patch("pyrit.cli.frontend_core.TargetRegistry") - @patch("pyrit.cli.frontend_core.initialize_pyrit_async", new_callable=AsyncMock) - async def test_run_scenario_async_with_empty_target_registry( - self, - mock_init: AsyncMock, - mock_target_registry_class: MagicMock, - ): - """Test running scenario with target name when registry is empty gives helpful error.""" - mock_registry = MagicMock() - mock_registry.get_instance_by_name.return_value = None - mock_registry.get_names.return_value = [] - mock_target_registry_class.get_registry_singleton.return_value = mock_registry - - context = frontend_core.FrontendCore() - context._scenario_registry = MagicMock() - context._initializer_registry = MagicMock() - context._initialized = True - - with pytest.raises(ValueError, match="target registry is empty"): - await frontend_core.run_scenario_async( - scenario_name="test_scenario", - context=context, - target_name="my_target", - ) - - @patch("pyrit.cli.frontend_core.initialize_pyrit_async", new_callable=AsyncMock) - @patch("pyrit.cli.frontend_core.ConsoleScenarioResultPrinter") - async def test_run_scenario_async_without_target( - self, - mock_printer_class: MagicMock, - mock_init: AsyncMock, - ): - """Test running scenario without target_name does not add objective_target to kwargs.""" - context = frontend_core.FrontendCore() - mock_scenario_registry = MagicMock() - mock_scenario_class = MagicMock() - mock_scenario_instance = MagicMock() - mock_result = MagicMock() - mock_printer = MagicMock() - mock_printer.print_summary_async = AsyncMock() - - mock_scenario_instance.initialize_async = AsyncMock() - mock_scenario_instance.run_async = AsyncMock(return_value=mock_result) - mock_scenario_class.return_value = mock_scenario_instance - mock_scenario_registry.get_class.return_value = mock_scenario_class - mock_printer_class.return_value = mock_printer - - context._scenario_registry = mock_scenario_registry - context._initializer_registry = MagicMock() - context._initialized = True - - await frontend_core.run_scenario_async( - scenario_name="test_scenario", - context=context, - ) - - # Verify no objective_target was passed - call_kwargs = mock_scenario_instance.initialize_async.call_args[1] - assert "objective_target" not in call_kwargs - - -@pytest.mark.usefixtures("patch_central_database") -class TestPrintTargetsList: - """Tests for print_targets_list_async function.""" - - @patch("pyrit.cli.frontend_core.TargetRegistry") - @patch("pyrit.cli.frontend_core.initialize_pyrit_async", new_callable=AsyncMock) - async def test_print_targets_list_with_targets( - self, - mock_init: AsyncMock, - mock_target_registry_class: MagicMock, - capsys, - ): - """Test print_targets_list_async displays target names.""" - mock_registry = MagicMock() - mock_registry.get_names.return_value = ["target_a", "target_b"] - mock_target_registry_class.get_registry_singleton.return_value = mock_registry - - context = frontend_core.FrontendCore() - context._scenario_registry = MagicMock() - context._initializer_registry = MagicMock() - context._initialized = True - - result = await frontend_core.print_targets_list_async(context=context) - - assert result == 0 - captured = capsys.readouterr() - assert "target_a" in captured.out - assert "target_b" in captured.out - assert "Total targets: 2" in captured.out - - @patch("pyrit.cli.frontend_core.TargetRegistry") - @patch("pyrit.cli.frontend_core.initialize_pyrit_async", new_callable=AsyncMock) - async def test_print_targets_list_empty( - self, - mock_init: AsyncMock, - mock_target_registry_class: MagicMock, - capsys, - ): - """Test print_targets_list_async with no targets gives helpful hint.""" - mock_registry = MagicMock() - mock_registry.get_names.return_value = [] - mock_target_registry_class.get_registry_singleton.return_value = mock_registry - - context = frontend_core.FrontendCore() - context._scenario_registry = MagicMock() - context._initializer_registry = MagicMock() - context._initialized = True - - result = await frontend_core.print_targets_list_async(context=context) - - assert result == 0 - captured = capsys.readouterr() - assert "No targets found" in captured.out - assert "--initializers target" in captured.out - - @patch("pyrit.cli.frontend_core.TargetRegistry") - @patch("pyrit.cli.frontend_core.initialize_pyrit_async", new_callable=AsyncMock) - async def test_list_targets_with_initialization_scripts_calls_initialize( - self, - mock_init: AsyncMock, - mock_target_registry_class: MagicMock, - ): - """Test list_targets_async calls initialize_pyrit_async when only scripts are configured.""" - mock_registry = MagicMock() - mock_registry.get_names.return_value = ["script_target"] - mock_target_registry_class.get_registry_singleton.return_value = mock_registry - - context = frontend_core.FrontendCore() - context._scenario_registry = MagicMock() - context._initializer_registry = MagicMock() - context._initialized = True - context._initialization_scripts = ["/path/to/script.py"] - context._initializer_configs = None - - result = await frontend_core.list_targets_async(context=context) - - assert result == ["script_target"] - # Verify initialize_pyrit_async was called with the scripts - mock_init.assert_called_once() - call_kwargs = mock_init.call_args[1] - assert call_kwargs["initialization_scripts"] == ["/path/to/script.py"] - assert call_kwargs["initializers"] is None - - -class TestParseRunArgumentsScenarioParams: - """Tests for declared-parameter augmentation in parse_run_arguments.""" - - def test_parse_with_no_declared_params_unchanged(self): - """Existing behavior: declared_params=None leaves built-in parsing intact.""" - result = frontend_core.parse_run_arguments(args_string="my_scenario --max-concurrency 5") - - assert result["scenario_name"] == "my_scenario" - assert result["max_concurrency"] == 5 - - def test_int_param_coerced(self): - from pyrit.common import Parameter - - result = frontend_core.parse_run_arguments( - args_string="my_scenario --max-turns 10", - declared_params=[Parameter(name="max_turns", description="d", param_type=int, default=5)], - ) - assert result["scenario__max_turns"] == 10 - - def test_bool_param_uses_safe_coercion(self): - from pyrit.common import Parameter - - result = frontend_core.parse_run_arguments( - args_string="my_scenario --enabled false", - declared_params=[Parameter(name="enabled", description="d", param_type=bool)], - ) - assert result["scenario__enabled"] is False - - def test_list_param_collects_multiple_values(self): - from pyrit.common import Parameter - - result = frontend_core.parse_run_arguments( - args_string="my_scenario --datasets a b c", - declared_params=[Parameter(name="datasets", description="d", param_type=list[str])], - ) - assert result["scenario__datasets"] == ["a", "b", "c"] - - def test_unset_scenario_flag_is_none(self): - """Shell parser initializes absent flags to None; extract_scenario_args drops them.""" - from pyrit.common import Parameter - - result = frontend_core.parse_run_arguments( - args_string="my_scenario", - declared_params=[Parameter(name="max_turns", description="d", param_type=int, default=5)], - ) - assert result["scenario__max_turns"] is None - - def test_collision_with_built_in_flag_raises(self): - from pyrit.common import Parameter - - with pytest.raises(ValueError, match="collides with a built-in flag"): - frontend_core.parse_run_arguments( - args_string="my_scenario --max-concurrency 5", - declared_params=[Parameter(name="max_concurrency", description="d", param_type=int)], - ) - - def test_scenario_vs_scenario_collision_raises(self): - """Two declared params normalizing to the same flag fail at parser-build time.""" - from pyrit.common import Parameter - - # foo_bar and foo-bar both normalize to --foo-bar - with pytest.raises(ValueError, match="normalize to the same CLI flag"): - frontend_core.parse_run_arguments( - args_string="my_scenario --foo-bar 1", - declared_params=[ - Parameter(name="foo_bar", description="d", param_type=int), - Parameter(name="foo-bar", description="d", param_type=int), - ], - ) - - -class TestExtractScenarioArgs: - """Tests for extract_scenario_args helper.""" - - def test_no_scenario_keys_returns_empty(self): - result = frontend_core.extract_scenario_args(parsed={"scenario_name": "x", "max_concurrency": 5}) - assert result == {} - - def test_scenario_keys_extracted_with_prefix_stripped(self): - result = frontend_core.extract_scenario_args( - parsed={"scenario_name": "x", "scenario__max_turns": 10, "scenario__mode": "fast"} - ) - assert result == {"max_turns": 10, "mode": "fast"} - - def test_none_values_dropped(self): - """Absent shell flags (initialized to None) must not reach set_params_from_args.""" - result = frontend_core.extract_scenario_args(parsed={"scenario__max_turns": None, "scenario__mode": "fast"}) - assert result == {"mode": "fast"} - - -class TestParamTypeDisplay: - """Tests for the registry's _param_type_display helper.""" - - def test_none_renders_as_any(self): - from pyrit.registry.class_registries.scenario_registry import _param_type_display - - assert _param_type_display(None) == "any" - - def test_builtin_types(self): - from pyrit.registry.class_registries.scenario_registry import _param_type_display - - assert _param_type_display(int) == "int" - assert _param_type_display(str) == "str" - assert _param_type_display(bool) == "bool" - - def test_parameterized_generic_uses_repr(self): - """list[str] has no __name__; falls back to repr.""" - from pyrit.registry.class_registries.scenario_registry import _param_type_display - - assert _param_type_display(list[str]) == "list[str]" diff --git a/tests/unit/cli/test_output.py b/tests/unit/cli/test_output.py new file mode 100644 index 0000000000..1869f67e73 --- /dev/null +++ b/tests/unit/cli/test_output.py @@ -0,0 +1,439 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Unit tests for pyrit.cli._output formatting helpers. +""" + +from unittest.mock import AsyncMock, MagicMock, patch + +from pyrit.cli import _output + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + + +def test_cprint_no_color_falls_back_to_print(capsys): + with patch.object(_output, "_HAS_COLOR", False): + _output._cprint("plain text", color="red", bold=True) + captured = capsys.readouterr() + assert "plain text" in captured.out + + +def test_cprint_uses_termcolor_when_available(capsys): + fake_termcolor = MagicMock() + with ( + patch.object(_output, "_HAS_COLOR", True), + patch.object(_output, "termcolor", fake_termcolor, create=True), + ): + _output._cprint("hello", color="cyan", bold=True) + fake_termcolor.cprint.assert_called_once_with("hello", "cyan", attrs=["bold"]) + + +def test_cprint_without_color_arg(capsys): + with patch.object(_output, "_HAS_COLOR", True): + _output._cprint("plain text") + captured = capsys.readouterr() + assert "plain text" in captured.out + + +def test_header_prints_with_cyan(capsys): + _output._header("Section Title") + captured = capsys.readouterr() + assert "Section Title" in captured.out + + +def test_wrap_short_text_single_line(): + result = _output._wrap(text="short text", indent=" ") + assert result == " short text" + + +def test_wrap_long_text_breaks_into_multiple_lines(): + text = "word " * 40 + result = _output._wrap(text=text.strip(), indent=" ", width=40) + assert "\n" in result + for line in result.split("\n"): + assert line.startswith(" ") + + +def test_wrap_empty_text_returns_empty_string(): + assert _output._wrap(text="", indent=" ") == "" + + +# --------------------------------------------------------------------------- +# print_scenario_list +# --------------------------------------------------------------------------- + + +def test_print_scenario_list_empty(capsys): + _output.print_scenario_list(items=[]) + captured = capsys.readouterr() + assert "No scenarios found." in captured.out + + +def test_print_scenario_list_full(capsys): + items = [ + { + "scenario_name": "airt.scam", + "scenario_type": "ScamScenario", + "description": "A test scenario.", + "aggregate_strategies": ["single_turn"], + "all_strategies": ["s1", "s2", "s3"], + "default_strategy": "s1", + "default_datasets": ["d1", "d2"], + "max_dataset_size": 50, + "supported_parameters": [ + { + "name": "max_turns", + "default": 5, + "param_type": "int", + "choices": None, + "description": "Maximum turns.", + }, + { + "name": "mode", + "default": None, + "param_type": "str", + "choices": ["a", "b"], + "description": "Mode.", + }, + ], + } + ] + _output.print_scenario_list(items=items) + captured = capsys.readouterr() + assert "airt.scam" in captured.out + assert "ScamScenario" in captured.out + assert "A test scenario." in captured.out + assert "Aggregate Strategies" in captured.out + assert "single_turn" in captured.out + assert "Available Strategies (3)" in captured.out + assert "Default Strategy: s1" in captured.out + assert "Default Datasets (2, max 50 per dataset)" in captured.out + assert "Supported Parameters" in captured.out + assert "max_turns" in captured.out + assert "mode" in captured.out + assert "Total scenarios: 1" in captured.out + + +def test_print_scenario_list_minimal_fields(capsys): + items = [{"scenario_name": "min", "scenario_type": "MinScenario"}] + _output.print_scenario_list(items=items) + captured = capsys.readouterr() + assert "min" in captured.out + assert "MinScenario" in captured.out + + +def test_print_scenario_list_no_max_dataset_size(capsys): + items = [ + { + "scenario_name": "no_max", + "scenario_type": "T", + "default_datasets": ["d1"], + } + ] + _output.print_scenario_list(items=items) + captured = capsys.readouterr() + assert "Default Datasets (1)" in captured.out + assert "max" not in captured.out.split("Default Datasets")[1].split("\n")[0] + + +# --------------------------------------------------------------------------- +# print_initializer_list +# --------------------------------------------------------------------------- + + +def test_print_initializer_list_empty(capsys): + _output.print_initializer_list(items=[]) + captured = capsys.readouterr() + assert "No initializers found." in captured.out + + +def test_print_initializer_list_full(capsys): + items = [ + { + "initializer_name": "openai_target", + "initializer_type": "OpenAITargetInitializer", + "required_env_vars": ["OPENAI_API_KEY", "OPENAI_ENDPOINT"], + "supported_parameters": [ + {"name": "model", "default": "gpt-4", "description": "Model name."}, + {"name": "temp", "default": None, "description": "Temperature."}, + ], + "description": "Registers OpenAI targets.", + }, + { + "initializer_name": "no_env", + "initializer_type": "NoEnvInitializer", + "required_env_vars": [], + }, + ] + _output.print_initializer_list(items=items) + captured = capsys.readouterr() + assert "openai_target" in captured.out + assert "OPENAI_API_KEY" in captured.out + assert "OPENAI_ENDPOINT" in captured.out + assert "Required Environment Variables: None" in captured.out + assert "model" in captured.out + assert "Registers OpenAI targets." in captured.out + assert "Total initializers: 2" in captured.out + + +# --------------------------------------------------------------------------- +# print_target_list +# --------------------------------------------------------------------------- + + +def test_print_target_list_empty(capsys): + _output.print_target_list(items=[]) + captured = capsys.readouterr() + assert "No targets found in registry" in captured.out + assert "--initializers target" in captured.out + + +def test_print_target_list_full(capsys): + items = [ + { + "target_registry_name": "openai_chat", + "target_type": "OpenAIChatTarget", + "underlying_model_name": "gpt-4", + "endpoint": "https://example.com", + }, + { + "target_registry_name": "claude", + "target_type": "AnthropicTarget", + "model_name": "claude-sonnet", + }, + { + "target_registry_name": "minimal", + "target_type": "MinimalTarget", + }, + ] + _output.print_target_list(items=items) + captured = capsys.readouterr() + assert "openai_chat" in captured.out + assert "Model: gpt-4" in captured.out + assert "Endpoint: https://example.com" in captured.out + assert "Model: claude-sonnet" in captured.out + assert "minimal" in captured.out + assert "Total targets: 3" in captured.out + + +# --------------------------------------------------------------------------- +# print_scenario_run_progress +# --------------------------------------------------------------------------- + + +def test_print_scenario_run_progress_with_known_totals(capsys): + run = { + "status": "RUNNING", + "total_attacks": 10, + "completed_attacks": 5, + "objective_achieved_rate": 30, + "strategies_used": ["s1", "s2"], + } + _output.print_scenario_run_progress(run=run, total_strategies=4) + captured = capsys.readouterr() + assert "strategies: 2/4" in captured.out + assert "5/10" in captured.out + assert "RUNNING" in captured.out + assert "30%" in captured.out + + +def test_print_scenario_run_progress_no_total_attacks(capsys): + run = { + "status": "PENDING", + "total_attacks": 0, + "completed_attacks": 0, + "objective_achieved_rate": 0, + "strategies_used": [], + } + _output.print_scenario_run_progress(run=run, total_strategies=0) + captured = capsys.readouterr() + assert "attacks: 0" in captured.out + assert "PENDING" in captured.out + + +def test_print_scenario_run_progress_strategies_done_only(capsys): + run = { + "status": "RUNNING", + "total_attacks": 0, + "completed_attacks": 0, + "objective_achieved_rate": 0, + "strategies_used": ["s1"], + } + _output.print_scenario_run_progress(run=run, total_strategies=0) + captured = capsys.readouterr() + assert "strategies: 1" in captured.out + + +# --------------------------------------------------------------------------- +# print_scenario_run_summary +# --------------------------------------------------------------------------- + + +def test_print_scenario_run_summary_completed(capsys): + run = { + "scenario_name": "test_sc", + "scenario_result_id": "abc-123", + "status": "COMPLETED", + "total_attacks": 5, + "completed_attacks": 5, + "objective_achieved_rate": 40, + "strategies_used": ["s1", "s2"], + } + _output.print_scenario_run_summary(run=run) + captured = capsys.readouterr() + assert "test_sc" in captured.out + assert "abc-123" in captured.out + assert "COMPLETED" in captured.out + assert "40%" in captured.out + assert "s1, s2" in captured.out + + +def test_print_scenario_run_summary_with_error(capsys): + run = { + "scenario_name": "failing", + "scenario_result_id": "id", + "status": "FAILED", + "total_attacks": 0, + "completed_attacks": 0, + "objective_achieved_rate": 0, + "error": "boom", + } + _output.print_scenario_run_summary(run=run) + captured = capsys.readouterr() + assert "Error:" in captured.out + assert "boom" in captured.out + + +# --------------------------------------------------------------------------- +# print_scenario_result_async +# --------------------------------------------------------------------------- + + +async def test_print_scenario_result_async_uses_pretty_printer(): + result_dict = {"some": "data"} + fake_scenario = MagicMock() + fake_printer = MagicMock() + fake_printer.write_async = AsyncMock() + + with ( + patch("pyrit.models.scenario_result.ScenarioResult.from_dict", return_value=fake_scenario) as from_dict_mock, + patch( + "pyrit.output.scenario_result.pretty.PrettyScenarioResultMemoryPrinter", return_value=fake_printer + ) as printer_cls, + ): + await _output.print_scenario_result_async(result_dict=result_dict) + + from_dict_mock.assert_called_once_with(result_dict) + printer_cls.assert_called_once_with() + fake_printer.write_async.assert_awaited_once_with(fake_scenario) + + +async def test_print_scenario_result_async_roundtrip_with_real_payload(): + """ + Integration smoke test: a real ScenarioResult.to_dict() payload must flow + through ScenarioResult.from_dict() inside print_scenario_result_async + without raising. Locks the REST contract used by the CLI thin client. + """ + from datetime import datetime, timezone + + from pyrit.identifiers.component_identifier import ComponentIdentifier + from pyrit.models import AttackOutcome, AttackResult + from pyrit.models.scenario_result import ScenarioIdentifier, ScenarioResult + + identifier = ScenarioIdentifier(name="test.scenario", description="A test") + target_identifier = ComponentIdentifier.from_dict( + {"__type__": "FakeTarget", "__module__": "test.mod", "params": {}} + ) + attack = AttackResult( + conversation_id="conv-1", + objective="extract data", + outcome=AttackOutcome.SUCCESS, + executed_turns=2, + execution_time_ms=150, + timestamp=datetime(2025, 1, 1, tzinfo=timezone.utc), + ) + original = ScenarioResult( + scenario_identifier=identifier, + objective_target_identifier=target_identifier, + objective_scorer_identifier=None, + attack_results={"strat_a": [attack]}, + scenario_run_state="COMPLETED", + ) + payload = original.to_dict() + + # Drive print_scenario_result_async through the real from_dict path; only + # stub the printer to keep the test fast. + fake_printer = MagicMock() + fake_printer.write_async = AsyncMock() + with patch( + "pyrit.output.scenario_result.pretty.PrettyScenarioResultMemoryPrinter", + return_value=fake_printer, + ): + await _output.print_scenario_result_async(result_dict=payload) + + fake_printer.write_async.assert_awaited_once() + reconstructed = fake_printer.write_async.await_args.args[0] + assert isinstance(reconstructed, ScenarioResult) + assert reconstructed.scenario_identifier.name == "test.scenario" + assert list(reconstructed.attack_results.keys()) == ["strat_a"] + assert reconstructed.attack_results["strat_a"][0].outcome == AttackOutcome.SUCCESS + + +# --------------------------------------------------------------------------- +# print_scenario_runs_list +# --------------------------------------------------------------------------- + + +def test_print_scenario_runs_list_empty(capsys): + _output.print_scenario_runs_list(runs=[]) + captured = capsys.readouterr() + assert "No scenario runs found." in captured.out + + +def test_print_scenario_runs_list_populated(capsys): + runs = [ + { + "status": "COMPLETED", + "scenario_name": "scen-a", + "scenario_result_id": "abcdefgh1234", + "total_attacks": 4, + "objective_achieved_rate": 75, + "created_at": "2024-01-01", + }, + { + "status": "RUNNING", + "scenario_name": "scen-b", + "scenario_result_id": "ijklmnop5678", + "total_attacks": 0, + "objective_achieved_rate": 0, + "created_at": "2024-02-02", + }, + ] + _output.print_scenario_runs_list(runs=runs) + captured = capsys.readouterr() + assert "scen-a" in captured.out + assert "scen-b" in captured.out + assert "abcdefgh" in captured.out + assert "Total runs: 2" in captured.out + + +# --------------------------------------------------------------------------- +# print_error_with_hint +# --------------------------------------------------------------------------- + + +def test_print_error_with_hint_message_only(capsys): + _output.print_error_with_hint(message="oops") + captured = capsys.readouterr() + assert "Error: oops" in captured.out + assert "Hint:" not in captured.out + + +def test_print_error_with_hint_with_hint(capsys): + _output.print_error_with_hint(message="oops", hint="try this") + captured = capsys.readouterr() + assert "Error: oops" in captured.out + assert "Hint: try this" in captured.out diff --git a/tests/unit/cli/test_pyrit_backend.py b/tests/unit/cli/test_pyrit_backend.py deleted file mode 100644 index a6d568aad8..0000000000 --- a/tests/unit/cli/test_pyrit_backend.py +++ /dev/null @@ -1,112 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -from pathlib import Path -from unittest.mock import AsyncMock, MagicMock, patch - -from pyrit.cli import pyrit_backend - - -class TestParseArgs: - """Tests for pyrit_backend.parse_args.""" - - def test_parse_args_defaults(self) -> None: - """Should parse backend defaults correctly.""" - args = pyrit_backend.parse_args(args=[]) - - assert args.host == "localhost" - assert args.port == 8000 - assert args.config_file is None - - def test_parse_args_accepts_config_file(self) -> None: - """Should parse --config-file argument.""" - args = pyrit_backend.parse_args(args=["--config-file", "./custom_conf.yaml"]) - - assert args.config_file == Path("./custom_conf.yaml") - - -class TestInitializeAndRun: - """Tests for pyrit_backend.initialize_and_run_async.""" - - async def test_initialize_and_run_passes_config_file_to_frontend_core(self) -> None: - """Should forward parsed config file path to FrontendCore.""" - parsed_args = pyrit_backend.parse_args(args=["--config-file", "./custom_conf.yaml"]) - - with ( - patch("pyrit.cli.pyrit_backend.frontend_core.FrontendCore") as mock_core_class, - patch("uvicorn.Config") as mock_uvicorn_config, - patch("uvicorn.Server") as mock_uvicorn_server, - ): - mock_core = MagicMock() - mock_core.initialize_async = AsyncMock() - mock_core._initializer_configs = None - mock_core_class.return_value = mock_core - - mock_server = MagicMock() - mock_server.serve = AsyncMock() - mock_uvicorn_server.return_value = mock_server - - result = await pyrit_backend.initialize_and_run_async(parsed_args=parsed_args) - - assert result == 0 - mock_core_class.assert_called_once() - assert mock_core_class.call_args.kwargs["config_file"] == Path("./custom_conf.yaml") - mock_core.initialize_async.assert_awaited_once() - mock_uvicorn_config.assert_called_once() - mock_uvicorn_server.assert_called_once() - mock_server.serve.assert_awaited_once() - - async def test_startup_warning_when_custom_initializers_enabled(self, capsys) -> None: - """Should print a warning when allow_custom_initializers is True.""" - parsed_args = pyrit_backend.parse_args(args=[]) - - with ( - patch("pyrit.cli.pyrit_backend.frontend_core.FrontendCore") as mock_core_class, - patch("uvicorn.Config"), - patch("uvicorn.Server") as mock_uvicorn_server, - ): - mock_core = MagicMock() - mock_core.initialize_async = AsyncMock() - mock_core._initializer_configs = None - mock_core._allow_custom_initializers = True - mock_core._operator = None - mock_core._operation = None - mock_core._max_concurrent_scenario_runs = 3 - mock_core_class.return_value = mock_core - - mock_server = MagicMock() - mock_server.serve = AsyncMock() - mock_uvicorn_server.return_value = mock_server - - await pyrit_backend.initialize_and_run_async(parsed_args=parsed_args) - - captured = capsys.readouterr() - assert "WARNING" in captured.out - assert "allow_custom_initializers" in captured.out - - async def test_no_startup_warning_when_custom_initializers_disabled(self, capsys) -> None: - """Should not print custom initializer warning when disabled.""" - parsed_args = pyrit_backend.parse_args(args=[]) - - with ( - patch("pyrit.cli.pyrit_backend.frontend_core.FrontendCore") as mock_core_class, - patch("uvicorn.Config"), - patch("uvicorn.Server") as mock_uvicorn_server, - ): - mock_core = MagicMock() - mock_core.initialize_async = AsyncMock() - mock_core._initializer_configs = None - mock_core._allow_custom_initializers = False - mock_core._operator = None - mock_core._operation = None - mock_core._max_concurrent_scenario_runs = 3 - mock_core_class.return_value = mock_core - - mock_server = MagicMock() - mock_server.serve = AsyncMock() - mock_uvicorn_server.return_value = mock_server - - await pyrit_backend.initialize_and_run_async(parsed_args=parsed_args) - - captured = capsys.readouterr() - assert "allow_custom_initializers" not in captured.out diff --git a/tests/unit/cli/test_pyrit_scan.py b/tests/unit/cli/test_pyrit_scan.py index e3c3d42c43..8ecf404b41 100644 --- a/tests/unit/cli/test_pyrit_scan.py +++ b/tests/unit/cli/test_pyrit_scan.py @@ -2,11 +2,11 @@ # Licensed under the MIT license. """ -Unit tests for the pyrit_scan CLI module. +Unit tests for the pyrit_scan CLI module (thin REST client). """ import logging -from pathlib import Path +from argparse import Namespace from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -18,76 +18,52 @@ class TestParseArgs: """Tests for parse_args function.""" def test_parse_args_list_scenarios(self): - """Test parsing --list-scenarios flag.""" args = pyrit_scan.parse_args(["--list-scenarios"]) - assert args.list_scenarios is True assert args.scenario_name is None def test_parse_args_list_initializers(self): - """Test parsing --list-initializers flag.""" args = pyrit_scan.parse_args(["--list-initializers"]) - assert args.list_initializers is True - assert args.scenario_name is None def test_parse_args_scenario_name_only(self): - """Test parsing scenario name without options.""" args = pyrit_scan.parse_args(["test_scenario"]) - assert args.scenario_name == "test_scenario" assert args.log_level == logging.WARNING def test_parse_args_with_log_level(self): - """Test parsing with log-level option.""" args = pyrit_scan.parse_args(["test_scenario", "--log-level", "DEBUG"]) - assert args.log_level == logging.DEBUG def test_parse_args_with_initializers(self): - """Test parsing with initializers.""" args = pyrit_scan.parse_args(["test_scenario", "--initializers", "init1", "init2"]) - assert args.initializers == ["init1", "init2"] - def test_parse_args_with_initialization_scripts(self): - """Test parsing with initialization-scripts.""" - args = pyrit_scan.parse_args(["test_scenario", "--initialization-scripts", "script1.py", "script2.py"]) - - assert args.initialization_scripts == ["script1.py", "script2.py"] + def test_parse_args_with_add_initializer(self): + args = pyrit_scan.parse_args(["--add-initializer", "script1.py", "script2.py"]) + assert args.add_initializer == ["script1.py", "script2.py"] def test_parse_args_with_strategies(self): - """Test parsing with strategies.""" args = pyrit_scan.parse_args(["test_scenario", "--strategies", "s1", "s2"]) - assert args.scenario_strategies == ["s1", "s2"] def test_parse_args_with_strategies_short_flag(self): - """Test parsing with -s flag.""" args = pyrit_scan.parse_args(["test_scenario", "-s", "s1", "s2"]) - assert args.scenario_strategies == ["s1", "s2"] def test_parse_args_with_max_concurrency(self): - """Test parsing with max-concurrency.""" args = pyrit_scan.parse_args(["test_scenario", "--max-concurrency", "5"]) - assert args.max_concurrency == 5 def test_parse_args_with_max_retries(self): - """Test parsing with max-retries.""" args = pyrit_scan.parse_args(["test_scenario", "--max-retries", "3"]) - assert args.max_retries == 3 def test_parse_args_with_memory_labels(self): - """Test parsing with memory-labels.""" args = pyrit_scan.parse_args(["test_scenario", "--memory-labels", '{"key":"value"}']) - assert args.memory_labels == '{"key":"value"}' def test_parse_args_complex_command(self): - """Test parsing complex command with multiple options.""" args = pyrit_scan.parse_args( [ "encoding_scenario", @@ -106,687 +82,742 @@ def test_parse_args_complex_command(self): '{"env":"test"}', ] ) - assert args.scenario_name == "encoding_scenario" assert args.log_level == logging.INFO assert args.initializers == ["openai_target"] assert args.scenario_strategies == ["base64", "rot13"] assert args.max_concurrency == 10 assert args.max_retries == 5 - assert args.memory_labels == '{"env":"test"}' def test_parse_args_invalid_log_level(self): - """Test parsing with invalid log level raises error.""" with pytest.raises(SystemExit): pyrit_scan.parse_args(["test_scenario", "--log-level", "INVALID"]) def test_parse_args_invalid_max_concurrency(self): - """Test parsing with invalid max-concurrency raises error.""" with pytest.raises(SystemExit): pyrit_scan.parse_args(["test_scenario", "--max-concurrency", "0"]) def test_parse_args_invalid_max_retries(self): - """Test parsing with invalid max-retries raises error.""" with pytest.raises(SystemExit): pyrit_scan.parse_args(["test_scenario", "--max-retries", "-1"]) def test_parse_args_help_flag(self): - """Test parsing --help flag exits.""" with pytest.raises(SystemExit) as exc_info: pyrit_scan.parse_args(["--help"]) - assert exc_info.value.code == 0 def test_parse_args_with_target(self): - """Test parsing with --target option.""" args = pyrit_scan.parse_args(["test_scenario", "--target", "my_target"]) - assert args.target == "my_target" def test_parse_args_target_default_is_none(self): - """Test --target defaults to None when not provided.""" args = pyrit_scan.parse_args(["test_scenario"]) - assert args.target is None def test_parse_args_with_list_targets(self): - """Test parsing --list-targets flag.""" args = pyrit_scan.parse_args(["--list-targets"]) - assert args.list_targets is True + def test_parse_args_with_server_url(self): + args = pyrit_scan.parse_args(["--list-scenarios", "--server-url", "http://remote:9000"]) + assert args.server_url == "http://remote:9000" + + def test_parse_args_with_start_server(self): + args = pyrit_scan.parse_args(["--list-scenarios", "--start-server"]) + assert args.start_server is True + + def test_parse_args_with_stop_server(self): + args = pyrit_scan.parse_args(["--stop-server"]) + assert args.stop_server is True + + def test_main_with_invalid_args(self): + result = pyrit_scan.main(["--invalid-flag"]) + assert result == 2 + + +class TestExtractScenarioArgs: + """Tests for the namespaced-dest extraction helper.""" + + def test_no_scenario_keys_returns_empty(self): + result = pyrit_scan._extract_scenario_args(parsed=Namespace(scenario_name="x", config_file=None, log_level=20)) + assert result == {} + + def test_scenario_keys_extracted_with_prefix_stripped(self): + result = pyrit_scan._extract_scenario_args( + parsed=Namespace(scenario_name="x", config_file=None, scenario__max_turns=10, scenario__mode="fast") + ) + assert result == {"max_turns": 10, "mode": "fast"} + + +def _mock_api_client(): + """Create a mock PyRITApiClient with default response behaviors.""" + client = AsyncMock() + client.health_check_async.return_value = True + client.list_scenarios_async.return_value = {"items": [], "pagination": {"total": 0}} + client.list_initializers_async.return_value = {"items": [], "pagination": {"total": 0}} + client.list_targets_async.return_value = {"items": [], "pagination": {"total": 0}} + client.get_scenario_async.return_value = { + "scenario_name": "test_scenario", + "supported_parameters": [], + } + client.start_scenario_run_async.return_value = { + "scenario_result_id": "test-id-123", + "scenario_name": "test_scenario", + "status": "CREATED", + } + client.get_scenario_run_async.return_value = { + "scenario_result_id": "test-id-123", + "status": "COMPLETED", + "total_attacks": 5, + "completed_attacks": 5, + "objective_achieved_rate": 40, + } + client.get_scenario_run_results_async.return_value = { + "run": { + "scenario_result_id": "test-id-123", + "scenario_name": "test_scenario", + "status": "COMPLETED", + "total_attacks": 5, + "completed_attacks": 5, + "objective_achieved_rate": 40, + }, + "attacks": [], + } + client.__aenter__ = AsyncMock(return_value=client) + client.__aexit__ = AsyncMock(return_value=None) + return client + class TestMain: - """Tests for main function.""" + """Tests for main function (thin REST client).""" - @patch("pyrit.cli.frontend_core.print_scenarios_list_async", new_callable=AsyncMock) - @patch("pyrit.cli.frontend_core.FrontendCore") - def test_main_list_scenarios(self, mock_frontend_core: MagicMock, mock_print_scenarios: AsyncMock): + @patch("pyrit.cli._server_launcher.ServerLauncher.probe_health_async", new_callable=AsyncMock, return_value=True) + @patch("pyrit.cli.api_client.PyRITApiClient") + def test_main_list_scenarios(self, mock_client_class, mock_probe): """Test main with --list-scenarios flag.""" - mock_print_scenarios.return_value = 0 + mock_client = _mock_api_client() + mock_client_class.return_value = mock_client result = pyrit_scan.main(["--list-scenarios"]) assert result == 0 - mock_print_scenarios.assert_called_once() - mock_frontend_core.assert_called_once() - - @patch("pyrit.cli.frontend_core.print_initializers_list_async", new_callable=AsyncMock) - @patch("pyrit.cli.frontend_core.FrontendCore") - def test_main_list_initializers( - self, - mock_frontend_core: MagicMock, - mock_print_initializers: AsyncMock, - ): + mock_client.list_scenarios_async.assert_awaited_once() + + @patch("pyrit.cli._server_launcher.ServerLauncher.probe_health_async", new_callable=AsyncMock, return_value=True) + @patch("pyrit.cli.api_client.PyRITApiClient") + def test_main_list_initializers(self, mock_client_class, mock_probe): """Test main with --list-initializers flag.""" - mock_print_initializers.return_value = 0 + mock_client = _mock_api_client() + mock_client_class.return_value = mock_client result = pyrit_scan.main(["--list-initializers"]) assert result == 0 - mock_print_initializers.assert_called_once() - - @patch("pyrit.cli.frontend_core.print_scenarios_list_async", new_callable=AsyncMock) - @patch("pyrit.cli.frontend_core.resolve_initialization_scripts") - @patch("pyrit.cli.frontend_core.FrontendCore") - def test_main_list_scenarios_with_scripts( - self, - mock_frontend_core: MagicMock, - mock_resolve_scripts: MagicMock, - mock_print_scenarios: AsyncMock, - ): - """Test main with --list-scenarios and --initialization-scripts.""" - mock_resolve_scripts.return_value = [Path("/test/script.py")] - mock_print_scenarios.return_value = 0 - - result = pyrit_scan.main(["--list-scenarios", "--initialization-scripts", "script.py"]) + mock_client.list_initializers_async.assert_awaited_once() - assert result == 0 - mock_resolve_scripts.assert_called_once_with(script_paths=["script.py"]) - mock_print_scenarios.assert_called_once() + @patch("pyrit.cli._server_launcher.ServerLauncher.probe_health_async", new_callable=AsyncMock, return_value=True) + @patch("pyrit.cli.api_client.PyRITApiClient") + def test_main_list_targets(self, mock_client_class, mock_probe): + """Test main with --list-targets flag.""" + mock_client = _mock_api_client() + mock_client_class.return_value = mock_client - @patch("pyrit.cli.frontend_core.resolve_initialization_scripts") - def test_main_list_scenarios_with_missing_script(self, mock_resolve_scripts: MagicMock): - """Test main with --list-scenarios and missing script file.""" - mock_resolve_scripts.side_effect = FileNotFoundError("Script not found") + result = pyrit_scan.main(["--list-targets"]) - result = pyrit_scan.main(["--list-scenarios", "--initialization-scripts", "missing.py"]) - - assert result == 1 + assert result == 0 + mock_client.list_targets_async.assert_awaited_once() - @patch("pyrit.cli.frontend_core.print_targets_list_async", new_callable=AsyncMock) - @patch("pyrit.cli.frontend_core.FrontendCore") - def test_main_list_targets_with_initializers( - self, - mock_frontend_core: MagicMock, - mock_print_targets: AsyncMock, - ): - """Test main with --list-targets and --initializers passes initializers to FrontendCore.""" - mock_print_targets.return_value = 0 + def test_main_no_args_shows_help(self): + """Test main with no arguments shows help.""" + result = pyrit_scan.main([]) + assert result == 0 # shows help and exits - result = pyrit_scan.main(["--list-targets", "--initializers", "target"]) + @patch("pyrit.cli._server_launcher.ServerLauncher.probe_health_async", new_callable=AsyncMock, return_value=True) + @patch("pyrit.cli.api_client.PyRITApiClient") + def test_main_run_scenario(self, mock_client_class, mock_probe): + """Test main running a scenario.""" + mock_client = _mock_api_client() + mock_client_class.return_value = mock_client - assert result == 0 - mock_frontend_core.assert_called_once() - call_kwargs = mock_frontend_core.call_args[1] - assert call_kwargs["initializer_names"] == ["target"] - mock_print_targets.assert_called_once() - - @patch("pyrit.cli.frontend_core.print_targets_list_async", new_callable=AsyncMock) - @patch("pyrit.cli.frontend_core.resolve_initialization_scripts") - @patch("pyrit.cli.frontend_core.FrontendCore") - def test_main_list_targets_with_scripts( - self, - mock_frontend_core: MagicMock, - mock_resolve_scripts: MagicMock, - mock_print_targets: AsyncMock, - ): - """Test main with --list-targets and --initialization-scripts passes scripts to FrontendCore.""" - mock_resolve_scripts.return_value = [Path("/test/script.py")] - mock_print_targets.return_value = 0 - - result = pyrit_scan.main(["--list-targets", "--initialization-scripts", "script.py"]) + result = pyrit_scan.main(["test_scenario", "--target", "my_target"]) assert result == 0 - mock_resolve_scripts.assert_called_once_with(script_paths=["script.py"]) - mock_frontend_core.assert_called_once() - call_kwargs = mock_frontend_core.call_args[1] - assert call_kwargs["initialization_scripts"] == [Path("/test/script.py")] - mock_print_targets.assert_called_once() + mock_client.get_scenario_async.assert_awaited_once() + mock_client.start_scenario_run_async.assert_awaited_once() - @patch("pyrit.cli.frontend_core.resolve_initialization_scripts") - def test_main_list_targets_with_missing_script(self, mock_resolve_scripts: MagicMock): - """Test main with --list-targets and missing script file.""" - mock_resolve_scripts.side_effect = FileNotFoundError("Script not found") + @patch("pyrit.cli._server_launcher.ServerLauncher.probe_health_async", new_callable=AsyncMock, return_value=True) + @patch("pyrit.cli.api_client.PyRITApiClient") + def test_main_run_scenario_with_initializers(self, mock_client_class, mock_probe): + """Test main maps --initializers to request format.""" + mock_client = _mock_api_client() + mock_client_class.return_value = mock_client - result = pyrit_scan.main(["--list-targets", "--initialization-scripts", "missing.py"]) + result = pyrit_scan.main(["test_scenario", "--target", "t", "--initializers", "target", "datasets"]) - assert result == 1 + assert result == 0 + call_kwargs = mock_client.start_scenario_run_async.call_args.kwargs + request = call_kwargs["request"] + assert request["initializers"] == ["target", "datasets"] - def test_main_no_scenario_specified(self, capsys): - """Test main without scenario name.""" - result = pyrit_scan.main([]) + @patch("pyrit.cli._server_launcher.ServerLauncher.probe_health_async", new_callable=AsyncMock, return_value=False) + def test_main_server_not_available(self, mock_probe, capsys): + """Test main when server is not available.""" + result = pyrit_scan.main(["--list-scenarios"]) assert result == 1 captured = capsys.readouterr() - assert "No scenario specified" in captured.out - - @patch("pyrit.cli.pyrit_scan.asyncio.run") - @patch("pyrit.cli.frontend_core.run_scenario_async", new_callable=AsyncMock) - @patch("pyrit.cli.frontend_core.FrontendCore") - def test_main_run_scenario_basic( - self, - mock_frontend_core: MagicMock, - mock_run_scenario: AsyncMock, - mock_asyncio_run: MagicMock, - ): - """Test main running a basic scenario.""" - result = pyrit_scan.main(["test_scenario", "--initializers", "test_init"]) + assert "Server not available" in captured.out - assert result == 0 - mock_asyncio_run.assert_called_once() - - @patch("pyrit.cli.pyrit_scan.asyncio.run") - @patch("pyrit.cli.frontend_core.run_scenario_async", new_callable=AsyncMock) - @patch("pyrit.cli.frontend_core.resolve_initialization_scripts") - @patch("pyrit.cli.frontend_core.FrontendCore") - def test_main_run_scenario_with_scripts( - self, - mock_frontend_core: MagicMock, - mock_resolve_scripts: MagicMock, - mock_run_scenario: AsyncMock, - mock_asyncio_run: MagicMock, - ): - """Test main running scenario with initialization scripts.""" - mock_resolve_scripts.return_value = [Path("/test/script.py")] - - result = pyrit_scan.main(["test_scenario", "--initialization-scripts", "script.py"]) + @patch("pyrit.cli._server_launcher.ServerLauncher.probe_health_async", new_callable=AsyncMock, return_value=False) + def test_main_stop_server(self, mock_probe, capsys): + """Test main with --stop-server.""" + result = pyrit_scan.main(["--stop-server"]) assert result == 0 - mock_resolve_scripts.assert_called_once_with(script_paths=["script.py"]) - mock_asyncio_run.assert_called_once() + captured = capsys.readouterr() + assert "No server running" in captured.out - @patch("pyrit.cli.frontend_core.resolve_initialization_scripts") - def test_main_run_scenario_with_missing_script(self, mock_resolve_scripts: MagicMock): - """Test main with missing initialization script.""" - mock_resolve_scripts.side_effect = FileNotFoundError("Script not found") + @patch("pyrit.cli._server_launcher.ServerLauncher.probe_health_async", new_callable=AsyncMock, return_value=True) + @patch("pyrit.cli.api_client.PyRITApiClient") + def test_main_scenario_not_found(self, mock_client_class, mock_probe, capsys): + """Test main when scenario is not found on server.""" + mock_client = _mock_api_client() + mock_client.get_scenario_async.return_value = None + mock_client_class.return_value = mock_client - result = pyrit_scan.main(["test_scenario", "--initialization-scripts", "missing.py"]) + result = pyrit_scan.main(["nonexistent_scenario", "--target", "t"]) assert result == 1 + captured = capsys.readouterr() + assert "not found" in captured.out + + @patch("pyrit.cli._server_launcher.ServerLauncher.probe_health_async", new_callable=AsyncMock, return_value=True) + @patch("pyrit.cli.api_client.PyRITApiClient") + def test_main_failed_scenario(self, mock_client_class, mock_probe): + """Test main when scenario run fails.""" + mock_client = _mock_api_client() + mock_client.get_scenario_run_async.return_value = { + "scenario_result_id": "test-id", + "status": "FAILED", + "total_attacks": 0, + "completed_attacks": 0, + "objective_achieved_rate": 0, + "error": "Something went wrong", + } + mock_client_class.return_value = mock_client + + result = pyrit_scan.main(["test_scenario", "--target", "t"]) - @patch("pyrit.cli.pyrit_scan.asyncio.run") - @patch("pyrit.cli.frontend_core.run_scenario_async", new_callable=AsyncMock) - @patch("pyrit.cli.frontend_core.FrontendCore") - def test_main_run_scenario_with_all_options( - self, - mock_frontend_core: MagicMock, - mock_run_scenario: AsyncMock, - mock_asyncio_run: MagicMock, - ): - """Test main with all scenario options.""" - result = pyrit_scan.main( - [ - "test_scenario", - "--log-level", - "DEBUG", - "--initializers", - "init1", - "init2", - "--strategies", - "s1", - "s2", - "--max-concurrency", - "10", - "--max-retries", - "5", - "--memory-labels", - '{"key":"value"}', - ] - ) - - assert result == 0 - mock_asyncio_run.assert_called_once() - - # Verify FrontendCore was called with correct args - call_kwargs = mock_frontend_core.call_args[1] - assert call_kwargs["log_level"] == logging.DEBUG - assert call_kwargs["initializer_names"] == ["init1", "init2"] - - @patch("pyrit.cli.pyrit_scan.asyncio.run") - @patch("pyrit.cli.frontend_core.parse_memory_labels") - @patch("pyrit.cli.frontend_core.run_scenario_async", new_callable=AsyncMock) - @patch("pyrit.cli.frontend_core.FrontendCore") - def test_main_run_scenario_with_memory_labels( - self, - mock_frontend_core: MagicMock, - mock_run_scenario: AsyncMock, - mock_parse_labels: MagicMock, - mock_asyncio_run: MagicMock, - ): - """Test main with memory labels parsing.""" - mock_parse_labels.return_value = {"key": "value"} - - result = pyrit_scan.main(["test_scenario", "--initializers", "test_init", "--memory-labels", '{"key":"value"}']) - - assert result == 0 - mock_parse_labels.assert_called_once_with(json_string='{"key":"value"}') + assert result == 1 - @patch("pyrit.cli.pyrit_scan.asyncio.run") - @patch("pyrit.cli.frontend_core.FrontendCore") - def test_main_run_scenario_with_exception( - self, - mock_frontend_core: MagicMock, - mock_asyncio_run: MagicMock, - ): - """Test main handles exceptions during scenario run.""" - mock_asyncio_run.side_effect = ValueError("Test error") - result = pyrit_scan.main(["test_scenario", "--initializers", "test_init"]) +# --------------------------------------------------------------------------- +# Internal helper coverage +# --------------------------------------------------------------------------- - assert result == 1 - @patch("pyrit.cli.frontend_core.FrontendCore") - def test_main_log_level_defaults_to_warning(self, mock_frontend_core: MagicMock): - """Test main uses WARNING as default log level.""" - pyrit_scan.main(["--list-scenarios"]) +class TestStopServerOnPort: + """Tests for stop_server_on_port helper (now lives in _server_launcher).""" - call_kwargs = mock_frontend_core.call_args[1] - assert call_kwargs["log_level"] == logging.WARNING + @patch("sys.platform", "win32") + @patch("subprocess.run") + @patch("os.kill") + def test_stop_on_windows_finds_pid_via_netstat(self, mock_kill, mock_run): + from pyrit.cli import _server_launcher - def test_main_with_invalid_args(self): - """Test main with invalid arguments.""" - result = pyrit_scan.main(["--invalid-flag"]) - - assert result == 2 # argparse returns 2 for invalid arguments - - @patch("builtins.print") - def test_main_prints_startup_message(self, mock_print: MagicMock): - """Test main prints startup message.""" - pyrit_scan.main(["--list-scenarios"]) - - # Check that "Starting PyRIT..." was printed - calls = [str(call_obj) for call_obj in mock_print.call_args_list] - assert any("Starting PyRIT" in str(call_obj) for call_obj in calls) - - @patch("pyrit.cli.pyrit_scan.asyncio.run") - @patch("pyrit.cli.frontend_core.FrontendCore") - def test_main_run_scenario_calls_run_scenario_async( - self, - mock_frontend_core: MagicMock, - mock_asyncio_run: MagicMock, - ): - """Test main properly calls run_scenario_async.""" - pyrit_scan.main(["test_scenario", "--initializers", "test_init", "--strategies", "s1"]) - - # Verify asyncio.run was called with run_scenario_async - assert mock_asyncio_run.call_count == 1 - assert mock_asyncio_run.call_count == 1 - - -class TestMainIntegration: - """Integration-style tests for main function.""" - - @patch("pyrit.cli.frontend_core.print_scenarios_list_async", new_callable=AsyncMock) - @patch("pyrit.registry.ScenarioRegistry") - @patch("pyrit.setup.initialize_pyrit_async", new_callable=AsyncMock) - def test_main_list_scenarios_integration( - self, - mock_init_pyrit: AsyncMock, - mock_scenario_registry: MagicMock, - mock_print_scenarios: AsyncMock, - ): - """Test main --list-scenarios with minimal mocking.""" - mock_print_scenarios.return_value = 0 + mock_run.return_value = MagicMock( + stdout=" TCP 0.0.0.0:8000 0.0.0.0:0 LISTENING 1234\n", + ) + assert _server_launcher.stop_server_on_port(port=8000) is True + mock_kill.assert_called_once() + + @patch("sys.platform", "win32") + @patch("subprocess.run") + @patch("os.kill") + def test_stop_on_windows_does_not_match_port_substring(self, mock_kill, mock_run): + from pyrit.cli import _server_launcher + + # Listening on :8000/:8080 should NOT be matched when stopping port 80. + mock_run.return_value = MagicMock( + stdout=( + " TCP 0.0.0.0:8000 0.0.0.0:0 LISTENING 1234\n" + " TCP 0.0.0.0:8080 0.0.0.0:0 LISTENING 2345\n" + ), + ) + assert _server_launcher.stop_server_on_port(port=80) is False + mock_kill.assert_not_called() - result = pyrit_scan.main(["--list-scenarios"]) + @patch("sys.platform", "win32") + @patch("subprocess.run") + @patch("os.kill") + def test_stop_on_windows_matches_ipv6_local_address(self, mock_kill, mock_run): + from pyrit.cli import _server_launcher - assert result == 0 + mock_run.return_value = MagicMock( + stdout=" TCP [::]:8000 [::]:0 LISTENING 9999\n", + ) + assert _server_launcher.stop_server_on_port(port=8000) is True + mock_kill.assert_called_once_with(9999, pytest.importorskip("signal").SIGTERM) - @patch("pyrit.cli.frontend_core.print_initializers_list_async", new_callable=AsyncMock) - def test_main_list_initializers_integration( - self, - mock_print_initializers: AsyncMock, - ): - """Test main --list-initializers with minimal mocking.""" - mock_print_initializers.return_value = 0 + @patch("sys.platform", "linux") + @patch("subprocess.run") + @patch("os.kill") + def test_stop_on_unix_finds_pid_via_lsof(self, mock_kill, mock_run): + from pyrit.cli import _server_launcher - result = pyrit_scan.main(["--list-initializers"]) + mock_run.return_value = MagicMock(stdout="5678\n") + assert _server_launcher.stop_server_on_port(port=8000) is True + mock_kill.assert_called_once_with(5678, pytest.importorskip("signal").SIGTERM) - assert result == 0 + @patch("subprocess.run", side_effect=OSError("nope")) + def test_stop_swallows_errors_and_returns_false(self, _mock_run): + from pyrit.cli import _server_launcher + assert _server_launcher.stop_server_on_port(port=8000) is False -class TestTwoPassParsing: - """Tests for the two-pass scenario-parameter augmentation flow.""" + @patch("sys.platform", "linux") + @patch("subprocess.run") + def test_stop_returns_false_when_no_pid_found(self, mock_run): + from pyrit.cli import _server_launcher - @staticmethod - def _patch_resolve(scenario_class): - """Patch the registry lookup so tests don't depend on the real registry.""" - return patch.object(pyrit_scan, "_resolve_scenario_class", return_value=scenario_class) + mock_run.return_value = MagicMock(stdout="") + assert _server_launcher.stop_server_on_port(port=8000) is False - @staticmethod - def _make_scenario_class(declared_params): - """Build a stand-in class whose only obligation is to expose supported_parameters().""" - class _FakeScenario: - @classmethod - def supported_parameters(cls): - return list(declared_params) +class TestAddScenarioParamsFromApi: + """Tests for _add_scenario_params_from_api.""" - return _FakeScenario + def test_adds_unseen_params_as_optional_flags(self): + from argparse import ArgumentParser - def test_no_scenario_resolved_leaves_namespace_unaugmented(self): - """When the positional name is missing or unknown, scenario flags do not appear.""" - with self._patch_resolve(None): - args = pyrit_scan.parse_args(["--list-scenarios"]) + parser = ArgumentParser() + pyrit_scan._add_scenario_params_from_api( + parser=parser, + params=[ + {"name": "max_turns", "description": "Max turns."}, + {"name": "mode", "description": "Mode."}, + ], + ) + parsed = parser.parse_args(["--max-turns", "5", "--mode", "fast"]) + assert parsed.scenario__max_turns == "5" + assert parsed.scenario__mode == "fast" + + def test_skips_params_that_collide_with_existing_flags(self): + from argparse import ArgumentParser + + parser = ArgumentParser() + parser.add_argument("--target") + pyrit_scan._add_scenario_params_from_api( + parser=parser, + params=[{"name": "target", "description": "..."}], + ) + parsed = parser.parse_args(["--target", "x"]) + # Original --target wins; no scenario__target added. + assert parsed.target == "x" + assert not hasattr(parsed, "scenario__target") + + +class TestBuildRunRequest: + """Tests for _build_run_request.""" + + def test_includes_initializer_args(self): + parsed = Namespace( + target="t", + initializers=[{"name": "openai_target", "args": {"model": "gpt-4"}}, "datasets"], + scenario_strategies=None, + max_concurrency=None, + max_retries=None, + dataset_names=None, + max_dataset_size=None, + memory_labels=None, + ) + request = pyrit_scan._build_run_request(parsed_args=parsed, scenario_name="s") + assert request["initializers"] == ["openai_target", "datasets"] + assert request["initializer_args"] == {"openai_target": {"model": "gpt-4"}} + + def test_populates_optional_fields(self): + parsed = Namespace( + target="t", + initializers=None, + scenario_strategies=["s1"], + max_concurrency=3, + max_retries=2, + dataset_names=["d1"], + max_dataset_size=10, + memory_labels='{"key":"value"}', + ) + request = pyrit_scan._build_run_request(parsed_args=parsed, scenario_name="s") + assert request["strategies"] == ["s1"] + assert request["max_concurrency"] == 3 + assert request["max_retries"] == 2 + assert request["dataset_names"] == ["d1"] + assert request["max_dataset_size"] == 10 + assert request["labels"] == {"key": "value"} + + def test_includes_scenario_declared_params(self): + parsed = Namespace( + target=None, + initializers=None, + scenario_strategies=None, + max_concurrency=None, + max_retries=None, + dataset_names=None, + max_dataset_size=None, + memory_labels=None, + scenario__max_turns="7", + ) + request = pyrit_scan._build_run_request(parsed_args=parsed, scenario_name="s") + assert request["scenario_params"] == {"max_turns": "7"} - # No scenario__-prefixed attrs sneaked in. - scenario_keys = [k for k in vars(args) if k.startswith("scenario__")] - assert scenario_keys == [] - def test_int_param_coerced(self): - """A declared int parameter coerces its string CLI value to int.""" - from pyrit.common import Parameter +class TestResolveServerUrl: + """Tests for _resolve_server_url_async.""" - scenario_class = self._make_scenario_class( - [Parameter(name="max_turns", description="d", param_type=int, default=5)] + async def test_uses_cli_flag_when_provided(self): + parsed = Namespace( + server_url="http://override:7000", + start_server=False, + config_file=None, ) - with self._patch_resolve(scenario_class): - args = pyrit_scan.parse_args(["fake_scenario", "--max-turns", "10"]) + with patch( + "pyrit.cli._server_launcher.ServerLauncher.probe_health_async", + new=AsyncMock(return_value=True), + ): + result = await pyrit_scan._resolve_server_url_async(parsed_args=parsed) + assert result == "http://override:7000" - scenario_args = pyrit_scan._extract_scenario_args(parsed=args) - assert scenario_args == {"max_turns": 10} + async def test_returns_none_when_unhealthy_and_no_start_server(self): + parsed = Namespace(server_url=None, start_server=False, config_file=None) + with ( + patch( + "pyrit.cli._config_reader.read_server_url", + return_value=None, + ), + patch( + "pyrit.cli._server_launcher.ServerLauncher.probe_health_async", + new=AsyncMock(return_value=False), + ), + ): + assert await pyrit_scan._resolve_server_url_async(parsed_args=parsed) is None - def test_bool_param_uses_safe_coercion(self): - """``--enabled false`` is correctly parsed to False (avoids the type=bool footgun).""" - from pyrit.common import Parameter + async def test_auto_starts_server_when_requested(self): + parsed = Namespace(server_url=None, start_server=True, config_file=None) + with ( + patch("pyrit.cli._config_reader.read_server_url", return_value=None), + patch( + "pyrit.cli._server_launcher.ServerLauncher.probe_health_async", + new=AsyncMock(return_value=False), + ), + patch( + "pyrit.cli._server_launcher.ServerLauncher.start_async", + new=AsyncMock(return_value="http://localhost:8000"), + ), + ): + assert await pyrit_scan._resolve_server_url_async(parsed_args=parsed) == "http://localhost:8000" - scenario_class = self._make_scenario_class([Parameter(name="enabled", description="d", param_type=bool)]) - with self._patch_resolve(scenario_class): - args = pyrit_scan.parse_args(["fake_scenario", "--enabled", "false"]) + async def test_returns_none_when_start_server_raises(self, capsys): + parsed = Namespace(server_url=None, start_server=True, config_file=None) + with ( + patch("pyrit.cli._config_reader.read_server_url", return_value=None), + patch( + "pyrit.cli._server_launcher.ServerLauncher.probe_health_async", + new=AsyncMock(return_value=False), + ), + patch( + "pyrit.cli._server_launcher.ServerLauncher.start_async", + new=AsyncMock(side_effect=RuntimeError("nope")), + ), + ): + assert await pyrit_scan._resolve_server_url_async(parsed_args=parsed) is None + assert "nope" in capsys.readouterr().out - assert pyrit_scan._extract_scenario_args(parsed=args) == {"enabled": False} + async def test_resolution_order_cli_beats_config_beats_default(self): + """CLI flag > config-file value > built-in default.""" + # 1) CLI flag wins even when config has a different value. + parsed = Namespace(server_url="http://cli:1111", start_server=False, config_file=None) + with ( + patch( + "pyrit.cli._config_reader.read_server_url", + return_value="http://cfg:2222", + ), + patch( + "pyrit.cli._server_launcher.ServerLauncher.probe_health_async", + new=AsyncMock(return_value=True), + ), + ): + assert await pyrit_scan._resolve_server_url_async(parsed_args=parsed) == "http://cli:1111" - def test_list_param_collects_multiple_values(self): - """A declared list[str] parameter uses nargs='+' to collect successive values.""" - from pyrit.common import Parameter + # 2) Config wins when CLI omitted. + parsed = Namespace(server_url=None, start_server=False, config_file=None) + with ( + patch( + "pyrit.cli._config_reader.read_server_url", + return_value="http://cfg:2222", + ), + patch( + "pyrit.cli._server_launcher.ServerLauncher.probe_health_async", + new=AsyncMock(return_value=True), + ), + ): + assert await pyrit_scan._resolve_server_url_async(parsed_args=parsed) == "http://cfg:2222" - scenario_class = self._make_scenario_class([Parameter(name="datasets", description="d", param_type=list[str])]) - with self._patch_resolve(scenario_class): - args = pyrit_scan.parse_args(["fake_scenario", "--datasets", "a", "b", "c"]) + # 3) Built-in default when neither CLI nor config provide a URL. + from pyrit.cli._config_reader import DEFAULT_SERVER_URL - assert pyrit_scan._extract_scenario_args(parsed=args) == {"datasets": ["a", "b", "c"]} + parsed = Namespace(server_url=None, start_server=False, config_file=None) + with ( + patch("pyrit.cli._config_reader.read_server_url", return_value=None), + patch( + "pyrit.cli._server_launcher.ServerLauncher.probe_health_async", + new=AsyncMock(return_value=True), + ), + ): + assert await pyrit_scan._resolve_server_url_async(parsed_args=parsed) == DEFAULT_SERVER_URL - def test_choices_validated_by_argparse(self): - """A value outside ``choices`` is rejected at parse time.""" - from pyrit.common import Parameter - scenario_class = self._make_scenario_class( - [Parameter(name="mode", description="d", param_type=str, choices=("fast", "slow"))] - ) - with self._patch_resolve(scenario_class): - with pytest.raises(SystemExit): - pyrit_scan.parse_args(["fake_scenario", "--mode", "medium"]) +class TestScenarioParamCoercion: + """Regression tests for client-side coercion of typed scenario-declared params.""" - def test_unset_scenario_flag_not_in_namespace(self): - """``argparse.SUPPRESS`` keeps absent flags out of the parsed Namespace.""" - from pyrit.common import Parameter + def test_list_param_uses_nargs_plus(self): + from argparse import ArgumentParser - scenario_class = self._make_scenario_class( - [Parameter(name="max_turns", description="d", param_type=int, default=5)] + parser = ArgumentParser() + pyrit_scan._add_scenario_params_from_api( + parser=parser, + params=[{"name": "items", "description": "...", "param_type": "list[str]", "is_list": True}], ) - with self._patch_resolve(scenario_class): - args = pyrit_scan.parse_args(["fake_scenario"]) + parsed = parser.parse_args(["--items", "a", "b", "c"]) + assert parsed.scenario__items == ["a", "b", "c"] - assert pyrit_scan._extract_scenario_args(parsed=args) == {} + def test_int_param_is_coerced(self): + from argparse import ArgumentParser - def test_unknown_scenario_flag_rejected(self): - """Argparse pass 2 rejects flags the scenario didn't declare.""" - from pyrit.common import Parameter - - scenario_class = self._make_scenario_class( - [Parameter(name="max_turns", description="d", param_type=int, default=5)] + parser = ArgumentParser() + pyrit_scan._add_scenario_params_from_api( + parser=parser, + params=[{"name": "max_turns", "description": "...", "param_type": "int"}], ) - with self._patch_resolve(scenario_class): - with pytest.raises(SystemExit): - pyrit_scan.parse_args(["fake_scenario", "--unknown-flag", "value"]) + parsed = parser.parse_args(["--max-turns", "7"]) + assert parsed.scenario__max_turns == 7 - def test_collision_with_built_in_flag_raises_at_build_time(self): - """A declared parameter colliding with a built-in flag fails at parser-build time.""" - from pyrit.common import Parameter + def test_int_param_invalid_value_rejected_client_side(self, capsys): + from argparse import ArgumentParser - scenario_class = self._make_scenario_class( - [Parameter(name="max_concurrency", description="d", param_type=int, default=10)] + parser = ArgumentParser() + pyrit_scan._add_scenario_params_from_api( + parser=parser, + params=[{"name": "max_turns", "description": "...", "param_type": "int"}], ) - with self._patch_resolve(scenario_class): - with pytest.raises(ValueError, match="collides with an existing flag"): - pyrit_scan.parse_args(["fake_scenario", "--max-concurrency", "5"]) + with pytest.raises(SystemExit): + parser.parse_args(["--max-turns", "not-an-int"]) + assert "invalid value" in capsys.readouterr().err - def test_two_scenario_params_with_same_kebab_form_raise(self): - """Two declared parameters that normalize to the same kebab-case flag fail with our ValueError.""" - from pyrit.common import Parameter + def test_choices_validated_client_side(self, capsys): + from argparse import ArgumentParser - scenario_class = self._make_scenario_class( - [ - Parameter(name="foo_bar", description="d", param_type=str), - Parameter(name="foo-bar", description="d", param_type=str), - ] + parser = ArgumentParser() + pyrit_scan._add_scenario_params_from_api( + parser=parser, + params=[{"name": "mode", "description": "...", "param_type": "str", "choices": ["fast", "slow"]}], ) - with self._patch_resolve(scenario_class): - with pytest.raises(ValueError, match="collides with an existing flag"): - pyrit_scan.parse_args(["fake_scenario", "--foo-bar", "x"]) + parsed = parser.parse_args(["--mode", "fast"]) + assert parsed.scenario__mode == "fast" - def test_scenario_flag_works_before_positional(self): - """Pass 1 uses the full base parser so option order does not break positional ID.""" - from pyrit.common import Parameter + with pytest.raises(SystemExit): + parser.parse_args(["--mode", "warp"]) + assert "invalid choice" in capsys.readouterr().err - scenario_class = self._make_scenario_class( - [Parameter(name="max_turns", description="d", param_type=int, default=5)] - ) - with self._patch_resolve(scenario_class): - args = pyrit_scan.parse_args(["--config-file", "foo.yaml", "fake_scenario", "--max-turns", "7"]) - # config-file landed correctly + scenario name identified + scenario param parsed - assert args.config_file == Path("foo.yaml") - assert args.scenario_name == "fake_scenario" - assert pyrit_scan._extract_scenario_args(parsed=args) == {"max_turns": 7} +class TestMainExtraPaths: + """Tests for additional main() code paths.""" - def test_help_after_scenario_lists_declared_flags(self, capsys): - """`pyrit_scan --help` shows scenario-declared flags inline.""" - from pyrit.common import Parameter + def test_main_no_args_prints_help_and_exits_zero(self, capsys): + result = pyrit_scan.main([]) + assert result == 0 + captured = capsys.readouterr() + assert "PyRIT Scanner" in captured.out or "usage" in captured.out.lower() + + @patch("pyrit.cli._server_launcher.ServerLauncher.probe_health_async", new_callable=AsyncMock, return_value=True) + @patch("pyrit.cli.api_client.PyRITApiClient") + def test_main_scenario_not_found_lists_available(self, mock_client_class, _mock_probe, capsys): + mock_client = _mock_api_client() + mock_client.get_scenario_async.return_value = None + mock_client.list_scenarios_async.return_value = { + "items": [{"scenario_name": "alt_a"}, {"scenario_name": "alt_b"}], + "pagination": {}, + } + mock_client_class.return_value = mock_client + + result = pyrit_scan.main(["nonexistent", "--target", "t"]) + assert result == 1 + captured = capsys.readouterr() + assert "alt_a" in captured.out + assert "alt_b" in captured.out - scenario_class = self._make_scenario_class( - [Parameter(name="max_turns", description="Conversation turn cap", param_type=int, default=5)] - ) - with self._patch_resolve(scenario_class): - with pytest.raises(SystemExit) as exc_info: - pyrit_scan.parse_args(["fake_scenario", "--help"]) + @patch("pyrit.cli._server_launcher.ServerLauncher.probe_health_async", new_callable=AsyncMock, return_value=True) + @patch("pyrit.cli.api_client.PyRITApiClient") + def test_main_start_scenario_failure(self, mock_client_class, _mock_probe, capsys): + mock_client = _mock_api_client() + mock_client.start_scenario_run_async.side_effect = RuntimeError("server full") + mock_client_class.return_value = mock_client - assert exc_info.value.code == 0 + result = pyrit_scan.main(["test_scenario", "--target", "t"]) + assert result == 1 captured = capsys.readouterr() - assert "--max-turns" in captured.out - assert "Conversation turn cap" in captured.out + assert "server full" in captured.out - def test_config_only_scenario_name_registers_scenario_flags(self): - """When pass 1's positional doesn't resolve, fall back to ``scenario.name`` from the config file.""" - from pyrit.common import Parameter + @patch("pyrit.cli._server_launcher.ServerLauncher.probe_health_async", new_callable=AsyncMock, return_value=True) + @patch("pyrit.cli.api_client.PyRITApiClient") + def test_main_run_results_fallback_to_summary(self, mock_client_class, _mock_probe, capsys): + mock_client = _mock_api_client() + mock_client.get_scenario_run_results_async.side_effect = RuntimeError("nope") + mock_client_class.return_value = mock_client - scenario_class = self._make_scenario_class( - [Parameter(name="max_turns", description="d", param_type=int, default=5)] - ) + result = pyrit_scan.main(["test_scenario", "--target", "t"]) + assert result == 0 + captured = capsys.readouterr() + # The summary printer should be used as a fallback. + assert "test_scenario" in captured.out - # Resolve only "fake_scenario"; anything pass 1 misclassifies as the positional - # (e.g. the "7" from "--max-turns 7") returns None and triggers the config peek. - def fake_resolve(name): - return scenario_class if name == "fake_scenario" else None + @patch("pyrit.cli._server_launcher.ServerLauncher.probe_health_async", new_callable=AsyncMock, return_value=True) + @patch("pyrit.cli.api_client.PyRITApiClient") + def test_main_start_server_only_prints_url_and_returns_zero(self, mock_client_class, _mock_probe, capsys): + result = pyrit_scan.main(["--start-server"]) + assert result == 0 + captured = capsys.readouterr() + assert "running" in captured.out.lower() + + @patch( + "pyrit.cli._server_launcher.ServerLauncher.probe_health_async", + new_callable=AsyncMock, + return_value=True, + ) + @patch("pyrit.cli._server_launcher.stop_server_on_port", return_value=True) + def test_main_stop_server_kills_process_and_returns_zero(self, _stop_mock, _mock_probe, capsys): + result = pyrit_scan.main(["--stop-server"]) + assert result == 0 + assert "stopped" in capsys.readouterr().out + + @patch( + "pyrit.cli._server_launcher.ServerLauncher.probe_health_async", + new_callable=AsyncMock, + return_value=True, + ) + @patch("pyrit.cli._server_launcher.stop_server_on_port", return_value=False) + def test_main_stop_server_when_process_cannot_be_identified(self, _stop_mock, _mock_probe, capsys): + result = pyrit_scan.main(["--stop-server"]) + assert result == 0 + out = capsys.readouterr().out + assert "could not identify" in out - with ( - patch.object(pyrit_scan, "_peek_scenario_name_from_config", return_value="fake_scenario") as peek, - patch.object(pyrit_scan, "_resolve_scenario_class", side_effect=fake_resolve), - ): - args = pyrit_scan.parse_args(["--config-file", "foo.yaml", "--max-turns", "7"]) + @patch("pyrit.cli._server_launcher.ServerLauncher.probe_health_async", new_callable=AsyncMock, return_value=True) + @patch("pyrit.cli.api_client.PyRITApiClient") + def test_main_add_initializer_missing_file(self, mock_client_class, _mock_probe, capsys, tmp_path): + mock_client = _mock_api_client() + mock_client_class.return_value = mock_client + missing = tmp_path / "nonexistent.py" - peek.assert_called_once() - assert args.scenario_name is None - assert pyrit_scan._extract_scenario_args(parsed=args) == {"max_turns": 7} + result = pyrit_scan.main(["--add-initializer", str(missing)]) + assert result == 1 + assert "File not found" in capsys.readouterr().out + @patch("pyrit.cli._server_launcher.ServerLauncher.probe_health_async", new_callable=AsyncMock, return_value=True) + @patch("pyrit.cli.api_client.PyRITApiClient") + def test_main_add_initializer_success(self, mock_client_class, _mock_probe, capsys, tmp_path): + mock_client = _mock_api_client() + mock_client.register_initializer_async = AsyncMock(return_value={"initializer_name": "myinit"}) + mock_client_class.return_value = mock_client -class TestExtractScenarioArgs: - """Tests for the namespaced-dest extraction helper.""" + script = tmp_path / "myinit.py" + script.write_text("# stub initializer\n") - def test_no_scenario_keys_returns_empty(self): - from argparse import Namespace + result = pyrit_scan.main(["--add-initializer", str(script)]) + assert result == 0 + assert "Registered initializer 'myinit'" in capsys.readouterr().out + mock_client.register_initializer_async.assert_awaited_once() - result = pyrit_scan._extract_scenario_args(parsed=Namespace(scenario_name="x", config_file=None, log_level=20)) - assert result == {} + @patch("pyrit.cli._server_launcher.ServerLauncher.probe_health_async", new_callable=AsyncMock, return_value=True) + @patch("pyrit.cli.api_client.PyRITApiClient") + def test_main_add_initializer_server_disabled(self, mock_client_class, _mock_probe, capsys, tmp_path): + from pyrit.cli.api_client import ServerNotAvailableError - def test_scenario_keys_extracted_with_prefix_stripped(self): - from argparse import Namespace + mock_client = _mock_api_client() + mock_client.register_initializer_async = AsyncMock(side_effect=ServerNotAvailableError("disabled")) + mock_client_class.return_value = mock_client - result = pyrit_scan._extract_scenario_args( - parsed=Namespace( - scenario_name="x", - config_file=None, - scenario__max_turns=10, - scenario__mode="fast", - ) - ) - assert result == {"max_turns": 10, "mode": "fast"} + script = tmp_path / "myinit.py" + script.write_text("# stub\n") + result = pyrit_scan.main(["--add-initializer", str(script)]) + assert result == 1 + assert "disabled" in capsys.readouterr().out -class TestConfigScenarioMerge: - """Tests for the CLI/config scenario_args merge in pyrit_scan.main().""" - @staticmethod - def _patch_resolve(scenario_class): - return patch.object(pyrit_scan, "_resolve_scenario_class", return_value=scenario_class) +class TestScenarioParamFlow: + """Regression tests for scenario-declared parameters flowing through the CLI.""" @staticmethod - def _make_scenario_class(declared_params): - class _FakeScenario: - @classmethod - def supported_parameters(cls): - return list(declared_params) - - return _FakeScenario - - @patch("pyrit.cli.pyrit_scan.asyncio.run") - @patch("pyrit.cli.frontend_core.run_scenario_async", new_callable=AsyncMock) - @patch("pyrit.cli.frontend_core.FrontendCore") - def test_cli_args_override_config_args_per_key( - self, - mock_frontend_core: MagicMock, - mock_run_scenario: AsyncMock, - mock_asyncio_run: MagicMock, - ): - """When CLI and config both set max_turns, CLI wins per-key.""" - from pyrit.common import Parameter - from pyrit.setup.configuration_loader import ScenarioConfig - - # Config sets max_turns=5, mode=slow; CLI overrides max_turns=10. - mock_context = MagicMock() - mock_context._scenario_config = ScenarioConfig(name="scam", args={"max_turns": 5, "mode": "slow"}) - mock_frontend_core.return_value = mock_context - - scenario_class = self._make_scenario_class( - [ - Parameter(name="max_turns", description="d", param_type=int, default=5), - Parameter(name="mode", description="d", param_type=str, default="slow"), - ] - ) - with self._patch_resolve(scenario_class): - pyrit_scan.main(["scam", "--max-turns", "10"]) - - # Inspect the scenario_args kwarg passed into run_scenario_async - call_kwargs = mock_run_scenario.call_args.kwargs - assert call_kwargs["scenario_args"] == {"max_turns": 10, "mode": "slow"} - - @patch("pyrit.cli.pyrit_scan.asyncio.run") - @patch("pyrit.cli.frontend_core.run_scenario_async", new_callable=AsyncMock) - @patch("pyrit.cli.frontend_core.FrontendCore") - def test_config_scenario_used_when_no_positional( - self, - mock_frontend_core: MagicMock, - mock_run_scenario: AsyncMock, - mock_asyncio_run: MagicMock, - ): - """Config-only scenario invocation: pyrit_scan --config-file my.yaml.""" - from pyrit.setup.configuration_loader import ScenarioConfig - - mock_context = MagicMock() - mock_context._scenario_config = ScenarioConfig(name="scam", args={"max_turns": 5}) - mock_frontend_core.return_value = mock_context - - # No positional, no scenario flags (would require pass-2 augmentation, - # which is a documented v1 limitation). - with self._patch_resolve(None): - result = pyrit_scan.main([]) + def _build_mock_client(supported_params=None, status="COMPLETED"): + from unittest.mock import AsyncMock + + client = AsyncMock() + client.list_scenarios_async.return_value = {"items": [{"scenario_name": "foo"}]} + client.get_scenario_async.return_value = { + "scenario_name": "foo", + "supported_parameters": supported_params or [], + } + client.start_scenario_run_async.return_value = {"scenario_result_id": "rid", "status": "CREATED"} + client.get_scenario_run_async.return_value = {"scenario_result_id": "rid", "status": status} + client.get_scenario_run_results_async.return_value = {"items": []} + client.close_async = AsyncMock() + client.__aenter__ = AsyncMock(return_value=client) + client.__aexit__ = AsyncMock(return_value=None) + return client + + @patch("pyrit.cli._server_launcher.ServerLauncher.probe_health_async", new_callable=AsyncMock, return_value=True) + @patch("pyrit.cli.api_client.PyRITApiClient") + @patch("pyrit.cli._output.print_scenario_result_async", new_callable=AsyncMock) + @patch("pyrit.cli._output.print_scenario_run_progress") + def test_scenario_declared_flag_is_forwarded(self, _mock_prog, _mock_print, mock_client_class, _mock_probe): + client = self._build_mock_client(supported_params=[{"name": "max_turns", "description": "..."}]) + mock_client_class.return_value = client + + result = pyrit_scan.main(["foo", "--target", "t", "--max-turns", "7"]) assert result == 0 - call_kwargs = mock_run_scenario.call_args.kwargs - assert call_kwargs["scenario_name"] == "scam" - assert call_kwargs["scenario_args"] == {"max_turns": 5} - - @patch("pyrit.cli.pyrit_scan.asyncio.run") - @patch("pyrit.cli.frontend_core.run_scenario_async", new_callable=AsyncMock) - @patch("pyrit.cli.frontend_core.FrontendCore") - def test_config_args_ignored_when_cli_specifies_different_scenario( - self, - mock_frontend_core: MagicMock, - mock_run_scenario: AsyncMock, - mock_asyncio_run: MagicMock, - ): - """CLI scenario name differs from config: config args silently dropped (CLI-wins).""" - from pyrit.setup.configuration_loader import ScenarioConfig - - mock_context = MagicMock() - mock_context._scenario_config = ScenarioConfig(name="scam", args={"max_turns": 5}) - mock_frontend_core.return_value = mock_context - - with self._patch_resolve(None): - pyrit_scan.main(["other_scenario"]) - - call_kwargs = mock_run_scenario.call_args.kwargs - assert call_kwargs["scenario_name"] == "other_scenario" - assert call_kwargs["scenario_args"] == {} - - @patch("pyrit.cli.frontend_core.FrontendCore") - def test_no_scenario_anywhere_returns_error( - self, - mock_frontend_core: MagicMock, - ): - """No CLI positional and no config scenario: explicit error message + nonzero exit.""" - mock_context = MagicMock() - mock_context._scenario_config = None - mock_frontend_core.return_value = mock_context - - with self._patch_resolve(None): - result = pyrit_scan.main([]) + sent_request = client.start_scenario_run_async.call_args.kwargs["request"] + assert sent_request["scenario_params"] == {"max_turns": "7"} + + @patch("pyrit.cli._server_launcher.ServerLauncher.probe_health_async", new_callable=AsyncMock, return_value=True) + @patch("pyrit.cli.api_client.PyRITApiClient") + @patch("pyrit.cli._output.print_scenario_result_async", new_callable=AsyncMock) + @patch("pyrit.cli._output.print_scenario_run_progress") + def test_unknown_flag_after_valid_scenario_errors(self, _mock_prog, _mock_print, mock_client_class, _mock_probe): + client = self._build_mock_client(supported_params=[{"name": "max_turns", "description": "..."}]) + mock_client_class.return_value = client + + result = pyrit_scan.main(["foo", "--target", "t", "--max-turns", "7", "--unknown-flag"]) assert result == 1 + client.start_scenario_run_async.assert_not_called() + + @patch("pyrit.cli._server_launcher.ServerLauncher.probe_health_async", new_callable=AsyncMock, return_value=True) + @patch("pyrit.cli.api_client.PyRITApiClient") + @patch("pyrit.cli._output.print_scenario_result_async", new_callable=AsyncMock) + @patch("pyrit.cli._output.print_scenario_run_progress") + def test_no_scenario_params_passes_through_cleanly(self, _mock_prog, _mock_print, mock_client_class, _mock_probe): + client = self._build_mock_client(supported_params=[]) + mock_client_class.return_value = client - @patch("pyrit.cli.pyrit_scan.asyncio.run") - @patch("pyrit.cli.frontend_core.run_scenario_async", new_callable=AsyncMock) - @patch("pyrit.cli.frontend_core.FrontendCore") - def test_config_args_deep_copied( - self, - mock_frontend_core: MagicMock, - mock_run_scenario: AsyncMock, - mock_asyncio_run: MagicMock, - ): - """Mutating scenario_args on one run must not leak into the config block.""" - from pyrit.setup.configuration_loader import ScenarioConfig - - original_args = {"datasets": ["a", "b"]} - mock_context = MagicMock() - mock_context._scenario_config = ScenarioConfig(name="scam", args=original_args) - mock_frontend_core.return_value = mock_context - - with self._patch_resolve(None): - pyrit_scan.main(["scam"]) - - call_kwargs = mock_run_scenario.call_args.kwargs - # Mutate the passed dict - call_kwargs["scenario_args"]["datasets"].append("c") - # Original config block must be untouched - assert original_args == {"datasets": ["a", "b"]} + result = pyrit_scan.main(["foo", "--target", "t"]) + + assert result == 0 + sent_request = client.start_scenario_run_async.call_args.kwargs["request"] + assert "scenario_params" not in sent_request + + def test_parse_args_tolerates_scenario_specific_flags(self): + # Pass 1 must not error on scenario-declared flags (they're recognized in pass 2). + parsed = pyrit_scan.parse_args(["foo", "--target", "t", "--max-turns", "7"]) + assert parsed.scenario_name == "foo" + assert parsed.target == "t" + assert parsed._unknown_args == ["--max-turns", "7"] diff --git a/tests/unit/cli/test_pyrit_shell.py b/tests/unit/cli/test_pyrit_shell.py index c7e70de450..233206071d 100644 --- a/tests/unit/cli/test_pyrit_shell.py +++ b/tests/unit/cli/test_pyrit_shell.py @@ -2,911 +2,753 @@ # Licensed under the MIT license. """ -Unit tests for the pyrit_shell CLI module. +Unit tests for the pyrit_shell CLI module (thin REST client). """ -import cmd -import logging -from pathlib import Path from unittest.mock import AsyncMock, MagicMock, patch import pytest -from pyrit.cli import _banner as banner from pyrit.cli import pyrit_shell @pytest.fixture() -def mock_fc(): - """Patch FrontendCore so the background thread uses a controllable mock context.""" - mock_context = MagicMock() - mock_context._database = "SQLite" - mock_context._log_level = "WARNING" - mock_context._env_files = None - mock_context._scenario_registry = MagicMock() - mock_context._initializer_registry = MagicMock() - mock_context.initialize_async = AsyncMock() - - with patch("pyrit.cli.frontend_core.FrontendCore", return_value=mock_context) as mock_fc_class: - yield mock_context, mock_fc_class +def mock_api_client(): + """Create a mock PyRITApiClient with default responses.""" + client = AsyncMock() + client.health_check_async.return_value = True + client.list_scenarios_async.return_value = {"items": [], "pagination": {"total": 0}} + client.list_initializers_async.return_value = {"items": [], "pagination": {"total": 0}} + client.list_targets_async.return_value = {"items": [], "pagination": {"total": 0}} + client.list_scenario_runs_async.return_value = {"items": []} + # Default: scenario fetch returns no declared params (back-compat for older tests) + client.get_scenario_async.return_value = {"scenario_name": "foo", "supported_parameters": []} + client.close_async = AsyncMock() + client.__aenter__ = AsyncMock(return_value=client) + client.__aexit__ = AsyncMock(return_value=None) + return client @pytest.fixture() -def shell(): - """Create a fully-initialized PyRITShell without spawning a background thread. - - Bypasses the real ``_background_init`` and wires up a mock FrontendCore - directly, avoiding thread + asyncio.run overhead per test. - """ - mock_context = MagicMock() - mock_context._database = "SQLite" - mock_context._log_level = "WARNING" - mock_context._env_files = None - mock_context._scenario_registry = MagicMock() - mock_context._initializer_registry = MagicMock() - mock_context.initialize_async = AsyncMock() - - with patch("pyrit.cli.frontend_core.FrontendCore", return_value=mock_context) as mock_fc_class: - with patch.object(pyrit_shell.PyRITShell, "_background_init"): - s = pyrit_shell.PyRITShell() - # Manually set the state that _background_init would have set - from pyrit.cli import frontend_core as fc_module - - s._fc = fc_module - s.context = mock_context - s.default_log_level = mock_context._log_level - s._init_complete.set() - yield s, mock_context, mock_fc_class +def shell(mock_api_client): + """Create a PyRITShell with a pre-wired mock API client.""" + s = pyrit_shell.PyRITShell(no_animation=True) + s._api_client = mock_api_client + s._base_url = "http://localhost:8000" + return s, mock_api_client class TestPyRITShell: """Tests for PyRITShell class.""" - def test_init(self, mock_fc): - """Test PyRITShell initialization.""" - ctx, mock_fc_class = mock_fc - - shell = pyrit_shell.PyRITShell() - shell._init_thread.join(timeout=5) - - assert shell._init_complete.is_set() - assert shell.context is ctx - assert shell.default_log_level == "WARNING" - assert shell._scenario_history == [] - mock_fc_class.assert_called_once_with() - ctx.initialize_async.assert_called_once() - - def test_background_init_failure_sets_event_and_raises_in_ensure_initialized(self, mock_fc): - """Test failed background initialization unblocks waiters and surfaces the original error.""" - ctx, _ = mock_fc - ctx.initialize_async = AsyncMock(side_effect=RuntimeError("Initialization failed")) - - shell = pyrit_shell.PyRITShell() - shell._init_thread.join(timeout=2) - - assert shell._init_complete.is_set() - with pytest.raises(RuntimeError, match="Initialization failed"): - shell._ensure_initialized() - - def test_prompt_and_intro(self, shell): - """Test shell prompt is set and cmdloop wires play_animation to intro.""" - s, ctx, _ = shell - + def test_prompt(self, shell): + s, _ = shell assert s.prompt == "pyrit> " - # Verify that cmdloop calls play_animation and passes the result as intro + def test_cmdloop_plays_animation(self): + s = pyrit_shell.PyRITShell(no_animation=True) with ( - patch("pyrit.cli._banner.play_animation", return_value="TEST_BANNER") as mock_play, + patch("pyrit.cli._banner.play_animation", return_value="BANNER") as mock_play, patch("cmd.Cmd.cmdloop") as mock_cmdloop, ): s.cmdloop() + mock_play.assert_called_once_with(no_animation=True) + mock_cmdloop.assert_called_once_with(intro="BANNER") - mock_play.assert_called_once_with(no_animation=s._no_animation) - mock_cmdloop.assert_called_once_with(intro="TEST_BANNER") - - def test_cmdloop_honors_explicit_intro(self, shell): - """Test that cmdloop passes through a non-None intro without calling play_animation.""" - s, ctx, _ = shell - - with patch("pyrit.cli._banner.play_animation") as mock_play, patch("cmd.Cmd.cmdloop") as mock_cmdloop: + def test_cmdloop_honors_explicit_intro(self): + s = pyrit_shell.PyRITShell(no_animation=True) + with ( + patch("pyrit.cli._banner.play_animation") as mock_play, + patch("cmd.Cmd.cmdloop") as mock_cmdloop, + ): s.cmdloop(intro="Custom intro") - mock_play.assert_not_called() mock_cmdloop.assert_called_once_with(intro="Custom intro") - @patch("pyrit.cli.frontend_core.print_scenarios_list_async", new_callable=AsyncMock) - def test_do_list_scenarios(self, mock_print_scenarios: AsyncMock, shell): - """Test do_list_scenarios command.""" - s, ctx, _ = shell - + def test_do_list_scenarios(self, shell): + s, client = shell s.do_list_scenarios("") - - mock_print_scenarios.assert_called_once_with(context=ctx) - - @patch("pyrit.cli.frontend_core.print_scenarios_list_async", new_callable=AsyncMock) - def test_do_list_scenarios_with_exception(self, mock_print_scenarios: AsyncMock, shell, capsys): - """Test do_list_scenarios handles exceptions.""" - s, ctx, _ = shell - mock_print_scenarios.side_effect = ValueError("Test error") - - s.do_list_scenarios("") - - captured = capsys.readouterr() - assert "Error listing scenarios" in captured.out + client.list_scenarios_async.assert_awaited_once() def test_do_list_scenarios_rejects_args(self, shell, capsys): - """Test do_list_scenarios rejects unexpected arguments.""" - s, ctx, _ = shell - + s, _ = shell s.do_list_scenarios("--unknown foo") - captured = capsys.readouterr() assert "does not accept arguments" in captured.out - @patch("pyrit.cli.frontend_core.print_initializers_list_async", new_callable=AsyncMock) - def test_do_list_initializers(self, mock_print_initializers: AsyncMock, shell): - """Test do_list_initializers command.""" - s, ctx, _ = shell - + def test_do_list_initializers(self, shell): + s, client = shell s.do_list_initializers("") - - mock_print_initializers.assert_called_once_with(context=ctx) - - @patch("pyrit.cli.frontend_core.print_initializers_list_async", new_callable=AsyncMock) - def test_do_list_initializers_with_exception(self, mock_print_initializers: AsyncMock, shell, capsys): - """Test do_list_initializers handles exceptions.""" - s, ctx, _ = shell - mock_print_initializers.side_effect = ValueError("Test error") - - s.do_list_initializers("") - - captured = capsys.readouterr() - assert "Error listing initializers" in captured.out + client.list_initializers_async.assert_awaited_once() def test_do_list_initializers_rejects_args(self, shell, capsys): - """Test do_list_initializers rejects unexpected arguments.""" - s, ctx, _ = shell - + s, _ = shell s.do_list_initializers("--unknown foo") - captured = capsys.readouterr() assert "does not accept arguments" in captured.out - @patch("pyrit.cli.frontend_core.print_targets_list_async", new_callable=AsyncMock) - def test_do_list_targets_no_args(self, mock_print_targets: AsyncMock, shell): - """Test do_list_targets with no arguments uses the default context.""" - s, ctx, _ = shell - + def test_do_list_targets(self, shell): + s, client = shell s.do_list_targets("") + client.list_targets_async.assert_awaited_once() - mock_print_targets.assert_called_once_with(context=ctx) - - @patch("pyrit.cli.frontend_core.print_targets_list_async", new_callable=AsyncMock) - @patch("pyrit.cli.frontend_core.parse_list_targets_arguments") - def test_do_list_targets_with_initializers( - self, - mock_parse: MagicMock, - mock_print_targets: AsyncMock, - shell, - ): - """Test do_list_targets with --initializers uses context.with_overrides.""" - s, ctx, _ = shell - mock_parse.return_value = {"initializers": ["target"], "initialization_scripts": None} - mock_derived = MagicMock() - ctx.with_overrides = MagicMock(return_value=mock_derived) - - s.do_list_targets("--initializers target") - - mock_parse.assert_called_once_with(args_string="--initializers target") - ctx.with_overrides.assert_called_once_with( - initialization_scripts=None, - initializer_names=["target"], - ) - mock_print_targets.assert_called_once_with(context=mock_derived) - - @patch("pyrit.cli.frontend_core.print_targets_list_async", new_callable=AsyncMock) - def test_do_list_targets_with_exception(self, mock_print_targets: AsyncMock, shell, capsys): - """Test do_list_targets handles exceptions.""" - s, ctx, _ = shell - mock_print_targets.side_effect = RuntimeError("Test error") - - s.do_list_targets("") - - captured = capsys.readouterr() - assert "Error listing targets" in captured.out - - def test_do_list_targets_parse_error(self, shell, capsys): - """Test do_list_targets shows error for invalid args.""" - s, ctx, _ = shell - - s.do_list_targets("--unknown-flag") - - captured = capsys.readouterr() - assert "Error" in captured.out - - def test_do_run_empty_line(self, shell, capsys): - """Test do_run with empty line.""" - s, ctx, _ = shell - + def test_do_run_empty_args(self, shell, capsys): + s, _ = shell s.do_run("") - captured = capsys.readouterr() assert "Specify a scenario name" in captured.out - @patch("pyrit.cli.pyrit_shell.asyncio.run") - @patch("pyrit.cli.frontend_core.run_scenario_async", new_callable=AsyncMock) - @patch("pyrit.cli.frontend_core.parse_run_arguments") - def test_do_run_basic_scenario( - self, - mock_parse_args: MagicMock, - _mock_run_scenario: AsyncMock, - mock_asyncio_run: MagicMock, - shell, - ): - """Test do_run with basic scenario.""" - s, ctx, _ = shell - - mock_parse_args.return_value = { - "scenario_name": "test_scenario", - "initializers": ["test_init"], - "initialization_scripts": None, - "scenario_strategies": None, - "max_concurrency": None, - "max_retries": None, - "memory_labels": None, - "log_level": None, - "dataset_names": None, - "max_dataset_size": None, - "target": None, - } - - mock_result = MagicMock() - mock_asyncio_run.side_effect = [mock_result] - - s.do_run("test_scenario --initializers test_init") - - mock_parse_args.assert_called_once() - assert mock_asyncio_run.call_count == 1 - - # Verify result was stored in history - assert len(s._scenario_history) == 1 - assert s._scenario_history[0][0] == "test_scenario --initializers test_init" - assert s._scenario_history[0][1] == mock_result - - @patch("pyrit.cli.frontend_core.parse_run_arguments") - def test_do_run_parse_error(self, mock_parse_args: MagicMock, shell, capsys): - """Test do_run with parse error.""" - s, ctx, _ = shell - mock_parse_args.side_effect = ValueError("Parse error") - - s.do_run("test_scenario --invalid") - - captured = capsys.readouterr() - assert "Error: Parse error" in captured.out - - @patch("pyrit.cli.pyrit_shell.asyncio.run") - @patch("pyrit.cli.frontend_core.run_scenario_async", new_callable=AsyncMock) - @patch("pyrit.cli.frontend_core.parse_run_arguments") - @patch("pyrit.cli.frontend_core.resolve_initialization_scripts") - def test_do_run_with_initialization_scripts( - self, - mock_resolve_scripts: MagicMock, - mock_parse_args: MagicMock, - mock_run_scenario: AsyncMock, - mock_asyncio_run: MagicMock, - shell, - ): - """Test do_run with initialization scripts.""" - s, ctx, _ = shell - - mock_parse_args.return_value = { - "scenario_name": "test_scenario", - "initializers": None, - "initialization_scripts": ["script.py"], - "scenario_strategies": None, - "max_concurrency": None, - "max_retries": None, - "memory_labels": None, - "log_level": None, - "dataset_names": None, - "max_dataset_size": None, - "target": None, - } - - mock_resolve_scripts.return_value = [Path("/test/script.py")] - mock_asyncio_run.side_effect = [MagicMock()] - - s.do_run("test_scenario --initialization-scripts script.py") - - mock_resolve_scripts.assert_called_once_with(script_paths=["script.py"]) - assert mock_asyncio_run.call_count == 1 - - @patch("pyrit.cli.frontend_core.parse_run_arguments") - @patch("pyrit.cli.frontend_core.resolve_initialization_scripts") - def test_do_run_with_missing_script( - self, - mock_resolve_scripts: MagicMock, - mock_parse_args: MagicMock, - shell, - capsys, - ): - """Test do_run with missing initialization script.""" - s, ctx, _ = shell - - mock_parse_args.return_value = { - "scenario_name": "test_scenario", - "initializers": None, - "initialization_scripts": ["missing.py"], - "scenario_strategies": None, - "max_concurrency": None, - "max_retries": None, - "memory_labels": None, - "log_level": None, - "dataset_names": None, - "max_dataset_size": None, - "target": None, - } - - mock_resolve_scripts.side_effect = FileNotFoundError("Script not found") - - s.do_run("test_scenario --initialization-scripts missing.py") - - captured = capsys.readouterr() - assert "Error: Script not found" in captured.out - - @patch("pyrit.cli.pyrit_shell.asyncio.run") - @patch("pyrit.cli.frontend_core.parse_run_arguments") - def test_do_run_with_exception( - self, - mock_parse_args: MagicMock, - mock_asyncio_run: MagicMock, - shell, - capsys, - ): - """Test do_run handles exceptions during scenario run.""" - s, ctx, _ = shell - - mock_parse_args.return_value = { - "scenario_name": "test_scenario", - "initializers": ["test_init"], - "initialization_scripts": None, - "scenario_strategies": None, - "max_concurrency": None, - "max_retries": None, - "memory_labels": None, - "log_level": None, - "dataset_names": None, - "max_dataset_size": None, - "target": None, - } - - mock_asyncio_run.side_effect = [ValueError("Test error")] - - s.do_run("test_scenario --initializers test_init") - - captured = capsys.readouterr() - assert "Error: Test error" in captured.out - - @patch("pyrit.cli.pyrit_shell.asyncio.run") - @patch("pyrit.cli.frontend_core.parse_run_arguments") - def test_do_run_keyboard_interrupt_returns_to_shell( - self, - mock_parse_args: MagicMock, - mock_asyncio_run: MagicMock, - shell, - capsys, - ): - """Test that Ctrl+C during scenario run returns to shell instead of crashing.""" - s, ctx, _ = shell - - mock_parse_args.return_value = { - "scenario_name": "test_scenario", - "initializers": ["test_init"], - "initialization_scripts": None, - "env_files": None, - "scenario_strategies": None, - "max_concurrency": None, - "max_retries": None, - "memory_labels": None, - "database": None, - "log_level": None, - "dataset_names": None, - "max_dataset_size": None, - "target": None, - } - - mock_asyncio_run.side_effect = KeyboardInterrupt() - - s.do_run("test_scenario --initializers test_init") - - captured = capsys.readouterr() - assert "interrupted" in captured.out.lower() - # Scenario should NOT be added to history - assert len(s._scenario_history) == 0 - - def test_do_scenario_history_empty(self, shell, capsys): - """Test do_scenario_history with no history.""" - s, ctx, _ = shell - - s.do_scenario_history("") - - captured = capsys.readouterr() - assert "No scenario runs in history" in captured.out - - def test_do_scenario_history_rejects_args(self, shell, capsys): - """Test do_scenario_history rejects unexpected arguments.""" - s, ctx, _ = shell - - s.do_scenario_history("--unknown foo") - - captured = capsys.readouterr() - assert "does not accept arguments" in captured.out - - def test_do_scenario_history_with_runs(self, shell, capsys): - """Test do_scenario_history with scenario runs.""" - s, ctx, _ = shell - - s._scenario_history = [ - ("test_scenario1 --initializers init1", MagicMock()), - ("test_scenario2 --initializers init2", MagicMock()), - ] - + def test_do_scenario_history_default_limit(self, shell): + s, client = shell + client.list_scenario_runs_async.return_value = {"items": []} s.do_scenario_history("") + client.list_scenario_runs_async.assert_awaited_once_with(limit=10) - captured = capsys.readouterr() - assert "Scenario Run History" in captured.out - assert "test_scenario1" in captured.out - assert "test_scenario2" in captured.out - assert "Total runs: 2" in captured.out - - def test_do_print_scenario_empty(self, shell, capsys): - """Test do_print_scenario with no history.""" - s, ctx, _ = shell - - s.do_print_scenario("") + def test_do_scenario_history_accepts_numeric_limit(self, shell): + s, client = shell + client.list_scenario_runs_async.return_value = {"items": []} + s.do_scenario_history("3") + client.list_scenario_runs_async.assert_awaited_once_with(limit=3) + def test_do_scenario_history_rejects_non_integer(self, shell, capsys): + s, _ = shell + s.do_scenario_history("extra") captured = capsys.readouterr() - assert "No scenario runs in history" in captured.out - - @patch("pyrit.cli.pyrit_shell.asyncio.run") - @patch("pyrit.output.scenario_result.pretty.PrettyScenarioResultMemoryPrinter") - def test_do_print_scenario_all( - self, - mock_printer_class: MagicMock, - mock_asyncio_run: MagicMock, - shell, - capsys, - ): - """Test do_print_scenario without argument prints all.""" - s, ctx, _ = shell - mock_printer = MagicMock() - mock_printer_class.return_value = mock_printer - - s._scenario_history = [ - ("test_scenario1", MagicMock()), - ("test_scenario2", MagicMock()), - ] + assert "Usage: scenario-history" in captured.out + def test_do_print_scenario_no_args(self, shell, capsys): + s, _ = shell s.do_print_scenario("") - captured = capsys.readouterr() - assert "Printing all scenario results" in captured.out - # 2 print calls (no background init) - assert mock_asyncio_run.call_count == 2 - - @patch("pyrit.cli.pyrit_shell.asyncio.run") - @patch("pyrit.output.scenario_result.pretty.PrettyScenarioResultMemoryPrinter") - def test_do_print_scenario_specific( - self, - mock_printer_class: MagicMock, - mock_asyncio_run: MagicMock, - shell, - capsys, - ): - """Test do_print_scenario with specific scenario number.""" - s, ctx, _ = shell - mock_printer = MagicMock() - mock_printer_class.return_value = mock_printer - - s._scenario_history = [ - ("test_scenario1", MagicMock()), - ("test_scenario2", MagicMock()), - ] - - s.do_print_scenario("1") - - captured = capsys.readouterr() - assert "Scenario Run #1" in captured.out - # 1 print call (no background init) - assert mock_asyncio_run.call_count == 1 - - def test_do_print_scenario_invalid_number(self, shell, capsys): - """Test do_print_scenario with invalid scenario number.""" - s, ctx, _ = shell - - s._scenario_history = [ - ("test_scenario1", MagicMock()), - ] - - s.do_print_scenario("5") - - captured = capsys.readouterr() - assert "must be between 1 and 1" in captured.out - - def test_do_print_scenario_non_integer(self, shell, capsys): - """Test do_print_scenario with non-integer argument.""" - s, ctx, _ = shell - - s._scenario_history = [ - ("test_scenario1", MagicMock()), - ] - - s.do_print_scenario("invalid") - - captured = capsys.readouterr() - assert "Invalid scenario number" in captured.out - - def test_do_help_without_arg(self, shell, capsys): - """Test do_help without argument.""" - s, ctx, _ = shell - - # Capture help output - with patch("cmd.Cmd.do_help"): - s.do_help("") - captured = capsys.readouterr() - assert "Shell Startup Options" in captured.out - - def test_do_help_with_arg(self, shell): - """Test do_help with specific command.""" - s, ctx, _ = shell - - with patch("cmd.Cmd.do_help") as mock_parent_help: - s.do_help("run") - mock_parent_help.assert_called_with("run") - - def test_do_help_with_hyphenated_arg(self, shell): - """Test do_help converts hyphens to underscores for command lookup.""" - s, ctx, _ = shell - - with patch("cmd.Cmd.do_help") as mock_parent_help: - s.do_help("list-targets") - mock_parent_help.assert_called_with("list_targets") - - @patch.object(cmd.Cmd, "cmdloop") - @patch.object(banner, "play_animation") - def test_cmdloop_sets_intro_via_play_animation(self, mock_play: MagicMock, mock_cmdloop: MagicMock, shell): - """Test cmdloop wires banner.play_animation into intro and threads --no-animation.""" - s, ctx, _ = shell - - mock_play.return_value = "animated banner" - - # Note: no_animation is not set because shell fixture uses default - s._no_animation = True - s.cmdloop() - - mock_play.assert_called_once_with(no_animation=True) - assert s.intro == "animated banner" - mock_cmdloop.assert_called_once_with(intro="animated banner") - - @patch.object(cmd.Cmd, "cmdloop") - def test_cmdloop_honors_explicit_intro(self, mock_cmdloop: MagicMock, shell): - """Test cmdloop honors a non-None intro argument without calling play_animation.""" - s, ctx, _ = shell - - s.cmdloop(intro="custom intro") - - assert s.intro == "custom intro" - mock_cmdloop.assert_called_once_with(intro="custom intro") - - def test_do_exit(self, shell, capsys): - """Test do_exit command.""" - s, ctx, _ = shell + assert "Usage" in captured.out + def test_do_exit(self, shell): + s, client = shell result = s.do_exit("") - assert result is True - captured = capsys.readouterr() - assert "Goodbye" in captured.out + client.close_async.assert_awaited_once() def test_do_quit_alias(self, shell): - """Test do_quit is alias for do_exit.""" - s, ctx, _ = shell - + s, _ = shell assert s.do_quit == s.do_exit def test_do_q_alias(self, shell): - """Test do_q is alias for do_exit.""" - s, ctx, _ = shell - + s, _ = shell assert s.do_q == s.do_exit - def test_do_eof_alias(self, shell): - """Test do_EOF is alias for do_exit.""" - s, ctx, _ = shell - - assert s.do_EOF == s.do_exit - - @patch("os.system") - def test_do_clear_windows(self, mock_system: MagicMock, shell): - """Test do_clear on Windows.""" - s, ctx, _ = shell - - with patch("os.name", "nt"): - s.do_clear("") - mock_system.assert_called_with("cls") - - @patch("os.system") - def test_do_clear_unix(self, mock_system: MagicMock, shell): - """Test do_clear on Unix.""" - s, ctx, _ = shell - - with patch("os.name", "posix"): - s.do_clear("") - mock_system.assert_called_with("clear") - def test_emptyline(self, shell): - """Test emptyline doesn't repeat last command.""" - s, ctx, _ = shell + s, _ = shell + assert s.emptyline() is False - result = s.emptyline() + def test_default_unknown_command(self, shell, capsys): + s, _ = shell + s.default("unknown_command") + captured = capsys.readouterr() + assert "Unknown command" in captured.out + def test_default_hyphen_to_underscore(self, shell): + s, client = shell + s.default("list-scenarios") + client.list_scenarios_async.assert_awaited_once() + + def test_do_stop_server_no_launcher(self, shell, capsys): + s, _ = shell + with ( + patch( + "pyrit.cli._server_launcher.ServerLauncher.probe_health_async", + new_callable=AsyncMock, + return_value=True, + ), + patch("pyrit.cli._server_launcher.stop_server_on_port", return_value=False), + ): + s.do_stop_server("") + captured = capsys.readouterr() + assert "No server found" in captured.out + + def test_ensure_client_already_connected(self, shell): + s, _ = shell + assert s._ensure_client() is True + + def test_ensure_client_no_server(self, capsys): + s = pyrit_shell.PyRITShell(no_animation=True) + with patch( + "pyrit.cli._server_launcher.ServerLauncher.probe_health_async", + new_callable=AsyncMock, + ) as mock_probe: + mock_probe.return_value = False + result = s._ensure_client() assert result is False + captured = capsys.readouterr() + assert "Server not available" in captured.out - def test_default_with_hyphen_to_underscore(self, shell): - """Test default converts hyphens to underscores.""" - s, ctx, _ = shell - # Mock a method with underscores - s.do_list_scenarios = MagicMock() +class TestShellMain: + """Tests for the shell main() entry point.""" - s.default("list-scenarios") + def test_main_parses_server_url(self): + with ( + patch("pyrit.cli._banner.play_animation", return_value=""), + patch("pyrit.cli.pyrit_shell.PyRITShell") as mock_shell_class, + ): + mock_shell = MagicMock() + mock_shell_class.return_value = mock_shell - s.do_list_scenarios.assert_called_once_with("") + with patch("sys.argv", ["pyrit_shell", "--server-url", "http://remote:9000", "--no-animation"]): + pyrit_shell.main() - def test_default_unknown_command(self, shell, capsys): - """Test default with unknown command.""" - s, ctx, _ = shell + mock_shell_class.assert_called_once() + assert mock_shell_class.call_args.kwargs["server_url"] == "http://remote:9000" - s.default("unknown_command") + def test_main_keyboard_interrupt(self, capsys): + with ( + patch("pyrit.cli._banner.play_animation", return_value=""), + patch("pyrit.cli.pyrit_shell.PyRITShell") as mock_shell_class, + patch("sys.argv", ["pyrit_shell", "--no-animation"]), + ): + mock_shell = MagicMock() + mock_shell.cmdloop.side_effect = KeyboardInterrupt() + mock_shell_class.return_value = mock_shell - captured = capsys.readouterr() - assert "Unknown command" in captured.out + result = pyrit_shell.main() + assert result == 0 + def test_main_generic_exception(self, capsys): + with ( + patch("pyrit.cli._banner.play_animation", return_value=""), + patch("pyrit.cli.pyrit_shell.PyRITShell") as mock_shell_class, + patch("sys.argv", ["pyrit_shell", "--no-animation"]), + ): + mock_shell = MagicMock() + mock_shell.cmdloop.side_effect = RuntimeError("boom") + mock_shell_class.return_value = mock_shell -class TestNullGuards: - """Tests for null-guard checks that raise RuntimeError when _fc or context is None.""" - - @pytest.fixture() - def uninitialized_shell(self): - """Create a shell where _ensure_initialized passes but _fc and context are None.""" - with patch.object(pyrit_shell.PyRITShell, "_background_init"): - s = pyrit_shell.PyRITShell() - s._init_complete.set() - s._fc = None - s.context = None - return s - - def test_ensure_initialized_raises_when_fc_is_none(self, uninitialized_shell): - """Test _ensure_initialized raises RuntimeError when _fc is None.""" - with pytest.raises(RuntimeError, match="Frontend core not initialized"): - uninitialized_shell._ensure_initialized() - - def test_do_list_scenarios_raises_when_fc_is_none(self, uninitialized_shell): - """Test do_list_scenarios raises RuntimeError when _fc is None after _ensure_initialized.""" - with patch.object(uninitialized_shell, "_ensure_initialized"): - with pytest.raises(RuntimeError, match="Frontend core not initialized"): - uninitialized_shell.do_list_scenarios("") - - def test_do_list_initializers_raises_when_fc_is_none(self, uninitialized_shell): - """Test do_list_initializers raises RuntimeError when _fc is None after _ensure_initialized.""" - with patch.object(uninitialized_shell, "_ensure_initialized"): - with pytest.raises(RuntimeError, match="Frontend core not initialized"): - uninitialized_shell.do_list_initializers("") - - def test_do_list_targets_raises_when_fc_is_none(self, uninitialized_shell): - """Test do_list_targets raises RuntimeError when _fc is None after _ensure_initialized.""" - with patch.object(uninitialized_shell, "_ensure_initialized"): - with pytest.raises(RuntimeError, match="Frontend core not initialized"): - uninitialized_shell.do_list_targets("") - - def test_do_run_raises_when_fc_is_none(self, uninitialized_shell): - """Test do_run raises RuntimeError when _fc is None after _ensure_initialized.""" - with patch.object(uninitialized_shell, "_ensure_initialized"): - with pytest.raises(RuntimeError, match="Frontend core not initialized"): - uninitialized_shell.do_run("some_scenario --target t") - - -class TestMain: - """Tests for main function.""" - - @patch("pyrit.cli.pyrit_shell.PyRITShell") - @patch("pyrit.cli._banner.play_animation", return_value="") - def test_main_default_args(self, mock_play: MagicMock, mock_shell_class: MagicMock): - """Test main with default arguments.""" - mock_shell = MagicMock() - mock_shell_class.return_value = mock_shell - - with patch("sys.argv", ["pyrit_shell"]): result = pyrit_shell.main() + assert result == 1 + captured = capsys.readouterr() + assert "boom" in captured.out - assert result == 0 - call_kwargs = mock_shell_class.call_args[1] - assert call_kwargs["log_level"] == logging.WARNING - mock_shell.cmdloop.assert_called_once() - - @patch("pyrit.cli.pyrit_shell.PyRITShell") - @patch("pyrit.cli._banner.play_animation", return_value="") - def test_main_with_config_file_arg(self, mock_play: MagicMock, mock_shell_class: MagicMock): - """Test main with config-file argument.""" - mock_shell = MagicMock() - mock_shell_class.return_value = mock_shell + def test_main_log_level_and_config_file(self, tmp_path): + with ( + patch("pyrit.cli._banner.play_animation", return_value=""), + patch("pyrit.cli.pyrit_shell.PyRITShell") as mock_shell_class, + patch( + "sys.argv", + [ + "pyrit_shell", + "--no-animation", + "--log-level", + "DEBUG", + "--config-file", + str(tmp_path / "conf.yaml"), + "--start-server", + ], + ), + ): + mock_shell = MagicMock() + mock_shell_class.return_value = mock_shell + assert pyrit_shell.main() == 0 + kwargs = mock_shell_class.call_args.kwargs + assert kwargs["start_server"] is True + assert kwargs["config_file"] == tmp_path / "conf.yaml" - with patch("sys.argv", ["pyrit_shell", "--config-file", "my_config.yaml"]): - result = pyrit_shell.main() - assert result == 0 - call_kwargs = mock_shell_class.call_args[1] - assert call_kwargs["config_file"] == Path("my_config.yaml") +class TestResolveBaseUrl: + def test_explicit_server_url_wins(self): + s = pyrit_shell.PyRITShell(no_animation=True, server_url="http://custom:1234") + assert s._resolve_base_url() == "http://custom:1234" - @patch("pyrit.cli.pyrit_shell.PyRITShell") - @patch("pyrit.cli._banner.play_animation", return_value="") - def test_main_with_log_level_arg(self, mock_play: MagicMock, mock_shell_class: MagicMock): - """Test main with log-level argument.""" - mock_shell = MagicMock() - mock_shell_class.return_value = mock_shell + def test_falls_back_to_config_reader(self, tmp_path): + s = pyrit_shell.PyRITShell(no_animation=True) + with patch("pyrit.cli._config_reader.read_server_url", return_value="http://from-cfg:8000"): + assert s._resolve_base_url() == "http://from-cfg:8000" - with patch("sys.argv", ["pyrit_shell", "--log-level", "DEBUG"]): - result = pyrit_shell.main() + def test_default_when_config_returns_none(self): + s = pyrit_shell.PyRITShell(no_animation=True) + with patch("pyrit.cli._config_reader.read_server_url", return_value=None): + from pyrit.cli._config_reader import DEFAULT_SERVER_URL - assert result == 0 - call_kwargs = mock_shell_class.call_args[1] - assert call_kwargs["log_level"] == logging.DEBUG + assert s._resolve_base_url() == DEFAULT_SERVER_URL - @patch("pyrit.cli.pyrit_shell.PyRITShell") - @patch("pyrit.cli._banner.play_animation", return_value="") - def test_main_with_keyboard_interrupt(self, mock_play: MagicMock, mock_shell_class: MagicMock, capsys): - """Test main handles keyboard interrupt.""" - mock_shell = MagicMock() - mock_shell.cmdloop.side_effect = KeyboardInterrupt() - mock_shell_class.return_value = mock_shell - with patch("sys.argv", ["pyrit_shell"]): - result = pyrit_shell.main() +class TestEnsureClientStartServer: + def test_start_server_launches_when_not_running(self): + s = pyrit_shell.PyRITShell(no_animation=True, start_server=True) + with ( + patch( + "pyrit.cli._server_launcher.ServerLauncher.probe_health_async", + new_callable=AsyncMock, + return_value=False, + ), + patch("pyrit.cli._server_launcher.ServerLauncher.start_async", new_callable=AsyncMock) as mock_start, + patch("pyrit.cli.api_client.PyRITApiClient") as mock_client_class, + ): + mock_start.return_value = "http://localhost:8000" + mock_client = AsyncMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client_class.return_value = mock_client + assert s._ensure_client() is True + assert s._api_client is mock_client + assert s._start_server is False # only auto-start once + + def test_start_server_failure_returns_false(self, capsys): + s = pyrit_shell.PyRITShell(no_animation=True, start_server=True) + with ( + patch( + "pyrit.cli._server_launcher.ServerLauncher.probe_health_async", + new_callable=AsyncMock, + return_value=False, + ), + patch( + "pyrit.cli._server_launcher.ServerLauncher.start_async", + new_callable=AsyncMock, + side_effect=RuntimeError("nope"), + ), + ): + assert s._ensure_client() is False + assert "Error starting server: nope" in capsys.readouterr().out + + +class TestDoAddInitializer: + def test_no_args_prints_usage(self, shell, capsys): + s, _ = shell + s.do_add_initializer("") + assert "Usage" in capsys.readouterr().out + + def test_file_not_found(self, shell, capsys): + s, _ = shell + s.do_add_initializer("/nonexistent/_xyz_not_a_file.py") + assert "File not found" in capsys.readouterr().out + + def test_success_path(self, shell, tmp_path, capsys): + s, client = shell + script = tmp_path / "my_init.py" + script.write_text("def init(): pass") + client.register_initializer_async = AsyncMock(return_value={"status": "ok"}) + s.do_add_initializer(str(script)) + assert "Registered initializer 'my_init'" in capsys.readouterr().out + client.register_initializer_async.assert_awaited_once() + + def test_server_not_available_error(self, shell, tmp_path, capsys): + from pyrit.cli.api_client import ServerNotAvailableError + + s, client = shell + script = tmp_path / "init.py" + script.write_text("x = 1") + client.register_initializer_async = AsyncMock(side_effect=ServerNotAvailableError("server gone")) + s.do_add_initializer(str(script)) + assert "server gone" in capsys.readouterr().out + + def test_generic_error(self, shell, tmp_path, capsys): + s, client = shell + script = tmp_path / "init.py" + script.write_text("x = 1") + client.register_initializer_async = AsyncMock(side_effect=RuntimeError("boom")) + s.do_add_initializer(str(script)) + assert "Error registering initializer: boom" in capsys.readouterr().out + + +class TestDoRun: + def _run_payload(self, status="COMPLETED"): + return {"scenario_result_id": "rid-1", "status": status} + + def test_run_invalid_arguments(self, shell, capsys): + s, _ = shell + with patch("pyrit.cli._cli_args.parse_run_arguments", side_effect=ValueError("bad")): + s.do_run("foo --target t") + assert "Error: bad" in capsys.readouterr().out + + def test_run_start_failure(self, shell, capsys): + s, client = shell + client.start_scenario_run_async = AsyncMock(side_effect=RuntimeError("nope")) + with patch( + "pyrit.cli._cli_args.parse_run_arguments", + return_value={"scenario_name": "foo", "target": "t"}, + ): + s.do_run("foo --target t") + assert "Error starting scenario: nope" in capsys.readouterr().out + + def test_run_completed_path_with_results(self, shell, capsys): + s, client = shell + client.start_scenario_run_async = AsyncMock(return_value=self._run_payload()) + client.get_scenario_run_async = AsyncMock(return_value=self._run_payload("COMPLETED")) + client.get_scenario_run_results_async = AsyncMock(return_value={"items": []}) + with ( + patch( + "pyrit.cli._cli_args.parse_run_arguments", + return_value={ + "scenario_name": "foo", + "target": "t", + "initializers": ["a", {"name": "b", "args": {"x": 1}}], + "scenario_strategies": ["s1"], + "max_concurrency": 2, + "max_retries": 3, + "memory_labels": {"k": "v"}, + "dataset_names": ["d1"], + "max_dataset_size": 5, + }, + ), + patch("pyrit.cli._output.print_scenario_result_async", new_callable=AsyncMock), + patch("pyrit.cli._output.print_scenario_run_progress"), + patch("pyrit.cli._output.print_scenario_run_summary"), + patch("time.sleep"), + ): + s.do_run("foo --target t") + kwargs = client.start_scenario_run_async.call_args.kwargs["request"] + assert kwargs["initializers"] == ["a", "b"] + assert kwargs["initializer_args"] == {"b": {"x": 1}} + assert kwargs["strategies"] == ["s1"] + assert kwargs["max_concurrency"] == 2 + assert kwargs["max_retries"] == 3 + assert kwargs["labels"] == {"k": "v"} + assert kwargs["dataset_names"] == ["d1"] + assert kwargs["max_dataset_size"] == 5 + + def test_run_failed_status_calls_summary(self, shell): + s, client = shell + client.start_scenario_run_async = AsyncMock(return_value=self._run_payload()) + client.get_scenario_run_async = AsyncMock(return_value=self._run_payload("FAILED")) + with ( + patch( + "pyrit.cli._cli_args.parse_run_arguments", + return_value={"scenario_name": "foo", "target": "t"}, + ), + patch("pyrit.cli._output.print_scenario_run_progress"), + patch("pyrit.cli._output.print_scenario_run_summary") as mock_summary, + patch("time.sleep"), + ): + s.do_run("foo --target t") + mock_summary.assert_called_once() + + def test_run_completed_fallback_to_summary_on_results_error(self, shell): + s, client = shell + client.start_scenario_run_async = AsyncMock(return_value=self._run_payload()) + client.get_scenario_run_async = AsyncMock(return_value=self._run_payload("COMPLETED")) + client.get_scenario_run_results_async = AsyncMock(side_effect=RuntimeError("nope")) + with ( + patch( + "pyrit.cli._cli_args.parse_run_arguments", + return_value={"scenario_name": "foo", "target": "t"}, + ), + patch("pyrit.cli._output.print_scenario_run_progress"), + patch("pyrit.cli._output.print_scenario_run_summary") as mock_summary, + patch("time.sleep"), + ): + s.do_run("foo --target t") + mock_summary.assert_called_once() + + def test_run_keyboard_interrupt_cancels(self, shell, capsys): + s, client = shell + client.start_scenario_run_async = AsyncMock(return_value=self._run_payload()) + # Use MagicMock so KeyboardInterrupt raises synchronously on call — + # this simulates Ctrl+C arriving between polling iterations, matching + # how signals are delivered to the shell's main thread in production. + client.get_scenario_run_async = MagicMock(side_effect=KeyboardInterrupt) + client.cancel_scenario_run_async = AsyncMock(return_value=None) + with ( + patch( + "pyrit.cli._cli_args.parse_run_arguments", + return_value={"scenario_name": "foo", "target": "t"}, + ), + patch("pyrit.cli._output.print_scenario_run_progress"), + patch("time.sleep"), + ): + s.do_run("foo --target t") + client.cancel_scenario_run_async.assert_awaited_once() + assert "cancelled" in capsys.readouterr().out.lower() + + def test_run_keyboard_interrupt_cancel_fails_warns(self, shell, capsys): + s, client = shell + client.start_scenario_run_async = AsyncMock(return_value=self._run_payload()) + client.get_scenario_run_async = MagicMock(side_effect=KeyboardInterrupt) + client.cancel_scenario_run_async = AsyncMock(side_effect=RuntimeError("offline")) + with ( + patch( + "pyrit.cli._cli_args.parse_run_arguments", + return_value={"scenario_name": "foo", "target": "t"}, + ), + patch("pyrit.cli._output.print_scenario_run_progress"), + patch("time.sleep"), + ): + s.do_run("foo --target t") + assert "could not cancel" in capsys.readouterr().out.lower() - assert result == 0 - captured = capsys.readouterr() - assert "Interrupted" in captured.out - @patch("pyrit.cli.pyrit_shell.PyRITShell") - @patch("pyrit.cli._banner.play_animation", return_value="") - def test_main_with_exception(self, mock_play: MagicMock, mock_shell_class: MagicMock, capsys): - """Test main handles exceptions.""" - mock_shell = MagicMock() - mock_shell.cmdloop.side_effect = ValueError("Test error") - mock_shell_class.return_value = mock_shell +class TestListErrors: + def test_list_scenarios_error(self, shell, capsys): + s, client = shell + client.list_scenarios_async = AsyncMock(side_effect=RuntimeError("x")) + s.do_list_scenarios("") + assert "Error listing scenarios" in capsys.readouterr().out - with patch("sys.argv", ["pyrit_shell"]): - result = pyrit_shell.main() + def test_list_initializers_error(self, shell, capsys): + s, client = shell + client.list_initializers_async = AsyncMock(side_effect=RuntimeError("x")) + s.do_list_initializers("") + assert "Error listing initializers" in capsys.readouterr().out - assert result == 1 - captured = capsys.readouterr() - assert "Error:" in captured.out - - @patch("pyrit.cli.pyrit_shell.PyRITShell") - @patch("pyrit.cli._banner.play_animation", return_value="") - def test_main_creates_context_without_initializers(self, mock_play: MagicMock, mock_shell_class: MagicMock): - """Test main creates context without initializers.""" - mock_shell = MagicMock() - mock_shell_class.return_value = mock_shell - - with patch("sys.argv", ["pyrit_shell"]): - pyrit_shell.main() - - call_kwargs = mock_shell_class.call_args[1] - # main() should not pass initialization_scripts or initializer_names - assert "initialization_scripts" not in call_kwargs - assert "initializer_names" not in call_kwargs - - @patch("pyrit.cli.pyrit_shell.PyRITShell") - @patch("pyrit.cli._banner.play_animation", return_value="") - def test_main_with_no_animation_flag(self, mock_play: MagicMock, mock_shell_class: MagicMock): - """Test main passes --no-animation flag to PyRITShell.""" - mock_shell = MagicMock() - mock_shell_class.return_value = mock_shell - - with patch("sys.argv", ["pyrit_shell", "--no-animation"]): - result = pyrit_shell.main() + def test_list_targets_error(self, shell, capsys): + s, client = shell + client.list_targets_async = AsyncMock(side_effect=RuntimeError("x")) + s.do_list_targets("") + assert "Error listing targets" in capsys.readouterr().out - assert result == 0 - call_kwargs = mock_shell_class.call_args[1] - assert call_kwargs["no_animation"] is True + def test_scenario_history_error(self, shell, capsys): + s, client = shell + client.list_scenario_runs_async = AsyncMock(side_effect=RuntimeError("x")) + s.do_scenario_history("") + assert "Error" in capsys.readouterr().out + + +class TestPrintScenarioAndHelp: + def test_print_scenario_success(self, shell): + s, client = shell + client.get_scenario_run_results_async = AsyncMock(return_value={"items": []}) + with patch("pyrit.cli._output.print_scenario_result_async", new_callable=AsyncMock) as mock_print: + s.do_print_scenario("rid-1") + mock_print.assert_awaited_once() + + def test_print_scenario_error(self, shell, capsys): + s, client = shell + client.get_scenario_run_results_async = AsyncMock(side_effect=RuntimeError("oops")) + s.do_print_scenario("rid-1") + assert "Error: oops" in capsys.readouterr().out + + def test_do_help_with_arg_normalizes_hyphen(self, shell): + s, _ = shell + with patch("cmd.Cmd.do_help") as mock_help: + s.do_help("list-scenarios") + mock_help.assert_called_once_with("list_scenarios") + + def test_do_help_no_arg(self, shell, capsys): + s, _ = shell + with patch("cmd.Cmd.do_help"): + s.do_help("") + assert "Use 'help '" in capsys.readouterr().out - @patch("pyrit.cli.pyrit_shell.PyRITShell") - @patch("pyrit.cli._banner.play_animation", return_value="") - def test_main_default_animation_enabled(self, mock_play: MagicMock, mock_shell_class: MagicMock): - """Test main defaults to animation enabled (no_animation=False).""" - mock_shell = MagicMock() - mock_shell_class.return_value = mock_shell - with patch("sys.argv", ["pyrit_shell"]): - result = pyrit_shell.main() +class TestServerManagement: + def test_start_server_already_running(self, capsys): + s = pyrit_shell.PyRITShell(no_animation=True) + with ( + patch( + "pyrit.cli._server_launcher.ServerLauncher.probe_health_async", + new_callable=AsyncMock, + return_value=True, + ), + patch("pyrit.cli.api_client.PyRITApiClient") as mock_client_class, + ): + mock_client = AsyncMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client_class.return_value = mock_client + s.do_start_server("") + assert "already running" in capsys.readouterr().out + assert s._api_client is mock_client + + def test_start_server_launch_success(self): + s = pyrit_shell.PyRITShell(no_animation=True) + with ( + patch( + "pyrit.cli._server_launcher.ServerLauncher.probe_health_async", + new_callable=AsyncMock, + return_value=False, + ), + patch("pyrit.cli._server_launcher.ServerLauncher.start_async", new_callable=AsyncMock) as mock_start, + patch("pyrit.cli.api_client.PyRITApiClient") as mock_client_class, + ): + mock_start.return_value = "http://localhost:8000" + mock_client = AsyncMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client_class.return_value = mock_client + s.do_start_server("") + assert s._base_url == "http://localhost:8000" + + def test_start_server_launch_replaces_existing_client(self): + s = pyrit_shell.PyRITShell(no_animation=True) + existing = AsyncMock() + existing.close_async = AsyncMock() + s._api_client = existing + with ( + patch( + "pyrit.cli._server_launcher.ServerLauncher.probe_health_async", + new_callable=AsyncMock, + return_value=False, + ), + patch("pyrit.cli._server_launcher.ServerLauncher.start_async", new_callable=AsyncMock) as mock_start, + patch("pyrit.cli.api_client.PyRITApiClient") as mock_client_class, + ): + mock_start.return_value = "http://localhost:8000" + new_client = AsyncMock() + new_client.__aenter__ = AsyncMock(return_value=new_client) + mock_client_class.return_value = new_client + s.do_start_server("") + existing.close_async.assert_awaited_once() + assert s._api_client is new_client + + def test_start_server_launch_failure(self, capsys): + s = pyrit_shell.PyRITShell(no_animation=True) + with ( + patch( + "pyrit.cli._server_launcher.ServerLauncher.probe_health_async", + new_callable=AsyncMock, + return_value=False, + ), + patch( + "pyrit.cli._server_launcher.ServerLauncher.start_async", + new_callable=AsyncMock, + side_effect=RuntimeError("nope"), + ), + ): + s.do_start_server("") + assert "nope" in capsys.readouterr().out + + def test_stop_server_with_owned_launcher(self, shell, capsys): + s, client = shell + launcher = MagicMock() + s._launcher = launcher + s.do_stop_server("") + launcher.stop.assert_called_once() + assert "Server stopped" in capsys.readouterr().out + assert s._launcher is None + assert s._api_client is None + + def test_stop_server_by_port_success(self, shell, capsys): + s, _ = shell + s._base_url = "http://localhost:8000" + with ( + patch( + "pyrit.cli._server_launcher.ServerLauncher.probe_health_async", + new_callable=AsyncMock, + return_value=True, + ), + patch("pyrit.cli._server_launcher.stop_server_on_port", return_value=True), + ): + s.do_stop_server("") + assert "stopped" in capsys.readouterr().out - assert result == 0 - call_kwargs = mock_shell_class.call_args[1] - assert call_kwargs["no_animation"] is False - - -class TestPyRITShellRunCommand: - """Detailed tests for the run command.""" - - @patch("pyrit.cli.pyrit_shell.asyncio.run") - @patch("pyrit.cli.frontend_core.parse_run_arguments") - def test_run_with_all_parameters( - self, - mock_parse_args: MagicMock, - mock_asyncio_run: MagicMock, - shell, - ): - """Test run command with all parameters.""" - s, ctx, _ = shell - - mock_parse_args.return_value = { - "scenario_name": "test_scenario", - "initializers": ["init1"], - "initialization_scripts": None, - "scenario_strategies": ["s1", "s2"], - "max_concurrency": 10, - "max_retries": 5, - "memory_labels": {"key": "value"}, - "log_level": "DEBUG", - "dataset_names": None, - "max_dataset_size": None, - "target": None, + def test_stop_server_by_port_skips_when_no_pyrit_backend(self, shell, capsys): + s, _ = shell + s._base_url = "http://localhost:8000" + with ( + patch( + "pyrit.cli._server_launcher.ServerLauncher.probe_health_async", + new_callable=AsyncMock, + return_value=False, + ), + patch("pyrit.cli._server_launcher.stop_server_on_port") as mock_stop, + ): + s.do_stop_server("") + mock_stop.assert_not_called() + assert "not stopping" in capsys.readouterr().out + + def test_stop_server_close_client_swallows_errors(self, shell): + s, client = shell + launcher = MagicMock() + s._launcher = launcher + client.close_async = AsyncMock(side_effect=RuntimeError("ignored")) + s.do_stop_server("") + assert s._api_client is None + + +class TestShellScenarioParamFlow: + """Regression tests: shell.do_run must forward scenario-declared parameters.""" + + def test_run_passes_scenario_declared_params(self, shell): + s, client = shell + client.get_scenario_async.return_value = { + "scenario_name": "foo", + "supported_parameters": [{"name": "max_turns", "description": "..."}], } + client.start_scenario_run_async = AsyncMock(return_value={"scenario_result_id": "rid", "status": "CREATED"}) + client.get_scenario_run_async = AsyncMock(return_value={"scenario_result_id": "rid", "status": "COMPLETED"}) + client.get_scenario_run_results_async = AsyncMock(return_value={"items": []}) - mock_asyncio_run.side_effect = [MagicMock()] - - with patch("pyrit.cli.frontend_core.FrontendCore"), patch("pyrit.cli.frontend_core.run_scenario_async"): - s.do_run("test_scenario --initializers init1 --strategies s1 s2 --max-concurrency 10") - - # Verify run_scenario_async was called with correct args - # (it's called via asyncio.run, so check the mock_asyncio_run call) - assert mock_asyncio_run.call_count == 1 - - @patch("pyrit.cli.pyrit_shell.asyncio.run") - @patch("pyrit.cli.frontend_core.parse_run_arguments") - def test_run_stores_result_in_history( - self, - mock_parse_args: MagicMock, - mock_asyncio_run: MagicMock, - shell, - ): - """Test run command stores result in history.""" - s, ctx, _ = shell - - mock_parse_args.return_value = { - "scenario_name": "test_scenario", - "initializers": ["test_init"], - "initialization_scripts": None, - "scenario_strategies": None, - "max_concurrency": None, - "max_retries": None, - "memory_labels": None, - "log_level": None, - "dataset_names": None, - "max_dataset_size": None, - "target": None, + with ( + patch("pyrit.cli._output.print_scenario_result_async", new_callable=AsyncMock), + patch("pyrit.cli._output.print_scenario_run_progress"), + patch("time.sleep"), + ): + s.do_run("foo --target t --max-turns 7") + + sent_request = client.start_scenario_run_async.call_args.kwargs["request"] + assert sent_request["scenario_params"] == {"max_turns": "7"} + + def test_run_metadata_fetch_failure_aborts(self, shell, capsys): + s, client = shell + client.get_scenario_async = AsyncMock(side_effect=RuntimeError("net down")) + s.do_run("foo --target t") + assert "Error fetching scenario metadata" in capsys.readouterr().out + + def test_run_unknown_scenario_aborts(self, shell, capsys): + s, client = shell + client.get_scenario_async.return_value = None + s.do_run("foo --target t") + assert "not found on server" in capsys.readouterr().out + + def test_run_unknown_flag_for_scenario_with_declared_params_errors(self, shell, capsys): + s, client = shell + client.get_scenario_async.return_value = { + "scenario_name": "foo", + "supported_parameters": [{"name": "max_turns", "description": "..."}], } + s.do_run("foo --target t --not-a-real-flag x") + captured = capsys.readouterr().out + assert "Unknown argument" in captured or "Error" in captured + + def test_run_fat_fingered_flag_with_no_scenario_params_errors(self, shell, capsys): + """Even when the scenario declares no params, unknown flags must error (no silent no-op).""" + s, client = shell + client.get_scenario_async.return_value = {"scenario_name": "foo", "supported_parameters": []} + s.do_run("foo --target t --initialization-scripts /nope.py") + captured = capsys.readouterr().out + assert "Unknown argument: --initialization-scripts" in captured + client.start_scenario_run_async.assert_not_called() + + def test_run_fat_fingered_log_level_flag_errors(self, shell, capsys): + """--log-level was a stale shell-only flag; passing it must now error.""" + s, client = shell + client.get_scenario_async.return_value = {"scenario_name": "foo", "supported_parameters": []} + s.do_run("foo --target t --log-level DEBUG") + captured = capsys.readouterr().out + assert "Unknown argument: --log-level" in captured + client.start_scenario_run_async.assert_not_called() + + +class TestScenarioParamCoercionInShell: + """Shell-side regression tests for typed scenario params from the catalog.""" + + def test_shell_list_param_collects_multiple_values(self, shell): + s, client = shell + client.get_scenario_async.return_value = { + "scenario_name": "foo", + "supported_parameters": [ + {"name": "items", "description": "list field", "param_type": "list[str]", "is_list": True} + ], + } + client.start_scenario_run_async = AsyncMock(return_value={"scenario_result_id": "rid", "status": "CREATED"}) + client.get_scenario_run_async = AsyncMock(return_value={"scenario_result_id": "rid", "status": "COMPLETED"}) + client.get_scenario_run_results_async = AsyncMock(return_value={"items": []}) - mock_result1 = MagicMock() - mock_result2 = MagicMock() - mock_asyncio_run.side_effect = [mock_result1, mock_result2] - - # Run two scenarios - s.do_run("scenario1 --initializers init1") - s.do_run("scenario2 --initializers init2") - - # Verify both are in history - assert len(s._scenario_history) == 2 - assert s._scenario_history[0][1] == mock_result1 - assert s._scenario_history[1][1] == mock_result2 + with ( + patch("pyrit.cli._output.print_scenario_result_async", new_callable=AsyncMock), + patch("pyrit.cli._output.print_scenario_run_progress"), + patch("time.sleep"), + ): + s.do_run("foo --target t --items a b c") + + sent = client.start_scenario_run_async.call_args.kwargs["request"] + assert sent["scenario_params"] == {"items": ["a", "b", "c"]} + + def test_shell_choices_rejected_before_request(self, shell, capsys): + s, client = shell + client.get_scenario_async.return_value = { + "scenario_name": "foo", + "supported_parameters": [ + {"name": "mode", "description": "...", "param_type": "str", "choices": ["fast", "slow"]} + ], + } + s.do_run("foo --target t --mode warp") + out = capsys.readouterr().out + # Parameter.coerce_value raises ValueError on out-of-choice values; + # do_run surfaces these as "Error: ...". + assert "Error" in out + client.start_scenario_run_async.assert_not_called() diff --git a/tests/unit/cli/test_server_launcher.py b/tests/unit/cli/test_server_launcher.py new file mode 100644 index 0000000000..8f3b4d3ea6 --- /dev/null +++ b/tests/unit/cli/test_server_launcher.py @@ -0,0 +1,157 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Unit tests for pyrit.cli._server_launcher.ServerLauncher. +""" + +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from pyrit.cli._server_launcher import ServerLauncher + +# --------------------------------------------------------------------------- +# probe_health_async +# --------------------------------------------------------------------------- + + +async def test_probe_health_returns_true_when_client_healthy(): + fake_client = MagicMock() + fake_client.health_check_async = AsyncMock(return_value=True) + fake_client.__aenter__ = AsyncMock(return_value=fake_client) + fake_client.__aexit__ = AsyncMock(return_value=None) + + with patch("pyrit.cli.api_client.PyRITApiClient", return_value=fake_client): + result = await ServerLauncher.probe_health_async(base_url="http://localhost:8000") + assert result is True + fake_client.health_check_async.assert_awaited_once() + + +async def test_probe_health_returns_false_when_client_unhealthy(): + fake_client = MagicMock() + fake_client.health_check_async = AsyncMock(return_value=False) + fake_client.__aenter__ = AsyncMock(return_value=fake_client) + fake_client.__aexit__ = AsyncMock(return_value=None) + + with patch("pyrit.cli.api_client.PyRITApiClient", return_value=fake_client): + result = await ServerLauncher.probe_health_async(base_url="http://localhost:8000") + assert result is False + + +# --------------------------------------------------------------------------- +# start_async +# --------------------------------------------------------------------------- + + +async def test_start_async_returns_url_when_already_healthy(): + launcher = ServerLauncher() + with patch.object(ServerLauncher, "probe_health_async", new=AsyncMock(return_value=True)): + url = await launcher.start_async(host="localhost", port=8000) + assert url == "http://localhost:8000" + # Should not have created a subprocess. + assert launcher.pid is None + + +async def test_start_async_spawns_subprocess_and_waits_for_health(): + launcher = ServerLauncher() + fake_proc = MagicMock() + fake_proc.pid = 4321 + fake_proc.poll.return_value = None + # First health probe (already-running check) returns False, second returns True + probe = AsyncMock(side_effect=[False, True]) + + with ( + patch.object(ServerLauncher, "probe_health_async", new=probe), + patch("subprocess.Popen", return_value=fake_proc) as popen_mock, + patch("asyncio.sleep", new=AsyncMock(return_value=None)), + ): + url = await launcher.start_async( + host="localhost", + port=8001, + config_file=Path("/tmp/foo.yaml"), + log_level="INFO", + startup_timeout=5, + ) + assert url == "http://localhost:8001" + assert launcher.pid == 4321 + # Verify command construction + cmd = popen_mock.call_args.args[0] + assert "pyrit.backend.pyrit_backend" in cmd + assert "--config-file" in cmd + assert "/tmp/foo.yaml" in cmd or "\\tmp\\foo.yaml" in cmd + assert "--log-level" in cmd + assert "INFO" in cmd + + +async def test_start_async_raises_when_process_crashes_during_startup(): + launcher = ServerLauncher() + fake_proc = MagicMock() + fake_proc.pid = 42 + fake_proc.poll.return_value = 1 # exited + probe = AsyncMock(return_value=False) + + with ( + patch.object(ServerLauncher, "probe_health_async", new=probe), + patch("subprocess.Popen", return_value=fake_proc), + patch("asyncio.sleep", new=AsyncMock(return_value=None)), + ): + with pytest.raises(RuntimeError, match="exited with code 1"): + await launcher.start_async(host="localhost", port=8000, startup_timeout=3) + + +async def test_start_async_raises_when_timeout_exhausted(): + launcher = ServerLauncher() + fake_proc = MagicMock() + fake_proc.pid = 99 + fake_proc.poll.return_value = None # still running + probe = AsyncMock(return_value=False) + + with ( + patch.object(ServerLauncher, "probe_health_async", new=probe), + patch("subprocess.Popen", return_value=fake_proc), + patch("asyncio.sleep", new=AsyncMock(return_value=None)), + ): + with pytest.raises(RuntimeError, match="did not become healthy"): + await launcher.start_async(host="localhost", port=8000, startup_timeout=2) + + +# --------------------------------------------------------------------------- +# stop +# --------------------------------------------------------------------------- + + +def test_stop_terminates_process(): + launcher = ServerLauncher() + fake_proc = MagicMock() + fake_proc.pid = 12345 + launcher._process = fake_proc + launcher._pid = 12345 + + launcher.stop() + + fake_proc.terminate.assert_called_once() + fake_proc.wait.assert_called_once_with(timeout=5) + assert launcher.pid is None + assert launcher._process is None + + +def test_stop_swallows_termination_errors(): + launcher = ServerLauncher() + fake_proc = MagicMock() + fake_proc.pid = 12345 + fake_proc.terminate.side_effect = OSError("permission denied") + launcher._process = fake_proc + launcher._pid = 12345 + + # Should not raise. + launcher.stop() + assert launcher._process is None + assert launcher.pid is None + + +def test_stop_is_noop_when_no_process(): + launcher = ServerLauncher() + launcher.stop() # Should not raise. + assert launcher.pid is None diff --git a/tests/unit/setup/test_configuration_loader.py b/tests/unit/setup/test_configuration_loader.py index bba5ab6810..be994d631b 100644 --- a/tests/unit/setup/test_configuration_loader.py +++ b/tests/unit/setup/test_configuration_loader.py @@ -11,6 +11,7 @@ ConfigurationLoader, InitializerConfig, ScenarioConfig, + ServerConfig, initialize_from_config_async, ) @@ -246,53 +247,53 @@ def test_get_default_config_path(self): class TestConfigurationLoaderResolvers: """Tests for ConfigurationLoader path resolution methods.""" - def test_resolve_initialization_scripts_none_returns_none(self): + def testresolve_initialization_scripts_none_returns_none(self): """Test that None (default) returns None to signal 'use defaults'.""" config = ConfigurationLoader() - assert config._resolve_initialization_scripts() is None + assert config.resolve_initialization_scripts() is None - def test_resolve_initialization_scripts_empty_list_returns_empty_list(self): + def testresolve_initialization_scripts_empty_list_returns_empty_list(self): """Test that explicit empty list [] returns empty list to signal 'load nothing'.""" config = ConfigurationLoader(initialization_scripts=[]) - resolved = config._resolve_initialization_scripts() + resolved = config.resolve_initialization_scripts() assert resolved is not None assert resolved == [] - def test_resolve_initialization_scripts_absolute_path(self): + def testresolve_initialization_scripts_absolute_path(self): """Test resolving absolute script paths.""" config = ConfigurationLoader(initialization_scripts=["/absolute/path/script.py"]) - resolved = config._resolve_initialization_scripts() + resolved = config.resolve_initialization_scripts() assert resolved is not None assert len(resolved) == 1 # Check path ends with expected components (Windows adds drive letter to Unix-style paths) assert resolved[0].parts[-3:] == ("absolute", "path", "script.py") - def test_resolve_initialization_scripts_relative_path(self): + def testresolve_initialization_scripts_relative_path(self): """Test resolving relative script paths (converted to absolute).""" config = ConfigurationLoader(initialization_scripts=["relative/script.py"]) - resolved = config._resolve_initialization_scripts() + resolved = config.resolve_initialization_scripts() assert resolved is not None assert len(resolved) == 1 assert resolved[0].is_absolute() # Check path ends with expected components (works on both Unix and Windows) assert resolved[0].parts[-2:] == ("relative", "script.py") - def test_resolve_env_files_none_returns_none(self): + def testresolve_env_files_none_returns_none(self): """Test that None (default) returns None to signal 'use defaults'.""" config = ConfigurationLoader() - assert config._resolve_env_files() is None + assert config.resolve_env_files() is None - def test_resolve_env_files_empty_list_returns_empty_list(self): + def testresolve_env_files_empty_list_returns_empty_list(self): """Test that explicit empty list [] returns empty list to signal 'load nothing'.""" config = ConfigurationLoader(env_files=[]) - resolved = config._resolve_env_files() + resolved = config.resolve_env_files() assert resolved is not None assert resolved == [] - def test_resolve_env_files_absolute_path(self): + def testresolve_env_files_absolute_path(self): """Test resolving absolute env file paths.""" config = ConfigurationLoader(env_files=["/path/to/.env"]) - resolved = config._resolve_env_files() + resolved = config.resolve_env_files() assert resolved is not None assert len(resolved) == 1 # Check path ends with expected components (Windows adds drive letter to Unix-style paths) @@ -623,3 +624,27 @@ def test_load_with_overrides_passes_scenario_through_explicit_config(self): assert config._scenario_config == ScenarioConfig(name="scam", args={"max_turns": 10}) finally: config_path.unlink() + + +class TestNormalizeServer: + """Tests for ConfigurationLoader._normalize_server.""" + + def test_server_none_yields_no_server_config(self): + config = ConfigurationLoader(server=None) + assert config.server_config is None + + def test_server_dict_with_url_normalizes(self): + config = ConfigurationLoader(server={"url": "http://remote:9000/"}) + assert config.server_config == ServerConfig(url="http://remote:9000") + + def test_server_dict_without_url_uses_default(self): + config = ConfigurationLoader(server={}) + assert config.server_config == ServerConfig(url="http://localhost:8000") + + def test_server_url_non_string_raises(self): + with pytest.raises(ValueError, match="Server 'url' must be a string"): + ConfigurationLoader(server={"url": 12345}) + + def test_server_non_dict_raises(self): + with pytest.raises(ValueError, match="Server entry must be a dict"): + ConfigurationLoader(server="http://oops:8000") # type: ignore[arg-type]