From 1165cfa849e2604c052007841db22ea719a8481e Mon Sep 17 00:00:00 2001 From: Pigbibi <20649888+Pigbibi@users.noreply.github.com> Date: Wed, 10 Jun 2026 20:49:36 +0800 Subject: [PATCH] Apply audit remediation --- .github/workflows/ci.yml | 4 ++ .github/workflows/dependabot_auto_merge.yml | 1 + .../common/strategy_plugin_artifacts.py | 58 +++++++++++++++++++ .../common/strategy_plugins.py | 42 +++++--------- tests/test_strategy_plugin_artifacts.py | 52 +++++++++++++++++ 5 files changed, 130 insertions(+), 27 deletions(-) create mode 100644 src/quant_platform_kit/common/strategy_plugin_artifacts.py create mode 100644 tests/test_strategy_plugin_artifacts.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 439f592..603cfd8 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -5,9 +5,13 @@ on: branches: [ main ] pull_request: +permissions: + contents: read + jobs: test: runs-on: ubuntu-latest + timeout-minutes: 20 steps: - name: Checkout uses: actions/checkout@v6 diff --git a/.github/workflows/dependabot_auto_merge.yml b/.github/workflows/dependabot_auto_merge.yml index f3f9f1a..a3a4988 100644 --- a/.github/workflows/dependabot_auto_merge.yml +++ b/.github/workflows/dependabot_auto_merge.yml @@ -9,6 +9,7 @@ jobs: auto-merge: if: github.event.workflow_run.conclusion == 'success' && startsWith(github.event.workflow_run.head_branch, 'dependabot/') runs-on: ubuntu-latest + timeout-minutes: 10 permissions: contents: write pull-requests: write diff --git a/src/quant_platform_kit/common/strategy_plugin_artifacts.py b/src/quant_platform_kit/common/strategy_plugin_artifacts.py new file mode 100644 index 0000000..0e9a863 --- /dev/null +++ b/src/quant_platform_kit/common/strategy_plugin_artifacts.py @@ -0,0 +1,58 @@ +"""Artifact path helpers for strategy plugin signal files.""" + +from __future__ import annotations + +import hashlib +from pathlib import Path +from typing import Any + + +def materialize_local_or_gcs_artifact( + reference: str, + *, + cache_dir: Path, + client_factory: Any = None, +) -> tuple[Path, dict[str, str | None]]: + raw_reference = _required_string(reference, field_name="reference") + if not raw_reference.startswith("gs://"): + return Path(raw_reference).expanduser(), {"source_uri": None, "local_path": raw_reference} + + local_path = cache_path_for_remote_artifact(raw_reference, cache_dir=cache_dir) + download_gcs_object(raw_reference, local_path, client_factory=client_factory) + return local_path, {"source_uri": raw_reference, "local_path": str(local_path)} + + +def download_gcs_object(uri: str, destination: Path, *, client_factory: Any = None) -> None: + if client_factory is None: + try: + from google.cloud import storage # type: ignore + except ImportError as exc: + raise RuntimeError("google-cloud-storage is required for GCS strategy plugin artifacts") from exc + client_factory = storage.Client + bucket_name, object_name = parse_gcs_uri(uri) + destination.parent.mkdir(parents=True, exist_ok=True) + client = client_factory() + client.bucket(bucket_name).blob(object_name).download_to_filename(str(destination)) + + +def parse_gcs_uri(uri: str) -> tuple[str, str]: + raw_uri = str(uri or "").strip() + if not raw_uri.startswith("gs://"): + raise ValueError(f"Unsupported GCS URI: {raw_uri}") + bucket_name, _, object_name = raw_uri[5:].partition("/") + if not bucket_name or not object_name: + raise ValueError(f"Invalid GCS URI: {raw_uri}") + return bucket_name, object_name + + +def cache_path_for_remote_artifact(reference: str, *, cache_dir: Path) -> Path: + digest = hashlib.sha256(reference.encode("utf-8")).hexdigest()[:16] + leaf_name = Path(reference).name or "latest_signal.json" + return cache_dir / digest / leaf_name + + +def _required_string(value: Any, *, field_name: str) -> str: + text = str(value or "").strip() + if not text: + raise ValueError(f"{field_name} must be a non-empty string") + return text diff --git a/src/quant_platform_kit/common/strategy_plugins.py b/src/quant_platform_kit/common/strategy_plugins.py index 0438ccb..d3aac15 100644 --- a/src/quant_platform_kit/common/strategy_plugins.py +++ b/src/quant_platform_kit/common/strategy_plugins.py @@ -11,6 +11,13 @@ from pathlib import Path from typing import Any, Callable +from quant_platform_kit.common.strategy_plugin_artifacts import ( + cache_path_for_remote_artifact, + download_gcs_object, + materialize_local_or_gcs_artifact, + parse_gcs_uri, +) + PLUGIN_CRISIS_RESPONSE_SHADOW = "crisis_response_shadow" PLUGIN_MARKET_REGIME_CONTROL = "market_regime_control" PLUGIN_MACRO_RISK_GOVERNOR = "macro_risk_governor" @@ -1237,42 +1244,23 @@ def _sanitize_key_part(value: Any) -> str: def _materialize_artifact_path(reference: str, *, client_factory: Any = None) -> tuple[Path, dict[str, str | None]]: - raw_reference = _required_string(reference, field_name="reference") - if not raw_reference.startswith("gs://"): - return Path(raw_reference).expanduser(), {"source_uri": None, "local_path": raw_reference} - - local_path = _cache_path_for_remote_artifact(raw_reference) - _download_gcs_object(raw_reference, local_path, client_factory=client_factory) - return local_path, {"source_uri": raw_reference, "local_path": str(local_path)} + return materialize_local_or_gcs_artifact( + reference, + cache_dir=DEFAULT_PLUGIN_ARTIFACT_CACHE_DIR, + client_factory=client_factory, + ) def _download_gcs_object(uri: str, destination: Path, *, client_factory: Any = None) -> None: - if client_factory is None: - try: - from google.cloud import storage # type: ignore - except ImportError as exc: - raise RuntimeError("google-cloud-storage is required for GCS strategy plugin artifacts") from exc - client_factory = storage.Client - bucket_name, object_name = _parse_gcs_uri(uri) - destination.parent.mkdir(parents=True, exist_ok=True) - client = client_factory() - client.bucket(bucket_name).blob(object_name).download_to_filename(str(destination)) + download_gcs_object(uri, destination, client_factory=client_factory) def _parse_gcs_uri(uri: str) -> tuple[str, str]: - raw_uri = str(uri or "").strip() - if not raw_uri.startswith("gs://"): - raise ValueError(f"Unsupported GCS URI: {raw_uri}") - bucket_name, _, object_name = raw_uri[5:].partition("/") - if not bucket_name or not object_name: - raise ValueError(f"Invalid GCS URI: {raw_uri}") - return bucket_name, object_name + return parse_gcs_uri(uri) def _cache_path_for_remote_artifact(reference: str) -> Path: - digest = hashlib.sha256(reference.encode("utf-8")).hexdigest()[:16] - leaf_name = Path(reference).name or "latest_signal.json" - return DEFAULT_PLUGIN_ARTIFACT_CACHE_DIR / digest / leaf_name + return cache_path_for_remote_artifact(reference, cache_dir=DEFAULT_PLUGIN_ARTIFACT_CACHE_DIR) def _as_bool(value: Any, *, default: bool = False) -> bool: diff --git a/tests/test_strategy_plugin_artifacts.py b/tests/test_strategy_plugin_artifacts.py new file mode 100644 index 0000000..17f82a8 --- /dev/null +++ b/tests/test_strategy_plugin_artifacts.py @@ -0,0 +1,52 @@ +from pathlib import Path + +import pytest + +from quant_platform_kit.common.strategy_plugin_artifacts import ( + cache_path_for_remote_artifact, + materialize_local_or_gcs_artifact, + parse_gcs_uri, +) + + +def test_parse_gcs_uri_requires_bucket_and_object(): + assert parse_gcs_uri("gs://bucket/path/latest_signal.json") == ( + "bucket", + "path/latest_signal.json", + ) + + with pytest.raises(ValueError, match="Invalid GCS URI"): + parse_gcs_uri("gs://bucket") + + with pytest.raises(ValueError, match="Unsupported GCS URI"): + parse_gcs_uri("https://example.com/latest_signal.json") + + +def test_cache_path_for_remote_artifact_is_stable_under_cache_dir(): + cache_dir = Path("/tmp/cache") + first = cache_path_for_remote_artifact( + "gs://bucket/path/latest_signal.json", + cache_dir=cache_dir, + ) + second = cache_path_for_remote_artifact( + "gs://bucket/path/latest_signal.json", + cache_dir=cache_dir, + ) + + assert first == second + assert first.parent.parent == cache_dir + assert first.name == "latest_signal.json" + + +def test_materialize_local_artifact_does_not_download(): + local_path, metadata = materialize_local_or_gcs_artifact( + "~/signals/latest_signal.json", + cache_dir=Path("/tmp/cache"), + client_factory=lambda: (_ for _ in ()).throw(AssertionError("should not download")), + ) + + assert local_path == Path("~/signals/latest_signal.json").expanduser() + assert metadata == { + "source_uri": None, + "local_path": "~/signals/latest_signal.json", + }