diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..bcd1feb --- /dev/null +++ b/.env.example @@ -0,0 +1,31 @@ +# Active routing (change these to switch provider/model without editing code) +LLM_ROUTER_PROVIDER=ollama +LLM_ROUTER_MODEL=qwen3.5:9b + +# Local Ollama +OLLAMA_BASE_URL=http://127.0.0.1:11434 +OLLAMA_MODEL=qwen3.5:9b + +# OpenAI +OPENAI_API_KEY= +OPENAI_BASE_URL=https://api.openai.com/v1 +OPENAI_MODEL=gpt-4.1-mini + +# OpenRouter +OPENROUTER_API_KEY= +OPENROUTER_BASE_URL=https://openrouter.ai/api/v1 +OPENROUTER_MODEL=openrouter/auto + +# Anthropic Claude +CLAUDE_API_KEY= +CLAUDE_MODEL=claude-3-5-sonnet-latest + +# Optional harness auth for clients/evals +# LLM_ROUTER_API_KEY= + +# Tooling (off by default in untrusted deployments) +# LLM_ROUTER_TOOL_EXEC_ENABLED=0 + +# ReAct / small-model loops (optional) +# LLM_ROUTER_DEFAULT_MAX_CONTEXT_TOKENS=8192 +# LLM_ROUTER_MAX_TOOL_CALLS_PER_REQUEST=24 diff --git a/.gitignore b/.gitignore index e8d928e..8f8e28a 100644 --- a/.gitignore +++ b/.gitignore @@ -3,8 +3,9 @@ zig-out .agents/ .ollama-qwen-env.example .env -client/* +__pycache__/ logs/* +tmp/ # General .DS_Store __MACOSX/ @@ -32,3 +33,4 @@ Network Trash Folder Temporary Items .apdisk .llm-router-env +.github/instructions/ \ No newline at end of file diff --git a/README.md b/README.md index ebfd0c5..5e29400 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ OpenAI-compatible LLM router and prompt-loop runtime written in Zig. -This project exposes a chat-completions API, routes requests across multiple providers, supports optional streaming, and includes built-in loop controls for iterative agent workflows. +This project exposes a chat-completions API, routes requests across multiple providers, supports optional streaming, and includes built-in loop controls for iterative agent workflows. Scope: **thin, reliable model harness** (see `phase.md` Phase 0 and `plan.md` for the frozen feature set and deferrals). > [!TIP] > Fastest local smoke test: @@ -19,6 +19,8 @@ This project exposes a chat-completions API, routes requests across multiple pro - Optional debug tooling (echo, utc, cmd, bash) with safe defaults. - Simple deploy surface: single Zig binary, environment-based configuration. + + ## Architecture ```mermaid @@ -55,9 +57,12 @@ flowchart TD ## Build, Run, and Test ```bash -zig build +zig build # ReleaseSafe by default (~2 MB stripped binary) zig build run zig build check +zig build windows # cross-compile zig-coding-agent.exe for x86_64-windows-gnu +zig build -Doptimize=Debug # larger binary with safety checks for development +zig build -Doptimize=ReleaseSmall # smallest binary (~800 KB) ``` ### Test Targets @@ -79,6 +84,62 @@ zig build test -Dtest-target=file "-Dtest-file=src/types.zig" zig build test -Dtest-target=all -Dtest-filter=normalizeProviderName ``` +### zig_eval (optional sibling checkout) + +Evaluations live in the sibling **zig_eval** repo. Check out both repos side by side: + +```text +zig/ +├── zig_coding_agent/ # this harness +└── zig_eval/ # registry-driven eval runner +``` + +`zig_eval/registry/services.json` targets this router at `http://127.0.0.1:8081` (`local-openai-compat`). + +**Prerequisites:** LLM provider reachable (e.g. Ollama), harness listening, and matching `LLM_ROUTER_API_KEY` if auth is enabled. + +```bash +# Terminal A: start harness — pick provider + model via env (or .env) +export LLM_ROUTER_PROVIDER=openrouter # ollama | openai | openrouter | claude | bedrock | llama_cpp +export LLM_ROUTER_MODEL=openai/gpt-4o-mini +export OPENROUTER_API_KEY=your-key # set the key for the active provider +zig build run -- --use-env + +# Switch to local Ollama later (restart server): +# export LLM_ROUTER_PROVIDER=ollama +# export LLM_ROUTER_MODEL=qwen3.5:9b + +# CLI overrides without editing env: +# zig build run -- --use-env --provider openai --model gpt-4.1-mini + +# Terminal B: run all registry evals against the live harness +zig build zig_evals + +# If .env points at a different default provider, probe the live provider explicitly +zig build zig_evals -Deval-provider=ollama + +# Wait up to 30s for the server to come up, then run evals +zig build zig_evals -Deval-wait-seconds=30 + +# Filter evals (pass-through to zig_eval CLI) +zig build zig_evals -- --eval smoke.reply_ok --format json +zig build zig_evals -- --group smoke --parallel 2 + +# Override registry or service +zig build zig_evals -Deval-registry=examples/registry -Deval-service=local-product +``` + +Readiness checks before evals start: + +1. `GET /health` on the harness (default `http://127.0.0.1:8081`) +2. A minimal `POST /v1/chat/completions` probe to confirm the LLM path works + +List evals directly from the sibling checkout: + +```bash +cd ../zig_eval && zig build run -- list --registry registry +``` + ## API Surface ### Endpoints @@ -89,11 +150,11 @@ zig build test -Dtest-target=all -Dtest-filter=normalizeProviderName | GET | /metrics | Yes (if API key configured) | Request and connection counters | | GET | /diagnostics/clients | Yes (if API key configured) | Connected client diagnostics | | GET | /diagnostics/requests | Yes (if API key configured) | Request success/failure diagnostics | -| GET | /diagnostics/providers | No | Provider status snapshot | +| GET | /diagnostics/providers | Yes (if API key configured) | Provider status snapshot | | POST | /v1/chat/completions | Yes (if API key configured) | OpenAI-compatible chat-completions | > [!NOTE] -> Authentication is enabled when LLM_ROUTER_API_KEY is non-empty. When enabled, all routes require auth except /health and /diagnostics/providers. +> Authentication is enabled when LLM_ROUTER_API_KEY is non-empty. When enabled, all routes require auth except /health. ### Chat Request Shape @@ -157,13 +218,71 @@ Loop controls: - --prompt initial prompt and loop entry - --provider provider override +- --model default model override (also available as `LLM_ROUTER_MODEL`) - --until completion marker (default: DONE) - --max-turns loop safety cap (default: 8) -- --loop-mode loop style +- --loop-mode loop style - --agent-loop shorthand for agent mode +- --react shorthand for ReAct reasoning mode - --use-env load .env - --env-file load a custom dotenv file +## ReAct Mode + +ReAct (Reasoning + Acting) mode implements the paradigm from [Yao et al., 2022](https://arxiv.org/abs/2210.03629). The model produces structured **Thought → Action** pairs, and the system executes each action and injects the result as an **Observation** before the next turn. + +### Available Actions + +| Action | Description | Requirements | +| --------- | -------------------------------- | --------------------------------------- | +| Search[q] | Search for information (stub) | None (future: wire to search tool/API) | +| Lookup[t] | Look up a term in context (stub) | None (future: wire to retrieval backend) | +| Cmd[c] | Execute a shell command | `LLM_ROUTER_TOOL_EXEC_ENABLED=1` | +| Finish[a] | Return final answer and stop loop | None | + +> [!NOTE] +> Search and Lookup return stub responses. They are designed as extension points for future tool/API integration. + +### CLI Example + +```bash +zig build run -- --react --prompt "What is the elevation range of the High Plains?" --provider ollama +``` + +### API Example + +```json +{ + "messages": [{ "role": "user", "content": "What is the elevation range of the High Plains?" }], + "loop_mode": "react", + "loop_max_turns": 24, + "tools": [ + { "name": "file_read", "description": "Read a project file" }, + { "name": "file_write", "description": "Write a project file" }, + { "name": "file_search", "description": "Search the repo" } + ] +} +``` + +When `loop_max_turns` is omitted, ReAct defaults to **24 turns** (vs 8 for basic/agent) so smaller models can solve coding tasks step-by-step. Tool-call budget scales with the turn cap. Set `max_context_tokens` (for example `8192`) so long loops compact history between turns. + +The server **auto-attaches** the coding tool set for `loop_mode: "react"` (`file_read`, `file_write`, `file_search`, plus `bash` or `cmd` on the host OS). Client-provided tools are preserved; missing defaults are merged in. `tool_choice` defaults to `auto` when omitted. Any HTTP client can use ReAct without sending a `tools` array. + +The model will produce output like: + +``` +Thought 1: I need to search for the High Plains elevation range. +Action 1: Search[High Plains elevation] +``` + +The system injects: + +``` +Observation 1: +``` + +This continues until the model emits `Action N: Finish[answer]` or the turn budget is exhausted. + ## Tools Registered tool names: @@ -172,11 +291,15 @@ Registered tool names: - utc - cmd - bash +- file_read +- file_write +- file_search Tool behavior summary: - echo and utc are deterministic debug helpers. - cmd and bash are guarded command-execution tools. +- file_read, file_write, and file_search are lightweight filesystem helpers for trusted local workflows. - Command execution is disabled by default and must be explicitly enabled. ```bash @@ -188,6 +311,10 @@ Related limits: - LLM_ROUTER_TOOL_EXEC_TIMEOUT_MS (default: 15000) - LLM_ROUTER_TOOL_EXEC_MAX_OUTPUT_BYTES (default: 65536) +- LLM_ROUTER_TOOL_EXEC_CONFIRM_REQUIRED (default: true; re-send with `LLM_ROUTER_TOOL_CONFIRM `) +- LLM_ROUTER_TOOL_EXEC_TRUSTED_LOCAL (default: false; allows pipes/chaining with denylist kept) +- LLM_ROUTER_TOOL_OUTPUT_OFFLOAD_BYTES (default: 8192; 0 disables offloading large tool output to disk) +- LLM_ROUTER_MAX_TOOL_CALLS_PER_REQUEST (default: 8) ## Configuration Reference @@ -199,11 +326,16 @@ Related limits: | LLM_ROUTER_PORT | 8081 | | LLM_ROUTER_DEBUG | 0 | | LLM_ROUTER_PROVIDER | ollama | +| LLM_ROUTER_MODEL | unset (uses per-provider `*_MODEL`) | | LLM_ROUTER_INSTANCE_ID | local-instance | | LLM_ROUTER_API_KEY | empty | | LLM_ROUTER_REQUEST_TIMEOUT_MS | 30000 | | LLM_ROUTER_PROVIDER_TIMEOUT_MS | 60000 | | LLM_ROUTER_LOOP_STREAM_PROGRESS_ENABLED | true | +| LLM_ROUTER_MAX_CONCURRENT_CONNECTIONS | 64 | +| LLM_ROUTER_MAX_REQUEST_BYTES | 1048576 | +| LLM_ROUTER_MAX_HEADER_BYTES | 16384 | +| LLM_ROUTER_DEFAULT_MAX_CONTEXT_TOKENS | unset | ### Session Storage @@ -212,14 +344,34 @@ Related limits: | LLM_ROUTER_SESSION_STORE_PATH | logs/sessions | | LLM_ROUTER_SESSION_RETENTION_MESSAGES | 24 | +### Workspace Mode (optional, in-memory) + +Lightweight Cursor-like mode: same `workspace_id` keeps conversation in RAM until the server exits. File tools resolve under the workspace root instead of `tmp/`. Disabled by default. + +| Variable | Default | +| -------------------------- | ------- | +| LLM_ROUTER_WORKSPACE_MODE | 0 | +| LLM_ROUTER_WORKSPACE_ROOT | `.` | + +Send `workspace_id` in the JSON body (separate from disk `session_id`): + +```bash +export LLM_ROUTER_WORKSPACE_MODE=1 +export LLM_ROUTER_WORKSPACE_ROOT=. + +curl -s http://127.0.0.1:8081/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{"workspace_id":"dev-1","messages":[{"role":"user","content":"Read src/main.zig"}],"tools":[{"name":"file_read","description":"read files"}],"tool_choice":"file_read"}' +``` + ### Ollama | Variable | Default | | --------------------- | ------------------------ | | OLLAMA_BASE_URL | | | OLLAMA_MODEL | qwen3.5:9b | -| OLLAMA_THINK | 0 | -| OLLAMA_NUM_PREDICT | 128 | +| OLLAMA_THINK | true | +| OLLAMA_NUM_PREDICT | 2048 | | OLLAMA_TEMPERATURE | 0.7 | | OLLAMA_REPEAT_PENALTY | 1.05 | @@ -290,12 +442,36 @@ Related limits: -d '{"messages":[{"role":"user","content":"Say hello from zig-coding-agent"}]}' ``` -4. Provider Diagnostics +4. Stateful Session Check + + ```bash + curl -s http://127.0.0.1:8081/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{"session_id":"demo-session","messages":[{"role":"user","content":"Remember that my test color is blue."}]}' + ``` + +5. Tool Check + + ```bash + curl -s http://127.0.0.1:8081/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{"messages":[{"role":"user","content":"What time is it in UTC?"}],"tools":[{"name":"utc","description":"Current UTC time"}],"tool_choice":"auto"}' + ``` + +6. Provider Diagnostics ```bash curl -s http://127.0.0.1:8081/diagnostics/providers ``` +## Short Runbook + +- 401 Unauthorized: set `LLM_ROUTER_API_KEY` on the server and send either `X-Api-Key: ` or `Authorization: Bearer `. +- 413 request_too_large: lower the prompt size or raise `LLM_ROUTER_MAX_REQUEST_BYTES` for trusted deployments. +- 504 provider_timeout: check provider reachability and `LLM_ROUTER_PROVIDER_TIMEOUT_MS`. +- unknown_tool: request only registered tools listed above. +- provider_not_configured: set the provider API key or switch to a local provider. + ## Project Layout ``` @@ -309,6 +485,7 @@ zig_coding_agent │ ├── api.zig │ ├── auth.zig │ ├── errors.zig + │ ├── mcp.zig │ ├── session.zig │ └── tools.zig ├── config.zig @@ -326,10 +503,12 @@ zig_coding_agent │ ├── openai.zig │ ├── openai_compatible.zig │ └── openrouter.zig + ├── react.zig ├── root.zig ├── tools │ ├── command_exec.zig │ ├── echo.zig + │ ├── file_ops.zig │ └── utc.zig └── types.zig ``` diff --git a/build.zig b/build.zig index dee0703..2f5ec79 100644 --- a/build.zig +++ b/build.zig @@ -9,7 +9,11 @@ const TestTarget = enum { pub fn build(b: *std.Build) void { const target = b.standardTargetOptions(.{}); - const optimize = b.standardOptimizeOption(.{}); + const optimize = b.option( + std.builtin.OptimizeMode, + "optimize", + "Prioritize performance, safety, or binary size (default: ReleaseSafe)", + ) orelse .ReleaseSafe; const package_name = "zig_coding_agent"; const exe_name = "zig-coding-agent"; @@ -25,6 +29,7 @@ pub fn build(b: *std.Build) void { const mod = b.addModule(package_name, .{ .root_source_file = b.path("src/root.zig"), .target = target, + .optimize = optimize, }); // CLI executable entrypoint. @@ -34,15 +39,102 @@ pub fn build(b: *std.Build) void { .root_source_file = b.path("src/main.zig"), .target = target, .optimize = optimize, + .strip = optimize != .Debug, .imports = &.{ .{ .name = package_name, .module = mod }, }, }), }); - app.linkLibC(); + // Linux native build avoids linkLibC (glibc/GCC 16 .sframe linker issues on Zig 0.15.2). + // Windows needs libc for socket APIs (recvfrom, etc.). + if (target.result.os.tag == .windows) { + app.linkLibC(); + } b.installArtifact(app); + // Cross-compile for Windows: `zig build windows` + const windows_target = b.resolveTargetQuery(.{ + .cpu_arch = .x86_64, + .os_tag = .windows, + .abi = .gnu, + }); + const windows_app = b.addExecutable(.{ + .name = exe_name, + .root_module = b.createModule(.{ + .root_source_file = b.path("src/main.zig"), + .target = windows_target, + .optimize = optimize, + .strip = optimize != .Debug, + .imports = &.{ + .{ .name = package_name, .module = mod }, + }, + }), + }); + windows_app.linkLibC(); + const windows_step = b.step("windows", "Cross-compile ReleaseSafe binary for x86_64-windows-gnu"); + windows_step.dependOn(&b.addInstallArtifact(windows_app, .{}).step); + + // Sibling eval runner: `zig build zig_evals` (requires ../zig_eval and a running harness). + const zig_eval_root = b.option([]const u8, "eval-root", "Path to sibling zig_eval checkout") orelse "../zig_eval"; + const eval_registry = b.option([]const u8, "eval-registry", "Registry directory inside zig_eval") orelse "registry"; + const eval_service = b.option([]const u8, "eval-service", "Service name from registry/services.json") orelse "local-openai-compat"; + const harness_base = b.option([]const u8, "eval-harness-url", "Harness base URL for readiness checks") orelse "http://127.0.0.1:8081"; + const eval_wait_seconds = b.option(u32, "eval-wait-seconds", "Seconds to wait for harness readiness before failing") orelse 0; + const eval_provider = b.option([]const u8, "eval-provider", "Provider to use for the eval readiness chat probe"); + + const preflight_script = b.fmt( + \\if [ -f ".env" ]; then set -a; . ./.env; set +a; fi + \\if [ -n "${{ZIG_EVAL_PROVIDER:-}}" ]; then PROVIDER="$ZIG_EVAL_PROVIDER"; else PROVIDER="${{LLM_ROUTER_PROVIDER:-ollama}}"; fi + \\for i in $(seq 0 {d}); do + \\ if curl -sf "{s}/health" >/dev/null; then + \\ PAYLOAD=$(printf '{{"messages":[{{"role":"user","content":"Reply OK"}}],"provider":"%s","model":"auto"}}' "$PROVIDER") + \\ if curl -sf -X POST "{s}/v1/chat/completions" \ + \\ -H "Content-Type: application/json" \ + \\ ${{LLM_ROUTER_API_KEY:+-H "Authorization: Bearer $LLM_ROUTER_API_KEY"}} \ + \\ -d "$PAYLOAD" >/dev/null; then + \\ exit 0 + \\ fi + \\ echo "error: harness chat probe failed at {s}/v1/chat/completions (provider=$PROVIDER; is the LLM configured?)" >&2 + \\ exit 1 + \\ fi + \\ if [ "$i" -eq {d} ]; then break; fi + \\ sleep 1 + \\done + \\echo "error: harness not reachable at {s}/health (start it with: zig build run -- --use-env)" >&2 + \\exit 1 + , + .{ eval_wait_seconds, harness_base, harness_base, harness_base, eval_wait_seconds, harness_base }, + ); + + const eval_preflight = b.addSystemCommand(&.{ "sh", "-c", preflight_script }); + if (eval_provider) |provider| { + eval_preflight.setEnvironmentVariable("ZIG_EVAL_PROVIDER", provider); + } + + const build_zig_eval = b.addSystemCommand(&.{ "zig", "build", "-Doptimize=ReleaseSafe" }); + build_zig_eval.setCwd(b.path(zig_eval_root)); + build_zig_eval.step.dependOn(&eval_preflight.step); + + const run_evals_script = b.fmt( + \\if [ -f ".env" ]; then set -a; . ./.env; set +a; fi + \\cd "{s}" && zig build run -Doptimize=ReleaseSafe -- run --registry {s} --service {s} "$@" + , + .{ zig_eval_root, eval_registry, eval_service }, + ); + + const run_evals = b.addSystemCommand(&.{ "sh", "-c", run_evals_script, "zig_evals" }); + run_evals.step.dependOn(&build_zig_eval.step); + if (b.args) |args| { + run_evals.addArgs(args); + } + + const zig_evals_step = b.step( + "zig_evals", + "Run zig_eval registry against a live harness (requires ../zig_eval; start server first)", + ); + zig_evals_step.dependOn(&run_evals.step); + // Run command: `zig build run -- [args]`. const run_step = b.step("run", "Run the Zig Coding Agent server"); const run_cmd = b.addRunArtifact(app); diff --git a/client/requirements.txt b/client/requirements.txt new file mode 100644 index 0000000..59c3cea --- /dev/null +++ b/client/requirements.txt @@ -0,0 +1,2 @@ +streamlit>=1.33.0 +requests>=2.31.0 diff --git a/client/test_client.py b/client/test_client.py new file mode 100644 index 0000000..e97924c --- /dev/null +++ b/client/test_client.py @@ -0,0 +1,477 @@ +import json +from dataclasses import dataclass +from typing import Dict, Generator, List, Optional +from urllib.parse import urlsplit, urlunsplit + +import requests +import streamlit as st + +st.set_page_config(page_title="Zig Agent Test Client", layout="wide") +st.title("Zig AI Harness Test Client") +st.caption("Debug and test harness behavior: chat routing, tool execution, sessions, streaming, and diagnostics.") + +CHAT_STATE_KEY = "chat_messages" +PENDING_CMD_CONFIRMATION_KEY = "pending_cmd_confirmation" + +CMD_CONFIRMATION_TAG = "DEBUG_TOOL_CONFIRMATION_REQUIRED" +CONFIRM_MARKER_PREFIX = "LLM_ROUTER_TOOL_CONFIRM " + + +@dataclass +class ChatMessage: + role: str + content: str + + +def build_headers(api_key: str) -> Dict[str, str]: + headers = {"Content-Type": "application/json"} + if api_key.strip(): + headers["x-api-key"] = api_key.strip() + return headers + + +def parse_assistant_content(data: Dict) -> str: + choices = data.get("choices", []) + if not choices: + return "" + message = choices[0].get("message", {}) + return message.get("content", "") + + +def base_url_from_endpoint(endpoint: str) -> str: + parsed = urlsplit(endpoint.strip()) + if not parsed.scheme or not parsed.netloc: + return "" + return urlunsplit((parsed.scheme, parsed.netloc, "", "", "")) + + +def stream_probe_lines(response: requests.Response) -> Generator[str, None, None]: + content_type = (response.headers.get("content-type") or "").lower() + + if "text/event-stream" not in content_type: + body = response.text + yield "Server did not return SSE stream (text/event-stream)." + yield f"Content-Type: {content_type or ''}" + if body: + yield "" + yield "Raw response body:" + yield body + return + + for raw_line in response.iter_lines(decode_unicode=True): + if not raw_line: + continue + + line = raw_line.strip() + if not line.startswith("data:"): + continue + + payload = line[5:].strip() + if payload == "[DONE]": + yield "\n\n[stream done]" + return + + try: + event = json.loads(payload) + except json.JSONDecodeError: + yield payload + continue + + choices = event.get("choices", []) + if not choices: + yield payload + continue + + delta = choices[0].get("delta", {}) + token = delta.get("content") + if token is not None: + yield token + + +def ensure_chat_state() -> None: + if CHAT_STATE_KEY not in st.session_state: + st.session_state[CHAT_STATE_KEY] = [ + ChatMessage(role="system", content="You are a helpful assistant. Keep replies concise."), + ] + + +def chat_state() -> List[ChatMessage]: + ensure_chat_state() + return st.session_state[CHAT_STATE_KEY] + + +def add_chat_message(role: str, content: str) -> None: + ensure_chat_state() + st.session_state[CHAT_STATE_KEY].append(ChatMessage(role=role, content=content)) + + +def render_chat(messages: List[ChatMessage]) -> None: + for msg in messages: + if msg.role == "system": + continue + with st.chat_message(msg.role): + st.markdown(msg.content) + + +def parse_cmd_confirmation(text: str) -> Optional[Dict[str, str]]: + if CMD_CONFIRMATION_TAG not in text: + return None + + tool: Optional[str] = None + command: Optional[str] = None + confirm_token: Optional[str] = None + + for raw_line in text.splitlines(): + line = raw_line.strip() + if line.startswith("tool="): + tool = line[len("tool=") :].strip() + elif line.startswith("command="): + command = line[len("command=") :].strip() + elif line.startswith("confirm_token="): + confirm_token = line[len("confirm_token=") :].strip() + + if tool and command and confirm_token: + return {"tool": tool, "command": command, "confirm_token": confirm_token} + return None + + +def set_pending_cmd_confirmation(value: Optional[Dict[str, str]]) -> None: + st.session_state[PENDING_CMD_CONFIRMATION_KEY] = value + + +with st.sidebar: + st.header("Connection") + endpoint = st.text_input( + "Server endpoint", + value="http://127.0.0.1:8081/v1/chat/completions", + help="Your Zig server chat completion endpoint.", + ) + api_key = st.text_input( + "Client API key (optional)", + value="", + type="password", + help="Sends x-api-key header if provided.", + ) + + st.header("Request") + provider = st.text_input("provider", value="ollama") + model = st.text_input("model", value="auto") + thinking = st.checkbox( + "thinking", + value=False, + help="Enable model reasoning/thinking tokens. Turn off for faster and shorter replies.", + ) + temperature = st.slider( + "temperature", + min_value=0.0, + max_value=2.0, + value=0.7, + step=0.05, + help="Lower values are more deterministic. Higher values are more creative.", + ) + session_id = st.text_input("session_id (optional)", value="") + tenant_id = st.text_input("tenant_id (optional)", value="") + max_context_tokens = st.number_input( + "max_context_tokens (optional)", + min_value=0, + value=0, + step=256, + ) + + st.header("Loop") + loop_mode = st.selectbox( + "loop_mode", + options=["none", "basic", "agent", "react"], + index=0, + help="Server-side iterative loop mode usable from any frontend client.", + ) + loop_until = st.text_input("loop_until", value="DONE") + loop_max_turns = st.number_input( + "loop_max_turns", + min_value=0, + max_value=128, + value=0, + step=1, + help="0 means disabled and server defaults apply when loop_mode is used.", + ) + + st.header("Tools (debug)") + + import platform as _platform + _is_windows = _platform.system().lower().startswith("win") + _shell_tool = "cmd" if _is_windows else "bash" + CODING_TOOLS = [_shell_tool, "file_read", "file_write", "file_search"] + + auto_tools_for_react = st.checkbox( + "Also send coding tools from client when loop_mode=react", + value=False, + help=( + "The server already auto-attaches the coding tool set for react loops. " + "Enable this only to duplicate tools in the JSON payload (e.g. for debugging)." + ), + ) + + selected_tools = st.multiselect( + "Available tools to send (manual selection)", + options=["echo", "utc", "cmd", "bash", "file_read", "file_write", "file_search"], + default=CODING_TOOLS, + help="Multiple tools can be sent in one request. cmd/bash require server env LLM_ROUTER_TOOL_EXEC_ENABLED=1.", + ) + + tool_choice = st.selectbox( + "tool_choice", + options=[ + "auto", + "none", + "required", + "echo", + "utc", + "cmd", + "bash", + "file_read", + "file_write", + "file_search", + ], + index=0, + help="Use auto for prompt-driven tool execution with multiple tools.", + ) + + if auto_tools_for_react and loop_mode == "react": + selected_tools = CODING_TOOLS + tool_choice = "auto" + st.caption( + f"react mode active — sending tools={CODING_TOOLS} with tool_choice=auto." + ) + + +def build_payload( + messages: List[ChatMessage], + force_stream: bool, + tool_choice_override: Optional[str] = None, + tools_override: Optional[List[str]] = None, +) -> Dict: + payload: Dict = { + "provider": provider.strip() or "ollama", + "messages": [{"role": m.role, "content": m.content} for m in messages if m.content.strip()], + "think": bool(thinking), + "temperature": float(temperature), + } + + if model.strip(): + payload["model"] = model.strip() + + if session_id.strip(): + payload["session_id"] = session_id.strip() + + if tenant_id.strip(): + payload["tenant_id"] = tenant_id.strip() + + if max_context_tokens > 0: + payload["max_context_tokens"] = int(max_context_tokens) + + if loop_mode != "none": + payload["loop_mode"] = loop_mode + if loop_until.strip(): + payload["loop_until"] = loop_until.strip() + if loop_max_turns > 0: + payload["loop_max_turns"] = int(loop_max_turns) + + tools_to_send = tools_override if tools_override is not None else selected_tools + + if tools_to_send: + tool_descriptions = { + "echo": "Return a deterministic echo response for harness debugging.", + "utc": "Return current UTC timestamp for deterministic tool-path checks.", + "cmd": "Execute a Windows cmd command from prompt text (debug use).", + "bash": "Execute a bash command from prompt text (debug use).", + "file_read": "Read a relative file path and return its content (debug use).", + "file_write": "Write a relative file path from a fenced code block (debug use).", + "file_search": "Search a substring under a relative directory (debug use).", + } + payload["tools"] = [ + {"name": tool_name, "description": tool_descriptions.get(tool_name, "debug tool")} + for tool_name in tools_to_send + ] + if tool_choice_override is not None: + payload["tool_choice"] = tool_choice_override + elif tool_choice != "none": + payload["tool_choice"] = tool_choice + + if force_stream: + payload["stream"] = True + + return payload + + +def request_timeout_seconds(force_stream: bool) -> int: + timeout = 120 if force_stream else 60 + + if loop_mode != "none": + timeout = max(timeout, 240) + + if ("cmd" in selected_tools) or ("bash" in selected_tools): + timeout = max(timeout, 120) + + return timeout + + +def send_chat_request( + force_stream: bool, + tool_choice_override: Optional[str] = None, + tools_override: Optional[List[str]] = None, +) -> Optional[Dict]: + if not endpoint.strip(): + st.error("Endpoint is required.") + return None + + payload = build_payload( + chat_state(), + force_stream=force_stream, + tool_choice_override=tool_choice_override, + tools_override=tools_override, + ) + st.code(json.dumps(payload, indent=2), language="json") + + try: + timeout_sec = request_timeout_seconds(force_stream=force_stream) + resp = requests.post( + endpoint.strip(), + headers=build_headers(api_key), + json=payload, + timeout=timeout_sec, + stream=force_stream, + ) + except requests.Timeout: + st.error( + "Request timed out. Increase timeout by disabling loop mode, reducing loop_max_turns, " + "or simplifying tool usage in the prompt." + ) + return None + except requests.RequestException as exc: + st.error(f"Request failed: {exc}") + return None + + st.write(f"HTTP {resp.status_code}") + + if force_stream: + if resp.status_code >= 400: + st.error("Server returned an error status during stream.") + st.code(resp.text) + return None + + with st.chat_message("assistant"): + streamed_text = st.write_stream(stream_probe_lines(resp)) + assistant_text = str(streamed_text) + add_chat_message("assistant", assistant_text) + set_pending_cmd_confirmation(parse_cmd_confirmation(assistant_text)) + return None + + try: + data = resp.json() + except ValueError: + st.error("Response is not valid JSON.") + st.code(resp.text) + return None + + st.json(data) + assistant = parse_assistant_content(data) + if assistant: + with st.chat_message("assistant"): + st.markdown(assistant) + add_chat_message("assistant", assistant) + set_pending_cmd_confirmation(parse_cmd_confirmation(assistant)) + return data + + +st.subheader("Chat") +ensure_chat_state() + +chat_controls_col1, chat_controls_col2, chat_controls_col3 = st.columns(3) +with chat_controls_col1: + if st.button("Clear chat", use_container_width=True): + st.session_state.pop(CHAT_STATE_KEY, None) + ensure_chat_state() +with chat_controls_col2: + if st.button("Add demo prompt", use_container_width=True): + add_chat_message( + "user", + "write a python script to add two numbers then execute and print it's output", + ) +with chat_controls_col3: + stream_mode = st.checkbox("Stream responses (SSE)", value=False) + +render_chat(chat_state()) + +user_input = st.chat_input("Type a message and press Enter…") +if user_input and user_input.strip(): + add_chat_message("user", user_input.strip()) + if stream_mode: + if loop_mode != "none": + st.info( + "Loop progress visibility is controlled by server setting " + "LLM_ROUTER_LOOP_STREAM_PROGRESS_ENABLED (default: enabled)." + ) + send_chat_request(force_stream=True) + else: + send_chat_request(force_stream=False) + +pending_cmd = st.session_state.get(PENDING_CMD_CONFIRMATION_KEY) +if pending_cmd: + tool = pending_cmd["tool"] + st.warning(f"Backend requested confirmation to execute `{tool}`.") + st.code(pending_cmd["command"]) + if st.button("Confirm command execution", use_container_width=True): + confirm_prompt = ( + f"{pending_cmd['command']}\n{CONFIRM_MARKER_PREFIX}{pending_cmd['confirm_token']}" + ) + add_chat_message("user", confirm_prompt) + + tools_override = selected_tools if isinstance(selected_tools, list) else [] + if tool not in tools_override: + tools_override = tools_override + [tool] + + # Force the tool choice so the backend executes immediately after confirmation. + set_pending_cmd_confirmation(None) + send_chat_request( + force_stream=stream_mode, + tool_choice_override=tool, + tools_override=tools_override, + ) + + +st.subheader("Diagnostics") +diag_col1, diag_col2, diag_col3 = st.columns(3) +health_clicked = diag_col1.button("GET /health", use_container_width=True) +metrics_clicked = diag_col2.button("GET /metrics", use_container_width=True) +providers_clicked = diag_col3.button("GET /diagnostics/providers", use_container_width=True) + +if health_clicked or metrics_clicked or providers_clicked: + base_url = base_url_from_endpoint(endpoint) + if not base_url: + st.error("Could not infer base URL from endpoint.") + else: + if health_clicked: + target_path = "/health" + elif metrics_clicked: + target_path = "/metrics" + else: + target_path = "/diagnostics/providers" + + target_url = f"{base_url}{target_path}" + st.code(target_url) + + try: + resp = requests.get( + target_url, + headers=build_headers(api_key), + timeout=30, + ) + except requests.RequestException as exc: + st.error(f"Diagnostics request failed: {exc}") + else: + st.write(f"HTTP {resp.status_code}") + try: + st.json(resp.json()) + except ValueError: + st.code(resp.text) + diff --git a/src/backend/api.zig b/src/backend/api.zig index 20ec499..b550cf6 100644 --- a/src/backend/api.zig +++ b/src/backend/api.zig @@ -385,6 +385,29 @@ pub fn parseChatRequest( null; errdefer if (tenant_id) |value| allocator.free(value); + const workspace_id = if (obj.get("workspace_id")) |workspace_id_value| + switch (workspace_id_value) { + .string => |value| blk: { + const trimmed = std.mem.trim(u8, value, " \t\r\n"); + if (trimmed.len == 0) { + return .{ .err = errors.validationError( + "workspace_id must not be empty", + "workspace_id", + "invalid_workspace_id", + ) }; + } + break :blk try allocator.dupe(u8, trimmed); + }, + else => return .{ .err = errors.validationError( + "workspace_id must be a string", + "workspace_id", + "invalid_workspace_id", + ) }, + } + else + null; + errdefer if (workspace_id) |value| allocator.free(value); + const max_context_tokens = if (obj.get("max_context_tokens")) |tokens_value| switch (tokens_value) { .integer => |value| blk: { @@ -435,7 +458,37 @@ pub fn parseChatRequest( ) }, }; - const name = switch (tool_obj.get("name") orelse { + if (tool_obj.get("type")) |type_value| { + const tool_type = switch (type_value) { + .string => |value| value, + else => return .{ .err = errors.validationError( + "tool.type must be a string", + "tools", + "invalid_tools", + ) }, + }; + if (!std.ascii.eqlIgnoreCase(tool_type, "function")) { + return .{ .err = errors.validationError( + "tool.type must be function", + "tools", + "invalid_tools", + ) }; + } + } + + const function_obj = if (tool_obj.get("function")) |function_value| + switch (function_value) { + .object => |value| value, + else => return .{ .err = errors.validationError( + "tool.function must be a JSON object", + "tools", + "invalid_tools", + ) }, + } + else + tool_obj; + + const name = switch (function_obj.get("name") orelse { return .{ .err = errors.validationError( "Each tool must include a name", "tools", @@ -478,7 +531,7 @@ pub fn parseChatRequest( } } - const description = if (tool_obj.get("description")) |description_value| + const description = if (function_obj.get("description")) |description_value| switch (description_value) { .string => |value| value, else => return .{ .err = errors.validationError( @@ -490,9 +543,23 @@ pub fn parseChatRequest( else ""; + const parameters_json = if (function_obj.get("parameters")) |parameters_value| params_blk: { + switch (parameters_value) { + .object => {}, + else => return .{ .err = errors.validationError( + "tool.function.parameters must be a JSON object", + "tools", + "invalid_tools", + ) }, + } + break :params_blk try jsonValueStringifyAlloc(allocator, parameters_value); + } else null; + errdefer if (parameters_json) |value| allocator.free(value); + try tools.append(allocator, .{ .name = try allocator.dupe(u8, trimmed_name), .description = try allocator.dupe(u8, description), + .parameters_json = parameters_json, }); } @@ -548,9 +615,9 @@ pub fn parseChatRequest( ) }; } - if (!std.ascii.eqlIgnoreCase(trimmed, "basic") and !std.ascii.eqlIgnoreCase(trimmed, "agent")) { + if (!std.ascii.eqlIgnoreCase(trimmed, "basic") and !std.ascii.eqlIgnoreCase(trimmed, "agent") and !std.ascii.eqlIgnoreCase(trimmed, "react")) { return .{ .err = errors.validationError( - "loop_mode must be one of: basic, agent", + "loop_mode must be one of: basic, agent, react", "loop_mode", "invalid_loop_mode", ) }; @@ -683,6 +750,7 @@ pub fn parseChatRequest( .repeat_penalty = repeat_penalty, .session_id = session_id, .tenant_id = tenant_id, + .workspace_id = workspace_id, .max_context_tokens = max_context_tokens, .tools = parsed_tools, .tool_choice = tool_choice, @@ -748,6 +816,124 @@ pub fn callProvider( return error.UnknownProvider; } +/// Outcome of a single ReAct/agent loop turn driven through the streaming +/// pipeline. `streamed` means tokens were emitted live to the SSE client and +/// `captured_content` holds the assistant's content text (no reasoning) for +/// action parsing or context replay. +pub const TurnStreamOutcome = struct { + success: bool, + finish_reason: []u8, + model: []u8, + captured_content: []u8, + + pub fn deinit(self: TurnStreamOutcome, allocator: std.mem.Allocator) void { + allocator.free(self.finish_reason); + allocator.free(self.model); + allocator.free(self.captured_content); + } +}; + +/// Streams a single loop turn directly to the SSE client. +/// +/// For providers with a dedicated SSE adapter (Ollama today) this forwards +/// tokens as they arrive. For all other providers, falls back to the buffered +/// `callProvider` path and emits the whole response as one SSE chunk so the +/// loop semantics remain identical. +/// +/// `headers_sent` and `think_block_open` are caller-owned across turns so a +/// single SSE response can carry many turns without re-emitting headers or +/// leaving an unclosed `` block. +pub fn streamProviderTurn( + connection: std.net.Server.Connection, + allocator: std.mem.Allocator, + app_config: *const config.Config, + request: types.Request, + completion_id: []const u8, + headers_sent: *bool, + think_block_open: *bool, +) !TurnStreamOutcome { + const response_mod = @import("../core/response.zig"); + + if (!headers_sent.*) { + try response_mod.sendEventStreamHeaders(connection, null); + headers_sent.* = true; + } + + const requested_provider = request.provider orelse app_config.default_provider; + const provider = types.normalizeProviderName(requested_provider) orelse { + return error.UnknownProvider; + }; + + if (std.mem.eql(u8, provider, "ollama_qwen")) { + const ollama_qwen = @import("../providers/ollama_qwen.zig"); + + var captured = std.ArrayList(u8){}; + errdefer captured.deinit(allocator); + + const turn_result = try ollama_qwen.streamQwenTurnToSse( + connection, + allocator, + app_config, + request, + completion_id, + &captured, + think_block_open, + ); + + switch (turn_result) { + .streamed => |success| { + defer success.deinit(allocator); + return .{ + .success = true, + .finish_reason = try allocator.dupe(u8, success.finish_reason), + .model = try allocator.dupe(u8, success.model), + .captured_content = try captured.toOwnedSlice(allocator), + }; + }, + .failed => |provider_response| { + defer provider_response.deinit(allocator); + captured.deinit(allocator); + return .{ + .success = false, + .finish_reason = try allocator.dupe(u8, provider_response.finish_reason), + .model = try allocator.dupe(u8, provider_response.model), + .captured_content = try allocator.dupe(u8, provider_response.output), + }; + }, + } + } + + // Fallback path for providers without per-token streaming: buffer via + // callProvider, then emit the whole response as a single SSE chunk. + if (think_block_open.*) { + try response_mod.sendChatCompletionChunkSse( + connection, allocator, completion_id, "fallback", "", null, + ); + think_block_open.* = false; + } + + var buffered = try callProvider(allocator, app_config, request); + defer buffered.deinit(allocator); + + if (buffered.output.len > 0) { + try response_mod.sendChatCompletionChunkSse( + connection, + allocator, + completion_id, + buffered.model, + buffered.output, + null, + ); + } + + return .{ + .success = buffered.success, + .finish_reason = try allocator.dupe(u8, buffered.finish_reason), + .model = try allocator.dupe(u8, buffered.model), + .captured_content = try allocator.dupe(u8, buffered.output), + }; +} + pub fn buildProviderStatusJson( allocator: std.mem.Allocator, app_config: *const config.Config, @@ -849,6 +1035,18 @@ fn hasToolNamed(tools: []const types.Tool, name: []const u8) bool { return false; } +fn jsonValueStringifyAlloc(allocator: std.mem.Allocator, value: std.json.Value) ![]u8 { + var out = std.Io.Writer.Allocating.init(allocator); + errdefer out.deinit(); + + var json_writer = std.json.Stringify{ + .writer = &out.writer, + .options = .{}, + }; + try json_writer.write(value); + return try out.toOwnedSlice(); +} + fn extractLastUserPrompt( allocator: std.mem.Allocator, messages: []const types.Message, @@ -1075,3 +1273,45 @@ test "parseChatRequest validates tool_choice against requested tools" { }, } } + +test "parseChatRequest accepts OpenAI function tool schema" { + const allocator = std.testing.allocator; + const body = + \\{ + \\ "messages": [ + \\ {"role": "user", "content": "Use echo"} + \\ ], + \\ "tools": [ + \\ { + \\ "type": "function", + \\ "function": { + \\ "name": "echo", + \\ "description": "Return a message unchanged.", + \\ "parameters": { + \\ "type": "object", + \\ "properties": { + \\ "message": {"type": "string"} + \\ }, + \\ "required": ["message"] + \\ } + \\ } + \\ } + \\ ], + \\ "tool_choice": "echo" + \\} + ; + + const parsed = try parseChatRequest(allocator, body); + const request = switch (parsed) { + .ok => |request| request, + .err => return error.UnexpectedApiError, + }; + defer request.deinit(allocator); + + try std.testing.expectEqual(@as(usize, 1), request.tools.len); + try std.testing.expectEqualStrings("echo", request.tools[0].name); + try std.testing.expectEqualStrings("Return a message unchanged.", request.tools[0].description); + try std.testing.expect(request.tools[0].parameters_json != null); + try std.testing.expect(std.mem.indexOf(u8, request.tools[0].parameters_json.?, "\"message\"") != null); + try std.testing.expectEqualStrings("echo", request.tool_choice.?); +} diff --git a/src/backend/errors.zig b/src/backend/errors.zig index 3b5c1f8..e7ccb2f 100644 --- a/src/backend/errors.zig +++ b/src/backend/errors.zig @@ -273,3 +273,15 @@ test "providerFailureFromDetail maps missing key configuration" { try std.testing.expectEqual(@as(u16, 500), mapped.status_code); try std.testing.expectEqualStrings("provider_not_configured", mapped.code.?); } + +test "providerFailureFromDetail maps provider auth failures" { + const mapped = providerFailureFromDetail("openrouter", "OpenRouter returned HTTP 401"); + try std.testing.expectEqual(@as(u16, 502), mapped.status_code); + try std.testing.expectEqualStrings("provider_auth_failed", mapped.code.?); +} + +test "providerFailureFromDetail maps invalid provider payloads" { + const mapped = providerFailureFromDetail("bedrock", "Bedrock returned invalid JSON"); + try std.testing.expectEqual(@as(u16, 502), mapped.status_code); + try std.testing.expectEqualStrings("provider_invalid_response", mapped.code.?); +} diff --git a/src/backend/mcp.zig b/src/backend/mcp.zig new file mode 100644 index 0000000..a3f7466 --- /dev/null +++ b/src/backend/mcp.zig @@ -0,0 +1,133 @@ +//! Minimal MCP discovery bridge. +//! +//! This module intentionally stops at bounded tool metadata parsing and +//! allowlist registration. Transport and invocation stay outside the hot path +//! until an MCP server is explicitly wired in. +const std = @import("std"); +const types = @import("../types.zig"); +const tooling = @import("tools.zig"); + +pub fn parseDiscoveredToolsAlloc( + allocator: std.mem.Allocator, + discovery_json: []const u8, + max_tools: usize, +) ![]types.Tool { + var parsed = std.json.parseFromSlice(std.json.Value, allocator, discovery_json, .{}) catch |err| switch (err) { + error.OutOfMemory => return err, + else => return try allocator.alloc(types.Tool, 0), + }; + defer parsed.deinit(); + + const root = switch (parsed.value) { + .object => |object| object, + else => return try allocator.alloc(types.Tool, 0), + }; + + const tools_value = root.get("tools") orelse return try allocator.alloc(types.Tool, 0); + const tool_items = switch (tools_value) { + .array => |array| array.items, + else => return try allocator.alloc(types.Tool, 0), + }; + + var out = std.ArrayList(types.Tool){}; + errdefer { + for (out.items) |tool| { + tool.deinit(allocator); + } + out.deinit(allocator); + } + + for (tool_items) |item| { + if (out.items.len >= max_tools) break; + + const object = switch (item) { + .object => |value| value, + else => continue, + }; + + const raw_name = switch (object.get("name") orelse continue) { + .string => |value| std.mem.trim(u8, value, " \t\r\n"), + else => continue, + }; + if (!isSafeToolName(raw_name)) continue; + if (hasToolNamed(out.items, raw_name)) continue; + + const raw_description = if (object.get("description")) |description| + switch (description) { + .string => |value| std.mem.trim(u8, value, " \t\r\n"), + else => "", + } + else + ""; + + try out.append(allocator, .{ + .name = try allocator.dupe(u8, raw_name), + .description = try allocator.dupe(u8, raw_description[0..@min(raw_description.len, 512)]), + }); + } + + return try out.toOwnedSlice(allocator); +} + +pub fn registerDiscoveredTools( + registry: *tooling.ToolRegistry, + allocator: std.mem.Allocator, + discovered_tools: []const types.Tool, +) !void { + for (discovered_tools) |tool| { + try registry.register(allocator, tool.name); + } +} + +fn hasToolNamed(tools: []const types.Tool, name: []const u8) bool { + for (tools) |tool| { + if (std.mem.eql(u8, tool.name, name)) return true; + } + return false; +} + +fn isSafeToolName(name: []const u8) bool { + if (name.len == 0 or name.len > 64) return false; + for (name) |c| { + switch (c) { + 'a'...'z', 'A'...'Z', '0'...'9', '_', '-', '.' => {}, + else => return false, + } + } + return true; +} + +test "parseDiscoveredToolsAlloc maps bounded MCP metadata" { + const allocator = std.testing.allocator; + const json = + \\{"tools":[ + \\ {"name":"search.docs","description":"Search project docs"}, + \\ {"name":"bad name","description":"rejected"}, + \\ {"name":"search.docs","description":"duplicate"} + \\]} + ; + + const tools = try parseDiscoveredToolsAlloc(allocator, json, 8); + defer { + for (tools) |tool| { + tool.deinit(allocator); + } + allocator.free(tools); + } + + try std.testing.expectEqual(@as(usize, 1), tools.len); + try std.testing.expectEqualStrings("search.docs", tools[0].name); +} + +test "registerDiscoveredTools maps MCP tools into allowlist" { + const allocator = std.testing.allocator; + const discovered = [_]types.Tool{ + .{ .name = "search.docs", .description = "Search project docs" }, + }; + + var registry = tooling.ToolRegistry.init(allocator); + defer registry.deinit(allocator); + + try registerDiscoveredTools(®istry, allocator, discovered[0..]); + try std.testing.expect(registry.isAllowed("search.docs")); +} diff --git a/src/backend/session.zig b/src/backend/session.zig index 194eafa..ceae5aa 100644 --- a/src/backend/session.zig +++ b/src/backend/session.zig @@ -212,6 +212,57 @@ pub fn shouldCompressContext(estimated_tokens: usize, max_context_tokens: usize) return estimated_tokens > max_context_tokens; } +pub fn compactContextToBudgetAlloc( + allocator: std.mem.Allocator, + messages: []const types.Message, + max_context_tokens: usize, + keep_last_messages: usize, +) ![]types.Message { + if (!shouldCompressContext(estimateTokenCount(messages), max_context_tokens)) { + return try cloneMessagesAlloc(allocator, messages); + } + + const keep_count = @min(keep_last_messages, messages.len); + const summary = try compressContextSummaryAlloc(allocator, messages, keep_count); + defer allocator.free(summary); + + var out = std.ArrayList(types.Message){}; + errdefer { + for (out.items) |message| { + message.deinit(allocator); + } + out.deinit(allocator); + } + + if (summary.len > 0) { + try out.append(allocator, .{ + .role = try allocator.dupe(u8, "system"), + .content = try allocator.dupe(u8, summary), + }); + } + + const start = messages.len - keep_count; + for (messages[start..]) |message| { + try out.append(allocator, .{ + .role = try allocator.dupe(u8, message.role), + .content = try allocator.dupe(u8, message.content), + }); + } + + while (out.items.len > 2 and estimateTokenCount(out.items) > max_context_tokens) { + const drop_index: usize = if (std.mem.eql(u8, out.items[0].role, "system")) 1 else 0; + var dropped = out.orderedRemove(drop_index); + dropped.deinit(allocator); + } + + if (out.items.len > 1 and estimateTokenCount(out.items) > max_context_tokens and std.mem.eql(u8, out.items[0].role, "system")) { + var dropped = out.orderedRemove(0); + dropped.deinit(allocator); + } + + return try out.toOwnedSlice(allocator); +} + pub fn compressContextSummaryAlloc( allocator: std.mem.Allocator, messages: []const types.Message, @@ -227,7 +278,13 @@ pub fn compressContextSummaryAlloc( try out.appendSlice(allocator, "Summary of previous context:\n"); for (messages[0..summarize_until]) |message| { - try out.writer(allocator).print("- {s}: {s}\n", .{ message.role, message.content }); + const max_content_bytes = 80; + const content = message.content[0..@min(message.content.len, max_content_bytes)]; + try out.writer(allocator).print("- {s}: {s}{s}\n", .{ + message.role, + content, + if (message.content.len > max_content_bytes) "..." else "", + }); } return try out.toOwnedSlice(allocator); @@ -517,3 +574,104 @@ test "trimToRetentionAlloc keeps latest messages" { try std.testing.expectEqualStrings("two", trimmed[0].content); try std.testing.expectEqualStrings("three", trimmed[1].content); } + +test "mergeMessagesAlloc preserves history before incoming messages" { + const allocator = std.testing.allocator; + const history = [_]types.Message{ + .{ .role = "user", .content = "one" }, + .{ .role = "assistant", .content = "two" }, + }; + const incoming = [_]types.Message{ + .{ .role = "user", .content = "three" }, + }; + + const merged = try mergeMessagesAlloc(allocator, history[0..], incoming[0..]); + defer { + for (merged) |message| { + message.deinit(allocator); + } + allocator.free(merged); + } + + try std.testing.expectEqual(@as(usize, 3), merged.len); + try std.testing.expectEqualStrings("one", merged[0].content); + try std.testing.expectEqualStrings("two", merged[1].content); + try std.testing.expectEqualStrings("three", merged[2].content); +} + +test "FileSessionStore roundtrips state and enforces retention" { + const allocator = std.testing.allocator; + const store_path = try std.fmt.allocPrint( + allocator, + ".zig-cache/session-test-{d}", + .{std.time.nanoTimestamp()}, + ); + defer allocator.free(store_path); + defer std.fs.cwd().deleteTree(store_path) catch {}; + + var file_store = try FileSessionStore.init(allocator, store_path, 2); + defer file_store.deinit(allocator); + var store = file_store.asStore(); + + var messages = try allocator.alloc(types.Message, 3); + messages[0] = .{ .role = try allocator.dupe(u8, "user"), .content = try allocator.dupe(u8, "one") }; + messages[1] = .{ .role = try allocator.dupe(u8, "assistant"), .content = try allocator.dupe(u8, "two") }; + messages[2] = .{ .role = try allocator.dupe(u8, "user"), .content = try allocator.dupe(u8, "three") }; + + var state = SessionState{ + .session_id = try allocator.dupe(u8, "session/a"), + .tenant_id = try allocator.dupe(u8, "tenant"), + .summary = try allocator.dupe(u8, "kept summary"), + .messages = messages, + .message_count = messages.len, + }; + defer state.deinit(allocator); + + try store.save(allocator, state); + + const loaded_opt = try store.load(allocator, "session/a", "tenant"); + try std.testing.expect(loaded_opt != null); + const loaded = loaded_opt.?; + defer loaded.deinit(allocator); + + try std.testing.expectEqualStrings("session/a", loaded.session_id); + try std.testing.expectEqualStrings("tenant", loaded.tenant_id.?); + try std.testing.expectEqualStrings("kept summary", loaded.summary); + try std.testing.expectEqual(@as(usize, 2), loaded.messages.len); + try std.testing.expectEqualStrings("two", loaded.messages[0].content); + try std.testing.expectEqualStrings("three", loaded.messages[1].content); +} + +test "compactContextToBudgetAlloc adds summary and keeps latest messages" { + const allocator = std.testing.allocator; + const messages = [_]types.Message{ + .{ .role = "user", .content = "older user content that should be summarized because it is intentionally long and no longer needs to remain as raw chat history in the outgoing provider context" }, + .{ .role = "assistant", .content = "older assistant content that should be summarized because it is intentionally long and no longer needs to remain as raw chat history in the outgoing provider context" }, + .{ .role = "user", .content = "latest user" }, + .{ .role = "assistant", .content = "latest assistant" }, + }; + + const compacted = try compactContextToBudgetAlloc(allocator, messages[0..], 70, 2); + defer { + for (compacted) |message| { + message.deinit(allocator); + } + allocator.free(compacted); + } + + try std.testing.expect(compacted.len >= 2); + try std.testing.expectEqualStrings("system", compacted[0].role); + try std.testing.expect(std.mem.indexOf(u8, compacted[0].content, "Summary of previous context") != null); + try std.testing.expectEqualStrings("latest assistant", compacted[compacted.len - 1].content); +} + +test "context token estimator triggers compression past budget" { + const messages = [_]types.Message{ + .{ .role = "user", .content = "12345678901234567890" }, + }; + + const estimated = estimateTokenCount(messages[0..]); + try std.testing.expect(estimated > 0); + try std.testing.expect(shouldCompressContext(estimated, estimated - 1)); + try std.testing.expect(!shouldCompressContext(estimated, estimated)); +} diff --git a/src/backend/tool_output.zig b/src/backend/tool_output.zig new file mode 100644 index 0000000..c41a976 --- /dev/null +++ b/src/backend/tool_output.zig @@ -0,0 +1,105 @@ +//! Tool output offloading for large tool results. +//! +//! When tool output exceeds a configured byte threshold, the full payload is +//! written to disk and a compact head/tail summary is returned for context. +const std = @import("std"); +const config = @import("../config.zig"); + +const head_tail_preview_bytes = 512; + +pub fn maybeOffloadToolOutputAlloc( + allocator: std.mem.Allocator, + app_config: *const config.Config, + request_id: []const u8, + tool_name: []const u8, + output: []const u8, +) ![]u8 { + if (app_config.tool_output_offload_bytes == 0) { + return try allocator.dupe(u8, output); + } + if (output.len <= app_config.tool_output_offload_bytes) { + return try allocator.dupe(u8, output); + } + + const offload_dir = try std.fmt.allocPrint(allocator, "{s}/tool_outputs", .{app_config.session_store_path}); + defer allocator.free(offload_dir); + + std.fs.cwd().makePath(offload_dir) catch {}; + + var safe_tool_name_buf: [64]u8 = undefined; + const safe_tool_name = sanitizePathComponent(tool_name, &safe_tool_name_buf); + const offload_path = try std.fmt.allocPrint( + allocator, + "{s}/{s}_{s}.txt", + .{ offload_dir, request_id, safe_tool_name }, + ); + defer allocator.free(offload_path); + + const file = std.fs.cwd().createFile(offload_path, .{}) catch { + return try allocator.dupe(u8, output); + }; + defer file.close(); + file.writeAll(output) catch { + return try allocator.dupe(u8, output); + }; + + const head_len = @min(head_tail_preview_bytes, output.len); + const tail_len = @min(head_tail_preview_bytes, output.len - head_len); + const tail_start = output.len - tail_len; + + return try std.fmt.allocPrint( + allocator, + "tool_output_offloaded=true\nfull_bytes={d}\noffload_path={s}\n--- head ---\n{s}\n--- tail ---\n{s}", + .{ + output.len, + offload_path, + output[0..head_len], + output[tail_start..], + }, + ); +} + +fn sanitizePathComponent(name: []const u8, out: *[64]u8) []const u8 { + var len: usize = 0; + for (name) |c| { + if (len >= out.len) break; + switch (c) { + 'a'...'z', 'A'...'Z', '0'...'9', '_', '-', '.' => { + out.*[len] = c; + len += 1; + }, + else => {}, + } + } + if (len == 0) return "tool"; + return out[0..len]; +} + +test "maybeOffloadToolOutputAlloc keeps small output inline" { + const allocator = std.testing.allocator; + var cfg = try @import("../tools/command_exec.zig").buildTestConfig(allocator, false); + defer cfg.deinit(allocator); + cfg.tool_output_offload_bytes = 64; + + const output = try maybeOffloadToolOutputAlloc(allocator, &cfg, "req-1", "echo", "small"); + defer allocator.free(output); + + try std.testing.expectEqualStrings("small", output); +} + +test "maybeOffloadToolOutputAlloc offloads large output with preview" { + const allocator = std.testing.allocator; + var cfg = try @import("../tools/command_exec.zig").buildTestConfig(allocator, false); + defer cfg.deinit(allocator); + cfg.tool_output_offload_bytes = 32; + + const large = try allocator.alloc(u8, 128); + defer allocator.free(large); + @memset(large, 'x'); + + const output = try maybeOffloadToolOutputAlloc(allocator, &cfg, "req-offload", "echo", large); + defer allocator.free(output); + + try std.testing.expect(std.mem.indexOf(u8, output, "tool_output_offloaded=true") != null); + try std.testing.expect(std.mem.indexOf(u8, output, "offload_path=") != null); +} diff --git a/src/backend/tools.zig b/src/backend/tools.zig index 3a0545a..43d9387 100644 --- a/src/backend/tools.zig +++ b/src/backend/tools.zig @@ -4,6 +4,9 @@ const types = @import("../types.zig"); const echo_tool = @import("../tools/echo.zig"); const utc_tool = @import("../tools/utc.zig"); const command_exec_tool = @import("../tools/command_exec.zig"); +const file_ops_tool = @import("../tools/file_ops.zig"); +const tool_output = @import("tool_output.zig"); +const workspace = @import("workspace.zig"); pub const ToolRegistry = struct { allowed_names: std.StringHashMap(void), @@ -41,6 +44,11 @@ pub fn validateRequestedTools( return true; } +pub fn noteToolCall(max_calls: usize, used_calls: *usize) !void { + if (used_calls.* >= max_calls) return error.ToolCallLimitExceeded; + used_calls.* += 1; +} + /// Executes simple built-in debug tools directly in the harness. /// /// This is intentionally minimal and deterministic so UI clients can verify @@ -49,6 +57,8 @@ pub fn tryExecuteDebugTool( allocator: std.mem.Allocator, request: types.Request, app_config: *const config.Config, + request_id: []const u8, + active_workspace: ?*workspace.WorkspaceState, ) !?types.Response { const choice = request.tool_choice orelse return null; @@ -64,21 +74,135 @@ pub fn tryExecuteDebugTool( if (std.ascii.eqlIgnoreCase(choice, "cmd")) { if (!hasRequestedTool(request.tools, "cmd")) return null; - return try command_exec_tool.execute(allocator, app_config, request, .cmd); + const command = try command_exec_tool.extractCommandFromPromptAlloc(allocator, .cmd, request.prompt, app_config.tool_exec_trusted_local); + defer if (command) |value| allocator.free(value); + if (command == null) return null; + return try command_exec_tool.execute(allocator, app_config, request, .cmd, request_id); } if (std.ascii.eqlIgnoreCase(choice, "bash")) { if (!hasRequestedTool(request.tools, "bash")) return null; - return try command_exec_tool.execute(allocator, app_config, request, .bash); + const command = try command_exec_tool.extractCommandFromPromptAlloc(allocator, .bash, request.prompt, app_config.tool_exec_trusted_local); + defer if (command) |value| allocator.free(value); + if (command == null) return null; + return try command_exec_tool.execute(allocator, app_config, request, .bash, request_id); + } + + if (std.ascii.eqlIgnoreCase(choice, "file_read")) { + if (!hasRequestedTool(request.tools, "file_read")) return null; + return try file_ops_tool.execute(allocator, app_config, request, .read, active_workspace); + } + + if (std.ascii.eqlIgnoreCase(choice, "file_write")) { + if (!hasRequestedTool(request.tools, "file_write")) return null; + return try file_ops_tool.execute(allocator, app_config, request, .write, active_workspace); + } + + if (std.ascii.eqlIgnoreCase(choice, "file_search")) { + if (!hasRequestedTool(request.tools, "file_search")) return null; + return try file_ops_tool.execute(allocator, app_config, request, .search, active_workspace); } return null; } +pub fn maybeMakeBuiltinToolCallResponseAlloc( + allocator: std.mem.Allocator, + request: types.Request, + request_id: []const u8, +) !?types.Response { + if (!hasRequestedTool(request.tools, "echo")) return null; + if (request.tool_choice) |choice| { + if (std.ascii.eqlIgnoreCase(choice, "none")) return null; + if (!std.ascii.eqlIgnoreCase(choice, "auto") and + !std.ascii.eqlIgnoreCase(choice, "required") and + !std.ascii.eqlIgnoreCase(choice, "echo")) + { + return null; + } + } + + if (std.ascii.indexOfIgnoreCase(request.prompt, "echo tool") == null) return null; + const message = extractExactMessageForEcho(request.prompt) orelse return null; + + const escaped_message = try escapeJsonStringAlloc(allocator, message); + defer allocator.free(escaped_message); + + const arguments_json = try std.fmt.allocPrint( + allocator, + "{{\"message\":\"{s}\"}}", + .{escaped_message}, + ); + errdefer allocator.free(arguments_json); + + const call_id = try std.fmt.allocPrint(allocator, "call_{s}_echo", .{request_id}); + errdefer allocator.free(call_id); + + const calls = try allocator.alloc(types.ToolCall, 1); + errdefer allocator.free(calls); + calls[0] = .{ + .id = call_id, + .name = try allocator.dupe(u8, "echo"), + .arguments_json = arguments_json, + }; + errdefer calls[0].deinit(allocator); + + return .{ + .id = null, + .model = try allocator.dupe(u8, "debug-tools/tool-call"), + .output = try allocator.dupe(u8, ""), + .finish_reason = try allocator.dupe(u8, "tool_calls"), + .success = true, + .usage = .{}, + .tool_calls = calls, + }; +} + +fn extractExactMessageForEcho(prompt: []const u8) ?[]const u8 { + const marker = "exact message "; + const start = std.ascii.indexOfIgnoreCase(prompt, marker) orelse return null; + const raw = prompt[start + marker.len ..]; + const trimmed = std.mem.trim(u8, raw, " \t\r\n\"'"); + if (trimmed.len == 0) return null; + + var end: usize = 0; + while (end < trimmed.len) : (end += 1) { + switch (trimmed[end]) { + '.', ',', ';', ':', '!', '?', '"', '\'', '\r', '\n', '\t', ' ' => break, + else => {}, + } + } + if (end == 0) return null; + return trimmed[0..end]; +} + +fn escapeJsonStringAlloc( + allocator: std.mem.Allocator, + input: []const u8, +) ![]u8 { + var out = std.ArrayList(u8){}; + defer out.deinit(allocator); + + for (input) |c| { + switch (c) { + '"' => try out.appendSlice(allocator, "\\\""), + '\\' => try out.appendSlice(allocator, "\\\\"), + '\n' => try out.appendSlice(allocator, "\\n"), + '\r' => try out.appendSlice(allocator, "\\r"), + '\t' => try out.appendSlice(allocator, "\\t"), + else => try out.append(allocator, c), + } + } + + return try out.toOwnedSlice(allocator); +} + pub fn maybeExecutePromptToolsAlloc( allocator: std.mem.Allocator, request: types.Request, app_config: *const config.Config, + request_id: []const u8, + active_workspace: ?*workspace.WorkspaceState, ) !?[]u8 { if (request.tools.len == 0) return null; @@ -93,9 +217,11 @@ pub fn maybeExecutePromptToolsAlloc( defer output.deinit(allocator); var executed_any = false; + var tool_calls: usize = 0; try output.appendSlice(allocator, "Tool execution results:\n"); if (hasRequestedTool(request.tools, "utc") and mentionsUtcIntent(prompt)) { + try noteToolCall(app_config.max_tool_calls_per_request, &tool_calls); const utc_result = try utc_tool.execute(allocator, request); defer utc_result.deinit(allocator); @@ -105,16 +231,87 @@ pub fn maybeExecutePromptToolsAlloc( const cmd_requested = hasRequestedTool(request.tools, "cmd"); const bash_requested = hasRequestedTool(request.tools, "bash"); + const file_write_requested = hasRequestedTool(request.tools, "file_write"); + const file_read_requested = hasRequestedTool(request.tools, "file_read"); + const file_search_requested = hasRequestedTool(request.tools, "file_search"); + + if (file_read_requested and mentionsFileReadIntent(prompt)) { + try noteToolCall(app_config.max_tool_calls_per_request, &tool_calls); + const read_result = try file_ops_tool.execute(allocator, app_config, request, .read, active_workspace); + defer read_result.deinit(allocator); + try output.writer(allocator).print("\n[tool=file_read]\n{s}\n", .{read_result.output}); + executed_any = true; + } + + if (file_search_requested and mentionsFileSearchIntent(prompt)) { + try noteToolCall(app_config.max_tool_calls_per_request, &tool_calls); + const search_result = try file_ops_tool.execute(allocator, app_config, request, .search, active_workspace); + defer search_result.deinit(allocator); + try output.writer(allocator).print("\n[tool=file_search]\n{s}\n", .{search_result.output}); + executed_any = true; + } + + if (file_write_requested and mentionsPythonAddTwoNumbersIntent(prompt)) { + const demo_path = "tmp/add_two_numbers.py"; + const demo_content = + \\a = 2 + \\b = 3 + \\print(a + b) + \\ + ; + + const write_prompt = try std.fmt.allocPrint( + allocator, + "write {s}\n```python\n{s}```", + .{ demo_path, demo_content }, + ); + defer allocator.free(write_prompt); + + const write_req = try buildSinglePromptRequestAlloc(allocator, write_prompt); + defer write_req.deinit(allocator); + + try noteToolCall(app_config.max_tool_calls_per_request, &tool_calls); + const write_result = try file_ops_tool.execute(allocator, app_config, write_req, .write, active_workspace); + defer write_result.deinit(allocator); + + try output.writer(allocator).print("\n[tool=file_write]\n{s}\n", .{write_result.output}); + executed_any = true; + + if (cmd_requested and @import("builtin").os.tag == .windows) { + const run_prompt = try std.fmt.allocPrint(allocator, "python {s}", .{demo_path}); + defer allocator.free(run_prompt); + const run_req = try buildSinglePromptRequestAlloc(allocator, run_prompt); + defer run_req.deinit(allocator); + + try noteToolCall(app_config.max_tool_calls_per_request, &tool_calls); + const run_result = try command_exec_tool.execute(allocator, app_config, run_req, .cmd, request_id); + defer run_result.deinit(allocator); + + try output.writer(allocator).print("\n[tool=cmd]\n{s}\n", .{run_result.output}); + } else if (bash_requested and @import("builtin").os.tag != .windows) { + const run_prompt = try std.fmt.allocPrint(allocator, "python3 {s}", .{demo_path}); + defer allocator.free(run_prompt); + const run_req = try buildSinglePromptRequestAlloc(allocator, run_prompt); + defer run_req.deinit(allocator); + + try noteToolCall(app_config.max_tool_calls_per_request, &tool_calls); + const run_result = try command_exec_tool.execute(allocator, app_config, run_req, .bash, request_id); + defer run_result.deinit(allocator); + + try output.writer(allocator).print("\n[tool=bash]\n{s}\n", .{run_result.output}); + } + } if (cmd_requested) { - const command_to_run = try command_exec_tool.extractCommandFromPromptAlloc(allocator, .cmd, prompt); + const command_to_run = try command_exec_tool.extractCommandFromPromptAlloc(allocator, .cmd, prompt, app_config.tool_exec_trusted_local); defer if (command_to_run) |command| allocator.free(command); if (command_to_run) |command| { const cmd_request = try buildSinglePromptRequestAlloc(allocator, command); defer cmd_request.deinit(allocator); - const cmd_result = try command_exec_tool.execute(allocator, app_config, cmd_request, .cmd); + try noteToolCall(app_config.max_tool_calls_per_request, &tool_calls); + const cmd_result = try command_exec_tool.execute(allocator, app_config, cmd_request, .cmd, request_id); defer cmd_result.deinit(allocator); try output.writer(allocator).print("\n[tool=cmd]\n{s}\n", .{cmd_result.output}); @@ -123,14 +320,15 @@ pub fn maybeExecutePromptToolsAlloc( } if (!executed_any and bash_requested) { - const command_to_run = try command_exec_tool.extractCommandFromPromptAlloc(allocator, .bash, prompt); + const command_to_run = try command_exec_tool.extractCommandFromPromptAlloc(allocator, .bash, prompt, app_config.tool_exec_trusted_local); defer if (command_to_run) |command| allocator.free(command); if (command_to_run) |command| { const bash_request = try buildSinglePromptRequestAlloc(allocator, command); defer bash_request.deinit(allocator); - const bash_result = try command_exec_tool.execute(allocator, app_config, bash_request, .bash); + try noteToolCall(app_config.max_tool_calls_per_request, &tool_calls); + const bash_result = try command_exec_tool.execute(allocator, app_config, bash_request, .bash, request_id); defer bash_result.deinit(allocator); try output.writer(allocator).print("\n[tool=bash]\n{s}\n", .{bash_result.output}); @@ -142,7 +340,37 @@ pub fn maybeExecutePromptToolsAlloc( return try output.toOwnedSlice(allocator); } -fn hasRequestedTool(tools: []const types.Tool, name: []const u8) bool { +pub fn shouldShortCircuitAutoTools(request: types.Request) bool { + if (request.tools.len == 0) return false; + + const choice = request.tool_choice orelse return false; + if (!std.ascii.eqlIgnoreCase(choice, "auto") and !std.ascii.eqlIgnoreCase(choice, "required")) return false; + + const prompt = request.prompt; + if (mentionsPythonAddTwoNumbersIntent(prompt)) return true; + if (hasRequestedTool(request.tools, "file_read") and mentionsFileReadIntent(prompt)) return true; + if (hasRequestedTool(request.tools, "file_search") and mentionsFileSearchIntent(prompt)) return true; + return false; +} + +pub fn makeAutoToolResponse( + allocator: std.mem.Allocator, + summary: []const u8, +) !types.Response { + const model_name = try std.fmt.allocPrint(allocator, "debug-tools/{s}", .{"auto"}); + errdefer allocator.free(model_name); + + return .{ + .id = null, + .model = model_name, + .output = try allocator.dupe(u8, summary), + .finish_reason = try allocator.dupe(u8, "tool"), + .success = true, + .usage = .{}, + }; +} + +pub fn hasRequestedTool(tools: []const types.Tool, name: []const u8) bool { for (tools) |tool| { if (std.ascii.eqlIgnoreCase(tool.name, name)) return true; } @@ -154,6 +382,24 @@ fn mentionsUtcIntent(prompt: []const u8) bool { containsIgnoreCase(prompt, "time"); } +fn mentionsFileReadIntent(prompt: []const u8) bool { + return containsIgnoreCase(prompt, "file_read") or + containsIgnoreCase(prompt, "read file") or + containsIgnoreCase(prompt, "open file"); +} + +fn mentionsFileSearchIntent(prompt: []const u8) bool { + return containsIgnoreCase(prompt, "file_search") or + containsIgnoreCase(prompt, "search ") or + containsIgnoreCase(prompt, "find "); +} + +fn mentionsPythonAddTwoNumbersIntent(prompt: []const u8) bool { + return containsIgnoreCase(prompt, "python") and + containsIgnoreCase(prompt, "add two") and + containsIgnoreCase(prompt, "execute"); +} + fn containsIgnoreCase(haystack: []const u8, needle: []const u8) bool { return indexOfIgnoreCase(haystack, needle) != null; } @@ -196,6 +442,7 @@ fn buildSinglePromptRequestAlloc( .model = null, .session_id = null, .tenant_id = null, + .workspace_id = null, .max_context_tokens = null, .tools = try allocator.alloc(types.Tool, 0), .tool_choice = null, @@ -233,7 +480,7 @@ test "tryExecuteDebugTool executes echo when explicitly selected" { var cfg = try command_exec_tool.buildTestConfig(allocator, false); defer cfg.deinit(allocator); - const maybe_result = try tryExecuteDebugTool(allocator, req, &cfg); + const maybe_result = try tryExecuteDebugTool(allocator, req, &cfg, "test-request", null); try std.testing.expect(maybe_result != null); var result = maybe_result.?; @@ -244,6 +491,48 @@ test "tryExecuteDebugTool executes echo when explicitly selected" { try std.testing.expect(std.mem.indexOf(u8, result.output, "DEBUG_TOOL_OK") != null); } +test "maybeMakeBuiltinToolCallResponseAlloc synthesizes echo tool call" { + const allocator = std.testing.allocator; + const prompt = "Use the echo tool to return the exact message demo-green. Do not answer in normal text."; + + const messages = try allocator.alloc(types.Message, 1); + messages[0] = .{ + .role = try allocator.dupe(u8, "user"), + .content = try allocator.dupe(u8, prompt), + }; + + const tools = try allocator.alloc(types.Tool, 1); + tools[0] = .{ + .name = try allocator.dupe(u8, "echo"), + .description = try allocator.dupe(u8, "debug echo"), + }; + + const req = types.Request{ + .prompt = try allocator.dupe(u8, prompt), + .messages = messages, + .provider = null, + .model = null, + .session_id = null, + .tenant_id = null, + .max_context_tokens = null, + .tools = tools, + .tool_choice = null, + }; + defer req.deinit(allocator); + + const maybe_result = try maybeMakeBuiltinToolCallResponseAlloc(allocator, req, "test-request"); + try std.testing.expect(maybe_result != null); + + var result = maybe_result.?; + defer result.deinit(allocator); + + try std.testing.expectEqualStrings("tool_calls", result.finish_reason); + try std.testing.expect(result.tool_calls != null); + try std.testing.expectEqual(@as(usize, 1), result.tool_calls.?.len); + try std.testing.expectEqualStrings("echo", result.tool_calls.?[0].name); + try std.testing.expectEqualStrings("{\"message\":\"demo-green\"}", result.tool_calls.?[0].arguments_json); +} + test "tryExecuteDebugTool executes utc when explicitly selected" { const allocator = std.testing.allocator; @@ -275,7 +564,7 @@ test "tryExecuteDebugTool executes utc when explicitly selected" { var cfg = try command_exec_tool.buildTestConfig(allocator, false); defer cfg.deinit(allocator); - const maybe_result = try tryExecuteDebugTool(allocator, req, &cfg); + const maybe_result = try tryExecuteDebugTool(allocator, req, &cfg, "test-request", null); try std.testing.expect(maybe_result != null); var result = maybe_result.?; @@ -317,7 +606,7 @@ test "maybeExecutePromptToolsAlloc runs utc from prompt intent" { var cfg = try command_exec_tool.buildTestConfig(allocator, false); defer cfg.deinit(allocator); - const summary = try maybeExecutePromptToolsAlloc(allocator, req, &cfg); + const summary = try maybeExecutePromptToolsAlloc(allocator, req, &cfg, "test-request", null); defer if (summary) |value| allocator.free(value); try std.testing.expect(summary != null); @@ -357,10 +646,52 @@ test "maybeExecutePromptToolsAlloc extracts command for cmd" { var cfg = try command_exec_tool.buildTestConfig(allocator, true); defer cfg.deinit(allocator); - const summary = try maybeExecutePromptToolsAlloc(allocator, req, &cfg); + const summary = try maybeExecutePromptToolsAlloc(allocator, req, &cfg, "test-request", null); defer if (summary) |value| allocator.free(value); try std.testing.expect(summary != null); try std.testing.expect(std.mem.indexOf(u8, summary.?, "[tool=cmd]") != null); try std.testing.expect(std.mem.indexOf(u8, summary.?, "command=zig build test") != null); } + +test "tryExecuteDebugTool ignores natural language cmd selection" { + const allocator = std.testing.allocator; + + const messages = try allocator.alloc(types.Message, 1); + messages[0] = .{ + .role = try allocator.dupe(u8, "user"), + .content = try allocator.dupe(u8, "Please write a small C++ program"), + }; + + const tools = try allocator.alloc(types.Tool, 1); + tools[0] = .{ + .name = try allocator.dupe(u8, "cmd"), + .description = try allocator.dupe(u8, "windows command tool"), + }; + + const req = types.Request{ + .prompt = try allocator.dupe(u8, "Please write a small C++ program"), + .messages = messages, + .provider = null, + .model = null, + .session_id = null, + .tenant_id = null, + .max_context_tokens = null, + .tools = tools, + .tool_choice = try allocator.dupe(u8, "cmd"), + }; + defer req.deinit(allocator); + + var cfg = try command_exec_tool.buildTestConfig(allocator, true); + defer cfg.deinit(allocator); + + const maybe_result = try tryExecuteDebugTool(allocator, req, &cfg, "test-request", null); + try std.testing.expect(maybe_result == null); +} + +test "noteToolCall enforces configured tool call cap" { + var used: usize = 0; + try noteToolCall(1, &used); + try std.testing.expectEqual(@as(usize, 1), used); + try std.testing.expectError(error.ToolCallLimitExceeded, noteToolCall(1, &used)); +} diff --git a/src/backend/workspace.zig b/src/backend/workspace.zig new file mode 100644 index 0000000..501b75d --- /dev/null +++ b/src/backend/workspace.zig @@ -0,0 +1,230 @@ +//! Lightweight ephemeral workspace memory (Cursor-like, cleared on server exit). +//! +//! Stores conversation turns and a small file-activity hint keyed by `workspace_id`. +//! Does not replace the core harness loop; it only substitutes disk `session_id` with +//! in-process state and enables repo-root file tools when workspace mode is on. +const std = @import("std"); +const types = @import("../types.zig"); +const session = @import("session.zig"); + +pub const max_file_hints = 16; + +pub const WorkspaceState = struct { + workspace_id: []const u8, + messages: []types.Message, + file_hints: [][]const u8, + + pub fn deinit(self: *WorkspaceState, allocator: std.mem.Allocator) void { + allocator.free(self.workspace_id); + for (self.messages) |message| { + message.deinit(allocator); + } + allocator.free(self.messages); + for (self.file_hints) |hint| { + allocator.free(hint); + } + allocator.free(self.file_hints); + } +}; + +pub const WorkspaceStore = struct { + mutex: std.Thread.Mutex = .{}, + workspaces: std.StringHashMap(*WorkspaceState), + allocator: std.mem.Allocator, + + pub fn init(allocator: std.mem.Allocator) WorkspaceStore { + return .{ + .workspaces = std.StringHashMap(*WorkspaceState).init(allocator), + .allocator = allocator, + }; + } + + pub fn deinit(self: *WorkspaceStore) void { + self.mutex.lock(); + defer self.mutex.unlock(); + + var it = self.workspaces.iterator(); + while (it.next()) |entry| { + entry.value_ptr.*.deinit(self.allocator); + self.allocator.destroy(entry.value_ptr.*); + } + self.workspaces.deinit(); + } + + pub fn count(self: *WorkspaceStore) usize { + self.mutex.lock(); + defer self.mutex.unlock(); + return self.workspaces.count(); + } + + pub fn getOrCreate(self: *WorkspaceStore, workspace_id: []const u8) !*WorkspaceState { + self.mutex.lock(); + defer self.mutex.unlock(); + + if (self.workspaces.get(workspace_id)) |existing| { + return existing; + } + + const owned_id = try self.allocator.dupe(u8, workspace_id); + errdefer self.allocator.free(owned_id); + + const state = try self.allocator.create(WorkspaceState); + errdefer self.allocator.destroy(state); + + state.* = .{ + .workspace_id = owned_id, + .messages = try self.allocator.alloc(types.Message, 0), + .file_hints = &.{}, + }; + + try self.workspaces.put(owned_id, state); + return state; + } + + pub fn saveMessages( + self: *WorkspaceStore, + workspace: *WorkspaceState, + messages: []types.Message, + ) !void { + self.mutex.lock(); + defer self.mutex.unlock(); + + for (workspace.messages) |message| { + message.deinit(self.allocator); + } + self.allocator.free(workspace.messages); + workspace.messages = messages; + } +}; + +pub fn mergeIncomingMessagesAlloc( + allocator: std.mem.Allocator, + existing: []const types.Message, + incoming: []const types.Message, +) ![]types.Message { + if (existing.len == 0) { + return try session.cloneMessagesAlloc(allocator, incoming); + } + if (incoming.len == 0) { + return try session.cloneMessagesAlloc(allocator, existing); + } + + const last_incoming = incoming[incoming.len - 1]; + const last_existing = existing[existing.len - 1]; + if (std.mem.eql(u8, last_existing.role, last_incoming.role) and + std.mem.eql(u8, last_existing.content, last_incoming.content)) + { + return try session.cloneMessagesAlloc(allocator, existing); + } + + return try session.mergeMessagesAlloc(allocator, existing, incoming); +} + +pub fn recordFileHint( + workspace: *WorkspaceState, + allocator: std.mem.Allocator, + operation: []const u8, + path: []const u8, +) !void { + const hint = try std.fmt.allocPrint(allocator, "{s} {s}", .{ operation, path }); + errdefer allocator.free(hint); + + if (workspace.file_hints.len >= max_file_hints) { + allocator.free(workspace.file_hints[0]); + for (workspace.file_hints[1..], 0..) |entry, idx| { + workspace.file_hints[idx] = entry; + } + workspace.file_hints[workspace.file_hints.len - 1] = hint; + return; + } + + const next = try allocator.realloc(workspace.file_hints, workspace.file_hints.len + 1); + next[next.len - 1] = hint; + workspace.file_hints = next; +} + +pub fn buildMemoryHintMessageAlloc( + allocator: std.mem.Allocator, + workspace: *const WorkspaceState, +) !?types.Message { + if (workspace.file_hints.len == 0) return null; + + var out = std.ArrayList(u8){}; + defer out.deinit(allocator); + + try out.appendSlice(allocator, "Workspace memory (cleared on server exit):\n"); + for (workspace.file_hints) |hint| { + try out.writer(allocator).print("- {s}\n", .{hint}); + } + + return types.Message{ + .role = try allocator.dupe(u8, "system"), + .content = try out.toOwnedSlice(allocator), + }; +} + +pub fn prependMemoryHintAlloc( + allocator: std.mem.Allocator, + workspace: *const WorkspaceState, + messages: []const types.Message, +) ![]types.Message { + const hint = try buildMemoryHintMessageAlloc(allocator, workspace); + if (hint == null) { + return try session.cloneMessagesAlloc(allocator, messages); + } + defer hint.?.deinit(allocator); + + var out = std.ArrayList(types.Message){}; + errdefer { + for (out.items) |message| { + message.deinit(allocator); + } + out.deinit(allocator); + } + + try out.append(allocator, .{ + .role = try allocator.dupe(u8, hint.?.role), + .content = try allocator.dupe(u8, hint.?.content), + }); + + for (messages) |message| { + try out.append(allocator, .{ + .role = try allocator.dupe(u8, message.role), + .content = try allocator.dupe(u8, message.content), + }); + } + + return try out.toOwnedSlice(allocator); +} + +test "WorkspaceStore is ephemeral and keyed by id" { + const allocator = std.testing.allocator; + var store = WorkspaceStore.init(allocator); + defer store.deinit(); + + const ws = try store.getOrCreate("dev"); + const incoming = [_]types.Message{.{ .role = "user", .content = "hi" }}; + const merged = try mergeIncomingMessagesAlloc(allocator, ws.messages, incoming[0..]); + try store.saveMessages(ws, merged); + + try std.testing.expectEqual(@as(usize, 1), ws.messages.len); + try std.testing.expectEqual(@as(usize, 1), store.count()); +} + +test "recordFileHint keeps bounded file activity" { + const allocator = std.testing.allocator; + var state = WorkspaceState{ + .workspace_id = try allocator.dupe(u8, "dev"), + .messages = try allocator.alloc(types.Message, 0), + .file_hints = &.{}, + }; + defer state.deinit(allocator); + + try recordFileHint(&state, allocator, "read", "src/main.zig"); + try std.testing.expectEqual(@as(usize, 1), state.file_hints.len); + + const hint = try buildMemoryHintMessageAlloc(allocator, &state); + try std.testing.expect(hint != null); + defer hint.?.deinit(allocator); + try std.testing.expect(std.mem.indexOf(u8, hint.?.content, "src/main.zig") != null); +} diff --git a/src/config.zig b/src/config.zig index 538eea8..d2f83ca 100644 --- a/src/config.zig +++ b/src/config.zig @@ -13,6 +13,7 @@ pub const Config = struct { listen_port: u16, debug_logging: bool, default_provider: []const u8, + default_model: ?[]const u8, request_timeout_ms: u32, provider_timeout_ms: u32, instance_id: []const u8, @@ -46,9 +47,19 @@ pub const Config = struct { session_store_path: []const u8, session_retention_messages: usize, tool_exec_enabled: bool, + tool_exec_confirmation_required: bool, + tool_exec_trusted_local: bool, tool_exec_timeout_ms: u32, tool_exec_max_output_bytes: usize, + tool_output_offload_bytes: usize, + max_tool_calls_per_request: usize, + default_max_context_tokens: ?usize, + workspace_mode_enabled: bool, + workspace_root: []const u8, loop_stream_progress_enabled: bool, + max_concurrent_connections: usize, + max_request_bytes: usize, + max_header_bytes: usize, pub fn load(allocator: std.mem.Allocator) !Config { return try loadWithOverrides(allocator, null); @@ -64,9 +75,9 @@ pub const Config = struct { "ollama", env_overrides, ); - errdefer allocator.free(requested_default_provider); const normalized_default_provider = types.normalizeProviderName(requested_default_provider) orelse { + allocator.free(requested_default_provider); return error.InvalidProvider; }; @@ -74,6 +85,18 @@ pub const Config = struct { allocator.free(requested_default_provider); errdefer allocator.free(default_provider); + const default_model_raw = try getEnvOrDefault( + allocator, + "LLM_ROUTER_MODEL", + "", + env_overrides, + ); + const default_model: ?[]const u8 = if (default_model_raw.len == 0) blk: { + allocator.free(default_model_raw); + break :blk null; + } else default_model_raw; + errdefer if (default_model) |model| allocator.free(model); + return Config{ .listen_host = try getEnvOrDefault( allocator, @@ -84,6 +107,7 @@ pub const Config = struct { .listen_port = try getEnvPortOrDefault("LLM_ROUTER_PORT", 8081, env_overrides), .debug_logging = try getEnvFlag(allocator, "LLM_ROUTER_DEBUG", env_overrides), .default_provider = default_provider, + .default_model = default_model, .request_timeout_ms = try getEnvPositiveU32OrDefault("LLM_ROUTER_REQUEST_TIMEOUT_MS", 30_000, env_overrides), .provider_timeout_ms = try getEnvPositiveU32OrDefault("LLM_ROUTER_PROVIDER_TIMEOUT_MS", 60_000, env_overrides), .instance_id = try getEnvOrDefault( @@ -110,13 +134,13 @@ pub const Config = struct { "qwen3.5:9b", env_overrides, ), - .ollama_think = try getEnvFlag(allocator, "OLLAMA_THINK", env_overrides), - .ollama_num_predict = try getEnvU32OrDefault("OLLAMA_NUM_PREDICT", 128, env_overrides), + .ollama_think = try getEnvFlagOrDefault(allocator, "OLLAMA_THINK", true, env_overrides), + .ollama_num_predict = try getEnvU32OrDefault("OLLAMA_NUM_PREDICT", 2048, env_overrides), .ollama_temperature = try getEnvF64OrDefault(allocator, "OLLAMA_TEMPERATURE", 0.7, env_overrides), .ollama_repeat_penalty = try getEnvF64OrDefault(allocator, "OLLAMA_REPEAT_PENALTY", 1.05, env_overrides), - .openai_base_url = try getEnvOrDefault( + .openai_base_url = try getFirstEnvOrDefault( allocator, - "OPENAI_BASE_URL", + &.{ "OPENAI_BASE_URL", "OPENAI_API_BASE" }, "https://api.openai.com/v1", env_overrides, ), @@ -132,9 +156,9 @@ pub const Config = struct { "gpt-4.1-mini", env_overrides, ), - .openrouter_base_url = try getEnvOrDefault( + .openrouter_base_url = try getFirstEnvOrDefault( allocator, - "OPENROUTER_BASE_URL", + &.{ "OPENROUTER_BASE_URL", "OPENROUTER_API_BASE" }, "https://openrouter.ai/api/v1", env_overrides, ), @@ -246,9 +270,36 @@ pub const Config = struct { env_overrides, ), .tool_exec_enabled = try getEnvFlag(allocator, "LLM_ROUTER_TOOL_EXEC_ENABLED", env_overrides), + .tool_exec_confirmation_required = try getEnvFlagOrDefault( + allocator, + "LLM_ROUTER_TOOL_EXEC_CONFIRM_REQUIRED", + true, + env_overrides, + ), + .tool_exec_trusted_local = try getEnvFlag(allocator, "LLM_ROUTER_TOOL_EXEC_TRUSTED_LOCAL", env_overrides), .tool_exec_timeout_ms = try getEnvPositiveU32OrDefault("LLM_ROUTER_TOOL_EXEC_TIMEOUT_MS", 15_000, env_overrides), .tool_exec_max_output_bytes = try getEnvPositiveUsizeOrDefault("LLM_ROUTER_TOOL_EXEC_MAX_OUTPUT_BYTES", 65_536, env_overrides), + .tool_output_offload_bytes = try getEnvPositiveUsizeOrDefault("LLM_ROUTER_TOOL_OUTPUT_OFFLOAD_BYTES", 8192, env_overrides), + .max_tool_calls_per_request = try getEnvPositiveUsizeOrDefault("LLM_ROUTER_MAX_TOOL_CALLS_PER_REQUEST", 8, env_overrides), + .default_max_context_tokens = blk: { + const raw = try getEnvUsizeOrDefault( + "LLM_ROUTER_DEFAULT_MAX_CONTEXT_TOKENS", + 0, + env_overrides, + ); + break :blk if (raw == 0) null else raw; + }, + .workspace_mode_enabled = try getEnvFlag(allocator, "LLM_ROUTER_WORKSPACE_MODE", env_overrides), + .workspace_root = try getEnvOrDefault( + allocator, + "LLM_ROUTER_WORKSPACE_ROOT", + ".", + env_overrides, + ), .loop_stream_progress_enabled = try getEnvFlagOrDefault(allocator, "LLM_ROUTER_LOOP_STREAM_PROGRESS_ENABLED", true, env_overrides), + .max_concurrent_connections = try getEnvPositiveUsizeOrDefault("LLM_ROUTER_MAX_CONCURRENT_CONNECTIONS", 64, env_overrides), + .max_request_bytes = try getEnvPositiveUsizeOrDefault("LLM_ROUTER_MAX_REQUEST_BYTES", 1024 * 1024, env_overrides), + .max_header_bytes = try getEnvPositiveUsizeOrDefault("LLM_ROUTER_MAX_HEADER_BYTES", 16 * 1024, env_overrides), }; } @@ -260,6 +311,7 @@ pub const Config = struct { pub fn deinit(self: *Config, allocator: std.mem.Allocator) void { allocator.free(self.listen_host); allocator.free(self.default_provider); + if (self.default_model) |model| allocator.free(model); allocator.free(self.instance_id); allocator.free(self.auth_api_key); allocator.free(self.ollama_base_url); @@ -285,6 +337,7 @@ pub const Config = struct { allocator.free(self.llama_cpp_api_key); allocator.free(self.llama_cpp_model); allocator.free(self.session_store_path); + allocator.free(self.workspace_root); } pub fn setDefaultProvider( @@ -300,6 +353,40 @@ pub const Config = struct { allocator.free(self.default_provider); self.default_provider = next_provider; } + + pub fn setDefaultModel( + self: *Config, + allocator: std.mem.Allocator, + model: []const u8, + ) !void { + if (self.default_model) |existing| allocator.free(existing); + self.default_model = if (model.len == 0) null else try allocator.dupe(u8, model); + } + + /// Model used when a request omits `model` (or sends `"auto"`). + pub fn defaultModel(self: *const Config) []const u8 { + return self.modelForProvider(self.default_provider); + } + + /// Resolves the configured model for a provider. `LLM_ROUTER_MODEL` overrides + /// per-provider defaults when the provider matches `LLM_ROUTER_PROVIDER`. + pub fn modelForProvider(self: *const Config, provider: []const u8) []const u8 { + const normalized = types.normalizeProviderName(provider) orelse provider; + if (self.default_model) |unified| { + if (std.mem.eql(u8, normalized, self.default_provider)) { + return unified; + } + } + + if (std.mem.eql(u8, normalized, "ollama_qwen")) return self.ollama_model; + if (std.mem.eql(u8, normalized, "openai")) return self.openai_model; + if (std.mem.eql(u8, normalized, "openrouter")) return self.openrouter_model; + if (std.mem.eql(u8, normalized, "claude")) return self.claude_model; + if (std.mem.eql(u8, normalized, "bedrock")) return self.bedrock_model; + if (std.mem.eql(u8, normalized, "llama_cpp")) return self.llama_cpp_model; + + return self.openai_model; + } }; fn validateProviderName(provider: []const u8) !void { @@ -380,6 +467,23 @@ fn getEnvU32OrDefault( }; } +fn getEnvUsizeOrDefault( + comptime key: []const u8, + default_value: usize, + env_overrides: ?*const EnvOverrides, +) !usize { + if (env_overrides) |overrides| { + if (overrides.get(key)) |value| { + return try std.fmt.parseInt(usize, value, 10); + } + } + + return std.process.parseEnvVarInt(key, usize, 10) catch |err| switch (err) { + error.EnvironmentVariableNotFound => default_value, + else => err, + }; +} + fn getEnvPositiveU32OrDefault( comptime key: []const u8, default_value: u32, @@ -490,6 +594,7 @@ test "setDefaultProvider stores canonical provider alias" { .listen_port = 8081, .debug_logging = false, .default_provider = try allocator.dupe(u8, "ollama_qwen"), + .default_model = null, .request_timeout_ms = 30_000, .provider_timeout_ms = 60_000, .instance_id = try allocator.dupe(u8, "local-instance"), @@ -497,7 +602,7 @@ test "setDefaultProvider stores canonical provider alias" { .ollama_base_url = try allocator.dupe(u8, "http://127.0.0.1:11434"), .ollama_model = try allocator.dupe(u8, "qwen:7b"), .ollama_think = false, - .ollama_num_predict = 512, + .ollama_num_predict = 1024, .ollama_temperature = 0.7, .ollama_repeat_penalty = 1.05, .openai_base_url = try allocator.dupe(u8, "https://api.openai.com/v1"), @@ -523,9 +628,19 @@ test "setDefaultProvider stores canonical provider alias" { .session_store_path = try allocator.dupe(u8, "logs/sessions"), .session_retention_messages = 24, .tool_exec_enabled = false, + .tool_exec_confirmation_required = false, + .tool_exec_trusted_local = false, .tool_exec_timeout_ms = 15_000, .tool_exec_max_output_bytes = 65_536, + .tool_output_offload_bytes = 8192, + .max_tool_calls_per_request = 8, + .default_max_context_tokens = null, + .workspace_mode_enabled = false, + .workspace_root = try allocator.dupe(u8, "."), .loop_stream_progress_enabled = true, + .max_concurrent_connections = 64, + .max_request_bytes = 1024 * 1024, + .max_header_bytes = 16 * 1024, }; defer cfg.deinit(allocator); @@ -538,3 +653,28 @@ test "setDefaultProvider stores canonical provider alias" { try cfg.setDefaultProvider(allocator, "bedrock"); try std.testing.expectEqualStrings("bedrock", cfg.default_provider); } + +test "modelForProvider honors LLM_ROUTER_MODEL for active provider only" { + const allocator = std.testing.allocator; + + var overrides = EnvOverrides.init(allocator); + try overrides.put(try allocator.dupe(u8, "LLM_ROUTER_PROVIDER"), try allocator.dupe(u8, "openrouter")); + try overrides.put(try allocator.dupe(u8, "LLM_ROUTER_MODEL"), try allocator.dupe(u8, "openai/gpt-4o-mini")); + try overrides.put(try allocator.dupe(u8, "OPENROUTER_MODEL"), try allocator.dupe(u8, "openrouter/auto")); + try overrides.put(try allocator.dupe(u8, "OLLAMA_MODEL"), try allocator.dupe(u8, "qwen3.5:9b")); + defer { + var it = overrides.iterator(); + while (it.next()) |entry| { + allocator.free(entry.key_ptr.*); + allocator.free(entry.value_ptr.*); + } + overrides.deinit(); + } + + var cfg = try Config.loadWithOverrides(allocator, &overrides); + defer cfg.deinit(allocator); + + try std.testing.expectEqualStrings("openai/gpt-4o-mini", cfg.defaultModel()); + try std.testing.expectEqualStrings("openai/gpt-4o-mini", cfg.modelForProvider("openrouter")); + try std.testing.expectEqualStrings("qwen3.5:9b", cfg.modelForProvider("ollama")); +} diff --git a/src/core/request.zig b/src/core/request.zig index 3dee16b..6bf6835 100644 --- a/src/core/request.zig +++ b/src/core/request.zig @@ -88,10 +88,9 @@ pub fn readHttpRequest( allocator: std.mem.Allocator, connection: *std.net.Server.Connection, request_timeout_ms: u32, + max_request_size: usize, + max_header_size: usize, ) ![]u8 { - const max_request_size = 1024 * 1024; - const max_header_size = 16 * 1024; - const started_ms = std.time.milliTimestamp(); var request = std.ArrayList(u8){}; @@ -133,7 +132,7 @@ pub fn readHttpRequest( error.MissingContentLength => 0, else => return err, }; - const required_length = header_end.? + content_length; + const required_length = try checkedRequiredLength(header_end.?, content_length); if (required_length > max_request_size) { return error.RequestTooLarge; } @@ -166,6 +165,80 @@ pub fn readHttpRequest( return try request.toOwnedSlice(allocator); } +pub fn checkedRequiredLength(header_len: usize, content_length: usize) !usize { + return std.math.add(usize, header_len, content_length) catch error.RequestTooLarge; +} + +/// Returns a trimmed HTTP header value when present. +pub fn getHeaderValue(headers: []const u8, header_name: []const u8) ?[]const u8 { + var lines = std.mem.splitScalar(u8, headers, '\n'); + _ = lines.next(); + + while (lines.next()) |raw_line| { + const line = std.mem.trimRight(u8, raw_line, "\r"); + if (line.len == 0) break; + + const separator_index = std.mem.indexOfScalar(u8, line, ':') orelse continue; + const name = std.mem.trim(u8, line[0..separator_index], " "); + if (!std.ascii.eqlIgnoreCase(name, header_name)) continue; + + return std.mem.trim(u8, line[separator_index + 1 ..], " "); + } + + return null; +} + +/// Uses `X-Request-Id` when valid, otherwise generates `req-{micros}`. +pub fn resolveRequestIdAlloc( + allocator: std.mem.Allocator, + request_raw: []const u8, +) ![]u8 { + const header_end = std.mem.indexOf(u8, request_raw, "\r\n\r\n") orelse request_raw.len; + if (getHeaderValue(request_raw[0..header_end], "X-Request-Id")) |incoming| { + if (isValidRequestId(incoming)) { + return try allocator.dupe(u8, incoming); + } + } + + return try std.fmt.allocPrint(allocator, "req-{d}", .{std.time.microTimestamp()}); +} + +fn isValidRequestId(value: []const u8) bool { + if (value.len == 0 or value.len > 128) return false; + for (value) |c| { + switch (c) { + 'a'...'z', 'A'...'Z', '0'...'9', '_', '-', '.' => {}, + else => return false, + } + } + return true; +} + +test "checkedRequiredLength rejects integer overflow" { + try std.testing.expectError( + error.RequestTooLarge, + checkedRequiredLength(std.math.maxInt(usize), 1), + ); +} + +test "resolveRequestIdAlloc generates fallback id" { + const allocator = std.testing.allocator; + const id = try resolveRequestIdAlloc(allocator, "GET /health HTTP/1.1\r\n\r\n"); + defer allocator.free(id); + try std.testing.expect(std.mem.startsWith(u8, id, "req-")); +} + +test "resolveRequestIdAlloc preserves incoming header" { + const allocator = std.testing.allocator; + const raw = + "POST /v1/chat/completions HTTP/1.1\r\n" ++ + "X-Request-Id: client-trace-42\r\n" ++ + "\r\n"; + const id = try resolveRequestIdAlloc(allocator, raw); + defer allocator.free(id); + try std.testing.expectEqualStrings("client-trace-42", id); +} + fn waitForReadable( socket: std.posix.socket_t, timeout_ms: u32, diff --git a/src/core/response.zig b/src/core/response.zig index 7ff853f..eb4c215 100644 --- a/src/core/response.zig +++ b/src/core/response.zig @@ -50,6 +50,7 @@ fn sendJson( connection: std.net.Server.Connection, status_code: u16, body: []const u8, + request_id: ?[]const u8, ) !void { const status_text = switch (status_code) { 200 => "OK", @@ -65,16 +66,27 @@ fn sendJson( else => "Internal Server Error", }; + const request_id_header = if (request_id) |id| + try std.fmt.allocPrint( + std.heap.page_allocator, + "X-Request-Id: {s}\r\n", + .{id}, + ) + else + try std.heap.page_allocator.dupe(u8, ""); + defer std.heap.page_allocator.free(request_id_header); + // The server handles one request per connection. const response = try std.fmt.allocPrint( std.heap.page_allocator, "HTTP/1.1 {d} {s}\r\n" ++ "Content-Type: application/json\r\n" ++ + "{s}" ++ "Content-Length: {d}\r\n" ++ "Connection: close\r\n" ++ "\r\n" ++ "{s}", - .{ status_code, status_text, body.len, body }, + .{ status_code, status_text, request_id_header, body.len, body }, ); defer std.heap.page_allocator.free(response); @@ -86,7 +98,16 @@ pub fn sendJsonText( status_code: u16, body: []const u8, ) !void { - try sendJson(connection, status_code, body); + try sendJson(connection, status_code, body, null); +} + +pub fn sendJsonTextWithRequestId( + connection: std.net.Server.Connection, + status_code: u16, + body: []const u8, + request_id: []const u8, +) !void { + try sendJson(connection, status_code, body, request_id); } /// Send API error response @@ -94,6 +115,7 @@ pub fn sendApiError( connection: std.net.Server.Connection, allocator: std.mem.Allocator, api_error: errors.ApiError, + request_id: ?[]const u8, ) !void { const escaped_message = try escapeJsonStringAlloc(allocator, api_error.message); defer allocator.free(escaped_message); @@ -114,7 +136,7 @@ pub fn sendApiError( }); defer allocator.free(response_json); - try sendJson(connection, api_error.status_code, response_json); + try sendJson(connection, api_error.status_code, response_json, request_id); } fn makeCompletionIdAlloc(allocator: std.mem.Allocator) ![]u8 { @@ -138,6 +160,7 @@ pub fn sendChatCompletion( connection: std.net.Server.Connection, allocator: std.mem.Allocator, result: types.Response, + request_id: ?[]const u8, ) !void { var generated_id: ?[]u8 = null; defer if (generated_id) |id| allocator.free(id); @@ -159,16 +182,20 @@ pub fn sendChatCompletion( const escaped_content = try escapeJsonStringAlloc(allocator, result.output); defer allocator.free(escaped_content); + const tool_calls_json = try renderToolCallsSuffixJsonAlloc(allocator, result.tool_calls); + defer allocator.free(tool_calls_json); + const escaped_finish_reason = try escapeJsonStringAlloc(allocator, result.finish_reason); defer allocator.free(escaped_finish_reason); const response_json = try std.fmt.allocPrint(allocator, - \\{{"id":"{s}","object":"chat.completion","created":{d},"model":"{s}","choices":[{{"index":0,"message":{{"role":"assistant","content":"{s}"}},"finish_reason":"{s}"}}],"usage":{{"prompt_tokens":{d},"completion_tokens":{d},"total_tokens":{d}}}}} + \\{{"id":"{s}","object":"chat.completion","created":{d},"model":"{s}","choices":[{{"index":0,"message":{{"role":"assistant","content":"{s}"{s}}},"finish_reason":"{s}"}}],"usage":{{"prompt_tokens":{d},"completion_tokens":{d},"total_tokens":{d}}}}} , .{ escaped_id, std.time.timestamp(), escaped_model, escaped_content, + tool_calls_json, escaped_finish_reason, result.usage.prompt_tokens, result.usage.completion_tokens, @@ -176,17 +203,81 @@ pub fn sendChatCompletion( }); defer allocator.free(response_json); - try sendJson(connection, 200, response_json); + try sendJson(connection, 200, response_json, request_id); } -pub fn sendEventStreamHeaders(connection: std.net.Server.Connection) !void { - const response = +fn renderToolCallsSuffixJsonAlloc( + allocator: std.mem.Allocator, + tool_calls: ?[]const types.ToolCall, +) ![]u8 { + const calls = tool_calls orelse return try allocator.dupe(u8, ""); + if (calls.len == 0) return try allocator.dupe(u8, ""); + + var out = std.ArrayList(u8){}; + errdefer out.deinit(allocator); + + try out.appendSlice(allocator, ",\"tool_calls\":["); + for (calls, 0..) |call, index| { + if (index > 0) try out.append(allocator, ','); + + const escaped_id = try escapeJsonStringAlloc(allocator, call.id); + defer allocator.free(escaped_id); + const escaped_name = try escapeJsonStringAlloc(allocator, call.name); + defer allocator.free(escaped_name); + const escaped_arguments = try escapeJsonStringAlloc(allocator, call.arguments_json); + defer allocator.free(escaped_arguments); + + try out.writer(allocator).print( + "{{\"id\":\"{s}\",\"type\":\"function\",\"function\":{{\"name\":\"{s}\",\"arguments\":\"{s}\"}}}}", + .{ escaped_id, escaped_name, escaped_arguments }, + ); + } + try out.append(allocator, ']'); + return try out.toOwnedSlice(allocator); +} + +test "renderToolCallsSuffixJsonAlloc emits OpenAI tool calls" { + const allocator = std.testing.allocator; + const calls = [_]types.ToolCall{ + .{ + .id = try allocator.dupe(u8, "call_1"), + .name = try allocator.dupe(u8, "echo"), + .arguments_json = try allocator.dupe(u8, "{\"message\":\"demo-green\"}"), + }, + }; + defer calls[0].deinit(allocator); + + const rendered = try renderToolCallsSuffixJsonAlloc(allocator, calls[0..]); + defer allocator.free(rendered); + + try std.testing.expect(std.mem.indexOf(u8, rendered, "\"tool_calls\"") != null); + try std.testing.expect(std.mem.indexOf(u8, rendered, "\"name\":\"echo\"") != null); + try std.testing.expect(std.mem.indexOf(u8, rendered, "\\\"message\\\":\\\"demo-green\\\"") != null); +} + +pub fn sendEventStreamHeaders(connection: std.net.Server.Connection, request_id: ?[]const u8) !void { + const request_id_header = if (request_id) |id| + try std.fmt.allocPrint( + std.heap.page_allocator, + "X-Request-Id: {s}\r\n", + .{id}, + ) + else + try std.heap.page_allocator.dupe(u8, ""); + defer std.heap.page_allocator.free(request_id_header); + + const response = try std.fmt.allocPrint( + std.heap.page_allocator, "HTTP/1.1 200 OK\r\n" ++ - "Content-Type: text/event-stream\r\n" ++ - "Cache-Control: no-cache\r\n" ++ - "Connection: close\r\n" ++ - "X-Accel-Buffering: no\r\n" ++ - "\r\n"; + "Content-Type: text/event-stream\r\n" ++ + "Cache-Control: no-cache\r\n" ++ + "{s}" ++ + "Connection: close\r\n" ++ + "X-Accel-Buffering: no\r\n" ++ + "\r\n", + .{request_id_header}, + ); + defer std.heap.page_allocator.free(response); try connection.stream.writeAll(response); } diff --git a/src/core/server.zig b/src/core/server.zig index 6c60671..0e899b1 100644 --- a/src/core/server.zig +++ b/src/core/server.zig @@ -9,13 +9,36 @@ const backend = @import("../backend/api.zig"); const auth = @import("../backend/auth.zig"); const tooling = @import("../backend/tools.zig"); const session = @import("../backend/session.zig"); +const workspace = @import("../backend/workspace.zig"); +const react = @import("../react.zig"); + +const context_compaction_keep_messages = 8; +const max_recent_requests = 32; +const length_continue_prompt = + "The previous assistant message was cut off by the generation limit. " ++ + "Continue the same answer from the exact next token. Output only the continuation text; " ++ + "do not mention truncation, continuation, the previous message, or these instructions."; + +const RecentRequest = struct { + route: [48]u8 = undefined, + route_len: u8 = 0, + status_code: u16 = 0, + duration_ms: u64 = 0, + request_id: [64]u8 = undefined, + request_id_len: u8 = 0, +}; const ServerState = struct { + mutex: std.Thread.Mutex = .{}, total_requests: u64 = 0, successful_requests: u64 = 0, failed_requests: u64 = 0, active_connections: u64 = 0, + provider_latency_buckets: [4]u64 = .{ 0, 0, 0, 0 }, connected_clients: std.ArrayList([]u8), + recent_requests: [max_recent_requests]RecentRequest = undefined, + recent_request_idx: usize = 0, + recent_request_count: usize = 0, fn init() ServerState { return .{ .connected_clients = .{} }; @@ -29,12 +52,192 @@ const ServerState = struct { } fn noteClient(self: *ServerState, allocator: std.mem.Allocator, label: []const u8) !void { + self.mutex.lock(); + defer self.mutex.unlock(); for (self.connected_clients.items) |existing| { if (std.mem.eql(u8, existing, label)) return; } try self.connected_clients.append(allocator, try allocator.dupe(u8, label)); } + + fn noteRequestStarted(self: *ServerState) void { + self.mutex.lock(); + defer self.mutex.unlock(); + self.total_requests += 1; + } + + fn noteRequestSucceeded(self: *ServerState) void { + self.mutex.lock(); + defer self.mutex.unlock(); + self.successful_requests += 1; + } + + fn noteRequestFailed(self: *ServerState) void { + self.mutex.lock(); + defer self.mutex.unlock(); + self.failed_requests += 1; + } + + fn noteConnectionOpened(self: *ServerState) void { + self.mutex.lock(); + defer self.mutex.unlock(); + self.active_connections += 1; + } + + fn noteConnectionClosed(self: *ServerState) void { + self.mutex.lock(); + defer self.mutex.unlock(); + self.active_connections -= 1; + } + + fn noteProviderLatency(self: *ServerState, elapsed_ms: u64) void { + self.mutex.lock(); + defer self.mutex.unlock(); + const idx: usize = if (elapsed_ms < 50) + 0 + else if (elapsed_ms < 200) + 1 + else if (elapsed_ms < 1000) + 2 + else + 3; + self.provider_latency_buckets[idx] += 1; + } + + fn noteRecentRequest( + self: *ServerState, + route: []const u8, + status_code: u16, + duration_ms: u64, + request_id: []const u8, + ) void { + self.mutex.lock(); + defer self.mutex.unlock(); + + var entry = &self.recent_requests[self.recent_request_idx]; + const route_len = @min(route.len, entry.route.len); + @memcpy(entry.route[0..route_len], route[0..route_len]); + entry.route_len = @intCast(route_len); + entry.status_code = status_code; + entry.duration_ms = duration_ms; + + const request_id_len = @min(request_id.len, entry.request_id.len); + @memcpy(entry.request_id[0..request_id_len], request_id[0..request_id_len]); + entry.request_id_len = @intCast(request_id_len); + + self.recent_request_idx = (self.recent_request_idx + 1) % max_recent_requests; + if (self.recent_request_count < max_recent_requests) { + self.recent_request_count += 1; + } + } + + fn recentRequestsJsonAlloc(self: *ServerState, allocator: std.mem.Allocator) ![]u8 { + var out = std.ArrayList(u8){}; + errdefer out.deinit(allocator); + + self.mutex.lock(); + defer self.mutex.unlock(); + + try out.append(allocator, '['); + var written: usize = 0; + var idx = if (self.recent_request_count < max_recent_requests) + 0 + else + self.recent_request_idx; + + while (written < self.recent_request_count) : ({ + written += 1; + idx = (idx + 1) % max_recent_requests; + }) { + const entry = self.recent_requests[idx]; + if (written > 0) try out.append(allocator, ','); + try out.writer(allocator).print( + "{{\"route\":\"{s}\",\"status\":{d},\"duration_ms\":{d},\"request_id\":\"{s}\"}}", + .{ + entry.route[0..entry.route_len], + entry.status_code, + entry.duration_ms, + entry.request_id[0..entry.request_id_len], + }, + ); + } + try out.append(allocator, ']'); + return try out.toOwnedSlice(allocator); + } + + fn snapshot(self: *ServerState) Snapshot { + self.mutex.lock(); + defer self.mutex.unlock(); + return .{ + .total_requests = self.total_requests, + .successful_requests = self.successful_requests, + .failed_requests = self.failed_requests, + .active_connections = self.active_connections, + .provider_latency_buckets = self.provider_latency_buckets, + }; + } + + fn knownClientsJsonAlloc(self: *ServerState, allocator: std.mem.Allocator) ![]u8 { + var clients_json = std.ArrayList(u8){}; + errdefer clients_json.deinit(allocator); + + self.mutex.lock(); + defer self.mutex.unlock(); + + try clients_json.append(allocator, '['); + for (self.connected_clients.items, 0..) |client, idx| { + if (idx > 0) try clients_json.append(allocator, ','); + const escaped = try response.escapeJsonStringAlloc(allocator, client); + defer allocator.free(escaped); + try clients_json.writer(allocator).print("\"{s}\"", .{escaped}); + } + try clients_json.append(allocator, ']'); + return try clients_json.toOwnedSlice(allocator); + } +}; + +const Snapshot = struct { + total_requests: u64, + successful_requests: u64, + failed_requests: u64, + active_connections: u64, + provider_latency_buckets: [4]u64, +}; + +const ConnectionGate = struct { + mutex: std.Thread.Mutex = .{}, + active: usize = 0, + max_active: usize, + + fn tryAcquire(self: *ConnectionGate) bool { + self.mutex.lock(); + defer self.mutex.unlock(); + if (self.active >= self.max_active) return false; + self.active += 1; + return true; + } + + fn release(self: *ConnectionGate) void { + self.mutex.lock(); + defer self.mutex.unlock(); + self.active -= 1; + } +}; + +const SessionStoreGuard = struct { + mutex: std.Thread.Mutex = .{}, +}; + +const WorkerContext = struct { + allocator: std.mem.Allocator, + app_config: *const config.Config, + server_state: *ServerState, + tool_registry: *tooling.ToolRegistry, + session_store: ?*session.SessionStore, + session_store_guard: *SessionStoreGuard, + workspace_store: ?*workspace.WorkspaceStore, + gate: *ConnectionGate, }; /// Starts the TCP listener and serves requests indefinitely. @@ -62,6 +265,7 @@ pub fn run( .{ app_config.listen_host, app_config.listen_port }, ); logInfo("Default provider: {s}", .{app_config.default_provider}); + logInfo("Default model: {s}", .{app_config.defaultModel()}); logInfo( "Available providers: ollama (aliases: qwen, ollama_qwen), openai, openrouter, claude (alias: anthropic), bedrock, llama_cpp (alias: llama.cpp)", .{}, @@ -102,6 +306,9 @@ pub fn run( try tool_registry.register(allocator, "utc"); try tool_registry.register(allocator, "cmd"); try tool_registry.register(allocator, "bash"); + try tool_registry.register(allocator, "file_read"); + try tool_registry.register(allocator, "file_write"); + try tool_registry.register(allocator, "file_search"); var file_session_store: ?session.FileSessionStore = null; var session_store: ?session.SessionStore = null; @@ -125,25 +332,73 @@ pub fn run( } defer if (session_store) |*store| store.deinit(allocator); + var workspace_store_impl: ?workspace.WorkspaceStore = null; + if (app_config.workspace_mode_enabled) { + workspace_store_impl = workspace.WorkspaceStore.init(allocator); + logInfo( + "Workspace mode enabled root={s} (in-memory; cleared on exit)", + .{app_config.workspace_root}, + ); + } + defer if (workspace_store_impl) |*store| store.deinit(); + const active_session_store: ?*session.SessionStore = if (session_store) |*store| store else null; + const active_workspace_store: ?*workspace.WorkspaceStore = if (workspace_store_impl) |*store| store else null; + + var gate = ConnectionGate{ .max_active = app_config.max_concurrent_connections }; + var session_store_guard = SessionStoreGuard{}; + const worker_context = WorkerContext{ + .allocator = allocator, + .app_config = app_config, + .server_state = &server_state, + .tool_registry = &tool_registry, + .session_store = active_session_store, + .session_store_guard = &session_store_guard, + .workspace_store = active_workspace_store, + .gate = &gate, + }; while (true) { - var connection = try server.accept(); - defer connection.stream.close(); - - server_state.active_connections += 1; - defer server_state.active_connections -= 1; - - // Client tracking is intentionally disabled in the hot path on Windows - // until socket stability issues are fully resolved. + const connection = try server.accept(); + if (!gate.tryAcquire()) { + sendJsonSafe( + connection, + 503, + "{\"error\":{\"message\":\"Server is at connection capacity\",\"type\":\"server_error\",\"param\":null,\"code\":\"connection_capacity\"}}", + app_config, + null, + ); + connection.stream.close(); + continue; + } - handleConnection(allocator, app_config, &connection, &server_state, &tool_registry, active_session_store) catch |err| { - logError("Request handling error: {s}", .{@errorName(err)}); - server_state.failed_requests += 1; - }; + var worker = try std.Thread.spawn(.{}, connectionWorkerMain, .{ worker_context, connection }); + worker.detach(); } } +fn connectionWorkerMain(context: WorkerContext, connection: std.net.Server.Connection) void { + defer context.gate.release(); + var worker_connection = connection; + defer worker_connection.stream.close(); + context.server_state.noteConnectionOpened(); + defer context.server_state.noteConnectionClosed(); + + handleConnection( + context.allocator, + context.app_config, + &worker_connection, + context.server_state, + context.tool_registry, + context.session_store, + context.session_store_guard, + context.workspace_store, + ) catch |err| { + logError("Request handling error: {s}", .{@errorName(err)}); + context.server_state.noteRequestFailed(); + }; +} + fn handleConnection( allocator: std.mem.Allocator, app_config: *const config.Config, @@ -151,54 +406,74 @@ fn handleConnection( server_state: *ServerState, tool_registry: *tooling.ToolRegistry, session_store: ?*session.SessionStore, + session_store_guard: *SessionStoreGuard, + workspace_store: ?*workspace.WorkspaceStore, ) !void { + const request_started_ms = std.time.milliTimestamp(); + var request_route: []const u8 = "unknown"; + var request_status: u16 = 500; + var request_id_owned: ?[]u8 = null; + defer if (request_id_owned) |id| allocator.free(id); + defer server_state.noteRecentRequest( + request_route, + request_status, + @intCast(@max(std.time.milliTimestamp() - request_started_ms, 0)), + request_id_owned orelse "unknown", + ); + // Translate transport-level parsing failures into OpenAI-style API errors. - const request_raw = request.readHttpRequest(allocator, connection, app_config.request_timeout_ms) catch |err| switch (err) { + const request_raw = request.readHttpRequest( + allocator, + connection, + app_config.request_timeout_ms, + app_config.max_request_bytes, + app_config.max_header_bytes, + ) catch |err| switch (err) { error.ClientDisconnected => { debugLog(app_config, "client disconnected before full request", .{}); return; }, error.RequestTimedOut => { - sendApiErrorSafe(connection.*, allocator, backend.errors.requestTimeoutError(), app_config); + sendApiErrorSafe(connection.*, allocator, backend.errors.requestTimeoutError(), app_config, null); return; }, error.RequestTooLarge => { - sendApiErrorSafe(connection.*, allocator, backend.errors.payloadTooLargeError(), app_config); + sendApiErrorSafe(connection.*, allocator, backend.errors.payloadTooLargeError(), app_config, null); return; }, error.HeadersTooLarge => { sendApiErrorSafe(connection.*, allocator, backend.errors.httpError( "HTTP headers are too large", "headers_too_large", - ), app_config); + ), app_config, null); return; }, error.InvalidHttpRequest => { sendApiErrorSafe(connection.*, allocator, backend.errors.httpError( "Malformed HTTP request", "invalid_http_request", - ), app_config); + ), app_config, null); return; }, error.MissingContentLength => { sendApiErrorSafe(connection.*, allocator, backend.errors.httpError( "Missing Content-Length header", "missing_content_length", - ), app_config); + ), app_config, null); return; }, error.InvalidContentLength => { sendApiErrorSafe(connection.*, allocator, backend.errors.httpError( "Invalid Content-Length header", "invalid_content_length", - ), app_config); + ), app_config, null); return; }, error.IncompleteRequestBody => { sendApiErrorSafe(connection.*, allocator, backend.errors.httpError( "Incomplete request body", "incomplete_body", - ), app_config); + ), app_config, null); return; }, else => { @@ -208,8 +483,16 @@ fn handleConnection( }; defer allocator.free(request_raw); + request_id_owned = try request.resolveRequestIdAlloc(allocator, request_raw); + const request_id = request_id_owned.?; + + const header_end = std.mem.indexOf(u8, request_raw, "\r\n\r\n") orelse request_raw.len; + if (request.getHeaderValue(request_raw[0..header_end], "User-Agent")) |user_agent| { + server_state.noteClient(allocator, user_agent) catch {}; + } + if (request_raw.len == 0) return; - server_state.total_requests += 1; + server_state.noteRequestStarted(); debugLog( app_config, @@ -221,20 +504,29 @@ fn handleConnection( sendApiErrorSafe(connection.*, allocator, backend.errors.httpError( "Malformed HTTP request", "invalid_http_request", - ), app_config); - server_state.failed_requests += 1; + ), app_config, request_id); + server_state.noteRequestFailed(); return; }; if (route == null) { - sendApiErrorSafe(connection.*, allocator, backend.errors.notFoundError(), app_config); - server_state.failed_requests += 1; + sendApiErrorSafe(connection.*, allocator, backend.errors.notFoundError(), app_config, request_id); + server_state.noteRequestFailed(); return; } + request_route = switch (route.?) { + .health => "GET /health", + .metrics => "GET /metrics", + .diagnostics_clients => "GET /diagnostics/clients", + .diagnostics_requests => "GET /diagnostics/requests", + .diagnostics_providers => "GET /diagnostics/providers", + .chat_completions => "POST /v1/chat/completions", + }; + if (requiresAuth(route.?) and auth.authorizeRequest(app_config.auth_api_key, request_raw) == .denied) { - sendJsonSafe(connection.*, 401, "{\"error\":{\"message\":\"Unauthorized\",\"type\":\"auth_error\",\"param\":null,\"code\":\"unauthorized\"}}", app_config); - server_state.failed_requests += 1; + sendJsonSafe(connection.*, 401, "{\"error\":{\"message\":\"Unauthorized\",\"type\":\"auth_error\",\"param\":null,\"code\":\"unauthorized\"}}", app_config, request_id); + server_state.noteRequestFailed(); return; } @@ -246,63 +538,66 @@ fn handleConnection( .{app_config.instance_id}, ); defer allocator.free(health_json); - sendJsonSafe(connection.*, 200, health_json, app_config); - server_state.successful_requests += 1; + sendJsonSafe(connection.*, 200, health_json, app_config, request_id); + server_state.noteRequestSucceeded(); return; }, .metrics => { + const snapshot = server_state.snapshot(); const metrics_json = try std.fmt.allocPrint( allocator, - "{{\"instance_id\":\"{s}\",\"total_requests\":{d},\"successful_requests\":{d},\"failed_requests\":{d},\"active_connections\":{d}}}", + "{{\"instance_id\":\"{s}\",\"total_requests\":{d},\"successful_requests\":{d},\"failed_requests\":{d},\"active_connections\":{d},\"provider_latency_buckets_ms\":{{\"lt_50\":{d},\"lt_200\":{d},\"lt_1000\":{d},\"gte_1000\":{d}}}}}", .{ app_config.instance_id, - server_state.total_requests, - server_state.successful_requests, - server_state.failed_requests, - server_state.active_connections, + snapshot.total_requests, + snapshot.successful_requests, + snapshot.failed_requests, + snapshot.active_connections, + snapshot.provider_latency_buckets[0], + snapshot.provider_latency_buckets[1], + snapshot.provider_latency_buckets[2], + snapshot.provider_latency_buckets[3], }, ); defer allocator.free(metrics_json); - sendJsonSafe(connection.*, 200, metrics_json, app_config); - server_state.successful_requests += 1; + sendJsonSafe(connection.*, 200, metrics_json, app_config, request_id); + server_state.noteRequestSucceeded(); return; }, .diagnostics_clients => { - var clients_json = std.ArrayList(u8){}; - defer clients_json.deinit(allocator); - try clients_json.append(allocator, '['); - for (server_state.connected_clients.items, 0..) |client, idx| { - if (idx > 0) try clients_json.append(allocator, ','); - const escaped = try response.escapeJsonStringAlloc(allocator, client); - defer allocator.free(escaped); - try clients_json.writer(allocator).print("\"{s}\"", .{escaped}); - } - try clients_json.append(allocator, ']'); + const snapshot = server_state.snapshot(); + const clients_json = try server_state.knownClientsJsonAlloc(allocator); + defer allocator.free(clients_json); const payload = try std.fmt.allocPrint( allocator, "{{\"instance_id\":\"{s}\",\"active_connections\":{d},\"known_clients\":{s}}}", - .{ app_config.instance_id, server_state.active_connections, clients_json.items }, + .{ app_config.instance_id, snapshot.active_connections, clients_json }, ); defer allocator.free(payload); - sendJsonSafe(connection.*, 200, payload, app_config); - server_state.successful_requests += 1; + sendJsonSafe(connection.*, 200, payload, app_config, request_id); + server_state.noteRequestSucceeded(); return; }, .diagnostics_requests => { + const snapshot = server_state.snapshot(); + const recent_json = try server_state.recentRequestsJsonAlloc(allocator); + defer allocator.free(recent_json); const payload = try std.fmt.allocPrint( allocator, - "{{\"instance_id\":\"{s}\",\"total_requests\":{d},\"successful_requests\":{d},\"failed_requests\":{d}}}", + "{{\"instance_id\":\"{s}\",\"total_requests\":{d},\"successful_requests\":{d},\"failed_requests\":{d},\"recent_requests\":{s}}}", .{ app_config.instance_id, - server_state.total_requests, - server_state.successful_requests, - server_state.failed_requests, + snapshot.total_requests, + snapshot.successful_requests, + snapshot.failed_requests, + recent_json, }, ); defer allocator.free(payload); - sendJsonSafe(connection.*, 200, payload, app_config); - server_state.successful_requests += 1; + sendJsonSafe(connection.*, 200, payload, app_config, null); + server_state.noteRequestSucceeded(); + request_status = 200; return; }, .diagnostics_providers => { @@ -316,8 +611,8 @@ fn handleConnection( ); defer allocator.free(payload); - sendJsonSafe(connection.*, 200, payload, app_config); - server_state.successful_requests += 1; + sendJsonSafe(connection.*, 200, payload, app_config, request_id); + server_state.noteRequestSucceeded(); return; }, .chat_completions => {}, @@ -327,37 +622,53 @@ fn handleConnection( sendApiErrorSafe(connection.*, allocator, backend.errors.httpError( "Missing request body", "missing_body", - ), app_config); - server_state.failed_requests += 1; + ), app_config, request_id); + server_state.noteRequestFailed(); return; }; debugLog(app_config, "request body_len={d}", .{body.len}); const parse_result = try backend.parseChatRequest(allocator, body); - const parsed_req = switch (parse_result) { + var parsed_req = switch (parse_result) { .ok => |parsed_request| parsed_request, .err => |api_error| { - sendApiErrorSafe(connection.*, allocator, api_error, app_config); - server_state.failed_requests += 1; + sendApiErrorSafe(connection.*, allocator, api_error, app_config, request_id); + server_state.noteRequestFailed(); return; }, }; defer parsed_req.deinit(allocator); + react.ensureReactCodingToolsAlloc(allocator, &parsed_req) catch |err| { + logError("ReAct tool augmentation failed: {s}", .{@errorName(err)}); + sendApiErrorSafe( + connection.*, + allocator, + backend.errors.httpError("Failed to prepare ReAct coding tools", "react_tool_setup_failed"), + app_config, + request_id, + ); + server_state.noteRequestFailed(); + return; + }; + if (!tooling.validateRequestedTools(tool_registry, parsed_req.tools)) { sendApiErrorSafe(connection.*, allocator, backend.errors.validationError( "One or more requested tools are not registered", "tools", "unknown_tool", - ), app_config); - server_state.failed_requests += 1; + ), app_config, request_id); + server_state.noteRequestFailed(); return; } var loaded_session: ?session.SessionState = null; defer if (loaded_session) |state| state.deinit(allocator); + var active_workspace: ?*workspace.WorkspaceState = null; + const use_workspace = app_config.workspace_mode_enabled and parsed_req.workspace_id != null; + var request_messages = try session.cloneMessagesAlloc(allocator, parsed_req.messages); defer { for (request_messages) |message| { @@ -369,12 +680,44 @@ fn handleConnection( var request_prompt = try allocator.dupe(u8, parsed_req.prompt); defer allocator.free(request_prompt); - if (session_store) |store| { + if (use_workspace) { + if (workspace_store) |store| { + active_workspace = try store.getOrCreate(parsed_req.workspace_id.?); + if (active_workspace.?.messages.len > 0) { + const merged_messages = try workspace.mergeIncomingMessagesAlloc( + allocator, + active_workspace.?.messages, + parsed_req.messages, + ); + + for (request_messages) |message| { + message.deinit(allocator); + } + allocator.free(request_messages); + request_messages = merged_messages; + + allocator.free(request_prompt); + request_prompt = try extractLastUserPromptAlloc(allocator, request_messages); + } + + const with_hint = try workspace.prependMemoryHintAlloc(allocator, active_workspace.?, request_messages); + for (request_messages) |message| { + message.deinit(allocator); + } + allocator.free(request_messages); + request_messages = with_hint; + + allocator.free(request_prompt); + request_prompt = try extractLastUserPromptAlloc(allocator, request_messages); + } + } else if (session_store) |store| { if (parsed_req.session_id) |session_id| { + session_store_guard.mutex.lock(); loaded_session = store.load(allocator, session_id, parsed_req.tenant_id) catch |err| blk: { logError("Session load failed for '{s}': {s}", .{ session_id, @errorName(err) }); break :blk null; }; + session_store_guard.mutex.unlock(); if (loaded_session) |state| { if (state.messages.len > 0) { @@ -397,13 +740,31 @@ fn handleConnection( } } - if (parsed_req.max_context_tokens) |max_tokens| { + const effective_max_context_tokens = parsed_req.max_context_tokens orelse app_config.default_max_context_tokens; + + if (effective_max_context_tokens) |max_tokens| { const estimated = session.estimateTokenCount(request_messages); if (session.shouldCompressContext(estimated, max_tokens)) { + const compacted_messages = try session.compactContextToBudgetAlloc( + allocator, + request_messages, + max_tokens, + context_compaction_keep_messages, + ); + + for (request_messages) |message| { + message.deinit(allocator); + } + allocator.free(request_messages); + request_messages = compacted_messages; + + allocator.free(request_prompt); + request_prompt = try extractLastUserPromptAlloc(allocator, request_messages); + debugLog( app_config, - "context compression suggested estimated_tokens={d} max_context_tokens={d}", - .{ estimated, max_tokens }, + "context compacted estimated_tokens={d} compacted_tokens={d} max_context_tokens={d}", + .{ estimated, session.estimateTokenCount(request_messages), max_tokens }, ); } } @@ -423,16 +784,6 @@ fn handleConnection( const requested_provider = parsed_req.provider orelse app_config.default_provider; const normalized_provider = types.normalizeProviderName(requested_provider) orelse requested_provider; - if (!std.mem.eql(u8, normalized_provider, "ollama_qwen")) { - sendApiErrorSafe(connection.*, allocator, backend.errors.validationError( - "stream=true is currently supported only for ollama", - "stream", - "unsupported_stream_provider", - ), app_config); - server_state.failed_requests += 1; - return; - } - var stream_request = try cloneRequestWithMessagesAlloc( allocator, parsed_req, @@ -441,10 +792,94 @@ fn handleConnection( ); defer stream_request.deinit(allocator); - const stream_auto_tool_summary = try tooling.maybeExecutePromptToolsAlloc(allocator, stream_request, app_config); + if (try tooling.tryExecuteDebugTool(allocator, stream_request, app_config, request_id, active_workspace)) |tool_result| { + defer tool_result.deinit(allocator); + + try response.sendEventStreamHeaders(connection.*, request_id); + const completion_id = try std.fmt.allocPrint( + allocator, + "chatcmpl-{d}", + .{std.time.microTimestamp()}, + ); + defer allocator.free(completion_id); + + try response.sendChatCompletionChunkSse( + connection.*, + allocator, + completion_id, + tool_result.model, + tool_result.output, + null, + ); + try response.sendChatCompletionChunkSse( + connection.*, + allocator, + completion_id, + tool_result.model, + null, + tool_result.finish_reason, + ); + try response.sendSseDone(connection.*); + + server_state.noteRequestSucceeded(); + return; + } + + if (!std.mem.eql(u8, normalized_provider, "ollama_qwen")) { + sendApiErrorSafe(connection.*, allocator, backend.errors.validationError( + "stream=true is currently supported only for ollama", + "stream", + "unsupported_stream_provider", + ), app_config, request_id); + server_state.noteRequestFailed(); + return; + } + + const stream_auto_tool_summary = tooling.maybeExecutePromptToolsAlloc(allocator, stream_request, app_config, request_id, active_workspace) catch |err| switch (err) { + error.ToolCallLimitExceeded => { + sendApiErrorSafe(connection.*, allocator, backend.errors.validationError( + "Tool call limit exceeded", + "tools", + "tool_call_limit_exceeded", + ), app_config, request_id); + server_state.noteRequestFailed(); + return; + }, + else => return err, + }; defer if (stream_auto_tool_summary) |summary| allocator.free(summary); if (stream_auto_tool_summary) |summary| { + if (tooling.shouldShortCircuitAutoTools(stream_request)) { + try response.sendEventStreamHeaders(connection.*, request_id); + const completion_id = try std.fmt.allocPrint( + allocator, + "chatcmpl-{d}", + .{std.time.microTimestamp()}, + ); + defer allocator.free(completion_id); + + try response.sendChatCompletionChunkSse( + connection.*, + allocator, + completion_id, + "debug-tools/auto", + summary, + null, + ); + try response.sendChatCompletionChunkSse( + connection.*, + allocator, + completion_id, + "debug-tools/auto", + null, + "tool", + ); + try response.sendSseDone(connection.*); + server_state.noteRequestSucceeded(); + return; + } + const augmented_messages = try appendSystemMessageAlloc(allocator, stream_request.messages, summary); for (stream_request.messages) |message| { message.deinit(allocator); @@ -461,26 +896,27 @@ fn handleConnection( app_config, stream_request, app_config.loop_stream_progress_enabled, + request_id, ) catch |err| { logError("Provider stream loop error: {s}", .{@errorName(err)}); - server_state.failed_requests += 1; + server_state.noteRequestFailed(); return; }; - server_state.successful_requests += 1; + server_state.noteRequestSucceeded(); return; } const ollama_qwen = @import("../providers/ollama_qwen.zig"); const stream_result = ollama_qwen.streamQwenToSse(connection.*, allocator, app_config, stream_request) catch |err| { logError("Provider stream error: {s}", .{@errorName(err)}); - server_state.failed_requests += 1; + server_state.noteRequestFailed(); return; }; switch (stream_result) { .streamed => { - server_state.successful_requests += 1; + server_state.noteRequestSucceeded(); return; }, .failed => |provider_error_response| { @@ -488,22 +924,60 @@ fn handleConnection( sendApiErrorSafe(connection.*, allocator, backend.errors.providerError( provider_error_response.output, "provider_error", - ), app_config); - server_state.failed_requests += 1; + ), app_config, request_id); + server_state.noteRequestFailed(); return; }, } } - if (try tooling.tryExecuteDebugTool(allocator, parsed_req, app_config)) |tool_result| { + if (try tooling.tryExecuteDebugTool(allocator, parsed_req, app_config, request_id, active_workspace)) |tool_result| { defer tool_result.deinit(allocator); debugLog( app_config, "debug tool executed tool_choice={s}", .{parsed_req.tool_choice orelse "(none)"}, ); - sendChatCompletionSafe(connection.*, allocator, tool_result, app_config); - server_state.successful_requests += 1; + persistConversationState( + allocator, + app_config, + use_workspace, + workspace_store, + active_workspace, + session_store, + session_store_guard, + parsed_req, + loaded_session, + request_messages, + null, + tool_result.output, + ); + sendChatCompletionSafe(connection.*, allocator, tool_result, app_config, request_id); + server_state.noteRequestSucceeded(); + request_status = 200; + return; + } + + if (try tooling.maybeMakeBuiltinToolCallResponseAlloc(allocator, parsed_req, request_id)) |tool_call_result| { + defer tool_call_result.deinit(allocator); + debugLog(app_config, "debug tool call synthesized", .{}); + persistConversationState( + allocator, + app_config, + use_workspace, + workspace_store, + active_workspace, + session_store, + session_store_guard, + parsed_req, + loaded_session, + request_messages, + null, + tool_call_result.output, + ); + sendChatCompletionSafe(connection.*, allocator, tool_call_result, app_config, request_id); + server_state.noteRequestSucceeded(); + request_status = 200; return; } @@ -518,9 +992,45 @@ fn handleConnection( ); defer provider_request.deinit(allocator); - const auto_tool_summary = try tooling.maybeExecutePromptToolsAlloc(allocator, provider_request, app_config); + const auto_tool_summary = tooling.maybeExecutePromptToolsAlloc(allocator, provider_request, app_config, request_id, active_workspace) catch |err| switch (err) { + error.ToolCallLimitExceeded => { + sendApiErrorSafe(connection.*, allocator, backend.errors.validationError( + "Tool call limit exceeded", + "tools", + "tool_call_limit_exceeded", + ), app_config, request_id); + server_state.noteRequestFailed(); + return; + }, + else => return err, + }; defer if (auto_tool_summary) |summary| allocator.free(summary); + if (auto_tool_summary) |summary| { + if (tooling.shouldShortCircuitAutoTools(provider_request)) { + var tool_response = try tooling.makeAutoToolResponse(allocator, summary); + defer tool_response.deinit(allocator); + persistConversationState( + allocator, + app_config, + use_workspace, + workspace_store, + active_workspace, + session_store, + session_store_guard, + parsed_req, + loaded_session, + provider_request.messages, + null, + tool_response.output, + ); + sendChatCompletionSafe(connection.*, allocator, tool_response, app_config, request_id); + server_state.noteRequestSucceeded(); + request_status = 200; + return; + } + } + if (auto_tool_summary) |summary| { const augmented_messages = try appendSystemMessageAlloc(allocator, provider_request.messages, summary); @@ -547,15 +1057,27 @@ fn handleConnection( const provider_started_ms = std.time.milliTimestamp(); if (loop_enabled) { - const loop_execution = executeLoopRequestAlloc(allocator, app_config, provider_request) catch |err| { + const loop_execution = executeLoopRequestAlloc(allocator, app_config, provider_request, request_id) catch |err| { + if (err == error.ToolCallLimitExceeded) { + sendApiErrorSafe( + connection.*, + allocator, + backend.errors.validationError("Tool call limit exceeded", "tools", "tool_call_limit_exceeded"), + app_config, + request_id, + ); + server_state.noteRequestFailed(); + return; + } logError("Provider loop request error: {}", .{err}); sendApiErrorSafe( connection.*, allocator, backend.errors.providerTransportError(@errorName(err)), app_config, + request_id, ); - server_state.failed_requests += 1; + server_state.noteRequestFailed(); return; }; @@ -570,14 +1092,16 @@ fn handleConnection( allocator, backend.errors.providerTransportError(@errorName(err)), app_config, + request_id, ); - server_state.failed_requests += 1; + server_state.noteRequestFailed(); return; }; result_ready = true; } const provider_elapsed_ms: u64 = @intCast(@max(std.time.milliTimestamp() - provider_started_ms, 0)); + server_state.noteProviderLatency(provider_elapsed_ms); // Log a warning when the provider exceeded the configured timeout budget, but // do NOT discard the already-completed response. A post-hoc check cannot // cancel an in-flight HTTP call; discarding a valid result would only confuse @@ -601,11 +1125,18 @@ fn handleConnection( allocator, backend.errors.providerFailureFromDetail(normalized_provider, result.output), app_config, + request_id, ); - server_state.failed_requests += 1; + server_state.noteRequestFailed(); return; } + if (!loop_enabled and isLengthFinishReason(result.finish_reason)) { + continueLengthResponseAlloc(allocator, app_config, provider_request, &result) catch |err| { + logError("Provider continuation failed: {s}", .{@errorName(err)}); + }; + } + if (std.mem.trim(u8, result.output, " \t\r\n").len == 0) { if (auto_tool_summary) |summary| { allocator.free(result.output); @@ -626,8 +1157,9 @@ fn handleConnection( "empty_model_response", ), app_config, + request_id, ); - server_state.failed_requests += 1; + server_state.noteRequestFailed(); return; } } @@ -639,50 +1171,44 @@ fn handleConnection( ); if (session_store) |store| { - if (parsed_req.session_id) |session_id| { - const loop_messages = loop_messages_for_persistence; - const with_assistant = if (loop_messages) |messages| blk: { - break :blk try session.cloneMessagesAlloc(allocator, messages); - } else session.appendAssistantMessageAlloc( + if (!use_workspace and parsed_req.session_id != null) { + persistConversationState( allocator, + app_config, + false, + workspace_store, + active_workspace, + store, + session_store_guard, + parsed_req, + loaded_session, provider_request.messages, + loop_messages_for_persistence, result.output, - ) catch |err| blk: { - logError("Session append failed for '{s}': {s}", .{ session_id, @errorName(err) }); - break :blk null; - }; - - if (with_assistant) |messages| { - defer { - for (messages) |message| { - message.deinit(allocator); - } - allocator.free(messages); - } - - var state = session.SessionState{ - .session_id = try allocator.dupe(u8, session_id), - .tenant_id = if (parsed_req.tenant_id) |value| try allocator.dupe(u8, value) else null, - .summary = if (loaded_session) |loaded| try allocator.dupe(u8, loaded.summary) else try allocator.dupe(u8, ""), - .messages = try session.trimToRetentionAlloc( - allocator, - messages, - app_config.session_retention_messages, - ), - .message_count = 0, - }; - state.message_count = state.messages.len; - defer state.deinit(allocator); - - store.save(allocator, state) catch |err| { - logError("Session save failed for '{s}': {s}", .{ session_id, @errorName(err) }); - }; - } + ); } } - sendChatCompletionSafe(connection.*, allocator, result, app_config); - server_state.successful_requests += 1; + if (use_workspace) { + persistConversationState( + allocator, + app_config, + true, + workspace_store, + active_workspace, + session_store, + session_store_guard, + parsed_req, + loaded_session, + provider_request.messages, + loop_messages_for_persistence, + result.output, + ); + } + + sendChatCompletionSafe(connection.*, allocator, result, app_config, request_id); + server_state.noteRequestSucceeded(); + request_status = 200; } fn cloneRequestWithMessagesAlloc( @@ -712,6 +1238,7 @@ fn cloneRequestWithMessagesAlloc( copied_tools[idx] = .{ .name = try allocator.dupe(u8, tool.name), .description = try allocator.dupe(u8, tool.description), + .parameters_json = if (tool.parameters_json) |parameters_json| try allocator.dupe(u8, parameters_json) else null, }; initialized_tools += 1; } @@ -727,6 +1254,7 @@ fn cloneRequestWithMessagesAlloc( .repeat_penalty = parsed_req.repeat_penalty, .session_id = if (parsed_req.session_id) |session_id| try allocator.dupe(u8, session_id) else null, .tenant_id = if (parsed_req.tenant_id) |tenant_id| try allocator.dupe(u8, tenant_id) else null, + .workspace_id = if (parsed_req.workspace_id) |workspace_id| try allocator.dupe(u8, workspace_id) else null, .max_context_tokens = parsed_req.max_context_tokens, .tools = copied_tools, .tool_choice = if (parsed_req.tool_choice) |tool_choice| try allocator.dupe(u8, tool_choice) else null, @@ -741,28 +1269,55 @@ const LoopExecution = struct { messages: []types.Message, }; -fn streamLoopRequestToSse( - connection: std.net.Server.Connection, +const agent_loop_guidance = + "You are running in API agent loop mode. Improve the answer each turn. Briefly self-critique then improve. Include the completion marker exactly when fully complete."; + +const agent_continue_prompt = + "Critique your previous answer briefly, then improve it with concrete next steps. If complete, include the completion marker exactly and return the final result."; + +const react_loop_default_max_turns: usize = 24; +const react_repeated_stop_threshold: usize = 2; + +fn resolveLoopMaxTurns(loop_mode: []const u8, explicit: ?usize) usize { + if (explicit) |value| return value; + if (std.ascii.eqlIgnoreCase(loop_mode, "react")) return react_loop_default_max_turns; + return 8; +} + +fn reactToolCallBudget(app_config: *const config.Config, loop_max_turns: usize) usize { + return @max(app_config.max_tool_calls_per_request, loop_max_turns); +} + +fn maybeCompactLoopWorkingMessages( allocator: std.mem.Allocator, - app_config: *const config.Config, - base_request: types.Request, - emit_progress: bool, + working_messages: *[]types.Message, + max_context_tokens: ?usize, ) !void { - const loop_mode = base_request.loop_mode orelse "basic"; - const loop_until = base_request.loop_until orelse "DONE"; - const loop_max_turns = base_request.loop_max_turns orelse 8; + const max_tokens = max_context_tokens orelse return; + const estimated = session.estimateTokenCount(working_messages.*); + if (!session.shouldCompressContext(estimated, max_tokens)) return; - try response.sendEventStreamHeaders(connection); - - const completion_id = try std.fmt.allocPrint( + const compacted = try session.compactContextToBudgetAlloc( allocator, - "chatcmpl-{d}", - .{std.time.microTimestamp()}, + working_messages.*, + max_tokens, + context_compaction_keep_messages, ); - defer allocator.free(completion_id); + for (working_messages.*) |message| { + message.deinit(allocator); + } + allocator.free(working_messages.*); + working_messages.* = compacted; +} + +fn prepareLoopWorkingMessagesAlloc( + allocator: std.mem.Allocator, + base_request: types.Request, + loop_mode: []const u8, +) ![]types.Message { var working_messages = try session.cloneMessagesAlloc(allocator, base_request.messages); - defer { + errdefer { for (working_messages) |message| { message.deinit(allocator); } @@ -774,71 +1329,292 @@ fn streamLoopRequestToSse( allocator, working_messages, "system", - "You are running in API agent loop mode. Improve the answer each turn. Briefly self-critique then improve. Include the completion marker exactly when fully complete.", + agent_loop_guidance, ); for (working_messages) |message| { message.deinit(allocator); } allocator.free(working_messages); working_messages = with_guidance; + } else if (std.ascii.eqlIgnoreCase(loop_mode, "react")) { + const react_prompt = try react.buildSystemPromptAlloc(allocator, base_request.tools); + defer allocator.free(react_prompt); + const with_react = try appendRoleMessageAlloc( + allocator, + working_messages, + "system", + react_prompt, + ); + for (working_messages) |message| { + message.deinit(allocator); + } + allocator.free(working_messages); + working_messages = with_react; } - var latest_user_prompt = try allocator.dupe(u8, base_request.prompt); - defer allocator.free(latest_user_prompt); + return working_messages; +} - var previous_output: ?[]u8 = null; - defer if (previous_output) |value| allocator.free(value); - var repeated_count: usize = 0; +fn nextBasicLoopPromptAlloc(allocator: std.mem.Allocator, loop_mode: []const u8) ![]u8 { + return try allocator.dupe( + u8, + if (std.ascii.eqlIgnoreCase(loop_mode, "agent")) agent_continue_prompt else "Continue.", + ); +} - var turn: usize = 0; - while (turn < loop_max_turns) : (turn += 1) { - var turn_request = try cloneRequestWithMessagesAlloc(allocator, base_request, latest_user_prompt, working_messages); - defer turn_request.deinit(allocator); +fn providerRequestForLoopTurn( + turn_request: types.Request, + loop_mode: []const u8, +) types.Request { + if (!std.ascii.eqlIgnoreCase(loop_mode, "react")) return turn_request; + + // ReAct advertises tools in the system prompt and parses actions locally. + // Do not forward tools upstream or providers activate tool templates that + // break on prior assistant messages. + var r = turn_request; + r.tools = &.{}; + r.tool_choice = null; + return r; +} - if (turn_request.loop_mode) |value| { - allocator.free(value); - turn_request.loop_mode = null; - } - if (turn_request.loop_until) |value| { +fn reactToolPromptAlloc( + allocator: std.mem.Allocator, + action: react.ReactAction, + tools: []const types.Tool, +) !?[]u8 { + return switch (action) { + .search => |query| blk: { + if (!tooling.hasRequestedTool(tools, "file_search")) break :blk null; + break :blk try std.fmt.allocPrint(allocator, "search {s} in .", .{query}); + }, + .lookup => |path| blk: { + if (!tooling.hasRequestedTool(tools, "file_read")) break :blk null; + break :blk try std.fmt.allocPrint(allocator, "read {s}", .{path}); + }, + .tool => |tool_action| blk: { + if (std.ascii.eqlIgnoreCase(tool_action.name, "file_read")) { + const trimmed = std.mem.trim(u8, tool_action.argument, " \t\r\n\"'"); + if (std.ascii.startsWithIgnoreCase(trimmed, "read ") or + std.ascii.startsWithIgnoreCase(trimmed, "file_read ")) + { + break :blk try allocator.dupe(u8, trimmed); + } + break :blk try std.fmt.allocPrint(allocator, "read {s}", .{trimmed}); + } + if (std.ascii.eqlIgnoreCase(tool_action.name, "file_write")) { + const trimmed = std.mem.trim(u8, tool_action.argument, " \t\r\n"); + if (std.ascii.startsWithIgnoreCase(trimmed, "write ")) { + break :blk try allocator.dupe(u8, trimmed); + } + break :blk try std.fmt.allocPrint(allocator, "write {s}", .{trimmed}); + } + if (std.ascii.eqlIgnoreCase(tool_action.name, "file_search")) { + const trimmed = std.mem.trim(u8, tool_action.argument, " \t\r\n\"'"); + if (std.ascii.startsWithIgnoreCase(trimmed, "search ") or + std.ascii.startsWithIgnoreCase(trimmed, "file_search ")) + { + break :blk try allocator.dupe(u8, trimmed); + } + break :blk try std.fmt.allocPrint(allocator, "search {s} in .", .{trimmed}); + } + break :blk try allocator.dupe(u8, tool_action.argument); + }, + else => null, + }; +} + +fn executeReactActionForRequest( + allocator: std.mem.Allocator, + app_config: *const config.Config, + turn_request: types.Request, + tool_messages: []const types.Message, + action: react.ReactAction, + request_id: []const u8, +) ![]u8 { + if (try reactToolPromptAlloc(allocator, action, turn_request.tools)) |mapped_prompt| { + defer allocator.free(mapped_prompt); + + const tool_name: []const u8 = switch (action) { + .search => "file_search", + .lookup => "file_read", + .tool => |tool_action| tool_action.name, + else => unreachable, + }; + + var tool_request = try cloneRequestWithMessagesAlloc( + allocator, + turn_request, + mapped_prompt, + tool_messages, + ); + defer tool_request.deinit(allocator); + + if (tool_request.tool_choice) |value| { allocator.free(value); - turn_request.loop_until = null; } - turn_request.loop_max_turns = null; + tool_request.tool_choice = try allocator.dupe(u8, tool_name); - var turn_result = try backend.callProvider(allocator, app_config, turn_request); - defer turn_result.deinit(allocator); + if (try tooling.tryExecuteDebugTool(allocator, tool_request, app_config, request_id, null)) |tool_result| { + defer tool_result.deinit(allocator); + return try allocator.dupe(u8, tool_result.output); + } - const with_assistant = try appendRoleMessageAlloc( + return try std.fmt.allocPrint( allocator, - working_messages, - "assistant", - turn_result.output, + "Tool action {s} was not requested or its input could not be parsed.", + .{tool_name}, ); + } + + return switch (action) { + .tool => |tool_action| blk: { + var tool_request = try cloneRequestWithMessagesAlloc( + allocator, + turn_request, + tool_action.argument, + tool_messages, + ); + defer tool_request.deinit(allocator); + + if (tool_request.tool_choice) |value| { + allocator.free(value); + } + tool_request.tool_choice = try allocator.dupe(u8, tool_action.name); + + if (try tooling.tryExecuteDebugTool(allocator, tool_request, app_config, request_id, null)) |tool_result| { + defer tool_result.deinit(allocator); + break :blk try allocator.dupe(u8, tool_result.output); + } + + break :blk try std.fmt.allocPrint( + allocator, + "Tool action {s} was not requested or its input could not be parsed.", + .{tool_action.name}, + ); + }, + else => try react.executeReactAction(allocator, action, app_config), + }; +} + +fn streamLoopRequestToSse( + connection: std.net.Server.Connection, + allocator: std.mem.Allocator, + app_config: *const config.Config, + base_request: types.Request, + emit_progress: bool, + request_id: []const u8, +) !void { + const loop_mode = base_request.loop_mode orelse "basic"; + const loop_until = base_request.loop_until orelse "DONE"; + const loop_max_turns = resolveLoopMaxTurns(loop_mode, base_request.loop_max_turns); + const effective_max_context_tokens = base_request.max_context_tokens orelse app_config.default_max_context_tokens; + const react_tool_budget = reactToolCallBudget(app_config, loop_max_turns); + + try response.sendEventStreamHeaders(connection, request_id); + var headers_sent: bool = true; + var think_block_open: bool = false; + + const completion_id = try std.fmt.allocPrint( + allocator, + "chatcmpl-{d}", + .{std.time.microTimestamp()}, + ); + defer allocator.free(completion_id); + + var working_messages = try prepareLoopWorkingMessagesAlloc(allocator, base_request, loop_mode); + defer { for (working_messages) |message| { message.deinit(allocator); } allocator.free(working_messages); - working_messages = with_assistant; + } + + var latest_user_prompt = try allocator.dupe(u8, base_request.prompt); + defer allocator.free(latest_user_prompt); + var previous_output: ?[]u8 = null; + defer if (previous_output) |value| allocator.free(value); + var repeated_count: usize = 0; + var react_tool_calls: usize = 0; + + var turn: usize = 0; + while (turn < loop_max_turns) : (turn += 1) { + try maybeCompactLoopWorkingMessages(allocator, &working_messages, effective_max_context_tokens); + + var turn_request = try cloneRequestWithMessagesAlloc(allocator, base_request, latest_user_prompt, working_messages); + defer turn_request.deinit(allocator); + + if (turn_request.loop_mode) |value| { + allocator.free(value); + turn_request.loop_mode = null; + } + if (turn_request.loop_until) |value| { + allocator.free(value); + turn_request.loop_until = null; + } + turn_request.loop_max_turns = null; + + const provider_request = providerRequestForLoopTurn(turn_request, loop_mode); + + // Emit the per-turn header BEFORE provider tokens stream in, so the + // client can render "[loop turn N/M]\n" then watch tokens arrive live. if (emit_progress) { - const progress_text = try std.fmt.allocPrint( + const progress_prefix = try std.fmt.allocPrint( allocator, - "[loop turn {d}/{d}]\n{s}\n", - .{ turn + 1, loop_max_turns, turn_result.output }, + "[loop turn {d}/{d}]\n", + .{ turn + 1, loop_max_turns }, ); - defer allocator.free(progress_text); + defer allocator.free(progress_prefix); try response.sendChatCompletionChunkSse( connection, allocator, completion_id, - turn_result.model, - progress_text, + provider_request.model orelse app_config.modelForProvider( + provider_request.provider orelse app_config.default_provider, + ), + progress_prefix, null, ); } - const normalized_output = std.mem.trim(u8, turn_result.output, " \t\r\n"); + var turn_outcome = try backend.streamProviderTurn( + connection, + allocator, + app_config, + provider_request, + completion_id, + &headers_sent, + &think_block_open, + ); + defer turn_outcome.deinit(allocator); + + // Emit a trailing newline after the turn body so progress text and + // the next observation/prefix don't visually collide. + if (emit_progress and turn_outcome.captured_content.len > 0) { + try response.sendChatCompletionChunkSse( + connection, + allocator, + completion_id, + turn_outcome.model, + "\n", + null, + ); + } + + const with_assistant = try appendRoleMessageAlloc( + allocator, + working_messages, + "assistant", + turn_outcome.captured_content, + ); + for (working_messages) |message| { + message.deinit(allocator); + } + allocator.free(working_messages); + working_messages = with_assistant; + + const normalized_output = std.mem.trim(u8, turn_outcome.captured_content, " \t\r\n"); if (previous_output) |prev| { if (std.mem.eql(u8, prev, normalized_output)) { repeated_count += 1; @@ -849,43 +1625,87 @@ fn streamLoopRequestToSse( } previous_output = try allocator.dupe(u8, normalized_output); - const reached_until = std.mem.indexOf(u8, turn_result.output, loop_until) != null; + const is_react_mode = std.ascii.eqlIgnoreCase(loop_mode, "react"); + const reached_until = std.mem.indexOf(u8, turn_outcome.captured_content, loop_until) != null; const reached_max = (turn + 1) >= loop_max_turns; - const repeated_stop = std.ascii.eqlIgnoreCase(loop_mode, "agent") and repeated_count >= 1; - const should_stop = reached_until or reached_max or repeated_stop or !turn_result.success; + const repeated_stop = blk: { + if (is_react_mode) break :blk repeated_count >= react_repeated_stop_threshold; + if (std.ascii.eqlIgnoreCase(loop_mode, "agent")) break :blk repeated_count >= 1; + break :blk false; + }; + + // ReAct: check for Finish action before standard stop checks. + var react_finished = false; + var react_observation: ?[]u8 = null; + defer if (react_observation) |obs| allocator.free(obs); + + if (is_react_mode) { + if (react.parseReactAction(turn_outcome.captured_content)) |action| { + switch (action) { + .finish => { + react_finished = true; + }, + else => { + try tooling.noteToolCall(react_tool_budget, &react_tool_calls); + const obs_raw = try executeReactActionForRequest(allocator, app_config, turn_request, working_messages, action, request_id); + defer allocator.free(obs_raw); + react_observation = try react.formatObservation(allocator, turn + 1, obs_raw); + }, + } + } else { + react_observation = try react.formatObservation( + allocator, + turn + 1, + react.parse_error_hint, + ); + } + } + + const should_stop = react_finished or reached_until or reached_max or repeated_stop or !turn_outcome.success; if (should_stop) { - if (!emit_progress and turn_result.output.len > 0) { + // Close any dangling block from the last turn so the + // client doesn't render an unterminated reasoning section. + if (think_block_open) { try response.sendChatCompletionChunkSse( - connection, - allocator, - completion_id, - turn_result.model, - turn_result.output, - null, + connection, allocator, completion_id, turn_outcome.model, "", null, ); + think_block_open = false; } try response.sendChatCompletionChunkSse( connection, allocator, completion_id, - turn_result.model, + turn_outcome.model, null, - turn_result.finish_reason, + turn_outcome.finish_reason, ); try response.sendSseDone(connection); return; } allocator.free(latest_user_prompt); - latest_user_prompt = try allocator.dupe( - u8, - if (std.ascii.eqlIgnoreCase(loop_mode, "agent")) - "Critique your previous answer briefly, then improve it with concrete next steps. If complete, include the completion marker exactly and return the final result." - else - "Continue.", - ); + + if (is_react_mode) { + // Inject the observation as the next user prompt. + const obs = react_observation orelse try allocator.dupe(u8, "Continue."); + react_observation = null; // Transfer ownership. + latest_user_prompt = obs; + + if (emit_progress) { + try response.sendChatCompletionChunkSse( + connection, + allocator, + completion_id, + turn_outcome.model, + latest_user_prompt, + null, + ); + } + } else { + latest_user_prompt = try nextBasicLoopPromptAlloc(allocator, loop_mode); + } const with_continue = try appendRoleMessageAlloc(allocator, working_messages, "user", latest_user_prompt); for (working_messages) |message| { @@ -902,12 +1722,15 @@ fn executeLoopRequestAlloc( allocator: std.mem.Allocator, app_config: *const config.Config, base_request: types.Request, + request_id: []const u8, ) !LoopExecution { const loop_mode = base_request.loop_mode orelse "basic"; const loop_until = base_request.loop_until orelse "DONE"; - const loop_max_turns = base_request.loop_max_turns orelse 8; + const loop_max_turns = resolveLoopMaxTurns(loop_mode, base_request.loop_max_turns); + const effective_max_context_tokens = base_request.max_context_tokens orelse app_config.default_max_context_tokens; + const react_tool_budget = reactToolCallBudget(app_config, loop_max_turns); - var working_messages = try session.cloneMessagesAlloc(allocator, base_request.messages); + var working_messages = try prepareLoopWorkingMessagesAlloc(allocator, base_request, loop_mode); errdefer { for (working_messages) |message| { message.deinit(allocator); @@ -915,29 +1738,18 @@ fn executeLoopRequestAlloc( allocator.free(working_messages); } - if (std.ascii.eqlIgnoreCase(loop_mode, "agent")) { - const with_guidance = try appendRoleMessageAlloc( - allocator, - working_messages, - "system", - "You are running in API agent loop mode. Improve the answer each turn. Briefly self-critique then improve. Include the completion marker exactly when fully complete.", - ); - for (working_messages) |message| { - message.deinit(allocator); - } - allocator.free(working_messages); - working_messages = with_guidance; - } - var latest_user_prompt = try allocator.dupe(u8, base_request.prompt); defer allocator.free(latest_user_prompt); var previous_output: ?[]u8 = null; defer if (previous_output) |value| allocator.free(value); var repeated_count: usize = 0; + var react_tool_calls: usize = 0; var turn: usize = 0; while (turn < loop_max_turns) : (turn += 1) { + try maybeCompactLoopWorkingMessages(allocator, &working_messages, effective_max_context_tokens); + var turn_request = try cloneRequestWithMessagesAlloc(allocator, base_request, latest_user_prompt, working_messages); defer turn_request.deinit(allocator); @@ -951,7 +1763,16 @@ fn executeLoopRequestAlloc( } turn_request.loop_max_turns = null; - var turn_result = try backend.callProvider(allocator, app_config, turn_request); + // See streamLoopRequestToSse: strip tools/tool_choice from the + // upstream call in ReAct mode to avoid provider-side tool templates. + const provider_request = providerRequestForLoopTurn(turn_request, loop_mode); + + var turn_result = try backend.callProvider(allocator, app_config, provider_request); + if (turn_result.success and isLengthFinishReason(turn_result.finish_reason)) { + continueLengthResponseAlloc(allocator, app_config, provider_request, &turn_result) catch |err| { + logError("Provider loop continuation failed: {s}", .{@errorName(err)}); + }; + } const with_assistant = try appendRoleMessageAlloc( allocator, @@ -992,16 +1813,53 @@ fn executeLoopRequestAlloc( return .{ .response = turn_result, .messages = working_messages }; } + // ReAct mode: parse action, execute, inject observation. + const is_react_mode = std.ascii.eqlIgnoreCase(loop_mode, "react"); + var react_observation: ?[]u8 = null; + var react_finished = false; + defer if (react_observation) |obs| allocator.free(obs); + + if (is_react_mode) { + if (repeated_count >= react_repeated_stop_threshold) { + return .{ .response = turn_result, .messages = working_messages }; + } + + if (react.parseReactAction(turn_result.output)) |action| { + switch (action) { + .finish => { + react_finished = true; + }, + else => { + try tooling.noteToolCall(react_tool_budget, &react_tool_calls); + const obs_raw = try executeReactActionForRequest(allocator, app_config, turn_request, working_messages, action, request_id); + defer allocator.free(obs_raw); + react_observation = try react.formatObservation(allocator, turn + 1, obs_raw); + }, + } + } else { + react_observation = try react.formatObservation( + allocator, + turn + 1, + react.parse_error_hint, + ); + } + + if (react_finished) { + return .{ .response = turn_result, .messages = working_messages }; + } + } + turn_result.deinit(allocator); allocator.free(latest_user_prompt); - latest_user_prompt = try allocator.dupe( - u8, - if (std.ascii.eqlIgnoreCase(loop_mode, "agent")) - "Critique your previous answer briefly, then improve it with concrete next steps. If complete, include the completion marker exactly and return the final result." - else - "Continue.", - ); + + if (is_react_mode) { + const obs = react_observation orelse try allocator.dupe(u8, "Continue."); + react_observation = null; // Transfer ownership. + latest_user_prompt = obs; + } else { + latest_user_prompt = try nextBasicLoopPromptAlloc(allocator, loop_mode); + } const with_continue = try appendRoleMessageAlloc(allocator, working_messages, "user", latest_user_prompt); for (working_messages) |message| { @@ -1086,13 +1944,210 @@ fn appendRoleMessageAlloc( return try out.toOwnedSlice(allocator); } +fn continueLengthResponseAlloc( + allocator: std.mem.Allocator, + app_config: *const config.Config, + base_request: types.Request, + result: *types.Response, +) !void { + const max_continuations: usize = 4; + var working_messages = try session.cloneMessagesAlloc(allocator, base_request.messages); + defer { + for (working_messages) |message| { + message.deinit(allocator); + } + allocator.free(working_messages); + } + + const with_first_assistant = try appendRoleMessageAlloc(allocator, working_messages, "assistant", result.output); + for (working_messages) |message| { + message.deinit(allocator); + } + allocator.free(working_messages); + working_messages = with_first_assistant; + + var turn: usize = 0; + while (turn < max_continuations and isLengthFinishReason(result.finish_reason)) : (turn += 1) { + const continue_prompt = length_continue_prompt; + const with_continue = try appendRoleMessageAlloc(allocator, working_messages, "user", continue_prompt); + for (working_messages) |message| { + message.deinit(allocator); + } + allocator.free(working_messages); + working_messages = with_continue; + + var turn_request = try cloneRequestWithMessagesAlloc(allocator, base_request, continue_prompt, working_messages); + defer turn_request.deinit(allocator); + + var turn_result = try backend.callProvider(allocator, app_config, turn_request); + defer turn_result.deinit(allocator); + if (!turn_result.success) return; + if (std.mem.trim(u8, turn_result.output, " \t\r\n").len == 0) return; + + const joined = try std.fmt.allocPrint(allocator, "{s}{s}", .{ result.output, turn_result.output }); + allocator.free(result.output); + result.output = joined; + + allocator.free(result.finish_reason); + result.finish_reason = try allocator.dupe(u8, turn_result.finish_reason); + result.usage.prompt_tokens += turn_result.usage.prompt_tokens; + result.usage.completion_tokens += turn_result.usage.completion_tokens; + result.usage.total_tokens += turn_result.usage.total_tokens; + + const with_assistant = try appendRoleMessageAlloc(allocator, working_messages, "assistant", turn_result.output); + for (working_messages) |message| { + message.deinit(allocator); + } + allocator.free(working_messages); + working_messages = with_assistant; + } +} + +fn isLengthFinishReason(finish_reason: []const u8) bool { + return std.ascii.eqlIgnoreCase(finish_reason, "length"); +} + +fn persistConversationState( + allocator: std.mem.Allocator, + app_config: *const config.Config, + use_workspace: bool, + workspace_store: ?*workspace.WorkspaceStore, + active_workspace: ?*workspace.WorkspaceState, + session_store: ?*session.SessionStore, + session_store_guard: *SessionStoreGuard, + parsed_req: types.Request, + loaded_session: ?session.SessionState, + base_messages: []const types.Message, + loop_messages: ?[]types.Message, + assistant_output: []const u8, +) void { + if (use_workspace) { + persistWorkspaceState( + allocator, + workspace_store, + active_workspace, + base_messages, + loop_messages, + assistant_output, + ); + return; + } + + persistSessionState( + allocator, + app_config, + session_store, + session_store_guard, + parsed_req, + loaded_session, + base_messages, + loop_messages, + assistant_output, + ); +} + +fn persistWorkspaceState( + allocator: std.mem.Allocator, + workspace_store: ?*workspace.WorkspaceStore, + active_workspace: ?*workspace.WorkspaceState, + base_messages: []const types.Message, + loop_messages: ?[]types.Message, + assistant_output: []const u8, +) void { + const store = workspace_store orelse return; + const ws = active_workspace orelse return; + + const with_assistant = if (loop_messages) |messages| blk: { + break :blk session.cloneMessagesAlloc(allocator, messages) catch |err| blk2: { + logError("Workspace append failed: {s}", .{@errorName(err)}); + break :blk2 null; + }; + } else session.appendAssistantMessageAlloc( + allocator, + base_messages, + assistant_output, + ) catch |err| blk: { + logError("Workspace append failed: {s}", .{@errorName(err)}); + break :blk null; + }; + + if (with_assistant) |messages| { + store.saveMessages(ws, messages) catch |err| { + logError("Workspace save failed: {s}", .{@errorName(err)}); + for (messages) |message| { + message.deinit(allocator); + } + allocator.free(messages); + }; + } +} + +fn persistSessionState( + allocator: std.mem.Allocator, + app_config: *const config.Config, + session_store: ?*session.SessionStore, + session_store_guard: *SessionStoreGuard, + parsed_req: types.Request, + loaded_session: ?session.SessionState, + base_messages: []const types.Message, + loop_messages: ?[]types.Message, + assistant_output: []const u8, +) void { + const store = session_store orelse return; + const session_id = parsed_req.session_id orelse return; + + const with_assistant = if (loop_messages) |messages| blk: { + break :blk session.cloneMessagesAlloc(allocator, messages) catch |err| blk2: { + logError("Session append failed for '{s}': {s}", .{ session_id, @errorName(err) }); + break :blk2 null; + }; + } else session.appendAssistantMessageAlloc( + allocator, + base_messages, + assistant_output, + ) catch |err| blk: { + logError("Session append failed for '{s}': {s}", .{ session_id, @errorName(err) }); + break :blk null; + }; + + if (with_assistant) |messages| { + defer { + for (messages) |message| { + message.deinit(allocator); + } + allocator.free(messages); + } + + var state = session.SessionState{ + .session_id = allocator.dupe(u8, session_id) catch return, + .tenant_id = if (parsed_req.tenant_id) |value| allocator.dupe(u8, value) catch return else null, + .summary = if (loaded_session) |loaded| allocator.dupe(u8, loaded.summary) catch return else allocator.dupe(u8, "") catch return, + .messages = session.trimToRetentionAlloc( + allocator, + messages, + app_config.session_retention_messages, + ) catch return, + .message_count = 0, + }; + state.message_count = state.messages.len; + defer state.deinit(allocator); + + session_store_guard.mutex.lock(); + store.save(allocator, state) catch |err| { + logError("Session save failed for '{s}': {s}", .{ session_id, @errorName(err) }); + }; + session_store_guard.mutex.unlock(); + } +} + fn sendApiErrorSafe( connection: std.net.Server.Connection, allocator: std.mem.Allocator, api_error: backend.errors.ApiError, app_config: *const config.Config, + request_id: ?[]const u8, ) void { - response.sendApiError(connection, allocator, api_error) catch |err| { + response.sendApiError(connection, allocator, api_error, request_id) catch |err| { swallowSocketWriteError(app_config, err); }; } @@ -1102,10 +2157,17 @@ fn sendJsonSafe( status_code: u16, body: []const u8, app_config: *const config.Config, + request_id: ?[]const u8, ) void { - response.sendJsonText(connection, status_code, body) catch |err| { - swallowSocketWriteError(app_config, err); - }; + if (request_id) |id| { + response.sendJsonTextWithRequestId(connection, status_code, body, id) catch |err| { + swallowSocketWriteError(app_config, err); + }; + } else { + response.sendJsonText(connection, status_code, body) catch |err| { + swallowSocketWriteError(app_config, err); + }; + } } fn sendChatCompletionSafe( @@ -1113,8 +2175,9 @@ fn sendChatCompletionSafe( allocator: std.mem.Allocator, result: types.Response, app_config: *const config.Config, + request_id: ?[]const u8, ) void { - response.sendChatCompletion(connection, allocator, result) catch |err| { + response.sendChatCompletion(connection, allocator, result, request_id) catch |err| { swallowSocketWriteError(app_config, err); }; } @@ -1137,7 +2200,6 @@ fn swallowSocketWriteError( fn requiresAuth(route: router.Route) bool { return switch (route) { .health => false, - .diagnostics_providers => false, else => true, }; } @@ -1162,3 +2224,53 @@ fn logInfo(comptime format: []const u8, args: anytype) void { fn logError(comptime format: []const u8, args: anytype) void { std.debug.print("[error] " ++ format ++ "\n", args); } + +test "connection gate enforces bounded capacity" { + var gate = ConnectionGate{ .max_active = 2 }; + try std.testing.expect(gate.tryAcquire()); + try std.testing.expect(gate.tryAcquire()); + try std.testing.expect(!gate.tryAcquire()); + + gate.release(); + try std.testing.expect(gate.tryAcquire()); +} + +test "connection gate recovers capacity after releases" { + var gate = ConnectionGate{ .max_active = 1 }; + try std.testing.expect(gate.tryAcquire()); + try std.testing.expect(!gate.tryAcquire()); + + gate.release(); + try std.testing.expect(gate.tryAcquire()); + gate.release(); + try std.testing.expect(gate.tryAcquire()); +} + +test "requiresAuth allows only health without an API key" { + try std.testing.expect(!requiresAuth(.health)); + try std.testing.expect(requiresAuth(.metrics)); + try std.testing.expect(requiresAuth(.diagnostics_providers)); + try std.testing.expect(requiresAuth(.chat_completions)); +} + +test "server state tracks provider latency buckets" { + var state = ServerState.init(); + defer state.deinit(std.testing.allocator); + + state.noteProviderLatency(10); + state.noteProviderLatency(75); + state.noteProviderLatency(500); + state.noteProviderLatency(1500); + + const snapshot = state.snapshot(); + try std.testing.expectEqual(@as(u64, 1), snapshot.provider_latency_buckets[0]); + try std.testing.expectEqual(@as(u64, 1), snapshot.provider_latency_buckets[1]); + try std.testing.expectEqual(@as(u64, 1), snapshot.provider_latency_buckets[2]); + try std.testing.expectEqual(@as(u64, 1), snapshot.provider_latency_buckets[3]); +} + +test "isLengthFinishReason detects provider length stops" { + try std.testing.expect(isLengthFinishReason("length")); + try std.testing.expect(isLengthFinishReason("LENGTH")); + try std.testing.expect(!isLengthFinishReason("stop")); +} diff --git a/src/main.zig b/src/main.zig index 93a3520..ff3bec7 100644 --- a/src/main.zig +++ b/src/main.zig @@ -1,16 +1,20 @@ //! CLI entrypoint for running the router in server mode or prompt-loop mode. const std = @import("std"); +const builtin = @import("builtin"); const root = @import("root.zig"); +const react = @import("react.zig"); const LoopMode = enum { basic, agent, + react, }; /// Parsed CLI options with owned heap-allocated string fields. const CliOptions = struct { prompt: ?[]u8 = null, // Enables prompt-loop mode when set. provider_override: ?[]u8 = null, // Optional provider alias from CLI. + model_override: ?[]u8 = null, // Optional default model override from CLI. env_file_path: ?[]u8 = null, // Optional dotenv path loaded before config. use_env: bool = false, // Enable dotenv loading. loop_mode: LoopMode = .basic, // Iteration style for prompt loop mode. @@ -25,6 +29,7 @@ const CliOptions = struct { pub fn deinit(self: *CliOptions, allocator: std.mem.Allocator) void { if (self.prompt) |prompt| allocator.free(prompt); if (self.provider_override) |provider| allocator.free(provider); + if (self.model_override) |model| allocator.free(model); if (self.env_file_path) |env_file_path| allocator.free(env_file_path); allocator.free(self.until_marker); } @@ -35,10 +40,18 @@ const CliOptions = struct { /// Errors: /// - allocator, configuration, CLI parsing, and provider/runtime errors. pub fn main() !void { - var gpa = std.heap.DebugAllocator(.{}){}; - defer _ = gpa.deinit(); - const allocator = gpa.allocator(); + if (comptime builtin.mode == .Debug) { + var gpa = std.heap.DebugAllocator(.{}){}; + defer _ = gpa.deinit(); + try runMain(gpa.allocator()); + } else { + var gpa = std.heap.GeneralPurposeAllocator(.{}){}; + defer _ = gpa.deinit(); + try runMain(gpa.allocator()); + } +} +fn runMain(allocator: std.mem.Allocator) !void { var cli = try parseCliOptions(allocator); defer cli.deinit(allocator); @@ -60,6 +73,10 @@ pub fn main() !void { try app_config.setDefaultProvider(allocator, provider); } + if (cli.model_override) |model| { + try app_config.setDefaultModel(allocator, model); + } + if (cli.prompt) |prompt| { try runPromptLoop(allocator, &app_config, prompt, cli.until_marker, cli.max_turns, cli.loop_mode); return; @@ -86,9 +103,16 @@ fn parseCliOptions(allocator: std.mem.Allocator) !CliOptions { continue; } + if (std.mem.eql(u8, arg, "--react")) { + cli.loop_mode = .react; + applyReactLoopDefaults(&cli); + continue; + } + if (std.mem.eql(u8, arg, "--loop-mode")) { const value = args.next() orelse return error.MissingLoopModeValue; cli.loop_mode = try parseLoopMode(value); + applyReactLoopDefaults(&cli); continue; } @@ -113,6 +137,14 @@ fn parseCliOptions(allocator: std.mem.Allocator) !CliOptions { continue; } + if (std.mem.eql(u8, arg, "--model")) { + const value = args.next() orelse return error.MissingModelValue; + if (value.len == 0) return error.InvalidModelValue; + if (cli.model_override) |model| allocator.free(model); + cli.model_override = try allocator.dupe(u8, value); + continue; + } + if (std.mem.eql(u8, arg, "--prompt")) { const value = args.next() orelse return error.MissingPromptValue; if (value.len == 0) return error.InvalidPromptValue; @@ -146,6 +178,15 @@ fn parseCliOptions(allocator: std.mem.Allocator) !CliOptions { continue; } + const model_prefix = "--model="; + if (std.mem.startsWith(u8, arg, model_prefix)) { + const value = arg[model_prefix.len..]; + if (value.len == 0) return error.InvalidModelValue; + if (cli.model_override) |model| allocator.free(model); + cli.model_override = try allocator.dupe(u8, value); + continue; + } + const prompt_prefix = "--prompt="; if (std.mem.startsWith(u8, arg, prompt_prefix)) { const value = arg[prompt_prefix.len..]; @@ -188,6 +229,7 @@ fn parseCliOptions(allocator: std.mem.Allocator) !CliOptions { if (std.mem.startsWith(u8, arg, loop_mode_prefix)) { const value = arg[loop_mode_prefix.len..]; cli.loop_mode = try parseLoopMode(value); + applyReactLoopDefaults(&cli); continue; } } @@ -195,9 +237,16 @@ fn parseCliOptions(allocator: std.mem.Allocator) !CliOptions { return cli; } +fn applyReactLoopDefaults(cli: *CliOptions) void { + if (cli.loop_mode == .react and cli.max_turns == 8) { + cli.max_turns = 24; + } +} + fn parseLoopMode(value: []const u8) !LoopMode { if (std.ascii.eqlIgnoreCase(value, "basic")) return .basic; if (std.ascii.eqlIgnoreCase(value, "agent")) return .agent; + if (std.ascii.eqlIgnoreCase(value, "react")) return .react; return error.InvalidLoopMode; } @@ -290,11 +339,19 @@ fn runPromptLoop( "system", "You are running in an iterative agent loop. Improve the solution every turn. Use concise self-critique before each improvement. When the task is complete, include the completion marker exactly.", ); + } else if (loop_mode == .react) { + try appendConversationMessage( + allocator, + &conversation, + "system", + react.system_prompt, + ); } try appendConversationMessage(allocator, &conversation, "user", initial_prompt); - var latest_user_prompt = initial_prompt; + var latest_user_prompt = try allocator.dupe(u8, initial_prompt); + defer allocator.free(latest_user_prompt); var turn: usize = 0; var repeated_count: usize = 0; var previous_output: ?[]u8 = null; @@ -346,7 +403,8 @@ fn runPromptLoop( } previous_output = try allocator.dupe(u8, normalized_output); - if (loop_mode == .agent and repeated_count >= 1) { + const repeated_limit: usize = if (loop_mode == .react) 2 else 1; + if ((loop_mode == .agent or loop_mode == .react) and repeated_count >= repeated_limit) { std.debug.print("Loop stopped early: model repeated output without progress.\n", .{}); return; } @@ -356,11 +414,55 @@ fn runPromptLoop( return; } - latest_user_prompt = switch (loop_mode) { - .basic => "Continue.", - .agent => "Critique your previous answer briefly, then improve it with concrete next steps. If complete, include the completion marker exactly and return the final result.", - }; - try appendConversationMessage(allocator, &conversation, "user", latest_user_prompt); + if (loop_mode == .react) { + // ReAct mode: parse the Action, execute it, inject the Observation. + const parsed_action = react.parseReactAction(result.output); + if (parsed_action) |action| { + switch (action) { + .finish => |answer| { + std.debug.print("ReAct finished:\n{s}\n", .{answer}); + return; + }, + else => { + const observation_raw = try react.executeReactAction(allocator, action, app_config); + defer allocator.free(observation_raw); + + const observation = try react.formatObservation(allocator, turn + 1, observation_raw); + defer allocator.free(observation); + + std.debug.print("{s}\n\n", .{observation}); + + allocator.free(latest_user_prompt); + latest_user_prompt = observation; + try appendConversationMessage(allocator, &conversation, "user", observation); + }, + } + } else { + // No parseable action found. Give the model a hint. + const hint = try react.formatObservation( + allocator, + turn + 1, + react.parse_error_hint, + ); + defer allocator.free(hint); + + std.debug.print("{s}\n\n", .{hint}); + allocator.free(latest_user_prompt); + latest_user_prompt = hint; + try appendConversationMessage(allocator, &conversation, "user", hint); + } + } else { + allocator.free(latest_user_prompt); + latest_user_prompt = try allocator.dupe( + u8, + switch (loop_mode) { + .basic => "Continue.", + .agent => "Critique your previous answer briefly, then improve it with concrete next steps. If complete, include the completion marker exactly and return the final result.", + .react => unreachable, + }, + ); + try appendConversationMessage(allocator, &conversation, "user", latest_user_prompt); + } } } diff --git a/src/providers/bedrock.zig b/src/providers/bedrock.zig index fc1b3a2..205bbfd 100644 --- a/src/providers/bedrock.zig +++ b/src/providers/bedrock.zig @@ -9,7 +9,7 @@ pub fn callBedrock( app_config: *const config.Config, request: types.Request, ) !types.Response { - const model_name = request.model orelse app_config.bedrock_model; + const model_name = request.model orelse app_config.modelForProvider("bedrock"); if (app_config.bedrock_access_key_id.len == 0 or app_config.bedrock_secret_access_key.len == 0) { return try errorResponse( diff --git a/src/providers/claude.zig b/src/providers/claude.zig index ec9270b..b0fffd9 100644 --- a/src/providers/claude.zig +++ b/src/providers/claude.zig @@ -10,7 +10,7 @@ pub fn callClaude( if (app_config.claude_api_key.len == 0) { return .{ .id = null, - .model = try allocator.dupe(u8, request.model orelse app_config.claude_model), + .model = try allocator.dupe(u8, request.model orelse app_config.modelForProvider("claude")), .output = try allocator.dupe(u8, "Claude API key is not configured on the server"), .finish_reason = try allocator.dupe(u8, "stop"), .success = false, @@ -18,7 +18,7 @@ pub fn callClaude( }; } - const model_name = request.model orelse app_config.claude_model; + const model_name = request.model orelse app_config.modelForProvider("claude"); var client = std.http.Client{ .allocator = allocator }; defer client.deinit(); diff --git a/src/providers/llama_cpp.zig b/src/providers/llama_cpp.zig index ce66814..5860d41 100644 --- a/src/providers/llama_cpp.zig +++ b/src/providers/llama_cpp.zig @@ -8,7 +8,7 @@ pub fn callLlamaCpp( app_config: *const config.Config, request: types.Request, ) !types.Response { - const model_name = request.model orelse app_config.llama_cpp_model; + const model_name = request.model orelse app_config.modelForProvider("llama_cpp"); const maybe_key: ?[]const u8 = if (app_config.llama_cpp_api_key.len > 0) app_config.llama_cpp_api_key diff --git a/src/providers/ollama_qwen.zig b/src/providers/ollama_qwen.zig index aef1c83..cb53392 100644 --- a/src/providers/ollama_qwen.zig +++ b/src/providers/ollama_qwen.zig @@ -30,18 +30,71 @@ const OllamaTuning = struct { num_predict: u32, }; +const max_stream_continuations: usize = 4; + +const stream_continue_prompt = + "The previous assistant message was cut off by the generation limit. " ++ + "Continue the same answer from the exact next token. Output only the continuation text; " ++ + "do not mention truncation, continuation, the previous message, or these instructions."; + +const OllamaStreamState = struct { + saw_done: bool = false, + stopped_by_length: bool = false, + think_block_open: *bool, +}; + +const StreamAttempt = struct { + saw_done: bool, + stopped_by_length: bool, +}; + +const StreamAttemptResult = union(enum) { + streamed: StreamAttempt, + failed: types.Response, +}; + pub const StreamQwenResult = union(enum) { streamed, failed: types.Response, }; -pub fn streamQwenToSse( +/// Per-turn streaming variant for use inside multi-turn loops (ReAct, agent). +/// +/// Differences from `streamQwenToSse`: +/// - The caller owns SSE headers, the completion id, and the streaming buffers +/// (`captured_content`, `think_block_open`). Headers and `[DONE]` are NOT +/// emitted here. +/// - On success, returns the finish_reason and resolved model name in +/// `TurnStreamSuccess` so the loop can decide whether to continue, parse +/// actions, or stop. +/// - `captured_content` accumulates only the assistant `content` tokens (no +/// reasoning), giving the loop a clean buffer to feed back into the next +/// turn or hand to `react.parseReactAction`. +pub const TurnStreamSuccess = struct { + finish_reason: []u8, + model: []u8, + + pub fn deinit(self: TurnStreamSuccess, allocator: std.mem.Allocator) void { + allocator.free(self.finish_reason); + allocator.free(self.model); + } +}; + +pub const TurnStreamResult = union(enum) { + streamed: TurnStreamSuccess, + failed: types.Response, +}; + +pub fn streamQwenTurnToSse( connection: std.net.Server.Connection, allocator: std.mem.Allocator, app_config: *const config.Config, request: types.Request, -) !StreamQwenResult { - const requested_model = request.model orelse app_config.ollama_model; + completion_id: []const u8, + captured_content: *std.ArrayList(u8), + think_block_open: *bool, +) !TurnStreamResult { + const requested_model = request.model orelse app_config.modelForProvider("ollama_qwen"); var model_name = requested_model; var fallback: ?ModelFallback = null; defer if (fallback) |value| value.deinit(allocator); @@ -49,12 +102,6 @@ pub fn streamQwenToSse( if (try pickFallbackModelAlloc(allocator, app_config, requested_model)) |selected| { fallback = selected; model_name = fallback.?.selected_model; - - const suggestion = fallback.?.suggested_model orelse fallback.?.selected_model; - std.debug.print( - "[warn] Ollama model '{s}' not installed. Using installed model '{s}'. Suggestion: set OLLAMA_MODEL='{s}'.\n", - .{ requested_model, fallback.?.selected_model, suggestion }, - ); } var client = std.http.Client{ .allocator = allocator }; @@ -65,50 +112,84 @@ pub fn streamQwenToSse( const uri = try std.Uri.parse(uri_text); - const messages_json = try renderMessagesJsonAlloc(allocator, request.messages); - defer allocator.free(messages_json); + var headers_sent_already = true; // caller already sent headers + var final_finish_reason: []const u8 = "stop"; - const body = try buildChatPayloadAlloc( - allocator, - app_config, - request, - model_name, - messages_json, - true, - ); - defer allocator.free(body); + var turn: usize = 0; + while (true) : (turn += 1) { + const messages_json = if (turn == 0) + try renderMessagesJsonAlloc(allocator, request.messages) + else + try renderContinuationMessagesJsonAlloc( + allocator, + request.messages, + captured_content.items, + stream_continue_prompt, + ); + defer allocator.free(messages_json); - const headers = &[_]std.http.Header{ - .{ .name = "content-type", .value = "application/json" }, - }; + const attempt_result = try streamChatMessagesToSse( + connection, + allocator, + &client, + app_config, + request, + uri, + completion_id, + model_name, + messages_json, + &headers_sent_already, + captured_content, + think_block_open, + ); - var req = try client.request(.POST, uri, .{ - .extra_headers = headers, - .keep_alive = false, - }); - defer req.deinit(); + switch (attempt_result) { + .failed => |failed_response| return .{ .failed = failed_response }, + .streamed => |attempt| { + if (!attempt.stopped_by_length) break; + if (turn + 1 >= max_stream_continuations) { + final_finish_reason = "length"; + break; + } + }, + } + } - req.transfer_encoding = .{ .content_length = body.len }; + return .{ .streamed = .{ + .finish_reason = try allocator.dupe(u8, final_finish_reason), + .model = try allocator.dupe(u8, model_name), + } }; +} - var req_body = try req.sendBodyUnflushed(&.{}); - try req_body.writer.writeAll(body); - try req_body.end(); - try req.connection.?.flush(); +pub fn streamQwenToSse( + connection: std.net.Server.Connection, + allocator: std.mem.Allocator, + app_config: *const config.Config, + request: types.Request, +) !StreamQwenResult { + const requested_model = request.model orelse app_config.modelForProvider("ollama_qwen"); + var model_name = requested_model; + var fallback: ?ModelFallback = null; + defer if (fallback) |value| value.deinit(allocator); - var upstream = try req.receiveHead(&.{}); + if (try pickFallbackModelAlloc(allocator, app_config, requested_model)) |selected| { + fallback = selected; + model_name = fallback.?.selected_model; - if (upstream.head.status != .ok) { - const error_message = try std.fmt.allocPrint( - allocator, - "Ollama HTTP {s} for model '{s}' during stream request", - .{ @tagName(upstream.head.status), model_name }, + const suggestion = fallback.?.suggested_model orelse fallback.?.selected_model; + std.debug.print( + "[warn] Ollama model '{s}' not installed. Using installed model '{s}'. Suggestion: set OLLAMA_MODEL='{s}'.\n", + .{ requested_model, fallback.?.selected_model, suggestion }, ); - defer allocator.free(error_message); - - return .{ .failed = try makeResponse(allocator, model_name, error_message, false) }; } - try response.sendEventStreamHeaders(connection); + var client = std.http.Client{ .allocator = allocator }; + defer client.deinit(); + + const uri_text = try buildChatUrl(allocator, app_config.ollama_base_url); + defer allocator.free(uri_text); + + const uri = try std.Uri.parse(uri_text); const completion_id = try std.fmt.allocPrint( allocator, @@ -117,59 +198,74 @@ pub fn streamQwenToSse( ); defer allocator.free(completion_id); - var transfer_buffer: [2048]u8 = undefined; - const body_reader = upstream.reader(transfer_buffer[0..]); - - var pending = std.ArrayList(u8){}; - defer pending.deinit(allocator); - - var read_buf: [1024]u8 = undefined; - var saw_done = false; + var headers_sent = false; + var streamed_text = std.ArrayList(u8){}; + defer streamed_text.deinit(allocator); + var final_finish_reason: []const u8 = "stop"; + var think_block_open: bool = false; - while (true) { - const n = body_reader.readSliceShort(read_buf[0..]) catch |err| switch (err) { - error.ReadFailed => return upstream.bodyErr() orelse err, - }; - if (n == 0) break; + var turn: usize = 0; + while (true) : (turn += 1) { + const messages_json = if (turn == 0) + try renderMessagesJsonAlloc(allocator, request.messages) + else + try renderContinuationMessagesJsonAlloc( + allocator, + request.messages, + streamed_text.items, + stream_continue_prompt, + ); + defer allocator.free(messages_json); - try pending.appendSlice(allocator, read_buf[0..n]); - try processPendingStreamLines( + const attempt_result = try streamChatMessagesToSse( connection, allocator, + &client, + app_config, + request, + uri, completion_id, model_name, - pending.items, - &pending, - &saw_done, + messages_json, + &headers_sent, + &streamed_text, + &think_block_open, ); - if (n < read_buf.len) break; + switch (attempt_result) { + .failed => |failed_response| { + if (!headers_sent) return .{ .failed = failed_response }; + defer failed_response.deinit(allocator); + break; + }, + .streamed => |attempt| { + if (!attempt.stopped_by_length) break; + if (turn + 1 >= max_stream_continuations) { + final_finish_reason = "length"; + break; + } + }, + } } - if (pending.items.len > 0) { - const trailing = std.mem.trim(u8, pending.items, "\r\n \t"); - if (trailing.len > 0) { - try handleOllamaStreamLine( - connection, - allocator, - completion_id, - model_name, - trailing, - &saw_done, + if (headers_sent) { + if (think_block_open) { + try response.sendChatCompletionChunkSse( + connection, allocator, completion_id, model_name, "", null, ); + think_block_open = false; } - } - - if (!saw_done) { try response.sendChatCompletionChunkSse( connection, allocator, completion_id, model_name, null, - "stop", + final_finish_reason, ); try response.sendSseDone(connection); + } else { + return .{ .failed = try makeResponse(allocator, model_name, "Ollama stream ended before headers were sent", false) }; } return .streamed; @@ -180,7 +276,7 @@ pub fn callQwen( app_config: *const config.Config, request: types.Request, ) !types.Response { - const requested_model = request.model orelse app_config.ollama_model; + const requested_model = request.model orelse app_config.modelForProvider("ollama_qwen"); var model_name = requested_model; var fallback: ?ModelFallback = null; defer if (fallback) |value| value.deinit(allocator); @@ -258,22 +354,44 @@ pub fn callQwen( ), }; - const content_value = switch (message_object.get("content") orelse { + const tool_calls = try parseToolCallsAlloc(allocator, message_object); + errdefer if (tool_calls) |calls| { + for (calls) |call| call.deinit(allocator); + allocator.free(calls); + }; + + const content_value = if (message_object.get("content")) |content| + switch (content) { + .string => |value| value, + .null => "", + else => return try makeResponse( + allocator, + model_name, + "Invalid content field from Ollama message", + false, + ), + } + else if (tool_calls != null) + "" + else return try makeResponse( allocator, model_name, "Missing content field from Ollama message", false, ); - }) { - .string => |value| value, - else => return try makeResponse( - allocator, - model_name, - "Invalid content field from Ollama message", - false, - ), - }; + + const thinking_value = if (message_object.get("thinking")) |thinking| + switch (thinking) { + .string => |value| value, + else => "", + } + else + ""; + + const tuning = resolveTuning(app_config, request); + const output = try formatAssistantOutputAlloc(allocator, content_value, thinking_value, tuning.think); + errdefer allocator.free(output); // Ollama may omit done_reason; default to `stop` for OpenAI compatibility. const finish_reason = if (root.get("done_reason")) |done_reason| @@ -287,7 +405,7 @@ pub fn callQwen( return .{ .id = null, .model = try allocator.dupe(u8, model_name), - .output = try allocator.dupe(u8, content_value), + .output = output, .finish_reason = try allocator.dupe(u8, finish_reason), .success = true, .usage = .{ @@ -296,9 +414,126 @@ pub fn callQwen( .total_tokens = parseUsageField(root.get("prompt_eval_count")) + parseUsageField(root.get("eval_count")), }, + .tool_calls = tool_calls, }; } +fn parseToolCallsAlloc( + allocator: std.mem.Allocator, + message_object: std.json.ObjectMap, +) !?[]types.ToolCall { + const tool_calls_value = message_object.get("tool_calls") orelse return null; + const tool_calls_array = switch (tool_calls_value) { + .array => |value| value, + else => return null, + }; + if (tool_calls_array.items.len == 0) return null; + + var calls = std.ArrayList(types.ToolCall){}; + errdefer { + for (calls.items) |call| call.deinit(allocator); + calls.deinit(allocator); + } + + for (tool_calls_array.items, 0..) |call_value, index| { + const call_object = switch (call_value) { + .object => |value| value, + else => continue, + }; + const function_object = switch (call_object.get("function") orelse continue) { + .object => |value| value, + else => continue, + }; + const name = switch (function_object.get("name") orelse continue) { + .string => |value| value, + else => continue, + }; + const arguments_json = try parseToolArgumentsJsonAlloc(allocator, function_object); + errdefer allocator.free(arguments_json); + + const id = if (call_object.get("id")) |id_value| + switch (id_value) { + .string => |value| try allocator.dupe(u8, value), + else => try std.fmt.allocPrint(allocator, "call_{d}", .{index}), + } + else + try std.fmt.allocPrint(allocator, "call_{d}", .{index}); + errdefer allocator.free(id); + + try calls.append(allocator, .{ + .id = id, + .name = try allocator.dupe(u8, name), + .arguments_json = arguments_json, + }); + } + + if (calls.items.len == 0) { + calls.deinit(allocator); + return null; + } + return try calls.toOwnedSlice(allocator); +} + +fn parseToolArgumentsJsonAlloc( + allocator: std.mem.Allocator, + function_object: std.json.ObjectMap, +) ![]u8 { + const arguments_value = function_object.get("arguments") orelse return try allocator.dupe(u8, "{}"); + return switch (arguments_value) { + .string => |value| try allocator.dupe(u8, value), + .object => try jsonValueStringifyAlloc(allocator, arguments_value), + else => try allocator.dupe(u8, "{}"), + }; +} + +fn formatAssistantOutputAlloc( + allocator: std.mem.Allocator, + content: []const u8, + thinking: []const u8, + include_thinking: bool, +) ![]u8 { + if (!include_thinking or std.mem.trim(u8, thinking, " \t\r\n").len == 0) { + return try allocator.dupe(u8, content); + } + + if (std.mem.trim(u8, content, " \t\r\n").len == 0) { + return try std.fmt.allocPrint(allocator, "{s}", .{thinking}); + } + + return try std.fmt.allocPrint( + allocator, + "{s}\n{s}", + .{ thinking, content }, + ); +} + +test "formatAssistantOutputAlloc includes thinking when requested" { + const allocator = std.testing.allocator; + const output = try formatAssistantOutputAlloc( + allocator, + "final answer", + "reasoning trace", + true, + ); + defer allocator.free(output); + + try std.testing.expect(std.mem.indexOf(u8, output, "reasoning trace") != null); + try std.testing.expect(std.mem.indexOf(u8, output, "final answer") != null); +} + +test "formatAssistantOutputAlloc omits thinking when disabled" { + const allocator = std.testing.allocator; + const output = try formatAssistantOutputAlloc( + allocator, + "final answer", + "reasoning trace", + false, + ); + defer allocator.free(output); + + try std.testing.expectEqualStrings("final answer", output); +} + pub fn buildStatusJsonAlloc( allocator: std.mem.Allocator, app_config: *const config.Config, @@ -481,17 +716,124 @@ fn buildTagsUrl( return try std.fmt.allocPrint(allocator, "{s}/api/tags", .{base_url}); } +fn streamChatMessagesToSse( + connection: std.net.Server.Connection, + allocator: std.mem.Allocator, + client: *std.http.Client, + app_config: *const config.Config, + request: types.Request, + uri: std.Uri, + completion_id: []const u8, + model_name: []const u8, + messages_json: []const u8, + headers_sent: *bool, + streamed_text: *std.ArrayList(u8), + think_block_open: *bool, +) !StreamAttemptResult { + const body = try buildChatPayloadAlloc( + allocator, + app_config, + request, + model_name, + messages_json, + true, + ); + defer allocator.free(body); + + const headers = &[_]std.http.Header{ + .{ .name = "content-type", .value = "application/json" }, + }; + + var req = try client.request(.POST, uri, .{ + .extra_headers = headers, + .keep_alive = false, + }); + defer req.deinit(); + + req.transfer_encoding = .{ .content_length = body.len }; + + var req_body = try req.sendBodyUnflushed(&.{}); + try req_body.writer.writeAll(body); + try req_body.end(); + try req.connection.?.flush(); + + var upstream = try req.receiveHead(&.{}); + + if (upstream.head.status != .ok) { + const error_message = try std.fmt.allocPrint( + allocator, + "Ollama HTTP {s} for model '{s}' during stream request", + .{ @tagName(upstream.head.status), model_name }, + ); + defer allocator.free(error_message); + + return .{ .failed = try makeResponse(allocator, model_name, error_message, false) }; + } + + if (!headers_sent.*) { + try response.sendEventStreamHeaders(connection, null); + headers_sent.* = true; + } + + var transfer_buffer: [2048]u8 = undefined; + const body_reader = upstream.reader(transfer_buffer[0..]); + + var pending = std.ArrayList(u8){}; + defer pending.deinit(allocator); + + var read_buf: [1024]u8 = undefined; + var state = OllamaStreamState{ .think_block_open = think_block_open }; + + while (!state.saw_done) { + const n = body_reader.readSliceShort(read_buf[0..]) catch |err| switch (err) { + error.ReadFailed => return upstream.bodyErr() orelse err, + }; + if (n == 0) break; + + try pending.appendSlice(allocator, read_buf[0..n]); + try processPendingStreamLines( + connection, + allocator, + completion_id, + model_name, + &pending, + &state, + streamed_text, + ); + } + + if (pending.items.len > 0 and !state.saw_done) { + const trailing = std.mem.trim(u8, pending.items, "\r\n \t"); + if (trailing.len > 0) { + try handleOllamaStreamLine( + connection, + allocator, + completion_id, + model_name, + trailing, + &state, + streamed_text, + ); + } + } + + return .{ + .streamed = .{ + .saw_done = state.saw_done, + .stopped_by_length = state.stopped_by_length, + }, + }; +} + fn processPendingStreamLines( connection: std.net.Server.Connection, allocator: std.mem.Allocator, completion_id: []const u8, fallback_model: []const u8, - lines: []const u8, pending: *std.ArrayList(u8), - saw_done: *bool, + state: *OllamaStreamState, + streamed_text: *std.ArrayList(u8), ) !void { - _ = lines; - while (std.mem.indexOfScalar(u8, pending.items, '\n')) |line_end| { const raw_line = pending.items[0..line_end]; const line = std.mem.trimRight(u8, raw_line, "\r"); @@ -502,7 +844,8 @@ fn processPendingStreamLines( completion_id, fallback_model, line, - saw_done, + state, + streamed_text, ); } @@ -519,7 +862,8 @@ fn handleOllamaStreamLine( completion_id: []const u8, fallback_model: []const u8, line: []const u8, - saw_done: *bool, + state: *OllamaStreamState, + streamed_text: *std.ArrayList(u8), ) !void { var parsed = std.json.parseFromSlice(std.json.Value, allocator, line, .{}) catch { return; @@ -546,36 +890,49 @@ fn handleOllamaStreamLine( }; if (message_object) |obj| { - var token: ?[]const u8 = null; - - if (obj.get("content")) |content_value| { - const content = switch (content_value) { + const thinking_text = if (obj.get("thinking")) |thinking_value| + switch (thinking_value) { .string => |value| value, else => "", - }; - - if (content.len > 0) token = content; - } - - if (token == null) { - if (obj.get("thinking")) |thinking_value| { - const thinking = switch (thinking_value) { - .string => |value| value, - else => "", - }; + } + else + ""; - if (thinking.len > 0) token = thinking; + const content_text = if (obj.get("content")) |content_value| + switch (content_value) { + .string => |value| value, + else => "", + } + else + ""; + + // Stream thinking wrapped in ... tags so the client + // can render it distinctly. Thinking is intentionally NOT appended + // to streamed_text: that buffer feeds the "continue from cutoff" + // prompt on length-stops, and including reasoning there causes the + // model to restart its think block on every continuation. + if (thinking_text.len > 0) { + if (!state.think_block_open.*) { + try response.sendChatCompletionChunkSse( + connection, allocator, completion_id, model, "", null, + ); + state.think_block_open.* = true; } + try response.sendChatCompletionChunkSse( + connection, allocator, completion_id, model, thinking_text, null, + ); } - if (token) |text| { + if (content_text.len > 0) { + if (state.think_block_open.*) { + try response.sendChatCompletionChunkSse( + connection, allocator, completion_id, model, "", null, + ); + state.think_block_open.* = false; + } + try streamed_text.appendSlice(allocator, content_text); try response.sendChatCompletionChunkSse( - connection, - allocator, - completion_id, - model, - text, - null, + connection, allocator, completion_id, model, content_text, null, ); } } @@ -588,6 +945,7 @@ fn handleOllamaStreamLine( }; if (response_text.len > 0) { + try streamed_text.appendSlice(allocator, response_text); try response.sendChatCompletionChunkSse( connection, allocator, @@ -607,7 +965,7 @@ fn handleOllamaStreamLine( else false; - if (done and !saw_done.*) { + if (done and !state.saw_done) { const finish_reason = if (root.get("done_reason")) |reason_value| switch (reason_value) { .string => |value| value, @@ -616,16 +974,8 @@ fn handleOllamaStreamLine( else "stop"; - try response.sendChatCompletionChunkSse( - connection, - allocator, - completion_id, - model, - null, - finish_reason, - ); - try response.sendSseDone(connection); - saw_done.* = true; + state.saw_done = true; + state.stopped_by_length = std.ascii.eqlIgnoreCase(finish_reason, "length"); } } @@ -717,14 +1067,30 @@ fn renderToolsJsonAlloc( defer allocator.free(escaped_description); try out.writer(allocator).print( - "{{\"type\":\"function\",\"function\":{{\"name\":\"{s}\",\"description\":\"{s}\"}}}}", + "{{\"type\":\"function\",\"function\":{{\"name\":\"{s}\",\"description\":\"{s}\"", .{ escaped_name, escaped_description }, ); + if (tool.parameters_json) |parameters_json| { + try out.writer(allocator).print(",\"parameters\":{s}", .{parameters_json}); + } + try out.appendSlice(allocator, "}}}"); } try out.append(allocator, ']'); return try out.toOwnedSlice(allocator); } +fn jsonValueStringifyAlloc(allocator: std.mem.Allocator, value: std.json.Value) ![]u8 { + var out = std.Io.Writer.Allocating.init(allocator); + errdefer out.deinit(); + + var json_writer = std.json.Stringify{ + .writer = &out.writer, + .options = .{}, + }; + try json_writer.write(value); + return try out.toOwnedSlice(); +} + fn resolveTuning(app_config: *const config.Config, request: types.Request) OllamaTuning { return .{ .think = request.think orelse app_config.ollama_think, @@ -899,6 +1265,31 @@ fn renderMessagesJsonAlloc( return try out.toOwnedSlice(allocator); } +fn renderContinuationMessagesJsonAlloc( + allocator: std.mem.Allocator, + base_messages: []const types.Message, + assistant_content: []const u8, + continue_prompt: []const u8, +) ![]u8 { + const messages = try allocator.alloc(types.Message, base_messages.len + 2); + defer allocator.free(messages); + + for (base_messages, 0..) |message, index| { + messages[index] = message; + } + + messages[base_messages.len] = .{ + .role = "assistant", + .content = assistant_content, + }; + messages[base_messages.len + 1] = .{ + .role = "user", + .content = continue_prompt, + }; + + return try renderMessagesJsonAlloc(allocator, messages); +} + fn parseUsageField(value: ?std.json.Value) usize { const actual_value = value orelse return 0; diff --git a/src/providers/openai.zig b/src/providers/openai.zig index 8a1f5c3..48cd66b 100644 --- a/src/providers/openai.zig +++ b/src/providers/openai.zig @@ -11,7 +11,7 @@ pub fn callOpenAI( if (app_config.openai_api_key.len == 0) { return .{ .id = null, - .model = try allocator.dupe(u8, request.model orelse app_config.openai_model), + .model = try allocator.dupe(u8, request.model orelse app_config.modelForProvider("openai")), .output = try allocator.dupe(u8, "OpenAI API key is not configured on the server"), .finish_reason = try allocator.dupe(u8, "stop"), .success = false, @@ -19,7 +19,7 @@ pub fn callOpenAI( }; } - const model_name = request.model orelse app_config.openai_model; + const model_name = request.model orelse app_config.modelForProvider("openai"); return try openai_compatible.callChat( allocator, app_config.openai_base_url, diff --git a/src/providers/openai_compatible.zig b/src/providers/openai_compatible.zig index 099612d..ce29d41 100644 --- a/src/providers/openai_compatible.zig +++ b/src/providers/openai_compatible.zig @@ -196,13 +196,23 @@ fn parseChatResponse( else => return try makeFailureResponse(allocator, fallback_model, "Provider message was not an object"), }; - const output = switch (message_obj.get("content") orelse { - return try makeFailureResponse(allocator, fallback_model, "Provider message missing content"); - }) { - .string => |value| value, - else => return try makeFailureResponse(allocator, fallback_model, "Provider content was not a string"), + const tool_calls = try parseToolCallsAlloc(allocator, message_obj); + errdefer if (tool_calls) |calls| { + for (calls) |call| call.deinit(allocator); + allocator.free(calls); }; + const output = if (message_obj.get("content")) |content_value| + switch (content_value) { + .string => |value| value, + .null => "", + else => return try makeFailureResponse(allocator, fallback_model, "Provider content was not a string"), + } + else if (tool_calls != null) + "" + else + return try makeFailureResponse(allocator, fallback_model, "Provider message missing content"); + const finish_reason = if (first_choice.get("finish_reason")) |finish_reason_value| switch (finish_reason_value) { .string => |value| value, @@ -230,6 +240,75 @@ fn parseChatResponse( .finish_reason = try allocator.dupe(u8, finish_reason), .success = true, .usage = usage, + .tool_calls = tool_calls, + }; +} + +fn parseToolCallsAlloc( + allocator: std.mem.Allocator, + message_obj: std.json.ObjectMap, +) !?[]types.ToolCall { + const tool_calls_value = message_obj.get("tool_calls") orelse return null; + const tool_calls_array = switch (tool_calls_value) { + .array => |value| value, + else => return null, + }; + if (tool_calls_array.items.len == 0) return null; + + var calls = std.ArrayList(types.ToolCall){}; + errdefer { + for (calls.items) |call| call.deinit(allocator); + calls.deinit(allocator); + } + + for (tool_calls_array.items, 0..) |call_value, index| { + const call_obj = switch (call_value) { + .object => |value| value, + else => continue, + }; + const function_obj = switch (call_obj.get("function") orelse continue) { + .object => |value| value, + else => continue, + }; + const name = switch (function_obj.get("name") orelse continue) { + .string => |value| value, + else => continue, + }; + const arguments_json = try parseToolArgumentsJsonAlloc(allocator, function_obj); + errdefer allocator.free(arguments_json); + + const id = if (call_obj.get("id")) |id_value| + switch (id_value) { + .string => |value| try allocator.dupe(u8, value), + else => try std.fmt.allocPrint(allocator, "call_{d}", .{index}), + } + else + try std.fmt.allocPrint(allocator, "call_{d}", .{index}); + errdefer allocator.free(id); + + try calls.append(allocator, .{ + .id = id, + .name = try allocator.dupe(u8, name), + .arguments_json = arguments_json, + }); + } + + if (calls.items.len == 0) { + calls.deinit(allocator); + return null; + } + return try calls.toOwnedSlice(allocator); +} + +fn parseToolArgumentsJsonAlloc( + allocator: std.mem.Allocator, + function_obj: std.json.ObjectMap, +) ![]u8 { + const arguments_value = function_obj.get("arguments") orelse return try allocator.dupe(u8, "{}"); + return switch (arguments_value) { + .string => |value| try allocator.dupe(u8, value), + .object => try jsonValueStringifyAlloc(allocator, arguments_value), + else => try allocator.dupe(u8, "{}"), }; } @@ -342,15 +421,31 @@ fn renderToolsJsonAlloc( defer allocator.free(escaped_description); try out.writer(allocator).print( - "{{\"type\":\"function\",\"function\":{{\"name\":\"{s}\",\"description\":\"{s}\"}}}}", + "{{\"type\":\"function\",\"function\":{{\"name\":\"{s}\",\"description\":\"{s}\"", .{ escaped_name, escaped_description }, ); + if (tool.parameters_json) |parameters_json| { + try out.writer(allocator).print(",\"parameters\":{s}", .{parameters_json}); + } + try out.appendSlice(allocator, "}}}"); } try out.append(allocator, ']'); return try out.toOwnedSlice(allocator); } +fn jsonValueStringifyAlloc(allocator: std.mem.Allocator, value: std.json.Value) ![]u8 { + var out = std.Io.Writer.Allocating.init(allocator); + errdefer out.deinit(); + + var json_writer = std.json.Stringify{ + .writer = &out.writer, + .options = .{}, + }; + try json_writer.write(value); + return try out.toOwnedSlice(); +} + fn makeFailureResponse( allocator: std.mem.Allocator, model_name: []const u8, @@ -431,3 +526,28 @@ test "parseChatResponse keeps ids and usage for successful responses" { try std.testing.expectEqualStrings("hello", response.output); try std.testing.expectEqual(@as(usize, 5), response.usage.total_tokens); } + +test "parseChatResponse preserves OpenAI tool calls" { + const allocator = std.testing.allocator; + const raw = + \\{"id":"chatcmpl-tools","model":"gpt-test","choices":[{"message":{"content":null,"tool_calls":[{"id":"call_1","type":"function","function":{"name":"echo","arguments":"{\"message\":\"demo-green\"}"}}]},"finish_reason":"tool_calls"}],"usage":{"prompt_tokens":3,"completion_tokens":2,"total_tokens":5}} + ; + + const response = try parseChatResponse( + allocator, + "gpt-test", + "OpenAI", + .ok, + raw, + ); + defer response.deinit(allocator); + + try std.testing.expect(response.success); + try std.testing.expectEqualStrings("", response.output); + try std.testing.expectEqualStrings("tool_calls", response.finish_reason); + try std.testing.expect(response.tool_calls != null); + try std.testing.expectEqual(@as(usize, 1), response.tool_calls.?.len); + try std.testing.expectEqualStrings("call_1", response.tool_calls.?[0].id); + try std.testing.expectEqualStrings("echo", response.tool_calls.?[0].name); + try std.testing.expectEqualStrings("{\"message\":\"demo-green\"}", response.tool_calls.?[0].arguments_json); +} diff --git a/src/providers/openrouter.zig b/src/providers/openrouter.zig index 31ca7bf..63d0f14 100644 --- a/src/providers/openrouter.zig +++ b/src/providers/openrouter.zig @@ -11,12 +11,12 @@ pub fn callOpenRouter( if (app_config.openrouter_api_key.len == 0) { return try errorResponse( allocator, - request.model orelse app_config.openrouter_model, + request.model orelse app_config.modelForProvider("openrouter"), "OpenRouter provider selected but OPENROUTER_API_KEY is not set", ); } - const model_name = request.model orelse app_config.openrouter_model; + const model_name = request.model orelse app_config.modelForProvider("openrouter"); var extra_headers_buffer: [2]std.http.Header = undefined; const extra_headers = buildExtraHeaders( app_config.openrouter_http_referer, diff --git a/src/react.zig b/src/react.zig new file mode 100644 index 0000000..028088e --- /dev/null +++ b/src/react.zig @@ -0,0 +1,663 @@ +//! ReAct (Reasoning + Acting) loop support. +//! +//! Implements the Thought → Action → Observation paradigm from Yao et al., 2022. +//! The model produces structured Thought/Action pairs, the harness executes the +//! action and injects the result as an Observation, and the cycle repeats until +//! the model emits Finish[answer] or the turn budget is exhausted. +const std = @import("std"); +const config = @import("config.zig"); +const types = @import("types.zig"); +const command_exec_tool = @import("tools/command_exec.zig"); + +/// System prompt injected at the start of every ReAct loop to instruct the +/// model on the required Thought/Action/Observation format. +pub const system_prompt = + "You are running in ReAct (Reasoning + Acting) mode for multi-step problem solving.\n" ++ + "Smaller models should solve tasks incrementally: one Thought and one Action per turn, " ++ + "wait for the Observation, then continue.\n\n" ++ + "Required format every turn:\n\n" ++ + "Thought N: \n" ++ + "Action N: \n\n" ++ + "The harness injects after each Action:\n" ++ + "Observation N: \n\n" ++ + "Example:\n" ++ + "Thought 1: I need to inspect the entry file before editing.\n" ++ + "Action 1: file_read[src/main.zig]\n\n" ++ + "Available actions:\n" ++ + " Search[query] - search the repo (maps to file_search when that tool is enabled)\n" ++ + " Lookup[path] - read a file path (maps to file_read when that tool is enabled)\n" ++ + " Cmd[command] - execute a shell command (may require confirmation)\n" ++ + " Finish[answer] - return the final answer and end the loop\n\n" ++ + "Rules:\n" ++ + "- Emit exactly ONE Action per turn; do not plan ahead with Action 2+ before the Observation.\n" ++ + "- Never write Observation lines yourself.\n" ++ + "- Use Finish[answer] only when the task is done.\n" ++ + "- Prefer small steps: read/search, then edit, then run tests via Cmd.\n" ++ + "- If parsing fails, repeat Thought/Action using the exact Action N: Type[argument] format.\n" ++ + "- If an action fails, change approach in the next Thought instead of repeating the same action.\n"; + +pub fn buildSystemPromptAlloc(allocator: std.mem.Allocator, tools: []const types.Tool) ![]u8 { + if (tools.len == 0) return try allocator.dupe(u8, system_prompt); + + var out = std.ArrayList(u8){}; + errdefer out.deinit(allocator); + + try out.appendSlice(allocator, system_prompt); + try out.appendSlice( + allocator, + "\nRequested API tools (names only; invoke with exact name in brackets):\n", + ); + + for (tools, 0..) |tool, idx| { + if (idx > 0) try out.append(allocator, ','); + try out.appendSlice(allocator, " "); + try out.appendSlice(allocator, tool.name); + } + + try out.appendSlice( + allocator, + "\n\nCoding tool rules:\n" ++ + "- file_read[path]: read before editing; path is relative to the project root.\n" ++ + "- file_search[search pattern in directory]: find symbols/files before changing code.\n" ++ + "- file_write[write relative/path\n```lang\n...\n```]: put the full write payload in brackets.\n" ++ + "- Read[path] and Write[...] are aliases for file_read and file_write.\n" ++ + "- Run Cmd[zig build test] (or similar) after edits to verify.\n", + ); + + return try out.toOwnedSlice(allocator); +} + +const DefaultCodingTool = struct { + name: []const u8, + description: []const u8, +}; + +/// Ensures ReAct loop requests expose the standard coding tool set server-side so +/// clients do not need to attach tools manually. Existing client tools are kept; +/// missing defaults are appended. Sets `tool_choice` to `auto` when unset. +pub fn ensureReactCodingToolsAlloc(allocator: std.mem.Allocator, request: *types.Request) !void { + const loop_mode = request.loop_mode orelse return; + if (!std.ascii.eqlIgnoreCase(loop_mode, "react")) return; + + const shell_tool: []const u8 = if (@import("builtin").os.tag == .windows) "cmd" else "bash"; + const shell_description = if (@import("builtin").os.tag == .windows) + "Execute a Windows shell command (requires LLM_ROUTER_TOOL_EXEC_ENABLED=1)." + else + "Execute a bash shell command (requires LLM_ROUTER_TOOL_EXEC_ENABLED=1)."; + + const defaults = [_]DefaultCodingTool{ + .{ .name = "file_read", .description = "Read a relative file path and return its content." }, + .{ .name = "file_write", .description = "Write or replace a relative file path." }, + .{ .name = "file_search", .description = "Search for a substring under a relative directory." }, + .{ .name = shell_tool, .description = shell_description }, + }; + + var merged = std.ArrayList(types.Tool){}; + errdefer { + for (merged.items) |tool| { + tool.deinit(allocator); + } + merged.deinit(allocator); + } + + for (request.tools) |tool| { + try merged.append(allocator, .{ + .name = try allocator.dupe(u8, tool.name), + .description = try allocator.dupe(u8, tool.description), + }); + } + + for (defaults) |default_tool| { + if (hasToolName(merged.items, default_tool.name)) continue; + try merged.append(allocator, .{ + .name = try allocator.dupe(u8, default_tool.name), + .description = try allocator.dupe(u8, default_tool.description), + }); + } + + for (request.tools) |tool| { + tool.deinit(allocator); + } + allocator.free(request.tools); + + request.tools = try merged.toOwnedSlice(allocator); + + if (request.tool_choice == null) { + request.tool_choice = try allocator.dupe(u8, "auto"); + } +} + +fn hasToolName(tools: []const types.Tool, name: []const u8) bool { + for (tools) |tool| { + if (std.ascii.eqlIgnoreCase(tool.name, name)) return true; + } + return false; +} + +/// Parsed action extracted from a model response. +pub const ToolAction = struct { + name: []const u8, + argument: []const u8, +}; + +pub const ReactAction = union(enum) { + search: []const u8, + lookup: []const u8, + cmd: []const u8, + finish: []const u8, + tool: ToolAction, + unknown: []const u8, +}; + +/// Scans model output for the last `Action [N]: Type[arg]` line and returns +/// the parsed action variant. +/// +/// Parsing rules: +/// - Case-insensitive match on the `Action` keyword. +/// - The turn number after `Action` is optional (the harness tracks turns). +/// - Brackets are mandatory for the argument. +/// - Greedy bracket matching: takes everything up to the last `]` after the +/// action, allowing multi-line tool inputs such as fenced code blocks. +/// - Returns `null` when no action line is found. +pub fn parseReactAction(output: []const u8) ?ReactAction { + // Scan for the FIRST Action directive that lives outside a + // ... block. Returning the first action (rather than the + // last) lets the loop drive multi-step plans one observation at a time: + // if the model speculatively emits Action 2 before seeing Observation 1, + // we execute Action 1 and let Observation 1 inform Action 2 on the next + // turn. + var first_action_start: ?usize = null; + var inside_think = false; + var cursor: usize = 0; + while (cursor < output.len) { + if (!inside_think and startsWithIgnoreCase(output[cursor..], "")) { + inside_think = true; + cursor += "".len; + continue; + } + if (inside_think and startsWithIgnoreCase(output[cursor..], "")) { + inside_think = false; + cursor += "".len; + continue; + } + + if (inside_think) { + cursor += 1; + continue; + } + + const line_start = cursor; + var line_end = cursor; + while (line_end < output.len and output[line_end] != '\n' and output[line_end] != '\r') { + line_end += 1; + } + + const raw_line = output[line_start..line_end]; + const line = std.mem.trimLeft(u8, raw_line, " \t"); + if (startsWithActionDirective(line)) { + first_action_start = line_start + (raw_line.len - line.len); + break; + } + + cursor = line_end; + while (cursor < output.len and (output[cursor] == '\n' or output[cursor] == '\r')) { + cursor += 1; + } + } + + const action_start = first_action_start orelse return null; + // For multi-line bracket arguments (fenced code blocks) we need to + // bound the search for the closing `]` to the FIRST action only, not + // any later actions the model may have over-eagerly emitted. We do this + // by clipping the search window at the next `Action ` directive line. + const action_text = clipAtNextActionDirective(output[action_start..]); + + // Extract the part after "Action" (and optional number + colon). + const after_keyword = action_text[6..]; // "Action" is 6 chars + const after_prefix = skipActionPrefix(after_keyword); + const trimmed = std.mem.trimLeft(u8, after_prefix, " \t"); + + if (trimmed.len == 0) return null; + + // Find the bracket pair: Type[arg] + const open_bracket = std.mem.indexOfScalar(u8, trimmed, '[') orelse return null; + // Greedy: find the last ']' after the action. + const close_bracket = std.mem.lastIndexOfScalar(u8, trimmed, ']') orelse return null; + if (close_bracket <= open_bracket) return null; + + const action_type = std.mem.trim(u8, trimmed[0..open_bracket], " \t"); + const argument = trimmed[open_bracket + 1 .. close_bracket]; + + if (action_type.len == 0) return null; + + if (eqlIgnoreCase(action_type, "search")) return .{ .search = argument }; + if (eqlIgnoreCase(action_type, "lookup")) return .{ .lookup = argument }; + if (eqlIgnoreCase(action_type, "cmd")) return .{ .cmd = argument }; + if (eqlIgnoreCase(action_type, "finish")) return .{ .finish = argument }; + if (eqlIgnoreCase(action_type, "read")) { + return .{ .tool = .{ .name = "file_read", .argument = argument } }; + } + if (eqlIgnoreCase(action_type, "write")) { + return .{ .tool = .{ .name = "file_write", .argument = argument } }; + } + if (isKnownApiToolAction(action_type)) { + return .{ .tool = .{ .name = action_type, .argument = argument } }; + } + + return .{ .unknown = trimmed[0 .. close_bracket + 1] }; +} + +/// Executes a parsed ReAct action and returns the observation text. +/// +/// - `Search` and `Lookup` return stubs (to be wired to tools/APIs later). +/// - `Cmd` delegates to the existing command_exec tool infrastructure. +/// - `Finish` should be handled by the caller before reaching this function, +/// but returns the answer content as a fallback. +/// - `Tool` is only executable by API loop callers that have a request tool +/// list; the standalone executor returns a hint. +/// - `Unknown` returns an error hint listing available actions. +pub fn executeReactAction( + allocator: std.mem.Allocator, + action: ReactAction, + app_config: *const config.Config, +) ![]u8 { + return switch (action) { + .search => |query| try std.fmt.allocPrint( + allocator, + "Search is not yet implemented. Query was: {s}", + .{query}, + ), + .lookup => |term| try std.fmt.allocPrint( + allocator, + "Lookup is not yet implemented. Term was: {s}", + .{term}, + ), + .cmd => |command| try executeReactCmd(allocator, command, app_config), + .finish => |answer| try allocator.dupe(u8, answer), + .tool => |tool| try std.fmt.allocPrint( + allocator, + "Tool action {s} is only available in API loop mode when requested by the client.", + .{tool.name}, + ), + .unknown => |text| try std.fmt.allocPrint( + allocator, + "Unknown action: {s}. Available actions: Search, Lookup, Cmd, Finish.", + .{text}, + ), + }; +} + +pub const parse_error_hint = + "Could not parse an Action. Reply with exactly one pair: Thought N: ... then Action N: Type[argument]. " ++ + "Example: Action 1: file_read[src/main.zig]"; + +/// Formats an observation message with turn numbering. +pub fn formatObservation( + allocator: std.mem.Allocator, + turn: usize, + observation_text: []const u8, +) ![]u8 { + return try std.fmt.allocPrint( + allocator, + "Observation {d}: {s}", + .{ turn, observation_text }, + ); +} + +// ── internal helpers ────────────────────────────────────────────────── + +fn executeReactCmd( + allocator: std.mem.Allocator, + command: []const u8, + app_config: *const config.Config, +) ![]u8 { + if (!app_config.tool_exec_enabled) { + return try allocator.dupe( + u8, + "Command execution is disabled. Set LLM_ROUTER_TOOL_EXEC_ENABLED=1 to enable.", + ); + } + + // Build a synthetic single-prompt request for the command tool. + const messages = try allocator.alloc(types.Message, 1); + errdefer allocator.free(messages); + + const role = try allocator.dupe(u8, "user"); + errdefer allocator.free(role); + + const content = try allocator.dupe(u8, command); + errdefer allocator.free(content); + + messages[0] = .{ .role = role, .content = content }; + + const req = types.Request{ + .prompt = try allocator.dupe(u8, command), + .messages = messages, + .provider = null, + .model = null, + .session_id = null, + .tenant_id = null, + .max_context_tokens = null, + .tools = try allocator.alloc(types.Tool, 0), + .tool_choice = null, + }; + defer req.deinit(allocator); + + const shell: command_exec_tool.ShellFlavor = if (@import("builtin").os.tag == .windows) .cmd else .bash; + var result = try command_exec_tool.execute(allocator, app_config, req, shell, "react-local"); + defer result.deinit(allocator); + + // Dupe the output to transfer ownership to the caller; result.deinit + // will free the original. + return try allocator.dupe(u8, result.output); +} + +/// Returns true if the line starts with the exact `Action` keyword. +fn startsWithActionDirective(line: []const u8) bool { + if (line.len < 6) return false; + if (!eqlIgnoreCase(line[0..6], "action")) return false; + + if (line.len == 6) return true; + return switch (line[6]) { + ' ', '\t', ':', '0'...'9' => true, + else => false, + }; +} + +/// Skips past optional whitespace, turn number, and colon after `Action`. +fn skipActionPrefix(input: []const u8) []const u8 { + var i: usize = 0; + + // Skip leading whitespace. + while (i < input.len and (input[i] == ' ' or input[i] == '\t')) : (i += 1) {} + + // Skip optional digits (turn number). + while (i < input.len and std.ascii.isDigit(input[i])) : (i += 1) {} + + while (i < input.len and (input[i] == ' ' or input[i] == '\t')) : (i += 1) {} + + // Skip optional colon. + if (i < input.len and input[i] == ':') { + i += 1; + } + + while (i < input.len and (input[i] == ' ' or input[i] == '\t')) : (i += 1) {} + + return input[i..]; +} + +fn eqlIgnoreCase(a: []const u8, b: []const u8) bool { + return std.ascii.eqlIgnoreCase(a, b); +} + +fn startsWithIgnoreCase(value: []const u8, prefix: []const u8) bool { + if (prefix.len > value.len) return false; + return std.ascii.eqlIgnoreCase(value[0..prefix.len], prefix); +} + +/// Returns a prefix of `text` that ends just before the next line starting +/// with `Action ` (case-insensitive). The first line is always kept since it +/// is itself the Action directive being parsed. +fn clipAtNextActionDirective(text: []const u8) []const u8 { + var cursor: usize = 0; + // Skip the directive's own first line. + while (cursor < text.len and text[cursor] != '\n' and text[cursor] != '\r') { + cursor += 1; + } + + while (cursor < text.len) { + while (cursor < text.len and (text[cursor] == '\n' or text[cursor] == '\r')) { + cursor += 1; + } + const line_start = cursor; + while (cursor < text.len and text[cursor] != '\n' and text[cursor] != '\r') { + cursor += 1; + } + + const raw_line = text[line_start..cursor]; + const line = std.mem.trimLeft(u8, raw_line, " \t"); + if (startsWithActionDirective(line)) { + return text[0..line_start]; + } + } + return text; +} + +fn isKnownApiToolAction(action_type: []const u8) bool { + return eqlIgnoreCase(action_type, "echo") or + eqlIgnoreCase(action_type, "utc") or + eqlIgnoreCase(action_type, "bash") or + eqlIgnoreCase(action_type, "file_read") or + eqlIgnoreCase(action_type, "file_write") or + eqlIgnoreCase(action_type, "file_search"); +} + +// ── tests ───────────────────────────────────────────────────────────── + +test "parseReactAction extracts Search action" { + const action = parseReactAction("Thought 1: I need to find info\nAction 1: Search[Colorado orogeny]"); + try std.testing.expect(action != null); + switch (action.?) { + .search => |query| try std.testing.expectEqualStrings("Colorado orogeny", query), + else => return error.UnexpectedActionType, + } +} + +test "parseReactAction extracts Finish action" { + const action = parseReactAction("Action 3: Finish[the answer is 42]"); + try std.testing.expect(action != null); + switch (action.?) { + .finish => |answer| try std.testing.expectEqualStrings("the answer is 42", answer), + else => return error.UnexpectedActionType, + } +} + +test "parseReactAction extracts Cmd action" { + const action = parseReactAction("Action 2: Cmd[zig build test]"); + try std.testing.expect(action != null); + switch (action.?) { + .cmd => |command| try std.testing.expectEqualStrings("zig build test", command), + else => return error.UnexpectedActionType, + } +} + +test "parseReactAction extracts Lookup action" { + const action = parseReactAction("Action 1: Lookup[eastern sector]"); + try std.testing.expect(action != null); + switch (action.?) { + .lookup => |term| try std.testing.expectEqualStrings("eastern sector", term), + else => return error.UnexpectedActionType, + } +} + +test "parseReactAction returns null for no action" { + const action = parseReactAction("Some text with no action line at all"); + try std.testing.expect(action == null); +} + +test "parseReactAction returns unknown for unrecognized action type" { + const action = parseReactAction("Action 1: UnknownAction[foo]"); + try std.testing.expect(action != null); + switch (action.?) { + .unknown => {}, + else => return error.UnexpectedActionType, + } +} + +test "parseReactAction maps Read alias to file_read" { + const action = parseReactAction("Action 1: Read[src/main.zig]"); + try std.testing.expect(action != null); + switch (action.?) { + .tool => |tool| { + try std.testing.expectEqualStrings("file_read", tool.name); + try std.testing.expectEqualStrings("src/main.zig", tool.argument); + }, + else => return error.UnexpectedActionType, + } +} + +test "parseReactAction extracts requested API tool action" { + const action = parseReactAction("Thought 1: write the file\nAction 1: file_write[write app.cpp\n```cpp\nint main() { return 0; }\n```]"); + try std.testing.expect(action != null); + switch (action.?) { + .tool => |tool| { + try std.testing.expectEqualStrings("file_write", tool.name); + try std.testing.expect(std.mem.indexOf(u8, tool.argument, "app.cpp") != null); + try std.testing.expect(std.mem.indexOf(u8, tool.argument, "int main()") != null); + }, + else => return error.UnexpectedActionType, + } +} + +test "ensureReactCodingToolsAlloc merges defaults and sets auto tool_choice" { + const allocator = std.testing.allocator; + + var request = types.Request{ + .prompt = try allocator.dupe(u8, "fix the bug"), + .messages = try allocator.alloc(types.Message, 0), + .tools = try allocator.alloc(types.Tool, 0), + .loop_mode = try allocator.dupe(u8, "react"), + }; + defer request.deinit(allocator); + + try ensureReactCodingToolsAlloc(allocator, &request); + try std.testing.expect(request.tools.len >= 4); + try std.testing.expect(hasToolName(request.tools, "file_read")); + try std.testing.expect(hasToolName(request.tools, "file_write")); + try std.testing.expect(hasToolName(request.tools, "file_search")); + try std.testing.expect(request.tool_choice != null); + try std.testing.expectEqualStrings("auto", request.tool_choice.?); +} + +test "ensureReactCodingToolsAlloc preserves client tools" { + const allocator = std.testing.allocator; + + var request = types.Request{ + .prompt = try allocator.dupe(u8, "echo"), + .messages = try allocator.alloc(types.Message, 0), + .tools = try allocator.alloc(types.Tool, 1), + .loop_mode = try allocator.dupe(u8, "react"), + }; + request.tools[0] = .{ + .name = try allocator.dupe(u8, "echo"), + .description = try allocator.dupe(u8, "echo tool"), + }; + defer request.deinit(allocator); + + try ensureReactCodingToolsAlloc(allocator, &request); + try std.testing.expect(hasToolName(request.tools, "echo")); + try std.testing.expect(hasToolName(request.tools, "file_read")); +} + +test "ensureReactCodingToolsAlloc no-op outside react mode" { + const allocator = std.testing.allocator; + + var request = types.Request{ + .prompt = try allocator.dupe(u8, "hello"), + .messages = try allocator.alloc(types.Message, 0), + .tools = try allocator.alloc(types.Tool, 0), + .loop_mode = try allocator.dupe(u8, "agent"), + }; + defer request.deinit(allocator); + + try ensureReactCodingToolsAlloc(allocator, &request); + try std.testing.expectEqual(@as(usize, 0), request.tools.len); + try std.testing.expect(request.tool_choice == null); +} + +test "buildSystemPromptAlloc includes requested tools" { + const allocator = std.testing.allocator; + const tools = [_]types.Tool{ + .{ .name = "file_write", .description = "write files" }, + }; + + const prompt = try buildSystemPromptAlloc(allocator, tools[0..]); + defer allocator.free(prompt); + + try std.testing.expect(std.mem.indexOf(u8, prompt, "file_write") != null); + try std.testing.expect(std.mem.indexOf(u8, prompt, "names only") != null); + try std.testing.expect(std.mem.indexOf(u8, prompt, "file_write[write relative/path") != null); +} + +test "parseReactAction handles nested brackets greedily" { + const action = parseReactAction("Action 5: Finish[1,800 to 7,000 ft]"); + try std.testing.expect(action != null); + switch (action.?) { + .finish => |answer| try std.testing.expectEqualStrings("1,800 to 7,000 ft", answer), + else => return error.UnexpectedActionType, + } +} + +test "parseReactAction is case-insensitive" { + const action = parseReactAction("action 1: search[test query]"); + try std.testing.expect(action != null); + switch (action.?) { + .search => |query| try std.testing.expectEqualStrings("test query", query), + else => return error.UnexpectedActionType, + } +} + +test "parseReactAction rejects actionable text" { + const action = parseReactAction("Actionable note: Search[bad]"); + try std.testing.expect(action == null); +} + +test "parseReactAction tolerates missing turn number" { + const action = parseReactAction("Action: Search[no number]"); + try std.testing.expect(action != null); + switch (action.?) { + .search => |query| try std.testing.expectEqualStrings("no number", query), + else => return error.UnexpectedActionType, + } +} + +test "parseReactAction returns first action when model emits multiple" { + const output = + "Thought 1: first thought\n" ++ + "Action 1: Search[first]\n" ++ + "Thought 2: second thought\n" ++ + "Action 2: Finish[final answer]"; + const action = parseReactAction(output); + try std.testing.expect(action != null); + switch (action.?) { + .search => |query| try std.testing.expectEqualStrings("first", query), + else => return error.UnexpectedActionType, + } +} + +test "parseReactAction ignores Action mentions inside block" { + const output = + "I should probably emit Action 1: Search[bad]\nsomething\n" ++ + "Action 1: Finish[real answer]"; + const action = parseReactAction(output); + try std.testing.expect(action != null); + switch (action.?) { + .finish => |answer| try std.testing.expectEqualStrings("real answer", answer), + else => return error.UnexpectedActionType, + } +} + +test "parseReactAction parses multi-line first action when a second follows" { + const output = + "Thought 1: write the file\n" ++ + "Action 1: file_write[calculator.cpp\n```cpp\nint main(){}\n```]\n" ++ + "Action 2: cmd[g++ -o calculator calculator.cpp]"; + const action = parseReactAction(output); + try std.testing.expect(action != null); + switch (action.?) { + .tool => |tool| { + try std.testing.expectEqualStrings("file_write", tool.name); + try std.testing.expect(std.mem.indexOf(u8, tool.argument, "calculator.cpp") != null); + try std.testing.expect(std.mem.indexOf(u8, tool.argument, "int main(){}") != null); + try std.testing.expect(std.mem.indexOf(u8, tool.argument, "g++") == null); + }, + else => return error.UnexpectedActionType, + } +} + +test "formatObservation produces correct format" { + const allocator = std.testing.allocator; + const result = try formatObservation(allocator, 3, "The answer is here"); + defer allocator.free(result); + try std.testing.expectEqualStrings("Observation 3: The answer is here", result); +} diff --git a/src/root.zig b/src/root.zig index 5d7c4ad..b87c2bf 100644 --- a/src/root.zig +++ b/src/root.zig @@ -9,9 +9,11 @@ pub const backend = @import("backend/api.zig"); pub const auth = @import("backend/auth.zig"); pub const session = @import("backend/session.zig"); pub const tools = @import("backend/tools.zig"); +pub const mcp = @import("backend/mcp.zig"); pub const core = @import("core/server.zig"); pub const config = @import("config.zig"); pub const types = @import("types.zig"); +pub const react = @import("react.zig"); // Re-export common error types. pub const ApiError = backend.errors.ApiError; @@ -24,6 +26,9 @@ test "backend: auth and tool scaffolding" { _ = @import("backend/auth.zig"); _ = @import("backend/tools.zig"); _ = @import("backend/session.zig"); + _ = @import("backend/mcp.zig"); + _ = @import("backend/tool_output.zig"); + _ = @import("backend/workspace.zig"); } test "core: HTTP request parsing" { @@ -43,3 +48,7 @@ test "providers: external provider adapters" { test "types: normalization" { _ = @import("types.zig"); } + +test "react: action parsing and observation formatting" { + _ = @import("react.zig"); +} diff --git a/src/tools/command_exec.zig b/src/tools/command_exec.zig index 25b508c..9ec4081 100644 --- a/src/tools/command_exec.zig +++ b/src/tools/command_exec.zig @@ -1,6 +1,7 @@ const std = @import("std"); const config = @import("../config.zig"); const types = @import("../types.zig"); +const tool_output = @import("../backend/tool_output.zig"); pub const ShellFlavor = enum { cmd, @@ -17,11 +18,59 @@ const CommandValidationError = error{ DangerousPattern, }; +const PipeCapture = struct { + bytes: []u8 = &.{}, + err_name: ?[]const u8 = null, + done: std.atomic.Value(bool) = std.atomic.Value(bool).init(false), +}; + +const tool_confirm_prefix = "LLM_ROUTER_TOOL_CONFIRM "; + +fn parseHexU64(input: []const u8) ?u64 { + var value: u64 = 0; + if (input.len == 0) return null; + for (input) |c| { + const digit: u8 = switch (c) { + '0'...'9' => c - '0', + 'a'...'f' => c - 'a' + 10, + 'A'...'F' => c - 'A' + 10, + else => return null, + }; + value = (value << 4) | @as(u64, @intCast(digit)); + } + return value; +} + +fn findToolConfirmToken(prompt: []const u8) ?u64 { + const idx = std.mem.indexOf(u8, prompt, tool_confirm_prefix) orelse return null; + const after = prompt[idx + tool_confirm_prefix.len ..]; + const token_end = std.mem.indexOfAny(u8, after, " \t\r\n") orelse after.len; + const token_str = after[0..token_end]; + return parseHexU64(token_str); +} + +fn stripToolConfirmMarkerSuffix(prompt: []const u8) []const u8 { + const idx = std.mem.indexOf(u8, prompt, tool_confirm_prefix) orelse return prompt; + return prompt[0..idx]; +} + +fn computeToolConfirmHash(flavor: ShellFlavor, command: []const u8) u64 { + var h = std.hash.Wyhash.init(0); + const flavor_bytes = switch (flavor) { + .cmd => "cmd", + .bash => "bash", + }; + h.update(flavor_bytes); + h.update(command); + return h.final(); +} + pub fn execute( allocator: std.mem.Allocator, app_config: *const config.Config, request: types.Request, flavor: ShellFlavor, + request_id: []const u8, ) !types.Response { const tool_name = switch (flavor) { .cmd => "cmd", @@ -40,8 +89,8 @@ pub fn execute( ); } - const prompt = std.mem.trim(u8, request.prompt, " \t\r\n"); - if (prompt.len == 0) { + const raw_prompt = std.mem.trim(u8, request.prompt, " \t\r\n"); + if (raw_prompt.len == 0) { return try makeToolResponse( allocator, tool_name, @@ -53,9 +102,19 @@ pub fn execute( ); } + const maybe_confirm_token = if (app_config.tool_exec_confirmation_required) + findToolConfirmToken(raw_prompt) + else + null; + + const prompt = if (app_config.tool_exec_confirmation_required) + stripToolConfirmMarkerSuffix(raw_prompt) + else + raw_prompt; + const command = blk: { if (looksLikeCommandPrompt(prompt)) { - break :blk validateAndNormalizeCommandAlloc(allocator, prompt, flavor) catch |err| switch (err) { + break :blk validateAndNormalizeCommandAlloc(allocator, prompt, flavor, app_config.tool_exec_trusted_local) catch |err| switch (err) { error.OutOfMemory => return err, else => { return try makeToolResponse( @@ -71,7 +130,7 @@ pub fn execute( }; } - const extracted = try extractCommandFromPromptAlloc(allocator, flavor, prompt); + const extracted = try extractCommandFromPromptAlloc(allocator, flavor, prompt, app_config.tool_exec_trusted_local); if (extracted) |value| { break :blk value; } @@ -88,6 +147,41 @@ pub fn execute( }; defer allocator.free(command); + if (app_config.tool_exec_confirmation_required) { + const expected = computeToolConfirmHash(flavor, command); + if (maybe_confirm_token == null) { + const token_hex = try std.fmt.allocPrint(allocator, "{x}", .{expected}); + defer allocator.free(token_hex); + + const out = try std.fmt.allocPrint( + allocator, + "DEBUG_TOOL_CONFIRMATION_REQUIRED\n" ++ + "tool={s}\n" ++ + "command={s}\n" ++ + "confirm_token={s}\n" ++ + "message=tool execution requires confirmation. Re-send the same command with a line containing: LLM_ROUTER_TOOL_CONFIRM {s}\n", + .{ tool_name, command, token_hex, token_hex }, + ); + return try makeToolResponse(allocator, tool_name, out); + } + + if (maybe_confirm_token.? != expected) { + const token_hex = try std.fmt.allocPrint(allocator, "{x}", .{expected}); + defer allocator.free(token_hex); + + const out = try std.fmt.allocPrint( + allocator, + "DEBUG_TOOL_CONFIRMATION_REQUIRED\n" ++ + "tool={s}\n" ++ + "command={s}\n" ++ + "confirm_token={s}\n" ++ + "message=tool execution requires confirmation. Re-send the same command with a line containing: LLM_ROUTER_TOOL_CONFIRM {s}\n", + .{ tool_name, command, token_hex, token_hex }, + ); + return try makeToolResponse(allocator, tool_name, out); + } + } + const start_ms = std.time.milliTimestamp(); const argv = switch (flavor) { @@ -95,17 +189,12 @@ pub fn execute( .bash => [_][]const u8{ "bash", "-lc", command }, }; - // TODO(security): `Child.run` blocks until the child exits and provides no - // built-in timeout. A hung command (e.g. `sleep 9999`) will hold the - // server's single accept-loop thread open for its entire duration. - // Replace this call with `Child.spawn` + async I/O polling + `Child.kill` - // once per-connection threading is in place. Until then, `tool_exec_enabled` - // must remain false (the default) in any exposed deployment. - const run_result = std.process.Child.run(.{ - .allocator = allocator, - .argv = &argv, - .max_output_bytes = app_config.tool_exec_max_output_bytes, - }) catch |err| { + var child = std.process.Child.init(&argv, allocator); + child.stdin_behavior = .Ignore; + child.stdout_behavior = .Pipe; + child.stderr_behavior = .Pipe; + + child.spawn() catch |err| { return try makeToolResponse( allocator, tool_name, @@ -116,8 +205,55 @@ pub fn execute( ), ); }; - defer allocator.free(run_result.stdout); - defer allocator.free(run_result.stderr); + + var stdout_capture = PipeCapture{}; + var stderr_capture = PipeCapture{}; + + var stdout_thread: ?std.Thread = null; + var stderr_thread: ?std.Thread = null; + + if (child.stdout) |stdout_file| { + stdout_thread = try std.Thread.spawn(.{}, capturePipeMain, .{ + &stdout_capture, + stdout_file, + app_config.tool_exec_max_output_bytes, + }); + } else { + stdout_capture.done.store(true, .release); + } + if (child.stderr) |stderr_file| { + stderr_thread = try std.Thread.spawn(.{}, capturePipeMain, .{ + &stderr_capture, + stderr_file, + app_config.tool_exec_max_output_bytes, + }); + } else { + stderr_capture.done.store(true, .release); + } + + const timeout_deadline = start_ms + @as(i64, app_config.tool_exec_timeout_ms); + var timed_out = false; + while (true) { + const stdout_done = stdout_capture.done.load(.acquire); + const stderr_done = stderr_capture.done.load(.acquire); + if (stdout_done and stderr_done) break; + if (std.time.milliTimestamp() >= timeout_deadline) { + timed_out = true; + break; + } + std.Thread.sleep(2 * std.time.ns_per_ms); + } + + if (timed_out) { + _ = child.kill() catch {}; + } + const terminated_as = child.wait() catch std.process.Child.Term{ .Unknown = 1 }; + + if (stdout_thread) |thread| thread.join(); + if (stderr_thread) |thread| thread.join(); + + defer if (stdout_capture.bytes.len > 0) std.heap.page_allocator.free(stdout_capture.bytes); + defer if (stderr_capture.bytes.len > 0) std.heap.page_allocator.free(stderr_capture.bytes); const end_ms = std.time.milliTimestamp(); const duration_ms = if (end_ms >= start_ms) @@ -125,13 +261,26 @@ pub fn execute( else @as(u64, 0); - const term_text = try childTermToTextAlloc(allocator, run_result.term); + const term_text = try childTermToTextAlloc(allocator, terminated_as); defer allocator.free(term_text); - const status_tag = if (isSuccessfulTermination(run_result.term)) "DEBUG_TOOL_OK" else "DEBUG_TOOL_ERROR"; + const status_tag = if (!timed_out and isSuccessfulTermination(terminated_as)) "DEBUG_TOOL_OK" else "DEBUG_TOOL_ERROR"; + const timeout_exceeded_text = if (timed_out) "true" else "false"; + const stdout_text = if (stdout_capture.err_name) |err_name| + try std.fmt.allocPrint(allocator, "read_error={s}", .{err_name}) + else + try allocator.dupe(u8, stdout_capture.bytes); + defer allocator.free(stdout_text); + + const stderr_text = if (stderr_capture.err_name) |err_name| + try std.fmt.allocPrint(allocator, "read_error={s}", .{err_name}) + else + try allocator.dupe(u8, stderr_capture.bytes); + defer allocator.free(stderr_text); + const output = try std.fmt.allocPrint( allocator, - "{s}\ntool={s}\nterm={s}\nduration_ms={d}\nmax_output_bytes={d}\ntimeout_policy_ms={d}\ntimeout_enforced=false\ncommand={s}\n--- stdout ---\n{s}\n--- stderr ---\n{s}", + "{s}\ntool={s}\nterm={s}\nduration_ms={d}\nmax_output_bytes={d}\ntimeout_policy_ms={d}\ntimeout_enforced=true\ntimeout_exceeded={s}\ncode={s}\ncommand={s}\n--- stdout ---\n{s}\n--- stderr ---\n{s}", .{ status_tag, tool_name, @@ -139,46 +288,68 @@ pub fn execute( duration_ms, app_config.tool_exec_max_output_bytes, app_config.tool_exec_timeout_ms, + timeout_exceeded_text, + if (timed_out) "command_timeout" else "command_completed", command, - run_result.stdout, - run_result.stderr, + stdout_text, + stderr_text, }, ); + defer allocator.free(output); + + const final_output = try tool_output.maybeOffloadToolOutputAlloc( + allocator, + app_config, + request_id, + tool_name, + output, + ); + + return try makeToolResponse(allocator, tool_name, final_output); +} - return try makeToolResponse(allocator, tool_name, output); +fn capturePipeMain(capture: *PipeCapture, file: std.fs.File, max_bytes: usize) void { + defer capture.done.store(true, .release); + capture.bytes = file.readToEndAlloc(std.heap.page_allocator, max_bytes) catch |err| { + capture.err_name = @errorName(err); + capture.bytes = &.{}; + return; + }; } pub fn extractCommandFromPromptAlloc( allocator: std.mem.Allocator, flavor: ShellFlavor, prompt: []const u8, + trusted_local: bool, ) !?[]u8 { - const trimmed = std.mem.trim(u8, prompt, " \t\r\n\"'"); + const without_confirm = stripToolConfirmMarkerSuffix(prompt); + const trimmed = std.mem.trim(u8, without_confirm, " \t\r\n\"'"); if (trimmed.len == 0) return null; if (extractQuotedCommandSlice(trimmed)) |quoted| { const normalized = std.mem.trim(u8, quoted, " \t\r\n"); if (normalized.len > 0) { - return try validateAndNormalizeCommandAlloc(allocator, normalized, flavor); + return try validateAndNormalizeCommandAlloc(allocator, normalized, flavor, trusted_local); } } if (stripCommandInstructionPrefix(trimmed)) |candidate| { const normalized = normalizeExtractedCandidate(candidate); if (normalized.len > 0) { - return try validateAndNormalizeCommandAlloc(allocator, normalized, flavor); + return try validateAndNormalizeCommandAlloc(allocator, normalized, flavor, trusted_local); } } if (extractInlineRunSegment(trimmed)) |candidate| { const normalized = normalizeExtractedCandidate(candidate); if (normalized.len > 0 and looksLikeCommandPrompt(normalized)) { - return try validateAndNormalizeCommandAlloc(allocator, normalized, flavor); + return try validateAndNormalizeCommandAlloc(allocator, normalized, flavor, trusted_local); } } if (looksLikeCommandPrompt(trimmed)) { - return try validateAndNormalizeCommandAlloc(allocator, trimmed, flavor); + return try validateAndNormalizeCommandAlloc(allocator, trimmed, flavor, trusted_local); } return null; @@ -188,11 +359,12 @@ fn validateAndNormalizeCommandAlloc( allocator: std.mem.Allocator, command: []const u8, flavor: ShellFlavor, + trusted_local: bool, ) ![]u8 { const trimmed = std.mem.trim(u8, command, " \t\r\n\"'"); if (trimmed.len == 0) return error.EmptyCommand; if (trimmed.len > max_command_len) return error.CommandTooLong; - if (containsUnsafeCommandByte(trimmed)) return error.UnsafeCharacter; + if (containsUnsafeCommandByte(trimmed, trusted_local)) return error.UnsafeCharacter; const command_name = firstToken(trimmed) orelse return error.EmptyCommand; if (isDangerousCommandName(flavor, command_name)) return error.DangerousCommand; @@ -225,10 +397,10 @@ fn looksLikeCommandPrompt(text: []const u8) bool { return true; } -fn containsUnsafeCommandByte(text: []const u8) bool { +fn containsUnsafeCommandByte(text: []const u8, trusted_local: bool) bool { for (text) |byte| { if (byte < 0x20 or byte == 0x7f) return true; - if (isUnsafeShellByte(byte)) return true; + if (!trusted_local and isUnsafeShellByte(byte)) return true; } return false; } @@ -382,6 +554,8 @@ fn extractQuotedCommandSlice(text: []const u8) ?[]const u8 { fn isLikelyNaturalLanguageLeadToken(token: []const u8) bool { const words = [_][]const u8{ "what", + "i", + "we", "why", "how", "when", @@ -395,6 +569,10 @@ fn isLikelyNaturalLanguageLeadToken(token: []const u8) bool { "tell", "show", "explain", + "write", + "create", + "build", + "make", "get", }; @@ -469,6 +647,7 @@ pub fn buildTestConfig(allocator: std.mem.Allocator, enable_exec: bool) !config. .listen_port = 8081, .debug_logging = false, .default_provider = try allocator.dupe(u8, "ollama_qwen"), + .default_model = null, .request_timeout_ms = 30_000, .provider_timeout_ms = 60_000, .instance_id = try allocator.dupe(u8, "local-instance"), @@ -476,7 +655,7 @@ pub fn buildTestConfig(allocator: std.mem.Allocator, enable_exec: bool) !config. .ollama_base_url = try allocator.dupe(u8, "http://127.0.0.1:11434"), .ollama_model = try allocator.dupe(u8, "qwen:7b"), .ollama_think = false, - .ollama_num_predict = 128, + .ollama_num_predict = 1024, .ollama_temperature = 0.7, .ollama_repeat_penalty = 1.05, .openai_base_url = try allocator.dupe(u8, "https://api.openai.com/v1"), @@ -502,9 +681,19 @@ pub fn buildTestConfig(allocator: std.mem.Allocator, enable_exec: bool) !config. .session_store_path = try allocator.dupe(u8, "logs/sessions"), .session_retention_messages = 24, .tool_exec_enabled = enable_exec, + .tool_exec_confirmation_required = false, + .tool_exec_trusted_local = false, .tool_exec_timeout_ms = 15_000, .tool_exec_max_output_bytes = 65_536, + .tool_output_offload_bytes = 8192, + .max_tool_calls_per_request = 8, + .default_max_context_tokens = null, + .workspace_mode_enabled = false, + .workspace_root = try allocator.dupe(u8, "."), .loop_stream_progress_enabled = true, + .max_concurrent_connections = 64, + .max_request_bytes = 1024 * 1024, + .max_header_bytes = 16 * 1024, }; } @@ -534,7 +723,7 @@ test "execute cmd tool returns disabled marker when not enabled" { }; defer req.deinit(allocator); - var result = try execute(allocator, &cfg, req, .cmd); + var result = try execute(allocator, &cfg, req, .cmd, "test-request"); defer result.deinit(allocator); try std.testing.expect(std.mem.indexOf(u8, result.output, "DEBUG_TOOL_ERROR") != null); @@ -548,6 +737,7 @@ test "extractCommandFromPromptAlloc extracts from run prefix" { allocator, .cmd, "Please run zig build test", + false, ); defer if (maybe_cmd) |cmd| allocator.free(cmd); @@ -555,6 +745,19 @@ test "extractCommandFromPromptAlloc extracts from run prefix" { try std.testing.expectEqualStrings("zig build test", maybe_cmd.?); } +test "extractCommandFromPromptAlloc ignores natural language requests" { + const allocator = std.testing.allocator; + const maybe_cmd = try extractCommandFromPromptAlloc( + allocator, + .cmd, + "Please write a small C++ program", + false, + ); + defer if (maybe_cmd) |cmd| allocator.free(cmd); + + try std.testing.expect(maybe_cmd == null); +} + test "execute cmd tool rejects unsafe chaining" { if (@import("builtin").os.tag != .windows) return; @@ -581,7 +784,7 @@ test "execute cmd tool rejects unsafe chaining" { }; defer req.deinit(allocator); - var result = try execute(allocator, &cfg, req, .cmd); + var result = try execute(allocator, &cfg, req, .cmd, "test-request"); defer result.deinit(allocator); try std.testing.expect(std.mem.indexOf(u8, result.output, "DEBUG_TOOL_ERROR") != null); @@ -614,9 +817,45 @@ test "execute cmd tool blocks dangerous denylist command" { }; defer req.deinit(allocator); - var result = try execute(allocator, &cfg, req, .cmd); + var result = try execute(allocator, &cfg, req, .cmd, "test-request"); defer result.deinit(allocator); try std.testing.expect(std.mem.indexOf(u8, result.output, "DEBUG_TOOL_ERROR") != null); try std.testing.expect(std.mem.indexOf(u8, result.output, "denylist") != null); } + +test "execute cmd tool enforces timeout and kills process" { + if (@import("builtin").os.tag != .windows) return; + + const allocator = std.testing.allocator; + var cfg = try buildTestConfig(allocator, true); + defer cfg.deinit(allocator); + cfg.tool_exec_timeout_ms = 200; + + const messages = try allocator.alloc(types.Message, 1); + messages[0] = .{ + .role = try allocator.dupe(u8, "user"), + .content = try allocator.dupe(u8, "powershell -NoProfile -Command Start-Sleep -Seconds 3"), + }; + + const req = types.Request{ + .prompt = try allocator.dupe(u8, "powershell -NoProfile -Command Start-Sleep -Seconds 3"), + .messages = messages, + .provider = null, + .model = null, + .session_id = null, + .tenant_id = null, + .max_context_tokens = null, + .tools = try allocator.alloc(types.Tool, 0), + .tool_choice = null, + }; + defer req.deinit(allocator); + + var result = try execute(allocator, &cfg, req, .cmd, "test-request"); + defer result.deinit(allocator); + + try std.testing.expect(std.mem.indexOf(u8, result.output, "DEBUG_TOOL_ERROR") != null); + try std.testing.expect(std.mem.indexOf(u8, result.output, "timeout_enforced=true") != null); + try std.testing.expect(std.mem.indexOf(u8, result.output, "timeout_exceeded=true") != null); + try std.testing.expect(std.mem.indexOf(u8, result.output, "code=command_timeout") != null); +} diff --git a/src/tools/file_ops.zig b/src/tools/file_ops.zig new file mode 100644 index 0000000..27dba01 --- /dev/null +++ b/src/tools/file_ops.zig @@ -0,0 +1,533 @@ +const std = @import("std"); +const config = @import("../config.zig"); +const types = @import("../types.zig"); +const workspace = @import("../backend/workspace.zig"); + +pub const Op = enum { read, write, search }; + +pub fn execute( + allocator: std.mem.Allocator, + app_config: *const config.Config, + request: types.Request, + op: Op, + active_workspace: ?*workspace.WorkspaceState, +) !types.Response { + const tool_name: []const u8 = switch (op) { + .read => "file_read", + .write => "file_write", + .search => "file_search", + }; + + const prompt = std.mem.trim(u8, request.prompt, " \t\r\n"); + if (prompt.len == 0) { + return try makeToolResponse( + allocator, + tool_name, + try std.fmt.allocPrint(allocator, "DEBUG_TOOL_ERROR\ntool={s}\nmessage=empty prompt", .{tool_name}), + ); + } + + return switch (op) { + .read => executeRead(allocator, app_config, tool_name, prompt, active_workspace), + .write => executeWrite(allocator, app_config, tool_name, request, prompt, active_workspace), + .search => executeSearch(allocator, app_config, tool_name, prompt, active_workspace), + }; +} + +fn executeRead( + allocator: std.mem.Allocator, + app_config: *const config.Config, + tool_name: []const u8, + prompt: []const u8, + active_workspace: ?*workspace.WorkspaceState, +) !types.Response { + const path_raw = parseSingleArgAfterPrefix(prompt, "read ") orelse + parseSingleArgAfterPrefix(prompt, "file_read ") orelse + parseSingleArgAfterPrefix(prompt, "exists ") orelse + parseSingleArgAfterPrefix(prompt, "file_exists ") orelse + parseSingleArgAfterPrefix(prompt, "open ") orelse + prompt; + + const rel_path = validateRelativePath(path_raw) catch |err| { + return try makeToolResponse( + allocator, + tool_name, + try std.fmt.allocPrint( + allocator, + "DEBUG_TOOL_ERROR\ntool={s}\nvalidation_error={s}\nmessage=invalid path", + .{ tool_name, @errorName(err) }, + ), + ); + }; + + const resolved_path = resolveToolPathAlloc(allocator, app_config, rel_path) catch |err| { + return try makeToolResponse( + allocator, + tool_name, + try std.fmt.allocPrint( + allocator, + "DEBUG_TOOL_ERROR\ntool={s}\nvalidation_error={s}\nmessage=invalid path", + .{ tool_name, @errorName(err) }, + ), + ); + }; + defer allocator.free(resolved_path); + + const start_ms = std.time.milliTimestamp(); + var file = std.fs.cwd().openFile(resolved_path, .{}) catch |err| { + return try makeToolResponse( + allocator, + tool_name, + try std.fmt.allocPrint(allocator, "DEBUG_TOOL_ERROR\ntool={s}\nopen_error={s}\npath={s}", .{ tool_name, @errorName(err), resolved_path }), + ); + }; + defer file.close(); + + const bytes = file.readToEndAlloc(allocator, app_config.tool_exec_max_output_bytes) catch |err| { + return try makeToolResponse( + allocator, + tool_name, + try std.fmt.allocPrint(allocator, "DEBUG_TOOL_ERROR\ntool={s}\nread_error={s}\npath={s}", .{ tool_name, @errorName(err), resolved_path }), + ); + }; + defer allocator.free(bytes); + + const end_ms = std.time.milliTimestamp(); + const duration_ms: u64 = if (end_ms >= start_ms) @as(u64, @intCast(end_ms - start_ms)) else 0; + + if (active_workspace) |ws| { + workspace.recordFileHint(ws, allocator, "read", rel_path) catch {}; + } + + const out = try std.fmt.allocPrint( + allocator, + "DEBUG_TOOL_OK\ntool={s}\npath={s}\nduration_ms={d}\nmax_bytes={d}\n--- content ---\n{s}", + .{ tool_name, resolved_path, duration_ms, app_config.tool_exec_max_output_bytes, bytes }, + ); + return try makeToolResponse(allocator, tool_name, out); +} + +fn executeWrite( + allocator: std.mem.Allocator, + app_config: *const config.Config, + tool_name: []const u8, + request: types.Request, + prompt: []const u8, + active_workspace: ?*workspace.WorkspaceState, +) !types.Response { + const parsed = parseWritePrompt(allocator, prompt) catch blk: { + const inferred_path = inferWriteFilenameFromPromptAlloc(allocator, prompt) orelse break :blk null; + errdefer allocator.free(inferred_path); + + const inferred_content = extractLastFencedCodeBlockAlloc(allocator, request.messages) orelse { + allocator.free(inferred_path); + break :blk null; + }; + errdefer allocator.free(inferred_content); + + break :blk ParsedWrite{ + .path = inferred_path, + .content = inferred_content, + }; + } orelse { + return try makeToolResponse( + allocator, + tool_name, + try std.fmt.allocPrint( + allocator, + "DEBUG_TOOL_ERROR\ntool={s}\nmessage=could not parse write request; use `write ` + fenced code block, or say `save the script as count.py` after providing a fenced code block", + .{tool_name}, + ), + ); + }; + defer allocator.free(parsed.path); + defer allocator.free(parsed.content); + + const rel_path = validateRelativePath(parsed.path) catch |err| { + return try makeToolResponse( + allocator, + tool_name, + try std.fmt.allocPrint(allocator, "DEBUG_TOOL_ERROR\ntool={s}\nvalidation_error={s}\nmessage=invalid path", .{ tool_name, @errorName(err) }), + ); + }; + + const resolved_path = resolveToolPathAlloc(allocator, app_config, rel_path) catch |err| { + return try makeToolResponse( + allocator, + tool_name, + try std.fmt.allocPrint(allocator, "DEBUG_TOOL_ERROR\ntool={s}\nvalidation_error={s}\nmessage=invalid path", .{ tool_name, @errorName(err) }), + ); + }; + defer allocator.free(resolved_path); + + if (parsed.content.len > app_config.tool_exec_max_output_bytes) { + return try makeToolResponse( + allocator, + tool_name, + try std.fmt.allocPrint( + allocator, + "DEBUG_TOOL_ERROR\ntool={s}\nmessage=content too large\ncontent_bytes={d}\nmax_bytes={d}", + .{ tool_name, parsed.content.len, app_config.tool_exec_max_output_bytes }, + ), + ); + } + + const start_ms = std.time.milliTimestamp(); + + if (std.fs.path.dirname(resolved_path)) |dir_name| { + std.fs.cwd().makePath(dir_name) catch |err| { + return try makeToolResponse( + allocator, + tool_name, + try std.fmt.allocPrint(allocator, "DEBUG_TOOL_ERROR\ntool={s}\nmkdir_error={s}\ndir={s}", .{ tool_name, @errorName(err), dir_name }), + ); + }; + } + + var file = std.fs.cwd().createFile(resolved_path, .{ .truncate = true }) catch |err| { + return try makeToolResponse( + allocator, + tool_name, + try std.fmt.allocPrint(allocator, "DEBUG_TOOL_ERROR\ntool={s}\ncreate_error={s}\npath={s}", .{ tool_name, @errorName(err), resolved_path }), + ); + }; + defer file.close(); + + file.writeAll(parsed.content) catch |err| { + return try makeToolResponse( + allocator, + tool_name, + try std.fmt.allocPrint(allocator, "DEBUG_TOOL_ERROR\ntool={s}\nwrite_error={s}\npath={s}", .{ tool_name, @errorName(err), resolved_path }), + ); + }; + + const end_ms = std.time.milliTimestamp(); + const duration_ms: u64 = if (end_ms >= start_ms) @as(u64, @intCast(end_ms - start_ms)) else 0; + + if (active_workspace) |ws| { + workspace.recordFileHint(ws, allocator, "write", rel_path) catch {}; + } + + const out = try std.fmt.allocPrint( + allocator, + "DEBUG_TOOL_OK\ntool={s}\npath={s}\nbytes_written={d}\nduration_ms={d}\n--- content ---\n{s}", + .{ tool_name, resolved_path, parsed.content.len, duration_ms, parsed.content }, + ); + return try makeToolResponse(allocator, tool_name, out); +} + +fn executeSearch( + allocator: std.mem.Allocator, + app_config: *const config.Config, + tool_name: []const u8, + prompt: []const u8, + active_workspace: ?*workspace.WorkspaceState, +) !types.Response { + const parsed = parseSearchPrompt(allocator, prompt) catch |err| { + return try makeToolResponse( + allocator, + tool_name, + try std.fmt.allocPrint( + allocator, + "DEBUG_TOOL_ERROR\ntool={s}\nparse_error={s}\nmessage=expected: search [in ]", + .{ tool_name, @errorName(err) }, + ), + ); + }; + defer allocator.free(parsed.pattern); + defer allocator.free(parsed.dir); + + const rel_dir = validateRelativePath(parsed.dir) catch |err| { + return try makeToolResponse( + allocator, + tool_name, + try std.fmt.allocPrint(allocator, "DEBUG_TOOL_ERROR\ntool={s}\nvalidation_error={s}\nmessage=invalid dir", .{ tool_name, @errorName(err) }), + ); + }; + + const resolved_dir = resolveToolPathAlloc(allocator, app_config, rel_dir) catch |err| { + return try makeToolResponse( + allocator, + tool_name, + try std.fmt.allocPrint(allocator, "DEBUG_TOOL_ERROR\ntool={s}\nvalidation_error={s}\nmessage=invalid dir", .{ tool_name, @errorName(err) }), + ); + }; + defer allocator.free(resolved_dir); + + const start_ms = std.time.milliTimestamp(); + + var out = std.ArrayList(u8){}; + defer out.deinit(allocator); + + try out.writer(allocator).print( + "DEBUG_TOOL_OK\ntool={s}\ndir={s}\npattern={s}\nmax_bytes={d}\n--- matches ---\n", + .{ tool_name, resolved_dir, parsed.pattern, app_config.tool_exec_max_output_bytes }, + ); + + var bytes_budget: usize = app_config.tool_exec_max_output_bytes; + if (out.items.len < bytes_budget) bytes_budget -= out.items.len else bytes_budget = 0; + + var matches: usize = 0; + var files_scanned: usize = 0; + const max_files: usize = 500; + + var dir = std.fs.cwd().openDir(resolved_dir, .{ .iterate = true }) catch |err| { + return try makeToolResponse( + allocator, + tool_name, + try std.fmt.allocPrint(allocator, "DEBUG_TOOL_ERROR\ntool={s}\nopen_dir_error={s}\ndir={s}", .{ tool_name, @errorName(err), resolved_dir }), + ); + }; + defer dir.close(); + + var walker = dir.walk(allocator) catch |err| return err; + defer walker.deinit(); + + while (try walker.next()) |entry| { + if (files_scanned >= max_files) break; + if (entry.kind != .file) continue; + files_scanned += 1; + + const path_in_dir = entry.path; + if (std.mem.endsWith(u8, path_in_dir, ".exe") or std.mem.endsWith(u8, path_in_dir, ".dll")) continue; + + var f = dir.openFile(path_in_dir, .{}) catch continue; + defer f.close(); + + const bytes = f.readToEndAlloc(allocator, 64 * 1024) catch continue; + defer allocator.free(bytes); + + if (std.mem.indexOf(u8, bytes, parsed.pattern) == null) continue; + matches += 1; + + const line = try std.fmt.allocPrint(allocator, "{s}\n", .{path_in_dir}); + defer allocator.free(line); + + if (line.len > bytes_budget) break; + try out.appendSlice(allocator, line); + bytes_budget -= line.len; + if (bytes_budget == 0) break; + } + + const end_ms = std.time.milliTimestamp(); + const duration_ms: u64 = if (end_ms >= start_ms) @as(u64, @intCast(end_ms - start_ms)) else 0; + try out.writer(allocator).print("\nfiles_scanned={d}\nmatches={d}\nduration_ms={d}\n", .{ files_scanned, matches, duration_ms }); + + if (active_workspace) |ws| { + workspace.recordFileHint(ws, allocator, "search", rel_dir) catch {}; + } + + return try makeToolResponse(allocator, tool_name, try out.toOwnedSlice(allocator)); +} + +fn inferWriteFilenameFromPromptAlloc(allocator: std.mem.Allocator, prompt: []const u8) ?[]u8 { + const trimmed = std.mem.trim(u8, prompt, " \t\r\n"); + if (trimmed.len == 0) return null; + + var it = std.mem.tokenizeAny(u8, trimmed, " \t\r\n\"'`()[]{}<>"); + var last_match: ?[]const u8 = null; + while (it.next()) |tok| { + if (tok.len < 4) continue; + if (!isSupportedWriteFilename(tok)) continue; + last_match = tok; + } + + if (last_match) |m| { + return allocator.dupe(u8, m) catch null; + } + return null; +} + +fn isSupportedWriteFilename(path: []const u8) bool { + return std.ascii.endsWithIgnoreCase(path, ".py") or + std.ascii.endsWithIgnoreCase(path, ".cpp") or + std.ascii.endsWithIgnoreCase(path, ".cc") or + std.ascii.endsWithIgnoreCase(path, ".cxx") or + std.ascii.endsWithIgnoreCase(path, ".c") or + std.ascii.endsWithIgnoreCase(path, ".h") or + std.ascii.endsWithIgnoreCase(path, ".hpp") or + std.ascii.endsWithIgnoreCase(path, ".js") or + std.ascii.endsWithIgnoreCase(path, ".ts") or + std.ascii.endsWithIgnoreCase(path, ".zig"); +} + +fn extractLastFencedCodeBlockAlloc(allocator: std.mem.Allocator, messages: []const types.Message) ?[]u8 { + var i: usize = messages.len; + while (i > 0) { + i -= 1; + const msg = messages[i]; + if (!std.ascii.eqlIgnoreCase(msg.role, "assistant")) continue; + + const content = msg.content; + const fence_end = std.mem.lastIndexOf(u8, content, "```") orelse continue; + const before_end = content[0..fence_end]; + const fence_start = std.mem.lastIndexOf(u8, before_end, "```") orelse continue; + + const after_start = content[fence_start + 3 ..]; + const nl = std.mem.indexOfScalar(u8, after_start, '\n') orelse continue; + const after_lang = after_start[nl + 1 ..]; + const end = std.mem.indexOf(u8, after_lang, "```") orelse continue; + + const inner = std.mem.trim(u8, after_lang[0..end], "\r\n"); + if (inner.len == 0) continue; + return allocator.dupe(u8, inner) catch null; + } + return null; +} + +const ParsedWrite = struct { + path: []u8, + content: []u8, +}; + +fn parseWritePrompt(allocator: std.mem.Allocator, prompt: []const u8) !ParsedWrite { + var trimmed = std.mem.trim(u8, prompt, " \t\r\n"); + if (!std.ascii.startsWithIgnoreCase(trimmed, "write ")) return error.BadFormat; + trimmed = trimmed["write ".len..]; + trimmed = std.mem.trimLeft(u8, trimmed, " \t"); + if (trimmed.len == 0) return error.BadFormat; + + const fence_idx = std.mem.indexOf(u8, trimmed, "```") orelse return error.BadFormat; + const path_part = std.mem.trim(u8, trimmed[0..fence_idx], " \t\r\n\"'"); + if (path_part.len == 0) return error.BadFormat; + + const fence_start = fence_idx; + const after_fence = trimmed[fence_start + 3 ..]; + const newline_idx = std.mem.indexOfScalar(u8, after_fence, '\n') orelse return error.BadFormat; + const after_lang = after_fence[newline_idx + 1 ..]; + const fence_end = std.mem.indexOf(u8, after_lang, "```") orelse return error.BadFormat; + const content_part = after_lang[0..fence_end]; + + return .{ + .path = try allocator.dupe(u8, path_part), + .content = try allocator.dupe(u8, content_part), + }; +} + +const ParsedSearch = struct { + pattern: []u8, + dir: []u8, +}; + +fn parseSearchPrompt(allocator: std.mem.Allocator, prompt: []const u8) !ParsedSearch { + var trimmed = std.mem.trim(u8, prompt, " \t\r\n"); + if (std.ascii.startsWithIgnoreCase(trimmed, "search ")) { + trimmed = trimmed["search ".len..]; + } else if (std.ascii.startsWithIgnoreCase(trimmed, "file_search ")) { + trimmed = trimmed["file_search ".len..]; + } else { + return error.BadFormat; + } + + trimmed = std.mem.trimLeft(u8, trimmed, " \t"); + if (trimmed.len == 0) return error.BadFormat; + + var dir_value: []const u8 = "."; + var pattern_value: []const u8 = trimmed; + + if (indexOfIgnoreCase(trimmed, " in ")) |idx| { + pattern_value = std.mem.trim(u8, trimmed[0..idx], " \t\r\n\"'"); + dir_value = std.mem.trim(u8, trimmed[idx + 4 ..], " \t\r\n\"'"); + } else { + pattern_value = std.mem.trim(u8, pattern_value, " \t\r\n\"'"); + } + + if (pattern_value.len == 0) return error.BadFormat; + if (dir_value.len == 0) dir_value = "."; + + return .{ + .pattern = try allocator.dupe(u8, pattern_value), + .dir = try allocator.dupe(u8, dir_value), + }; +} + +fn parseSingleArgAfterPrefix(text: []const u8, prefix: []const u8) ?[]const u8 { + if (!std.ascii.startsWithIgnoreCase(text, prefix)) return null; + return std.mem.trim(u8, text[prefix.len..], " \t\r\n\"'"); +} + +fn resolveToolPathAlloc(allocator: std.mem.Allocator, app_config: *const config.Config, rel_path: []const u8) ![]u8 { + if (app_config.workspace_mode_enabled) { + return std.fs.path.join(allocator, &.{ app_config.workspace_root, rel_path }); + } + return prefixTmpPathAlloc(allocator, rel_path); +} + +fn prefixTmpPathAlloc(allocator: std.mem.Allocator, rel_path: []const u8) ![]u8 { + // All tool file IO is sandboxed under `tmp/` so generated artifacts don't + // spill into the repo root. + if (std.ascii.eqlIgnoreCase(rel_path, ".")) { + return allocator.dupe(u8, "tmp"); + } + + if (std.ascii.startsWithIgnoreCase(rel_path, "tmp")) { + if (rel_path.len == 3) return allocator.dupe(u8, rel_path); + const next = rel_path[3]; + if (next == '/' or next == std.fs.path.sep) return allocator.dupe(u8, rel_path); + } + + return std.fmt.allocPrint(allocator, "tmp{c}{s}", .{ std.fs.path.sep, rel_path }); +} + +const PathValidationError = error{ + EmptyPath, + AbsolutePath, + ParentTraversal, + ContainsNul, +}; + +fn validateRelativePath(path_raw: []const u8) ![]const u8 { + const p = std.mem.trim(u8, path_raw, " \t\r\n\"'"); + if (p.len == 0) return error.EmptyPath; + if (std.mem.indexOfScalar(u8, p, 0) != null) return error.ContainsNul; + if (std.fs.path.isAbsolute(p)) return error.AbsolutePath; + + var it = std.mem.splitScalar(u8, p, std.fs.path.sep); + while (it.next()) |part| { + if (std.mem.eql(u8, part, "..")) return error.ParentTraversal; + } + + return p; +} + +fn indexOfIgnoreCase(haystack: []const u8, needle: []const u8) ?usize { + if (needle.len == 0) return 0; + if (needle.len > haystack.len) return null; + var i: usize = 0; + while (i + needle.len <= haystack.len) : (i += 1) { + if (std.ascii.eqlIgnoreCase(haystack[i .. i + needle.len], needle)) return i; + } + return null; +} + +fn makeToolResponse( + allocator: std.mem.Allocator, + tool_name: []const u8, + output: []u8, +) !types.Response { + errdefer allocator.free(output); + const model_name = try std.fmt.allocPrint(allocator, "debug-tools/{s}", .{tool_name}); + errdefer allocator.free(model_name); + return .{ + .id = null, + .model = model_name, + .output = output, + .finish_reason = try allocator.dupe(u8, "tool"), + .success = true, + .usage = .{}, + }; +} + +test "inferWriteFilenameFromPromptAlloc supports C++ filenames" { + const allocator = std.testing.allocator; + const path = inferWriteFilenameFromPromptAlloc(allocator, "file_write app.cpp") orelse return error.ExpectedFilename; + defer allocator.free(path); + + try std.testing.expectEqualStrings("app.cpp", path); +} + +test "isSupportedWriteFilename keeps Python inference" { + try std.testing.expect(isSupportedWriteFilename("count.py")); + try std.testing.expect(!isSupportedWriteFilename("notes.txt")); +} + diff --git a/src/types.zig b/src/types.zig index 7b3eacf..e7b0c03 100644 --- a/src/types.zig +++ b/src/types.zig @@ -15,6 +15,7 @@ pub const Request = struct { repeat_penalty: ?f64 = null, session_id: ?[]const u8 = null, tenant_id: ?[]const u8 = null, + workspace_id: ?[]const u8 = null, max_context_tokens: ?usize = null, tools: []Tool, tool_choice: ?[]const u8 = null, @@ -45,6 +46,9 @@ pub const Request = struct { if (self.tenant_id) |tenant_id| { allocator.free(tenant_id); } + if (self.workspace_id) |workspace_id| { + allocator.free(workspace_id); + } for (self.tools) |tool| { tool.deinit(allocator); } @@ -80,10 +84,24 @@ pub const Message = struct { pub const Tool = struct { name: []const u8, description: []const u8, + parameters_json: ?[]const u8 = null, pub fn deinit(self: Tool, allocator: std.mem.Allocator) void { allocator.free(self.name); allocator.free(self.description); + if (self.parameters_json) |parameters_json| allocator.free(parameters_json); + } +}; + +pub const ToolCall = struct { + id: []const u8, + name: []const u8, + arguments_json: []const u8, + + pub fn deinit(self: ToolCall, allocator: std.mem.Allocator) void { + allocator.free(self.id); + allocator.free(self.name); + allocator.free(self.arguments_json); } }; @@ -101,6 +119,7 @@ pub const Response = struct { finish_reason: []const u8, // OpenAI-compatible completion stop reason. success: bool, // false means output should be surfaced as provider error. usage: Usage = .{}, // Best-effort token accounting. + tool_calls: ?[]ToolCall = null, // OpenAI-compatible assistant tool calls, when provided. /// Releases all owned response fields. /// @@ -114,6 +133,12 @@ pub const Response = struct { allocator.free(self.model); allocator.free(self.output); allocator.free(self.finish_reason); + if (self.tool_calls) |tool_calls| { + for (tool_calls) |tool_call| { + tool_call.deinit(allocator); + } + allocator.free(tool_calls); + } } };