diff --git a/backend/app/api/routes_settings.py b/backend/app/api/routes_settings.py index b14cbb0..ef48b74 100644 --- a/backend/app/api/routes_settings.py +++ b/backend/app/api/routes_settings.py @@ -1,13 +1,77 @@ -from backend.app.schemas.settings import AppSettings -from backend.app.services.settings_service import get_model_settings, update_model_settings +from __future__ import annotations + +from argument_risk_engine.classification.model_provider import ProviderProfile + +from backend.app.schemas.settings import ( + ActiveProviderRequest, + ActiveProviderResponse, + AppSettings, + ProviderListResponse, + ProviderProfilePatch, + ProviderProfileSchema, + ProviderTestResponse, +) +from backend.app.services.settings_service import ( + get_active_model_provider, + get_model_settings, + list_model_providers, + patch_model_provider, + save_model_provider, + set_active_model_provider, + test_model_provider, + update_model_settings, +) from fastapi import APIRouter router = APIRouter(prefix="/settings", tags=["settings"]) + @router.get("", response_model=AppSettings) def get_settings() -> AppSettings: return get_model_settings() + @router.put("", response_model=AppSettings) def put_settings(settings: AppSettings) -> AppSettings: return update_model_settings(settings) + + +@router.get("/model-providers", response_model=ProviderListResponse) +def get_model_providers() -> ProviderListResponse: + return ProviderListResponse(providers=[ProviderProfileSchema(**profile.model_dump()) for profile in list_model_providers()]) + + +@router.post("/model-providers", response_model=ProviderProfileSchema) +def post_model_provider(profile: ProviderProfileSchema) -> ProviderProfileSchema: + saved = save_model_provider(ProviderProfile(**profile.model_dump())) + return ProviderProfileSchema(**saved.model_dump()) + + +@router.patch("/model-providers/{provider_id}", response_model=ProviderProfileSchema) +def patch_model_provider_route(provider_id: str, patch: ProviderProfilePatch) -> ProviderProfileSchema | dict[str, str]: + updated = patch_model_provider(provider_id, patch.model_dump()) + if updated is None: + return {"detail": "provider not found"} + return ProviderProfileSchema(**updated.model_dump()) + + +@router.post("/model-providers/{provider_id}/test", response_model=ProviderTestResponse) +def test_model_provider_route(provider_id: str) -> ProviderTestResponse | dict[str, str]: + result = test_model_provider(provider_id) + if result is None: + return {"detail": "provider not found"} + return ProviderTestResponse(**result.model_dump()) + + +@router.get("/active-model-provider", response_model=ActiveProviderResponse) +def get_active_model_provider_route() -> ActiveProviderResponse: + provider_id, provider = get_active_model_provider() + return ActiveProviderResponse(provider_id=provider_id, provider=ProviderProfileSchema(**provider.model_dump()) if provider else None) + + +@router.post("/active-model-provider", response_model=ActiveProviderResponse) +def post_active_model_provider(payload: ActiveProviderRequest) -> ActiveProviderResponse | dict[str, str]: + provider_id, provider = set_active_model_provider(payload.provider_id) + if provider is None: + return {"detail": "provider not found"} + return ActiveProviderResponse(provider_id=provider_id, provider=ProviderProfileSchema(**provider.model_dump())) diff --git a/backend/app/schemas/settings.py b/backend/app/schemas/settings.py index ff24c02..c815cc2 100644 --- a/backend/app/schemas/settings.py +++ b/backend/app/schemas/settings.py @@ -1,7 +1,65 @@ -from pydantic import BaseModel +from __future__ import annotations + +from typing import Literal + +from pydantic import BaseModel, Field + +ProviderType = Literal["deterministic", "openai_compatible"] + + +class ProviderProfileSchema(BaseModel): + provider_id: str + label: str + provider_type: ProviderType = "openai_compatible" + base_url: str = "" + model_name: str = "" + api_key_env_var: str = "" + timeout_seconds: int = 60 + max_tokens: int = 2048 + temperature: float = 0.0 + supports_json_mode: str | bool = "unknown" + supports_streaming: str | bool = "unknown" + enabled: bool = True + + +class ProviderProfilePatch(BaseModel): + label: str | None = None + provider_type: ProviderType | None = None + base_url: str | None = None + model_name: str | None = None + api_key_env_var: str | None = None + timeout_seconds: int | None = None + max_tokens: int | None = None + temperature: float | None = None + supports_json_mode: str | bool | None = None + supports_streaming: str | bool | None = None + enabled: bool | None = None + + +class ProviderListResponse(BaseModel): + providers: list[ProviderProfileSchema] = Field(default_factory=list) + + +class ActiveProviderRequest(BaseModel): + provider_id: str + + +class ActiveProviderResponse(BaseModel): + provider_id: str + provider: ProviderProfileSchema | None = None + + +class ProviderTestResponse(BaseModel): + provider_id: str + status: str + latency_ms: int + warnings: list[str] = Field(default_factory=list) + models: list[str] = Field(default_factory=list) + detail: str = "" class AppSettings(BaseModel): - llm_provider: str = "deterministic" + llm_provider: str = "deterministic_baseline" model: str = "local-keyword" temperature: float = 0.0 + active_model_provider: str = "deterministic_baseline" diff --git a/backend/app/services/settings_service.py b/backend/app/services/settings_service.py index a1b2629..4b22f61 100644 --- a/backend/app/services/settings_service.py +++ b/backend/app/services/settings_service.py @@ -1,11 +1,95 @@ +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from argument_risk_engine.classification.llm_client import LLMClient, ProviderTestResult +from argument_risk_engine.classification.model_provider import ProviderProfile +from argument_risk_engine.classification.provider_registry import ProviderRegistry + +import yaml +from backend.app.core.paths import DATA_DIR from backend.app.schemas.settings import AppSettings -_CURRENT = AppSettings() +MODEL_PROFILES_PATH = DATA_DIR / "config" / "model_profiles.yaml" +APP_SETTINGS_PATH = DATA_DIR / "config" / "app_settings.yaml" + +_registry = ProviderRegistry(MODEL_PROFILES_PATH) + def get_model_settings() -> AppSettings: - return _CURRENT + payload = _read_yaml(APP_SETTINGS_PATH) + provider_id = payload.get("active_model_provider") or payload.get("llm_provider") or "deterministic_baseline" + return AppSettings( + llm_provider=payload.get("llm_provider", provider_id), + model=payload.get("model", "local-keyword"), + temperature=float(payload.get("temperature", 0.0)), + active_model_provider=provider_id, + ) + def update_model_settings(settings: AppSettings) -> AppSettings: - global _CURRENT - _CURRENT = settings - return _CURRENT + provider_id = settings.llm_provider or settings.active_model_provider + payload = settings.model_dump() + payload["active_model_provider"] = settings.active_model_provider or provider_id + payload["llm_provider"] = provider_id + _write_yaml(APP_SETTINGS_PATH, payload) + return get_model_settings() + + +def list_model_providers() -> list[ProviderProfile]: + return [_sanitize(profile) for profile in _registry.list_profiles()] + + +def save_model_provider(profile: ProviderProfile) -> ProviderProfile: + return _sanitize(_registry.upsert_profile(profile)) + + +def patch_model_provider(provider_id: str, patch: dict[str, Any]) -> ProviderProfile | None: + patch.pop("api_key", None) + patch.pop("raw_api_key", None) + updated = _registry.patch_profile(provider_id, patch) + return _sanitize(updated) if updated else None + + +def get_active_model_provider() -> tuple[str, ProviderProfile | None]: + settings = get_model_settings() + profile = _registry.get_profile(settings.active_model_provider) + return settings.active_model_provider, _sanitize(profile) if profile else None + + +def set_active_model_provider(provider_id: str) -> tuple[str, ProviderProfile | None]: + profile = _registry.get_profile(provider_id) + if profile is None: + return provider_id, None + settings = get_model_settings() + settings.active_model_provider = provider_id + settings.llm_provider = provider_id + settings.model = profile.model_name or settings.model + settings.temperature = float(profile.temperature) + update_model_settings(settings) + return provider_id, _sanitize(profile) + + +def test_model_provider(provider_id: str) -> ProviderTestResult | None: + profile = _registry.get_profile(provider_id) + if profile is None: + return None + return LLMClient(profile).test_provider() + + +def _sanitize(profile: ProviderProfile | None) -> ProviderProfile | None: + if profile is None: + return None + return ProviderProfile(**profile.model_dump()) + + +def _read_yaml(path: Path) -> dict[str, Any]: + if not path.exists(): + return {} + return yaml.safe_load(path.read_text(encoding="utf-8")) or {} + + +def _write_yaml(path: Path, payload: dict[str, Any]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(yaml.safe_dump(payload, sort_keys=False), encoding="utf-8") diff --git a/data/config/app_settings.yaml b/data/config/app_settings.yaml index 6af4e05..d857ef5 100644 --- a/data/config/app_settings.yaml +++ b/data/config/app_settings.yaml @@ -1 +1,6 @@ -llm_provider: deterministic +{ + "llm_provider": "deterministic_baseline", + "model": "local-keyword", + "temperature": 0.0, + "active_model_provider": "deterministic_baseline" +} diff --git a/data/config/model_profiles.yaml b/data/config/model_profiles.yaml index a584d07..ebafdfb 100644 --- a/data/config/model_profiles.yaml +++ b/data/config/model_profiles.yaml @@ -1,95 +1,89 @@ { - "version": "0.2.0", + "version": "1.0", "items": [ { "provider_id": "deterministic_baseline", "label": "Deterministic Baseline", "provider_type": "deterministic", "base_url": "", + "model_name": "local-keyword", "api_key_env_var": "", - "default_model": "", - "timeout_seconds": "30", - "max_tokens": "0", - "temperature": "0", - "supports_json_mode": "n/a", - "supports_streaming": "n/a", - "enabled_default": "yes", - "security_note": "No API key required; must always work offline." + "timeout_seconds": 30, + "max_tokens": 0, + "temperature": 0.0, + "supports_json_mode": true, + "supports_streaming": false, + "enabled": true }, { "provider_id": "lm_studio_local", "label": "LM Studio Local", "provider_type": "openai_compatible", "base_url": "http://localhost:1234/v1", + "model_name": "local-model", "api_key_env_var": "LM_STUDIO_API_KEY", - "default_model": "local-model", - "timeout_seconds": "60", - "max_tokens": "2048", - "temperature": "0", + "timeout_seconds": 60, + "max_tokens": 2048, + "temperature": 0.0, "supports_json_mode": "unknown", - "supports_streaming": "yes", - "enabled_default": "no", - "security_note": "Dashboard must not store raw key; local servers may accept dummy key." + "supports_streaming": true, + "enabled": true }, { "provider_id": "ollama_local", "label": "Ollama Local", "provider_type": "openai_compatible", "base_url": "http://localhost:11434/v1", + "model_name": "qwen3:8b", "api_key_env_var": "OLLAMA_API_KEY", - "default_model": "qwen3:8b", - "timeout_seconds": "60", - "max_tokens": "2048", - "temperature": "0", + "timeout_seconds": 60, + "max_tokens": 2048, + "temperature": 0.0, "supports_json_mode": "unknown", - "supports_streaming": "yes", - "enabled_default": "no", - "security_note": "Use OpenAI-compatible endpoint; avoid assuming JSON mode." + "supports_streaming": true, + "enabled": true }, { "provider_id": "openai_remote", "label": "OpenAI Remote", "provider_type": "openai_compatible", "base_url": "https://api.openai.com/v1", + "model_name": "gpt-4.1-mini", "api_key_env_var": "OPENAI_API_KEY", - "default_model": "gpt-4.1-mini", - "timeout_seconds": "60", - "max_tokens": "2048", - "temperature": "0", - "supports_json_mode": "yes", - "supports_streaming": "yes", - "enabled_default": "no", - "security_note": "Backend reads key from environment variable only." + "timeout_seconds": 60, + "max_tokens": 2048, + "temperature": 0.0, + "supports_json_mode": true, + "supports_streaming": true, + "enabled": true }, { "provider_id": "openrouter_remote", "label": "OpenRouter Remote", "provider_type": "openai_compatible", "base_url": "https://openrouter.ai/api/v1", + "model_name": "", "api_key_env_var": "OPENROUTER_API_KEY", - "default_model": "", - "timeout_seconds": "60", - "max_tokens": "2048", - "temperature": "0", + "timeout_seconds": 60, + "max_tokens": 2048, + "temperature": 0.0, "supports_json_mode": "model_dependent", - "supports_streaming": "yes", - "enabled_default": "no", - "security_note": "Treat as OpenAI-compatible; do not hardcode model names." + "supports_streaming": true, + "enabled": true }, { "provider_id": "custom_openai_compatible", "label": "Custom OpenAI-Compatible", "provider_type": "openai_compatible", "base_url": "", + "model_name": "", "api_key_env_var": "CUSTOM_LLM_API_KEY", - "default_model": "", - "timeout_seconds": "60", - "max_tokens": "2048", - "temperature": "0", + "timeout_seconds": 60, + "max_tokens": 2048, + "temperature": 0.0, "supports_json_mode": "unknown", "supports_streaming": "unknown", - "enabled_default": "no", - "security_note": "User-configurable base_url and model_name; never expose raw key." + "enabled": false } ] -} \ No newline at end of file +} diff --git a/engine/argument_risk_engine/classification/llm_client.py b/engine/argument_risk_engine/classification/llm_client.py index a7212d9..9966b75 100644 --- a/engine/argument_risk_engine/classification/llm_client.py +++ b/engine/argument_risk_engine/classification/llm_client.py @@ -1,3 +1,121 @@ +from __future__ import annotations + +import json +import os +import time +from typing import Any +from urllib.error import HTTPError, URLError +from urllib.request import Request, urlopen + +from argument_risk_engine.classification.model_provider import ProviderProfile +from pydantic import BaseModel + + +class ProviderTestResult(BaseModel): + provider_id: str + status: str + latency_ms: int + warnings: list[str] + models: list[str] + detail: str = "" + + class LLMClient: + def __init__(self, profile: ProviderProfile | None = None): + self.profile = profile + def classify(self, prompt: str) -> dict[str, str]: - return {"provider": "deterministic", "response": "LLM providers are optional in the MVP."} + provider_id = self.profile.provider_id if self.profile else "deterministic_baseline" + return {"provider": provider_id, "response": "LLM providers are optional; deterministic baseline is available."} + + def test_provider(self) -> ProviderTestResult: + if self.profile is None or self.profile.provider_type == "deterministic": + provider_id = self.profile.provider_id if self.profile else "deterministic_baseline" + return ProviderTestResult(provider_id=provider_id, status="ok", latency_ms=0, warnings=[], models=[], detail="Deterministic provider is available offline.") + return test_openai_compatible_provider(self.profile) + + +def test_openai_compatible_provider(profile: ProviderProfile) -> ProviderTestResult: + start = time.perf_counter() + warnings: list[str] = [] + models: list[str] = [] + + if not profile.base_url: + return _result(profile, "not_configured", start, ["base_url is required for OpenAI-compatible providers."], models) + + api_key = os.environ.get(profile.api_key_env_var, "") if profile.api_key_env_var else "" + if profile.api_key_env_var and not api_key: + warnings.append(f"Environment variable {profile.api_key_env_var} is not set; local providers may still accept unauthenticated requests.") + + models_status, model_warning, models = _get_models(profile, api_key) + if models_status == "ok": + return _result(profile, "ok", start, warnings, models, "Fetched /models successfully.") + if model_warning: + warnings.append(model_warning) + + chat_status, chat_warning = _minimal_chat_completion(profile, api_key) + if chat_status == "ok": + return _result(profile, "ok", start, warnings, models, "Minimal chat completion succeeded.") + if chat_warning: + warnings.append(chat_warning) + return _result(profile, "failed", start, warnings, models) + + +def _get_models(profile: ProviderProfile, api_key: str) -> tuple[str, str, list[str]]: + url = _join_url(profile.base_url, "models") + status, payload, warning = _request_json("GET", url, api_key, None, profile.timeout_seconds) + if status != "ok": + return status, warning, [] + data = payload.get("data", payload if isinstance(payload, list) else []) + models = [str(item.get("id", item)) for item in data if item] + return "ok", "", models + + +def _minimal_chat_completion(profile: ProviderProfile, api_key: str) -> tuple[str, str]: + if not profile.model_name: + return "not_configured", "model_name is required to test chat completions." + body = { + "model": profile.model_name, + "messages": [{"role": "user", "content": "Return ok."}], + "max_tokens": 8, + "temperature": 0, + "stream": False, + } + url = _join_url(profile.base_url, "chat/completions") + status, _payload, warning = _request_json("POST", url, api_key, body, profile.timeout_seconds) + return status, warning + + +def _request_json(method: str, url: str, api_key: str, body: dict[str, Any] | None, timeout: int) -> tuple[str, dict[str, Any] | list[Any], str]: + headers = {"Content-Type": "application/json"} + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + data = json.dumps(body).encode("utf-8") if body is not None else None + request = Request(url, data=data, headers=headers, method=method) + try: + with urlopen(request, timeout=max(1, int(timeout or 30))) as response: + raw = response.read().decode("utf-8") + return "ok", json.loads(raw) if raw else {}, "" + except HTTPError as exc: + return "failed", {}, f"{method} {url} returned HTTP {exc.code}." + except URLError as exc: + return "failed", {}, f"{method} {url} could not connect: {exc.reason}." + except TimeoutError: + return "failed", {}, f"{method} {url} timed out." + except json.JSONDecodeError: + return "failed", {}, f"{method} {url} returned invalid JSON." + + +def _join_url(base_url: str, suffix: str) -> str: + return f"{base_url.rstrip('/')}/{suffix.lstrip('/')}" + + +def _result(profile: ProviderProfile, status: str, start: float, warnings: list[str], models: list[str], detail: str = "") -> ProviderTestResult: + return ProviderTestResult( + provider_id=profile.provider_id, + status=status, + latency_ms=int((time.perf_counter() - start) * 1000), + warnings=warnings, + models=models, + detail=detail, + ) diff --git a/engine/argument_risk_engine/classification/model_provider.py b/engine/argument_risk_engine/classification/model_provider.py index 90a7302..94b545e 100644 --- a/engine/argument_risk_engine/classification/model_provider.py +++ b/engine/argument_risk_engine/classification/model_provider.py @@ -1,7 +1,40 @@ +from __future__ import annotations + +from typing import Literal + from pydantic import BaseModel +ProviderType = Literal["deterministic", "openai_compatible"] + + +class ProviderProfile(BaseModel): + provider_id: str + label: str + provider_type: ProviderType = "openai_compatible" + base_url: str = "" + model_name: str = "" + api_key_env_var: str = "" + timeout_seconds: int = 60 + max_tokens: int = 2048 + temperature: float = 0.0 + supports_json_mode: str | bool = "unknown" + supports_streaming: str | bool = "unknown" + enabled: bool = False + + def sanitized(self) -> ProviderProfile: + """Return a copy containing only metadata, never an API key value.""" + return ProviderProfile(**self.model_dump()) + + @property + def requires_secret(self) -> bool: + return self.provider_type == "openai_compatible" and bool(self.api_key_env_var) + + @property + def is_remote(self) -> bool: + return self.base_url.startswith("https://") and "localhost" not in self.base_url and "127.0.0.1" not in self.base_url + class ModelProvider(BaseModel): - name: str = "deterministic" + name: str = "deterministic_baseline" model: str = "local-keyword" temperature: float = 0.0 diff --git a/engine/argument_risk_engine/classification/provider_registry.py b/engine/argument_risk_engine/classification/provider_registry.py index d981c7d..1717be5 100644 --- a/engine/argument_risk_engine/classification/provider_registry.py +++ b/engine/argument_risk_engine/classification/provider_registry.py @@ -1 +1,160 @@ -PROVIDERS = {"deterministic": "Local deterministic keyword classifier", "openai": "Paid OpenAI-compatible provider"} +from __future__ import annotations + +from pathlib import Path +from typing import Any + +import yaml +from argument_risk_engine.classification.model_provider import ProviderProfile + +DEFAULT_PROFILES: list[dict[str, Any]] = [ + { + "provider_id": "deterministic_baseline", + "label": "Deterministic Baseline", + "provider_type": "deterministic", + "base_url": "", + "model_name": "local-keyword", + "api_key_env_var": "", + "timeout_seconds": 30, + "max_tokens": 0, + "temperature": 0.0, + "supports_json_mode": True, + "supports_streaming": False, + "enabled": True, + }, + { + "provider_id": "lm_studio_local", + "label": "LM Studio Local", + "provider_type": "openai_compatible", + "base_url": "http://localhost:1234/v1", + "model_name": "local-model", + "api_key_env_var": "LM_STUDIO_API_KEY", + "timeout_seconds": 60, + "max_tokens": 2048, + "temperature": 0.0, + "supports_json_mode": "unknown", + "supports_streaming": True, + "enabled": True, + }, + { + "provider_id": "ollama_local", + "label": "Ollama Local", + "provider_type": "openai_compatible", + "base_url": "http://localhost:11434/v1", + "model_name": "qwen3:8b", + "api_key_env_var": "OLLAMA_API_KEY", + "timeout_seconds": 60, + "max_tokens": 2048, + "temperature": 0.0, + "supports_json_mode": "unknown", + "supports_streaming": True, + "enabled": True, + }, + { + "provider_id": "openai_remote", + "label": "OpenAI Remote", + "provider_type": "openai_compatible", + "base_url": "https://api.openai.com/v1", + "model_name": "gpt-4.1-mini", + "api_key_env_var": "OPENAI_API_KEY", + "timeout_seconds": 60, + "max_tokens": 2048, + "temperature": 0.0, + "supports_json_mode": True, + "supports_streaming": True, + "enabled": True, + }, + { + "provider_id": "openrouter_remote", + "label": "OpenRouter Remote", + "provider_type": "openai_compatible", + "base_url": "https://openrouter.ai/api/v1", + "model_name": "", + "api_key_env_var": "OPENROUTER_API_KEY", + "timeout_seconds": 60, + "max_tokens": 2048, + "temperature": 0.0, + "supports_json_mode": "model_dependent", + "supports_streaming": True, + "enabled": True, + }, + { + "provider_id": "custom_openai_compatible", + "label": "Custom OpenAI-Compatible", + "provider_type": "openai_compatible", + "base_url": "", + "model_name": "", + "api_key_env_var": "CUSTOM_LLM_API_KEY", + "timeout_seconds": 60, + "max_tokens": 2048, + "temperature": 0.0, + "supports_json_mode": "unknown", + "supports_streaming": "unknown", + "enabled": False, + }, +] + + +class ProviderRegistry: + def __init__(self, path: Path): + self.path = path + + def list_profiles(self) -> list[ProviderProfile]: + payload = self._read_payload() + items = payload.get("items", DEFAULT_PROFILES) + return [self._profile_from_item(item) for item in items] + + def get_profile(self, provider_id: str) -> ProviderProfile | None: + return next((profile for profile in self.list_profiles() if profile.provider_id == provider_id), None) + + def upsert_profile(self, profile: ProviderProfile) -> ProviderProfile: + profiles = self.list_profiles() + updated = False + for index, existing in enumerate(profiles): + if existing.provider_id == profile.provider_id: + profiles[index] = profile + updated = True + break + if not updated: + profiles.append(profile) + self.save_profiles(profiles) + return profile + + def patch_profile(self, provider_id: str, patch: dict[str, Any]) -> ProviderProfile | None: + profile = self.get_profile(provider_id) + if profile is None: + return None + data = profile.model_dump() + data.update({key: value for key, value in patch.items() if value is not None}) + updated = ProviderProfile(**data) + return self.upsert_profile(updated) + + def save_profiles(self, profiles: list[ProviderProfile]) -> None: + self.path.parent.mkdir(parents=True, exist_ok=True) + payload = {"version": "1.0", "items": [profile.model_dump() for profile in profiles]} + self.path.write_text(yaml.safe_dump(payload, sort_keys=False), encoding="utf-8") + + def _read_payload(self) -> dict[str, Any]: + if not self.path.exists(): + self.save_profiles([ProviderProfile(**item) for item in DEFAULT_PROFILES]) + payload = yaml.safe_load(self.path.read_text(encoding="utf-8")) or {} + if not payload.get("items"): + payload["items"] = DEFAULT_PROFILES + return payload + + @staticmethod + def _profile_from_item(item: dict[str, Any]) -> ProviderProfile: + normalized = dict(item) + if "model_name" not in normalized and "default_model" in normalized: + normalized["model_name"] = normalized.pop("default_model") + if "enabled" not in normalized and "enabled_default" in normalized: + normalized["enabled"] = str(normalized.pop("enabled_default")).lower() in {"yes", "true", "1"} + for int_field in ("timeout_seconds", "max_tokens"): + if int_field in normalized: + normalized[int_field] = int(normalized[int_field] or 0) + if "temperature" in normalized: + normalized["temperature"] = float(normalized["temperature"] or 0) + normalized.pop("security_note", None) + return ProviderProfile(**normalized) + + +PROVIDERS = {item["provider_id"]: item["label"] for item in DEFAULT_PROFILES} diff --git a/frontend/src/api/client.ts b/frontend/src/api/client.ts index 722d651..886123b 100644 --- a/frontend/src/api/client.ts +++ b/frontend/src/api/client.ts @@ -1,4 +1,4 @@ -import type { AnalysisResponse, TaxonomyCoverage, TaxonomyEntry, TaxonomyImportResult, TaxonomyPackSummary, TaxonomyQualityReport, TaxonomyValidationResult } from './types' +import type { ActiveProviderResponse, AnalysisResponse, ProviderListResponse, ProviderProfile, ProviderTestResponse, TaxonomyCoverage, TaxonomyEntry, TaxonomyImportResult, TaxonomyPackSummary, TaxonomyQualityReport, TaxonomyValidationResult } from './types' const API_BASE = import.meta.env.VITE_API_BASE ?? 'http://localhost:8000/api' @@ -77,3 +77,49 @@ export async function updateTaxonomyActivation(riskId: string, activationStatus: if (payload.detail) throw new Error(payload.detail) return payload.entry } + + +export async function fetchModelProviders(): Promise { + const response = await fetch(`${API_BASE}/settings/model-providers`) + if (!response.ok) throw new Error('Model providers request failed') + const payload: ProviderListResponse = await response.json() + return payload.providers +} + +export async function saveModelProvider(profile: ProviderProfile): Promise { + const response = await fetch(`${API_BASE}/settings/model-providers/${encodeURIComponent(profile.provider_id)}`, { + method: 'PATCH', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify(profile), + }) + if (!response.ok) throw new Error('Model provider update failed') + const payload = await response.json() + if (payload.detail) throw new Error(payload.detail) + return payload +} + +export async function fetchActiveModelProvider(): Promise { + const response = await fetch(`${API_BASE}/settings/active-model-provider`) + if (!response.ok) throw new Error('Active provider request failed') + return response.json() +} + +export async function setActiveModelProvider(providerId: string): Promise { + const response = await fetch(`${API_BASE}/settings/active-model-provider`, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ provider_id: providerId }), + }) + if (!response.ok) throw new Error('Active provider update failed') + const payload = await response.json() + if (payload.detail) throw new Error(payload.detail) + return payload +} + +export async function testModelProvider(providerId: string): Promise { + const response = await fetch(`${API_BASE}/settings/model-providers/${encodeURIComponent(providerId)}/test`, { method: 'POST' }) + if (!response.ok) throw new Error('Provider test failed') + const payload = await response.json() + if (payload.detail) throw new Error(payload.detail) + return payload +} diff --git a/frontend/src/api/types.ts b/frontend/src/api/types.ts index e80ca38..ff42580 100644 --- a/frontend/src/api/types.ts +++ b/frontend/src/api/types.ts @@ -70,3 +70,22 @@ export type TaxonomyIssue = { code: string; message: string; severity: string; e export type TaxonomyQualityReport = { ok: boolean; entry_count: number; active_classification_count: number; error_count: number; warning_count: number; errors: TaxonomyIssue[]; warnings: TaxonomyIssue[] } export type TaxonomyPackSummary = { pack: string; entry_count: number; active_count: number; enabled_for_classification_count: number } export type TaxonomyValidationResult = { ok: boolean; entry_count: number; active_classification_count: number; errors: TaxonomyIssue[]; warnings: TaxonomyIssue[] } + +export type ProviderType = 'deterministic' | 'openai_compatible' +export type ProviderProfile = { + provider_id: string + label: string + provider_type: ProviderType + base_url: string + model_name: string + api_key_env_var: string + timeout_seconds: number + max_tokens: number + temperature: number + supports_json_mode: boolean | string + supports_streaming: boolean | string + enabled: boolean +} +export type ProviderListResponse = { providers: ProviderProfile[] } +export type ActiveProviderResponse = { provider_id: string; provider: ProviderProfile | null } +export type ProviderTestResponse = { provider_id: string; status: string; latency_ms: number; warnings: string[]; models: string[]; detail: string } diff --git a/frontend/src/components/layout/Sidebar.tsx b/frontend/src/components/layout/Sidebar.tsx index 176e675..6c69f75 100644 --- a/frontend/src/components/layout/Sidebar.tsx +++ b/frontend/src/components/layout/Sidebar.tsx @@ -1,2 +1,2 @@ -const items = ['Analyze', 'Taxonomy', 'Workbench', 'Settings', 'Review', 'Evaluation', 'Reports'] -export function Sidebar() { return } +const items = ['Analyze', 'Taxonomy', 'Workbench', 'Model Settings', 'Review', 'Evaluation', 'Reports'] +export function Sidebar() { return } diff --git a/frontend/src/components/settings/ModelParameterForm.tsx b/frontend/src/components/settings/ModelParameterForm.tsx index 7fb45ca..4f3ee56 100644 --- a/frontend/src/components/settings/ModelParameterForm.tsx +++ b/frontend/src/components/settings/ModelParameterForm.tsx @@ -1,2 +1,27 @@ -import { Card } from '../shared/Card' -export function ModelParameterForm() { return

ModelParameterForm

MVP placeholder wired for local file-backed workflows.

} +import type { ProviderProfile } from '../../api/types' + +const numericFields = [ + { key: 'timeout_seconds', label: 'Timeout (seconds)', min: 1, step: 1 }, + { key: 'max_tokens', label: 'Max tokens', min: 0, step: 1 }, + { key: 'temperature', label: 'Temperature', min: 0, step: 0.1 }, +] as const + +export function ModelParameterForm({ profile, onChange }: { profile: ProviderProfile; onChange: (profile: ProviderProfile) => void }) { + const update = (key: keyof ProviderProfile, value: string | number | boolean) => onChange({ ...profile, [key]: value }) + + return
+ + + + {numericFields.map(field => )} +
+} diff --git a/frontend/src/components/settings/ModelSettingsPage.tsx b/frontend/src/components/settings/ModelSettingsPage.tsx index 4a08397..83cc7d6 100644 --- a/frontend/src/components/settings/ModelSettingsPage.tsx +++ b/frontend/src/components/settings/ModelSettingsPage.tsx @@ -1,2 +1,60 @@ +import { useEffect, useState } from 'react' +import { fetchActiveModelProvider, fetchModelProviders, saveModelProvider, setActiveModelProvider, testModelProvider } from '../../api/client' +import type { ProviderProfile, ProviderTestResponse } from '../../api/types' import { Card } from '../shared/Card' -export function ModelSettingsPage() { return

ModelSettings

MVP placeholder wired for local file-backed workflows.

} +import { ProviderProfileCard } from './ProviderProfileCard' + +export function ModelSettingsPage() { + const [providers, setProviders] = useState([]) + const [activeProviderId, setActiveProviderId] = useState('deterministic_baseline') + const [tests, setTests] = useState>({}) + const [testingId, setTestingId] = useState('') + const [message, setMessage] = useState('') + + useEffect(() => { + Promise.all([fetchModelProviders(), fetchActiveModelProvider()]) + .then(([providerList, active]) => { + setProviders(providerList) + setActiveProviderId(active.provider_id) + }) + .catch(error => setMessage(error instanceof Error ? error.message : 'Unable to load model settings')) + }, []) + + const updateProvider = (profile: ProviderProfile) => setProviders(items => items.map(item => item.provider_id === profile.provider_id ? profile : item)) + const saveProvider = (profile: ProviderProfile) => saveModelProvider(profile).then(saved => { + updateProvider(saved) + setMessage(`Saved ${saved.label}`) + }).catch(error => setMessage(error instanceof Error ? error.message : 'Save failed')) + const selectProvider = (providerId: string) => setActiveModelProvider(providerId).then(active => { + setActiveProviderId(active.provider_id) + setMessage(`Active provider set to ${active.provider_id}`) + }).catch(error => setMessage(error instanceof Error ? error.message : 'Selection failed')) + const runTest = (providerId: string) => { + setTestingId(providerId) + testModelProvider(providerId) + .then(result => setTests(current => ({ ...current, [providerId]: result }))) + .catch(error => setMessage(error instanceof Error ? error.message : 'Provider test failed')) + .finally(() => setTestingId('')) + } + + return
+ +

Model Settings

+

Configure deterministic, local OpenAI-compatible, paid remote, or custom OpenAI-compatible providers. Provider metadata can be stored in the dashboard; paid secrets must be supplied to the backend through environment variables.

+ {message ?

{message}

: null} +
+
+ {providers.map(profile => )} +
+
+} diff --git a/frontend/src/components/settings/ProviderProfileCard.tsx b/frontend/src/components/settings/ProviderProfileCard.tsx index 54e2a2c..aba9183 100644 --- a/frontend/src/components/settings/ProviderProfileCard.tsx +++ b/frontend/src/components/settings/ProviderProfileCard.tsx @@ -1,2 +1,47 @@ +import type { ProviderProfile, ProviderTestResponse } from '../../api/types' +import { Badge } from '../shared/Badge' +import { Button } from '../shared/Button' import { Card } from '../shared/Card' -export function ProviderProfileCard() { return

ProviderProfileCard

MVP placeholder wired for local file-backed workflows.

} +import { ModelParameterForm } from './ModelParameterForm' +import { ProviderTestPanel } from './ProviderTestPanel' + +function isRemote(profile: ProviderProfile) { + return profile.base_url.startsWith('https://') && !profile.base_url.includes('localhost') && !profile.base_url.includes('127.0.0.1') +} + +function jsonModeVerified(profile: ProviderProfile) { + return profile.supports_json_mode === true || profile.supports_json_mode === 'yes' || profile.supports_json_mode === 'true' +} + +export function ProviderProfileCard({ profile, activeProviderId, testResult, isTesting, onChange, onSave, onSelect, onTest }: { + profile: ProviderProfile + activeProviderId: string + testResult?: ProviderTestResponse + isTesting: boolean + onChange: (profile: ProviderProfile) => void + onSave: (profile: ProviderProfile) => void + onSelect: (providerId: string) => void + onTest: (providerId: string) => void +}) { + const active = profile.provider_id === activeProviderId + const remote = isRemote(profile) + return +
+
+

{profile.label}

+

{profile.provider_id} · {profile.provider_type}

+
+ {active ? 'Active' : profile.enabled ? 'Available' : 'Disabled'} +
+ {remote ?

Paid or remote provider: configure secrets in backend environment variables, not localStorage.

: null} + {!jsonModeVerified(profile) ?

JSON mode has not been verified for this provider/model.

: null} + {profile.provider_type === 'deterministic' + ?

Works offline with no API key and no paid provider dependency.

+ : } +
+ + +
+ onTest(profile.provider_id)} /> +
+} diff --git a/frontend/src/components/settings/ProviderTestPanel.tsx b/frontend/src/components/settings/ProviderTestPanel.tsx index eac20c8..edd8f7d 100644 --- a/frontend/src/components/settings/ProviderTestPanel.tsx +++ b/frontend/src/components/settings/ProviderTestPanel.tsx @@ -1,2 +1,17 @@ -import { Card } from '../shared/Card' -export function ProviderTestPanel() { return

ProviderTest Panel

MVP placeholder wired for local file-backed workflows.

} +import type { ProviderProfile, ProviderTestResponse } from '../../api/types' +import { Button } from '../shared/Button' +import { Badge } from '../shared/Badge' + +export function ProviderTestPanel({ profile, result, isTesting, onTest }: { profile: ProviderProfile; result?: ProviderTestResponse; isTesting: boolean; onTest: () => void }) { + const status = result?.status ?? (profile.enabled ? 'not tested' : 'not configured') + return
+
+ + {status} + {result ? {result.latency_ms} ms : null} +
+ {result?.detail ?

{result.detail}

: null} + {result?.models?.length ?

Models: {result.models.slice(0, 5).join(', ')}{result.models.length > 5 ? '…' : ''}

: null} + {result?.warnings?.length ?
    {result.warnings.map(warning =>
  • {warning}
  • )}
: null} +
+} diff --git a/frontend/src/styles/global.css b/frontend/src/styles/global.css index 9d8b8ee..70d8b40 100644 --- a/frontend/src/styles/global.css +++ b/frontend/src/styles/global.css @@ -37,3 +37,8 @@ select { border: 1px solid #c8d2e1; border-radius: 10px; padding: .65rem; font: .issue-list { display: grid; gap: .45rem; padding-left: 1.1rem; } .workbench-grid { display: grid; gap: 1rem; } @media (max-width: 980px) { .taxonomy-layout { grid-template-columns: 1fr; } .drawer { max-height: none; } } +.warning { color: #8a4b00; background: #fff7e6; border: 1px solid #ffd58a; border-radius: 10px; padding: .65rem; } +.settings-page { grid-column: 1 / -1; } +.provider-grid { display: grid; grid-template-columns: repeat(auto-fit, minmax(340px, 1fr)); gap: 1rem; } +.settings-form { display: grid; gap: .75rem; margin-top: .75rem; } +.provider-test-panel { border-top: 1px solid #edf1f7; margin-top: .75rem; padding-top: .75rem; } diff --git a/tests/test_api_settings.py b/tests/test_api_settings.py index c240142..4740519 100644 --- a/tests/test_api_settings.py +++ b/tests/test_api_settings.py @@ -7,3 +7,38 @@ def test_api_settings(): response = client.put("/api/settings", json={"llm_provider": "deterministic", "model": "local-keyword", "temperature": 0}) assert response.status_code == 200 assert response.json()["llm_provider"] == "deterministic" + + +def test_model_provider_defaults_and_deterministic_test(): + client = TestClient(app) + response = client.get("/api/settings/model-providers") + assert response.status_code == 200 + providers = response.json()["providers"] + provider_ids = {provider["provider_id"] for provider in providers} + assert "deterministic_baseline" in provider_ids + assert "lm_studio_local" in provider_ids + assert "ollama_local" in provider_ids + assert "openai_remote" in provider_ids + assert all("api_key" not in provider for provider in providers) + + test_response = client.post("/api/settings/model-providers/deterministic_baseline/test") + assert test_response.status_code == 200 + assert test_response.json()["status"] == "ok" + + +def test_active_provider_selection_and_secret_metadata_only(): + client = TestClient(app) + select_response = client.post("/api/settings/active-model-provider", json={"provider_id": "lm_studio_local"}) + assert select_response.status_code == 200 + assert select_response.json()["provider_id"] == "lm_studio_local" + + patch_response = client.patch( + "/api/settings/model-providers/lm_studio_local", + json={"api_key_env_var": "LM_STUDIO_API_KEY", "api_key": "should-not-be-stored"}, + ) + assert patch_response.status_code == 200 + payload = patch_response.json() + assert payload["api_key_env_var"] == "LM_STUDIO_API_KEY" + assert "should-not-be-stored" not in str(payload) + + client.post("/api/settings/active-model-provider", json={"provider_id": "deterministic_baseline"})