Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,13 @@ on:
branches: [ main ]
pull_request:

permissions:
contents: read

jobs:
test:
runs-on: ubuntu-latest
timeout-minutes: 20
steps:
- name: Checkout
uses: actions/checkout@v6
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/dependabot_auto_merge.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ jobs:
auto-merge:
if: github.event.workflow_run.conclusion == 'success' && startsWith(github.event.workflow_run.head_branch, 'dependabot/')
runs-on: ubuntu-latest
timeout-minutes: 10
permissions:
contents: write
pull-requests: write
Expand Down
58 changes: 58 additions & 0 deletions src/quant_platform_kit/common/strategy_plugin_artifacts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
"""Artifact path helpers for strategy plugin signal files."""

from __future__ import annotations

import hashlib
from pathlib import Path
from typing import Any


def materialize_local_or_gcs_artifact(
reference: str,
*,
cache_dir: Path,
client_factory: Any = None,
) -> tuple[Path, dict[str, str | None]]:
raw_reference = _required_string(reference, field_name="reference")
if not raw_reference.startswith("gs://"):
return Path(raw_reference).expanduser(), {"source_uri": None, "local_path": raw_reference}

local_path = cache_path_for_remote_artifact(raw_reference, cache_dir=cache_dir)
download_gcs_object(raw_reference, local_path, client_factory=client_factory)
return local_path, {"source_uri": raw_reference, "local_path": str(local_path)}


def download_gcs_object(uri: str, destination: Path, *, client_factory: Any = None) -> None:
if client_factory is None:
try:
from google.cloud import storage # type: ignore
except ImportError as exc:
raise RuntimeError("google-cloud-storage is required for GCS strategy plugin artifacts") from exc
client_factory = storage.Client
bucket_name, object_name = parse_gcs_uri(uri)
destination.parent.mkdir(parents=True, exist_ok=True)
client = client_factory()
client.bucket(bucket_name).blob(object_name).download_to_filename(str(destination))


def parse_gcs_uri(uri: str) -> tuple[str, str]:
raw_uri = str(uri or "").strip()
if not raw_uri.startswith("gs://"):
raise ValueError(f"Unsupported GCS URI: {raw_uri}")
bucket_name, _, object_name = raw_uri[5:].partition("/")
if not bucket_name or not object_name:
raise ValueError(f"Invalid GCS URI: {raw_uri}")
return bucket_name, object_name


def cache_path_for_remote_artifact(reference: str, *, cache_dir: Path) -> Path:
digest = hashlib.sha256(reference.encode("utf-8")).hexdigest()[:16]
leaf_name = Path(reference).name or "latest_signal.json"
return cache_dir / digest / leaf_name


def _required_string(value: Any, *, field_name: str) -> str:
text = str(value or "").strip()
if not text:
raise ValueError(f"{field_name} must be a non-empty string")
return text
42 changes: 15 additions & 27 deletions src/quant_platform_kit/common/strategy_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@
from pathlib import Path
from typing import Any, Callable

from quant_platform_kit.common.strategy_plugin_artifacts import (
cache_path_for_remote_artifact,
download_gcs_object,
materialize_local_or_gcs_artifact,
parse_gcs_uri,
)

PLUGIN_CRISIS_RESPONSE_SHADOW = "crisis_response_shadow"
PLUGIN_MARKET_REGIME_CONTROL = "market_regime_control"
PLUGIN_MACRO_RISK_GOVERNOR = "macro_risk_governor"
Expand Down Expand Up @@ -1237,42 +1244,23 @@ def _sanitize_key_part(value: Any) -> str:


def _materialize_artifact_path(reference: str, *, client_factory: Any = None) -> tuple[Path, dict[str, str | None]]:
raw_reference = _required_string(reference, field_name="reference")
if not raw_reference.startswith("gs://"):
return Path(raw_reference).expanduser(), {"source_uri": None, "local_path": raw_reference}

local_path = _cache_path_for_remote_artifact(raw_reference)
_download_gcs_object(raw_reference, local_path, client_factory=client_factory)
return local_path, {"source_uri": raw_reference, "local_path": str(local_path)}
return materialize_local_or_gcs_artifact(
reference,
cache_dir=DEFAULT_PLUGIN_ARTIFACT_CACHE_DIR,
client_factory=client_factory,
)


def _download_gcs_object(uri: str, destination: Path, *, client_factory: Any = None) -> None:
if client_factory is None:
try:
from google.cloud import storage # type: ignore
except ImportError as exc:
raise RuntimeError("google-cloud-storage is required for GCS strategy plugin artifacts") from exc
client_factory = storage.Client
bucket_name, object_name = _parse_gcs_uri(uri)
destination.parent.mkdir(parents=True, exist_ok=True)
client = client_factory()
client.bucket(bucket_name).blob(object_name).download_to_filename(str(destination))
download_gcs_object(uri, destination, client_factory=client_factory)


def _parse_gcs_uri(uri: str) -> tuple[str, str]:
raw_uri = str(uri or "").strip()
if not raw_uri.startswith("gs://"):
raise ValueError(f"Unsupported GCS URI: {raw_uri}")
bucket_name, _, object_name = raw_uri[5:].partition("/")
if not bucket_name or not object_name:
raise ValueError(f"Invalid GCS URI: {raw_uri}")
return bucket_name, object_name
return parse_gcs_uri(uri)


def _cache_path_for_remote_artifact(reference: str) -> Path:
digest = hashlib.sha256(reference.encode("utf-8")).hexdigest()[:16]
leaf_name = Path(reference).name or "latest_signal.json"
return DEFAULT_PLUGIN_ARTIFACT_CACHE_DIR / digest / leaf_name
return cache_path_for_remote_artifact(reference, cache_dir=DEFAULT_PLUGIN_ARTIFACT_CACHE_DIR)


def _as_bool(value: Any, *, default: bool = False) -> bool:
Expand Down
52 changes: 52 additions & 0 deletions tests/test_strategy_plugin_artifacts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from pathlib import Path

import pytest

from quant_platform_kit.common.strategy_plugin_artifacts import (
cache_path_for_remote_artifact,
materialize_local_or_gcs_artifact,
parse_gcs_uri,
)


def test_parse_gcs_uri_requires_bucket_and_object():
assert parse_gcs_uri("gs://bucket/path/latest_signal.json") == (
"bucket",
"path/latest_signal.json",
)

with pytest.raises(ValueError, match="Invalid GCS URI"):
parse_gcs_uri("gs://bucket")

with pytest.raises(ValueError, match="Unsupported GCS URI"):
parse_gcs_uri("https://example.com/latest_signal.json")


def test_cache_path_for_remote_artifact_is_stable_under_cache_dir():
cache_dir = Path("/tmp/cache")
first = cache_path_for_remote_artifact(
"gs://bucket/path/latest_signal.json",
cache_dir=cache_dir,
)
second = cache_path_for_remote_artifact(
"gs://bucket/path/latest_signal.json",
cache_dir=cache_dir,
)

assert first == second
assert first.parent.parent == cache_dir
assert first.name == "latest_signal.json"


def test_materialize_local_artifact_does_not_download():
local_path, metadata = materialize_local_or_gcs_artifact(
"~/signals/latest_signal.json",
cache_dir=Path("/tmp/cache"),
client_factory=lambda: (_ for _ in ()).throw(AssertionError("should not download")),
)

assert local_path == Path("~/signals/latest_signal.json").expanduser()
assert metadata == {
"source_uri": None,
"local_path": "~/signals/latest_signal.json",
}