From 199d652d61c16566a37126ca060eab48d2c13bc4 Mon Sep 17 00:00:00 2001 From: Ajeets6 Date: Sun, 26 Apr 2026 21:14:03 +1000 Subject: [PATCH 01/25] Add test client --- .gitignore | 2 +- client/requirements.txt | 2 + client/test_client.py | 337 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 340 insertions(+), 1 deletion(-) create mode 100644 client/requirements.txt create mode 100644 client/test_client.py diff --git a/.gitignore b/.gitignore index e8d928e..bd7061e 100644 --- a/.gitignore +++ b/.gitignore @@ -3,7 +3,7 @@ zig-out .agents/ .ollama-qwen-env.example .env -client/* +__pycache__/ logs/* # General .DS_Store diff --git a/client/requirements.txt b/client/requirements.txt new file mode 100644 index 0000000..59c3cea --- /dev/null +++ b/client/requirements.txt @@ -0,0 +1,2 @@ +streamlit>=1.33.0 +requests>=2.31.0 diff --git a/client/test_client.py b/client/test_client.py new file mode 100644 index 0000000..9766bdc --- /dev/null +++ b/client/test_client.py @@ -0,0 +1,337 @@ +import json +from typing import Dict, Generator +from urllib.parse import urlsplit, urlunsplit + +import requests +import streamlit as st + +st.set_page_config(page_title="Zig Agent Test Client", layout="wide") +st.title("Zig AI Harness Test Client") +st.caption("Debug and test harness behavior: chat routing, tool execution, sessions, streaming, and diagnostics.") + + +def build_headers(api_key: str) -> Dict[str, str]: + headers = {"Content-Type": "application/json"} + if api_key.strip(): + headers["x-api-key"] = api_key.strip() + return headers + + +def parse_assistant_content(data: Dict) -> str: + choices = data.get("choices", []) + if not choices: + return "" + message = choices[0].get("message", {}) + return message.get("content", "") + + +def base_url_from_endpoint(endpoint: str) -> str: + parsed = urlsplit(endpoint.strip()) + if not parsed.scheme or not parsed.netloc: + return "" + return urlunsplit((parsed.scheme, parsed.netloc, "", "", "")) + + +def stream_probe_lines(response: requests.Response) -> Generator[str, None, None]: + content_type = (response.headers.get("content-type") or "").lower() + + if "text/event-stream" not in content_type: + body = response.text + yield "Server did not return SSE stream (text/event-stream)." + yield f"Content-Type: {content_type or ''}" + if body: + yield "" + yield "Raw response body:" + yield body + return + + for raw_line in response.iter_lines(decode_unicode=True): + if not raw_line: + continue + + line = raw_line.strip() + if not line.startswith("data:"): + continue + + payload = line[5:].strip() + if payload == "[DONE]": + yield "\n\n[stream done]" + return + + try: + event = json.loads(payload) + except json.JSONDecodeError: + yield payload + continue + + choices = event.get("choices", []) + if not choices: + yield payload + continue + + delta = choices[0].get("delta", {}) + token = delta.get("content") + if token is not None: + yield token + + +with st.sidebar: + st.header("Connection") + endpoint = st.text_input( + "Server endpoint", + value="http://127.0.0.1:8081/v1/chat/completions", + help="Your Zig server chat completion endpoint.", + ) + api_key = st.text_input( + "Client API key (optional)", + value="", + type="password", + help="Sends x-api-key header if provided.", + ) + + st.header("Request") + provider = st.text_input("provider", value="ollama") + model = st.text_input("model", value="auto") + thinking = st.checkbox( + "thinking", + value=False, + help="Enable model reasoning/thinking tokens. Turn off for faster and shorter replies.", + ) + temperature = st.slider( + "temperature", + min_value=0.0, + max_value=2.0, + value=0.7, + step=0.05, + help="Lower values are more deterministic. Higher values are more creative.", + ) + session_id = st.text_input("session_id (optional)", value="") + tenant_id = st.text_input("tenant_id (optional)", value="") + max_context_tokens = st.number_input( + "max_context_tokens (optional)", + min_value=0, + value=0, + step=256, + ) + + st.header("Loop") + loop_mode = st.selectbox( + "loop_mode", + options=["none", "basic", "agent"], + index=0, + help="Server-side iterative loop mode usable from any frontend client.", + ) + loop_until = st.text_input("loop_until", value="DONE") + loop_max_turns = st.number_input( + "loop_max_turns", + min_value=0, + max_value=128, + value=0, + step=1, + help="0 means disabled and server defaults apply when loop_mode is used.", + ) + + st.header("Tools (debug)") + selected_tools = st.multiselect( + "Available tools to send", + options=["echo", "utc", "cmd", "bash"], + default=["utc", "cmd"], + help="Multiple tools can be sent in one request. cmd/bash require server env LLM_ROUTER_TOOL_EXEC_ENABLED=1.", + ) + + tool_choice = st.selectbox( + "tool_choice", + options=["auto", "none", "required", "echo", "utc", "cmd", "bash"], + index=0, + help="Use auto for prompt-driven tool execution with multiple tools.", + ) + +st.subheader("Prompt") +prompt = st.text_area( + "Message", + value="Get time using utc tool and run the ping google.com command.", + height=120, +) + +col1, col2 = st.columns(2) +send_clicked = col1.button("Send POST request", use_container_width=True) +stream_clicked = col2.button("Probe streaming", use_container_width=True) + +diag_col1, diag_col2, diag_col3 = st.columns(3) +health_clicked = diag_col1.button("GET /health", use_container_width=True) +metrics_clicked = diag_col2.button("GET /metrics", use_container_width=True) +providers_clicked = diag_col3.button("GET /diagnostics/providers", use_container_width=True) + + +def build_payload(force_stream: bool) -> Dict: + message_content = prompt + + payload: Dict = { + "provider": provider.strip() or "ollama", + "messages": [{"role": "user", "content": message_content}], + "think": bool(thinking), + "temperature": float(temperature), + } + + if model.strip(): + payload["model"] = model.strip() + + if session_id.strip(): + payload["session_id"] = session_id.strip() + + if tenant_id.strip(): + payload["tenant_id"] = tenant_id.strip() + + if max_context_tokens > 0: + payload["max_context_tokens"] = int(max_context_tokens) + + if loop_mode != "none": + payload["loop_mode"] = loop_mode + if loop_until.strip(): + payload["loop_until"] = loop_until.strip() + if loop_max_turns > 0: + payload["loop_max_turns"] = int(loop_max_turns) + + if selected_tools: + tool_descriptions = { + "echo": "Return a deterministic echo response for harness debugging.", + "utc": "Return current UTC timestamp for deterministic tool-path checks.", + "cmd": "Execute a Windows cmd command from prompt text (debug use).", + "bash": "Execute a bash command from prompt text (debug use).", + } + payload["tools"] = [ + { + "name": tool_name, + "description": tool_descriptions.get(tool_name, "debug tool"), + } + for tool_name in selected_tools + ] + + if tool_choice != "none": + payload["tool_choice"] = tool_choice + + if force_stream: + payload["stream"] = True + + return payload + + +def request_timeout_seconds(force_stream: bool) -> int: + timeout = 120 if force_stream else 60 + + if loop_mode != "none": + # Loop mode can require multiple provider turns in one HTTP request. + timeout = max(timeout, 240) + + if ("cmd" in selected_tools) or ("bash" in selected_tools): + timeout = max(timeout, 120) + + return timeout + + +if send_clicked: + if not endpoint.strip(): + st.error("Endpoint is required.") + elif not prompt.strip(): + st.error("Prompt is required.") + else: + payload = build_payload(force_stream=False) + st.code(json.dumps(payload, indent=2), language="json") + + try: + timeout_sec = request_timeout_seconds(force_stream=False) + response = requests.post( + endpoint.strip(), + headers=build_headers(api_key), + json=payload, + timeout=timeout_sec, + ) + except requests.Timeout: + st.error( + "Request timed out. Increase timeout by disabling loop mode, reducing loop_max_turns, " + "or simplifying tool usage in the prompt." + ) + except requests.RequestException as exc: + st.error(f"Request failed: {exc}") + else: + st.write(f"HTTP {response.status_code}") + try: + data = response.json() + except ValueError: + st.error("Response is not valid JSON.") + st.code(response.text) + else: + st.json(data) + assistant = parse_assistant_content(data) + if assistant: + st.success("Assistant output") + st.write(assistant) + +if stream_clicked: + if not endpoint.strip(): + st.error("Endpoint is required.") + elif not prompt.strip(): + st.error("Prompt is required.") + else: + if loop_mode != "none": + st.info( + "Loop progress visibility is controlled by server setting " + "LLM_ROUTER_LOOP_STREAM_PROGRESS_ENABLED (default: enabled)." + ) + + payload = build_payload(force_stream=True) + st.code(json.dumps(payload, indent=2), language="json") + + try: + timeout_sec = request_timeout_seconds(force_stream=True) + response = requests.post( + endpoint.strip(), + headers=build_headers(api_key), + json=payload, + timeout=timeout_sec, + stream=True, + ) + except requests.Timeout: + st.error( + "Streaming probe timed out. Try a shorter prompt, fewer loop turns, or disable loop mode." + ) + except requests.RequestException as exc: + st.error(f"Streaming probe failed: {exc}") + else: + st.write(f"HTTP {response.status_code}") + if response.status_code >= 400: + st.error("Server returned an error status during stream probe.") + st.code(response.text) + else: + st.write("Streaming output") + st.write_stream(stream_probe_lines(response)) + +if health_clicked or metrics_clicked or providers_clicked: + base_url = base_url_from_endpoint(endpoint) + if not base_url: + st.error("Could not infer base URL from endpoint.") + else: + if health_clicked: + target_path = "/health" + elif metrics_clicked: + target_path = "/metrics" + else: + target_path = "/diagnostics/providers" + + target_url = f"{base_url}{target_path}" + st.code(target_url) + + try: + resp = requests.get( + target_url, + headers=build_headers(api_key), + timeout=30, + ) + except requests.RequestException as exc: + st.error(f"Diagnostics request failed: {exc}") + else: + st.write(f"HTTP {resp.status_code}") + try: + st.json(resp.json()) + except ValueError: + st.code(resp.text) From b50cf607d4fb8f159a887948501199488b25bff2 Mon Sep 17 00:00:00 2001 From: Ajeets6 Date: Fri, 1 May 2026 14:14:30 +1000 Subject: [PATCH 02/25] Add concurency --- src/config.zig | 3 + src/core/server.zig | 255 +++++++++++++++++++++++++++++-------- src/tools/command_exec.zig | 1 + 3 files changed, 205 insertions(+), 54 deletions(-) diff --git a/src/config.zig b/src/config.zig index 538eea8..7176baa 100644 --- a/src/config.zig +++ b/src/config.zig @@ -49,6 +49,7 @@ pub const Config = struct { tool_exec_timeout_ms: u32, tool_exec_max_output_bytes: usize, loop_stream_progress_enabled: bool, + max_concurrent_connections: usize, pub fn load(allocator: std.mem.Allocator) !Config { return try loadWithOverrides(allocator, null); @@ -249,6 +250,7 @@ pub const Config = struct { .tool_exec_timeout_ms = try getEnvPositiveU32OrDefault("LLM_ROUTER_TOOL_EXEC_TIMEOUT_MS", 15_000, env_overrides), .tool_exec_max_output_bytes = try getEnvPositiveUsizeOrDefault("LLM_ROUTER_TOOL_EXEC_MAX_OUTPUT_BYTES", 65_536, env_overrides), .loop_stream_progress_enabled = try getEnvFlagOrDefault(allocator, "LLM_ROUTER_LOOP_STREAM_PROGRESS_ENABLED", true, env_overrides), + .max_concurrent_connections = try getEnvPositiveUsizeOrDefault("LLM_ROUTER_MAX_CONCURRENT_CONNECTIONS", 64, env_overrides), }; } @@ -526,6 +528,7 @@ test "setDefaultProvider stores canonical provider alias" { .tool_exec_timeout_ms = 15_000, .tool_exec_max_output_bytes = 65_536, .loop_stream_progress_enabled = true, + .max_concurrent_connections = 64, }; defer cfg.deinit(allocator); diff --git a/src/core/server.zig b/src/core/server.zig index 6c60671..55e121b 100644 --- a/src/core/server.zig +++ b/src/core/server.zig @@ -11,6 +11,7 @@ const tooling = @import("../backend/tools.zig"); const session = @import("../backend/session.zig"); const ServerState = struct { + mutex: std.Thread.Mutex = .{}, total_requests: u64 = 0, successful_requests: u64 = 0, failed_requests: u64 = 0, @@ -29,12 +30,114 @@ const ServerState = struct { } fn noteClient(self: *ServerState, allocator: std.mem.Allocator, label: []const u8) !void { + self.mutex.lock(); + defer self.mutex.unlock(); for (self.connected_clients.items) |existing| { if (std.mem.eql(u8, existing, label)) return; } try self.connected_clients.append(allocator, try allocator.dupe(u8, label)); } + + fn noteRequestStarted(self: *ServerState) void { + self.mutex.lock(); + defer self.mutex.unlock(); + self.total_requests += 1; + } + + fn noteRequestSucceeded(self: *ServerState) void { + self.mutex.lock(); + defer self.mutex.unlock(); + self.successful_requests += 1; + } + + fn noteRequestFailed(self: *ServerState) void { + self.mutex.lock(); + defer self.mutex.unlock(); + self.failed_requests += 1; + } + + fn noteConnectionOpened(self: *ServerState) void { + self.mutex.lock(); + defer self.mutex.unlock(); + self.active_connections += 1; + } + + fn noteConnectionClosed(self: *ServerState) void { + self.mutex.lock(); + defer self.mutex.unlock(); + self.active_connections -= 1; + } + + fn snapshot(self: *ServerState) Snapshot { + self.mutex.lock(); + defer self.mutex.unlock(); + return .{ + .total_requests = self.total_requests, + .successful_requests = self.successful_requests, + .failed_requests = self.failed_requests, + .active_connections = self.active_connections, + }; + } + + fn knownClientsJsonAlloc(self: *ServerState, allocator: std.mem.Allocator) ![]u8 { + var clients_json = std.ArrayList(u8){}; + errdefer clients_json.deinit(allocator); + + self.mutex.lock(); + defer self.mutex.unlock(); + + try clients_json.append(allocator, '['); + for (self.connected_clients.items, 0..) |client, idx| { + if (idx > 0) try clients_json.append(allocator, ','); + const escaped = try response.escapeJsonStringAlloc(allocator, client); + defer allocator.free(escaped); + try clients_json.writer(allocator).print("\"{s}\"", .{escaped}); + } + try clients_json.append(allocator, ']'); + return try clients_json.toOwnedSlice(allocator); + } +}; + +const Snapshot = struct { + total_requests: u64, + successful_requests: u64, + failed_requests: u64, + active_connections: u64, +}; + +const ConnectionGate = struct { + mutex: std.Thread.Mutex = .{}, + active: usize = 0, + max_active: usize, + + fn tryAcquire(self: *ConnectionGate) bool { + self.mutex.lock(); + defer self.mutex.unlock(); + if (self.active >= self.max_active) return false; + self.active += 1; + return true; + } + + fn release(self: *ConnectionGate) void { + self.mutex.lock(); + defer self.mutex.unlock(); + self.active -= 1; + } +}; + +const SessionStoreGuard = struct { + mutex: std.Thread.Mutex = .{}, +}; + +const WorkerContext = struct { + allocator: std.mem.Allocator, + app_config: *const config.Config, + server_state: *ServerState, + tool_registry: *tooling.ToolRegistry, + session_store: ?*session.SessionStore, + session_store_guard: *SessionStoreGuard, + gate: *ConnectionGate, }; /// Starts the TCP listener and serves requests indefinitely. @@ -127,23 +230,57 @@ pub fn run( const active_session_store: ?*session.SessionStore = if (session_store) |*store| store else null; - while (true) { - var connection = try server.accept(); - defer connection.stream.close(); - - server_state.active_connections += 1; - defer server_state.active_connections -= 1; + var gate = ConnectionGate{ .max_active = app_config.max_concurrent_connections }; + var session_store_guard = SessionStoreGuard{}; + const worker_context = WorkerContext{ + .allocator = allocator, + .app_config = app_config, + .server_state = &server_state, + .tool_registry = &tool_registry, + .session_store = active_session_store, + .session_store_guard = &session_store_guard, + .gate = &gate, + }; - // Client tracking is intentionally disabled in the hot path on Windows - // until socket stability issues are fully resolved. + while (true) { + const connection = try server.accept(); + if (!gate.tryAcquire()) { + sendJsonSafe( + connection, + 503, + "{\"error\":{\"message\":\"Server is at connection capacity\",\"type\":\"server_error\",\"param\":null,\"code\":\"connection_capacity\"}}", + app_config, + ); + connection.stream.close(); + continue; + } - handleConnection(allocator, app_config, &connection, &server_state, &tool_registry, active_session_store) catch |err| { - logError("Request handling error: {s}", .{@errorName(err)}); - server_state.failed_requests += 1; - }; + var worker = try std.Thread.spawn(.{}, connectionWorkerMain, .{ worker_context, connection }); + worker.detach(); } } +fn connectionWorkerMain(context: WorkerContext, connection: std.net.Server.Connection) void { + defer context.gate.release(); + var worker_connection = connection; + defer worker_connection.stream.close(); + context.server_state.noteConnectionOpened(); + defer context.server_state.noteConnectionClosed(); + + handleConnection( + context.allocator, + context.app_config, + &worker_connection, + context.server_state, + context.tool_registry, + context.session_store, + context.session_store_guard, + ) catch |err| { + logError("Request handling error: {s}", .{@errorName(err)}); + context.server_state.noteRequestFailed(); + }; +} + fn handleConnection( allocator: std.mem.Allocator, app_config: *const config.Config, @@ -151,6 +288,7 @@ fn handleConnection( server_state: *ServerState, tool_registry: *tooling.ToolRegistry, session_store: ?*session.SessionStore, + session_store_guard: *SessionStoreGuard, ) !void { // Translate transport-level parsing failures into OpenAI-style API errors. const request_raw = request.readHttpRequest(allocator, connection, app_config.request_timeout_ms) catch |err| switch (err) { @@ -209,7 +347,7 @@ fn handleConnection( defer allocator.free(request_raw); if (request_raw.len == 0) return; - server_state.total_requests += 1; + server_state.noteRequestStarted(); debugLog( app_config, @@ -222,19 +360,19 @@ fn handleConnection( "Malformed HTTP request", "invalid_http_request", ), app_config); - server_state.failed_requests += 1; + server_state.noteRequestFailed(); return; }; if (route == null) { sendApiErrorSafe(connection.*, allocator, backend.errors.notFoundError(), app_config); - server_state.failed_requests += 1; + server_state.noteRequestFailed(); return; } if (requiresAuth(route.?) and auth.authorizeRequest(app_config.auth_api_key, request_raw) == .denied) { sendJsonSafe(connection.*, 401, "{\"error\":{\"message\":\"Unauthorized\",\"type\":\"auth_error\",\"param\":null,\"code\":\"unauthorized\"}}", app_config); - server_state.failed_requests += 1; + server_state.noteRequestFailed(); return; } @@ -247,62 +385,57 @@ fn handleConnection( ); defer allocator.free(health_json); sendJsonSafe(connection.*, 200, health_json, app_config); - server_state.successful_requests += 1; + server_state.noteRequestSucceeded(); return; }, .metrics => { + const snapshot = server_state.snapshot(); const metrics_json = try std.fmt.allocPrint( allocator, "{{\"instance_id\":\"{s}\",\"total_requests\":{d},\"successful_requests\":{d},\"failed_requests\":{d},\"active_connections\":{d}}}", .{ app_config.instance_id, - server_state.total_requests, - server_state.successful_requests, - server_state.failed_requests, - server_state.active_connections, + snapshot.total_requests, + snapshot.successful_requests, + snapshot.failed_requests, + snapshot.active_connections, }, ); defer allocator.free(metrics_json); sendJsonSafe(connection.*, 200, metrics_json, app_config); - server_state.successful_requests += 1; + server_state.noteRequestSucceeded(); return; }, .diagnostics_clients => { - var clients_json = std.ArrayList(u8){}; - defer clients_json.deinit(allocator); - try clients_json.append(allocator, '['); - for (server_state.connected_clients.items, 0..) |client, idx| { - if (idx > 0) try clients_json.append(allocator, ','); - const escaped = try response.escapeJsonStringAlloc(allocator, client); - defer allocator.free(escaped); - try clients_json.writer(allocator).print("\"{s}\"", .{escaped}); - } - try clients_json.append(allocator, ']'); + const snapshot = server_state.snapshot(); + const clients_json = try server_state.knownClientsJsonAlloc(allocator); + defer allocator.free(clients_json); const payload = try std.fmt.allocPrint( allocator, "{{\"instance_id\":\"{s}\",\"active_connections\":{d},\"known_clients\":{s}}}", - .{ app_config.instance_id, server_state.active_connections, clients_json.items }, + .{ app_config.instance_id, snapshot.active_connections, clients_json }, ); defer allocator.free(payload); sendJsonSafe(connection.*, 200, payload, app_config); - server_state.successful_requests += 1; + server_state.noteRequestSucceeded(); return; }, .diagnostics_requests => { + const snapshot = server_state.snapshot(); const payload = try std.fmt.allocPrint( allocator, "{{\"instance_id\":\"{s}\",\"total_requests\":{d},\"successful_requests\":{d},\"failed_requests\":{d}}}", .{ app_config.instance_id, - server_state.total_requests, - server_state.successful_requests, - server_state.failed_requests, + snapshot.total_requests, + snapshot.successful_requests, + snapshot.failed_requests, }, ); defer allocator.free(payload); sendJsonSafe(connection.*, 200, payload, app_config); - server_state.successful_requests += 1; + server_state.noteRequestSucceeded(); return; }, .diagnostics_providers => { @@ -317,7 +450,7 @@ fn handleConnection( defer allocator.free(payload); sendJsonSafe(connection.*, 200, payload, app_config); - server_state.successful_requests += 1; + server_state.noteRequestSucceeded(); return; }, .chat_completions => {}, @@ -328,7 +461,7 @@ fn handleConnection( "Missing request body", "missing_body", ), app_config); - server_state.failed_requests += 1; + server_state.noteRequestFailed(); return; }; @@ -339,7 +472,7 @@ fn handleConnection( .ok => |parsed_request| parsed_request, .err => |api_error| { sendApiErrorSafe(connection.*, allocator, api_error, app_config); - server_state.failed_requests += 1; + server_state.noteRequestFailed(); return; }, }; @@ -351,7 +484,7 @@ fn handleConnection( "tools", "unknown_tool", ), app_config); - server_state.failed_requests += 1; + server_state.noteRequestFailed(); return; } @@ -371,10 +504,12 @@ fn handleConnection( if (session_store) |store| { if (parsed_req.session_id) |session_id| { + session_store_guard.mutex.lock(); loaded_session = store.load(allocator, session_id, parsed_req.tenant_id) catch |err| blk: { logError("Session load failed for '{s}': {s}", .{ session_id, @errorName(err) }); break :blk null; }; + session_store_guard.mutex.unlock(); if (loaded_session) |state| { if (state.messages.len > 0) { @@ -429,7 +564,7 @@ fn handleConnection( "stream", "unsupported_stream_provider", ), app_config); - server_state.failed_requests += 1; + server_state.noteRequestFailed(); return; } @@ -463,24 +598,24 @@ fn handleConnection( app_config.loop_stream_progress_enabled, ) catch |err| { logError("Provider stream loop error: {s}", .{@errorName(err)}); - server_state.failed_requests += 1; + server_state.noteRequestFailed(); return; }; - server_state.successful_requests += 1; + server_state.noteRequestSucceeded(); return; } const ollama_qwen = @import("../providers/ollama_qwen.zig"); const stream_result = ollama_qwen.streamQwenToSse(connection.*, allocator, app_config, stream_request) catch |err| { logError("Provider stream error: {s}", .{@errorName(err)}); - server_state.failed_requests += 1; + server_state.noteRequestFailed(); return; }; switch (stream_result) { .streamed => { - server_state.successful_requests += 1; + server_state.noteRequestSucceeded(); return; }, .failed => |provider_error_response| { @@ -489,7 +624,7 @@ fn handleConnection( provider_error_response.output, "provider_error", ), app_config); - server_state.failed_requests += 1; + server_state.noteRequestFailed(); return; }, } @@ -503,7 +638,7 @@ fn handleConnection( .{parsed_req.tool_choice orelse "(none)"}, ); sendChatCompletionSafe(connection.*, allocator, tool_result, app_config); - server_state.successful_requests += 1; + server_state.noteRequestSucceeded(); return; } @@ -555,7 +690,7 @@ fn handleConnection( backend.errors.providerTransportError(@errorName(err)), app_config, ); - server_state.failed_requests += 1; + server_state.noteRequestFailed(); return; }; @@ -571,7 +706,7 @@ fn handleConnection( backend.errors.providerTransportError(@errorName(err)), app_config, ); - server_state.failed_requests += 1; + server_state.noteRequestFailed(); return; }; result_ready = true; @@ -602,7 +737,7 @@ fn handleConnection( backend.errors.providerFailureFromDetail(normalized_provider, result.output), app_config, ); - server_state.failed_requests += 1; + server_state.noteRequestFailed(); return; } @@ -627,7 +762,7 @@ fn handleConnection( ), app_config, ); - server_state.failed_requests += 1; + server_state.noteRequestFailed(); return; } } @@ -674,15 +809,17 @@ fn handleConnection( state.message_count = state.messages.len; defer state.deinit(allocator); + session_store_guard.mutex.lock(); store.save(allocator, state) catch |err| { logError("Session save failed for '{s}': {s}", .{ session_id, @errorName(err) }); }; + session_store_guard.mutex.unlock(); } } } sendChatCompletionSafe(connection.*, allocator, result, app_config); - server_state.successful_requests += 1; + server_state.noteRequestSucceeded(); } fn cloneRequestWithMessagesAlloc( @@ -1162,3 +1299,13 @@ fn logInfo(comptime format: []const u8, args: anytype) void { fn logError(comptime format: []const u8, args: anytype) void { std.debug.print("[error] " ++ format ++ "\n", args); } + +test "connection gate enforces bounded capacity" { + var gate = ConnectionGate{ .max_active = 2 }; + try std.testing.expect(gate.tryAcquire()); + try std.testing.expect(gate.tryAcquire()); + try std.testing.expect(!gate.tryAcquire()); + + gate.release(); + try std.testing.expect(gate.tryAcquire()); +} diff --git a/src/tools/command_exec.zig b/src/tools/command_exec.zig index 25b508c..d6f8a55 100644 --- a/src/tools/command_exec.zig +++ b/src/tools/command_exec.zig @@ -505,6 +505,7 @@ pub fn buildTestConfig(allocator: std.mem.Allocator, enable_exec: bool) !config. .tool_exec_timeout_ms = 15_000, .tool_exec_max_output_bytes = 65_536, .loop_stream_progress_enabled = true, + .max_concurrent_connections = 64, }; } From 08403d73e00aadd97b6ee77e47a453cace6ddb1d Mon Sep 17 00:00:00 2001 From: Ajeets6 Date: Fri, 1 May 2026 16:19:45 +1000 Subject: [PATCH 03/25] Increase cmd exec security --- README.md | 6 +- src/tools/command_exec.zig | 144 ++++++++++++++++++++++++++++++++----- 2 files changed, 131 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index ebfd0c5..f909559 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ OpenAI-compatible LLM router and prompt-loop runtime written in Zig. -This project exposes a chat-completions API, routes requests across multiple providers, supports optional streaming, and includes built-in loop controls for iterative agent workflows. +This project exposes a chat-completions API, routes requests across multiple providers, supports optional streaming, and includes built-in loop controls for iterative agent workflows. Scope: **thin, reliable model harness** (see `phase.md` Phase 0 and `plan.md` for the frozen feature set and deferrals). > [!TIP] > Fastest local smoke test: @@ -19,6 +19,10 @@ This project exposes a chat-completions API, routes requests across multiple pro - Optional debug tooling (echo, utc, cmd, bash) with safe defaults. - Simple deploy surface: single Zig binary, environment-based configuration. +## Scope freeze (Phase 0) + +Phase 0 is **complete**: one canonical chat path and response shape, with deferrals listed in `plan.md` (distributed orchestration, deep tenancy/RBAC, heavy observability, plugin marketplaces). Operational routes (`/health`, `/metrics`, `/diagnostics/*`) stay diagnostics-only. + ## Architecture ```mermaid diff --git a/src/tools/command_exec.zig b/src/tools/command_exec.zig index d6f8a55..6e4c6ae 100644 --- a/src/tools/command_exec.zig +++ b/src/tools/command_exec.zig @@ -17,6 +17,12 @@ const CommandValidationError = error{ DangerousPattern, }; +const PipeCapture = struct { + bytes: []u8 = &.{}, + err_name: ?[]const u8 = null, + done: std.atomic.Value(bool) = std.atomic.Value(bool).init(false), +}; + pub fn execute( allocator: std.mem.Allocator, app_config: *const config.Config, @@ -95,17 +101,12 @@ pub fn execute( .bash => [_][]const u8{ "bash", "-lc", command }, }; - // TODO(security): `Child.run` blocks until the child exits and provides no - // built-in timeout. A hung command (e.g. `sleep 9999`) will hold the - // server's single accept-loop thread open for its entire duration. - // Replace this call with `Child.spawn` + async I/O polling + `Child.kill` - // once per-connection threading is in place. Until then, `tool_exec_enabled` - // must remain false (the default) in any exposed deployment. - const run_result = std.process.Child.run(.{ - .allocator = allocator, - .argv = &argv, - .max_output_bytes = app_config.tool_exec_max_output_bytes, - }) catch |err| { + var child = std.process.Child.init(&argv, allocator); + child.stdin_behavior = .Ignore; + child.stdout_behavior = .Pipe; + child.stderr_behavior = .Pipe; + + child.spawn() catch |err| { return try makeToolResponse( allocator, tool_name, @@ -116,8 +117,55 @@ pub fn execute( ), ); }; - defer allocator.free(run_result.stdout); - defer allocator.free(run_result.stderr); + + var stdout_capture = PipeCapture{}; + var stderr_capture = PipeCapture{}; + + var stdout_thread: ?std.Thread = null; + var stderr_thread: ?std.Thread = null; + + if (child.stdout) |stdout_file| { + stdout_thread = try std.Thread.spawn(.{}, capturePipeMain, .{ + &stdout_capture, + stdout_file, + app_config.tool_exec_max_output_bytes, + }); + } else { + stdout_capture.done.store(true, .release); + } + if (child.stderr) |stderr_file| { + stderr_thread = try std.Thread.spawn(.{}, capturePipeMain, .{ + &stderr_capture, + stderr_file, + app_config.tool_exec_max_output_bytes, + }); + } else { + stderr_capture.done.store(true, .release); + } + + const timeout_deadline = start_ms + @as(i64, app_config.tool_exec_timeout_ms); + var timed_out = false; + while (true) { + const stdout_done = stdout_capture.done.load(.acquire); + const stderr_done = stderr_capture.done.load(.acquire); + if (stdout_done and stderr_done) break; + if (std.time.milliTimestamp() >= timeout_deadline) { + timed_out = true; + break; + } + std.Thread.sleep(2 * std.time.ns_per_ms); + } + + const terminated_as: std.process.Child.Term = if (timed_out) + child.kill() catch .{ .Unknown = 1 } + else + child.wait() catch .{ .Unknown = 1 }; + + if (stdout_thread) |thread| thread.join(); + if (stderr_thread) |thread| thread.join(); + + defer if (stdout_capture.bytes.len > 0) std.heap.page_allocator.free(stdout_capture.bytes); + defer if (stderr_capture.bytes.len > 0) std.heap.page_allocator.free(stderr_capture.bytes); const end_ms = std.time.milliTimestamp(); const duration_ms = if (end_ms >= start_ms) @@ -125,13 +173,26 @@ pub fn execute( else @as(u64, 0); - const term_text = try childTermToTextAlloc(allocator, run_result.term); + const term_text = try childTermToTextAlloc(allocator, terminated_as); defer allocator.free(term_text); - const status_tag = if (isSuccessfulTermination(run_result.term)) "DEBUG_TOOL_OK" else "DEBUG_TOOL_ERROR"; + const status_tag = if (!timed_out and isSuccessfulTermination(terminated_as)) "DEBUG_TOOL_OK" else "DEBUG_TOOL_ERROR"; + const timeout_exceeded_text = if (timed_out) "true" else "false"; + const stdout_text = if (stdout_capture.err_name) |err_name| + try std.fmt.allocPrint(allocator, "read_error={s}", .{err_name}) + else + try allocator.dupe(u8, stdout_capture.bytes); + defer allocator.free(stdout_text); + + const stderr_text = if (stderr_capture.err_name) |err_name| + try std.fmt.allocPrint(allocator, "read_error={s}", .{err_name}) + else + try allocator.dupe(u8, stderr_capture.bytes); + defer allocator.free(stderr_text); + const output = try std.fmt.allocPrint( allocator, - "{s}\ntool={s}\nterm={s}\nduration_ms={d}\nmax_output_bytes={d}\ntimeout_policy_ms={d}\ntimeout_enforced=false\ncommand={s}\n--- stdout ---\n{s}\n--- stderr ---\n{s}", + "{s}\ntool={s}\nterm={s}\nduration_ms={d}\nmax_output_bytes={d}\ntimeout_policy_ms={d}\ntimeout_enforced=true\ntimeout_exceeded={s}\ncode={s}\ncommand={s}\n--- stdout ---\n{s}\n--- stderr ---\n{s}", .{ status_tag, tool_name, @@ -139,15 +200,26 @@ pub fn execute( duration_ms, app_config.tool_exec_max_output_bytes, app_config.tool_exec_timeout_ms, + timeout_exceeded_text, + if (timed_out) "command_timeout" else "command_completed", command, - run_result.stdout, - run_result.stderr, + stdout_text, + stderr_text, }, ); return try makeToolResponse(allocator, tool_name, output); } +fn capturePipeMain(capture: *PipeCapture, file: std.fs.File, max_bytes: usize) void { + defer capture.done.store(true, .release); + capture.bytes = file.readToEndAlloc(std.heap.page_allocator, max_bytes) catch |err| { + capture.err_name = @errorName(err); + capture.bytes = &.{}; + return; + }; +} + pub fn extractCommandFromPromptAlloc( allocator: std.mem.Allocator, flavor: ShellFlavor, @@ -621,3 +693,39 @@ test "execute cmd tool blocks dangerous denylist command" { try std.testing.expect(std.mem.indexOf(u8, result.output, "DEBUG_TOOL_ERROR") != null); try std.testing.expect(std.mem.indexOf(u8, result.output, "denylist") != null); } + +test "execute cmd tool enforces timeout and kills process" { + if (@import("builtin").os.tag != .windows) return; + + const allocator = std.testing.allocator; + var cfg = try buildTestConfig(allocator, true); + defer cfg.deinit(allocator); + cfg.tool_exec_timeout_ms = 200; + + const messages = try allocator.alloc(types.Message, 1); + messages[0] = .{ + .role = try allocator.dupe(u8, "user"), + .content = try allocator.dupe(u8, "powershell -NoProfile -Command Start-Sleep -Seconds 3"), + }; + + const req = types.Request{ + .prompt = try allocator.dupe(u8, "powershell -NoProfile -Command Start-Sleep -Seconds 3"), + .messages = messages, + .provider = null, + .model = null, + .session_id = null, + .tenant_id = null, + .max_context_tokens = null, + .tools = try allocator.alloc(types.Tool, 0), + .tool_choice = null, + }; + defer req.deinit(allocator); + + var result = try execute(allocator, &cfg, req, .cmd); + defer result.deinit(allocator); + + try std.testing.expect(std.mem.indexOf(u8, result.output, "DEBUG_TOOL_ERROR") != null); + try std.testing.expect(std.mem.indexOf(u8, result.output, "timeout_enforced=true") != null); + try std.testing.expect(std.mem.indexOf(u8, result.output, "timeout_exceeded=true") != null); + try std.testing.expect(std.mem.indexOf(u8, result.output, "code=command_timeout") != null); +} From 143e1cad2e927d674aa7676870783ed4374ca5f2 Mon Sep 17 00:00:00 2001 From: temp-audit Date: Tue, 5 May 2026 20:34:09 +1000 Subject: [PATCH 04/25] Add ReAct mode support with action parsing and observation formatting --- .gitignore | 1 + README.md | 51 ++++++- src/backend/api.zig | 4 +- src/core/server.zig | 146 +++++++++++++++++--- src/main.zig | 74 +++++++++- src/react.zig | 327 ++++++++++++++++++++++++++++++++++++++++++++ src/root.zig | 5 + 7 files changed, 582 insertions(+), 26 deletions(-) create mode 100644 src/react.zig diff --git a/.gitignore b/.gitignore index e8d928e..3139cb2 100644 --- a/.gitignore +++ b/.gitignore @@ -32,3 +32,4 @@ Network Trash Folder Temporary Items .apdisk .llm-router-env +.github/instructions/ \ No newline at end of file diff --git a/README.md b/README.md index ebfd0c5..1c3d47f 100644 --- a/README.md +++ b/README.md @@ -159,11 +159,59 @@ Loop controls: - --provider provider override - --until completion marker (default: DONE) - --max-turns loop safety cap (default: 8) -- --loop-mode loop style +- --loop-mode loop style - --agent-loop shorthand for agent mode +- --react shorthand for ReAct reasoning mode - --use-env load .env - --env-file load a custom dotenv file +## ReAct Mode + +ReAct (Reasoning + Acting) mode implements the paradigm from [Yao et al., 2022](https://arxiv.org/abs/2210.03629). The model produces structured **Thought → Action** pairs, and the system executes each action and injects the result as an **Observation** before the next turn. + +### Available Actions + +| Action | Description | Requirements | +| --------- | -------------------------------- | --------------------------------------- | +| Search[q] | Search for information (stub) | None (future: wire to search tool/API) | +| Lookup[t] | Look up a term in context (stub) | None (future: wire to retrieval backend) | +| Cmd[c] | Execute a shell command | `LLM_ROUTER_TOOL_EXEC_ENABLED=1` | +| Finish[a] | Return final answer and stop loop | None | + +> [!NOTE] +> Search and Lookup return stub responses. They are designed as extension points for future tool/API integration. + +### CLI Example + +```bash +zig build run -- --react --prompt "What is the elevation range of the High Plains?" --provider ollama +``` + +### API Example + +```json +{ + "messages": [{ "role": "user", "content": "What is the elevation range of the High Plains?" }], + "loop_mode": "react", + "loop_max_turns": 10 +} +``` + +The model will produce output like: + +``` +Thought 1: I need to search for the High Plains elevation range. +Action 1: Search[High Plains elevation] +``` + +The system injects: + +``` +Observation 1: +``` + +This continues until the model emits `Action N: Finish[answer]` or the turn budget is exhausted. + ## Tools Registered tool names: @@ -326,6 +374,7 @@ zig_coding_agent │ ├── openai.zig │ ├── openai_compatible.zig │ └── openrouter.zig + ├── react.zig ├── root.zig ├── tools │ ├── command_exec.zig diff --git a/src/backend/api.zig b/src/backend/api.zig index 20ec499..6c0ece8 100644 --- a/src/backend/api.zig +++ b/src/backend/api.zig @@ -548,9 +548,9 @@ pub fn parseChatRequest( ) }; } - if (!std.ascii.eqlIgnoreCase(trimmed, "basic") and !std.ascii.eqlIgnoreCase(trimmed, "agent")) { + if (!std.ascii.eqlIgnoreCase(trimmed, "basic") and !std.ascii.eqlIgnoreCase(trimmed, "agent") and !std.ascii.eqlIgnoreCase(trimmed, "react")) { return .{ .err = errors.validationError( - "loop_mode must be one of: basic, agent", + "loop_mode must be one of: basic, agent, react", "loop_mode", "invalid_loop_mode", ) }; diff --git a/src/core/server.zig b/src/core/server.zig index 6c60671..28649c5 100644 --- a/src/core/server.zig +++ b/src/core/server.zig @@ -9,6 +9,7 @@ const backend = @import("../backend/api.zig"); const auth = @import("../backend/auth.zig"); const tooling = @import("../backend/tools.zig"); const session = @import("../backend/session.zig"); +const react = @import("../react.zig"); const ServerState = struct { total_requests: u64 = 0, @@ -781,6 +782,18 @@ fn streamLoopRequestToSse( } allocator.free(working_messages); working_messages = with_guidance; + } else if (std.ascii.eqlIgnoreCase(loop_mode, "react")) { + const with_react = try appendRoleMessageAlloc( + allocator, + working_messages, + "system", + react.system_prompt, + ); + for (working_messages) |message| { + message.deinit(allocator); + } + allocator.free(working_messages); + working_messages = with_react; } var latest_user_prompt = try allocator.dupe(u8, base_request.prompt); @@ -849,10 +862,38 @@ fn streamLoopRequestToSse( } previous_output = try allocator.dupe(u8, normalized_output); + const is_react_mode = std.ascii.eqlIgnoreCase(loop_mode, "react"); const reached_until = std.mem.indexOf(u8, turn_result.output, loop_until) != null; const reached_max = (turn + 1) >= loop_max_turns; - const repeated_stop = std.ascii.eqlIgnoreCase(loop_mode, "agent") and repeated_count >= 1; - const should_stop = reached_until or reached_max or repeated_stop or !turn_result.success; + const repeated_stop = (std.ascii.eqlIgnoreCase(loop_mode, "agent") or is_react_mode) and repeated_count >= 1; + + // ReAct: check for Finish action before standard stop checks. + var react_finished = false; + var react_observation: ?[]u8 = null; + defer if (react_observation) |obs| allocator.free(obs); + + if (is_react_mode) { + if (react.parseReactAction(turn_result.output)) |action| { + switch (action) { + .finish => { + react_finished = true; + }, + else => { + const obs_raw = try react.executeReactAction(allocator, action, app_config); + defer allocator.free(obs_raw); + react_observation = try react.formatObservation(allocator, turn + 1, obs_raw); + }, + } + } else { + react_observation = try react.formatObservation( + allocator, + turn + 1, + "Could not parse an Action from your response. Please use the format: Action N: Type[argument]", + ); + } + } + + const should_stop = react_finished or reached_until or reached_max or repeated_stop or !turn_result.success; if (should_stop) { if (!emit_progress and turn_result.output.len > 0) { @@ -879,13 +920,32 @@ fn streamLoopRequestToSse( } allocator.free(latest_user_prompt); - latest_user_prompt = try allocator.dupe( - u8, - if (std.ascii.eqlIgnoreCase(loop_mode, "agent")) - "Critique your previous answer briefly, then improve it with concrete next steps. If complete, include the completion marker exactly and return the final result." - else - "Continue.", - ); + + if (is_react_mode) { + // Inject the observation as the next user prompt. + const obs = react_observation orelse try allocator.dupe(u8, "Continue."); + react_observation = null; // Transfer ownership. + latest_user_prompt = obs; + + if (emit_progress) { + try response.sendChatCompletionChunkSse( + connection, + allocator, + completion_id, + turn_result.model, + latest_user_prompt, + null, + ); + } + } else { + latest_user_prompt = try allocator.dupe( + u8, + if (std.ascii.eqlIgnoreCase(loop_mode, "agent")) + "Critique your previous answer briefly, then improve it with concrete next steps. If complete, include the completion marker exactly and return the final result." + else + "Continue.", + ); + } const with_continue = try appendRoleMessageAlloc(allocator, working_messages, "user", latest_user_prompt); for (working_messages) |message| { @@ -927,6 +987,18 @@ fn executeLoopRequestAlloc( } allocator.free(working_messages); working_messages = with_guidance; + } else if (std.ascii.eqlIgnoreCase(loop_mode, "react")) { + const with_react = try appendRoleMessageAlloc( + allocator, + working_messages, + "system", + react.system_prompt, + ); + for (working_messages) |message| { + message.deinit(allocator); + } + allocator.free(working_messages); + working_messages = with_react; } var latest_user_prompt = try allocator.dupe(u8, base_request.prompt); @@ -992,16 +1064,58 @@ fn executeLoopRequestAlloc( return .{ .response = turn_result, .messages = working_messages }; } + // ReAct mode: parse action, execute, inject observation. + const is_react_mode = std.ascii.eqlIgnoreCase(loop_mode, "react"); + var react_observation: ?[]u8 = null; + var react_finished = false; + defer if (react_observation) |obs| allocator.free(obs); + + if (is_react_mode) { + if (repeated_count >= 1) { + return .{ .response = turn_result, .messages = working_messages }; + } + + if (react.parseReactAction(turn_result.output)) |action| { + switch (action) { + .finish => { + react_finished = true; + }, + else => { + const obs_raw = try react.executeReactAction(allocator, action, app_config); + defer allocator.free(obs_raw); + react_observation = try react.formatObservation(allocator, turn + 1, obs_raw); + }, + } + } else { + react_observation = try react.formatObservation( + allocator, + turn + 1, + "Could not parse an Action from your response. Please use the format: Action N: Type[argument]", + ); + } + + if (react_finished) { + return .{ .response = turn_result, .messages = working_messages }; + } + } + turn_result.deinit(allocator); allocator.free(latest_user_prompt); - latest_user_prompt = try allocator.dupe( - u8, - if (std.ascii.eqlIgnoreCase(loop_mode, "agent")) - "Critique your previous answer briefly, then improve it with concrete next steps. If complete, include the completion marker exactly and return the final result." - else - "Continue.", - ); + + if (is_react_mode) { + const obs = react_observation orelse try allocator.dupe(u8, "Continue."); + react_observation = null; // Transfer ownership. + latest_user_prompt = obs; + } else { + latest_user_prompt = try allocator.dupe( + u8, + if (std.ascii.eqlIgnoreCase(loop_mode, "agent")) + "Critique your previous answer briefly, then improve it with concrete next steps. If complete, include the completion marker exactly and return the final result." + else + "Continue.", + ); + } const with_continue = try appendRoleMessageAlloc(allocator, working_messages, "user", latest_user_prompt); for (working_messages) |message| { diff --git a/src/main.zig b/src/main.zig index 93a3520..972d36a 100644 --- a/src/main.zig +++ b/src/main.zig @@ -1,10 +1,12 @@ //! CLI entrypoint for running the router in server mode or prompt-loop mode. const std = @import("std"); const root = @import("root.zig"); +const react = @import("react.zig"); const LoopMode = enum { basic, agent, + react, }; /// Parsed CLI options with owned heap-allocated string fields. @@ -86,6 +88,11 @@ fn parseCliOptions(allocator: std.mem.Allocator) !CliOptions { continue; } + if (std.mem.eql(u8, arg, "--react")) { + cli.loop_mode = .react; + continue; + } + if (std.mem.eql(u8, arg, "--loop-mode")) { const value = args.next() orelse return error.MissingLoopModeValue; cli.loop_mode = try parseLoopMode(value); @@ -198,6 +205,7 @@ fn parseCliOptions(allocator: std.mem.Allocator) !CliOptions { fn parseLoopMode(value: []const u8) !LoopMode { if (std.ascii.eqlIgnoreCase(value, "basic")) return .basic; if (std.ascii.eqlIgnoreCase(value, "agent")) return .agent; + if (std.ascii.eqlIgnoreCase(value, "react")) return .react; return error.InvalidLoopMode; } @@ -290,11 +298,19 @@ fn runPromptLoop( "system", "You are running in an iterative agent loop. Improve the solution every turn. Use concise self-critique before each improvement. When the task is complete, include the completion marker exactly.", ); + } else if (loop_mode == .react) { + try appendConversationMessage( + allocator, + &conversation, + "system", + react.system_prompt, + ); } try appendConversationMessage(allocator, &conversation, "user", initial_prompt); - var latest_user_prompt = initial_prompt; + var latest_user_prompt = try allocator.dupe(u8, initial_prompt); + defer allocator.free(latest_user_prompt); var turn: usize = 0; var repeated_count: usize = 0; var previous_output: ?[]u8 = null; @@ -346,7 +362,7 @@ fn runPromptLoop( } previous_output = try allocator.dupe(u8, normalized_output); - if (loop_mode == .agent and repeated_count >= 1) { + if ((loop_mode == .agent or loop_mode == .react) and repeated_count >= 1) { std.debug.print("Loop stopped early: model repeated output without progress.\n", .{}); return; } @@ -356,11 +372,55 @@ fn runPromptLoop( return; } - latest_user_prompt = switch (loop_mode) { - .basic => "Continue.", - .agent => "Critique your previous answer briefly, then improve it with concrete next steps. If complete, include the completion marker exactly and return the final result.", - }; - try appendConversationMessage(allocator, &conversation, "user", latest_user_prompt); + if (loop_mode == .react) { + // ReAct mode: parse the Action, execute it, inject the Observation. + const parsed_action = react.parseReactAction(result.output); + if (parsed_action) |action| { + switch (action) { + .finish => |answer| { + std.debug.print("ReAct finished:\n{s}\n", .{answer}); + return; + }, + else => { + const observation_raw = try react.executeReactAction(allocator, action, app_config); + defer allocator.free(observation_raw); + + const observation = try react.formatObservation(allocator, turn + 1, observation_raw); + defer allocator.free(observation); + + std.debug.print("{s}\n\n", .{observation}); + + allocator.free(latest_user_prompt); + latest_user_prompt = observation; + try appendConversationMessage(allocator, &conversation, "user", observation); + }, + } + } else { + // No parseable action found. Give the model a hint. + const hint = try react.formatObservation( + allocator, + turn + 1, + "Could not parse an Action from your response. Please use the format: Action N: Type[argument]", + ); + defer allocator.free(hint); + + std.debug.print("{s}\n\n", .{hint}); + allocator.free(latest_user_prompt); + latest_user_prompt = hint; + try appendConversationMessage(allocator, &conversation, "user", hint); + } + } else { + allocator.free(latest_user_prompt); + latest_user_prompt = try allocator.dupe( + u8, + switch (loop_mode) { + .basic => "Continue.", + .agent => "Critique your previous answer briefly, then improve it with concrete next steps. If complete, include the completion marker exactly and return the final result.", + .react => unreachable, + }, + ); + try appendConversationMessage(allocator, &conversation, "user", latest_user_prompt); + } } } diff --git a/src/react.zig b/src/react.zig new file mode 100644 index 0000000..238d5ad --- /dev/null +++ b/src/react.zig @@ -0,0 +1,327 @@ +//! ReAct (Reasoning + Acting) loop support. +//! +//! Implements the Thought → Action → Observation paradigm from Yao et al., 2022. +//! The model produces structured Thought/Action pairs, the harness executes the +//! action and injects the result as an Observation, and the cycle repeats until +//! the model emits Finish[answer] or the turn budget is exhausted. +const std = @import("std"); +const config = @import("config.zig"); +const types = @import("types.zig"); +const command_exec_tool = @import("tools/command_exec.zig"); + +/// System prompt injected at the start of every ReAct loop to instruct the +/// model on the required Thought/Action/Observation format. +pub const system_prompt = + "You are running in ReAct (Reasoning + Acting) mode.\n" ++ + "You MUST follow this format strictly for every step:\n\n" ++ + "Thought N: \n" ++ + "Action N: \n\n" ++ + "After each Action the system will inject:\n" ++ + "Observation N: \n\n" ++ + "Available actions:\n" ++ + " Search[query] - search for information\n" ++ + " Lookup[term] - look up a term in the current context\n" ++ + " Cmd[command] - execute a shell command (only if tools are enabled)\n" ++ + " Finish[answer] - return the final answer and end the loop\n\n" ++ + "Rules:\n" ++ + "- Always start with a Thought, then an Action.\n" ++ + "- Never produce an Observation yourself; the system provides them.\n" ++ + "- Use Finish[answer] when you have the final answer.\n" ++ + "- Be concrete: use specific names, values, steps, and short outputs instead of vague summaries.\n" ++ + "- If you need information, ask for the exact term or command that will resolve it.\n" ++ + "- If an action fails, adjust your approach in the next Thought.\n"; + +/// Parsed action extracted from a model response. +pub const ReactAction = union(enum) { + search: []const u8, + lookup: []const u8, + cmd: []const u8, + finish: []const u8, + unknown: []const u8, +}; + +/// Scans model output for the last `Action [N]: Type[arg]` line and returns +/// the parsed action variant. +/// +/// Parsing rules: +/// - Case-insensitive match on the `Action` keyword. +/// - The turn number after `Action` is optional (the harness tracks turns). +/// - Brackets are mandatory for the argument. +/// - Greedy bracket matching: takes everything up to the last `]` on the line. +/// - Returns `null` when no action line is found. +pub fn parseReactAction(output: []const u8) ?ReactAction { + // Scan backwards through lines to find the last Action directive. + var last_action_line: ?[]const u8 = null; + var lines = std.mem.splitAny(u8, output, "\n\r"); + while (lines.next()) |raw_line| { + const line = std.mem.trim(u8, raw_line, " \t"); + if (line.len == 0) continue; + if (startsWithActionDirective(line)) { + last_action_line = line; + } + } + + const action_line = last_action_line orelse return null; + + // Extract the part after "Action" (and optional number + colon). + const after_keyword = action_line[6..]; // "Action" is 6 chars + const after_prefix = skipActionPrefix(after_keyword); + const trimmed = std.mem.trim(u8, after_prefix, " \t"); + + if (trimmed.len == 0) return null; + + // Find the bracket pair: Type[arg] + const open_bracket = std.mem.indexOfScalar(u8, trimmed, '[') orelse return null; + // Greedy: find the last ']' on the line. + const close_bracket = std.mem.lastIndexOfScalar(u8, trimmed, ']') orelse return null; + if (close_bracket <= open_bracket) return null; + + const action_type = std.mem.trim(u8, trimmed[0..open_bracket], " \t"); + const argument = trimmed[open_bracket + 1 .. close_bracket]; + + if (action_type.len == 0) return null; + + if (eqlIgnoreCase(action_type, "search")) return .{ .search = argument }; + if (eqlIgnoreCase(action_type, "lookup")) return .{ .lookup = argument }; + if (eqlIgnoreCase(action_type, "cmd")) return .{ .cmd = argument }; + if (eqlIgnoreCase(action_type, "finish")) return .{ .finish = argument }; + + return .{ .unknown = trimmed }; +} + +/// Executes a parsed ReAct action and returns the observation text. +/// +/// - `Search` and `Lookup` return stubs (to be wired to tools/APIs later). +/// - `Cmd` delegates to the existing command_exec tool infrastructure. +/// - `Finish` should be handled by the caller before reaching this function, +/// but returns the answer content as a fallback. +/// - `Unknown` returns an error hint listing available actions. +pub fn executeReactAction( + allocator: std.mem.Allocator, + action: ReactAction, + app_config: *const config.Config, +) ![]u8 { + return switch (action) { + .search => |query| try std.fmt.allocPrint( + allocator, + "Search is not yet implemented. Query was: {s}", + .{query}, + ), + .lookup => |term| try std.fmt.allocPrint( + allocator, + "Lookup is not yet implemented. Term was: {s}", + .{term}, + ), + .cmd => |command| try executeReactCmd(allocator, command, app_config), + .finish => |answer| try allocator.dupe(u8, answer), + .unknown => |text| try std.fmt.allocPrint( + allocator, + "Unknown action: {s}. Available actions: Search, Lookup, Cmd, Finish.", + .{text}, + ), + }; +} + +/// Formats an observation message with turn numbering. +pub fn formatObservation( + allocator: std.mem.Allocator, + turn: usize, + observation_text: []const u8, +) ![]u8 { + return try std.fmt.allocPrint( + allocator, + "Observation {d}: {s}", + .{ turn, observation_text }, + ); +} + +// ── internal helpers ────────────────────────────────────────────────── + +fn executeReactCmd( + allocator: std.mem.Allocator, + command: []const u8, + app_config: *const config.Config, +) ![]u8 { + if (!app_config.tool_exec_enabled) { + return try allocator.dupe( + u8, + "Command execution is disabled. Set LLM_ROUTER_TOOL_EXEC_ENABLED=1 to enable.", + ); + } + + // Build a synthetic single-prompt request for the command tool. + const messages = try allocator.alloc(types.Message, 1); + errdefer allocator.free(messages); + + const role = try allocator.dupe(u8, "user"); + errdefer allocator.free(role); + + const content = try allocator.dupe(u8, command); + errdefer allocator.free(content); + + messages[0] = .{ .role = role, .content = content }; + + const req = types.Request{ + .prompt = try allocator.dupe(u8, command), + .messages = messages, + .provider = null, + .model = null, + .session_id = null, + .tenant_id = null, + .max_context_tokens = null, + .tools = try allocator.alloc(types.Tool, 0), + .tool_choice = null, + }; + defer req.deinit(allocator); + + const shell: command_exec_tool.ShellFlavor = if (@import("builtin").os.tag == .windows) .cmd else .bash; + var result = try command_exec_tool.execute(allocator, app_config, req, shell); + defer result.deinit(allocator); + + // Dupe the output to transfer ownership to the caller; result.deinit + // will free the original. + return try allocator.dupe(u8, result.output); +} + +/// Returns true if the line starts with the exact `Action` keyword. +fn startsWithActionDirective(line: []const u8) bool { + if (line.len < 6) return false; + if (!eqlIgnoreCase(line[0..6], "action")) return false; + + if (line.len == 6) return true; + return switch (line[6]) { + ' ', '\t', ':', '0'...'9' => true, + else => false, + }; +} + +/// Skips past optional whitespace, turn number, and colon after `Action`. +fn skipActionPrefix(input: []const u8) []const u8 { + var i: usize = 0; + + // Skip leading whitespace. + while (i < input.len and (input[i] == ' ' or input[i] == '\t')) : (i += 1) {} + + // Skip optional digits (turn number). + while (i < input.len and std.ascii.isDigit(input[i])) : (i += 1) {} + + while (i < input.len and (input[i] == ' ' or input[i] == '\t')) : (i += 1) {} + + // Skip optional colon. + if (i < input.len and input[i] == ':') { + i += 1; + } + + while (i < input.len and (input[i] == ' ' or input[i] == '\t')) : (i += 1) {} + + return input[i..]; +} + +fn eqlIgnoreCase(a: []const u8, b: []const u8) bool { + return std.ascii.eqlIgnoreCase(a, b); +} + +// ── tests ───────────────────────────────────────────────────────────── + +test "parseReactAction extracts Search action" { + const action = parseReactAction("Thought 1: I need to find info\nAction 1: Search[Colorado orogeny]"); + try std.testing.expect(action != null); + switch (action.?) { + .search => |query| try std.testing.expectEqualStrings("Colorado orogeny", query), + else => return error.UnexpectedActionType, + } +} + +test "parseReactAction extracts Finish action" { + const action = parseReactAction("Action 3: Finish[the answer is 42]"); + try std.testing.expect(action != null); + switch (action.?) { + .finish => |answer| try std.testing.expectEqualStrings("the answer is 42", answer), + else => return error.UnexpectedActionType, + } +} + +test "parseReactAction extracts Cmd action" { + const action = parseReactAction("Action 2: Cmd[zig build test]"); + try std.testing.expect(action != null); + switch (action.?) { + .cmd => |command| try std.testing.expectEqualStrings("zig build test", command), + else => return error.UnexpectedActionType, + } +} + +test "parseReactAction extracts Lookup action" { + const action = parseReactAction("Action 1: Lookup[eastern sector]"); + try std.testing.expect(action != null); + switch (action.?) { + .lookup => |term| try std.testing.expectEqualStrings("eastern sector", term), + else => return error.UnexpectedActionType, + } +} + +test "parseReactAction returns null for no action" { + const action = parseReactAction("Some text with no action line at all"); + try std.testing.expect(action == null); +} + +test "parseReactAction returns unknown for unrecognized action type" { + const action = parseReactAction("Action 1: UnknownAction[foo]"); + try std.testing.expect(action != null); + switch (action.?) { + .unknown => {}, + else => return error.UnexpectedActionType, + } +} + +test "parseReactAction handles nested brackets greedily" { + const action = parseReactAction("Action 5: Finish[1,800 to 7,000 ft]"); + try std.testing.expect(action != null); + switch (action.?) { + .finish => |answer| try std.testing.expectEqualStrings("1,800 to 7,000 ft", answer), + else => return error.UnexpectedActionType, + } +} + +test "parseReactAction is case-insensitive" { + const action = parseReactAction("action 1: search[test query]"); + try std.testing.expect(action != null); + switch (action.?) { + .search => |query| try std.testing.expectEqualStrings("test query", query), + else => return error.UnexpectedActionType, + } +} + +test "parseReactAction rejects actionable text" { + const action = parseReactAction("Actionable note: Search[bad]"); + try std.testing.expect(action == null); +} + +test "parseReactAction tolerates missing turn number" { + const action = parseReactAction("Action: Search[no number]"); + try std.testing.expect(action != null); + switch (action.?) { + .search => |query| try std.testing.expectEqualStrings("no number", query), + else => return error.UnexpectedActionType, + } +} + +test "parseReactAction uses last action line when multiple exist" { + const output = + "Thought 1: first thought\n" ++ + "Action 1: Search[first]\n" ++ + "Thought 2: second thought\n" ++ + "Action 2: Finish[final answer]"; + const action = parseReactAction(output); + try std.testing.expect(action != null); + switch (action.?) { + .finish => |answer| try std.testing.expectEqualStrings("final answer", answer), + else => return error.UnexpectedActionType, + } +} + +test "formatObservation produces correct format" { + const allocator = std.testing.allocator; + const result = try formatObservation(allocator, 3, "The answer is here"); + defer allocator.free(result); + try std.testing.expectEqualStrings("Observation 3: The answer is here", result); +} diff --git a/src/root.zig b/src/root.zig index 5d7c4ad..ee0d664 100644 --- a/src/root.zig +++ b/src/root.zig @@ -12,6 +12,7 @@ pub const tools = @import("backend/tools.zig"); pub const core = @import("core/server.zig"); pub const config = @import("config.zig"); pub const types = @import("types.zig"); +pub const react = @import("react.zig"); // Re-export common error types. pub const ApiError = backend.errors.ApiError; @@ -43,3 +44,7 @@ test "providers: external provider adapters" { test "types: normalization" { _ = @import("types.zig"); } + +test "react: action parsing and observation formatting" { + _ = @import("react.zig"); +} From 33b76e524da214c117915a58fc9b47295353fa02 Mon Sep 17 00:00:00 2001 From: Ajeets6 Date: Fri, 8 May 2026 16:24:50 +1000 Subject: [PATCH 05/25] Add file operations tools write, and search functionality implemented. Update tool registry and enhance request handling for new operations. --- .gitignore | 1 + client/test_client.py | 718 ++++++++++++++++++++++------------------- src/backend/tools.zig | 129 ++++++++ src/core/server.zig | 62 +++- src/tools/file_ops.zig | 459 ++++++++++++++++++++++++++ 5 files changed, 1024 insertions(+), 345 deletions(-) create mode 100644 src/tools/file_ops.zig diff --git a/.gitignore b/.gitignore index bd7061e..4cdd1da 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,7 @@ zig-out .env __pycache__/ logs/* +tmp/ # General .DS_Store __MACOSX/ diff --git a/client/test_client.py b/client/test_client.py index 9766bdc..374bbd8 100644 --- a/client/test_client.py +++ b/client/test_client.py @@ -1,337 +1,381 @@ -import json -from typing import Dict, Generator -from urllib.parse import urlsplit, urlunsplit - -import requests -import streamlit as st - -st.set_page_config(page_title="Zig Agent Test Client", layout="wide") -st.title("Zig AI Harness Test Client") -st.caption("Debug and test harness behavior: chat routing, tool execution, sessions, streaming, and diagnostics.") - - -def build_headers(api_key: str) -> Dict[str, str]: - headers = {"Content-Type": "application/json"} - if api_key.strip(): - headers["x-api-key"] = api_key.strip() - return headers - - -def parse_assistant_content(data: Dict) -> str: - choices = data.get("choices", []) - if not choices: - return "" - message = choices[0].get("message", {}) - return message.get("content", "") - - -def base_url_from_endpoint(endpoint: str) -> str: - parsed = urlsplit(endpoint.strip()) - if not parsed.scheme or not parsed.netloc: - return "" - return urlunsplit((parsed.scheme, parsed.netloc, "", "", "")) - - -def stream_probe_lines(response: requests.Response) -> Generator[str, None, None]: - content_type = (response.headers.get("content-type") or "").lower() - - if "text/event-stream" not in content_type: - body = response.text - yield "Server did not return SSE stream (text/event-stream)." - yield f"Content-Type: {content_type or ''}" - if body: - yield "" - yield "Raw response body:" - yield body - return - - for raw_line in response.iter_lines(decode_unicode=True): - if not raw_line: - continue - - line = raw_line.strip() - if not line.startswith("data:"): - continue - - payload = line[5:].strip() - if payload == "[DONE]": - yield "\n\n[stream done]" - return - - try: - event = json.loads(payload) - except json.JSONDecodeError: - yield payload - continue - - choices = event.get("choices", []) - if not choices: - yield payload - continue - - delta = choices[0].get("delta", {}) - token = delta.get("content") - if token is not None: - yield token - - -with st.sidebar: - st.header("Connection") - endpoint = st.text_input( - "Server endpoint", - value="http://127.0.0.1:8081/v1/chat/completions", - help="Your Zig server chat completion endpoint.", - ) - api_key = st.text_input( - "Client API key (optional)", - value="", - type="password", - help="Sends x-api-key header if provided.", - ) - - st.header("Request") - provider = st.text_input("provider", value="ollama") - model = st.text_input("model", value="auto") - thinking = st.checkbox( - "thinking", - value=False, - help="Enable model reasoning/thinking tokens. Turn off for faster and shorter replies.", - ) - temperature = st.slider( - "temperature", - min_value=0.0, - max_value=2.0, - value=0.7, - step=0.05, - help="Lower values are more deterministic. Higher values are more creative.", - ) - session_id = st.text_input("session_id (optional)", value="") - tenant_id = st.text_input("tenant_id (optional)", value="") - max_context_tokens = st.number_input( - "max_context_tokens (optional)", - min_value=0, - value=0, - step=256, - ) - - st.header("Loop") - loop_mode = st.selectbox( - "loop_mode", - options=["none", "basic", "agent"], - index=0, - help="Server-side iterative loop mode usable from any frontend client.", - ) - loop_until = st.text_input("loop_until", value="DONE") - loop_max_turns = st.number_input( - "loop_max_turns", - min_value=0, - max_value=128, - value=0, - step=1, - help="0 means disabled and server defaults apply when loop_mode is used.", - ) - - st.header("Tools (debug)") - selected_tools = st.multiselect( - "Available tools to send", - options=["echo", "utc", "cmd", "bash"], - default=["utc", "cmd"], - help="Multiple tools can be sent in one request. cmd/bash require server env LLM_ROUTER_TOOL_EXEC_ENABLED=1.", - ) - - tool_choice = st.selectbox( - "tool_choice", - options=["auto", "none", "required", "echo", "utc", "cmd", "bash"], - index=0, - help="Use auto for prompt-driven tool execution with multiple tools.", - ) - -st.subheader("Prompt") -prompt = st.text_area( - "Message", - value="Get time using utc tool and run the ping google.com command.", - height=120, -) - -col1, col2 = st.columns(2) -send_clicked = col1.button("Send POST request", use_container_width=True) -stream_clicked = col2.button("Probe streaming", use_container_width=True) - -diag_col1, diag_col2, diag_col3 = st.columns(3) -health_clicked = diag_col1.button("GET /health", use_container_width=True) -metrics_clicked = diag_col2.button("GET /metrics", use_container_width=True) -providers_clicked = diag_col3.button("GET /diagnostics/providers", use_container_width=True) - - -def build_payload(force_stream: bool) -> Dict: - message_content = prompt - - payload: Dict = { - "provider": provider.strip() or "ollama", - "messages": [{"role": "user", "content": message_content}], - "think": bool(thinking), - "temperature": float(temperature), - } - - if model.strip(): - payload["model"] = model.strip() - - if session_id.strip(): - payload["session_id"] = session_id.strip() - - if tenant_id.strip(): - payload["tenant_id"] = tenant_id.strip() - - if max_context_tokens > 0: - payload["max_context_tokens"] = int(max_context_tokens) - - if loop_mode != "none": - payload["loop_mode"] = loop_mode - if loop_until.strip(): - payload["loop_until"] = loop_until.strip() - if loop_max_turns > 0: - payload["loop_max_turns"] = int(loop_max_turns) - - if selected_tools: - tool_descriptions = { - "echo": "Return a deterministic echo response for harness debugging.", - "utc": "Return current UTC timestamp for deterministic tool-path checks.", - "cmd": "Execute a Windows cmd command from prompt text (debug use).", - "bash": "Execute a bash command from prompt text (debug use).", - } - payload["tools"] = [ - { - "name": tool_name, - "description": tool_descriptions.get(tool_name, "debug tool"), - } - for tool_name in selected_tools - ] - - if tool_choice != "none": - payload["tool_choice"] = tool_choice - - if force_stream: - payload["stream"] = True - - return payload - - -def request_timeout_seconds(force_stream: bool) -> int: - timeout = 120 if force_stream else 60 - - if loop_mode != "none": - # Loop mode can require multiple provider turns in one HTTP request. - timeout = max(timeout, 240) - - if ("cmd" in selected_tools) or ("bash" in selected_tools): - timeout = max(timeout, 120) - - return timeout - - -if send_clicked: - if not endpoint.strip(): - st.error("Endpoint is required.") - elif not prompt.strip(): - st.error("Prompt is required.") - else: - payload = build_payload(force_stream=False) - st.code(json.dumps(payload, indent=2), language="json") - - try: - timeout_sec = request_timeout_seconds(force_stream=False) - response = requests.post( - endpoint.strip(), - headers=build_headers(api_key), - json=payload, - timeout=timeout_sec, - ) - except requests.Timeout: - st.error( - "Request timed out. Increase timeout by disabling loop mode, reducing loop_max_turns, " - "or simplifying tool usage in the prompt." - ) - except requests.RequestException as exc: - st.error(f"Request failed: {exc}") - else: - st.write(f"HTTP {response.status_code}") - try: - data = response.json() - except ValueError: - st.error("Response is not valid JSON.") - st.code(response.text) - else: - st.json(data) - assistant = parse_assistant_content(data) - if assistant: - st.success("Assistant output") - st.write(assistant) - -if stream_clicked: - if not endpoint.strip(): - st.error("Endpoint is required.") - elif not prompt.strip(): - st.error("Prompt is required.") - else: - if loop_mode != "none": - st.info( - "Loop progress visibility is controlled by server setting " - "LLM_ROUTER_LOOP_STREAM_PROGRESS_ENABLED (default: enabled)." - ) - - payload = build_payload(force_stream=True) - st.code(json.dumps(payload, indent=2), language="json") - - try: - timeout_sec = request_timeout_seconds(force_stream=True) - response = requests.post( - endpoint.strip(), - headers=build_headers(api_key), - json=payload, - timeout=timeout_sec, - stream=True, - ) - except requests.Timeout: - st.error( - "Streaming probe timed out. Try a shorter prompt, fewer loop turns, or disable loop mode." - ) - except requests.RequestException as exc: - st.error(f"Streaming probe failed: {exc}") - else: - st.write(f"HTTP {response.status_code}") - if response.status_code >= 400: - st.error("Server returned an error status during stream probe.") - st.code(response.text) - else: - st.write("Streaming output") - st.write_stream(stream_probe_lines(response)) - -if health_clicked or metrics_clicked or providers_clicked: - base_url = base_url_from_endpoint(endpoint) - if not base_url: - st.error("Could not infer base URL from endpoint.") - else: - if health_clicked: - target_path = "/health" - elif metrics_clicked: - target_path = "/metrics" - else: - target_path = "/diagnostics/providers" - - target_url = f"{base_url}{target_path}" - st.code(target_url) - - try: - resp = requests.get( - target_url, - headers=build_headers(api_key), - timeout=30, - ) - except requests.RequestException as exc: - st.error(f"Diagnostics request failed: {exc}") - else: - st.write(f"HTTP {resp.status_code}") - try: - st.json(resp.json()) - except ValueError: - st.code(resp.text) +import json +from dataclasses import dataclass +from typing import Dict, Generator, List, Optional +from urllib.parse import urlsplit, urlunsplit + +import requests +import streamlit as st + +st.set_page_config(page_title="Zig Agent Test Client", layout="wide") +st.title("Zig AI Harness Test Client") +st.caption("Debug and test harness behavior: chat routing, tool execution, sessions, streaming, and diagnostics.") + +CHAT_STATE_KEY = "chat_messages" + + +@dataclass +class ChatMessage: + role: str + content: str + + +def build_headers(api_key: str) -> Dict[str, str]: + headers = {"Content-Type": "application/json"} + if api_key.strip(): + headers["x-api-key"] = api_key.strip() + return headers + + +def parse_assistant_content(data: Dict) -> str: + choices = data.get("choices", []) + if not choices: + return "" + message = choices[0].get("message", {}) + return message.get("content", "") + + +def base_url_from_endpoint(endpoint: str) -> str: + parsed = urlsplit(endpoint.strip()) + if not parsed.scheme or not parsed.netloc: + return "" + return urlunsplit((parsed.scheme, parsed.netloc, "", "", "")) + + +def stream_probe_lines(response: requests.Response) -> Generator[str, None, None]: + content_type = (response.headers.get("content-type") or "").lower() + + if "text/event-stream" not in content_type: + body = response.text + yield "Server did not return SSE stream (text/event-stream)." + yield f"Content-Type: {content_type or ''}" + if body: + yield "" + yield "Raw response body:" + yield body + return + + for raw_line in response.iter_lines(decode_unicode=True): + if not raw_line: + continue + + line = raw_line.strip() + if not line.startswith("data:"): + continue + + payload = line[5:].strip() + if payload == "[DONE]": + yield "\n\n[stream done]" + return + + try: + event = json.loads(payload) + except json.JSONDecodeError: + yield payload + continue + + choices = event.get("choices", []) + if not choices: + yield payload + continue + + delta = choices[0].get("delta", {}) + token = delta.get("content") + if token is not None: + yield token + + +def ensure_chat_state() -> None: + if CHAT_STATE_KEY not in st.session_state: + st.session_state[CHAT_STATE_KEY] = [ + ChatMessage(role="system", content="You are a helpful assistant. Keep replies concise."), + ] + + +def chat_state() -> List[ChatMessage]: + ensure_chat_state() + return st.session_state[CHAT_STATE_KEY] + + +def add_chat_message(role: str, content: str) -> None: + ensure_chat_state() + st.session_state[CHAT_STATE_KEY].append(ChatMessage(role=role, content=content)) + + +def render_chat(messages: List[ChatMessage]) -> None: + for msg in messages: + if msg.role == "system": + continue + with st.chat_message(msg.role): + st.markdown(msg.content) + + +with st.sidebar: + st.header("Connection") + endpoint = st.text_input( + "Server endpoint", + value="http://127.0.0.1:8081/v1/chat/completions", + help="Your Zig server chat completion endpoint.", + ) + api_key = st.text_input( + "Client API key (optional)", + value="", + type="password", + help="Sends x-api-key header if provided.", + ) + + st.header("Request") + provider = st.text_input("provider", value="ollama") + model = st.text_input("model", value="auto") + thinking = st.checkbox( + "thinking", + value=False, + help="Enable model reasoning/thinking tokens. Turn off for faster and shorter replies.", + ) + temperature = st.slider( + "temperature", + min_value=0.0, + max_value=2.0, + value=0.7, + step=0.05, + help="Lower values are more deterministic. Higher values are more creative.", + ) + session_id = st.text_input("session_id (optional)", value="") + tenant_id = st.text_input("tenant_id (optional)", value="") + max_context_tokens = st.number_input( + "max_context_tokens (optional)", + min_value=0, + value=0, + step=256, + ) + + st.header("Loop") + loop_mode = st.selectbox( + "loop_mode", + options=["none", "basic", "agent"], + index=0, + help="Server-side iterative loop mode usable from any frontend client.", + ) + loop_until = st.text_input("loop_until", value="DONE") + loop_max_turns = st.number_input( + "loop_max_turns", + min_value=0, + max_value=128, + value=0, + step=1, + help="0 means disabled and server defaults apply when loop_mode is used.", + ) + + st.header("Tools (debug)") + selected_tools = st.multiselect( + "Available tools to send", + options=["echo", "utc", "cmd", "bash", "file_read", "file_write", "file_search"], + default=["utc", "cmd", "file_write"], + help="Multiple tools can be sent in one request. cmd/bash require server env LLM_ROUTER_TOOL_EXEC_ENABLED=1.", + ) + + tool_choice = st.selectbox( + "tool_choice", + options=[ + "auto", + "none", + "required", + "echo", + "utc", + "cmd", + "bash", + "file_read", + "file_write", + "file_search", + ], + index=0, + help="Use auto for prompt-driven tool execution with multiple tools.", + ) + + +def build_payload(messages: List[ChatMessage], force_stream: bool) -> Dict: + payload: Dict = { + "provider": provider.strip() or "ollama", + "messages": [{"role": m.role, "content": m.content} for m in messages if m.content.strip()], + "think": bool(thinking), + "temperature": float(temperature), + } + + if model.strip(): + payload["model"] = model.strip() + + if session_id.strip(): + payload["session_id"] = session_id.strip() + + if tenant_id.strip(): + payload["tenant_id"] = tenant_id.strip() + + if max_context_tokens > 0: + payload["max_context_tokens"] = int(max_context_tokens) + + if loop_mode != "none": + payload["loop_mode"] = loop_mode + if loop_until.strip(): + payload["loop_until"] = loop_until.strip() + if loop_max_turns > 0: + payload["loop_max_turns"] = int(loop_max_turns) + + if selected_tools: + tool_descriptions = { + "echo": "Return a deterministic echo response for harness debugging.", + "utc": "Return current UTC timestamp for deterministic tool-path checks.", + "cmd": "Execute a Windows cmd command from prompt text (debug use).", + "bash": "Execute a bash command from prompt text (debug use).", + "file_read": "Read a relative file path and return its content (debug use).", + "file_write": "Write a relative file path from a fenced code block (debug use).", + "file_search": "Search a substring under a relative directory (debug use).", + } + payload["tools"] = [ + {"name": tool_name, "description": tool_descriptions.get(tool_name, "debug tool")} + for tool_name in selected_tools + ] + if tool_choice != "none": + payload["tool_choice"] = tool_choice + + if force_stream: + payload["stream"] = True + + return payload + + +def request_timeout_seconds(force_stream: bool) -> int: + timeout = 120 if force_stream else 60 + + if loop_mode != "none": + timeout = max(timeout, 240) + + if ("cmd" in selected_tools) or ("bash" in selected_tools): + timeout = max(timeout, 120) + + return timeout + + +def send_chat_request(force_stream: bool) -> Optional[Dict]: + if not endpoint.strip(): + st.error("Endpoint is required.") + return None + + payload = build_payload(chat_state(), force_stream=force_stream) + st.code(json.dumps(payload, indent=2), language="json") + + try: + timeout_sec = request_timeout_seconds(force_stream=force_stream) + resp = requests.post( + endpoint.strip(), + headers=build_headers(api_key), + json=payload, + timeout=timeout_sec, + stream=force_stream, + ) + except requests.Timeout: + st.error( + "Request timed out. Increase timeout by disabling loop mode, reducing loop_max_turns, " + "or simplifying tool usage in the prompt." + ) + return None + except requests.RequestException as exc: + st.error(f"Request failed: {exc}") + return None + + st.write(f"HTTP {resp.status_code}") + + if force_stream: + if resp.status_code >= 400: + st.error("Server returned an error status during stream.") + st.code(resp.text) + return None + + with st.chat_message("assistant"): + streamed_text = st.write_stream(stream_probe_lines(resp)) + add_chat_message("assistant", str(streamed_text)) + return None + + try: + data = resp.json() + except ValueError: + st.error("Response is not valid JSON.") + st.code(resp.text) + return None + + st.json(data) + assistant = parse_assistant_content(data) + if assistant: + with st.chat_message("assistant"): + st.markdown(assistant) + add_chat_message("assistant", assistant) + return data + + +st.subheader("Chat") +ensure_chat_state() + +chat_controls_col1, chat_controls_col2, chat_controls_col3 = st.columns(3) +with chat_controls_col1: + if st.button("Clear chat", use_container_width=True): + st.session_state.pop(CHAT_STATE_KEY, None) + ensure_chat_state() +with chat_controls_col2: + if st.button("Add demo prompt", use_container_width=True): + add_chat_message( + "user", + "write a python script to add two numbers then execute and print it's output", + ) +with chat_controls_col3: + stream_mode = st.checkbox("Stream responses (SSE)", value=False) + +render_chat(chat_state()) + +user_input = st.chat_input("Type a message and press Enter…") +if user_input and user_input.strip(): + add_chat_message("user", user_input.strip()) + if stream_mode: + if loop_mode != "none": + st.info( + "Loop progress visibility is controlled by server setting " + "LLM_ROUTER_LOOP_STREAM_PROGRESS_ENABLED (default: enabled)." + ) + send_chat_request(force_stream=True) + else: + send_chat_request(force_stream=False) + + +st.subheader("Diagnostics") +diag_col1, diag_col2, diag_col3 = st.columns(3) +health_clicked = diag_col1.button("GET /health", use_container_width=True) +metrics_clicked = diag_col2.button("GET /metrics", use_container_width=True) +providers_clicked = diag_col3.button("GET /diagnostics/providers", use_container_width=True) + +if health_clicked or metrics_clicked or providers_clicked: + base_url = base_url_from_endpoint(endpoint) + if not base_url: + st.error("Could not infer base URL from endpoint.") + else: + if health_clicked: + target_path = "/health" + elif metrics_clicked: + target_path = "/metrics" + else: + target_path = "/diagnostics/providers" + + target_url = f"{base_url}{target_path}" + st.code(target_url) + + try: + resp = requests.get( + target_url, + headers=build_headers(api_key), + timeout=30, + ) + except requests.RequestException as exc: + st.error(f"Diagnostics request failed: {exc}") + else: + st.write(f"HTTP {resp.status_code}") + try: + st.json(resp.json()) + except ValueError: + st.code(resp.text) + diff --git a/src/backend/tools.zig b/src/backend/tools.zig index 3a0545a..8efbc76 100644 --- a/src/backend/tools.zig +++ b/src/backend/tools.zig @@ -4,6 +4,7 @@ const types = @import("../types.zig"); const echo_tool = @import("../tools/echo.zig"); const utc_tool = @import("../tools/utc.zig"); const command_exec_tool = @import("../tools/command_exec.zig"); +const file_ops_tool = @import("../tools/file_ops.zig"); pub const ToolRegistry = struct { allowed_names: std.StringHashMap(void), @@ -72,6 +73,21 @@ pub fn tryExecuteDebugTool( return try command_exec_tool.execute(allocator, app_config, request, .bash); } + if (std.ascii.eqlIgnoreCase(choice, "file_read")) { + if (!hasRequestedTool(request.tools, "file_read")) return null; + return try file_ops_tool.execute(allocator, app_config, request, .read); + } + + if (std.ascii.eqlIgnoreCase(choice, "file_write")) { + if (!hasRequestedTool(request.tools, "file_write")) return null; + return try file_ops_tool.execute(allocator, app_config, request, .write); + } + + if (std.ascii.eqlIgnoreCase(choice, "file_search")) { + if (!hasRequestedTool(request.tools, "file_search")) return null; + return try file_ops_tool.execute(allocator, app_config, request, .search); + } + return null; } @@ -105,6 +121,71 @@ pub fn maybeExecutePromptToolsAlloc( const cmd_requested = hasRequestedTool(request.tools, "cmd"); const bash_requested = hasRequestedTool(request.tools, "bash"); + const file_write_requested = hasRequestedTool(request.tools, "file_write"); + const file_read_requested = hasRequestedTool(request.tools, "file_read"); + const file_search_requested = hasRequestedTool(request.tools, "file_search"); + + if (file_read_requested and mentionsFileReadIntent(prompt)) { + const read_result = try file_ops_tool.execute(allocator, app_config, request, .read); + defer read_result.deinit(allocator); + try output.writer(allocator).print("\n[tool=file_read]\n{s}\n", .{read_result.output}); + executed_any = true; + } + + if (file_search_requested and mentionsFileSearchIntent(prompt)) { + const search_result = try file_ops_tool.execute(allocator, app_config, request, .search); + defer search_result.deinit(allocator); + try output.writer(allocator).print("\n[tool=file_search]\n{s}\n", .{search_result.output}); + executed_any = true; + } + + if (file_write_requested and mentionsPythonAddTwoNumbersIntent(prompt)) { + const demo_path = "tmp/add_two_numbers.py"; + const demo_content = + \\a = 2 + \\b = 3 + \\print(a + b) + \\ + ; + + const write_prompt = try std.fmt.allocPrint( + allocator, + "write {s}\n```python\n{s}```", + .{ demo_path, demo_content }, + ); + defer allocator.free(write_prompt); + + const write_req = try buildSinglePromptRequestAlloc(allocator, write_prompt); + defer write_req.deinit(allocator); + + const write_result = try file_ops_tool.execute(allocator, app_config, write_req, .write); + defer write_result.deinit(allocator); + + try output.writer(allocator).print("\n[tool=file_write]\n{s}\n", .{write_result.output}); + executed_any = true; + + if (cmd_requested and @import("builtin").os.tag == .windows) { + const run_prompt = try std.fmt.allocPrint(allocator, "python {s}", .{demo_path}); + defer allocator.free(run_prompt); + const run_req = try buildSinglePromptRequestAlloc(allocator, run_prompt); + defer run_req.deinit(allocator); + + const run_result = try command_exec_tool.execute(allocator, app_config, run_req, .cmd); + defer run_result.deinit(allocator); + + try output.writer(allocator).print("\n[tool=cmd]\n{s}\n", .{run_result.output}); + } else if (bash_requested and @import("builtin").os.tag != .windows) { + const run_prompt = try std.fmt.allocPrint(allocator, "python3 {s}", .{demo_path}); + defer allocator.free(run_prompt); + const run_req = try buildSinglePromptRequestAlloc(allocator, run_prompt); + defer run_req.deinit(allocator); + + const run_result = try command_exec_tool.execute(allocator, app_config, run_req, .bash); + defer run_result.deinit(allocator); + + try output.writer(allocator).print("\n[tool=bash]\n{s}\n", .{run_result.output}); + } + } if (cmd_requested) { const command_to_run = try command_exec_tool.extractCommandFromPromptAlloc(allocator, .cmd, prompt); @@ -142,6 +223,36 @@ pub fn maybeExecutePromptToolsAlloc( return try output.toOwnedSlice(allocator); } +pub fn shouldShortCircuitAutoTools(request: types.Request) bool { + if (request.tools.len == 0) return false; + + const choice = request.tool_choice orelse return false; + if (!std.ascii.eqlIgnoreCase(choice, "auto") and !std.ascii.eqlIgnoreCase(choice, "required")) return false; + + const prompt = request.prompt; + if (mentionsPythonAddTwoNumbersIntent(prompt)) return true; + if (hasRequestedTool(request.tools, "file_read") and mentionsFileReadIntent(prompt)) return true; + if (hasRequestedTool(request.tools, "file_search") and mentionsFileSearchIntent(prompt)) return true; + return false; +} + +pub fn makeAutoToolResponse( + allocator: std.mem.Allocator, + summary: []const u8, +) !types.Response { + const model_name = try std.fmt.allocPrint(allocator, "debug-tools/{s}", .{"auto"}); + errdefer allocator.free(model_name); + + return .{ + .id = null, + .model = model_name, + .output = try allocator.dupe(u8, summary), + .finish_reason = try allocator.dupe(u8, "tool"), + .success = true, + .usage = .{}, + }; +} + fn hasRequestedTool(tools: []const types.Tool, name: []const u8) bool { for (tools) |tool| { if (std.ascii.eqlIgnoreCase(tool.name, name)) return true; @@ -154,6 +265,24 @@ fn mentionsUtcIntent(prompt: []const u8) bool { containsIgnoreCase(prompt, "time"); } +fn mentionsFileReadIntent(prompt: []const u8) bool { + return containsIgnoreCase(prompt, "file_read") or + containsIgnoreCase(prompt, "read file") or + containsIgnoreCase(prompt, "open file"); +} + +fn mentionsFileSearchIntent(prompt: []const u8) bool { + return containsIgnoreCase(prompt, "file_search") or + containsIgnoreCase(prompt, "search ") or + containsIgnoreCase(prompt, "find "); +} + +fn mentionsPythonAddTwoNumbersIntent(prompt: []const u8) bool { + return containsIgnoreCase(prompt, "python") and + containsIgnoreCase(prompt, "add two") and + containsIgnoreCase(prompt, "execute"); +} + fn containsIgnoreCase(haystack: []const u8, needle: []const u8) bool { return indexOfIgnoreCase(haystack, needle) != null; } diff --git a/src/core/server.zig b/src/core/server.zig index 55e121b..719320d 100644 --- a/src/core/server.zig +++ b/src/core/server.zig @@ -205,6 +205,9 @@ pub fn run( try tool_registry.register(allocator, "utc"); try tool_registry.register(allocator, "cmd"); try tool_registry.register(allocator, "bash"); + try tool_registry.register(allocator, "file_read"); + try tool_registry.register(allocator, "file_write"); + try tool_registry.register(allocator, "file_search"); var file_session_store: ?session.FileSessionStore = null; var session_store: ?session.SessionStore = null; @@ -558,6 +561,47 @@ fn handleConnection( const requested_provider = parsed_req.provider orelse app_config.default_provider; const normalized_provider = types.normalizeProviderName(requested_provider) orelse requested_provider; + var stream_request = try cloneRequestWithMessagesAlloc( + allocator, + parsed_req, + request_prompt, + request_messages, + ); + defer stream_request.deinit(allocator); + + if (try tooling.tryExecuteDebugTool(allocator, stream_request, app_config)) |tool_result| { + defer tool_result.deinit(allocator); + + try response.sendEventStreamHeaders(connection.*); + const completion_id = try std.fmt.allocPrint( + allocator, + "chatcmpl-{d}", + .{std.time.microTimestamp()}, + ); + defer allocator.free(completion_id); + + try response.sendChatCompletionChunkSse( + connection.*, + allocator, + completion_id, + tool_result.model, + tool_result.output, + null, + ); + try response.sendChatCompletionChunkSse( + connection.*, + allocator, + completion_id, + tool_result.model, + null, + tool_result.finish_reason, + ); + try response.sendSseDone(connection.*); + + server_state.noteRequestSucceeded(); + return; + } + if (!std.mem.eql(u8, normalized_provider, "ollama_qwen")) { sendApiErrorSafe(connection.*, allocator, backend.errors.validationError( "stream=true is currently supported only for ollama", @@ -568,14 +612,6 @@ fn handleConnection( return; } - var stream_request = try cloneRequestWithMessagesAlloc( - allocator, - parsed_req, - request_prompt, - request_messages, - ); - defer stream_request.deinit(allocator); - const stream_auto_tool_summary = try tooling.maybeExecutePromptToolsAlloc(allocator, stream_request, app_config); defer if (stream_auto_tool_summary) |summary| allocator.free(summary); @@ -656,6 +692,16 @@ fn handleConnection( const auto_tool_summary = try tooling.maybeExecutePromptToolsAlloc(allocator, provider_request, app_config); defer if (auto_tool_summary) |summary| allocator.free(summary); + if (auto_tool_summary) |summary| { + if (tooling.shouldShortCircuitAutoTools(provider_request)) { + var tool_response = try tooling.makeAutoToolResponse(allocator, summary); + defer tool_response.deinit(allocator); + sendChatCompletionSafe(connection.*, allocator, tool_response, app_config); + server_state.noteRequestSucceeded(); + return; + } + } + if (auto_tool_summary) |summary| { const augmented_messages = try appendSystemMessageAlloc(allocator, provider_request.messages, summary); diff --git a/src/tools/file_ops.zig b/src/tools/file_ops.zig new file mode 100644 index 0000000..385241a --- /dev/null +++ b/src/tools/file_ops.zig @@ -0,0 +1,459 @@ +const std = @import("std"); +const config = @import("../config.zig"); +const types = @import("../types.zig"); + +pub const Op = enum { read, write, search }; + +pub fn execute( + allocator: std.mem.Allocator, + app_config: *const config.Config, + request: types.Request, + op: Op, +) !types.Response { + const tool_name: []const u8 = switch (op) { + .read => "file_read", + .write => "file_write", + .search => "file_search", + }; + + const prompt = std.mem.trim(u8, request.prompt, " \t\r\n"); + if (prompt.len == 0) { + return try makeToolResponse( + allocator, + tool_name, + try std.fmt.allocPrint(allocator, "DEBUG_TOOL_ERROR\ntool={s}\nmessage=empty prompt", .{tool_name}), + ); + } + + return switch (op) { + .read => executeRead(allocator, app_config, tool_name, prompt), + .write => executeWrite(allocator, app_config, tool_name, request, prompt), + .search => executeSearch(allocator, app_config, tool_name, prompt), + }; +} + +fn executeRead( + allocator: std.mem.Allocator, + app_config: *const config.Config, + tool_name: []const u8, + prompt: []const u8, +) !types.Response { + const path_raw = parseSingleArgAfterPrefix(prompt, "read ") orelse + parseSingleArgAfterPrefix(prompt, "file_read ") orelse + parseSingleArgAfterPrefix(prompt, "open ") orelse + prompt; + + const rel_path = validateRelativePath(path_raw) catch |err| { + return try makeToolResponse( + allocator, + tool_name, + try std.fmt.allocPrint( + allocator, + "DEBUG_TOOL_ERROR\ntool={s}\nvalidation_error={s}\nmessage=invalid path", + .{ tool_name, @errorName(err) }, + ), + ); + }; + + const tmp_rel_path = try prefixTmpPathAlloc(allocator, rel_path); + defer allocator.free(tmp_rel_path); + + const start_ms = std.time.milliTimestamp(); + var file = std.fs.cwd().openFile(tmp_rel_path, .{}) catch |err| { + return try makeToolResponse( + allocator, + tool_name, + try std.fmt.allocPrint(allocator, "DEBUG_TOOL_ERROR\ntool={s}\nopen_error={s}\npath={s}", .{ tool_name, @errorName(err), tmp_rel_path }), + ); + }; + defer file.close(); + + const bytes = file.readToEndAlloc(allocator, app_config.tool_exec_max_output_bytes) catch |err| { + return try makeToolResponse( + allocator, + tool_name, + try std.fmt.allocPrint(allocator, "DEBUG_TOOL_ERROR\ntool={s}\nread_error={s}\npath={s}", .{ tool_name, @errorName(err), tmp_rel_path }), + ); + }; + defer allocator.free(bytes); + + const end_ms = std.time.milliTimestamp(); + const duration_ms: u64 = if (end_ms >= start_ms) @as(u64, @intCast(end_ms - start_ms)) else 0; + + const out = try std.fmt.allocPrint( + allocator, + "DEBUG_TOOL_OK\ntool={s}\npath={s}\nduration_ms={d}\nmax_bytes={d}\n--- content ---\n{s}", + .{ tool_name, tmp_rel_path, duration_ms, app_config.tool_exec_max_output_bytes, bytes }, + ); + return try makeToolResponse(allocator, tool_name, out); +} + +fn executeWrite( + allocator: std.mem.Allocator, + app_config: *const config.Config, + tool_name: []const u8, + request: types.Request, + prompt: []const u8, +) !types.Response { + const parsed = parseWritePrompt(allocator, prompt) catch blk: { + const inferred_path = inferPyFilenameFromPromptAlloc(allocator, prompt) orelse break :blk null; + errdefer allocator.free(inferred_path); + + const inferred_content = extractLastFencedCodeBlockAlloc(allocator, request.messages) orelse { + allocator.free(inferred_path); + break :blk null; + }; + errdefer allocator.free(inferred_content); + + break :blk ParsedWrite{ + .path = inferred_path, + .content = inferred_content, + }; + } orelse { + return try makeToolResponse( + allocator, + tool_name, + try std.fmt.allocPrint( + allocator, + "DEBUG_TOOL_ERROR\ntool={s}\nmessage=could not parse write request; use `write ` + fenced code block, or say `save the script as count.py` after providing a fenced code block", + .{tool_name}, + ), + ); + }; + defer allocator.free(parsed.path); + defer allocator.free(parsed.content); + + const rel_path = validateRelativePath(parsed.path) catch |err| { + return try makeToolResponse( + allocator, + tool_name, + try std.fmt.allocPrint(allocator, "DEBUG_TOOL_ERROR\ntool={s}\nvalidation_error={s}\nmessage=invalid path", .{ tool_name, @errorName(err) }), + ); + }; + + const tmp_rel_path = try prefixTmpPathAlloc(allocator, rel_path); + defer allocator.free(tmp_rel_path); + + if (parsed.content.len > app_config.tool_exec_max_output_bytes) { + return try makeToolResponse( + allocator, + tool_name, + try std.fmt.allocPrint( + allocator, + "DEBUG_TOOL_ERROR\ntool={s}\nmessage=content too large\ncontent_bytes={d}\nmax_bytes={d}", + .{ tool_name, parsed.content.len, app_config.tool_exec_max_output_bytes }, + ), + ); + } + + const start_ms = std.time.milliTimestamp(); + + if (std.fs.path.dirname(tmp_rel_path)) |dir_name| { + std.fs.cwd().makePath(dir_name) catch |err| { + return try makeToolResponse( + allocator, + tool_name, + try std.fmt.allocPrint(allocator, "DEBUG_TOOL_ERROR\ntool={s}\nmkdir_error={s}\ndir={s}", .{ tool_name, @errorName(err), dir_name }), + ); + }; + } + + var file = std.fs.cwd().createFile(tmp_rel_path, .{ .truncate = true }) catch |err| { + return try makeToolResponse( + allocator, + tool_name, + try std.fmt.allocPrint(allocator, "DEBUG_TOOL_ERROR\ntool={s}\ncreate_error={s}\npath={s}", .{ tool_name, @errorName(err), tmp_rel_path }), + ); + }; + defer file.close(); + + file.writeAll(parsed.content) catch |err| { + return try makeToolResponse( + allocator, + tool_name, + try std.fmt.allocPrint(allocator, "DEBUG_TOOL_ERROR\ntool={s}\nwrite_error={s}\npath={s}", .{ tool_name, @errorName(err), tmp_rel_path }), + ); + }; + + const end_ms = std.time.milliTimestamp(); + const duration_ms: u64 = if (end_ms >= start_ms) @as(u64, @intCast(end_ms - start_ms)) else 0; + + const out = try std.fmt.allocPrint( + allocator, + "DEBUG_TOOL_OK\ntool={s}\npath={s}\nbytes_written={d}\nduration_ms={d}", + .{ tool_name, tmp_rel_path, parsed.content.len, duration_ms }, + ); + return try makeToolResponse(allocator, tool_name, out); +} + +fn executeSearch( + allocator: std.mem.Allocator, + app_config: *const config.Config, + tool_name: []const u8, + prompt: []const u8, +) !types.Response { + const parsed = parseSearchPrompt(allocator, prompt) catch |err| { + return try makeToolResponse( + allocator, + tool_name, + try std.fmt.allocPrint( + allocator, + "DEBUG_TOOL_ERROR\ntool={s}\nparse_error={s}\nmessage=expected: search [in ]", + .{ tool_name, @errorName(err) }, + ), + ); + }; + defer allocator.free(parsed.pattern); + defer allocator.free(parsed.dir); + + const rel_dir = validateRelativePath(parsed.dir) catch |err| { + return try makeToolResponse( + allocator, + tool_name, + try std.fmt.allocPrint(allocator, "DEBUG_TOOL_ERROR\ntool={s}\nvalidation_error={s}\nmessage=invalid dir", .{ tool_name, @errorName(err) }), + ); + }; + + const tmp_rel_dir = try prefixTmpPathAlloc(allocator, rel_dir); + defer allocator.free(tmp_rel_dir); + + const start_ms = std.time.milliTimestamp(); + + var out = std.ArrayList(u8){}; + defer out.deinit(allocator); + + try out.writer(allocator).print( + "DEBUG_TOOL_OK\ntool={s}\ndir={s}\npattern={s}\nmax_bytes={d}\n--- matches ---\n", + .{ tool_name, tmp_rel_dir, parsed.pattern, app_config.tool_exec_max_output_bytes }, + ); + + var bytes_budget: usize = app_config.tool_exec_max_output_bytes; + if (out.items.len < bytes_budget) bytes_budget -= out.items.len else bytes_budget = 0; + + var matches: usize = 0; + var files_scanned: usize = 0; + const max_files: usize = 500; + + var dir = std.fs.cwd().openDir(tmp_rel_dir, .{ .iterate = true }) catch |err| { + return try makeToolResponse( + allocator, + tool_name, + try std.fmt.allocPrint(allocator, "DEBUG_TOOL_ERROR\ntool={s}\nopen_dir_error={s}\ndir={s}", .{ tool_name, @errorName(err), tmp_rel_dir }), + ); + }; + defer dir.close(); + + var walker = dir.walk(allocator) catch |err| return err; + defer walker.deinit(); + + while (try walker.next()) |entry| { + if (files_scanned >= max_files) break; + if (entry.kind != .file) continue; + files_scanned += 1; + + const path_in_dir = entry.path; + if (std.mem.endsWith(u8, path_in_dir, ".exe") or std.mem.endsWith(u8, path_in_dir, ".dll")) continue; + + var f = dir.openFile(path_in_dir, .{}) catch continue; + defer f.close(); + + const bytes = f.readToEndAlloc(allocator, 64 * 1024) catch continue; + defer allocator.free(bytes); + + if (std.mem.indexOf(u8, bytes, parsed.pattern) == null) continue; + matches += 1; + + const line = try std.fmt.allocPrint(allocator, "{s}\n", .{path_in_dir}); + defer allocator.free(line); + + if (line.len > bytes_budget) break; + try out.appendSlice(allocator, line); + bytes_budget -= line.len; + if (bytes_budget == 0) break; + } + + const end_ms = std.time.milliTimestamp(); + const duration_ms: u64 = if (end_ms >= start_ms) @as(u64, @intCast(end_ms - start_ms)) else 0; + try out.writer(allocator).print("\nfiles_scanned={d}\nmatches={d}\nduration_ms={d}\n", .{ files_scanned, matches, duration_ms }); + + return try makeToolResponse(allocator, tool_name, try out.toOwnedSlice(allocator)); +} + +fn inferPyFilenameFromPromptAlloc(allocator: std.mem.Allocator, prompt: []const u8) ?[]u8 { + const trimmed = std.mem.trim(u8, prompt, " \t\r\n"); + if (trimmed.len == 0) return null; + + var it = std.mem.tokenizeAny(u8, trimmed, " \t\r\n\"'`()[]{}<>"); + var last_match: ?[]const u8 = null; + while (it.next()) |tok| { + if (tok.len < 4) continue; + if (!std.ascii.endsWithIgnoreCase(tok, ".py")) continue; + last_match = tok; + } + + if (last_match) |m| { + return allocator.dupe(u8, m) catch null; + } + return null; +} + +fn extractLastFencedCodeBlockAlloc(allocator: std.mem.Allocator, messages: []const types.Message) ?[]u8 { + var i: usize = messages.len; + while (i > 0) { + i -= 1; + const msg = messages[i]; + if (!std.ascii.eqlIgnoreCase(msg.role, "assistant")) continue; + + const content = msg.content; + const fence_end = std.mem.lastIndexOf(u8, content, "```") orelse continue; + const before_end = content[0..fence_end]; + const fence_start = std.mem.lastIndexOf(u8, before_end, "```") orelse continue; + + const after_start = content[fence_start + 3 ..]; + const nl = std.mem.indexOfScalar(u8, after_start, '\n') orelse continue; + const after_lang = after_start[nl + 1 ..]; + const end = std.mem.indexOf(u8, after_lang, "```") orelse continue; + + const inner = std.mem.trim(u8, after_lang[0..end], "\r\n"); + if (inner.len == 0) continue; + return allocator.dupe(u8, inner) catch null; + } + return null; +} + +const ParsedWrite = struct { + path: []u8, + content: []u8, +}; + +fn parseWritePrompt(allocator: std.mem.Allocator, prompt: []const u8) !ParsedWrite { + var trimmed = std.mem.trim(u8, prompt, " \t\r\n"); + if (!std.ascii.startsWithIgnoreCase(trimmed, "write ")) return error.BadFormat; + trimmed = trimmed["write ".len..]; + trimmed = std.mem.trimLeft(u8, trimmed, " \t"); + if (trimmed.len == 0) return error.BadFormat; + + const fence_idx = std.mem.indexOf(u8, trimmed, "```") orelse return error.BadFormat; + const path_part = std.mem.trim(u8, trimmed[0..fence_idx], " \t\r\n\"'"); + if (path_part.len == 0) return error.BadFormat; + + const fence_start = fence_idx; + const after_fence = trimmed[fence_start + 3 ..]; + const newline_idx = std.mem.indexOfScalar(u8, after_fence, '\n') orelse return error.BadFormat; + const after_lang = after_fence[newline_idx + 1 ..]; + const fence_end = std.mem.indexOf(u8, after_lang, "```") orelse return error.BadFormat; + const content_part = after_lang[0..fence_end]; + + return .{ + .path = try allocator.dupe(u8, path_part), + .content = try allocator.dupe(u8, content_part), + }; +} + +const ParsedSearch = struct { + pattern: []u8, + dir: []u8, +}; + +fn parseSearchPrompt(allocator: std.mem.Allocator, prompt: []const u8) !ParsedSearch { + var trimmed = std.mem.trim(u8, prompt, " \t\r\n"); + if (std.ascii.startsWithIgnoreCase(trimmed, "search ")) { + trimmed = trimmed["search ".len..]; + } else if (std.ascii.startsWithIgnoreCase(trimmed, "file_search ")) { + trimmed = trimmed["file_search ".len..]; + } else { + return error.BadFormat; + } + + trimmed = std.mem.trimLeft(u8, trimmed, " \t"); + if (trimmed.len == 0) return error.BadFormat; + + var dir_value: []const u8 = "."; + var pattern_value: []const u8 = trimmed; + + if (indexOfIgnoreCase(trimmed, " in ")) |idx| { + pattern_value = std.mem.trim(u8, trimmed[0..idx], " \t\r\n\"'"); + dir_value = std.mem.trim(u8, trimmed[idx + 4 ..], " \t\r\n\"'"); + } else { + pattern_value = std.mem.trim(u8, pattern_value, " \t\r\n\"'"); + } + + if (pattern_value.len == 0) return error.BadFormat; + if (dir_value.len == 0) dir_value = "."; + + return .{ + .pattern = try allocator.dupe(u8, pattern_value), + .dir = try allocator.dupe(u8, dir_value), + }; +} + +fn parseSingleArgAfterPrefix(text: []const u8, prefix: []const u8) ?[]const u8 { + if (!std.ascii.startsWithIgnoreCase(text, prefix)) return null; + return std.mem.trim(u8, text[prefix.len..], " \t\r\n\"'"); +} + +fn prefixTmpPathAlloc(allocator: std.mem.Allocator, rel_path: []const u8) ![]u8 { + // All tool file IO is sandboxed under `tmp/` so generated artifacts don't + // spill into the repo root. + if (std.ascii.eqlIgnoreCase(rel_path, ".")) { + return allocator.dupe(u8, "tmp"); + } + + if (std.ascii.startsWithIgnoreCase(rel_path, "tmp")) { + if (rel_path.len == 3) return allocator.dupe(u8, rel_path); + const next = rel_path[3]; + if (next == '/' or next == std.fs.path.sep) return allocator.dupe(u8, rel_path); + } + + return std.fmt.allocPrint(allocator, "tmp{c}{s}", .{ std.fs.path.sep, rel_path }); +} + +const PathValidationError = error{ + EmptyPath, + AbsolutePath, + ParentTraversal, + ContainsNul, +}; + +fn validateRelativePath(path_raw: []const u8) ![]const u8 { + const p = std.mem.trim(u8, path_raw, " \t\r\n\"'"); + if (p.len == 0) return error.EmptyPath; + if (std.mem.indexOfScalar(u8, p, 0) != null) return error.ContainsNul; + if (std.fs.path.isAbsolute(p)) return error.AbsolutePath; + + var it = std.mem.splitScalar(u8, p, std.fs.path.sep); + while (it.next()) |part| { + if (std.mem.eql(u8, part, "..")) return error.ParentTraversal; + } + + return p; +} + +fn indexOfIgnoreCase(haystack: []const u8, needle: []const u8) ?usize { + if (needle.len == 0) return 0; + if (needle.len > haystack.len) return null; + var i: usize = 0; + while (i + needle.len <= haystack.len) : (i += 1) { + if (std.ascii.eqlIgnoreCase(haystack[i .. i + needle.len], needle)) return i; + } + return null; +} + +fn makeToolResponse( + allocator: std.mem.Allocator, + tool_name: []const u8, + output: []u8, +) !types.Response { + errdefer allocator.free(output); + const model_name = try std.fmt.allocPrint(allocator, "debug-tools/{s}", .{tool_name}); + errdefer allocator.free(model_name); + return .{ + .id = null, + .model = model_name, + .output = output, + .finish_reason = try allocator.dupe(u8, "tool"), + .success = true, + .usage = .{}, + }; +} + From e4511c94144766f01744005cb4b1f491ba650277 Mon Sep 17 00:00:00 2001 From: Ajeets6 Date: Fri, 15 May 2026 14:30:22 +1000 Subject: [PATCH 06/25] Implement streaming continuation logic enhance response formatting in Ollama provider --- src/providers/ollama_qwen.zig | 431 ++++++++++++++++++++++++---------- 1 file changed, 312 insertions(+), 119 deletions(-) diff --git a/src/providers/ollama_qwen.zig b/src/providers/ollama_qwen.zig index aef1c83..265912c 100644 --- a/src/providers/ollama_qwen.zig +++ b/src/providers/ollama_qwen.zig @@ -30,6 +30,29 @@ const OllamaTuning = struct { num_predict: u32, }; +const max_stream_continuations: usize = 4; + +const stream_continue_prompt = + "The previous assistant message was cut off by the generation limit. " ++ + "Continue the same answer from the exact next token. Output only the continuation text; " ++ + "do not mention truncation, continuation, the previous message, or these instructions."; + +const OllamaStreamState = struct { + saw_done: bool = false, + stopped_by_length: bool = false, + think_block_open: *bool, +}; + +const StreamAttempt = struct { + saw_done: bool, + stopped_by_length: bool, +}; + +const StreamAttemptResult = union(enum) { + streamed: StreamAttempt, + failed: types.Response, +}; + pub const StreamQwenResult = union(enum) { streamed, failed: types.Response, @@ -65,51 +88,6 @@ pub fn streamQwenToSse( const uri = try std.Uri.parse(uri_text); - const messages_json = try renderMessagesJsonAlloc(allocator, request.messages); - defer allocator.free(messages_json); - - const body = try buildChatPayloadAlloc( - allocator, - app_config, - request, - model_name, - messages_json, - true, - ); - defer allocator.free(body); - - const headers = &[_]std.http.Header{ - .{ .name = "content-type", .value = "application/json" }, - }; - - var req = try client.request(.POST, uri, .{ - .extra_headers = headers, - .keep_alive = false, - }); - defer req.deinit(); - - req.transfer_encoding = .{ .content_length = body.len }; - - var req_body = try req.sendBodyUnflushed(&.{}); - try req_body.writer.writeAll(body); - try req_body.end(); - try req.connection.?.flush(); - - var upstream = try req.receiveHead(&.{}); - - if (upstream.head.status != .ok) { - const error_message = try std.fmt.allocPrint( - allocator, - "Ollama HTTP {s} for model '{s}' during stream request", - .{ @tagName(upstream.head.status), model_name }, - ); - defer allocator.free(error_message); - - return .{ .failed = try makeResponse(allocator, model_name, error_message, false) }; - } - - try response.sendEventStreamHeaders(connection); - const completion_id = try std.fmt.allocPrint( allocator, "chatcmpl-{d}", @@ -117,59 +95,74 @@ pub fn streamQwenToSse( ); defer allocator.free(completion_id); - var transfer_buffer: [2048]u8 = undefined; - const body_reader = upstream.reader(transfer_buffer[0..]); + var headers_sent = false; + var streamed_text = std.ArrayList(u8){}; + defer streamed_text.deinit(allocator); + var final_finish_reason: []const u8 = "stop"; + var think_block_open: bool = false; - var pending = std.ArrayList(u8){}; - defer pending.deinit(allocator); - - var read_buf: [1024]u8 = undefined; - var saw_done = false; - - while (true) { - const n = body_reader.readSliceShort(read_buf[0..]) catch |err| switch (err) { - error.ReadFailed => return upstream.bodyErr() orelse err, - }; - if (n == 0) break; + var turn: usize = 0; + while (true) : (turn += 1) { + const messages_json = if (turn == 0) + try renderMessagesJsonAlloc(allocator, request.messages) + else + try renderContinuationMessagesJsonAlloc( + allocator, + request.messages, + streamed_text.items, + stream_continue_prompt, + ); + defer allocator.free(messages_json); - try pending.appendSlice(allocator, read_buf[0..n]); - try processPendingStreamLines( + const attempt_result = try streamChatMessagesToSse( connection, allocator, + &client, + app_config, + request, + uri, completion_id, model_name, - pending.items, - &pending, - &saw_done, + messages_json, + &headers_sent, + &streamed_text, + &think_block_open, ); - if (n < read_buf.len) break; + switch (attempt_result) { + .failed => |failed_response| { + if (!headers_sent) return .{ .failed = failed_response }; + defer failed_response.deinit(allocator); + break; + }, + .streamed => |attempt| { + if (!attempt.stopped_by_length) break; + if (turn + 1 >= max_stream_continuations) { + final_finish_reason = "length"; + break; + } + }, + } } - if (pending.items.len > 0) { - const trailing = std.mem.trim(u8, pending.items, "\r\n \t"); - if (trailing.len > 0) { - try handleOllamaStreamLine( - connection, - allocator, - completion_id, - model_name, - trailing, - &saw_done, + if (headers_sent) { + if (think_block_open) { + try response.sendChatCompletionChunkSse( + connection, allocator, completion_id, model_name, "", null, ); + think_block_open = false; } - } - - if (!saw_done) { try response.sendChatCompletionChunkSse( connection, allocator, completion_id, model_name, null, - "stop", + final_finish_reason, ); try response.sendSseDone(connection); + } else { + return .{ .failed = try makeResponse(allocator, model_name, "Ollama stream ended before headers were sent", false) }; } return .streamed; @@ -275,6 +268,18 @@ pub fn callQwen( ), }; + const thinking_value = if (message_object.get("thinking")) |thinking| + switch (thinking) { + .string => |value| value, + else => "", + } + else + ""; + + const tuning = resolveTuning(app_config, request); + const output = try formatAssistantOutputAlloc(allocator, content_value, thinking_value, tuning.think); + errdefer allocator.free(output); + // Ollama may omit done_reason; default to `stop` for OpenAI compatibility. const finish_reason = if (root.get("done_reason")) |done_reason| switch (done_reason) { @@ -287,7 +292,7 @@ pub fn callQwen( return .{ .id = null, .model = try allocator.dupe(u8, model_name), - .output = try allocator.dupe(u8, content_value), + .output = output, .finish_reason = try allocator.dupe(u8, finish_reason), .success = true, .usage = .{ @@ -299,6 +304,54 @@ pub fn callQwen( }; } +fn formatAssistantOutputAlloc( + allocator: std.mem.Allocator, + content: []const u8, + thinking: []const u8, + include_thinking: bool, +) ![]u8 { + if (!include_thinking or std.mem.trim(u8, thinking, " \t\r\n").len == 0) { + return try allocator.dupe(u8, content); + } + + if (std.mem.trim(u8, content, " \t\r\n").len == 0) { + return try std.fmt.allocPrint(allocator, "{s}", .{thinking}); + } + + return try std.fmt.allocPrint( + allocator, + "{s}\n{s}", + .{ thinking, content }, + ); +} + +test "formatAssistantOutputAlloc includes thinking when requested" { + const allocator = std.testing.allocator; + const output = try formatAssistantOutputAlloc( + allocator, + "final answer", + "reasoning trace", + true, + ); + defer allocator.free(output); + + try std.testing.expect(std.mem.indexOf(u8, output, "reasoning trace") != null); + try std.testing.expect(std.mem.indexOf(u8, output, "final answer") != null); +} + +test "formatAssistantOutputAlloc omits thinking when disabled" { + const allocator = std.testing.allocator; + const output = try formatAssistantOutputAlloc( + allocator, + "final answer", + "reasoning trace", + false, + ); + defer allocator.free(output); + + try std.testing.expectEqualStrings("final answer", output); +} + pub fn buildStatusJsonAlloc( allocator: std.mem.Allocator, app_config: *const config.Config, @@ -481,17 +534,124 @@ fn buildTagsUrl( return try std.fmt.allocPrint(allocator, "{s}/api/tags", .{base_url}); } +fn streamChatMessagesToSse( + connection: std.net.Server.Connection, + allocator: std.mem.Allocator, + client: *std.http.Client, + app_config: *const config.Config, + request: types.Request, + uri: std.Uri, + completion_id: []const u8, + model_name: []const u8, + messages_json: []const u8, + headers_sent: *bool, + streamed_text: *std.ArrayList(u8), + think_block_open: *bool, +) !StreamAttemptResult { + const body = try buildChatPayloadAlloc( + allocator, + app_config, + request, + model_name, + messages_json, + true, + ); + defer allocator.free(body); + + const headers = &[_]std.http.Header{ + .{ .name = "content-type", .value = "application/json" }, + }; + + var req = try client.request(.POST, uri, .{ + .extra_headers = headers, + .keep_alive = false, + }); + defer req.deinit(); + + req.transfer_encoding = .{ .content_length = body.len }; + + var req_body = try req.sendBodyUnflushed(&.{}); + try req_body.writer.writeAll(body); + try req_body.end(); + try req.connection.?.flush(); + + var upstream = try req.receiveHead(&.{}); + + if (upstream.head.status != .ok) { + const error_message = try std.fmt.allocPrint( + allocator, + "Ollama HTTP {s} for model '{s}' during stream request", + .{ @tagName(upstream.head.status), model_name }, + ); + defer allocator.free(error_message); + + return .{ .failed = try makeResponse(allocator, model_name, error_message, false) }; + } + + if (!headers_sent.*) { + try response.sendEventStreamHeaders(connection); + headers_sent.* = true; + } + + var transfer_buffer: [2048]u8 = undefined; + const body_reader = upstream.reader(transfer_buffer[0..]); + + var pending = std.ArrayList(u8){}; + defer pending.deinit(allocator); + + var read_buf: [1024]u8 = undefined; + var state = OllamaStreamState{ .think_block_open = think_block_open }; + + while (!state.saw_done) { + const n = body_reader.readSliceShort(read_buf[0..]) catch |err| switch (err) { + error.ReadFailed => return upstream.bodyErr() orelse err, + }; + if (n == 0) break; + + try pending.appendSlice(allocator, read_buf[0..n]); + try processPendingStreamLines( + connection, + allocator, + completion_id, + model_name, + &pending, + &state, + streamed_text, + ); + } + + if (pending.items.len > 0 and !state.saw_done) { + const trailing = std.mem.trim(u8, pending.items, "\r\n \t"); + if (trailing.len > 0) { + try handleOllamaStreamLine( + connection, + allocator, + completion_id, + model_name, + trailing, + &state, + streamed_text, + ); + } + } + + return .{ + .streamed = .{ + .saw_done = state.saw_done, + .stopped_by_length = state.stopped_by_length, + }, + }; +} + fn processPendingStreamLines( connection: std.net.Server.Connection, allocator: std.mem.Allocator, completion_id: []const u8, fallback_model: []const u8, - lines: []const u8, pending: *std.ArrayList(u8), - saw_done: *bool, + state: *OllamaStreamState, + streamed_text: *std.ArrayList(u8), ) !void { - _ = lines; - while (std.mem.indexOfScalar(u8, pending.items, '\n')) |line_end| { const raw_line = pending.items[0..line_end]; const line = std.mem.trimRight(u8, raw_line, "\r"); @@ -502,7 +662,8 @@ fn processPendingStreamLines( completion_id, fallback_model, line, - saw_done, + state, + streamed_text, ); } @@ -519,7 +680,8 @@ fn handleOllamaStreamLine( completion_id: []const u8, fallback_model: []const u8, line: []const u8, - saw_done: *bool, + state: *OllamaStreamState, + streamed_text: *std.ArrayList(u8), ) !void { var parsed = std.json.parseFromSlice(std.json.Value, allocator, line, .{}) catch { return; @@ -546,36 +708,49 @@ fn handleOllamaStreamLine( }; if (message_object) |obj| { - var token: ?[]const u8 = null; - - if (obj.get("content")) |content_value| { - const content = switch (content_value) { + const thinking_text = if (obj.get("thinking")) |thinking_value| + switch (thinking_value) { .string => |value| value, else => "", - }; - - if (content.len > 0) token = content; - } - - if (token == null) { - if (obj.get("thinking")) |thinking_value| { - const thinking = switch (thinking_value) { - .string => |value| value, - else => "", - }; + } + else + ""; - if (thinking.len > 0) token = thinking; + const content_text = if (obj.get("content")) |content_value| + switch (content_value) { + .string => |value| value, + else => "", } + else + ""; + + // Stream thinking wrapped in ... tags so the client + // can render it distinctly. Thinking is intentionally NOT appended + // to streamed_text: that buffer feeds the "continue from cutoff" + // prompt on length-stops, and including reasoning there causes the + // model to restart its think block on every continuation. + if (thinking_text.len > 0) { + if (!state.think_block_open.*) { + try response.sendChatCompletionChunkSse( + connection, allocator, completion_id, model, "", null, + ); + state.think_block_open.* = true; + } + try response.sendChatCompletionChunkSse( + connection, allocator, completion_id, model, thinking_text, null, + ); } - if (token) |text| { + if (content_text.len > 0) { + if (state.think_block_open.*) { + try response.sendChatCompletionChunkSse( + connection, allocator, completion_id, model, "", null, + ); + state.think_block_open.* = false; + } + try streamed_text.appendSlice(allocator, content_text); try response.sendChatCompletionChunkSse( - connection, - allocator, - completion_id, - model, - text, - null, + connection, allocator, completion_id, model, content_text, null, ); } } @@ -588,6 +763,7 @@ fn handleOllamaStreamLine( }; if (response_text.len > 0) { + try streamed_text.appendSlice(allocator, response_text); try response.sendChatCompletionChunkSse( connection, allocator, @@ -607,7 +783,7 @@ fn handleOllamaStreamLine( else false; - if (done and !saw_done.*) { + if (done and !state.saw_done) { const finish_reason = if (root.get("done_reason")) |reason_value| switch (reason_value) { .string => |value| value, @@ -616,16 +792,8 @@ fn handleOllamaStreamLine( else "stop"; - try response.sendChatCompletionChunkSse( - connection, - allocator, - completion_id, - model, - null, - finish_reason, - ); - try response.sendSseDone(connection); - saw_done.* = true; + state.saw_done = true; + state.stopped_by_length = std.ascii.eqlIgnoreCase(finish_reason, "length"); } } @@ -899,6 +1067,31 @@ fn renderMessagesJsonAlloc( return try out.toOwnedSlice(allocator); } +fn renderContinuationMessagesJsonAlloc( + allocator: std.mem.Allocator, + base_messages: []const types.Message, + assistant_content: []const u8, + continue_prompt: []const u8, +) ![]u8 { + const messages = try allocator.alloc(types.Message, base_messages.len + 2); + defer allocator.free(messages); + + for (base_messages, 0..) |message, index| { + messages[index] = message; + } + + messages[base_messages.len] = .{ + .role = "assistant", + .content = assistant_content, + }; + messages[base_messages.len + 1] = .{ + .role = "user", + .content = continue_prompt, + }; + + return try renderMessagesJsonAlloc(allocator, messages); +} + fn parseUsageField(value: ?std.json.Value) usize { const actual_value = value orelse return 0; From 5dcfca64d9c290c48a9bc9ae34a5286877353622 Mon Sep 17 00:00:00 2001 From: Ajeets6 Date: Fri, 15 May 2026 14:32:20 +1000 Subject: [PATCH 07/25] Add compactContextToBudgetAlloc for message compression and retention --- src/backend/session.zig | 160 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 159 insertions(+), 1 deletion(-) diff --git a/src/backend/session.zig b/src/backend/session.zig index 194eafa..ceae5aa 100644 --- a/src/backend/session.zig +++ b/src/backend/session.zig @@ -212,6 +212,57 @@ pub fn shouldCompressContext(estimated_tokens: usize, max_context_tokens: usize) return estimated_tokens > max_context_tokens; } +pub fn compactContextToBudgetAlloc( + allocator: std.mem.Allocator, + messages: []const types.Message, + max_context_tokens: usize, + keep_last_messages: usize, +) ![]types.Message { + if (!shouldCompressContext(estimateTokenCount(messages), max_context_tokens)) { + return try cloneMessagesAlloc(allocator, messages); + } + + const keep_count = @min(keep_last_messages, messages.len); + const summary = try compressContextSummaryAlloc(allocator, messages, keep_count); + defer allocator.free(summary); + + var out = std.ArrayList(types.Message){}; + errdefer { + for (out.items) |message| { + message.deinit(allocator); + } + out.deinit(allocator); + } + + if (summary.len > 0) { + try out.append(allocator, .{ + .role = try allocator.dupe(u8, "system"), + .content = try allocator.dupe(u8, summary), + }); + } + + const start = messages.len - keep_count; + for (messages[start..]) |message| { + try out.append(allocator, .{ + .role = try allocator.dupe(u8, message.role), + .content = try allocator.dupe(u8, message.content), + }); + } + + while (out.items.len > 2 and estimateTokenCount(out.items) > max_context_tokens) { + const drop_index: usize = if (std.mem.eql(u8, out.items[0].role, "system")) 1 else 0; + var dropped = out.orderedRemove(drop_index); + dropped.deinit(allocator); + } + + if (out.items.len > 1 and estimateTokenCount(out.items) > max_context_tokens and std.mem.eql(u8, out.items[0].role, "system")) { + var dropped = out.orderedRemove(0); + dropped.deinit(allocator); + } + + return try out.toOwnedSlice(allocator); +} + pub fn compressContextSummaryAlloc( allocator: std.mem.Allocator, messages: []const types.Message, @@ -227,7 +278,13 @@ pub fn compressContextSummaryAlloc( try out.appendSlice(allocator, "Summary of previous context:\n"); for (messages[0..summarize_until]) |message| { - try out.writer(allocator).print("- {s}: {s}\n", .{ message.role, message.content }); + const max_content_bytes = 80; + const content = message.content[0..@min(message.content.len, max_content_bytes)]; + try out.writer(allocator).print("- {s}: {s}{s}\n", .{ + message.role, + content, + if (message.content.len > max_content_bytes) "..." else "", + }); } return try out.toOwnedSlice(allocator); @@ -517,3 +574,104 @@ test "trimToRetentionAlloc keeps latest messages" { try std.testing.expectEqualStrings("two", trimmed[0].content); try std.testing.expectEqualStrings("three", trimmed[1].content); } + +test "mergeMessagesAlloc preserves history before incoming messages" { + const allocator = std.testing.allocator; + const history = [_]types.Message{ + .{ .role = "user", .content = "one" }, + .{ .role = "assistant", .content = "two" }, + }; + const incoming = [_]types.Message{ + .{ .role = "user", .content = "three" }, + }; + + const merged = try mergeMessagesAlloc(allocator, history[0..], incoming[0..]); + defer { + for (merged) |message| { + message.deinit(allocator); + } + allocator.free(merged); + } + + try std.testing.expectEqual(@as(usize, 3), merged.len); + try std.testing.expectEqualStrings("one", merged[0].content); + try std.testing.expectEqualStrings("two", merged[1].content); + try std.testing.expectEqualStrings("three", merged[2].content); +} + +test "FileSessionStore roundtrips state and enforces retention" { + const allocator = std.testing.allocator; + const store_path = try std.fmt.allocPrint( + allocator, + ".zig-cache/session-test-{d}", + .{std.time.nanoTimestamp()}, + ); + defer allocator.free(store_path); + defer std.fs.cwd().deleteTree(store_path) catch {}; + + var file_store = try FileSessionStore.init(allocator, store_path, 2); + defer file_store.deinit(allocator); + var store = file_store.asStore(); + + var messages = try allocator.alloc(types.Message, 3); + messages[0] = .{ .role = try allocator.dupe(u8, "user"), .content = try allocator.dupe(u8, "one") }; + messages[1] = .{ .role = try allocator.dupe(u8, "assistant"), .content = try allocator.dupe(u8, "two") }; + messages[2] = .{ .role = try allocator.dupe(u8, "user"), .content = try allocator.dupe(u8, "three") }; + + var state = SessionState{ + .session_id = try allocator.dupe(u8, "session/a"), + .tenant_id = try allocator.dupe(u8, "tenant"), + .summary = try allocator.dupe(u8, "kept summary"), + .messages = messages, + .message_count = messages.len, + }; + defer state.deinit(allocator); + + try store.save(allocator, state); + + const loaded_opt = try store.load(allocator, "session/a", "tenant"); + try std.testing.expect(loaded_opt != null); + const loaded = loaded_opt.?; + defer loaded.deinit(allocator); + + try std.testing.expectEqualStrings("session/a", loaded.session_id); + try std.testing.expectEqualStrings("tenant", loaded.tenant_id.?); + try std.testing.expectEqualStrings("kept summary", loaded.summary); + try std.testing.expectEqual(@as(usize, 2), loaded.messages.len); + try std.testing.expectEqualStrings("two", loaded.messages[0].content); + try std.testing.expectEqualStrings("three", loaded.messages[1].content); +} + +test "compactContextToBudgetAlloc adds summary and keeps latest messages" { + const allocator = std.testing.allocator; + const messages = [_]types.Message{ + .{ .role = "user", .content = "older user content that should be summarized because it is intentionally long and no longer needs to remain as raw chat history in the outgoing provider context" }, + .{ .role = "assistant", .content = "older assistant content that should be summarized because it is intentionally long and no longer needs to remain as raw chat history in the outgoing provider context" }, + .{ .role = "user", .content = "latest user" }, + .{ .role = "assistant", .content = "latest assistant" }, + }; + + const compacted = try compactContextToBudgetAlloc(allocator, messages[0..], 70, 2); + defer { + for (compacted) |message| { + message.deinit(allocator); + } + allocator.free(compacted); + } + + try std.testing.expect(compacted.len >= 2); + try std.testing.expectEqualStrings("system", compacted[0].role); + try std.testing.expect(std.mem.indexOf(u8, compacted[0].content, "Summary of previous context") != null); + try std.testing.expectEqualStrings("latest assistant", compacted[compacted.len - 1].content); +} + +test "context token estimator triggers compression past budget" { + const messages = [_]types.Message{ + .{ .role = "user", .content = "12345678901234567890" }, + }; + + const estimated = estimateTokenCount(messages[0..]); + try std.testing.expect(estimated > 0); + try std.testing.expect(shouldCompressContext(estimated, estimated - 1)); + try std.testing.expect(!shouldCompressContext(estimated, estimated)); +} From e87b12364a284f3b180ede0e5a7cbde5eed447d6 Mon Sep 17 00:00:00 2001 From: Ajeets6 Date: Fri, 15 May 2026 14:33:08 +1000 Subject: [PATCH 08/25] Add tool call tracking and confirmation for command execution --- src/backend/tools.zig | 62 ++++++++++++++++++++ src/tools/command_exec.zig | 117 +++++++++++++++++++++++++++++++++++-- src/tools/file_ops.zig | 38 ++++++++++-- 3 files changed, 208 insertions(+), 9 deletions(-) diff --git a/src/backend/tools.zig b/src/backend/tools.zig index 8efbc76..0282515 100644 --- a/src/backend/tools.zig +++ b/src/backend/tools.zig @@ -42,6 +42,11 @@ pub fn validateRequestedTools( return true; } +pub fn noteToolCall(max_calls: usize, used_calls: *usize) !void { + if (used_calls.* >= max_calls) return error.ToolCallLimitExceeded; + used_calls.* += 1; +} + /// Executes simple built-in debug tools directly in the harness. /// /// This is intentionally minimal and deterministic so UI clients can verify @@ -65,11 +70,17 @@ pub fn tryExecuteDebugTool( if (std.ascii.eqlIgnoreCase(choice, "cmd")) { if (!hasRequestedTool(request.tools, "cmd")) return null; + const command = try command_exec_tool.extractCommandFromPromptAlloc(allocator, .cmd, request.prompt); + defer if (command) |value| allocator.free(value); + if (command == null) return null; return try command_exec_tool.execute(allocator, app_config, request, .cmd); } if (std.ascii.eqlIgnoreCase(choice, "bash")) { if (!hasRequestedTool(request.tools, "bash")) return null; + const command = try command_exec_tool.extractCommandFromPromptAlloc(allocator, .bash, request.prompt); + defer if (command) |value| allocator.free(value); + if (command == null) return null; return try command_exec_tool.execute(allocator, app_config, request, .bash); } @@ -109,9 +120,11 @@ pub fn maybeExecutePromptToolsAlloc( defer output.deinit(allocator); var executed_any = false; + var tool_calls: usize = 0; try output.appendSlice(allocator, "Tool execution results:\n"); if (hasRequestedTool(request.tools, "utc") and mentionsUtcIntent(prompt)) { + try noteToolCall(app_config.max_tool_calls_per_request, &tool_calls); const utc_result = try utc_tool.execute(allocator, request); defer utc_result.deinit(allocator); @@ -126,6 +139,7 @@ pub fn maybeExecutePromptToolsAlloc( const file_search_requested = hasRequestedTool(request.tools, "file_search"); if (file_read_requested and mentionsFileReadIntent(prompt)) { + try noteToolCall(app_config.max_tool_calls_per_request, &tool_calls); const read_result = try file_ops_tool.execute(allocator, app_config, request, .read); defer read_result.deinit(allocator); try output.writer(allocator).print("\n[tool=file_read]\n{s}\n", .{read_result.output}); @@ -133,6 +147,7 @@ pub fn maybeExecutePromptToolsAlloc( } if (file_search_requested and mentionsFileSearchIntent(prompt)) { + try noteToolCall(app_config.max_tool_calls_per_request, &tool_calls); const search_result = try file_ops_tool.execute(allocator, app_config, request, .search); defer search_result.deinit(allocator); try output.writer(allocator).print("\n[tool=file_search]\n{s}\n", .{search_result.output}); @@ -158,6 +173,7 @@ pub fn maybeExecutePromptToolsAlloc( const write_req = try buildSinglePromptRequestAlloc(allocator, write_prompt); defer write_req.deinit(allocator); + try noteToolCall(app_config.max_tool_calls_per_request, &tool_calls); const write_result = try file_ops_tool.execute(allocator, app_config, write_req, .write); defer write_result.deinit(allocator); @@ -170,6 +186,7 @@ pub fn maybeExecutePromptToolsAlloc( const run_req = try buildSinglePromptRequestAlloc(allocator, run_prompt); defer run_req.deinit(allocator); + try noteToolCall(app_config.max_tool_calls_per_request, &tool_calls); const run_result = try command_exec_tool.execute(allocator, app_config, run_req, .cmd); defer run_result.deinit(allocator); @@ -180,6 +197,7 @@ pub fn maybeExecutePromptToolsAlloc( const run_req = try buildSinglePromptRequestAlloc(allocator, run_prompt); defer run_req.deinit(allocator); + try noteToolCall(app_config.max_tool_calls_per_request, &tool_calls); const run_result = try command_exec_tool.execute(allocator, app_config, run_req, .bash); defer run_result.deinit(allocator); @@ -195,6 +213,7 @@ pub fn maybeExecutePromptToolsAlloc( const cmd_request = try buildSinglePromptRequestAlloc(allocator, command); defer cmd_request.deinit(allocator); + try noteToolCall(app_config.max_tool_calls_per_request, &tool_calls); const cmd_result = try command_exec_tool.execute(allocator, app_config, cmd_request, .cmd); defer cmd_result.deinit(allocator); @@ -211,6 +230,7 @@ pub fn maybeExecutePromptToolsAlloc( const bash_request = try buildSinglePromptRequestAlloc(allocator, command); defer bash_request.deinit(allocator); + try noteToolCall(app_config.max_tool_calls_per_request, &tool_calls); const bash_result = try command_exec_tool.execute(allocator, app_config, bash_request, .bash); defer bash_result.deinit(allocator); @@ -493,3 +513,45 @@ test "maybeExecutePromptToolsAlloc extracts command for cmd" { try std.testing.expect(std.mem.indexOf(u8, summary.?, "[tool=cmd]") != null); try std.testing.expect(std.mem.indexOf(u8, summary.?, "command=zig build test") != null); } + +test "tryExecuteDebugTool ignores natural language cmd selection" { + const allocator = std.testing.allocator; + + const messages = try allocator.alloc(types.Message, 1); + messages[0] = .{ + .role = try allocator.dupe(u8, "user"), + .content = try allocator.dupe(u8, "Please write a small C++ program"), + }; + + const tools = try allocator.alloc(types.Tool, 1); + tools[0] = .{ + .name = try allocator.dupe(u8, "cmd"), + .description = try allocator.dupe(u8, "windows command tool"), + }; + + const req = types.Request{ + .prompt = try allocator.dupe(u8, "Please write a small C++ program"), + .messages = messages, + .provider = null, + .model = null, + .session_id = null, + .tenant_id = null, + .max_context_tokens = null, + .tools = tools, + .tool_choice = try allocator.dupe(u8, "cmd"), + }; + defer req.deinit(allocator); + + var cfg = try command_exec_tool.buildTestConfig(allocator, true); + defer cfg.deinit(allocator); + + const maybe_result = try tryExecuteDebugTool(allocator, req, &cfg); + try std.testing.expect(maybe_result == null); +} + +test "noteToolCall enforces configured tool call cap" { + var used: usize = 0; + try noteToolCall(1, &used); + try std.testing.expectEqual(@as(usize, 1), used); + try std.testing.expectError(error.ToolCallLimitExceeded, noteToolCall(1, &used)); +} diff --git a/src/tools/command_exec.zig b/src/tools/command_exec.zig index 6e4c6ae..c74aaf3 100644 --- a/src/tools/command_exec.zig +++ b/src/tools/command_exec.zig @@ -23,6 +23,47 @@ const PipeCapture = struct { done: std.atomic.Value(bool) = std.atomic.Value(bool).init(false), }; +const tool_confirm_prefix = "LLM_ROUTER_TOOL_CONFIRM "; + +fn parseHexU64(input: []const u8) ?u64 { + var value: u64 = 0; + if (input.len == 0) return null; + for (input) |c| { + const digit: u8 = switch (c) { + '0'...'9' => c - '0', + 'a'...'f' => c - 'a' + 10, + 'A'...'F' => c - 'A' + 10, + else => return null, + }; + value = (value << 4) | @as(u64, @intCast(digit)); + } + return value; +} + +fn findToolConfirmToken(prompt: []const u8) ?u64 { + const idx = std.mem.indexOf(u8, prompt, tool_confirm_prefix) orelse return null; + const after = prompt[idx + tool_confirm_prefix.len ..]; + const token_end = std.mem.indexOfAny(u8, after, " \t\r\n") orelse after.len; + const token_str = after[0..token_end]; + return parseHexU64(token_str); +} + +fn stripToolConfirmMarkerSuffix(prompt: []const u8) []const u8 { + const idx = std.mem.indexOf(u8, prompt, tool_confirm_prefix) orelse return prompt; + return prompt[0..idx]; +} + +fn computeToolConfirmHash(flavor: ShellFlavor, command: []const u8) u64 { + var h = std.hash.Wyhash.init(0); + const flavor_bytes = switch (flavor) { + .cmd => "cmd", + .bash => "bash", + }; + h.update(flavor_bytes); + h.update(command); + return h.final(); +} + pub fn execute( allocator: std.mem.Allocator, app_config: *const config.Config, @@ -46,8 +87,8 @@ pub fn execute( ); } - const prompt = std.mem.trim(u8, request.prompt, " \t\r\n"); - if (prompt.len == 0) { + const raw_prompt = std.mem.trim(u8, request.prompt, " \t\r\n"); + if (raw_prompt.len == 0) { return try makeToolResponse( allocator, tool_name, @@ -59,6 +100,16 @@ pub fn execute( ); } + const maybe_confirm_token = if (app_config.tool_exec_confirmation_required) + findToolConfirmToken(raw_prompt) + else + null; + + const prompt = if (app_config.tool_exec_confirmation_required) + stripToolConfirmMarkerSuffix(raw_prompt) + else + raw_prompt; + const command = blk: { if (looksLikeCommandPrompt(prompt)) { break :blk validateAndNormalizeCommandAlloc(allocator, prompt, flavor) catch |err| switch (err) { @@ -94,6 +145,41 @@ pub fn execute( }; defer allocator.free(command); + if (app_config.tool_exec_confirmation_required) { + const expected = computeToolConfirmHash(flavor, command); + if (maybe_confirm_token == null) { + const token_hex = try std.fmt.allocPrint(allocator, "{x}", .{expected}); + defer allocator.free(token_hex); + + const out = try std.fmt.allocPrint( + allocator, + "DEBUG_TOOL_CONFIRMATION_REQUIRED\n" ++ + "tool={s}\n" ++ + "command={s}\n" ++ + "confirm_token={s}\n" ++ + "message=tool execution requires confirmation. Re-send the same command with a line containing: LLM_ROUTER_TOOL_CONFIRM {s}\n", + .{ tool_name, command, token_hex, token_hex }, + ); + return try makeToolResponse(allocator, tool_name, out); + } + + if (maybe_confirm_token.? != expected) { + const token_hex = try std.fmt.allocPrint(allocator, "{x}", .{expected}); + defer allocator.free(token_hex); + + const out = try std.fmt.allocPrint( + allocator, + "DEBUG_TOOL_CONFIRMATION_REQUIRED\n" ++ + "tool={s}\n" ++ + "command={s}\n" ++ + "confirm_token={s}\n" ++ + "message=tool execution requires confirmation. Re-send the same command with a line containing: LLM_ROUTER_TOOL_CONFIRM {s}\n", + .{ tool_name, command, token_hex, token_hex }, + ); + return try makeToolResponse(allocator, tool_name, out); + } + } + const start_ms = std.time.milliTimestamp(); const argv = switch (flavor) { @@ -225,7 +311,8 @@ pub fn extractCommandFromPromptAlloc( flavor: ShellFlavor, prompt: []const u8, ) !?[]u8 { - const trimmed = std.mem.trim(u8, prompt, " \t\r\n\"'"); + const without_confirm = stripToolConfirmMarkerSuffix(prompt); + const trimmed = std.mem.trim(u8, without_confirm, " \t\r\n\"'"); if (trimmed.len == 0) return null; if (extractQuotedCommandSlice(trimmed)) |quoted| { @@ -454,6 +541,8 @@ fn extractQuotedCommandSlice(text: []const u8) ?[]const u8 { fn isLikelyNaturalLanguageLeadToken(token: []const u8) bool { const words = [_][]const u8{ "what", + "i", + "we", "why", "how", "when", @@ -467,6 +556,10 @@ fn isLikelyNaturalLanguageLeadToken(token: []const u8) bool { "tell", "show", "explain", + "write", + "create", + "build", + "make", "get", }; @@ -548,7 +641,7 @@ pub fn buildTestConfig(allocator: std.mem.Allocator, enable_exec: bool) !config. .ollama_base_url = try allocator.dupe(u8, "http://127.0.0.1:11434"), .ollama_model = try allocator.dupe(u8, "qwen:7b"), .ollama_think = false, - .ollama_num_predict = 128, + .ollama_num_predict = 1024, .ollama_temperature = 0.7, .ollama_repeat_penalty = 1.05, .openai_base_url = try allocator.dupe(u8, "https://api.openai.com/v1"), @@ -574,10 +667,14 @@ pub fn buildTestConfig(allocator: std.mem.Allocator, enable_exec: bool) !config. .session_store_path = try allocator.dupe(u8, "logs/sessions"), .session_retention_messages = 24, .tool_exec_enabled = enable_exec, + .tool_exec_confirmation_required = false, .tool_exec_timeout_ms = 15_000, .tool_exec_max_output_bytes = 65_536, + .max_tool_calls_per_request = 8, .loop_stream_progress_enabled = true, .max_concurrent_connections = 64, + .max_request_bytes = 1024 * 1024, + .max_header_bytes = 16 * 1024, }; } @@ -628,6 +725,18 @@ test "extractCommandFromPromptAlloc extracts from run prefix" { try std.testing.expectEqualStrings("zig build test", maybe_cmd.?); } +test "extractCommandFromPromptAlloc ignores natural language requests" { + const allocator = std.testing.allocator; + const maybe_cmd = try extractCommandFromPromptAlloc( + allocator, + .cmd, + "Please write a small C++ program", + ); + defer if (maybe_cmd) |cmd| allocator.free(cmd); + + try std.testing.expect(maybe_cmd == null); +} + test "execute cmd tool rejects unsafe chaining" { if (@import("builtin").os.tag != .windows) return; diff --git a/src/tools/file_ops.zig b/src/tools/file_ops.zig index 385241a..0aaf5c2 100644 --- a/src/tools/file_ops.zig +++ b/src/tools/file_ops.zig @@ -40,6 +40,8 @@ fn executeRead( ) !types.Response { const path_raw = parseSingleArgAfterPrefix(prompt, "read ") orelse parseSingleArgAfterPrefix(prompt, "file_read ") orelse + parseSingleArgAfterPrefix(prompt, "exists ") orelse + parseSingleArgAfterPrefix(prompt, "file_exists ") orelse parseSingleArgAfterPrefix(prompt, "open ") orelse prompt; @@ -96,7 +98,7 @@ fn executeWrite( prompt: []const u8, ) !types.Response { const parsed = parseWritePrompt(allocator, prompt) catch blk: { - const inferred_path = inferPyFilenameFromPromptAlloc(allocator, prompt) orelse break :blk null; + const inferred_path = inferWriteFilenameFromPromptAlloc(allocator, prompt) orelse break :blk null; errdefer allocator.free(inferred_path); const inferred_content = extractLastFencedCodeBlockAlloc(allocator, request.messages) orelse { @@ -180,8 +182,8 @@ fn executeWrite( const out = try std.fmt.allocPrint( allocator, - "DEBUG_TOOL_OK\ntool={s}\npath={s}\nbytes_written={d}\nduration_ms={d}", - .{ tool_name, tmp_rel_path, parsed.content.len, duration_ms }, + "DEBUG_TOOL_OK\ntool={s}\npath={s}\nbytes_written={d}\nduration_ms={d}\n--- content ---\n{s}", + .{ tool_name, tmp_rel_path, parsed.content.len, duration_ms, parsed.content }, ); return try makeToolResponse(allocator, tool_name, out); } @@ -279,7 +281,7 @@ fn executeSearch( return try makeToolResponse(allocator, tool_name, try out.toOwnedSlice(allocator)); } -fn inferPyFilenameFromPromptAlloc(allocator: std.mem.Allocator, prompt: []const u8) ?[]u8 { +fn inferWriteFilenameFromPromptAlloc(allocator: std.mem.Allocator, prompt: []const u8) ?[]u8 { const trimmed = std.mem.trim(u8, prompt, " \t\r\n"); if (trimmed.len == 0) return null; @@ -287,7 +289,7 @@ fn inferPyFilenameFromPromptAlloc(allocator: std.mem.Allocator, prompt: []const var last_match: ?[]const u8 = null; while (it.next()) |tok| { if (tok.len < 4) continue; - if (!std.ascii.endsWithIgnoreCase(tok, ".py")) continue; + if (!isSupportedWriteFilename(tok)) continue; last_match = tok; } @@ -297,6 +299,19 @@ fn inferPyFilenameFromPromptAlloc(allocator: std.mem.Allocator, prompt: []const return null; } +fn isSupportedWriteFilename(path: []const u8) bool { + return std.ascii.endsWithIgnoreCase(path, ".py") or + std.ascii.endsWithIgnoreCase(path, ".cpp") or + std.ascii.endsWithIgnoreCase(path, ".cc") or + std.ascii.endsWithIgnoreCase(path, ".cxx") or + std.ascii.endsWithIgnoreCase(path, ".c") or + std.ascii.endsWithIgnoreCase(path, ".h") or + std.ascii.endsWithIgnoreCase(path, ".hpp") or + std.ascii.endsWithIgnoreCase(path, ".js") or + std.ascii.endsWithIgnoreCase(path, ".ts") or + std.ascii.endsWithIgnoreCase(path, ".zig"); +} + fn extractLastFencedCodeBlockAlloc(allocator: std.mem.Allocator, messages: []const types.Message) ?[]u8 { var i: usize = messages.len; while (i > 0) { @@ -457,3 +472,16 @@ fn makeToolResponse( }; } +test "inferWriteFilenameFromPromptAlloc supports C++ filenames" { + const allocator = std.testing.allocator; + const path = inferWriteFilenameFromPromptAlloc(allocator, "file_write app.cpp") orelse return error.ExpectedFilename; + defer allocator.free(path); + + try std.testing.expectEqualStrings("app.cpp", path); +} + +test "isSupportedWriteFilename keeps Python inference" { + try std.testing.expect(isSupportedWriteFilename("count.py")); + try std.testing.expect(!isSupportedWriteFilename("notes.txt")); +} + From 9b52b78425c4c586115ddb2f852d1f5174a219fa Mon Sep 17 00:00:00 2001 From: Ajeets6 Date: Fri, 15 May 2026 14:33:48 +1000 Subject: [PATCH 09/25] Add mcp server bridge --- src/backend/mcp.zig | 133 ++++++++++++++++++++++++++++++++++++++++++++ src/config.zig | 23 +++++++- src/root.zig | 2 + 3 files changed, 155 insertions(+), 3 deletions(-) create mode 100644 src/backend/mcp.zig diff --git a/src/backend/mcp.zig b/src/backend/mcp.zig new file mode 100644 index 0000000..a3f7466 --- /dev/null +++ b/src/backend/mcp.zig @@ -0,0 +1,133 @@ +//! Minimal MCP discovery bridge. +//! +//! This module intentionally stops at bounded tool metadata parsing and +//! allowlist registration. Transport and invocation stay outside the hot path +//! until an MCP server is explicitly wired in. +const std = @import("std"); +const types = @import("../types.zig"); +const tooling = @import("tools.zig"); + +pub fn parseDiscoveredToolsAlloc( + allocator: std.mem.Allocator, + discovery_json: []const u8, + max_tools: usize, +) ![]types.Tool { + var parsed = std.json.parseFromSlice(std.json.Value, allocator, discovery_json, .{}) catch |err| switch (err) { + error.OutOfMemory => return err, + else => return try allocator.alloc(types.Tool, 0), + }; + defer parsed.deinit(); + + const root = switch (parsed.value) { + .object => |object| object, + else => return try allocator.alloc(types.Tool, 0), + }; + + const tools_value = root.get("tools") orelse return try allocator.alloc(types.Tool, 0); + const tool_items = switch (tools_value) { + .array => |array| array.items, + else => return try allocator.alloc(types.Tool, 0), + }; + + var out = std.ArrayList(types.Tool){}; + errdefer { + for (out.items) |tool| { + tool.deinit(allocator); + } + out.deinit(allocator); + } + + for (tool_items) |item| { + if (out.items.len >= max_tools) break; + + const object = switch (item) { + .object => |value| value, + else => continue, + }; + + const raw_name = switch (object.get("name") orelse continue) { + .string => |value| std.mem.trim(u8, value, " \t\r\n"), + else => continue, + }; + if (!isSafeToolName(raw_name)) continue; + if (hasToolNamed(out.items, raw_name)) continue; + + const raw_description = if (object.get("description")) |description| + switch (description) { + .string => |value| std.mem.trim(u8, value, " \t\r\n"), + else => "", + } + else + ""; + + try out.append(allocator, .{ + .name = try allocator.dupe(u8, raw_name), + .description = try allocator.dupe(u8, raw_description[0..@min(raw_description.len, 512)]), + }); + } + + return try out.toOwnedSlice(allocator); +} + +pub fn registerDiscoveredTools( + registry: *tooling.ToolRegistry, + allocator: std.mem.Allocator, + discovered_tools: []const types.Tool, +) !void { + for (discovered_tools) |tool| { + try registry.register(allocator, tool.name); + } +} + +fn hasToolNamed(tools: []const types.Tool, name: []const u8) bool { + for (tools) |tool| { + if (std.mem.eql(u8, tool.name, name)) return true; + } + return false; +} + +fn isSafeToolName(name: []const u8) bool { + if (name.len == 0 or name.len > 64) return false; + for (name) |c| { + switch (c) { + 'a'...'z', 'A'...'Z', '0'...'9', '_', '-', '.' => {}, + else => return false, + } + } + return true; +} + +test "parseDiscoveredToolsAlloc maps bounded MCP metadata" { + const allocator = std.testing.allocator; + const json = + \\{"tools":[ + \\ {"name":"search.docs","description":"Search project docs"}, + \\ {"name":"bad name","description":"rejected"}, + \\ {"name":"search.docs","description":"duplicate"} + \\]} + ; + + const tools = try parseDiscoveredToolsAlloc(allocator, json, 8); + defer { + for (tools) |tool| { + tool.deinit(allocator); + } + allocator.free(tools); + } + + try std.testing.expectEqual(@as(usize, 1), tools.len); + try std.testing.expectEqualStrings("search.docs", tools[0].name); +} + +test "registerDiscoveredTools maps MCP tools into allowlist" { + const allocator = std.testing.allocator; + const discovered = [_]types.Tool{ + .{ .name = "search.docs", .description = "Search project docs" }, + }; + + var registry = tooling.ToolRegistry.init(allocator); + defer registry.deinit(allocator); + + try registerDiscoveredTools(®istry, allocator, discovered[0..]); + try std.testing.expect(registry.isAllowed("search.docs")); +} diff --git a/src/config.zig b/src/config.zig index 7176baa..3ad7ee0 100644 --- a/src/config.zig +++ b/src/config.zig @@ -46,10 +46,14 @@ pub const Config = struct { session_store_path: []const u8, session_retention_messages: usize, tool_exec_enabled: bool, + tool_exec_confirmation_required: bool, tool_exec_timeout_ms: u32, tool_exec_max_output_bytes: usize, + max_tool_calls_per_request: usize, loop_stream_progress_enabled: bool, max_concurrent_connections: usize, + max_request_bytes: usize, + max_header_bytes: usize, pub fn load(allocator: std.mem.Allocator) !Config { return try loadWithOverrides(allocator, null); @@ -111,8 +115,8 @@ pub const Config = struct { "qwen3.5:9b", env_overrides, ), - .ollama_think = try getEnvFlag(allocator, "OLLAMA_THINK", env_overrides), - .ollama_num_predict = try getEnvU32OrDefault("OLLAMA_NUM_PREDICT", 128, env_overrides), + .ollama_think = try getEnvFlagOrDefault(allocator, "OLLAMA_THINK", true, env_overrides), + .ollama_num_predict = try getEnvU32OrDefault("OLLAMA_NUM_PREDICT", 2048, env_overrides), .ollama_temperature = try getEnvF64OrDefault(allocator, "OLLAMA_TEMPERATURE", 0.7, env_overrides), .ollama_repeat_penalty = try getEnvF64OrDefault(allocator, "OLLAMA_REPEAT_PENALTY", 1.05, env_overrides), .openai_base_url = try getEnvOrDefault( @@ -247,10 +251,19 @@ pub const Config = struct { env_overrides, ), .tool_exec_enabled = try getEnvFlag(allocator, "LLM_ROUTER_TOOL_EXEC_ENABLED", env_overrides), + .tool_exec_confirmation_required = try getEnvFlagOrDefault( + allocator, + "LLM_ROUTER_TOOL_EXEC_CONFIRM_REQUIRED", + true, + env_overrides, + ), .tool_exec_timeout_ms = try getEnvPositiveU32OrDefault("LLM_ROUTER_TOOL_EXEC_TIMEOUT_MS", 15_000, env_overrides), .tool_exec_max_output_bytes = try getEnvPositiveUsizeOrDefault("LLM_ROUTER_TOOL_EXEC_MAX_OUTPUT_BYTES", 65_536, env_overrides), + .max_tool_calls_per_request = try getEnvPositiveUsizeOrDefault("LLM_ROUTER_MAX_TOOL_CALLS_PER_REQUEST", 8, env_overrides), .loop_stream_progress_enabled = try getEnvFlagOrDefault(allocator, "LLM_ROUTER_LOOP_STREAM_PROGRESS_ENABLED", true, env_overrides), .max_concurrent_connections = try getEnvPositiveUsizeOrDefault("LLM_ROUTER_MAX_CONCURRENT_CONNECTIONS", 64, env_overrides), + .max_request_bytes = try getEnvPositiveUsizeOrDefault("LLM_ROUTER_MAX_REQUEST_BYTES", 1024 * 1024, env_overrides), + .max_header_bytes = try getEnvPositiveUsizeOrDefault("LLM_ROUTER_MAX_HEADER_BYTES", 16 * 1024, env_overrides), }; } @@ -499,7 +512,7 @@ test "setDefaultProvider stores canonical provider alias" { .ollama_base_url = try allocator.dupe(u8, "http://127.0.0.1:11434"), .ollama_model = try allocator.dupe(u8, "qwen:7b"), .ollama_think = false, - .ollama_num_predict = 512, + .ollama_num_predict = 1024, .ollama_temperature = 0.7, .ollama_repeat_penalty = 1.05, .openai_base_url = try allocator.dupe(u8, "https://api.openai.com/v1"), @@ -525,10 +538,14 @@ test "setDefaultProvider stores canonical provider alias" { .session_store_path = try allocator.dupe(u8, "logs/sessions"), .session_retention_messages = 24, .tool_exec_enabled = false, + .tool_exec_confirmation_required = false, .tool_exec_timeout_ms = 15_000, .tool_exec_max_output_bytes = 65_536, + .max_tool_calls_per_request = 8, .loop_stream_progress_enabled = true, .max_concurrent_connections = 64, + .max_request_bytes = 1024 * 1024, + .max_header_bytes = 16 * 1024, }; defer cfg.deinit(allocator); diff --git a/src/root.zig b/src/root.zig index ee0d664..9f0215b 100644 --- a/src/root.zig +++ b/src/root.zig @@ -9,6 +9,7 @@ pub const backend = @import("backend/api.zig"); pub const auth = @import("backend/auth.zig"); pub const session = @import("backend/session.zig"); pub const tools = @import("backend/tools.zig"); +pub const mcp = @import("backend/mcp.zig"); pub const core = @import("core/server.zig"); pub const config = @import("config.zig"); pub const types = @import("types.zig"); @@ -25,6 +26,7 @@ test "backend: auth and tool scaffolding" { _ = @import("backend/auth.zig"); _ = @import("backend/tools.zig"); _ = @import("backend/session.zig"); + _ = @import("backend/mcp.zig"); } test "core: HTTP request parsing" { From 2e60d6edf721d4047c7e58e2a3c94458770b1df2 Mon Sep 17 00:00:00 2001 From: Ajeets6 Date: Fri, 15 May 2026 14:34:19 +1000 Subject: [PATCH 10/25] Update README and client to enhance tool functionality and command confirmation --- README.md | 42 +++++- client/test_client.py | 116 +++++++++++++-- src/backend/errors.zig | 12 ++ src/core/request.zig | 18 ++- src/core/server.zig | 326 +++++++++++++++++++++++++++++++++++++++-- src/react.zig | 216 +++++++++++++++++++++++++-- 6 files changed, 682 insertions(+), 48 deletions(-) diff --git a/README.md b/README.md index 57c2c57..ab0f07d 100644 --- a/README.md +++ b/README.md @@ -19,9 +19,7 @@ This project exposes a chat-completions API, routes requests across multiple pro - Optional debug tooling (echo, utc, cmd, bash) with safe defaults. - Simple deploy surface: single Zig binary, environment-based configuration. -## Scope freeze (Phase 0) -Phase 0 is **complete**: one canonical chat path and response shape, with deferrals listed in `plan.md` (distributed orchestration, deep tenancy/RBAC, heavy observability, plugin marketplaces). Operational routes (`/health`, `/metrics`, `/diagnostics/*`) stay diagnostics-only. ## Architecture @@ -93,11 +91,11 @@ zig build test -Dtest-target=all -Dtest-filter=normalizeProviderName | GET | /metrics | Yes (if API key configured) | Request and connection counters | | GET | /diagnostics/clients | Yes (if API key configured) | Connected client diagnostics | | GET | /diagnostics/requests | Yes (if API key configured) | Request success/failure diagnostics | -| GET | /diagnostics/providers | No | Provider status snapshot | +| GET | /diagnostics/providers | Yes (if API key configured) | Provider status snapshot | | POST | /v1/chat/completions | Yes (if API key configured) | OpenAI-compatible chat-completions | > [!NOTE] -> Authentication is enabled when LLM_ROUTER_API_KEY is non-empty. When enabled, all routes require auth except /health and /diagnostics/providers. +> Authentication is enabled when LLM_ROUTER_API_KEY is non-empty. When enabled, all routes require auth except /health. ### Chat Request Shape @@ -224,11 +222,15 @@ Registered tool names: - utc - cmd - bash +- file_read +- file_write +- file_search Tool behavior summary: - echo and utc are deterministic debug helpers. - cmd and bash are guarded command-execution tools. +- file_read, file_write, and file_search are lightweight filesystem helpers for trusted local workflows. - Command execution is disabled by default and must be explicitly enabled. ```bash @@ -240,6 +242,7 @@ Related limits: - LLM_ROUTER_TOOL_EXEC_TIMEOUT_MS (default: 15000) - LLM_ROUTER_TOOL_EXEC_MAX_OUTPUT_BYTES (default: 65536) +- LLM_ROUTER_MAX_TOOL_CALLS_PER_REQUEST (default: 8) ## Configuration Reference @@ -256,6 +259,9 @@ Related limits: | LLM_ROUTER_REQUEST_TIMEOUT_MS | 30000 | | LLM_ROUTER_PROVIDER_TIMEOUT_MS | 60000 | | LLM_ROUTER_LOOP_STREAM_PROGRESS_ENABLED | true | +| LLM_ROUTER_MAX_CONCURRENT_CONNECTIONS | 64 | +| LLM_ROUTER_MAX_REQUEST_BYTES | 1048576 | +| LLM_ROUTER_MAX_HEADER_BYTES | 16384 | ### Session Storage @@ -342,12 +348,36 @@ Related limits: -d '{"messages":[{"role":"user","content":"Say hello from zig-coding-agent"}]}' ``` -4. Provider Diagnostics +4. Stateful Session Check + + ```bash + curl -s http://127.0.0.1:8081/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{"session_id":"demo-session","messages":[{"role":"user","content":"Remember that my test color is blue."}]}' + ``` + +5. Tool Check + + ```bash + curl -s http://127.0.0.1:8081/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{"messages":[{"role":"user","content":"What time is it in UTC?"}],"tools":[{"name":"utc","description":"Current UTC time"}],"tool_choice":"auto"}' + ``` + +6. Provider Diagnostics ```bash curl -s http://127.0.0.1:8081/diagnostics/providers ``` +## Short Runbook + +- 401 Unauthorized: set `LLM_ROUTER_API_KEY` on the server and send either `X-Api-Key: ` or `Authorization: Bearer `. +- 413 request_too_large: lower the prompt size or raise `LLM_ROUTER_MAX_REQUEST_BYTES` for trusted deployments. +- 504 provider_timeout: check provider reachability and `LLM_ROUTER_PROVIDER_TIMEOUT_MS`. +- unknown_tool: request only registered tools listed above. +- provider_not_configured: set the provider API key or switch to a local provider. + ## Project Layout ``` @@ -361,6 +391,7 @@ zig_coding_agent │ ├── api.zig │ ├── auth.zig │ ├── errors.zig + │ ├── mcp.zig │ ├── session.zig │ └── tools.zig ├── config.zig @@ -383,6 +414,7 @@ zig_coding_agent ├── tools │ ├── command_exec.zig │ ├── echo.zig + │ ├── file_ops.zig │ └── utc.zig └── types.zig ``` diff --git a/client/test_client.py b/client/test_client.py index 374bbd8..a42d37f 100644 --- a/client/test_client.py +++ b/client/test_client.py @@ -11,6 +11,10 @@ st.caption("Debug and test harness behavior: chat routing, tool execution, sessions, streaming, and diagnostics.") CHAT_STATE_KEY = "chat_messages" +PENDING_CMD_CONFIRMATION_KEY = "pending_cmd_confirmation" + +CMD_CONFIRMATION_TAG = "DEBUG_TOOL_CONFIRMATION_REQUIRED" +CONFIRM_MARKER_PREFIX = "LLM_ROUTER_TOOL_CONFIRM " @dataclass @@ -109,6 +113,32 @@ def render_chat(messages: List[ChatMessage]) -> None: st.markdown(msg.content) +def parse_cmd_confirmation(text: str) -> Optional[Dict[str, str]]: + if CMD_CONFIRMATION_TAG not in text: + return None + + tool: Optional[str] = None + command: Optional[str] = None + confirm_token: Optional[str] = None + + for raw_line in text.splitlines(): + line = raw_line.strip() + if line.startswith("tool="): + tool = line[len("tool=") :].strip() + elif line.startswith("command="): + command = line[len("command=") :].strip() + elif line.startswith("confirm_token="): + confirm_token = line[len("confirm_token=") :].strip() + + if tool and command and confirm_token: + return {"tool": tool, "command": command, "confirm_token": confirm_token} + return None + + +def set_pending_cmd_confirmation(value: Optional[Dict[str, str]]) -> None: + st.session_state[PENDING_CMD_CONFIRMATION_KEY] = value + + with st.sidebar: st.header("Connection") endpoint = st.text_input( @@ -151,7 +181,7 @@ def render_chat(messages: List[ChatMessage]) -> None: st.header("Loop") loop_mode = st.selectbox( "loop_mode", - options=["none", "basic", "agent"], + options=["none", "basic", "agent", "react"], index=0, help="Server-side iterative loop mode usable from any frontend client.", ) @@ -166,10 +196,25 @@ def render_chat(messages: List[ChatMessage]) -> None: ) st.header("Tools (debug)") + + import platform as _platform + _is_windows = _platform.system().lower().startswith("win") + _shell_tool = "cmd" if _is_windows else "bash" + CODING_TOOLS = [_shell_tool, "file_read", "file_write", "file_search"] + + auto_tools_for_react = st.checkbox( + "Auto-attach coding tools when loop_mode=react", + value=True, + help=( + "When enabled and loop_mode is 'react', send the full coding tool set " + f"({', '.join(CODING_TOOLS)}) with tool_choice=auto. Overrides the manual selection below." + ), + ) + selected_tools = st.multiselect( - "Available tools to send", + "Available tools to send (manual selection)", options=["echo", "utc", "cmd", "bash", "file_read", "file_write", "file_search"], - default=["utc", "cmd", "file_write"], + default=CODING_TOOLS, help="Multiple tools can be sent in one request. cmd/bash require server env LLM_ROUTER_TOOL_EXEC_ENABLED=1.", ) @@ -191,8 +236,20 @@ def render_chat(messages: List[ChatMessage]) -> None: help="Use auto for prompt-driven tool execution with multiple tools.", ) + if auto_tools_for_react and loop_mode == "react": + selected_tools = CODING_TOOLS + tool_choice = "auto" + st.caption( + f"react mode active — sending tools={CODING_TOOLS} with tool_choice=auto." + ) + -def build_payload(messages: List[ChatMessage], force_stream: bool) -> Dict: +def build_payload( + messages: List[ChatMessage], + force_stream: bool, + tool_choice_override: Optional[str] = None, + tools_override: Optional[List[str]] = None, +) -> Dict: payload: Dict = { "provider": provider.strip() or "ollama", "messages": [{"role": m.role, "content": m.content} for m in messages if m.content.strip()], @@ -219,7 +276,9 @@ def build_payload(messages: List[ChatMessage], force_stream: bool) -> Dict: if loop_max_turns > 0: payload["loop_max_turns"] = int(loop_max_turns) - if selected_tools: + tools_to_send = tools_override if tools_override is not None else selected_tools + + if tools_to_send: tool_descriptions = { "echo": "Return a deterministic echo response for harness debugging.", "utc": "Return current UTC timestamp for deterministic tool-path checks.", @@ -231,9 +290,11 @@ def build_payload(messages: List[ChatMessage], force_stream: bool) -> Dict: } payload["tools"] = [ {"name": tool_name, "description": tool_descriptions.get(tool_name, "debug tool")} - for tool_name in selected_tools + for tool_name in tools_to_send ] - if tool_choice != "none": + if tool_choice_override is not None: + payload["tool_choice"] = tool_choice_override + elif tool_choice != "none": payload["tool_choice"] = tool_choice if force_stream: @@ -254,12 +315,21 @@ def request_timeout_seconds(force_stream: bool) -> int: return timeout -def send_chat_request(force_stream: bool) -> Optional[Dict]: +def send_chat_request( + force_stream: bool, + tool_choice_override: Optional[str] = None, + tools_override: Optional[List[str]] = None, +) -> Optional[Dict]: if not endpoint.strip(): st.error("Endpoint is required.") return None - payload = build_payload(chat_state(), force_stream=force_stream) + payload = build_payload( + chat_state(), + force_stream=force_stream, + tool_choice_override=tool_choice_override, + tools_override=tools_override, + ) st.code(json.dumps(payload, indent=2), language="json") try: @@ -291,7 +361,9 @@ def send_chat_request(force_stream: bool) -> Optional[Dict]: with st.chat_message("assistant"): streamed_text = st.write_stream(stream_probe_lines(resp)) - add_chat_message("assistant", str(streamed_text)) + assistant_text = str(streamed_text) + add_chat_message("assistant", assistant_text) + set_pending_cmd_confirmation(parse_cmd_confirmation(assistant_text)) return None try: @@ -307,6 +379,7 @@ def send_chat_request(force_stream: bool) -> Optional[Dict]: with st.chat_message("assistant"): st.markdown(assistant) add_chat_message("assistant", assistant) + set_pending_cmd_confirmation(parse_cmd_confirmation(assistant)) return data @@ -342,6 +415,29 @@ def send_chat_request(force_stream: bool) -> Optional[Dict]: else: send_chat_request(force_stream=False) +pending_cmd = st.session_state.get(PENDING_CMD_CONFIRMATION_KEY) +if pending_cmd: + tool = pending_cmd["tool"] + st.warning(f"Backend requested confirmation to execute `{tool}`.") + st.code(pending_cmd["command"]) + if st.button("Confirm command execution", use_container_width=True): + confirm_prompt = ( + f"{pending_cmd['command']}\n{CONFIRM_MARKER_PREFIX}{pending_cmd['confirm_token']}" + ) + add_chat_message("user", confirm_prompt) + + tools_override = selected_tools if isinstance(selected_tools, list) else [] + if tool not in tools_override: + tools_override = tools_override + [tool] + + # Force the tool choice so the backend executes immediately after confirmation. + set_pending_cmd_confirmation(None) + send_chat_request( + force_stream=stream_mode, + tool_choice_override=tool, + tools_override=tools_override, + ) + st.subheader("Diagnostics") diag_col1, diag_col2, diag_col3 = st.columns(3) diff --git a/src/backend/errors.zig b/src/backend/errors.zig index 3b5c1f8..e7ccb2f 100644 --- a/src/backend/errors.zig +++ b/src/backend/errors.zig @@ -273,3 +273,15 @@ test "providerFailureFromDetail maps missing key configuration" { try std.testing.expectEqual(@as(u16, 500), mapped.status_code); try std.testing.expectEqualStrings("provider_not_configured", mapped.code.?); } + +test "providerFailureFromDetail maps provider auth failures" { + const mapped = providerFailureFromDetail("openrouter", "OpenRouter returned HTTP 401"); + try std.testing.expectEqual(@as(u16, 502), mapped.status_code); + try std.testing.expectEqualStrings("provider_auth_failed", mapped.code.?); +} + +test "providerFailureFromDetail maps invalid provider payloads" { + const mapped = providerFailureFromDetail("bedrock", "Bedrock returned invalid JSON"); + try std.testing.expectEqual(@as(u16, 502), mapped.status_code); + try std.testing.expectEqualStrings("provider_invalid_response", mapped.code.?); +} diff --git a/src/core/request.zig b/src/core/request.zig index 3dee16b..fcc5a53 100644 --- a/src/core/request.zig +++ b/src/core/request.zig @@ -88,10 +88,9 @@ pub fn readHttpRequest( allocator: std.mem.Allocator, connection: *std.net.Server.Connection, request_timeout_ms: u32, + max_request_size: usize, + max_header_size: usize, ) ![]u8 { - const max_request_size = 1024 * 1024; - const max_header_size = 16 * 1024; - const started_ms = std.time.milliTimestamp(); var request = std.ArrayList(u8){}; @@ -133,7 +132,7 @@ pub fn readHttpRequest( error.MissingContentLength => 0, else => return err, }; - const required_length = header_end.? + content_length; + const required_length = try checkedRequiredLength(header_end.?, content_length); if (required_length > max_request_size) { return error.RequestTooLarge; } @@ -166,6 +165,17 @@ pub fn readHttpRequest( return try request.toOwnedSlice(allocator); } +pub fn checkedRequiredLength(header_len: usize, content_length: usize) !usize { + return std.math.add(usize, header_len, content_length) catch error.RequestTooLarge; +} + +test "checkedRequiredLength rejects integer overflow" { + try std.testing.expectError( + error.RequestTooLarge, + checkedRequiredLength(std.math.maxInt(usize), 1), + ); +} + fn waitForReadable( socket: std.posix.socket_t, timeout_ms: u32, diff --git a/src/core/server.zig b/src/core/server.zig index a118cd2..2c53885 100644 --- a/src/core/server.zig +++ b/src/core/server.zig @@ -11,12 +11,19 @@ const tooling = @import("../backend/tools.zig"); const session = @import("../backend/session.zig"); const react = @import("../react.zig"); +const context_compaction_keep_messages = 8; +const length_continue_prompt = + "The previous assistant message was cut off by the generation limit. " ++ + "Continue the same answer from the exact next token. Output only the continuation text; " ++ + "do not mention truncation, continuation, the previous message, or these instructions."; + const ServerState = struct { mutex: std.Thread.Mutex = .{}, total_requests: u64 = 0, successful_requests: u64 = 0, failed_requests: u64 = 0, active_connections: u64 = 0, + provider_latency_buckets: [4]u64 = .{ 0, 0, 0, 0 }, connected_clients: std.ArrayList([]u8), fn init() ServerState { @@ -70,6 +77,20 @@ const ServerState = struct { self.active_connections -= 1; } + fn noteProviderLatency(self: *ServerState, elapsed_ms: u64) void { + self.mutex.lock(); + defer self.mutex.unlock(); + const idx: usize = if (elapsed_ms < 50) + 0 + else if (elapsed_ms < 200) + 1 + else if (elapsed_ms < 1000) + 2 + else + 3; + self.provider_latency_buckets[idx] += 1; + } + fn snapshot(self: *ServerState) Snapshot { self.mutex.lock(); defer self.mutex.unlock(); @@ -78,6 +99,7 @@ const ServerState = struct { .successful_requests = self.successful_requests, .failed_requests = self.failed_requests, .active_connections = self.active_connections, + .provider_latency_buckets = self.provider_latency_buckets, }; } @@ -105,6 +127,7 @@ const Snapshot = struct { successful_requests: u64, failed_requests: u64, active_connections: u64, + provider_latency_buckets: [4]u64, }; const ConnectionGate = struct { @@ -295,7 +318,13 @@ fn handleConnection( session_store_guard: *SessionStoreGuard, ) !void { // Translate transport-level parsing failures into OpenAI-style API errors. - const request_raw = request.readHttpRequest(allocator, connection, app_config.request_timeout_ms) catch |err| switch (err) { + const request_raw = request.readHttpRequest( + allocator, + connection, + app_config.request_timeout_ms, + app_config.max_request_bytes, + app_config.max_header_bytes, + ) catch |err| switch (err) { error.ClientDisconnected => { debugLog(app_config, "client disconnected before full request", .{}); return; @@ -396,13 +425,17 @@ fn handleConnection( const snapshot = server_state.snapshot(); const metrics_json = try std.fmt.allocPrint( allocator, - "{{\"instance_id\":\"{s}\",\"total_requests\":{d},\"successful_requests\":{d},\"failed_requests\":{d},\"active_connections\":{d}}}", + "{{\"instance_id\":\"{s}\",\"total_requests\":{d},\"successful_requests\":{d},\"failed_requests\":{d},\"active_connections\":{d},\"provider_latency_buckets_ms\":{{\"lt_50\":{d},\"lt_200\":{d},\"lt_1000\":{d},\"gte_1000\":{d}}}}}", .{ app_config.instance_id, snapshot.total_requests, snapshot.successful_requests, snapshot.failed_requests, snapshot.active_connections, + snapshot.provider_latency_buckets[0], + snapshot.provider_latency_buckets[1], + snapshot.provider_latency_buckets[2], + snapshot.provider_latency_buckets[3], }, ); defer allocator.free(metrics_json); @@ -539,10 +572,26 @@ fn handleConnection( if (parsed_req.max_context_tokens) |max_tokens| { const estimated = session.estimateTokenCount(request_messages); if (session.shouldCompressContext(estimated, max_tokens)) { + const compacted_messages = try session.compactContextToBudgetAlloc( + allocator, + request_messages, + max_tokens, + context_compaction_keep_messages, + ); + + for (request_messages) |message| { + message.deinit(allocator); + } + allocator.free(request_messages); + request_messages = compacted_messages; + + allocator.free(request_prompt); + request_prompt = try extractLastUserPromptAlloc(allocator, request_messages); + debugLog( app_config, - "context compression suggested estimated_tokens={d} max_context_tokens={d}", - .{ estimated, max_tokens }, + "context compacted estimated_tokens={d} compacted_tokens={d} max_context_tokens={d}", + .{ estimated, session.estimateTokenCount(request_messages), max_tokens }, ); } } @@ -613,10 +662,51 @@ fn handleConnection( return; } - const stream_auto_tool_summary = try tooling.maybeExecutePromptToolsAlloc(allocator, stream_request, app_config); + const stream_auto_tool_summary = tooling.maybeExecutePromptToolsAlloc(allocator, stream_request, app_config) catch |err| switch (err) { + error.ToolCallLimitExceeded => { + sendApiErrorSafe(connection.*, allocator, backend.errors.validationError( + "Tool call limit exceeded", + "tools", + "tool_call_limit_exceeded", + ), app_config); + server_state.noteRequestFailed(); + return; + }, + else => return err, + }; defer if (stream_auto_tool_summary) |summary| allocator.free(summary); if (stream_auto_tool_summary) |summary| { + if (tooling.shouldShortCircuitAutoTools(stream_request)) { + try response.sendEventStreamHeaders(connection.*); + const completion_id = try std.fmt.allocPrint( + allocator, + "chatcmpl-{d}", + .{std.time.microTimestamp()}, + ); + defer allocator.free(completion_id); + + try response.sendChatCompletionChunkSse( + connection.*, + allocator, + completion_id, + "debug-tools/auto", + summary, + null, + ); + try response.sendChatCompletionChunkSse( + connection.*, + allocator, + completion_id, + "debug-tools/auto", + null, + "tool", + ); + try response.sendSseDone(connection.*); + server_state.noteRequestSucceeded(); + return; + } + const augmented_messages = try appendSystemMessageAlloc(allocator, stream_request.messages, summary); for (stream_request.messages) |message| { message.deinit(allocator); @@ -690,7 +780,18 @@ fn handleConnection( ); defer provider_request.deinit(allocator); - const auto_tool_summary = try tooling.maybeExecutePromptToolsAlloc(allocator, provider_request, app_config); + const auto_tool_summary = tooling.maybeExecutePromptToolsAlloc(allocator, provider_request, app_config) catch |err| switch (err) { + error.ToolCallLimitExceeded => { + sendApiErrorSafe(connection.*, allocator, backend.errors.validationError( + "Tool call limit exceeded", + "tools", + "tool_call_limit_exceeded", + ), app_config); + server_state.noteRequestFailed(); + return; + }, + else => return err, + }; defer if (auto_tool_summary) |summary| allocator.free(summary); if (auto_tool_summary) |summary| { @@ -730,6 +831,16 @@ fn handleConnection( const provider_started_ms = std.time.milliTimestamp(); if (loop_enabled) { const loop_execution = executeLoopRequestAlloc(allocator, app_config, provider_request) catch |err| { + if (err == error.ToolCallLimitExceeded) { + sendApiErrorSafe( + connection.*, + allocator, + backend.errors.validationError("Tool call limit exceeded", "tools", "tool_call_limit_exceeded"), + app_config, + ); + server_state.noteRequestFailed(); + return; + } logError("Provider loop request error: {}", .{err}); sendApiErrorSafe( connection.*, @@ -760,6 +871,7 @@ fn handleConnection( } const provider_elapsed_ms: u64 = @intCast(@max(std.time.milliTimestamp() - provider_started_ms, 0)); + server_state.noteProviderLatency(provider_elapsed_ms); // Log a warning when the provider exceeded the configured timeout budget, but // do NOT discard the already-completed response. A post-hoc check cannot // cancel an in-flight HTTP call; discarding a valid result would only confuse @@ -788,6 +900,12 @@ fn handleConnection( return; } + if (!loop_enabled and isLengthFinishReason(result.finish_reason)) { + continueLengthResponseAlloc(allocator, app_config, provider_request, &result) catch |err| { + logError("Provider continuation failed: {s}", .{@errorName(err)}); + }; + } + if (std.mem.trim(u8, result.output, " \t\r\n").len == 0) { if (auto_tool_summary) |summary| { allocator.free(result.output); @@ -925,6 +1043,43 @@ const LoopExecution = struct { messages: []types.Message, }; +fn executeReactActionForRequest( + allocator: std.mem.Allocator, + app_config: *const config.Config, + turn_request: types.Request, + tool_messages: []const types.Message, + action: react.ReactAction, +) ![]u8 { + return switch (action) { + .tool => |tool_action| blk: { + var tool_request = try cloneRequestWithMessagesAlloc( + allocator, + turn_request, + tool_action.argument, + tool_messages, + ); + defer tool_request.deinit(allocator); + + if (tool_request.tool_choice) |value| { + allocator.free(value); + } + tool_request.tool_choice = try allocator.dupe(u8, tool_action.name); + + if (try tooling.tryExecuteDebugTool(allocator, tool_request, app_config)) |tool_result| { + defer tool_result.deinit(allocator); + break :blk try allocator.dupe(u8, tool_result.output); + } + + break :blk try std.fmt.allocPrint( + allocator, + "Tool action {s} was not requested or its input could not be parsed.", + .{tool_action.name}, + ); + }, + else => try react.executeReactAction(allocator, action, app_config), + }; +} + fn streamLoopRequestToSse( connection: std.net.Server.Connection, allocator: std.mem.Allocator, @@ -966,11 +1121,13 @@ fn streamLoopRequestToSse( allocator.free(working_messages); working_messages = with_guidance; } else if (std.ascii.eqlIgnoreCase(loop_mode, "react")) { + const react_prompt = try react.buildSystemPromptAlloc(allocator, base_request.tools); + defer allocator.free(react_prompt); const with_react = try appendRoleMessageAlloc( allocator, working_messages, "system", - react.system_prompt, + react_prompt, ); for (working_messages) |message| { message.deinit(allocator); @@ -985,6 +1142,7 @@ fn streamLoopRequestToSse( var previous_output: ?[]u8 = null; defer if (previous_output) |value| allocator.free(value); var repeated_count: usize = 0; + var react_tool_calls: usize = 0; var turn: usize = 0; while (turn < loop_max_turns) : (turn += 1) { @@ -1001,8 +1159,28 @@ fn streamLoopRequestToSse( } turn_request.loop_max_turns = null; - var turn_result = try backend.callProvider(allocator, app_config, turn_request); + // In ReAct mode the tool catalog is advertised inside the system + // prompt and parsed locally; we must NOT forward `tools` upstream or + // providers like Ollama activate their XML tool-calling template and + // choke on prior assistant messages containing markdown or + // (manifests as "internal_server_error ... element closed + // by "). We strip tools on a shallow alias so the original + // turn_request keeps them for local action dispatch. + const empty_tools: []types.Tool = &.{}; + const provider_request: types.Request = if (std.ascii.eqlIgnoreCase(loop_mode, "react")) blk: { + var r = turn_request; + r.tools = empty_tools; + r.tool_choice = null; + break :blk r; + } else turn_request; + + var turn_result = try backend.callProvider(allocator, app_config, provider_request); defer turn_result.deinit(allocator); + if (turn_result.success and isLengthFinishReason(turn_result.finish_reason)) { + continueLengthResponseAlloc(allocator, app_config, provider_request, &turn_result) catch |err| { + logError("Provider loop continuation failed: {s}", .{@errorName(err)}); + }; + } const with_assistant = try appendRoleMessageAlloc( allocator, @@ -1062,7 +1240,8 @@ fn streamLoopRequestToSse( react_finished = true; }, else => { - const obs_raw = try react.executeReactAction(allocator, action, app_config); + try tooling.noteToolCall(app_config.max_tool_calls_per_request, &react_tool_calls); + const obs_raw = try executeReactActionForRequest(allocator, app_config, turn_request, working_messages, action); defer allocator.free(obs_raw); react_observation = try react.formatObservation(allocator, turn + 1, obs_raw); }, @@ -1171,11 +1350,13 @@ fn executeLoopRequestAlloc( allocator.free(working_messages); working_messages = with_guidance; } else if (std.ascii.eqlIgnoreCase(loop_mode, "react")) { + const react_prompt = try react.buildSystemPromptAlloc(allocator, base_request.tools); + defer allocator.free(react_prompt); const with_react = try appendRoleMessageAlloc( allocator, working_messages, "system", - react.system_prompt, + react_prompt, ); for (working_messages) |message| { message.deinit(allocator); @@ -1190,6 +1371,7 @@ fn executeLoopRequestAlloc( var previous_output: ?[]u8 = null; defer if (previous_output) |value| allocator.free(value); var repeated_count: usize = 0; + var react_tool_calls: usize = 0; var turn: usize = 0; while (turn < loop_max_turns) : (turn += 1) { @@ -1206,7 +1388,22 @@ fn executeLoopRequestAlloc( } turn_request.loop_max_turns = null; - var turn_result = try backend.callProvider(allocator, app_config, turn_request); + // See streamLoopRequestToSse: strip tools/tool_choice from the + // upstream call in ReAct mode to avoid provider-side tool templates. + const empty_tools: []types.Tool = &.{}; + const provider_request: types.Request = if (std.ascii.eqlIgnoreCase(loop_mode, "react")) blk: { + var r = turn_request; + r.tools = empty_tools; + r.tool_choice = null; + break :blk r; + } else turn_request; + + var turn_result = try backend.callProvider(allocator, app_config, provider_request); + if (turn_result.success and isLengthFinishReason(turn_result.finish_reason)) { + continueLengthResponseAlloc(allocator, app_config, provider_request, &turn_result) catch |err| { + logError("Provider loop continuation failed: {s}", .{@errorName(err)}); + }; + } const with_assistant = try appendRoleMessageAlloc( allocator, @@ -1264,7 +1461,8 @@ fn executeLoopRequestAlloc( react_finished = true; }, else => { - const obs_raw = try react.executeReactAction(allocator, action, app_config); + try tooling.noteToolCall(app_config.max_tool_calls_per_request, &react_tool_calls); + const obs_raw = try executeReactActionForRequest(allocator, app_config, turn_request, working_messages, action); defer allocator.free(obs_raw); react_observation = try react.formatObservation(allocator, turn + 1, obs_raw); }, @@ -1383,6 +1581,69 @@ fn appendRoleMessageAlloc( return try out.toOwnedSlice(allocator); } +fn continueLengthResponseAlloc( + allocator: std.mem.Allocator, + app_config: *const config.Config, + base_request: types.Request, + result: *types.Response, +) !void { + const max_continuations: usize = 4; + var working_messages = try session.cloneMessagesAlloc(allocator, base_request.messages); + defer { + for (working_messages) |message| { + message.deinit(allocator); + } + allocator.free(working_messages); + } + + const with_first_assistant = try appendRoleMessageAlloc(allocator, working_messages, "assistant", result.output); + for (working_messages) |message| { + message.deinit(allocator); + } + allocator.free(working_messages); + working_messages = with_first_assistant; + + var turn: usize = 0; + while (turn < max_continuations and isLengthFinishReason(result.finish_reason)) : (turn += 1) { + const continue_prompt = length_continue_prompt; + const with_continue = try appendRoleMessageAlloc(allocator, working_messages, "user", continue_prompt); + for (working_messages) |message| { + message.deinit(allocator); + } + allocator.free(working_messages); + working_messages = with_continue; + + var turn_request = try cloneRequestWithMessagesAlloc(allocator, base_request, continue_prompt, working_messages); + defer turn_request.deinit(allocator); + + var turn_result = try backend.callProvider(allocator, app_config, turn_request); + defer turn_result.deinit(allocator); + if (!turn_result.success) return; + if (std.mem.trim(u8, turn_result.output, " \t\r\n").len == 0) return; + + const joined = try std.fmt.allocPrint(allocator, "{s}{s}", .{ result.output, turn_result.output }); + allocator.free(result.output); + result.output = joined; + + allocator.free(result.finish_reason); + result.finish_reason = try allocator.dupe(u8, turn_result.finish_reason); + result.usage.prompt_tokens += turn_result.usage.prompt_tokens; + result.usage.completion_tokens += turn_result.usage.completion_tokens; + result.usage.total_tokens += turn_result.usage.total_tokens; + + const with_assistant = try appendRoleMessageAlloc(allocator, working_messages, "assistant", turn_result.output); + for (working_messages) |message| { + message.deinit(allocator); + } + allocator.free(working_messages); + working_messages = with_assistant; + } +} + +fn isLengthFinishReason(finish_reason: []const u8) bool { + return std.ascii.eqlIgnoreCase(finish_reason, "length"); +} + fn sendApiErrorSafe( connection: std.net.Server.Connection, allocator: std.mem.Allocator, @@ -1434,7 +1695,6 @@ fn swallowSocketWriteError( fn requiresAuth(route: router.Route) bool { return switch (route) { .health => false, - .diagnostics_providers => false, else => true, }; } @@ -1469,3 +1729,43 @@ test "connection gate enforces bounded capacity" { gate.release(); try std.testing.expect(gate.tryAcquire()); } + +test "connection gate recovers capacity after releases" { + var gate = ConnectionGate{ .max_active = 1 }; + try std.testing.expect(gate.tryAcquire()); + try std.testing.expect(!gate.tryAcquire()); + + gate.release(); + try std.testing.expect(gate.tryAcquire()); + gate.release(); + try std.testing.expect(gate.tryAcquire()); +} + +test "requiresAuth allows only health without an API key" { + try std.testing.expect(!requiresAuth(.health)); + try std.testing.expect(requiresAuth(.metrics)); + try std.testing.expect(requiresAuth(.diagnostics_providers)); + try std.testing.expect(requiresAuth(.chat_completions)); +} + +test "server state tracks provider latency buckets" { + var state = ServerState.init(); + defer state.deinit(std.testing.allocator); + + state.noteProviderLatency(10); + state.noteProviderLatency(75); + state.noteProviderLatency(500); + state.noteProviderLatency(1500); + + const snapshot = state.snapshot(); + try std.testing.expectEqual(@as(u64, 1), snapshot.provider_latency_buckets[0]); + try std.testing.expectEqual(@as(u64, 1), snapshot.provider_latency_buckets[1]); + try std.testing.expectEqual(@as(u64, 1), snapshot.provider_latency_buckets[2]); + try std.testing.expectEqual(@as(u64, 1), snapshot.provider_latency_buckets[3]); +} + +test "isLengthFinishReason detects provider length stops" { + try std.testing.expect(isLengthFinishReason("length")); + try std.testing.expect(isLengthFinishReason("LENGTH")); + try std.testing.expect(!isLengthFinishReason("stop")); +} diff --git a/src/react.zig b/src/react.zig index 238d5ad..fd7bc11 100644 --- a/src/react.zig +++ b/src/react.zig @@ -21,7 +21,7 @@ pub const system_prompt = "Available actions:\n" ++ " Search[query] - search for information\n" ++ " Lookup[term] - look up a term in the current context\n" ++ - " Cmd[command] - execute a shell command (only if tools are enabled)\n" ++ + " Cmd[command] - execute a shell command (may require confirmation)\n" ++ " Finish[answer] - return the final answer and end the loop\n\n" ++ "Rules:\n" ++ "- Always start with a Thought, then an Action.\n" ++ @@ -31,12 +31,48 @@ pub const system_prompt = "- If you need information, ask for the exact term or command that will resolve it.\n" ++ "- If an action fails, adjust your approach in the next Thought.\n"; +pub fn buildSystemPromptAlloc(allocator: std.mem.Allocator, tools: []const types.Tool) ![]u8 { + if (tools.len == 0) return try allocator.dupe(u8, system_prompt); + + var out = std.ArrayList(u8){}; + errdefer out.deinit(allocator); + + try out.appendSlice(allocator, system_prompt); + try out.appendSlice( + allocator, + "\nRequested API tools are also valid Actions when named exactly:\n", + ); + + for (tools) |tool| { + try out.writer(allocator).print( + " {s}[argument] - {s}\n", + .{ tool.name, tool.description }, + ); + } + + try out.appendSlice( + allocator, + "\nTool rules:\n" ++ + "- For file_write, include the path and fenced code block inside the brackets, for example:\n" ++ + " Action N: file_write[write relative/path\n```lang\n...\n```]\n" ++ + "- Put the full tool input inside the brackets.\n", + ); + + return try out.toOwnedSlice(allocator); +} + /// Parsed action extracted from a model response. +pub const ToolAction = struct { + name: []const u8, + argument: []const u8, +}; + pub const ReactAction = union(enum) { search: []const u8, lookup: []const u8, cmd: []const u8, finish: []const u8, + tool: ToolAction, unknown: []const u8, }; @@ -47,32 +83,72 @@ pub const ReactAction = union(enum) { /// - Case-insensitive match on the `Action` keyword. /// - The turn number after `Action` is optional (the harness tracks turns). /// - Brackets are mandatory for the argument. -/// - Greedy bracket matching: takes everything up to the last `]` on the line. +/// - Greedy bracket matching: takes everything up to the last `]` after the +/// action, allowing multi-line tool inputs such as fenced code blocks. /// - Returns `null` when no action line is found. pub fn parseReactAction(output: []const u8) ?ReactAction { - // Scan backwards through lines to find the last Action directive. - var last_action_line: ?[]const u8 = null; - var lines = std.mem.splitAny(u8, output, "\n\r"); - while (lines.next()) |raw_line| { - const line = std.mem.trim(u8, raw_line, " \t"); - if (line.len == 0) continue; + // Scan for the FIRST Action directive that lives outside a + // ... block. Returning the first action (rather than the + // last) lets the loop drive multi-step plans one observation at a time: + // if the model speculatively emits Action 2 before seeing Observation 1, + // we execute Action 1 and let Observation 1 inform Action 2 on the next + // turn. + var first_action_start: ?usize = null; + var inside_think = false; + var cursor: usize = 0; + while (cursor < output.len) { + if (!inside_think and startsWithIgnoreCase(output[cursor..], "")) { + inside_think = true; + cursor += "".len; + continue; + } + if (inside_think and startsWithIgnoreCase(output[cursor..], "")) { + inside_think = false; + cursor += "".len; + continue; + } + + if (inside_think) { + cursor += 1; + continue; + } + + const line_start = cursor; + var line_end = cursor; + while (line_end < output.len and output[line_end] != '\n' and output[line_end] != '\r') { + line_end += 1; + } + + const raw_line = output[line_start..line_end]; + const line = std.mem.trimLeft(u8, raw_line, " \t"); if (startsWithActionDirective(line)) { - last_action_line = line; + first_action_start = line_start + (raw_line.len - line.len); + break; + } + + cursor = line_end; + while (cursor < output.len and (output[cursor] == '\n' or output[cursor] == '\r')) { + cursor += 1; } } - const action_line = last_action_line orelse return null; + const action_start = first_action_start orelse return null; + // For multi-line bracket arguments (fenced code blocks) we need to + // bound the search for the closing `]` to the FIRST action only, not + // any later actions the model may have over-eagerly emitted. We do this + // by clipping the search window at the next `Action ` directive line. + const action_text = clipAtNextActionDirective(output[action_start..]); // Extract the part after "Action" (and optional number + colon). - const after_keyword = action_line[6..]; // "Action" is 6 chars + const after_keyword = action_text[6..]; // "Action" is 6 chars const after_prefix = skipActionPrefix(after_keyword); - const trimmed = std.mem.trim(u8, after_prefix, " \t"); + const trimmed = std.mem.trimLeft(u8, after_prefix, " \t"); if (trimmed.len == 0) return null; // Find the bracket pair: Type[arg] const open_bracket = std.mem.indexOfScalar(u8, trimmed, '[') orelse return null; - // Greedy: find the last ']' on the line. + // Greedy: find the last ']' after the action. const close_bracket = std.mem.lastIndexOfScalar(u8, trimmed, ']') orelse return null; if (close_bracket <= open_bracket) return null; @@ -85,8 +161,11 @@ pub fn parseReactAction(output: []const u8) ?ReactAction { if (eqlIgnoreCase(action_type, "lookup")) return .{ .lookup = argument }; if (eqlIgnoreCase(action_type, "cmd")) return .{ .cmd = argument }; if (eqlIgnoreCase(action_type, "finish")) return .{ .finish = argument }; + if (isKnownApiToolAction(action_type)) { + return .{ .tool = .{ .name = action_type, .argument = argument } }; + } - return .{ .unknown = trimmed }; + return .{ .unknown = trimmed[0 .. close_bracket + 1] }; } /// Executes a parsed ReAct action and returns the observation text. @@ -95,6 +174,8 @@ pub fn parseReactAction(output: []const u8) ?ReactAction { /// - `Cmd` delegates to the existing command_exec tool infrastructure. /// - `Finish` should be handled by the caller before reaching this function, /// but returns the answer content as a fallback. +/// - `Tool` is only executable by API loop callers that have a request tool +/// list; the standalone executor returns a hint. /// - `Unknown` returns an error hint listing available actions. pub fn executeReactAction( allocator: std.mem.Allocator, @@ -114,6 +195,11 @@ pub fn executeReactAction( ), .cmd => |command| try executeReactCmd(allocator, command, app_config), .finish => |answer| try allocator.dupe(u8, answer), + .tool => |tool| try std.fmt.allocPrint( + allocator, + "Tool action {s} is only available in API loop mode when requested by the client.", + .{tool.name}, + ), .unknown => |text| try std.fmt.allocPrint( allocator, "Unknown action: {s}. Available actions: Search, Lookup, Cmd, Finish.", @@ -221,6 +307,48 @@ fn eqlIgnoreCase(a: []const u8, b: []const u8) bool { return std.ascii.eqlIgnoreCase(a, b); } +fn startsWithIgnoreCase(value: []const u8, prefix: []const u8) bool { + if (prefix.len > value.len) return false; + return std.ascii.eqlIgnoreCase(value[0..prefix.len], prefix); +} + +/// Returns a prefix of `text` that ends just before the next line starting +/// with `Action ` (case-insensitive). The first line is always kept since it +/// is itself the Action directive being parsed. +fn clipAtNextActionDirective(text: []const u8) []const u8 { + var cursor: usize = 0; + // Skip the directive's own first line. + while (cursor < text.len and text[cursor] != '\n' and text[cursor] != '\r') { + cursor += 1; + } + + while (cursor < text.len) { + while (cursor < text.len and (text[cursor] == '\n' or text[cursor] == '\r')) { + cursor += 1; + } + const line_start = cursor; + while (cursor < text.len and text[cursor] != '\n' and text[cursor] != '\r') { + cursor += 1; + } + + const raw_line = text[line_start..cursor]; + const line = std.mem.trimLeft(u8, raw_line, " \t"); + if (startsWithActionDirective(line)) { + return text[0..line_start]; + } + } + return text; +} + +fn isKnownApiToolAction(action_type: []const u8) bool { + return eqlIgnoreCase(action_type, "echo") or + eqlIgnoreCase(action_type, "utc") or + eqlIgnoreCase(action_type, "bash") or + eqlIgnoreCase(action_type, "file_read") or + eqlIgnoreCase(action_type, "file_write") or + eqlIgnoreCase(action_type, "file_search"); +} + // ── tests ───────────────────────────────────────────────────────────── test "parseReactAction extracts Search action" { @@ -273,6 +401,32 @@ test "parseReactAction returns unknown for unrecognized action type" { } } +test "parseReactAction extracts requested API tool action" { + const action = parseReactAction("Thought 1: write the file\nAction 1: file_write[write app.cpp\n```cpp\nint main() { return 0; }\n```]"); + try std.testing.expect(action != null); + switch (action.?) { + .tool => |tool| { + try std.testing.expectEqualStrings("file_write", tool.name); + try std.testing.expect(std.mem.indexOf(u8, tool.argument, "app.cpp") != null); + try std.testing.expect(std.mem.indexOf(u8, tool.argument, "int main()") != null); + }, + else => return error.UnexpectedActionType, + } +} + +test "buildSystemPromptAlloc includes requested tools" { + const allocator = std.testing.allocator; + const tools = [_]types.Tool{ + .{ .name = "file_write", .description = "write files" }, + }; + + const prompt = try buildSystemPromptAlloc(allocator, tools[0..]); + defer allocator.free(prompt); + + try std.testing.expect(std.mem.indexOf(u8, prompt, "file_write[argument]") != null); + try std.testing.expect(std.mem.indexOf(u8, prompt, "file_write[write relative/path") != null); +} + test "parseReactAction handles nested brackets greedily" { const action = parseReactAction("Action 5: Finish[1,800 to 7,000 ft]"); try std.testing.expect(action != null); @@ -305,7 +459,7 @@ test "parseReactAction tolerates missing turn number" { } } -test "parseReactAction uses last action line when multiple exist" { +test "parseReactAction returns first action when model emits multiple" { const output = "Thought 1: first thought\n" ++ "Action 1: Search[first]\n" ++ @@ -314,7 +468,37 @@ test "parseReactAction uses last action line when multiple exist" { const action = parseReactAction(output); try std.testing.expect(action != null); switch (action.?) { - .finish => |answer| try std.testing.expectEqualStrings("final answer", answer), + .search => |query| try std.testing.expectEqualStrings("first", query), + else => return error.UnexpectedActionType, + } +} + +test "parseReactAction ignores Action mentions inside block" { + const output = + "I should probably emit Action 1: Search[bad]\nsomething\n" ++ + "Action 1: Finish[real answer]"; + const action = parseReactAction(output); + try std.testing.expect(action != null); + switch (action.?) { + .finish => |answer| try std.testing.expectEqualStrings("real answer", answer), + else => return error.UnexpectedActionType, + } +} + +test "parseReactAction parses multi-line first action when a second follows" { + const output = + "Thought 1: write the file\n" ++ + "Action 1: file_write[calculator.cpp\n```cpp\nint main(){}\n```]\n" ++ + "Action 2: cmd[g++ -o calculator calculator.cpp]"; + const action = parseReactAction(output); + try std.testing.expect(action != null); + switch (action.?) { + .tool => |tool| { + try std.testing.expectEqualStrings("file_write", tool.name); + try std.testing.expect(std.mem.indexOf(u8, tool.argument, "calculator.cpp") != null); + try std.testing.expect(std.mem.indexOf(u8, tool.argument, "int main(){}") != null); + try std.testing.expect(std.mem.indexOf(u8, tool.argument, "g++") == null); + }, else => return error.UnexpectedActionType, } } From 91ea8a970c3cd4546a0fdd45545de7f6eff10efd Mon Sep 17 00:00:00 2001 From: Ajeets6 Date: Thu, 28 May 2026 00:22:28 +1000 Subject: [PATCH 11/25] Add zig test runner --- README.md | 20 ++++++++++++++++++++ build.zig | 29 +++++++++++++++++++++++++++++ build.zig.zon | 4 ++++ 3 files changed, 53 insertions(+) diff --git a/README.md b/README.md index ab0f07d..e15ca5b 100644 --- a/README.md +++ b/README.md @@ -81,6 +81,26 @@ zig build test -Dtest-target=file "-Dtest-file=src/types.zig" zig build test -Dtest-target=all -Dtest-filter=normalizeProviderName ``` +### zig_eval harness (optional) + +This repo depends on the sibling **zig_eval** checkout at `../zig_eval` (see `build.zig.zon`). It ships a small **eval registry** under `eval/registry/` (same layout as `zig_eval`: `services.json`, `evals//*.json`, `data/.../*.jsonl`) and a runner binary **`zig-coding-agent-eval`**. + +**Prerequisites:** Ollama (or whatever `services.json` targets) reachable, and this router listening (default `http://127.0.0.1:8081`). If `LLM_ROUTER_API_KEY` is set on the server, set the same variable for the eval client (see `eval/registry/services.json` → `api_key_env`). + +```bash +# Terminal A: router +zig build run + +# Terminal B: from repo root — runs all eval definitions under eval/registry/evals/ +zig build eval-run + +# Custom registry directory or JSON run log +zig build eval-run -- path/to/registry +zig build eval-run -- --json +``` + +There is currently **one** eval definition checked in (`smoke.reply_ok`). Add more by dropping new JSON files under `eval/registry/evals/` and matching datasets under `eval/registry/data/`. + ## API Surface ### Endpoints diff --git a/build.zig b/build.zig index dee0703..783f755 100644 --- a/build.zig +++ b/build.zig @@ -27,6 +27,12 @@ pub fn build(b: *std.Build) void { .target = target, }); + const zig_eval_dep = b.dependency("zig_eval", .{ + .target = target, + .optimize = optimize, + }); + const zig_eval_mod = zig_eval_dep.module("zig_eval"); + // CLI executable entrypoint. const app = b.addExecutable(.{ .name = exe_name, @@ -43,6 +49,28 @@ pub fn build(b: *std.Build) void { b.installArtifact(app); + const eval_exe = b.addExecutable(.{ + .name = "zig-coding-agent-eval", + .root_module = b.createModule(.{ + .root_source_file = b.path("src/eval_runner_main.zig"), + .target = target, + .optimize = optimize, + .imports = &.{ + .{ .name = "zig_eval", .module = zig_eval_mod }, + }, + }), + }); + eval_exe.linkLibC(); + b.installArtifact(eval_exe); + + const eval_run_step = b.step("eval-run", "Run eval/registry against a running harness (see eval/registry/services.json)"); + const eval_run_cmd = b.addRunArtifact(eval_exe); + eval_run_cmd.step.dependOn(b.getInstallStep()); + if (b.args) |args| { + eval_run_cmd.addArgs(args); + } + eval_run_step.dependOn(&eval_run_cmd.step); + // Run command: `zig build run -- [args]`. const run_step = b.step("run", "Run the Zig Coding Agent server"); const run_cmd = b.addRunArtifact(app); @@ -125,6 +153,7 @@ pub fn build(b: *std.Build) void { // Compile-only verification for app and built-in test modules. const check_step = b.step("check", "Compile app and built-in test modules without running"); check_step.dependOn(&app.step); + check_step.dependOn(&eval_exe.step); check_step.dependOn(&mod_tests.step); check_step.dependOn(&exe_tests.step); check_step.dependOn(&types_unit_tests.step); diff --git a/build.zig.zon b/build.zig.zon index d7f6a1a..7c242d2 100644 --- a/build.zig.zon +++ b/build.zig.zon @@ -32,6 +32,9 @@ // Once all dependencies are fetched, `zig build` no longer requires // internet connectivity. .dependencies = .{ + .zig_eval = .{ + .path = "../zig_eval", + }, // See `zig fetch --save ` for a command-line interface for adding dependencies. //.example = .{ // // When updating this field to a new URL, be sure to delete the corresponding @@ -73,6 +76,7 @@ .paths = .{ "build.zig", "build.zig.zon", + "eval", "src", // For example... //"LICENSE", From 21f479b90c6d184a2f9eb58048c6267c1a221f00 Mon Sep 17 00:00:00 2001 From: Ajeets6 Date: Thu, 28 May 2026 00:23:40 +1000 Subject: [PATCH 12/25] Add streming to react loop --- src/backend/api.zig | 118 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 118 insertions(+) diff --git a/src/backend/api.zig b/src/backend/api.zig index 6c0ece8..471e5b1 100644 --- a/src/backend/api.zig +++ b/src/backend/api.zig @@ -748,6 +748,124 @@ pub fn callProvider( return error.UnknownProvider; } +/// Outcome of a single ReAct/agent loop turn driven through the streaming +/// pipeline. `streamed` means tokens were emitted live to the SSE client and +/// `captured_content` holds the assistant's content text (no reasoning) for +/// action parsing or context replay. +pub const TurnStreamOutcome = struct { + success: bool, + finish_reason: []u8, + model: []u8, + captured_content: []u8, + + pub fn deinit(self: TurnStreamOutcome, allocator: std.mem.Allocator) void { + allocator.free(self.finish_reason); + allocator.free(self.model); + allocator.free(self.captured_content); + } +}; + +/// Streams a single loop turn directly to the SSE client. +/// +/// For providers with a dedicated SSE adapter (Ollama today) this forwards +/// tokens as they arrive. For all other providers, falls back to the buffered +/// `callProvider` path and emits the whole response as one SSE chunk so the +/// loop semantics remain identical. +/// +/// `headers_sent` and `think_block_open` are caller-owned across turns so a +/// single SSE response can carry many turns without re-emitting headers or +/// leaving an unclosed `` block. +pub fn streamProviderTurn( + connection: std.net.Server.Connection, + allocator: std.mem.Allocator, + app_config: *const config.Config, + request: types.Request, + completion_id: []const u8, + headers_sent: *bool, + think_block_open: *bool, +) !TurnStreamOutcome { + const response_mod = @import("../core/response.zig"); + + if (!headers_sent.*) { + try response_mod.sendEventStreamHeaders(connection); + headers_sent.* = true; + } + + const requested_provider = request.provider orelse app_config.default_provider; + const provider = types.normalizeProviderName(requested_provider) orelse { + return error.UnknownProvider; + }; + + if (std.mem.eql(u8, provider, "ollama_qwen")) { + const ollama_qwen = @import("../providers/ollama_qwen.zig"); + + var captured = std.ArrayList(u8){}; + errdefer captured.deinit(allocator); + + const turn_result = try ollama_qwen.streamQwenTurnToSse( + connection, + allocator, + app_config, + request, + completion_id, + &captured, + think_block_open, + ); + + switch (turn_result) { + .streamed => |success| { + defer success.deinit(allocator); + return .{ + .success = true, + .finish_reason = try allocator.dupe(u8, success.finish_reason), + .model = try allocator.dupe(u8, success.model), + .captured_content = try captured.toOwnedSlice(allocator), + }; + }, + .failed => |provider_response| { + defer provider_response.deinit(allocator); + captured.deinit(allocator); + return .{ + .success = false, + .finish_reason = try allocator.dupe(u8, provider_response.finish_reason), + .model = try allocator.dupe(u8, provider_response.model), + .captured_content = try allocator.dupe(u8, provider_response.output), + }; + }, + } + } + + // Fallback path for providers without per-token streaming: buffer via + // callProvider, then emit the whole response as a single SSE chunk. + if (think_block_open.*) { + try response_mod.sendChatCompletionChunkSse( + connection, allocator, completion_id, "fallback", "", null, + ); + think_block_open.* = false; + } + + var buffered = try callProvider(allocator, app_config, request); + defer buffered.deinit(allocator); + + if (buffered.output.len > 0) { + try response_mod.sendChatCompletionChunkSse( + connection, + allocator, + completion_id, + buffered.model, + buffered.output, + null, + ); + } + + return .{ + .success = buffered.success, + .finish_reason = try allocator.dupe(u8, buffered.finish_reason), + .model = try allocator.dupe(u8, buffered.model), + .captured_content = try allocator.dupe(u8, buffered.output), + }; +} + pub fn buildProviderStatusJson( allocator: std.mem.Allocator, app_config: *const config.Config, From ac3cf53b32e3d32157f70c1fb4dd6fdad6564231 Mon Sep 17 00:00:00 2001 From: Ajeets6 Date: Thu, 28 May 2026 00:24:13 +1000 Subject: [PATCH 13/25] Fix loop connection --- src/core/server.zig | 96 +++++++++++++++++++++++++++------------------ 1 file changed, 57 insertions(+), 39 deletions(-) diff --git a/src/core/server.zig b/src/core/server.zig index 2c53885..2b20606 100644 --- a/src/core/server.zig +++ b/src/core/server.zig @@ -1092,6 +1092,8 @@ fn streamLoopRequestToSse( const loop_max_turns = base_request.loop_max_turns orelse 8; try response.sendEventStreamHeaders(connection); + var headers_sent: bool = true; + var think_block_open: bool = false; const completion_id = try std.fmt.allocPrint( allocator, @@ -1174,45 +1176,63 @@ fn streamLoopRequestToSse( break :blk r; } else turn_request; - var turn_result = try backend.callProvider(allocator, app_config, provider_request); - defer turn_result.deinit(allocator); - if (turn_result.success and isLengthFinishReason(turn_result.finish_reason)) { - continueLengthResponseAlloc(allocator, app_config, provider_request, &turn_result) catch |err| { - logError("Provider loop continuation failed: {s}", .{@errorName(err)}); - }; + // Emit the per-turn header BEFORE provider tokens stream in, so the + // client can render "[loop turn N/M]\n" then watch tokens arrive live. + if (emit_progress) { + const progress_prefix = try std.fmt.allocPrint( + allocator, + "[loop turn {d}/{d}]\n", + .{ turn + 1, loop_max_turns }, + ); + defer allocator.free(progress_prefix); + + try response.sendChatCompletionChunkSse( + connection, + allocator, + completion_id, + provider_request.model orelse app_config.ollama_model, + progress_prefix, + null, + ); } - const with_assistant = try appendRoleMessageAlloc( + var turn_outcome = try backend.streamProviderTurn( + connection, allocator, - working_messages, - "assistant", - turn_result.output, + app_config, + provider_request, + completion_id, + &headers_sent, + &think_block_open, ); - for (working_messages) |message| { - message.deinit(allocator); - } - allocator.free(working_messages); - working_messages = with_assistant; - - if (emit_progress) { - const progress_text = try std.fmt.allocPrint( - allocator, - "[loop turn {d}/{d}]\n{s}\n", - .{ turn + 1, loop_max_turns, turn_result.output }, - ); - defer allocator.free(progress_text); + defer turn_outcome.deinit(allocator); + // Emit a trailing newline after the turn body so progress text and + // the next observation/prefix don't visually collide. + if (emit_progress and turn_outcome.captured_content.len > 0) { try response.sendChatCompletionChunkSse( connection, allocator, completion_id, - turn_result.model, - progress_text, + turn_outcome.model, + "\n", null, ); } - const normalized_output = std.mem.trim(u8, turn_result.output, " \t\r\n"); + const with_assistant = try appendRoleMessageAlloc( + allocator, + working_messages, + "assistant", + turn_outcome.captured_content, + ); + for (working_messages) |message| { + message.deinit(allocator); + } + allocator.free(working_messages); + working_messages = with_assistant; + + const normalized_output = std.mem.trim(u8, turn_outcome.captured_content, " \t\r\n"); if (previous_output) |prev| { if (std.mem.eql(u8, prev, normalized_output)) { repeated_count += 1; @@ -1224,7 +1244,7 @@ fn streamLoopRequestToSse( previous_output = try allocator.dupe(u8, normalized_output); const is_react_mode = std.ascii.eqlIgnoreCase(loop_mode, "react"); - const reached_until = std.mem.indexOf(u8, turn_result.output, loop_until) != null; + const reached_until = std.mem.indexOf(u8, turn_outcome.captured_content, loop_until) != null; const reached_max = (turn + 1) >= loop_max_turns; const repeated_stop = (std.ascii.eqlIgnoreCase(loop_mode, "agent") or is_react_mode) and repeated_count >= 1; @@ -1234,7 +1254,7 @@ fn streamLoopRequestToSse( defer if (react_observation) |obs| allocator.free(obs); if (is_react_mode) { - if (react.parseReactAction(turn_result.output)) |action| { + if (react.parseReactAction(turn_outcome.captured_content)) |action| { switch (action) { .finish => { react_finished = true; @@ -1255,27 +1275,25 @@ fn streamLoopRequestToSse( } } - const should_stop = react_finished or reached_until or reached_max or repeated_stop or !turn_result.success; + const should_stop = react_finished or reached_until or reached_max or repeated_stop or !turn_outcome.success; if (should_stop) { - if (!emit_progress and turn_result.output.len > 0) { + // Close any dangling block from the last turn so the + // client doesn't render an unterminated reasoning section. + if (think_block_open) { try response.sendChatCompletionChunkSse( - connection, - allocator, - completion_id, - turn_result.model, - turn_result.output, - null, + connection, allocator, completion_id, turn_outcome.model, "", null, ); + think_block_open = false; } try response.sendChatCompletionChunkSse( connection, allocator, completion_id, - turn_result.model, + turn_outcome.model, null, - turn_result.finish_reason, + turn_outcome.finish_reason, ); try response.sendSseDone(connection); return; @@ -1294,7 +1312,7 @@ fn streamLoopRequestToSse( connection, allocator, completion_id, - turn_result.model, + turn_outcome.model, latest_user_prompt, null, ); From 40a31faa01eeead95bd85efddd05fc7a04a8b02b Mon Sep 17 00:00:00 2001 From: Ajeets6 Date: Thu, 28 May 2026 00:46:58 +1000 Subject: [PATCH 14/25] Add streming support for multi loops --- src/providers/ollama_qwen.zig | 103 ++++++++++++++++++++++++++++++++++ 1 file changed, 103 insertions(+) diff --git a/src/providers/ollama_qwen.zig b/src/providers/ollama_qwen.zig index 265912c..8eee0b3 100644 --- a/src/providers/ollama_qwen.zig +++ b/src/providers/ollama_qwen.zig @@ -58,6 +58,109 @@ pub const StreamQwenResult = union(enum) { failed: types.Response, }; +/// Per-turn streaming variant for use inside multi-turn loops (ReAct, agent). +/// +/// Differences from `streamQwenToSse`: +/// - The caller owns SSE headers, the completion id, and the streaming buffers +/// (`captured_content`, `think_block_open`). Headers and `[DONE]` are NOT +/// emitted here. +/// - On success, returns the finish_reason and resolved model name in +/// `TurnStreamSuccess` so the loop can decide whether to continue, parse +/// actions, or stop. +/// - `captured_content` accumulates only the assistant `content` tokens (no +/// reasoning), giving the loop a clean buffer to feed back into the next +/// turn or hand to `react.parseReactAction`. +pub const TurnStreamSuccess = struct { + finish_reason: []u8, + model: []u8, + + pub fn deinit(self: TurnStreamSuccess, allocator: std.mem.Allocator) void { + allocator.free(self.finish_reason); + allocator.free(self.model); + } +}; + +pub const TurnStreamResult = union(enum) { + streamed: TurnStreamSuccess, + failed: types.Response, +}; + +pub fn streamQwenTurnToSse( + connection: std.net.Server.Connection, + allocator: std.mem.Allocator, + app_config: *const config.Config, + request: types.Request, + completion_id: []const u8, + captured_content: *std.ArrayList(u8), + think_block_open: *bool, +) !TurnStreamResult { + const requested_model = request.model orelse app_config.ollama_model; + var model_name = requested_model; + var fallback: ?ModelFallback = null; + defer if (fallback) |value| value.deinit(allocator); + + if (try pickFallbackModelAlloc(allocator, app_config, requested_model)) |selected| { + fallback = selected; + model_name = fallback.?.selected_model; + } + + var client = std.http.Client{ .allocator = allocator }; + defer client.deinit(); + + const uri_text = try buildChatUrl(allocator, app_config.ollama_base_url); + defer allocator.free(uri_text); + + const uri = try std.Uri.parse(uri_text); + + var headers_sent_already = true; // caller already sent headers + var final_finish_reason: []const u8 = "stop"; + + var turn: usize = 0; + while (true) : (turn += 1) { + const messages_json = if (turn == 0) + try renderMessagesJsonAlloc(allocator, request.messages) + else + try renderContinuationMessagesJsonAlloc( + allocator, + request.messages, + captured_content.items, + stream_continue_prompt, + ); + defer allocator.free(messages_json); + + const attempt_result = try streamChatMessagesToSse( + connection, + allocator, + &client, + app_config, + request, + uri, + completion_id, + model_name, + messages_json, + &headers_sent_already, + captured_content, + think_block_open, + ); + + switch (attempt_result) { + .failed => |failed_response| return .{ .failed = failed_response }, + .streamed => |attempt| { + if (!attempt.stopped_by_length) break; + if (turn + 1 >= max_stream_continuations) { + final_finish_reason = "length"; + break; + } + }, + } + } + + return .{ .streamed = .{ + .finish_reason = try allocator.dupe(u8, final_finish_reason), + .model = try allocator.dupe(u8, model_name), + } }; +} + pub fn streamQwenToSse( connection: std.net.Server.Connection, allocator: std.mem.Allocator, From a4a1d54352121b020998af259e9728527a1f7744 Mon Sep 17 00:00:00 2001 From: Ajeets6 Date: Fri, 29 May 2026 14:39:16 +1000 Subject: [PATCH 15/25] Add recent request tracking to server state This update introduces a new structure for tracking recent requests, including route, status code, duration, and request ID. The server state now maintains a circular buffer of recent requests, allowing for efficient logging and retrieval. --- build.zig | 46 +++-- build.zig.zon | 4 - src/backend/api.zig | 2 +- src/backend/tool_output.zig | 105 +++++++++++ src/backend/tools.zig | 33 ++-- src/config.zig | 12 ++ src/core/request.zig | 63 +++++++ src/core/response.zig | 60 ++++-- src/core/server.zig | 344 ++++++++++++++++++++++++++-------- src/providers/ollama_qwen.zig | 2 +- src/react.zig | 18 +- src/root.zig | 1 + src/tools/command_exec.zig | 54 ++++-- 13 files changed, 580 insertions(+), 164 deletions(-) create mode 100644 src/backend/tool_output.zig diff --git a/build.zig b/build.zig index 783f755..541e536 100644 --- a/build.zig +++ b/build.zig @@ -27,12 +27,6 @@ pub fn build(b: *std.Build) void { .target = target, }); - const zig_eval_dep = b.dependency("zig_eval", .{ - .target = target, - .optimize = optimize, - }); - const zig_eval_mod = zig_eval_dep.module("zig_eval"); - // CLI executable entrypoint. const app = b.addExecutable(.{ .name = exe_name, @@ -45,30 +39,35 @@ pub fn build(b: *std.Build) void { }, }), }); - app.linkLibC(); + // No linkLibC(): harness uses std only; avoids glibc/GCC 16 .sframe linker issues on Zig 0.15.2. b.installArtifact(app); - const eval_exe = b.addExecutable(.{ - .name = "zig-coding-agent-eval", - .root_module = b.createModule(.{ - .root_source_file = b.path("src/eval_runner_main.zig"), - .target = target, - .optimize = optimize, - .imports = &.{ - .{ .name = "zig_eval", .module = zig_eval_mod }, - }, - }), - }); - eval_exe.linkLibC(); - b.installArtifact(eval_exe); + // Optional sibling checkout: ../zig_eval (same parent folder as this repo). + const zig_eval_root = "../zig_eval"; + const zig_eval_registry = "registry"; + + const build_zig_eval = b.addSystemCommand(&.{ "zig", "build" }); + build_zig_eval.setCwd(b.path(zig_eval_root)); + + const eval_build_step = b.step("eval-build", "Build sibling zig_eval checkout (requires ../zig_eval)"); + eval_build_step.dependOn(&build_zig_eval.step); - const eval_run_step = b.step("eval-run", "Run eval/registry against a running harness (see eval/registry/services.json)"); - const eval_run_cmd = b.addRunArtifact(eval_exe); - eval_run_cmd.step.dependOn(b.getInstallStep()); + const eval_list_cmd = b.addSystemCommand(&.{ "zig", "build", "run", "--", "list", "--registry", zig_eval_registry }); + eval_list_cmd.setCwd(b.path(zig_eval_root)); + eval_list_cmd.step.dependOn(&build_zig_eval.step); + + const eval_list_step = b.step("eval-list", "List evals from sibling zig_eval registry (requires ../zig_eval)"); + eval_list_step.dependOn(&eval_list_cmd.step); + + const eval_run_cmd = b.addSystemCommand(&.{ "zig", "build", "run", "--", "run", "--registry", zig_eval_registry }); + eval_run_cmd.setCwd(b.path(zig_eval_root)); + eval_run_cmd.step.dependOn(&build_zig_eval.step); if (b.args) |args| { eval_run_cmd.addArgs(args); } + + const eval_run_step = b.step("eval-run", "Run sibling zig_eval against this harness (requires ../zig_eval; start harness first)"); eval_run_step.dependOn(&eval_run_cmd.step); // Run command: `zig build run -- [args]`. @@ -153,7 +152,6 @@ pub fn build(b: *std.Build) void { // Compile-only verification for app and built-in test modules. const check_step = b.step("check", "Compile app and built-in test modules without running"); check_step.dependOn(&app.step); - check_step.dependOn(&eval_exe.step); check_step.dependOn(&mod_tests.step); check_step.dependOn(&exe_tests.step); check_step.dependOn(&types_unit_tests.step); diff --git a/build.zig.zon b/build.zig.zon index 7c242d2..d7f6a1a 100644 --- a/build.zig.zon +++ b/build.zig.zon @@ -32,9 +32,6 @@ // Once all dependencies are fetched, `zig build` no longer requires // internet connectivity. .dependencies = .{ - .zig_eval = .{ - .path = "../zig_eval", - }, // See `zig fetch --save ` for a command-line interface for adding dependencies. //.example = .{ // // When updating this field to a new URL, be sure to delete the corresponding @@ -76,7 +73,6 @@ .paths = .{ "build.zig", "build.zig.zon", - "eval", "src", // For example... //"LICENSE", diff --git a/src/backend/api.zig b/src/backend/api.zig index 471e5b1..8d13ef5 100644 --- a/src/backend/api.zig +++ b/src/backend/api.zig @@ -787,7 +787,7 @@ pub fn streamProviderTurn( const response_mod = @import("../core/response.zig"); if (!headers_sent.*) { - try response_mod.sendEventStreamHeaders(connection); + try response_mod.sendEventStreamHeaders(connection, null); headers_sent.* = true; } diff --git a/src/backend/tool_output.zig b/src/backend/tool_output.zig new file mode 100644 index 0000000..edab210 --- /dev/null +++ b/src/backend/tool_output.zig @@ -0,0 +1,105 @@ +//! Tool output offloading for large tool results. +//! +//! When tool output exceeds a configured byte threshold, the full payload is +//! written to disk and a compact head/tail summary is returned for context. +const std = @import("std"); +const config = @import("../config.zig"); + +const head_tail_preview_bytes = 512; + +pub fn maybeOffloadToolOutputAlloc( + allocator: std.mem.Allocator, + app_config: *const config.Config, + request_id: []const u8, + tool_name: []const u8, + output: []const u8, +) ![]u8 { + if (app_config.tool_output_offload_bytes == 0) { + return try allocator.dupe(u8, output); + } + if (output.len <= app_config.tool_output_offload_bytes) { + return try allocator.dupe(u8, output); + } + + const offload_dir = try std.fmt.allocPrint(allocator, "{s}/tool_outputs", .{app_config.session_store_path}); + defer allocator.free(offload_dir); + + std.fs.cwd().makePath(offload_dir) catch {}; + + const safe_tool_name = sanitizePathComponent(tool_name); + const offload_path = try std.fmt.allocPrint( + allocator, + "{s}/{s}_{s}.txt", + .{ offload_dir, request_id, safe_tool_name }, + ); + defer allocator.free(offload_path); + + const file = std.fs.cwd().createFile(offload_path, .{}) catch { + return try allocator.dupe(u8, output); + }; + defer file.close(); + file.writeAll(output) catch { + return try allocator.dupe(u8, output); + }; + + const head_len = @min(head_tail_preview_bytes, output.len); + const tail_len = @min(head_tail_preview_bytes, output.len - head_len); + const tail_start = output.len - tail_len; + + return try std.fmt.allocPrint( + allocator, + "tool_output_offloaded=true\nfull_bytes={d}\noffload_path={s}\n--- head ---\n{s}\n--- tail ---\n{s}", + .{ + output.len, + offload_path, + output[0..head_len], + output[tail_start..], + }, + ); +} + +fn sanitizePathComponent(name: []const u8) []const u8 { + var out: [64]u8 = undefined; + var len: usize = 0; + for (name) |c| { + if (len >= out.len) break; + switch (c) { + 'a'...'z', 'A'...'Z', '0'...'9', '_', '-', '.' => { + out[len] = c; + len += 1; + }, + else => {}, + } + } + if (len == 0) return "tool"; + return out[0..len]; +} + +test "maybeOffloadToolOutputAlloc keeps small output inline" { + const allocator = std.testing.allocator; + var cfg = try @import("../tools/command_exec.zig").buildTestConfig(allocator, false); + defer cfg.deinit(allocator); + cfg.tool_output_offload_bytes = 64; + + const output = try maybeOffloadToolOutputAlloc(allocator, &cfg, "req-1", "echo", "small"); + defer allocator.free(output); + + try std.testing.expectEqualStrings("small", output); +} + +test "maybeOffloadToolOutputAlloc offloads large output with preview" { + const allocator = std.testing.allocator; + var cfg = try @import("../tools/command_exec.zig").buildTestConfig(allocator, false); + defer cfg.deinit(allocator); + cfg.tool_output_offload_bytes = 32; + + const large = try allocator.alloc(u8, 128); + defer allocator.free(large); + @memset(large, 'x'); + + const output = try maybeOffloadToolOutputAlloc(allocator, &cfg, "req-offload", "echo", large); + defer allocator.free(output); + + try std.testing.expect(std.mem.indexOf(u8, output, "tool_output_offloaded=true") != null); + try std.testing.expect(std.mem.indexOf(u8, output, "offload_path=") != null); +} diff --git a/src/backend/tools.zig b/src/backend/tools.zig index 0282515..4b875ce 100644 --- a/src/backend/tools.zig +++ b/src/backend/tools.zig @@ -5,6 +5,7 @@ const echo_tool = @import("../tools/echo.zig"); const utc_tool = @import("../tools/utc.zig"); const command_exec_tool = @import("../tools/command_exec.zig"); const file_ops_tool = @import("../tools/file_ops.zig"); +const tool_output = @import("tool_output.zig"); pub const ToolRegistry = struct { allowed_names: std.StringHashMap(void), @@ -55,6 +56,7 @@ pub fn tryExecuteDebugTool( allocator: std.mem.Allocator, request: types.Request, app_config: *const config.Config, + request_id: []const u8, ) !?types.Response { const choice = request.tool_choice orelse return null; @@ -70,18 +72,18 @@ pub fn tryExecuteDebugTool( if (std.ascii.eqlIgnoreCase(choice, "cmd")) { if (!hasRequestedTool(request.tools, "cmd")) return null; - const command = try command_exec_tool.extractCommandFromPromptAlloc(allocator, .cmd, request.prompt); + const command = try command_exec_tool.extractCommandFromPromptAlloc(allocator, .cmd, request.prompt, app_config.tool_exec_trusted_local); defer if (command) |value| allocator.free(value); if (command == null) return null; - return try command_exec_tool.execute(allocator, app_config, request, .cmd); + return try command_exec_tool.execute(allocator, app_config, request, .cmd, request_id); } if (std.ascii.eqlIgnoreCase(choice, "bash")) { if (!hasRequestedTool(request.tools, "bash")) return null; - const command = try command_exec_tool.extractCommandFromPromptAlloc(allocator, .bash, request.prompt); + const command = try command_exec_tool.extractCommandFromPromptAlloc(allocator, .bash, request.prompt, app_config.tool_exec_trusted_local); defer if (command) |value| allocator.free(value); if (command == null) return null; - return try command_exec_tool.execute(allocator, app_config, request, .bash); + return try command_exec_tool.execute(allocator, app_config, request, .bash, request_id); } if (std.ascii.eqlIgnoreCase(choice, "file_read")) { @@ -106,6 +108,7 @@ pub fn maybeExecutePromptToolsAlloc( allocator: std.mem.Allocator, request: types.Request, app_config: *const config.Config, + request_id: []const u8, ) !?[]u8 { if (request.tools.len == 0) return null; @@ -187,7 +190,7 @@ pub fn maybeExecutePromptToolsAlloc( defer run_req.deinit(allocator); try noteToolCall(app_config.max_tool_calls_per_request, &tool_calls); - const run_result = try command_exec_tool.execute(allocator, app_config, run_req, .cmd); + const run_result = try command_exec_tool.execute(allocator, app_config, run_req, .cmd, request_id); defer run_result.deinit(allocator); try output.writer(allocator).print("\n[tool=cmd]\n{s}\n", .{run_result.output}); @@ -198,7 +201,7 @@ pub fn maybeExecutePromptToolsAlloc( defer run_req.deinit(allocator); try noteToolCall(app_config.max_tool_calls_per_request, &tool_calls); - const run_result = try command_exec_tool.execute(allocator, app_config, run_req, .bash); + const run_result = try command_exec_tool.execute(allocator, app_config, run_req, .bash, request_id); defer run_result.deinit(allocator); try output.writer(allocator).print("\n[tool=bash]\n{s}\n", .{run_result.output}); @@ -206,7 +209,7 @@ pub fn maybeExecutePromptToolsAlloc( } if (cmd_requested) { - const command_to_run = try command_exec_tool.extractCommandFromPromptAlloc(allocator, .cmd, prompt); + const command_to_run = try command_exec_tool.extractCommandFromPromptAlloc(allocator, .cmd, prompt, app_config.tool_exec_trusted_local); defer if (command_to_run) |command| allocator.free(command); if (command_to_run) |command| { @@ -214,7 +217,7 @@ pub fn maybeExecutePromptToolsAlloc( defer cmd_request.deinit(allocator); try noteToolCall(app_config.max_tool_calls_per_request, &tool_calls); - const cmd_result = try command_exec_tool.execute(allocator, app_config, cmd_request, .cmd); + const cmd_result = try command_exec_tool.execute(allocator, app_config, cmd_request, .cmd, request_id); defer cmd_result.deinit(allocator); try output.writer(allocator).print("\n[tool=cmd]\n{s}\n", .{cmd_result.output}); @@ -223,7 +226,7 @@ pub fn maybeExecutePromptToolsAlloc( } if (!executed_any and bash_requested) { - const command_to_run = try command_exec_tool.extractCommandFromPromptAlloc(allocator, .bash, prompt); + const command_to_run = try command_exec_tool.extractCommandFromPromptAlloc(allocator, .bash, prompt, app_config.tool_exec_trusted_local); defer if (command_to_run) |command| allocator.free(command); if (command_to_run) |command| { @@ -231,7 +234,7 @@ pub fn maybeExecutePromptToolsAlloc( defer bash_request.deinit(allocator); try noteToolCall(app_config.max_tool_calls_per_request, &tool_calls); - const bash_result = try command_exec_tool.execute(allocator, app_config, bash_request, .bash); + const bash_result = try command_exec_tool.execute(allocator, app_config, bash_request, .bash, request_id); defer bash_result.deinit(allocator); try output.writer(allocator).print("\n[tool=bash]\n{s}\n", .{bash_result.output}); @@ -382,7 +385,7 @@ test "tryExecuteDebugTool executes echo when explicitly selected" { var cfg = try command_exec_tool.buildTestConfig(allocator, false); defer cfg.deinit(allocator); - const maybe_result = try tryExecuteDebugTool(allocator, req, &cfg); + const maybe_result = try tryExecuteDebugTool(allocator, req, &cfg, "test-request"); try std.testing.expect(maybe_result != null); var result = maybe_result.?; @@ -424,7 +427,7 @@ test "tryExecuteDebugTool executes utc when explicitly selected" { var cfg = try command_exec_tool.buildTestConfig(allocator, false); defer cfg.deinit(allocator); - const maybe_result = try tryExecuteDebugTool(allocator, req, &cfg); + const maybe_result = try tryExecuteDebugTool(allocator, req, &cfg, "test-request"); try std.testing.expect(maybe_result != null); var result = maybe_result.?; @@ -466,7 +469,7 @@ test "maybeExecutePromptToolsAlloc runs utc from prompt intent" { var cfg = try command_exec_tool.buildTestConfig(allocator, false); defer cfg.deinit(allocator); - const summary = try maybeExecutePromptToolsAlloc(allocator, req, &cfg); + const summary = try maybeExecutePromptToolsAlloc(allocator, req, &cfg, "test-request"); defer if (summary) |value| allocator.free(value); try std.testing.expect(summary != null); @@ -506,7 +509,7 @@ test "maybeExecutePromptToolsAlloc extracts command for cmd" { var cfg = try command_exec_tool.buildTestConfig(allocator, true); defer cfg.deinit(allocator); - const summary = try maybeExecutePromptToolsAlloc(allocator, req, &cfg); + const summary = try maybeExecutePromptToolsAlloc(allocator, req, &cfg, "test-request"); defer if (summary) |value| allocator.free(value); try std.testing.expect(summary != null); @@ -545,7 +548,7 @@ test "tryExecuteDebugTool ignores natural language cmd selection" { var cfg = try command_exec_tool.buildTestConfig(allocator, true); defer cfg.deinit(allocator); - const maybe_result = try tryExecuteDebugTool(allocator, req, &cfg); + const maybe_result = try tryExecuteDebugTool(allocator, req, &cfg, "test-request"); try std.testing.expect(maybe_result == null); } diff --git a/src/config.zig b/src/config.zig index 3ad7ee0..ef4c60f 100644 --- a/src/config.zig +++ b/src/config.zig @@ -47,9 +47,12 @@ pub const Config = struct { session_retention_messages: usize, tool_exec_enabled: bool, tool_exec_confirmation_required: bool, + tool_exec_trusted_local: bool, tool_exec_timeout_ms: u32, tool_exec_max_output_bytes: usize, + tool_output_offload_bytes: usize, max_tool_calls_per_request: usize, + default_max_context_tokens: ?usize, loop_stream_progress_enabled: bool, max_concurrent_connections: usize, max_request_bytes: usize, @@ -257,9 +260,15 @@ pub const Config = struct { true, env_overrides, ), + .tool_exec_trusted_local = try getEnvFlag(allocator, "LLM_ROUTER_TOOL_EXEC_TRUSTED_LOCAL", env_overrides), .tool_exec_timeout_ms = try getEnvPositiveU32OrDefault("LLM_ROUTER_TOOL_EXEC_TIMEOUT_MS", 15_000, env_overrides), .tool_exec_max_output_bytes = try getEnvPositiveUsizeOrDefault("LLM_ROUTER_TOOL_EXEC_MAX_OUTPUT_BYTES", 65_536, env_overrides), + .tool_output_offload_bytes = try getEnvPositiveUsizeOrDefault("LLM_ROUTER_TOOL_OUTPUT_OFFLOAD_BYTES", 8192, env_overrides), .max_tool_calls_per_request = try getEnvPositiveUsizeOrDefault("LLM_ROUTER_MAX_TOOL_CALLS_PER_REQUEST", 8, env_overrides), + .default_max_context_tokens = blk: { + const raw = getEnvPositiveUsizeOrDefault("LLM_ROUTER_DEFAULT_MAX_CONTEXT_TOKENS", 0, env_overrides) catch 0; + break :blk if (raw == 0) null else raw; + }, .loop_stream_progress_enabled = try getEnvFlagOrDefault(allocator, "LLM_ROUTER_LOOP_STREAM_PROGRESS_ENABLED", true, env_overrides), .max_concurrent_connections = try getEnvPositiveUsizeOrDefault("LLM_ROUTER_MAX_CONCURRENT_CONNECTIONS", 64, env_overrides), .max_request_bytes = try getEnvPositiveUsizeOrDefault("LLM_ROUTER_MAX_REQUEST_BYTES", 1024 * 1024, env_overrides), @@ -539,9 +548,12 @@ test "setDefaultProvider stores canonical provider alias" { .session_retention_messages = 24, .tool_exec_enabled = false, .tool_exec_confirmation_required = false, + .tool_exec_trusted_local = false, .tool_exec_timeout_ms = 15_000, .tool_exec_max_output_bytes = 65_536, + .tool_output_offload_bytes = 8192, .max_tool_calls_per_request = 8, + .default_max_context_tokens = null, .loop_stream_progress_enabled = true, .max_concurrent_connections = 64, .max_request_bytes = 1024 * 1024, diff --git a/src/core/request.zig b/src/core/request.zig index fcc5a53..6bf6835 100644 --- a/src/core/request.zig +++ b/src/core/request.zig @@ -169,6 +169,51 @@ pub fn checkedRequiredLength(header_len: usize, content_length: usize) !usize { return std.math.add(usize, header_len, content_length) catch error.RequestTooLarge; } +/// Returns a trimmed HTTP header value when present. +pub fn getHeaderValue(headers: []const u8, header_name: []const u8) ?[]const u8 { + var lines = std.mem.splitScalar(u8, headers, '\n'); + _ = lines.next(); + + while (lines.next()) |raw_line| { + const line = std.mem.trimRight(u8, raw_line, "\r"); + if (line.len == 0) break; + + const separator_index = std.mem.indexOfScalar(u8, line, ':') orelse continue; + const name = std.mem.trim(u8, line[0..separator_index], " "); + if (!std.ascii.eqlIgnoreCase(name, header_name)) continue; + + return std.mem.trim(u8, line[separator_index + 1 ..], " "); + } + + return null; +} + +/// Uses `X-Request-Id` when valid, otherwise generates `req-{micros}`. +pub fn resolveRequestIdAlloc( + allocator: std.mem.Allocator, + request_raw: []const u8, +) ![]u8 { + const header_end = std.mem.indexOf(u8, request_raw, "\r\n\r\n") orelse request_raw.len; + if (getHeaderValue(request_raw[0..header_end], "X-Request-Id")) |incoming| { + if (isValidRequestId(incoming)) { + return try allocator.dupe(u8, incoming); + } + } + + return try std.fmt.allocPrint(allocator, "req-{d}", .{std.time.microTimestamp()}); +} + +fn isValidRequestId(value: []const u8) bool { + if (value.len == 0 or value.len > 128) return false; + for (value) |c| { + switch (c) { + 'a'...'z', 'A'...'Z', '0'...'9', '_', '-', '.' => {}, + else => return false, + } + } + return true; +} + test "checkedRequiredLength rejects integer overflow" { try std.testing.expectError( error.RequestTooLarge, @@ -176,6 +221,24 @@ test "checkedRequiredLength rejects integer overflow" { ); } +test "resolveRequestIdAlloc generates fallback id" { + const allocator = std.testing.allocator; + const id = try resolveRequestIdAlloc(allocator, "GET /health HTTP/1.1\r\n\r\n"); + defer allocator.free(id); + try std.testing.expect(std.mem.startsWith(u8, id, "req-")); +} + +test "resolveRequestIdAlloc preserves incoming header" { + const allocator = std.testing.allocator; + const raw = + "POST /v1/chat/completions HTTP/1.1\r\n" ++ + "X-Request-Id: client-trace-42\r\n" ++ + "\r\n"; + const id = try resolveRequestIdAlloc(allocator, raw); + defer allocator.free(id); + try std.testing.expectEqualStrings("client-trace-42", id); +} + fn waitForReadable( socket: std.posix.socket_t, timeout_ms: u32, diff --git a/src/core/response.zig b/src/core/response.zig index 7ff853f..eb30f1e 100644 --- a/src/core/response.zig +++ b/src/core/response.zig @@ -50,6 +50,7 @@ fn sendJson( connection: std.net.Server.Connection, status_code: u16, body: []const u8, + request_id: ?[]const u8, ) !void { const status_text = switch (status_code) { 200 => "OK", @@ -65,16 +66,27 @@ fn sendJson( else => "Internal Server Error", }; + const request_id_header = if (request_id) |id| + try std.fmt.allocPrint( + std.heap.page_allocator, + "X-Request-Id: {s}\r\n", + .{id}, + ) + else + try std.heap.page_allocator.dupe(u8, ""); + defer std.heap.page_allocator.free(request_id_header); + // The server handles one request per connection. const response = try std.fmt.allocPrint( std.heap.page_allocator, "HTTP/1.1 {d} {s}\r\n" ++ "Content-Type: application/json\r\n" ++ + "{s}" ++ "Content-Length: {d}\r\n" ++ "Connection: close\r\n" ++ "\r\n" ++ "{s}", - .{ status_code, status_text, body.len, body }, + .{ status_code, status_text, request_id_header, body.len, body }, ); defer std.heap.page_allocator.free(response); @@ -86,7 +98,16 @@ pub fn sendJsonText( status_code: u16, body: []const u8, ) !void { - try sendJson(connection, status_code, body); + try sendJson(connection, status_code, body, null); +} + +pub fn sendJsonTextWithRequestId( + connection: std.net.Server.Connection, + status_code: u16, + body: []const u8, + request_id: []const u8, +) !void { + try sendJson(connection, status_code, body, request_id); } /// Send API error response @@ -94,6 +115,7 @@ pub fn sendApiError( connection: std.net.Server.Connection, allocator: std.mem.Allocator, api_error: errors.ApiError, + request_id: ?[]const u8, ) !void { const escaped_message = try escapeJsonStringAlloc(allocator, api_error.message); defer allocator.free(escaped_message); @@ -114,7 +136,7 @@ pub fn sendApiError( }); defer allocator.free(response_json); - try sendJson(connection, api_error.status_code, response_json); + try sendJson(connection, api_error.status_code, response_json, request_id); } fn makeCompletionIdAlloc(allocator: std.mem.Allocator) ![]u8 { @@ -138,6 +160,7 @@ pub fn sendChatCompletion( connection: std.net.Server.Connection, allocator: std.mem.Allocator, result: types.Response, + request_id: ?[]const u8, ) !void { var generated_id: ?[]u8 = null; defer if (generated_id) |id| allocator.free(id); @@ -176,17 +199,32 @@ pub fn sendChatCompletion( }); defer allocator.free(response_json); - try sendJson(connection, 200, response_json); + try sendJson(connection, 200, response_json, request_id); } -pub fn sendEventStreamHeaders(connection: std.net.Server.Connection) !void { - const response = +pub fn sendEventStreamHeaders(connection: std.net.Server.Connection, request_id: ?[]const u8) !void { + const request_id_header = if (request_id) |id| + try std.fmt.allocPrint( + std.heap.page_allocator, + "X-Request-Id: {s}\r\n", + .{id}, + ) + else + try std.heap.page_allocator.dupe(u8, ""); + defer std.heap.page_allocator.free(request_id_header); + + const response = try std.fmt.allocPrint( + std.heap.page_allocator, "HTTP/1.1 200 OK\r\n" ++ - "Content-Type: text/event-stream\r\n" ++ - "Cache-Control: no-cache\r\n" ++ - "Connection: close\r\n" ++ - "X-Accel-Buffering: no\r\n" ++ - "\r\n"; + "Content-Type: text/event-stream\r\n" ++ + "Cache-Control: no-cache\r\n" ++ + "{s}" ++ + "Connection: close\r\n" ++ + "X-Accel-Buffering: no\r\n" ++ + "\r\n", + .{request_id_header}, + ); + defer std.heap.page_allocator.free(response); try connection.stream.writeAll(response); } diff --git a/src/core/server.zig b/src/core/server.zig index 2b20606..56922e7 100644 --- a/src/core/server.zig +++ b/src/core/server.zig @@ -12,11 +12,21 @@ const session = @import("../backend/session.zig"); const react = @import("../react.zig"); const context_compaction_keep_messages = 8; +const max_recent_requests = 32; const length_continue_prompt = "The previous assistant message was cut off by the generation limit. " ++ "Continue the same answer from the exact next token. Output only the continuation text; " ++ "do not mention truncation, continuation, the previous message, or these instructions."; +const RecentRequest = struct { + route: [48]u8 = undefined, + route_len: u8 = 0, + status_code: u16 = 0, + duration_ms: u64 = 0, + request_id: [64]u8 = undefined, + request_id_len: u8 = 0, +}; + const ServerState = struct { mutex: std.Thread.Mutex = .{}, total_requests: u64 = 0, @@ -25,6 +35,9 @@ const ServerState = struct { active_connections: u64 = 0, provider_latency_buckets: [4]u64 = .{ 0, 0, 0, 0 }, connected_clients: std.ArrayList([]u8), + recent_requests: [max_recent_requests]RecentRequest = undefined, + recent_request_idx: usize = 0, + recent_request_count: usize = 0, fn init() ServerState { return .{ .connected_clients = .{} }; @@ -91,6 +104,67 @@ const ServerState = struct { self.provider_latency_buckets[idx] += 1; } + fn noteRecentRequest( + self: *ServerState, + route: []const u8, + status_code: u16, + duration_ms: u64, + request_id: []const u8, + ) void { + self.mutex.lock(); + defer self.mutex.unlock(); + + var entry = &self.recent_requests[self.recent_request_idx]; + const route_len = @min(route.len, entry.route.len); + @memcpy(entry.route[0..route_len], route[0..route_len]); + entry.route_len = @intCast(route_len); + entry.status_code = status_code; + entry.duration_ms = duration_ms; + + const request_id_len = @min(request_id.len, entry.request_id.len); + @memcpy(entry.request_id[0..request_id_len], request_id[0..request_id_len]); + entry.request_id_len = @intCast(request_id_len); + + self.recent_request_idx = (self.recent_request_idx + 1) % max_recent_requests; + if (self.recent_request_count < max_recent_requests) { + self.recent_request_count += 1; + } + } + + fn recentRequestsJsonAlloc(self: *ServerState, allocator: std.mem.Allocator) ![]u8 { + var out = std.ArrayList(u8){}; + errdefer out.deinit(allocator); + + self.mutex.lock(); + defer self.mutex.unlock(); + + try out.append(allocator, '['); + var written: usize = 0; + var idx = if (self.recent_request_count < max_recent_requests) + 0 + else + self.recent_request_idx; + + while (written < self.recent_request_count) : ({ + written += 1; + idx = (idx + 1) % max_recent_requests; + }) { + const entry = self.recent_requests[idx]; + if (written > 0) try out.append(allocator, ','); + try out.writer(allocator).print( + "{{\"route\":\"{s}\",\"status\":{d},\"duration_ms\":{d},\"request_id\":\"{s}\"}}", + .{ + entry.route[0..entry.route_len], + entry.status_code, + entry.duration_ms, + entry.request_id[0..entry.request_id_len], + }, + ); + } + try out.append(allocator, ']'); + return try out.toOwnedSlice(allocator); + } + fn snapshot(self: *ServerState) Snapshot { self.mutex.lock(); defer self.mutex.unlock(); @@ -277,6 +351,7 @@ pub fn run( 503, "{\"error\":{\"message\":\"Server is at connection capacity\",\"type\":\"server_error\",\"param\":null,\"code\":\"connection_capacity\"}}", app_config, + null, ); connection.stream.close(); continue; @@ -317,6 +392,18 @@ fn handleConnection( session_store: ?*session.SessionStore, session_store_guard: *SessionStoreGuard, ) !void { + const request_started_ms = std.time.milliTimestamp(); + var request_route: []const u8 = "unknown"; + var request_status: u16 = 500; + var request_id_owned: ?[]u8 = null; + defer if (request_id_owned) |id| allocator.free(id); + defer server_state.noteRecentRequest( + request_route, + request_status, + @intCast(@max(std.time.milliTimestamp() - request_started_ms, 0)), + request_id_owned orelse "unknown", + ); + // Translate transport-level parsing failures into OpenAI-style API errors. const request_raw = request.readHttpRequest( allocator, @@ -330,46 +417,46 @@ fn handleConnection( return; }, error.RequestTimedOut => { - sendApiErrorSafe(connection.*, allocator, backend.errors.requestTimeoutError(), app_config); + sendApiErrorSafe(connection.*, allocator, backend.errors.requestTimeoutError(), app_config, null); return; }, error.RequestTooLarge => { - sendApiErrorSafe(connection.*, allocator, backend.errors.payloadTooLargeError(), app_config); + sendApiErrorSafe(connection.*, allocator, backend.errors.payloadTooLargeError(), app_config, null); return; }, error.HeadersTooLarge => { sendApiErrorSafe(connection.*, allocator, backend.errors.httpError( "HTTP headers are too large", "headers_too_large", - ), app_config); + ), app_config, null); return; }, error.InvalidHttpRequest => { sendApiErrorSafe(connection.*, allocator, backend.errors.httpError( "Malformed HTTP request", "invalid_http_request", - ), app_config); + ), app_config, null); return; }, error.MissingContentLength => { sendApiErrorSafe(connection.*, allocator, backend.errors.httpError( "Missing Content-Length header", "missing_content_length", - ), app_config); + ), app_config, null); return; }, error.InvalidContentLength => { sendApiErrorSafe(connection.*, allocator, backend.errors.httpError( "Invalid Content-Length header", "invalid_content_length", - ), app_config); + ), app_config, null); return; }, error.IncompleteRequestBody => { sendApiErrorSafe(connection.*, allocator, backend.errors.httpError( "Incomplete request body", "incomplete_body", - ), app_config); + ), app_config, null); return; }, else => { @@ -379,6 +466,14 @@ fn handleConnection( }; defer allocator.free(request_raw); + request_id_owned = try request.resolveRequestIdAlloc(allocator, request_raw); + const request_id = request_id_owned.?; + + const header_end = std.mem.indexOf(u8, request_raw, "\r\n\r\n") orelse request_raw.len; + if (request.getHeaderValue(request_raw[0..header_end], "User-Agent")) |user_agent| { + server_state.noteClient(allocator, user_agent) catch {}; + } + if (request_raw.len == 0) return; server_state.noteRequestStarted(); @@ -392,19 +487,28 @@ fn handleConnection( sendApiErrorSafe(connection.*, allocator, backend.errors.httpError( "Malformed HTTP request", "invalid_http_request", - ), app_config); + ), app_config, request_id); server_state.noteRequestFailed(); return; }; if (route == null) { - sendApiErrorSafe(connection.*, allocator, backend.errors.notFoundError(), app_config); + sendApiErrorSafe(connection.*, allocator, backend.errors.notFoundError(), app_config, request_id); server_state.noteRequestFailed(); return; } + request_route = switch (route.?) { + .health => "GET /health", + .metrics => "GET /metrics", + .diagnostics_clients => "GET /diagnostics/clients", + .diagnostics_requests => "GET /diagnostics/requests", + .diagnostics_providers => "GET /diagnostics/providers", + .chat_completions => "POST /v1/chat/completions", + }; + if (requiresAuth(route.?) and auth.authorizeRequest(app_config.auth_api_key, request_raw) == .denied) { - sendJsonSafe(connection.*, 401, "{\"error\":{\"message\":\"Unauthorized\",\"type\":\"auth_error\",\"param\":null,\"code\":\"unauthorized\"}}", app_config); + sendJsonSafe(connection.*, 401, "{\"error\":{\"message\":\"Unauthorized\",\"type\":\"auth_error\",\"param\":null,\"code\":\"unauthorized\"}}", app_config, request_id); server_state.noteRequestFailed(); return; } @@ -417,7 +521,7 @@ fn handleConnection( .{app_config.instance_id}, ); defer allocator.free(health_json); - sendJsonSafe(connection.*, 200, health_json, app_config); + sendJsonSafe(connection.*, 200, health_json, app_config, request_id); server_state.noteRequestSucceeded(); return; }, @@ -439,7 +543,7 @@ fn handleConnection( }, ); defer allocator.free(metrics_json); - sendJsonSafe(connection.*, 200, metrics_json, app_config); + sendJsonSafe(connection.*, 200, metrics_json, app_config, request_id); server_state.noteRequestSucceeded(); return; }, @@ -454,25 +558,29 @@ fn handleConnection( .{ app_config.instance_id, snapshot.active_connections, clients_json }, ); defer allocator.free(payload); - sendJsonSafe(connection.*, 200, payload, app_config); + sendJsonSafe(connection.*, 200, payload, app_config, request_id); server_state.noteRequestSucceeded(); return; }, .diagnostics_requests => { const snapshot = server_state.snapshot(); + const recent_json = try server_state.recentRequestsJsonAlloc(allocator); + defer allocator.free(recent_json); const payload = try std.fmt.allocPrint( allocator, - "{{\"instance_id\":\"{s}\",\"total_requests\":{d},\"successful_requests\":{d},\"failed_requests\":{d}}}", + "{{\"instance_id\":\"{s}\",\"total_requests\":{d},\"successful_requests\":{d},\"failed_requests\":{d},\"recent_requests\":{s}}}", .{ app_config.instance_id, snapshot.total_requests, snapshot.successful_requests, snapshot.failed_requests, + recent_json, }, ); defer allocator.free(payload); - sendJsonSafe(connection.*, 200, payload, app_config); + sendJsonSafe(connection.*, 200, payload, app_config, null); server_state.noteRequestSucceeded(); + request_status = 200; return; }, .diagnostics_providers => { @@ -486,7 +594,7 @@ fn handleConnection( ); defer allocator.free(payload); - sendJsonSafe(connection.*, 200, payload, app_config); + sendJsonSafe(connection.*, 200, payload, app_config, request_id); server_state.noteRequestSucceeded(); return; }, @@ -497,7 +605,7 @@ fn handleConnection( sendApiErrorSafe(connection.*, allocator, backend.errors.httpError( "Missing request body", "missing_body", - ), app_config); + ), app_config, request_id); server_state.noteRequestFailed(); return; }; @@ -508,7 +616,7 @@ fn handleConnection( const parsed_req = switch (parse_result) { .ok => |parsed_request| parsed_request, .err => |api_error| { - sendApiErrorSafe(connection.*, allocator, api_error, app_config); + sendApiErrorSafe(connection.*, allocator, api_error, app_config, request_id); server_state.noteRequestFailed(); return; }, @@ -520,7 +628,7 @@ fn handleConnection( "One or more requested tools are not registered", "tools", "unknown_tool", - ), app_config); + ), app_config, request_id); server_state.noteRequestFailed(); return; } @@ -569,7 +677,9 @@ fn handleConnection( } } - if (parsed_req.max_context_tokens) |max_tokens| { + const effective_max_context_tokens = parsed_req.max_context_tokens orelse app_config.default_max_context_tokens; + + if (effective_max_context_tokens) |max_tokens| { const estimated = session.estimateTokenCount(request_messages); if (session.shouldCompressContext(estimated, max_tokens)) { const compacted_messages = try session.compactContextToBudgetAlloc( @@ -619,10 +729,10 @@ fn handleConnection( ); defer stream_request.deinit(allocator); - if (try tooling.tryExecuteDebugTool(allocator, stream_request, app_config)) |tool_result| { + if (try tooling.tryExecuteDebugTool(allocator, stream_request, app_config, request_id)) |tool_result| { defer tool_result.deinit(allocator); - try response.sendEventStreamHeaders(connection.*); + try response.sendEventStreamHeaders(connection.*, request_id); const completion_id = try std.fmt.allocPrint( allocator, "chatcmpl-{d}", @@ -657,18 +767,18 @@ fn handleConnection( "stream=true is currently supported only for ollama", "stream", "unsupported_stream_provider", - ), app_config); + ), app_config, request_id); server_state.noteRequestFailed(); return; } - const stream_auto_tool_summary = tooling.maybeExecutePromptToolsAlloc(allocator, stream_request, app_config) catch |err| switch (err) { + const stream_auto_tool_summary = tooling.maybeExecutePromptToolsAlloc(allocator, stream_request, app_config, request_id) catch |err| switch (err) { error.ToolCallLimitExceeded => { sendApiErrorSafe(connection.*, allocator, backend.errors.validationError( "Tool call limit exceeded", "tools", "tool_call_limit_exceeded", - ), app_config); + ), app_config, request_id); server_state.noteRequestFailed(); return; }, @@ -678,7 +788,7 @@ fn handleConnection( if (stream_auto_tool_summary) |summary| { if (tooling.shouldShortCircuitAutoTools(stream_request)) { - try response.sendEventStreamHeaders(connection.*); + try response.sendEventStreamHeaders(connection.*, request_id); const completion_id = try std.fmt.allocPrint( allocator, "chatcmpl-{d}", @@ -723,6 +833,7 @@ fn handleConnection( app_config, stream_request, app_config.loop_stream_progress_enabled, + request_id, ) catch |err| { logError("Provider stream loop error: {s}", .{@errorName(err)}); server_state.noteRequestFailed(); @@ -750,22 +861,34 @@ fn handleConnection( sendApiErrorSafe(connection.*, allocator, backend.errors.providerError( provider_error_response.output, "provider_error", - ), app_config); + ), app_config, request_id); server_state.noteRequestFailed(); return; }, } } - if (try tooling.tryExecuteDebugTool(allocator, parsed_req, app_config)) |tool_result| { + if (try tooling.tryExecuteDebugTool(allocator, parsed_req, app_config, request_id)) |tool_result| { defer tool_result.deinit(allocator); debugLog( app_config, "debug tool executed tool_choice={s}", .{parsed_req.tool_choice orelse "(none)"}, ); - sendChatCompletionSafe(connection.*, allocator, tool_result, app_config); + persistSessionState( + allocator, + app_config, + session_store, + session_store_guard, + parsed_req, + loaded_session, + request_messages, + null, + tool_result.output, + ); + sendChatCompletionSafe(connection.*, allocator, tool_result, app_config, request_id); server_state.noteRequestSucceeded(); + request_status = 200; return; } @@ -780,13 +903,13 @@ fn handleConnection( ); defer provider_request.deinit(allocator); - const auto_tool_summary = tooling.maybeExecutePromptToolsAlloc(allocator, provider_request, app_config) catch |err| switch (err) { + const auto_tool_summary = tooling.maybeExecutePromptToolsAlloc(allocator, provider_request, app_config, request_id) catch |err| switch (err) { error.ToolCallLimitExceeded => { sendApiErrorSafe(connection.*, allocator, backend.errors.validationError( "Tool call limit exceeded", "tools", "tool_call_limit_exceeded", - ), app_config); + ), app_config, request_id); server_state.noteRequestFailed(); return; }, @@ -798,8 +921,20 @@ fn handleConnection( if (tooling.shouldShortCircuitAutoTools(provider_request)) { var tool_response = try tooling.makeAutoToolResponse(allocator, summary); defer tool_response.deinit(allocator); - sendChatCompletionSafe(connection.*, allocator, tool_response, app_config); + persistSessionState( + allocator, + app_config, + session_store, + session_store_guard, + parsed_req, + loaded_session, + provider_request.messages, + null, + tool_response.output, + ); + sendChatCompletionSafe(connection.*, allocator, tool_response, app_config, request_id); server_state.noteRequestSucceeded(); + request_status = 200; return; } } @@ -830,13 +965,14 @@ fn handleConnection( const provider_started_ms = std.time.milliTimestamp(); if (loop_enabled) { - const loop_execution = executeLoopRequestAlloc(allocator, app_config, provider_request) catch |err| { + const loop_execution = executeLoopRequestAlloc(allocator, app_config, provider_request, request_id) catch |err| { if (err == error.ToolCallLimitExceeded) { sendApiErrorSafe( connection.*, allocator, backend.errors.validationError("Tool call limit exceeded", "tools", "tool_call_limit_exceeded"), app_config, + request_id, ); server_state.noteRequestFailed(); return; @@ -847,6 +983,7 @@ fn handleConnection( allocator, backend.errors.providerTransportError(@errorName(err)), app_config, + request_id, ); server_state.noteRequestFailed(); return; @@ -863,6 +1000,7 @@ fn handleConnection( allocator, backend.errors.providerTransportError(@errorName(err)), app_config, + request_id, ); server_state.noteRequestFailed(); return; @@ -895,6 +1033,7 @@ fn handleConnection( allocator, backend.errors.providerFailureFromDetail(normalized_provider, result.output), app_config, + request_id, ); server_state.noteRequestFailed(); return; @@ -926,6 +1065,7 @@ fn handleConnection( "empty_model_response", ), app_config, + request_id, ); server_state.noteRequestFailed(); return; @@ -939,52 +1079,24 @@ fn handleConnection( ); if (session_store) |store| { - if (parsed_req.session_id) |session_id| { - const loop_messages = loop_messages_for_persistence; - const with_assistant = if (loop_messages) |messages| blk: { - break :blk try session.cloneMessagesAlloc(allocator, messages); - } else session.appendAssistantMessageAlloc( + if (parsed_req.session_id != null) { + persistSessionState( allocator, + app_config, + store, + session_store_guard, + parsed_req, + loaded_session, provider_request.messages, + loop_messages_for_persistence, result.output, - ) catch |err| blk: { - logError("Session append failed for '{s}': {s}", .{ session_id, @errorName(err) }); - break :blk null; - }; - - if (with_assistant) |messages| { - defer { - for (messages) |message| { - message.deinit(allocator); - } - allocator.free(messages); - } - - var state = session.SessionState{ - .session_id = try allocator.dupe(u8, session_id), - .tenant_id = if (parsed_req.tenant_id) |value| try allocator.dupe(u8, value) else null, - .summary = if (loaded_session) |loaded| try allocator.dupe(u8, loaded.summary) else try allocator.dupe(u8, ""), - .messages = try session.trimToRetentionAlloc( - allocator, - messages, - app_config.session_retention_messages, - ), - .message_count = 0, - }; - state.message_count = state.messages.len; - defer state.deinit(allocator); - - session_store_guard.mutex.lock(); - store.save(allocator, state) catch |err| { - logError("Session save failed for '{s}': {s}", .{ session_id, @errorName(err) }); - }; - session_store_guard.mutex.unlock(); - } + ); } } - sendChatCompletionSafe(connection.*, allocator, result, app_config); + sendChatCompletionSafe(connection.*, allocator, result, app_config, request_id); server_state.noteRequestSucceeded(); + request_status = 200; } fn cloneRequestWithMessagesAlloc( @@ -1049,6 +1161,7 @@ fn executeReactActionForRequest( turn_request: types.Request, tool_messages: []const types.Message, action: react.ReactAction, + request_id: []const u8, ) ![]u8 { return switch (action) { .tool => |tool_action| blk: { @@ -1065,7 +1178,7 @@ fn executeReactActionForRequest( } tool_request.tool_choice = try allocator.dupe(u8, tool_action.name); - if (try tooling.tryExecuteDebugTool(allocator, tool_request, app_config)) |tool_result| { + if (try tooling.tryExecuteDebugTool(allocator, tool_request, app_config, request_id)) |tool_result| { defer tool_result.deinit(allocator); break :blk try allocator.dupe(u8, tool_result.output); } @@ -1086,12 +1199,13 @@ fn streamLoopRequestToSse( app_config: *const config.Config, base_request: types.Request, emit_progress: bool, + request_id: []const u8, ) !void { const loop_mode = base_request.loop_mode orelse "basic"; const loop_until = base_request.loop_until orelse "DONE"; const loop_max_turns = base_request.loop_max_turns orelse 8; - try response.sendEventStreamHeaders(connection); + try response.sendEventStreamHeaders(connection, request_id); var headers_sent: bool = true; var think_block_open: bool = false; @@ -1261,7 +1375,7 @@ fn streamLoopRequestToSse( }, else => { try tooling.noteToolCall(app_config.max_tool_calls_per_request, &react_tool_calls); - const obs_raw = try executeReactActionForRequest(allocator, app_config, turn_request, working_messages, action); + const obs_raw = try executeReactActionForRequest(allocator, app_config, turn_request, working_messages, action, request_id); defer allocator.free(obs_raw); react_observation = try react.formatObservation(allocator, turn + 1, obs_raw); }, @@ -1342,6 +1456,7 @@ fn executeLoopRequestAlloc( allocator: std.mem.Allocator, app_config: *const config.Config, base_request: types.Request, + request_id: []const u8, ) !LoopExecution { const loop_mode = base_request.loop_mode orelse "basic"; const loop_until = base_request.loop_until orelse "DONE"; @@ -1480,7 +1595,7 @@ fn executeLoopRequestAlloc( }, else => { try tooling.noteToolCall(app_config.max_tool_calls_per_request, &react_tool_calls); - const obs_raw = try executeReactActionForRequest(allocator, app_config, turn_request, working_messages, action); + const obs_raw = try executeReactActionForRequest(allocator, app_config, turn_request, working_messages, action, request_id); defer allocator.free(obs_raw); react_observation = try react.formatObservation(allocator, turn + 1, obs_raw); }, @@ -1662,13 +1777,72 @@ fn isLengthFinishReason(finish_reason: []const u8) bool { return std.ascii.eqlIgnoreCase(finish_reason, "length"); } +fn persistSessionState( + allocator: std.mem.Allocator, + app_config: *const config.Config, + session_store: ?*session.SessionStore, + session_store_guard: *SessionStoreGuard, + parsed_req: types.Request, + loaded_session: ?session.SessionState, + base_messages: []const types.Message, + loop_messages: ?[]types.Message, + assistant_output: []const u8, +) void { + const store = session_store orelse return; + const session_id = parsed_req.session_id orelse return; + + const with_assistant = if (loop_messages) |messages| blk: { + break :blk session.cloneMessagesAlloc(allocator, messages) catch |err| blk2: { + logError("Session append failed for '{s}': {s}", .{ session_id, @errorName(err) }); + break :blk2 null; + }; + } else session.appendAssistantMessageAlloc( + allocator, + base_messages, + assistant_output, + ) catch |err| blk: { + logError("Session append failed for '{s}': {s}", .{ session_id, @errorName(err) }); + break :blk null; + }; + + if (with_assistant) |messages| { + defer { + for (messages) |message| { + message.deinit(allocator); + } + allocator.free(messages); + } + + var state = session.SessionState{ + .session_id = allocator.dupe(u8, session_id) catch return, + .tenant_id = if (parsed_req.tenant_id) |value| allocator.dupe(u8, value) catch return else null, + .summary = if (loaded_session) |loaded| allocator.dupe(u8, loaded.summary) catch return else allocator.dupe(u8, "") catch return, + .messages = session.trimToRetentionAlloc( + allocator, + messages, + app_config.session_retention_messages, + ) catch return, + .message_count = 0, + }; + state.message_count = state.messages.len; + defer state.deinit(allocator); + + session_store_guard.mutex.lock(); + store.save(allocator, state) catch |err| { + logError("Session save failed for '{s}': {s}", .{ session_id, @errorName(err) }); + }; + session_store_guard.mutex.unlock(); + } +} + fn sendApiErrorSafe( connection: std.net.Server.Connection, allocator: std.mem.Allocator, api_error: backend.errors.ApiError, app_config: *const config.Config, + request_id: ?[]const u8, ) void { - response.sendApiError(connection, allocator, api_error) catch |err| { + response.sendApiError(connection, allocator, api_error, request_id) catch |err| { swallowSocketWriteError(app_config, err); }; } @@ -1678,10 +1852,17 @@ fn sendJsonSafe( status_code: u16, body: []const u8, app_config: *const config.Config, + request_id: ?[]const u8, ) void { - response.sendJsonText(connection, status_code, body) catch |err| { - swallowSocketWriteError(app_config, err); - }; + if (request_id) |id| { + response.sendJsonTextWithRequestId(connection, status_code, body, id) catch |err| { + swallowSocketWriteError(app_config, err); + }; + } else { + response.sendJsonText(connection, status_code, body) catch |err| { + swallowSocketWriteError(app_config, err); + }; + } } fn sendChatCompletionSafe( @@ -1689,8 +1870,9 @@ fn sendChatCompletionSafe( allocator: std.mem.Allocator, result: types.Response, app_config: *const config.Config, + request_id: ?[]const u8, ) void { - response.sendChatCompletion(connection, allocator, result) catch |err| { + response.sendChatCompletion(connection, allocator, result, request_id) catch |err| { swallowSocketWriteError(app_config, err); }; } diff --git a/src/providers/ollama_qwen.zig b/src/providers/ollama_qwen.zig index 8eee0b3..3497422 100644 --- a/src/providers/ollama_qwen.zig +++ b/src/providers/ollama_qwen.zig @@ -692,7 +692,7 @@ fn streamChatMessagesToSse( } if (!headers_sent.*) { - try response.sendEventStreamHeaders(connection); + try response.sendEventStreamHeaders(connection, null); headers_sent.* = true; } diff --git a/src/react.zig b/src/react.zig index fd7bc11..cbb156d 100644 --- a/src/react.zig +++ b/src/react.zig @@ -40,19 +40,18 @@ pub fn buildSystemPromptAlloc(allocator: std.mem.Allocator, tools: []const types try out.appendSlice(allocator, system_prompt); try out.appendSlice( allocator, - "\nRequested API tools are also valid Actions when named exactly:\n", + "\nRequested API tools (names only; invoke with exact name in brackets):\n", ); - for (tools) |tool| { - try out.writer(allocator).print( - " {s}[argument] - {s}\n", - .{ tool.name, tool.description }, - ); + for (tools, 0..) |tool, idx| { + if (idx > 0) try out.append(allocator, ','); + try out.appendSlice(allocator, " "); + try out.appendSlice(allocator, tool.name); } try out.appendSlice( allocator, - "\nTool rules:\n" ++ + "\n\nTool rules:\n" ++ "- For file_write, include the path and fenced code block inside the brackets, for example:\n" ++ " Action N: file_write[write relative/path\n```lang\n...\n```]\n" ++ "- Put the full tool input inside the brackets.\n", @@ -261,7 +260,7 @@ fn executeReactCmd( defer req.deinit(allocator); const shell: command_exec_tool.ShellFlavor = if (@import("builtin").os.tag == .windows) .cmd else .bash; - var result = try command_exec_tool.execute(allocator, app_config, req, shell); + var result = try command_exec_tool.execute(allocator, app_config, req, shell, "react-local"); defer result.deinit(allocator); // Dupe the output to transfer ownership to the caller; result.deinit @@ -423,7 +422,8 @@ test "buildSystemPromptAlloc includes requested tools" { const prompt = try buildSystemPromptAlloc(allocator, tools[0..]); defer allocator.free(prompt); - try std.testing.expect(std.mem.indexOf(u8, prompt, "file_write[argument]") != null); + try std.testing.expect(std.mem.indexOf(u8, prompt, "file_write") != null); + try std.testing.expect(std.mem.indexOf(u8, prompt, "names only") != null); try std.testing.expect(std.mem.indexOf(u8, prompt, "file_write[write relative/path") != null); } diff --git a/src/root.zig b/src/root.zig index 9f0215b..8374287 100644 --- a/src/root.zig +++ b/src/root.zig @@ -27,6 +27,7 @@ test "backend: auth and tool scaffolding" { _ = @import("backend/tools.zig"); _ = @import("backend/session.zig"); _ = @import("backend/mcp.zig"); + _ = @import("backend/tool_output.zig"); } test "core: HTTP request parsing" { diff --git a/src/tools/command_exec.zig b/src/tools/command_exec.zig index c74aaf3..09de654 100644 --- a/src/tools/command_exec.zig +++ b/src/tools/command_exec.zig @@ -1,6 +1,7 @@ const std = @import("std"); const config = @import("../config.zig"); const types = @import("../types.zig"); +const tool_output = @import("../backend/tool_output.zig"); pub const ShellFlavor = enum { cmd, @@ -69,6 +70,7 @@ pub fn execute( app_config: *const config.Config, request: types.Request, flavor: ShellFlavor, + request_id: []const u8, ) !types.Response { const tool_name = switch (flavor) { .cmd => "cmd", @@ -112,7 +114,7 @@ pub fn execute( const command = blk: { if (looksLikeCommandPrompt(prompt)) { - break :blk validateAndNormalizeCommandAlloc(allocator, prompt, flavor) catch |err| switch (err) { + break :blk validateAndNormalizeCommandAlloc(allocator, prompt, flavor, app_config.tool_exec_trusted_local) catch |err| switch (err) { error.OutOfMemory => return err, else => { return try makeToolResponse( @@ -128,7 +130,7 @@ pub fn execute( }; } - const extracted = try extractCommandFromPromptAlloc(allocator, flavor, prompt); + const extracted = try extractCommandFromPromptAlloc(allocator, flavor, prompt, app_config.tool_exec_trusted_local); if (extracted) |value| { break :blk value; } @@ -242,10 +244,10 @@ pub fn execute( std.Thread.sleep(2 * std.time.ns_per_ms); } - const terminated_as: std.process.Child.Term = if (timed_out) - child.kill() catch .{ .Unknown = 1 } - else - child.wait() catch .{ .Unknown = 1 }; + if (timed_out) { + _ = child.kill() catch {}; + } + const terminated_as = child.wait() catch std.process.Child.Term{ .Unknown = 1 }; if (stdout_thread) |thread| thread.join(); if (stderr_thread) |thread| thread.join(); @@ -293,8 +295,17 @@ pub fn execute( stderr_text, }, ); + defer allocator.free(output); + + const final_output = try tool_output.maybeOffloadToolOutputAlloc( + allocator, + app_config, + request_id, + tool_name, + output, + ); - return try makeToolResponse(allocator, tool_name, output); + return try makeToolResponse(allocator, tool_name, final_output); } fn capturePipeMain(capture: *PipeCapture, file: std.fs.File, max_bytes: usize) void { @@ -310,6 +321,7 @@ pub fn extractCommandFromPromptAlloc( allocator: std.mem.Allocator, flavor: ShellFlavor, prompt: []const u8, + trusted_local: bool, ) !?[]u8 { const without_confirm = stripToolConfirmMarkerSuffix(prompt); const trimmed = std.mem.trim(u8, without_confirm, " \t\r\n\"'"); @@ -318,26 +330,26 @@ pub fn extractCommandFromPromptAlloc( if (extractQuotedCommandSlice(trimmed)) |quoted| { const normalized = std.mem.trim(u8, quoted, " \t\r\n"); if (normalized.len > 0) { - return try validateAndNormalizeCommandAlloc(allocator, normalized, flavor); + return try validateAndNormalizeCommandAlloc(allocator, normalized, flavor, trusted_local); } } if (stripCommandInstructionPrefix(trimmed)) |candidate| { const normalized = normalizeExtractedCandidate(candidate); if (normalized.len > 0) { - return try validateAndNormalizeCommandAlloc(allocator, normalized, flavor); + return try validateAndNormalizeCommandAlloc(allocator, normalized, flavor, trusted_local); } } if (extractInlineRunSegment(trimmed)) |candidate| { const normalized = normalizeExtractedCandidate(candidate); if (normalized.len > 0 and looksLikeCommandPrompt(normalized)) { - return try validateAndNormalizeCommandAlloc(allocator, normalized, flavor); + return try validateAndNormalizeCommandAlloc(allocator, normalized, flavor, trusted_local); } } if (looksLikeCommandPrompt(trimmed)) { - return try validateAndNormalizeCommandAlloc(allocator, trimmed, flavor); + return try validateAndNormalizeCommandAlloc(allocator, trimmed, flavor, trusted_local); } return null; @@ -347,11 +359,12 @@ fn validateAndNormalizeCommandAlloc( allocator: std.mem.Allocator, command: []const u8, flavor: ShellFlavor, + trusted_local: bool, ) ![]u8 { const trimmed = std.mem.trim(u8, command, " \t\r\n\"'"); if (trimmed.len == 0) return error.EmptyCommand; if (trimmed.len > max_command_len) return error.CommandTooLong; - if (containsUnsafeCommandByte(trimmed)) return error.UnsafeCharacter; + if (containsUnsafeCommandByte(trimmed, trusted_local)) return error.UnsafeCharacter; const command_name = firstToken(trimmed) orelse return error.EmptyCommand; if (isDangerousCommandName(flavor, command_name)) return error.DangerousCommand; @@ -384,10 +397,10 @@ fn looksLikeCommandPrompt(text: []const u8) bool { return true; } -fn containsUnsafeCommandByte(text: []const u8) bool { +fn containsUnsafeCommandByte(text: []const u8, trusted_local: bool) bool { for (text) |byte| { if (byte < 0x20 or byte == 0x7f) return true; - if (isUnsafeShellByte(byte)) return true; + if (!trusted_local and isUnsafeShellByte(byte)) return true; } return false; } @@ -668,9 +681,12 @@ pub fn buildTestConfig(allocator: std.mem.Allocator, enable_exec: bool) !config. .session_retention_messages = 24, .tool_exec_enabled = enable_exec, .tool_exec_confirmation_required = false, + .tool_exec_trusted_local = false, .tool_exec_timeout_ms = 15_000, .tool_exec_max_output_bytes = 65_536, + .tool_output_offload_bytes = 8192, .max_tool_calls_per_request = 8, + .default_max_context_tokens = null, .loop_stream_progress_enabled = true, .max_concurrent_connections = 64, .max_request_bytes = 1024 * 1024, @@ -704,7 +720,7 @@ test "execute cmd tool returns disabled marker when not enabled" { }; defer req.deinit(allocator); - var result = try execute(allocator, &cfg, req, .cmd); + var result = try execute(allocator, &cfg, req, .cmd, "test-request"); defer result.deinit(allocator); try std.testing.expect(std.mem.indexOf(u8, result.output, "DEBUG_TOOL_ERROR") != null); @@ -718,6 +734,7 @@ test "extractCommandFromPromptAlloc extracts from run prefix" { allocator, .cmd, "Please run zig build test", + false, ); defer if (maybe_cmd) |cmd| allocator.free(cmd); @@ -731,6 +748,7 @@ test "extractCommandFromPromptAlloc ignores natural language requests" { allocator, .cmd, "Please write a small C++ program", + false, ); defer if (maybe_cmd) |cmd| allocator.free(cmd); @@ -763,7 +781,7 @@ test "execute cmd tool rejects unsafe chaining" { }; defer req.deinit(allocator); - var result = try execute(allocator, &cfg, req, .cmd); + var result = try execute(allocator, &cfg, req, .cmd, "test-request"); defer result.deinit(allocator); try std.testing.expect(std.mem.indexOf(u8, result.output, "DEBUG_TOOL_ERROR") != null); @@ -796,7 +814,7 @@ test "execute cmd tool blocks dangerous denylist command" { }; defer req.deinit(allocator); - var result = try execute(allocator, &cfg, req, .cmd); + var result = try execute(allocator, &cfg, req, .cmd, "test-request"); defer result.deinit(allocator); try std.testing.expect(std.mem.indexOf(u8, result.output, "DEBUG_TOOL_ERROR") != null); @@ -830,7 +848,7 @@ test "execute cmd tool enforces timeout and kills process" { }; defer req.deinit(allocator); - var result = try execute(allocator, &cfg, req, .cmd); + var result = try execute(allocator, &cfg, req, .cmd, "test-request"); defer result.deinit(allocator); try std.testing.expect(std.mem.indexOf(u8, result.output, "DEBUG_TOOL_ERROR") != null); From 1a22adcd6eb0b0d2a1b8761be0ec200aa1d5fc50 Mon Sep 17 00:00:00 2001 From: Ajeets6 Date: Fri, 29 May 2026 15:56:35 +1000 Subject: [PATCH 16/25] Add default model and workspace mode config Load LLM_ROUTER_MODEL and workspace env flags. Accept OpenAI and OpenRouter base URL alias names. --- src/config.zig | 120 ++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 114 insertions(+), 6 deletions(-) diff --git a/src/config.zig b/src/config.zig index ef4c60f..d2f83ca 100644 --- a/src/config.zig +++ b/src/config.zig @@ -13,6 +13,7 @@ pub const Config = struct { listen_port: u16, debug_logging: bool, default_provider: []const u8, + default_model: ?[]const u8, request_timeout_ms: u32, provider_timeout_ms: u32, instance_id: []const u8, @@ -53,6 +54,8 @@ pub const Config = struct { tool_output_offload_bytes: usize, max_tool_calls_per_request: usize, default_max_context_tokens: ?usize, + workspace_mode_enabled: bool, + workspace_root: []const u8, loop_stream_progress_enabled: bool, max_concurrent_connections: usize, max_request_bytes: usize, @@ -72,9 +75,9 @@ pub const Config = struct { "ollama", env_overrides, ); - errdefer allocator.free(requested_default_provider); const normalized_default_provider = types.normalizeProviderName(requested_default_provider) orelse { + allocator.free(requested_default_provider); return error.InvalidProvider; }; @@ -82,6 +85,18 @@ pub const Config = struct { allocator.free(requested_default_provider); errdefer allocator.free(default_provider); + const default_model_raw = try getEnvOrDefault( + allocator, + "LLM_ROUTER_MODEL", + "", + env_overrides, + ); + const default_model: ?[]const u8 = if (default_model_raw.len == 0) blk: { + allocator.free(default_model_raw); + break :blk null; + } else default_model_raw; + errdefer if (default_model) |model| allocator.free(model); + return Config{ .listen_host = try getEnvOrDefault( allocator, @@ -92,6 +107,7 @@ pub const Config = struct { .listen_port = try getEnvPortOrDefault("LLM_ROUTER_PORT", 8081, env_overrides), .debug_logging = try getEnvFlag(allocator, "LLM_ROUTER_DEBUG", env_overrides), .default_provider = default_provider, + .default_model = default_model, .request_timeout_ms = try getEnvPositiveU32OrDefault("LLM_ROUTER_REQUEST_TIMEOUT_MS", 30_000, env_overrides), .provider_timeout_ms = try getEnvPositiveU32OrDefault("LLM_ROUTER_PROVIDER_TIMEOUT_MS", 60_000, env_overrides), .instance_id = try getEnvOrDefault( @@ -122,9 +138,9 @@ pub const Config = struct { .ollama_num_predict = try getEnvU32OrDefault("OLLAMA_NUM_PREDICT", 2048, env_overrides), .ollama_temperature = try getEnvF64OrDefault(allocator, "OLLAMA_TEMPERATURE", 0.7, env_overrides), .ollama_repeat_penalty = try getEnvF64OrDefault(allocator, "OLLAMA_REPEAT_PENALTY", 1.05, env_overrides), - .openai_base_url = try getEnvOrDefault( + .openai_base_url = try getFirstEnvOrDefault( allocator, - "OPENAI_BASE_URL", + &.{ "OPENAI_BASE_URL", "OPENAI_API_BASE" }, "https://api.openai.com/v1", env_overrides, ), @@ -140,9 +156,9 @@ pub const Config = struct { "gpt-4.1-mini", env_overrides, ), - .openrouter_base_url = try getEnvOrDefault( + .openrouter_base_url = try getFirstEnvOrDefault( allocator, - "OPENROUTER_BASE_URL", + &.{ "OPENROUTER_BASE_URL", "OPENROUTER_API_BASE" }, "https://openrouter.ai/api/v1", env_overrides, ), @@ -266,9 +282,20 @@ pub const Config = struct { .tool_output_offload_bytes = try getEnvPositiveUsizeOrDefault("LLM_ROUTER_TOOL_OUTPUT_OFFLOAD_BYTES", 8192, env_overrides), .max_tool_calls_per_request = try getEnvPositiveUsizeOrDefault("LLM_ROUTER_MAX_TOOL_CALLS_PER_REQUEST", 8, env_overrides), .default_max_context_tokens = blk: { - const raw = getEnvPositiveUsizeOrDefault("LLM_ROUTER_DEFAULT_MAX_CONTEXT_TOKENS", 0, env_overrides) catch 0; + const raw = try getEnvUsizeOrDefault( + "LLM_ROUTER_DEFAULT_MAX_CONTEXT_TOKENS", + 0, + env_overrides, + ); break :blk if (raw == 0) null else raw; }, + .workspace_mode_enabled = try getEnvFlag(allocator, "LLM_ROUTER_WORKSPACE_MODE", env_overrides), + .workspace_root = try getEnvOrDefault( + allocator, + "LLM_ROUTER_WORKSPACE_ROOT", + ".", + env_overrides, + ), .loop_stream_progress_enabled = try getEnvFlagOrDefault(allocator, "LLM_ROUTER_LOOP_STREAM_PROGRESS_ENABLED", true, env_overrides), .max_concurrent_connections = try getEnvPositiveUsizeOrDefault("LLM_ROUTER_MAX_CONCURRENT_CONNECTIONS", 64, env_overrides), .max_request_bytes = try getEnvPositiveUsizeOrDefault("LLM_ROUTER_MAX_REQUEST_BYTES", 1024 * 1024, env_overrides), @@ -284,6 +311,7 @@ pub const Config = struct { pub fn deinit(self: *Config, allocator: std.mem.Allocator) void { allocator.free(self.listen_host); allocator.free(self.default_provider); + if (self.default_model) |model| allocator.free(model); allocator.free(self.instance_id); allocator.free(self.auth_api_key); allocator.free(self.ollama_base_url); @@ -309,6 +337,7 @@ pub const Config = struct { allocator.free(self.llama_cpp_api_key); allocator.free(self.llama_cpp_model); allocator.free(self.session_store_path); + allocator.free(self.workspace_root); } pub fn setDefaultProvider( @@ -324,6 +353,40 @@ pub const Config = struct { allocator.free(self.default_provider); self.default_provider = next_provider; } + + pub fn setDefaultModel( + self: *Config, + allocator: std.mem.Allocator, + model: []const u8, + ) !void { + if (self.default_model) |existing| allocator.free(existing); + self.default_model = if (model.len == 0) null else try allocator.dupe(u8, model); + } + + /// Model used when a request omits `model` (or sends `"auto"`). + pub fn defaultModel(self: *const Config) []const u8 { + return self.modelForProvider(self.default_provider); + } + + /// Resolves the configured model for a provider. `LLM_ROUTER_MODEL` overrides + /// per-provider defaults when the provider matches `LLM_ROUTER_PROVIDER`. + pub fn modelForProvider(self: *const Config, provider: []const u8) []const u8 { + const normalized = types.normalizeProviderName(provider) orelse provider; + if (self.default_model) |unified| { + if (std.mem.eql(u8, normalized, self.default_provider)) { + return unified; + } + } + + if (std.mem.eql(u8, normalized, "ollama_qwen")) return self.ollama_model; + if (std.mem.eql(u8, normalized, "openai")) return self.openai_model; + if (std.mem.eql(u8, normalized, "openrouter")) return self.openrouter_model; + if (std.mem.eql(u8, normalized, "claude")) return self.claude_model; + if (std.mem.eql(u8, normalized, "bedrock")) return self.bedrock_model; + if (std.mem.eql(u8, normalized, "llama_cpp")) return self.llama_cpp_model; + + return self.openai_model; + } }; fn validateProviderName(provider: []const u8) !void { @@ -404,6 +467,23 @@ fn getEnvU32OrDefault( }; } +fn getEnvUsizeOrDefault( + comptime key: []const u8, + default_value: usize, + env_overrides: ?*const EnvOverrides, +) !usize { + if (env_overrides) |overrides| { + if (overrides.get(key)) |value| { + return try std.fmt.parseInt(usize, value, 10); + } + } + + return std.process.parseEnvVarInt(key, usize, 10) catch |err| switch (err) { + error.EnvironmentVariableNotFound => default_value, + else => err, + }; +} + fn getEnvPositiveU32OrDefault( comptime key: []const u8, default_value: u32, @@ -514,6 +594,7 @@ test "setDefaultProvider stores canonical provider alias" { .listen_port = 8081, .debug_logging = false, .default_provider = try allocator.dupe(u8, "ollama_qwen"), + .default_model = null, .request_timeout_ms = 30_000, .provider_timeout_ms = 60_000, .instance_id = try allocator.dupe(u8, "local-instance"), @@ -554,6 +635,8 @@ test "setDefaultProvider stores canonical provider alias" { .tool_output_offload_bytes = 8192, .max_tool_calls_per_request = 8, .default_max_context_tokens = null, + .workspace_mode_enabled = false, + .workspace_root = try allocator.dupe(u8, "."), .loop_stream_progress_enabled = true, .max_concurrent_connections = 64, .max_request_bytes = 1024 * 1024, @@ -570,3 +653,28 @@ test "setDefaultProvider stores canonical provider alias" { try cfg.setDefaultProvider(allocator, "bedrock"); try std.testing.expectEqualStrings("bedrock", cfg.default_provider); } + +test "modelForProvider honors LLM_ROUTER_MODEL for active provider only" { + const allocator = std.testing.allocator; + + var overrides = EnvOverrides.init(allocator); + try overrides.put(try allocator.dupe(u8, "LLM_ROUTER_PROVIDER"), try allocator.dupe(u8, "openrouter")); + try overrides.put(try allocator.dupe(u8, "LLM_ROUTER_MODEL"), try allocator.dupe(u8, "openai/gpt-4o-mini")); + try overrides.put(try allocator.dupe(u8, "OPENROUTER_MODEL"), try allocator.dupe(u8, "openrouter/auto")); + try overrides.put(try allocator.dupe(u8, "OLLAMA_MODEL"), try allocator.dupe(u8, "qwen3.5:9b")); + defer { + var it = overrides.iterator(); + while (it.next()) |entry| { + allocator.free(entry.key_ptr.*); + allocator.free(entry.value_ptr.*); + } + overrides.deinit(); + } + + var cfg = try Config.loadWithOverrides(allocator, &overrides); + defer cfg.deinit(allocator); + + try std.testing.expectEqualStrings("openai/gpt-4o-mini", cfg.defaultModel()); + try std.testing.expectEqualStrings("openai/gpt-4o-mini", cfg.modelForProvider("openrouter")); + try std.testing.expectEqualStrings("qwen3.5:9b", cfg.modelForProvider("ollama")); +} From fb1325449317703a1181a20477a7546b64fece27 Mon Sep 17 00:00:00 2001 From: Ajeets6 Date: Fri, 29 May 2026 15:56:35 +1000 Subject: [PATCH 17/25] Add workspace_id to chat request model Parse workspace_id on chat completions requests. Extend shared request types for workspace routing. --- src/backend/api.zig | 24 ++++++++++++++++++++++++ src/types.zig | 4 ++++ 2 files changed, 28 insertions(+) diff --git a/src/backend/api.zig b/src/backend/api.zig index 8d13ef5..2bbeb6f 100644 --- a/src/backend/api.zig +++ b/src/backend/api.zig @@ -385,6 +385,29 @@ pub fn parseChatRequest( null; errdefer if (tenant_id) |value| allocator.free(value); + const workspace_id = if (obj.get("workspace_id")) |workspace_id_value| + switch (workspace_id_value) { + .string => |value| blk: { + const trimmed = std.mem.trim(u8, value, " \t\r\n"); + if (trimmed.len == 0) { + return .{ .err = errors.validationError( + "workspace_id must not be empty", + "workspace_id", + "invalid_workspace_id", + ) }; + } + break :blk try allocator.dupe(u8, trimmed); + }, + else => return .{ .err = errors.validationError( + "workspace_id must be a string", + "workspace_id", + "invalid_workspace_id", + ) }, + } + else + null; + errdefer if (workspace_id) |value| allocator.free(value); + const max_context_tokens = if (obj.get("max_context_tokens")) |tokens_value| switch (tokens_value) { .integer => |value| blk: { @@ -683,6 +706,7 @@ pub fn parseChatRequest( .repeat_penalty = repeat_penalty, .session_id = session_id, .tenant_id = tenant_id, + .workspace_id = workspace_id, .max_context_tokens = max_context_tokens, .tools = parsed_tools, .tool_choice = tool_choice, diff --git a/src/types.zig b/src/types.zig index 7b3eacf..3f3b9c8 100644 --- a/src/types.zig +++ b/src/types.zig @@ -15,6 +15,7 @@ pub const Request = struct { repeat_penalty: ?f64 = null, session_id: ?[]const u8 = null, tenant_id: ?[]const u8 = null, + workspace_id: ?[]const u8 = null, max_context_tokens: ?usize = null, tools: []Tool, tool_choice: ?[]const u8 = null, @@ -45,6 +46,9 @@ pub const Request = struct { if (self.tenant_id) |tenant_id| { allocator.free(tenant_id); } + if (self.workspace_id) |workspace_id| { + allocator.free(workspace_id); + } for (self.tools) |tool| { tool.deinit(allocator); } From 286ade1bae2de815dd7ac9a3a5218d313e48be8a Mon Sep 17 00:00:00 2001 From: Ajeets6 Date: Fri, 29 May 2026 15:56:35 +1000 Subject: [PATCH 18/25] Add ephemeral workspace memory store Keep per-workspace messages and file hints in RAM. Clear workspace state when the server exits. --- src/backend/workspace.zig | 230 ++++++++++++++++++++++++++++++++++++++ src/root.zig | 1 + 2 files changed, 231 insertions(+) create mode 100644 src/backend/workspace.zig diff --git a/src/backend/workspace.zig b/src/backend/workspace.zig new file mode 100644 index 0000000..501b75d --- /dev/null +++ b/src/backend/workspace.zig @@ -0,0 +1,230 @@ +//! Lightweight ephemeral workspace memory (Cursor-like, cleared on server exit). +//! +//! Stores conversation turns and a small file-activity hint keyed by `workspace_id`. +//! Does not replace the core harness loop; it only substitutes disk `session_id` with +//! in-process state and enables repo-root file tools when workspace mode is on. +const std = @import("std"); +const types = @import("../types.zig"); +const session = @import("session.zig"); + +pub const max_file_hints = 16; + +pub const WorkspaceState = struct { + workspace_id: []const u8, + messages: []types.Message, + file_hints: [][]const u8, + + pub fn deinit(self: *WorkspaceState, allocator: std.mem.Allocator) void { + allocator.free(self.workspace_id); + for (self.messages) |message| { + message.deinit(allocator); + } + allocator.free(self.messages); + for (self.file_hints) |hint| { + allocator.free(hint); + } + allocator.free(self.file_hints); + } +}; + +pub const WorkspaceStore = struct { + mutex: std.Thread.Mutex = .{}, + workspaces: std.StringHashMap(*WorkspaceState), + allocator: std.mem.Allocator, + + pub fn init(allocator: std.mem.Allocator) WorkspaceStore { + return .{ + .workspaces = std.StringHashMap(*WorkspaceState).init(allocator), + .allocator = allocator, + }; + } + + pub fn deinit(self: *WorkspaceStore) void { + self.mutex.lock(); + defer self.mutex.unlock(); + + var it = self.workspaces.iterator(); + while (it.next()) |entry| { + entry.value_ptr.*.deinit(self.allocator); + self.allocator.destroy(entry.value_ptr.*); + } + self.workspaces.deinit(); + } + + pub fn count(self: *WorkspaceStore) usize { + self.mutex.lock(); + defer self.mutex.unlock(); + return self.workspaces.count(); + } + + pub fn getOrCreate(self: *WorkspaceStore, workspace_id: []const u8) !*WorkspaceState { + self.mutex.lock(); + defer self.mutex.unlock(); + + if (self.workspaces.get(workspace_id)) |existing| { + return existing; + } + + const owned_id = try self.allocator.dupe(u8, workspace_id); + errdefer self.allocator.free(owned_id); + + const state = try self.allocator.create(WorkspaceState); + errdefer self.allocator.destroy(state); + + state.* = .{ + .workspace_id = owned_id, + .messages = try self.allocator.alloc(types.Message, 0), + .file_hints = &.{}, + }; + + try self.workspaces.put(owned_id, state); + return state; + } + + pub fn saveMessages( + self: *WorkspaceStore, + workspace: *WorkspaceState, + messages: []types.Message, + ) !void { + self.mutex.lock(); + defer self.mutex.unlock(); + + for (workspace.messages) |message| { + message.deinit(self.allocator); + } + self.allocator.free(workspace.messages); + workspace.messages = messages; + } +}; + +pub fn mergeIncomingMessagesAlloc( + allocator: std.mem.Allocator, + existing: []const types.Message, + incoming: []const types.Message, +) ![]types.Message { + if (existing.len == 0) { + return try session.cloneMessagesAlloc(allocator, incoming); + } + if (incoming.len == 0) { + return try session.cloneMessagesAlloc(allocator, existing); + } + + const last_incoming = incoming[incoming.len - 1]; + const last_existing = existing[existing.len - 1]; + if (std.mem.eql(u8, last_existing.role, last_incoming.role) and + std.mem.eql(u8, last_existing.content, last_incoming.content)) + { + return try session.cloneMessagesAlloc(allocator, existing); + } + + return try session.mergeMessagesAlloc(allocator, existing, incoming); +} + +pub fn recordFileHint( + workspace: *WorkspaceState, + allocator: std.mem.Allocator, + operation: []const u8, + path: []const u8, +) !void { + const hint = try std.fmt.allocPrint(allocator, "{s} {s}", .{ operation, path }); + errdefer allocator.free(hint); + + if (workspace.file_hints.len >= max_file_hints) { + allocator.free(workspace.file_hints[0]); + for (workspace.file_hints[1..], 0..) |entry, idx| { + workspace.file_hints[idx] = entry; + } + workspace.file_hints[workspace.file_hints.len - 1] = hint; + return; + } + + const next = try allocator.realloc(workspace.file_hints, workspace.file_hints.len + 1); + next[next.len - 1] = hint; + workspace.file_hints = next; +} + +pub fn buildMemoryHintMessageAlloc( + allocator: std.mem.Allocator, + workspace: *const WorkspaceState, +) !?types.Message { + if (workspace.file_hints.len == 0) return null; + + var out = std.ArrayList(u8){}; + defer out.deinit(allocator); + + try out.appendSlice(allocator, "Workspace memory (cleared on server exit):\n"); + for (workspace.file_hints) |hint| { + try out.writer(allocator).print("- {s}\n", .{hint}); + } + + return types.Message{ + .role = try allocator.dupe(u8, "system"), + .content = try out.toOwnedSlice(allocator), + }; +} + +pub fn prependMemoryHintAlloc( + allocator: std.mem.Allocator, + workspace: *const WorkspaceState, + messages: []const types.Message, +) ![]types.Message { + const hint = try buildMemoryHintMessageAlloc(allocator, workspace); + if (hint == null) { + return try session.cloneMessagesAlloc(allocator, messages); + } + defer hint.?.deinit(allocator); + + var out = std.ArrayList(types.Message){}; + errdefer { + for (out.items) |message| { + message.deinit(allocator); + } + out.deinit(allocator); + } + + try out.append(allocator, .{ + .role = try allocator.dupe(u8, hint.?.role), + .content = try allocator.dupe(u8, hint.?.content), + }); + + for (messages) |message| { + try out.append(allocator, .{ + .role = try allocator.dupe(u8, message.role), + .content = try allocator.dupe(u8, message.content), + }); + } + + return try out.toOwnedSlice(allocator); +} + +test "WorkspaceStore is ephemeral and keyed by id" { + const allocator = std.testing.allocator; + var store = WorkspaceStore.init(allocator); + defer store.deinit(); + + const ws = try store.getOrCreate("dev"); + const incoming = [_]types.Message{.{ .role = "user", .content = "hi" }}; + const merged = try mergeIncomingMessagesAlloc(allocator, ws.messages, incoming[0..]); + try store.saveMessages(ws, merged); + + try std.testing.expectEqual(@as(usize, 1), ws.messages.len); + try std.testing.expectEqual(@as(usize, 1), store.count()); +} + +test "recordFileHint keeps bounded file activity" { + const allocator = std.testing.allocator; + var state = WorkspaceState{ + .workspace_id = try allocator.dupe(u8, "dev"), + .messages = try allocator.alloc(types.Message, 0), + .file_hints = &.{}, + }; + defer state.deinit(allocator); + + try recordFileHint(&state, allocator, "read", "src/main.zig"); + try std.testing.expectEqual(@as(usize, 1), state.file_hints.len); + + const hint = try buildMemoryHintMessageAlloc(allocator, &state); + try std.testing.expect(hint != null); + defer hint.?.deinit(allocator); + try std.testing.expect(std.mem.indexOf(u8, hint.?.content, "src/main.zig") != null); +} diff --git a/src/root.zig b/src/root.zig index 8374287..b87c2bf 100644 --- a/src/root.zig +++ b/src/root.zig @@ -28,6 +28,7 @@ test "backend: auth and tool scaffolding" { _ = @import("backend/session.zig"); _ = @import("backend/mcp.zig"); _ = @import("backend/tool_output.zig"); + _ = @import("backend/workspace.zig"); } test "core: HTTP request parsing" { From 22de5e8752ac9c10b2967a8563f2cc07edeb2cdf Mon Sep 17 00:00:00 2001 From: Ajeets6 Date: Fri, 29 May 2026 15:56:35 +1000 Subject: [PATCH 19/25] Support repo-root file tools in workspace mode Resolve file tools under workspace root when enabled. Pass workspace context through debug tool execution. --- src/backend/tools.zig | 28 ++++++++------ src/tools/file_ops.zig | 88 ++++++++++++++++++++++++++++++++---------- 2 files changed, 83 insertions(+), 33 deletions(-) diff --git a/src/backend/tools.zig b/src/backend/tools.zig index 4b875ce..5080716 100644 --- a/src/backend/tools.zig +++ b/src/backend/tools.zig @@ -6,6 +6,7 @@ const utc_tool = @import("../tools/utc.zig"); const command_exec_tool = @import("../tools/command_exec.zig"); const file_ops_tool = @import("../tools/file_ops.zig"); const tool_output = @import("tool_output.zig"); +const workspace = @import("workspace.zig"); pub const ToolRegistry = struct { allowed_names: std.StringHashMap(void), @@ -57,6 +58,7 @@ pub fn tryExecuteDebugTool( request: types.Request, app_config: *const config.Config, request_id: []const u8, + active_workspace: ?*workspace.WorkspaceState, ) !?types.Response { const choice = request.tool_choice orelse return null; @@ -88,17 +90,17 @@ pub fn tryExecuteDebugTool( if (std.ascii.eqlIgnoreCase(choice, "file_read")) { if (!hasRequestedTool(request.tools, "file_read")) return null; - return try file_ops_tool.execute(allocator, app_config, request, .read); + return try file_ops_tool.execute(allocator, app_config, request, .read, active_workspace); } if (std.ascii.eqlIgnoreCase(choice, "file_write")) { if (!hasRequestedTool(request.tools, "file_write")) return null; - return try file_ops_tool.execute(allocator, app_config, request, .write); + return try file_ops_tool.execute(allocator, app_config, request, .write, active_workspace); } if (std.ascii.eqlIgnoreCase(choice, "file_search")) { if (!hasRequestedTool(request.tools, "file_search")) return null; - return try file_ops_tool.execute(allocator, app_config, request, .search); + return try file_ops_tool.execute(allocator, app_config, request, .search, active_workspace); } return null; @@ -109,6 +111,7 @@ pub fn maybeExecutePromptToolsAlloc( request: types.Request, app_config: *const config.Config, request_id: []const u8, + active_workspace: ?*workspace.WorkspaceState, ) !?[]u8 { if (request.tools.len == 0) return null; @@ -143,7 +146,7 @@ pub fn maybeExecutePromptToolsAlloc( if (file_read_requested and mentionsFileReadIntent(prompt)) { try noteToolCall(app_config.max_tool_calls_per_request, &tool_calls); - const read_result = try file_ops_tool.execute(allocator, app_config, request, .read); + const read_result = try file_ops_tool.execute(allocator, app_config, request, .read, active_workspace); defer read_result.deinit(allocator); try output.writer(allocator).print("\n[tool=file_read]\n{s}\n", .{read_result.output}); executed_any = true; @@ -151,7 +154,7 @@ pub fn maybeExecutePromptToolsAlloc( if (file_search_requested and mentionsFileSearchIntent(prompt)) { try noteToolCall(app_config.max_tool_calls_per_request, &tool_calls); - const search_result = try file_ops_tool.execute(allocator, app_config, request, .search); + const search_result = try file_ops_tool.execute(allocator, app_config, request, .search, active_workspace); defer search_result.deinit(allocator); try output.writer(allocator).print("\n[tool=file_search]\n{s}\n", .{search_result.output}); executed_any = true; @@ -177,7 +180,7 @@ pub fn maybeExecutePromptToolsAlloc( defer write_req.deinit(allocator); try noteToolCall(app_config.max_tool_calls_per_request, &tool_calls); - const write_result = try file_ops_tool.execute(allocator, app_config, write_req, .write); + const write_result = try file_ops_tool.execute(allocator, app_config, write_req, .write, active_workspace); defer write_result.deinit(allocator); try output.writer(allocator).print("\n[tool=file_write]\n{s}\n", .{write_result.output}); @@ -276,7 +279,7 @@ pub fn makeAutoToolResponse( }; } -fn hasRequestedTool(tools: []const types.Tool, name: []const u8) bool { +pub fn hasRequestedTool(tools: []const types.Tool, name: []const u8) bool { for (tools) |tool| { if (std.ascii.eqlIgnoreCase(tool.name, name)) return true; } @@ -348,6 +351,7 @@ fn buildSinglePromptRequestAlloc( .model = null, .session_id = null, .tenant_id = null, + .workspace_id = null, .max_context_tokens = null, .tools = try allocator.alloc(types.Tool, 0), .tool_choice = null, @@ -385,7 +389,7 @@ test "tryExecuteDebugTool executes echo when explicitly selected" { var cfg = try command_exec_tool.buildTestConfig(allocator, false); defer cfg.deinit(allocator); - const maybe_result = try tryExecuteDebugTool(allocator, req, &cfg, "test-request"); + const maybe_result = try tryExecuteDebugTool(allocator, req, &cfg, "test-request", null); try std.testing.expect(maybe_result != null); var result = maybe_result.?; @@ -427,7 +431,7 @@ test "tryExecuteDebugTool executes utc when explicitly selected" { var cfg = try command_exec_tool.buildTestConfig(allocator, false); defer cfg.deinit(allocator); - const maybe_result = try tryExecuteDebugTool(allocator, req, &cfg, "test-request"); + const maybe_result = try tryExecuteDebugTool(allocator, req, &cfg, "test-request", null); try std.testing.expect(maybe_result != null); var result = maybe_result.?; @@ -469,7 +473,7 @@ test "maybeExecutePromptToolsAlloc runs utc from prompt intent" { var cfg = try command_exec_tool.buildTestConfig(allocator, false); defer cfg.deinit(allocator); - const summary = try maybeExecutePromptToolsAlloc(allocator, req, &cfg, "test-request"); + const summary = try maybeExecutePromptToolsAlloc(allocator, req, &cfg, "test-request", null); defer if (summary) |value| allocator.free(value); try std.testing.expect(summary != null); @@ -509,7 +513,7 @@ test "maybeExecutePromptToolsAlloc extracts command for cmd" { var cfg = try command_exec_tool.buildTestConfig(allocator, true); defer cfg.deinit(allocator); - const summary = try maybeExecutePromptToolsAlloc(allocator, req, &cfg, "test-request"); + const summary = try maybeExecutePromptToolsAlloc(allocator, req, &cfg, "test-request", null); defer if (summary) |value| allocator.free(value); try std.testing.expect(summary != null); @@ -548,7 +552,7 @@ test "tryExecuteDebugTool ignores natural language cmd selection" { var cfg = try command_exec_tool.buildTestConfig(allocator, true); defer cfg.deinit(allocator); - const maybe_result = try tryExecuteDebugTool(allocator, req, &cfg, "test-request"); + const maybe_result = try tryExecuteDebugTool(allocator, req, &cfg, "test-request", null); try std.testing.expect(maybe_result == null); } diff --git a/src/tools/file_ops.zig b/src/tools/file_ops.zig index 0aaf5c2..27dba01 100644 --- a/src/tools/file_ops.zig +++ b/src/tools/file_ops.zig @@ -1,6 +1,7 @@ const std = @import("std"); const config = @import("../config.zig"); const types = @import("../types.zig"); +const workspace = @import("../backend/workspace.zig"); pub const Op = enum { read, write, search }; @@ -9,6 +10,7 @@ pub fn execute( app_config: *const config.Config, request: types.Request, op: Op, + active_workspace: ?*workspace.WorkspaceState, ) !types.Response { const tool_name: []const u8 = switch (op) { .read => "file_read", @@ -26,9 +28,9 @@ pub fn execute( } return switch (op) { - .read => executeRead(allocator, app_config, tool_name, prompt), - .write => executeWrite(allocator, app_config, tool_name, request, prompt), - .search => executeSearch(allocator, app_config, tool_name, prompt), + .read => executeRead(allocator, app_config, tool_name, prompt, active_workspace), + .write => executeWrite(allocator, app_config, tool_name, request, prompt, active_workspace), + .search => executeSearch(allocator, app_config, tool_name, prompt, active_workspace), }; } @@ -37,6 +39,7 @@ fn executeRead( app_config: *const config.Config, tool_name: []const u8, prompt: []const u8, + active_workspace: ?*workspace.WorkspaceState, ) !types.Response { const path_raw = parseSingleArgAfterPrefix(prompt, "read ") orelse parseSingleArgAfterPrefix(prompt, "file_read ") orelse @@ -57,15 +60,25 @@ fn executeRead( ); }; - const tmp_rel_path = try prefixTmpPathAlloc(allocator, rel_path); - defer allocator.free(tmp_rel_path); + const resolved_path = resolveToolPathAlloc(allocator, app_config, rel_path) catch |err| { + return try makeToolResponse( + allocator, + tool_name, + try std.fmt.allocPrint( + allocator, + "DEBUG_TOOL_ERROR\ntool={s}\nvalidation_error={s}\nmessage=invalid path", + .{ tool_name, @errorName(err) }, + ), + ); + }; + defer allocator.free(resolved_path); const start_ms = std.time.milliTimestamp(); - var file = std.fs.cwd().openFile(tmp_rel_path, .{}) catch |err| { + var file = std.fs.cwd().openFile(resolved_path, .{}) catch |err| { return try makeToolResponse( allocator, tool_name, - try std.fmt.allocPrint(allocator, "DEBUG_TOOL_ERROR\ntool={s}\nopen_error={s}\npath={s}", .{ tool_name, @errorName(err), tmp_rel_path }), + try std.fmt.allocPrint(allocator, "DEBUG_TOOL_ERROR\ntool={s}\nopen_error={s}\npath={s}", .{ tool_name, @errorName(err), resolved_path }), ); }; defer file.close(); @@ -74,7 +87,7 @@ fn executeRead( return try makeToolResponse( allocator, tool_name, - try std.fmt.allocPrint(allocator, "DEBUG_TOOL_ERROR\ntool={s}\nread_error={s}\npath={s}", .{ tool_name, @errorName(err), tmp_rel_path }), + try std.fmt.allocPrint(allocator, "DEBUG_TOOL_ERROR\ntool={s}\nread_error={s}\npath={s}", .{ tool_name, @errorName(err), resolved_path }), ); }; defer allocator.free(bytes); @@ -82,10 +95,14 @@ fn executeRead( const end_ms = std.time.milliTimestamp(); const duration_ms: u64 = if (end_ms >= start_ms) @as(u64, @intCast(end_ms - start_ms)) else 0; + if (active_workspace) |ws| { + workspace.recordFileHint(ws, allocator, "read", rel_path) catch {}; + } + const out = try std.fmt.allocPrint( allocator, "DEBUG_TOOL_OK\ntool={s}\npath={s}\nduration_ms={d}\nmax_bytes={d}\n--- content ---\n{s}", - .{ tool_name, tmp_rel_path, duration_ms, app_config.tool_exec_max_output_bytes, bytes }, + .{ tool_name, resolved_path, duration_ms, app_config.tool_exec_max_output_bytes, bytes }, ); return try makeToolResponse(allocator, tool_name, out); } @@ -96,6 +113,7 @@ fn executeWrite( tool_name: []const u8, request: types.Request, prompt: []const u8, + active_workspace: ?*workspace.WorkspaceState, ) !types.Response { const parsed = parseWritePrompt(allocator, prompt) catch blk: { const inferred_path = inferWriteFilenameFromPromptAlloc(allocator, prompt) orelse break :blk null; @@ -133,8 +151,14 @@ fn executeWrite( ); }; - const tmp_rel_path = try prefixTmpPathAlloc(allocator, rel_path); - defer allocator.free(tmp_rel_path); + const resolved_path = resolveToolPathAlloc(allocator, app_config, rel_path) catch |err| { + return try makeToolResponse( + allocator, + tool_name, + try std.fmt.allocPrint(allocator, "DEBUG_TOOL_ERROR\ntool={s}\nvalidation_error={s}\nmessage=invalid path", .{ tool_name, @errorName(err) }), + ); + }; + defer allocator.free(resolved_path); if (parsed.content.len > app_config.tool_exec_max_output_bytes) { return try makeToolResponse( @@ -150,7 +174,7 @@ fn executeWrite( const start_ms = std.time.milliTimestamp(); - if (std.fs.path.dirname(tmp_rel_path)) |dir_name| { + if (std.fs.path.dirname(resolved_path)) |dir_name| { std.fs.cwd().makePath(dir_name) catch |err| { return try makeToolResponse( allocator, @@ -160,11 +184,11 @@ fn executeWrite( }; } - var file = std.fs.cwd().createFile(tmp_rel_path, .{ .truncate = true }) catch |err| { + var file = std.fs.cwd().createFile(resolved_path, .{ .truncate = true }) catch |err| { return try makeToolResponse( allocator, tool_name, - try std.fmt.allocPrint(allocator, "DEBUG_TOOL_ERROR\ntool={s}\ncreate_error={s}\npath={s}", .{ tool_name, @errorName(err), tmp_rel_path }), + try std.fmt.allocPrint(allocator, "DEBUG_TOOL_ERROR\ntool={s}\ncreate_error={s}\npath={s}", .{ tool_name, @errorName(err), resolved_path }), ); }; defer file.close(); @@ -173,17 +197,21 @@ fn executeWrite( return try makeToolResponse( allocator, tool_name, - try std.fmt.allocPrint(allocator, "DEBUG_TOOL_ERROR\ntool={s}\nwrite_error={s}\npath={s}", .{ tool_name, @errorName(err), tmp_rel_path }), + try std.fmt.allocPrint(allocator, "DEBUG_TOOL_ERROR\ntool={s}\nwrite_error={s}\npath={s}", .{ tool_name, @errorName(err), resolved_path }), ); }; const end_ms = std.time.milliTimestamp(); const duration_ms: u64 = if (end_ms >= start_ms) @as(u64, @intCast(end_ms - start_ms)) else 0; + if (active_workspace) |ws| { + workspace.recordFileHint(ws, allocator, "write", rel_path) catch {}; + } + const out = try std.fmt.allocPrint( allocator, "DEBUG_TOOL_OK\ntool={s}\npath={s}\nbytes_written={d}\nduration_ms={d}\n--- content ---\n{s}", - .{ tool_name, tmp_rel_path, parsed.content.len, duration_ms, parsed.content }, + .{ tool_name, resolved_path, parsed.content.len, duration_ms, parsed.content }, ); return try makeToolResponse(allocator, tool_name, out); } @@ -193,6 +221,7 @@ fn executeSearch( app_config: *const config.Config, tool_name: []const u8, prompt: []const u8, + active_workspace: ?*workspace.WorkspaceState, ) !types.Response { const parsed = parseSearchPrompt(allocator, prompt) catch |err| { return try makeToolResponse( @@ -216,8 +245,14 @@ fn executeSearch( ); }; - const tmp_rel_dir = try prefixTmpPathAlloc(allocator, rel_dir); - defer allocator.free(tmp_rel_dir); + const resolved_dir = resolveToolPathAlloc(allocator, app_config, rel_dir) catch |err| { + return try makeToolResponse( + allocator, + tool_name, + try std.fmt.allocPrint(allocator, "DEBUG_TOOL_ERROR\ntool={s}\nvalidation_error={s}\nmessage=invalid dir", .{ tool_name, @errorName(err) }), + ); + }; + defer allocator.free(resolved_dir); const start_ms = std.time.milliTimestamp(); @@ -226,7 +261,7 @@ fn executeSearch( try out.writer(allocator).print( "DEBUG_TOOL_OK\ntool={s}\ndir={s}\npattern={s}\nmax_bytes={d}\n--- matches ---\n", - .{ tool_name, tmp_rel_dir, parsed.pattern, app_config.tool_exec_max_output_bytes }, + .{ tool_name, resolved_dir, parsed.pattern, app_config.tool_exec_max_output_bytes }, ); var bytes_budget: usize = app_config.tool_exec_max_output_bytes; @@ -236,11 +271,11 @@ fn executeSearch( var files_scanned: usize = 0; const max_files: usize = 500; - var dir = std.fs.cwd().openDir(tmp_rel_dir, .{ .iterate = true }) catch |err| { + var dir = std.fs.cwd().openDir(resolved_dir, .{ .iterate = true }) catch |err| { return try makeToolResponse( allocator, tool_name, - try std.fmt.allocPrint(allocator, "DEBUG_TOOL_ERROR\ntool={s}\nopen_dir_error={s}\ndir={s}", .{ tool_name, @errorName(err), tmp_rel_dir }), + try std.fmt.allocPrint(allocator, "DEBUG_TOOL_ERROR\ntool={s}\nopen_dir_error={s}\ndir={s}", .{ tool_name, @errorName(err), resolved_dir }), ); }; defer dir.close(); @@ -278,6 +313,10 @@ fn executeSearch( const duration_ms: u64 = if (end_ms >= start_ms) @as(u64, @intCast(end_ms - start_ms)) else 0; try out.writer(allocator).print("\nfiles_scanned={d}\nmatches={d}\nduration_ms={d}\n", .{ files_scanned, matches, duration_ms }); + if (active_workspace) |ws| { + workspace.recordFileHint(ws, allocator, "search", rel_dir) catch {}; + } + return try makeToolResponse(allocator, tool_name, try out.toOwnedSlice(allocator)); } @@ -407,6 +446,13 @@ fn parseSingleArgAfterPrefix(text: []const u8, prefix: []const u8) ?[]const u8 { return std.mem.trim(u8, text[prefix.len..], " \t\r\n\"'"); } +fn resolveToolPathAlloc(allocator: std.mem.Allocator, app_config: *const config.Config, rel_path: []const u8) ![]u8 { + if (app_config.workspace_mode_enabled) { + return std.fs.path.join(allocator, &.{ app_config.workspace_root, rel_path }); + } + return prefixTmpPathAlloc(allocator, rel_path); +} + fn prefixTmpPathAlloc(allocator: std.mem.Allocator, rel_path: []const u8) ![]u8 { // All tool file IO is sandboxed under `tmp/` so generated artifacts don't // spill into the repo root. From c2d97d0575a42ffe5b849937aad72fee6a861bb9 Mon Sep 17 00:00:00 2001 From: Ajeets6 Date: Fri, 29 May 2026 15:56:35 +1000 Subject: [PATCH 20/25] Improve ReAct loop for small coding models Tune prompts, parsing, and CLI defaults for multi-step loops. Map Search and Lookup to file tools when available. --- src/main.zig | 52 ++++++++++++-- src/react.zig | 188 +++++++++++++++++++++++++++++++++++++++++++++----- 2 files changed, 217 insertions(+), 23 deletions(-) diff --git a/src/main.zig b/src/main.zig index 972d36a..ff3bec7 100644 --- a/src/main.zig +++ b/src/main.zig @@ -1,5 +1,6 @@ //! CLI entrypoint for running the router in server mode or prompt-loop mode. const std = @import("std"); +const builtin = @import("builtin"); const root = @import("root.zig"); const react = @import("react.zig"); @@ -13,6 +14,7 @@ const LoopMode = enum { const CliOptions = struct { prompt: ?[]u8 = null, // Enables prompt-loop mode when set. provider_override: ?[]u8 = null, // Optional provider alias from CLI. + model_override: ?[]u8 = null, // Optional default model override from CLI. env_file_path: ?[]u8 = null, // Optional dotenv path loaded before config. use_env: bool = false, // Enable dotenv loading. loop_mode: LoopMode = .basic, // Iteration style for prompt loop mode. @@ -27,6 +29,7 @@ const CliOptions = struct { pub fn deinit(self: *CliOptions, allocator: std.mem.Allocator) void { if (self.prompt) |prompt| allocator.free(prompt); if (self.provider_override) |provider| allocator.free(provider); + if (self.model_override) |model| allocator.free(model); if (self.env_file_path) |env_file_path| allocator.free(env_file_path); allocator.free(self.until_marker); } @@ -37,10 +40,18 @@ const CliOptions = struct { /// Errors: /// - allocator, configuration, CLI parsing, and provider/runtime errors. pub fn main() !void { - var gpa = std.heap.DebugAllocator(.{}){}; - defer _ = gpa.deinit(); - const allocator = gpa.allocator(); + if (comptime builtin.mode == .Debug) { + var gpa = std.heap.DebugAllocator(.{}){}; + defer _ = gpa.deinit(); + try runMain(gpa.allocator()); + } else { + var gpa = std.heap.GeneralPurposeAllocator(.{}){}; + defer _ = gpa.deinit(); + try runMain(gpa.allocator()); + } +} +fn runMain(allocator: std.mem.Allocator) !void { var cli = try parseCliOptions(allocator); defer cli.deinit(allocator); @@ -62,6 +73,10 @@ pub fn main() !void { try app_config.setDefaultProvider(allocator, provider); } + if (cli.model_override) |model| { + try app_config.setDefaultModel(allocator, model); + } + if (cli.prompt) |prompt| { try runPromptLoop(allocator, &app_config, prompt, cli.until_marker, cli.max_turns, cli.loop_mode); return; @@ -90,12 +105,14 @@ fn parseCliOptions(allocator: std.mem.Allocator) !CliOptions { if (std.mem.eql(u8, arg, "--react")) { cli.loop_mode = .react; + applyReactLoopDefaults(&cli); continue; } if (std.mem.eql(u8, arg, "--loop-mode")) { const value = args.next() orelse return error.MissingLoopModeValue; cli.loop_mode = try parseLoopMode(value); + applyReactLoopDefaults(&cli); continue; } @@ -120,6 +137,14 @@ fn parseCliOptions(allocator: std.mem.Allocator) !CliOptions { continue; } + if (std.mem.eql(u8, arg, "--model")) { + const value = args.next() orelse return error.MissingModelValue; + if (value.len == 0) return error.InvalidModelValue; + if (cli.model_override) |model| allocator.free(model); + cli.model_override = try allocator.dupe(u8, value); + continue; + } + if (std.mem.eql(u8, arg, "--prompt")) { const value = args.next() orelse return error.MissingPromptValue; if (value.len == 0) return error.InvalidPromptValue; @@ -153,6 +178,15 @@ fn parseCliOptions(allocator: std.mem.Allocator) !CliOptions { continue; } + const model_prefix = "--model="; + if (std.mem.startsWith(u8, arg, model_prefix)) { + const value = arg[model_prefix.len..]; + if (value.len == 0) return error.InvalidModelValue; + if (cli.model_override) |model| allocator.free(model); + cli.model_override = try allocator.dupe(u8, value); + continue; + } + const prompt_prefix = "--prompt="; if (std.mem.startsWith(u8, arg, prompt_prefix)) { const value = arg[prompt_prefix.len..]; @@ -195,6 +229,7 @@ fn parseCliOptions(allocator: std.mem.Allocator) !CliOptions { if (std.mem.startsWith(u8, arg, loop_mode_prefix)) { const value = arg[loop_mode_prefix.len..]; cli.loop_mode = try parseLoopMode(value); + applyReactLoopDefaults(&cli); continue; } } @@ -202,6 +237,12 @@ fn parseCliOptions(allocator: std.mem.Allocator) !CliOptions { return cli; } +fn applyReactLoopDefaults(cli: *CliOptions) void { + if (cli.loop_mode == .react and cli.max_turns == 8) { + cli.max_turns = 24; + } +} + fn parseLoopMode(value: []const u8) !LoopMode { if (std.ascii.eqlIgnoreCase(value, "basic")) return .basic; if (std.ascii.eqlIgnoreCase(value, "agent")) return .agent; @@ -362,7 +403,8 @@ fn runPromptLoop( } previous_output = try allocator.dupe(u8, normalized_output); - if ((loop_mode == .agent or loop_mode == .react) and repeated_count >= 1) { + const repeated_limit: usize = if (loop_mode == .react) 2 else 1; + if ((loop_mode == .agent or loop_mode == .react) and repeated_count >= repeated_limit) { std.debug.print("Loop stopped early: model repeated output without progress.\n", .{}); return; } @@ -400,7 +442,7 @@ fn runPromptLoop( const hint = try react.formatObservation( allocator, turn + 1, - "Could not parse an Action from your response. Please use the format: Action N: Type[argument]", + react.parse_error_hint, ); defer allocator.free(hint); diff --git a/src/react.zig b/src/react.zig index cbb156d..028088e 100644 --- a/src/react.zig +++ b/src/react.zig @@ -12,24 +12,29 @@ const command_exec_tool = @import("tools/command_exec.zig"); /// System prompt injected at the start of every ReAct loop to instruct the /// model on the required Thought/Action/Observation format. pub const system_prompt = - "You are running in ReAct (Reasoning + Acting) mode.\n" ++ - "You MUST follow this format strictly for every step:\n\n" ++ - "Thought N: \n" ++ - "Action N: \n\n" ++ - "After each Action the system will inject:\n" ++ - "Observation N: \n\n" ++ + "You are running in ReAct (Reasoning + Acting) mode for multi-step problem solving.\n" ++ + "Smaller models should solve tasks incrementally: one Thought and one Action per turn, " ++ + "wait for the Observation, then continue.\n\n" ++ + "Required format every turn:\n\n" ++ + "Thought N: \n" ++ + "Action N: \n\n" ++ + "The harness injects after each Action:\n" ++ + "Observation N: \n\n" ++ + "Example:\n" ++ + "Thought 1: I need to inspect the entry file before editing.\n" ++ + "Action 1: file_read[src/main.zig]\n\n" ++ "Available actions:\n" ++ - " Search[query] - search for information\n" ++ - " Lookup[term] - look up a term in the current context\n" ++ + " Search[query] - search the repo (maps to file_search when that tool is enabled)\n" ++ + " Lookup[path] - read a file path (maps to file_read when that tool is enabled)\n" ++ " Cmd[command] - execute a shell command (may require confirmation)\n" ++ " Finish[answer] - return the final answer and end the loop\n\n" ++ "Rules:\n" ++ - "- Always start with a Thought, then an Action.\n" ++ - "- Never produce an Observation yourself; the system provides them.\n" ++ - "- Use Finish[answer] when you have the final answer.\n" ++ - "- Be concrete: use specific names, values, steps, and short outputs instead of vague summaries.\n" ++ - "- If you need information, ask for the exact term or command that will resolve it.\n" ++ - "- If an action fails, adjust your approach in the next Thought.\n"; + "- Emit exactly ONE Action per turn; do not plan ahead with Action 2+ before the Observation.\n" ++ + "- Never write Observation lines yourself.\n" ++ + "- Use Finish[answer] only when the task is done.\n" ++ + "- Prefer small steps: read/search, then edit, then run tests via Cmd.\n" ++ + "- If parsing fails, repeat Thought/Action using the exact Action N: Type[argument] format.\n" ++ + "- If an action fails, change approach in the next Thought instead of repeating the same action.\n"; pub fn buildSystemPromptAlloc(allocator: std.mem.Allocator, tools: []const types.Tool) ![]u8 { if (tools.len == 0) return try allocator.dupe(u8, system_prompt); @@ -51,15 +56,84 @@ pub fn buildSystemPromptAlloc(allocator: std.mem.Allocator, tools: []const types try out.appendSlice( allocator, - "\n\nTool rules:\n" ++ - "- For file_write, include the path and fenced code block inside the brackets, for example:\n" ++ - " Action N: file_write[write relative/path\n```lang\n...\n```]\n" ++ - "- Put the full tool input inside the brackets.\n", + "\n\nCoding tool rules:\n" ++ + "- file_read[path]: read before editing; path is relative to the project root.\n" ++ + "- file_search[search pattern in directory]: find symbols/files before changing code.\n" ++ + "- file_write[write relative/path\n```lang\n...\n```]: put the full write payload in brackets.\n" ++ + "- Read[path] and Write[...] are aliases for file_read and file_write.\n" ++ + "- Run Cmd[zig build test] (or similar) after edits to verify.\n", ); return try out.toOwnedSlice(allocator); } +const DefaultCodingTool = struct { + name: []const u8, + description: []const u8, +}; + +/// Ensures ReAct loop requests expose the standard coding tool set server-side so +/// clients do not need to attach tools manually. Existing client tools are kept; +/// missing defaults are appended. Sets `tool_choice` to `auto` when unset. +pub fn ensureReactCodingToolsAlloc(allocator: std.mem.Allocator, request: *types.Request) !void { + const loop_mode = request.loop_mode orelse return; + if (!std.ascii.eqlIgnoreCase(loop_mode, "react")) return; + + const shell_tool: []const u8 = if (@import("builtin").os.tag == .windows) "cmd" else "bash"; + const shell_description = if (@import("builtin").os.tag == .windows) + "Execute a Windows shell command (requires LLM_ROUTER_TOOL_EXEC_ENABLED=1)." + else + "Execute a bash shell command (requires LLM_ROUTER_TOOL_EXEC_ENABLED=1)."; + + const defaults = [_]DefaultCodingTool{ + .{ .name = "file_read", .description = "Read a relative file path and return its content." }, + .{ .name = "file_write", .description = "Write or replace a relative file path." }, + .{ .name = "file_search", .description = "Search for a substring under a relative directory." }, + .{ .name = shell_tool, .description = shell_description }, + }; + + var merged = std.ArrayList(types.Tool){}; + errdefer { + for (merged.items) |tool| { + tool.deinit(allocator); + } + merged.deinit(allocator); + } + + for (request.tools) |tool| { + try merged.append(allocator, .{ + .name = try allocator.dupe(u8, tool.name), + .description = try allocator.dupe(u8, tool.description), + }); + } + + for (defaults) |default_tool| { + if (hasToolName(merged.items, default_tool.name)) continue; + try merged.append(allocator, .{ + .name = try allocator.dupe(u8, default_tool.name), + .description = try allocator.dupe(u8, default_tool.description), + }); + } + + for (request.tools) |tool| { + tool.deinit(allocator); + } + allocator.free(request.tools); + + request.tools = try merged.toOwnedSlice(allocator); + + if (request.tool_choice == null) { + request.tool_choice = try allocator.dupe(u8, "auto"); + } +} + +fn hasToolName(tools: []const types.Tool, name: []const u8) bool { + for (tools) |tool| { + if (std.ascii.eqlIgnoreCase(tool.name, name)) return true; + } + return false; +} + /// Parsed action extracted from a model response. pub const ToolAction = struct { name: []const u8, @@ -160,6 +234,12 @@ pub fn parseReactAction(output: []const u8) ?ReactAction { if (eqlIgnoreCase(action_type, "lookup")) return .{ .lookup = argument }; if (eqlIgnoreCase(action_type, "cmd")) return .{ .cmd = argument }; if (eqlIgnoreCase(action_type, "finish")) return .{ .finish = argument }; + if (eqlIgnoreCase(action_type, "read")) { + return .{ .tool = .{ .name = "file_read", .argument = argument } }; + } + if (eqlIgnoreCase(action_type, "write")) { + return .{ .tool = .{ .name = "file_write", .argument = argument } }; + } if (isKnownApiToolAction(action_type)) { return .{ .tool = .{ .name = action_type, .argument = argument } }; } @@ -207,6 +287,10 @@ pub fn executeReactAction( }; } +pub const parse_error_hint = + "Could not parse an Action. Reply with exactly one pair: Thought N: ... then Action N: Type[argument]. " ++ + "Example: Action 1: file_read[src/main.zig]"; + /// Formats an observation message with turn numbering. pub fn formatObservation( allocator: std.mem.Allocator, @@ -400,6 +484,18 @@ test "parseReactAction returns unknown for unrecognized action type" { } } +test "parseReactAction maps Read alias to file_read" { + const action = parseReactAction("Action 1: Read[src/main.zig]"); + try std.testing.expect(action != null); + switch (action.?) { + .tool => |tool| { + try std.testing.expectEqualStrings("file_read", tool.name); + try std.testing.expectEqualStrings("src/main.zig", tool.argument); + }, + else => return error.UnexpectedActionType, + } +} + test "parseReactAction extracts requested API tool action" { const action = parseReactAction("Thought 1: write the file\nAction 1: file_write[write app.cpp\n```cpp\nint main() { return 0; }\n```]"); try std.testing.expect(action != null); @@ -413,6 +509,62 @@ test "parseReactAction extracts requested API tool action" { } } +test "ensureReactCodingToolsAlloc merges defaults and sets auto tool_choice" { + const allocator = std.testing.allocator; + + var request = types.Request{ + .prompt = try allocator.dupe(u8, "fix the bug"), + .messages = try allocator.alloc(types.Message, 0), + .tools = try allocator.alloc(types.Tool, 0), + .loop_mode = try allocator.dupe(u8, "react"), + }; + defer request.deinit(allocator); + + try ensureReactCodingToolsAlloc(allocator, &request); + try std.testing.expect(request.tools.len >= 4); + try std.testing.expect(hasToolName(request.tools, "file_read")); + try std.testing.expect(hasToolName(request.tools, "file_write")); + try std.testing.expect(hasToolName(request.tools, "file_search")); + try std.testing.expect(request.tool_choice != null); + try std.testing.expectEqualStrings("auto", request.tool_choice.?); +} + +test "ensureReactCodingToolsAlloc preserves client tools" { + const allocator = std.testing.allocator; + + var request = types.Request{ + .prompt = try allocator.dupe(u8, "echo"), + .messages = try allocator.alloc(types.Message, 0), + .tools = try allocator.alloc(types.Tool, 1), + .loop_mode = try allocator.dupe(u8, "react"), + }; + request.tools[0] = .{ + .name = try allocator.dupe(u8, "echo"), + .description = try allocator.dupe(u8, "echo tool"), + }; + defer request.deinit(allocator); + + try ensureReactCodingToolsAlloc(allocator, &request); + try std.testing.expect(hasToolName(request.tools, "echo")); + try std.testing.expect(hasToolName(request.tools, "file_read")); +} + +test "ensureReactCodingToolsAlloc no-op outside react mode" { + const allocator = std.testing.allocator; + + var request = types.Request{ + .prompt = try allocator.dupe(u8, "hello"), + .messages = try allocator.alloc(types.Message, 0), + .tools = try allocator.alloc(types.Tool, 0), + .loop_mode = try allocator.dupe(u8, "agent"), + }; + defer request.deinit(allocator); + + try ensureReactCodingToolsAlloc(allocator, &request); + try std.testing.expectEqual(@as(usize, 0), request.tools.len); + try std.testing.expect(request.tool_choice == null); +} + test "buildSystemPromptAlloc includes requested tools" { const allocator = std.testing.allocator; const tools = [_]types.Tool{ From 22985ebd1acdf214062eae7f66c12abf06584665 Mon Sep 17 00:00:00 2001 From: Ajeets6 Date: Fri, 29 May 2026 15:56:35 +1000 Subject: [PATCH 21/25] Wire workspace mode and ReAct into server Integrate workspace merge, compaction, and loop execution. Auto-attach coding tools for react loop_mode on the server. --- src/core/server.zig | 507 ++++++++++++++++++++++++++++++++++---------- 1 file changed, 394 insertions(+), 113 deletions(-) diff --git a/src/core/server.zig b/src/core/server.zig index 56922e7..14e7a4a 100644 --- a/src/core/server.zig +++ b/src/core/server.zig @@ -9,6 +9,7 @@ const backend = @import("../backend/api.zig"); const auth = @import("../backend/auth.zig"); const tooling = @import("../backend/tools.zig"); const session = @import("../backend/session.zig"); +const workspace = @import("../backend/workspace.zig"); const react = @import("../react.zig"); const context_compaction_keep_messages = 8; @@ -235,6 +236,7 @@ const WorkerContext = struct { tool_registry: *tooling.ToolRegistry, session_store: ?*session.SessionStore, session_store_guard: *SessionStoreGuard, + workspace_store: ?*workspace.WorkspaceStore, gate: *ConnectionGate, }; @@ -263,6 +265,7 @@ pub fn run( .{ app_config.listen_host, app_config.listen_port }, ); logInfo("Default provider: {s}", .{app_config.default_provider}); + logInfo("Default model: {s}", .{app_config.defaultModel()}); logInfo( "Available providers: ollama (aliases: qwen, ollama_qwen), openai, openrouter, claude (alias: anthropic), bedrock, llama_cpp (alias: llama.cpp)", .{}, @@ -329,7 +332,18 @@ pub fn run( } defer if (session_store) |*store| store.deinit(allocator); + var workspace_store_impl: ?workspace.WorkspaceStore = null; + if (app_config.workspace_mode_enabled) { + workspace_store_impl = workspace.WorkspaceStore.init(allocator); + logInfo( + "Workspace mode enabled root={s} (in-memory; cleared on exit)", + .{app_config.workspace_root}, + ); + } + defer if (workspace_store_impl) |*store| store.deinit(); + const active_session_store: ?*session.SessionStore = if (session_store) |*store| store else null; + const active_workspace_store: ?*workspace.WorkspaceStore = if (workspace_store_impl) |*store| store else null; var gate = ConnectionGate{ .max_active = app_config.max_concurrent_connections }; var session_store_guard = SessionStoreGuard{}; @@ -340,6 +354,7 @@ pub fn run( .tool_registry = &tool_registry, .session_store = active_session_store, .session_store_guard = &session_store_guard, + .workspace_store = active_workspace_store, .gate = &gate, }; @@ -377,6 +392,7 @@ fn connectionWorkerMain(context: WorkerContext, connection: std.net.Server.Conne context.tool_registry, context.session_store, context.session_store_guard, + context.workspace_store, ) catch |err| { logError("Request handling error: {s}", .{@errorName(err)}); context.server_state.noteRequestFailed(); @@ -391,6 +407,7 @@ fn handleConnection( tool_registry: *tooling.ToolRegistry, session_store: ?*session.SessionStore, session_store_guard: *SessionStoreGuard, + workspace_store: ?*workspace.WorkspaceStore, ) !void { const request_started_ms = std.time.milliTimestamp(); var request_route: []const u8 = "unknown"; @@ -613,7 +630,7 @@ fn handleConnection( debugLog(app_config, "request body_len={d}", .{body.len}); const parse_result = try backend.parseChatRequest(allocator, body); - const parsed_req = switch (parse_result) { + var parsed_req = switch (parse_result) { .ok => |parsed_request| parsed_request, .err => |api_error| { sendApiErrorSafe(connection.*, allocator, api_error, app_config, request_id); @@ -623,6 +640,19 @@ fn handleConnection( }; defer parsed_req.deinit(allocator); + react.ensureReactCodingToolsAlloc(allocator, &parsed_req) catch |err| { + logError("ReAct tool augmentation failed: {s}", .{@errorName(err)}); + sendApiErrorSafe( + connection.*, + allocator, + backend.errors.httpError("Failed to prepare ReAct coding tools", "react_tool_setup_failed"), + app_config, + request_id, + ); + server_state.noteRequestFailed(); + return; + }; + if (!tooling.validateRequestedTools(tool_registry, parsed_req.tools)) { sendApiErrorSafe(connection.*, allocator, backend.errors.validationError( "One or more requested tools are not registered", @@ -636,6 +666,9 @@ fn handleConnection( var loaded_session: ?session.SessionState = null; defer if (loaded_session) |state| state.deinit(allocator); + var active_workspace: ?*workspace.WorkspaceState = null; + const use_workspace = app_config.workspace_mode_enabled and parsed_req.workspace_id != null; + var request_messages = try session.cloneMessagesAlloc(allocator, parsed_req.messages); defer { for (request_messages) |message| { @@ -647,7 +680,37 @@ fn handleConnection( var request_prompt = try allocator.dupe(u8, parsed_req.prompt); defer allocator.free(request_prompt); - if (session_store) |store| { + if (use_workspace) { + if (workspace_store) |store| { + active_workspace = try store.getOrCreate(parsed_req.workspace_id.?); + if (active_workspace.?.messages.len > 0) { + const merged_messages = try workspace.mergeIncomingMessagesAlloc( + allocator, + active_workspace.?.messages, + parsed_req.messages, + ); + + for (request_messages) |message| { + message.deinit(allocator); + } + allocator.free(request_messages); + request_messages = merged_messages; + + allocator.free(request_prompt); + request_prompt = try extractLastUserPromptAlloc(allocator, request_messages); + } + + const with_hint = try workspace.prependMemoryHintAlloc(allocator, active_workspace.?, request_messages); + for (request_messages) |message| { + message.deinit(allocator); + } + allocator.free(request_messages); + request_messages = with_hint; + + allocator.free(request_prompt); + request_prompt = try extractLastUserPromptAlloc(allocator, request_messages); + } + } else if (session_store) |store| { if (parsed_req.session_id) |session_id| { session_store_guard.mutex.lock(); loaded_session = store.load(allocator, session_id, parsed_req.tenant_id) catch |err| blk: { @@ -729,7 +792,7 @@ fn handleConnection( ); defer stream_request.deinit(allocator); - if (try tooling.tryExecuteDebugTool(allocator, stream_request, app_config, request_id)) |tool_result| { + if (try tooling.tryExecuteDebugTool(allocator, stream_request, app_config, request_id, active_workspace)) |tool_result| { defer tool_result.deinit(allocator); try response.sendEventStreamHeaders(connection.*, request_id); @@ -772,7 +835,7 @@ fn handleConnection( return; } - const stream_auto_tool_summary = tooling.maybeExecutePromptToolsAlloc(allocator, stream_request, app_config, request_id) catch |err| switch (err) { + const stream_auto_tool_summary = tooling.maybeExecutePromptToolsAlloc(allocator, stream_request, app_config, request_id, active_workspace) catch |err| switch (err) { error.ToolCallLimitExceeded => { sendApiErrorSafe(connection.*, allocator, backend.errors.validationError( "Tool call limit exceeded", @@ -868,16 +931,19 @@ fn handleConnection( } } - if (try tooling.tryExecuteDebugTool(allocator, parsed_req, app_config, request_id)) |tool_result| { + if (try tooling.tryExecuteDebugTool(allocator, parsed_req, app_config, request_id, active_workspace)) |tool_result| { defer tool_result.deinit(allocator); debugLog( app_config, "debug tool executed tool_choice={s}", .{parsed_req.tool_choice orelse "(none)"}, ); - persistSessionState( + persistConversationState( allocator, app_config, + use_workspace, + workspace_store, + active_workspace, session_store, session_store_guard, parsed_req, @@ -903,7 +969,7 @@ fn handleConnection( ); defer provider_request.deinit(allocator); - const auto_tool_summary = tooling.maybeExecutePromptToolsAlloc(allocator, provider_request, app_config, request_id) catch |err| switch (err) { + const auto_tool_summary = tooling.maybeExecutePromptToolsAlloc(allocator, provider_request, app_config, request_id, active_workspace) catch |err| switch (err) { error.ToolCallLimitExceeded => { sendApiErrorSafe(connection.*, allocator, backend.errors.validationError( "Tool call limit exceeded", @@ -921,9 +987,12 @@ fn handleConnection( if (tooling.shouldShortCircuitAutoTools(provider_request)) { var tool_response = try tooling.makeAutoToolResponse(allocator, summary); defer tool_response.deinit(allocator); - persistSessionState( + persistConversationState( allocator, app_config, + use_workspace, + workspace_store, + active_workspace, session_store, session_store_guard, parsed_req, @@ -1079,10 +1148,13 @@ fn handleConnection( ); if (session_store) |store| { - if (parsed_req.session_id != null) { - persistSessionState( + if (!use_workspace and parsed_req.session_id != null) { + persistConversationState( allocator, app_config, + false, + workspace_store, + active_workspace, store, session_store_guard, parsed_req, @@ -1094,6 +1166,23 @@ fn handleConnection( } } + if (use_workspace) { + persistConversationState( + allocator, + app_config, + true, + workspace_store, + active_workspace, + session_store, + session_store_guard, + parsed_req, + loaded_session, + provider_request.messages, + loop_messages_for_persistence, + result.output, + ); + } + sendChatCompletionSafe(connection.*, allocator, result, app_config, request_id); server_state.noteRequestSucceeded(); request_status = 200; @@ -1141,6 +1230,7 @@ fn cloneRequestWithMessagesAlloc( .repeat_penalty = parsed_req.repeat_penalty, .session_id = if (parsed_req.session_id) |session_id| try allocator.dupe(u8, session_id) else null, .tenant_id = if (parsed_req.tenant_id) |tenant_id| try allocator.dupe(u8, tenant_id) else null, + .workspace_id = if (parsed_req.workspace_id) |workspace_id| try allocator.dupe(u8, workspace_id) else null, .max_context_tokens = parsed_req.max_context_tokens, .tools = copied_tools, .tool_choice = if (parsed_req.tool_choice) |tool_choice| try allocator.dupe(u8, tool_choice) else null, @@ -1155,6 +1245,160 @@ const LoopExecution = struct { messages: []types.Message, }; +const agent_loop_guidance = + "You are running in API agent loop mode. Improve the answer each turn. Briefly self-critique then improve. Include the completion marker exactly when fully complete."; + +const agent_continue_prompt = + "Critique your previous answer briefly, then improve it with concrete next steps. If complete, include the completion marker exactly and return the final result."; + +const react_loop_default_max_turns: usize = 24; +const react_repeated_stop_threshold: usize = 2; + +fn resolveLoopMaxTurns(loop_mode: []const u8, explicit: ?usize) usize { + if (explicit) |value| return value; + if (std.ascii.eqlIgnoreCase(loop_mode, "react")) return react_loop_default_max_turns; + return 8; +} + +fn reactToolCallBudget(app_config: *const config.Config, loop_max_turns: usize) usize { + return @max(app_config.max_tool_calls_per_request, loop_max_turns); +} + +fn maybeCompactLoopWorkingMessages( + allocator: std.mem.Allocator, + working_messages: *[]types.Message, + max_context_tokens: ?usize, +) !void { + const max_tokens = max_context_tokens orelse return; + const estimated = session.estimateTokenCount(working_messages.*); + if (!session.shouldCompressContext(estimated, max_tokens)) return; + + const compacted = try session.compactContextToBudgetAlloc( + allocator, + working_messages.*, + max_tokens, + context_compaction_keep_messages, + ); + + for (working_messages.*) |message| { + message.deinit(allocator); + } + allocator.free(working_messages.*); + working_messages.* = compacted; +} + +fn prepareLoopWorkingMessagesAlloc( + allocator: std.mem.Allocator, + base_request: types.Request, + loop_mode: []const u8, +) ![]types.Message { + var working_messages = try session.cloneMessagesAlloc(allocator, base_request.messages); + errdefer { + for (working_messages) |message| { + message.deinit(allocator); + } + allocator.free(working_messages); + } + + if (std.ascii.eqlIgnoreCase(loop_mode, "agent")) { + const with_guidance = try appendRoleMessageAlloc( + allocator, + working_messages, + "system", + agent_loop_guidance, + ); + for (working_messages) |message| { + message.deinit(allocator); + } + allocator.free(working_messages); + working_messages = with_guidance; + } else if (std.ascii.eqlIgnoreCase(loop_mode, "react")) { + const react_prompt = try react.buildSystemPromptAlloc(allocator, base_request.tools); + defer allocator.free(react_prompt); + const with_react = try appendRoleMessageAlloc( + allocator, + working_messages, + "system", + react_prompt, + ); + for (working_messages) |message| { + message.deinit(allocator); + } + allocator.free(working_messages); + working_messages = with_react; + } + + return working_messages; +} + +fn nextBasicLoopPromptAlloc(allocator: std.mem.Allocator, loop_mode: []const u8) ![]u8 { + return try allocator.dupe( + u8, + if (std.ascii.eqlIgnoreCase(loop_mode, "agent")) agent_continue_prompt else "Continue.", + ); +} + +fn providerRequestForLoopTurn( + turn_request: types.Request, + loop_mode: []const u8, +) types.Request { + if (!std.ascii.eqlIgnoreCase(loop_mode, "react")) return turn_request; + + // ReAct advertises tools in the system prompt and parses actions locally. + // Do not forward tools upstream or providers activate tool templates that + // break on prior assistant messages. + var r = turn_request; + r.tools = &.{}; + r.tool_choice = null; + return r; +} + +fn reactToolPromptAlloc( + allocator: std.mem.Allocator, + action: react.ReactAction, + tools: []const types.Tool, +) !?[]u8 { + return switch (action) { + .search => |query| blk: { + if (!tooling.hasRequestedTool(tools, "file_search")) break :blk null; + break :blk try std.fmt.allocPrint(allocator, "search {s} in .", .{query}); + }, + .lookup => |path| blk: { + if (!tooling.hasRequestedTool(tools, "file_read")) break :blk null; + break :blk try std.fmt.allocPrint(allocator, "read {s}", .{path}); + }, + .tool => |tool_action| blk: { + if (std.ascii.eqlIgnoreCase(tool_action.name, "file_read")) { + const trimmed = std.mem.trim(u8, tool_action.argument, " \t\r\n\"'"); + if (std.ascii.startsWithIgnoreCase(trimmed, "read ") or + std.ascii.startsWithIgnoreCase(trimmed, "file_read ")) + { + break :blk try allocator.dupe(u8, trimmed); + } + break :blk try std.fmt.allocPrint(allocator, "read {s}", .{trimmed}); + } + if (std.ascii.eqlIgnoreCase(tool_action.name, "file_write")) { + const trimmed = std.mem.trim(u8, tool_action.argument, " \t\r\n"); + if (std.ascii.startsWithIgnoreCase(trimmed, "write ")) { + break :blk try allocator.dupe(u8, trimmed); + } + break :blk try std.fmt.allocPrint(allocator, "write {s}", .{trimmed}); + } + if (std.ascii.eqlIgnoreCase(tool_action.name, "file_search")) { + const trimmed = std.mem.trim(u8, tool_action.argument, " \t\r\n\"'"); + if (std.ascii.startsWithIgnoreCase(trimmed, "search ") or + std.ascii.startsWithIgnoreCase(trimmed, "file_search ")) + { + break :blk try allocator.dupe(u8, trimmed); + } + break :blk try std.fmt.allocPrint(allocator, "search {s} in .", .{trimmed}); + } + break :blk try allocator.dupe(u8, tool_action.argument); + }, + else => null, + }; +} + fn executeReactActionForRequest( allocator: std.mem.Allocator, app_config: *const config.Config, @@ -1163,6 +1407,41 @@ fn executeReactActionForRequest( action: react.ReactAction, request_id: []const u8, ) ![]u8 { + if (try reactToolPromptAlloc(allocator, action, turn_request.tools)) |mapped_prompt| { + defer allocator.free(mapped_prompt); + + const tool_name: []const u8 = switch (action) { + .search => "file_search", + .lookup => "file_read", + .tool => |tool_action| tool_action.name, + else => unreachable, + }; + + var tool_request = try cloneRequestWithMessagesAlloc( + allocator, + turn_request, + mapped_prompt, + tool_messages, + ); + defer tool_request.deinit(allocator); + + if (tool_request.tool_choice) |value| { + allocator.free(value); + } + tool_request.tool_choice = try allocator.dupe(u8, tool_name); + + if (try tooling.tryExecuteDebugTool(allocator, tool_request, app_config, request_id, null)) |tool_result| { + defer tool_result.deinit(allocator); + return try allocator.dupe(u8, tool_result.output); + } + + return try std.fmt.allocPrint( + allocator, + "Tool action {s} was not requested or its input could not be parsed.", + .{tool_name}, + ); + } + return switch (action) { .tool => |tool_action| blk: { var tool_request = try cloneRequestWithMessagesAlloc( @@ -1178,7 +1457,7 @@ fn executeReactActionForRequest( } tool_request.tool_choice = try allocator.dupe(u8, tool_action.name); - if (try tooling.tryExecuteDebugTool(allocator, tool_request, app_config, request_id)) |tool_result| { + if (try tooling.tryExecuteDebugTool(allocator, tool_request, app_config, request_id, null)) |tool_result| { defer tool_result.deinit(allocator); break :blk try allocator.dupe(u8, tool_result.output); } @@ -1203,7 +1482,9 @@ fn streamLoopRequestToSse( ) !void { const loop_mode = base_request.loop_mode orelse "basic"; const loop_until = base_request.loop_until orelse "DONE"; - const loop_max_turns = base_request.loop_max_turns orelse 8; + const loop_max_turns = resolveLoopMaxTurns(loop_mode, base_request.loop_max_turns); + const effective_max_context_tokens = base_request.max_context_tokens orelse app_config.default_max_context_tokens; + const react_tool_budget = reactToolCallBudget(app_config, loop_max_turns); try response.sendEventStreamHeaders(connection, request_id); var headers_sent: bool = true; @@ -1216,7 +1497,7 @@ fn streamLoopRequestToSse( ); defer allocator.free(completion_id); - var working_messages = try session.cloneMessagesAlloc(allocator, base_request.messages); + var working_messages = try prepareLoopWorkingMessagesAlloc(allocator, base_request, loop_mode); defer { for (working_messages) |message| { message.deinit(allocator); @@ -1224,34 +1505,6 @@ fn streamLoopRequestToSse( allocator.free(working_messages); } - if (std.ascii.eqlIgnoreCase(loop_mode, "agent")) { - const with_guidance = try appendRoleMessageAlloc( - allocator, - working_messages, - "system", - "You are running in API agent loop mode. Improve the answer each turn. Briefly self-critique then improve. Include the completion marker exactly when fully complete.", - ); - for (working_messages) |message| { - message.deinit(allocator); - } - allocator.free(working_messages); - working_messages = with_guidance; - } else if (std.ascii.eqlIgnoreCase(loop_mode, "react")) { - const react_prompt = try react.buildSystemPromptAlloc(allocator, base_request.tools); - defer allocator.free(react_prompt); - const with_react = try appendRoleMessageAlloc( - allocator, - working_messages, - "system", - react_prompt, - ); - for (working_messages) |message| { - message.deinit(allocator); - } - allocator.free(working_messages); - working_messages = with_react; - } - var latest_user_prompt = try allocator.dupe(u8, base_request.prompt); defer allocator.free(latest_user_prompt); @@ -1262,6 +1515,8 @@ fn streamLoopRequestToSse( var turn: usize = 0; while (turn < loop_max_turns) : (turn += 1) { + try maybeCompactLoopWorkingMessages(allocator, &working_messages, effective_max_context_tokens); + var turn_request = try cloneRequestWithMessagesAlloc(allocator, base_request, latest_user_prompt, working_messages); defer turn_request.deinit(allocator); @@ -1275,20 +1530,7 @@ fn streamLoopRequestToSse( } turn_request.loop_max_turns = null; - // In ReAct mode the tool catalog is advertised inside the system - // prompt and parsed locally; we must NOT forward `tools` upstream or - // providers like Ollama activate their XML tool-calling template and - // choke on prior assistant messages containing markdown or - // (manifests as "internal_server_error ... element closed - // by "). We strip tools on a shallow alias so the original - // turn_request keeps them for local action dispatch. - const empty_tools: []types.Tool = &.{}; - const provider_request: types.Request = if (std.ascii.eqlIgnoreCase(loop_mode, "react")) blk: { - var r = turn_request; - r.tools = empty_tools; - r.tool_choice = null; - break :blk r; - } else turn_request; + const provider_request = providerRequestForLoopTurn(turn_request, loop_mode); // Emit the per-turn header BEFORE provider tokens stream in, so the // client can render "[loop turn N/M]\n" then watch tokens arrive live. @@ -1304,7 +1546,9 @@ fn streamLoopRequestToSse( connection, allocator, completion_id, - provider_request.model orelse app_config.ollama_model, + provider_request.model orelse app_config.modelForProvider( + provider_request.provider orelse app_config.default_provider, + ), progress_prefix, null, ); @@ -1360,7 +1604,11 @@ fn streamLoopRequestToSse( const is_react_mode = std.ascii.eqlIgnoreCase(loop_mode, "react"); const reached_until = std.mem.indexOf(u8, turn_outcome.captured_content, loop_until) != null; const reached_max = (turn + 1) >= loop_max_turns; - const repeated_stop = (std.ascii.eqlIgnoreCase(loop_mode, "agent") or is_react_mode) and repeated_count >= 1; + const repeated_stop = blk: { + if (is_react_mode) break :blk repeated_count >= react_repeated_stop_threshold; + if (std.ascii.eqlIgnoreCase(loop_mode, "agent")) break :blk repeated_count >= 1; + break :blk false; + }; // ReAct: check for Finish action before standard stop checks. var react_finished = false; @@ -1374,7 +1622,7 @@ fn streamLoopRequestToSse( react_finished = true; }, else => { - try tooling.noteToolCall(app_config.max_tool_calls_per_request, &react_tool_calls); + try tooling.noteToolCall(react_tool_budget, &react_tool_calls); const obs_raw = try executeReactActionForRequest(allocator, app_config, turn_request, working_messages, action, request_id); defer allocator.free(obs_raw); react_observation = try react.formatObservation(allocator, turn + 1, obs_raw); @@ -1384,7 +1632,7 @@ fn streamLoopRequestToSse( react_observation = try react.formatObservation( allocator, turn + 1, - "Could not parse an Action from your response. Please use the format: Action N: Type[argument]", + react.parse_error_hint, ); } } @@ -1432,13 +1680,7 @@ fn streamLoopRequestToSse( ); } } else { - latest_user_prompt = try allocator.dupe( - u8, - if (std.ascii.eqlIgnoreCase(loop_mode, "agent")) - "Critique your previous answer briefly, then improve it with concrete next steps. If complete, include the completion marker exactly and return the final result." - else - "Continue.", - ); + latest_user_prompt = try nextBasicLoopPromptAlloc(allocator, loop_mode); } const with_continue = try appendRoleMessageAlloc(allocator, working_messages, "user", latest_user_prompt); @@ -1460,9 +1702,11 @@ fn executeLoopRequestAlloc( ) !LoopExecution { const loop_mode = base_request.loop_mode orelse "basic"; const loop_until = base_request.loop_until orelse "DONE"; - const loop_max_turns = base_request.loop_max_turns orelse 8; + const loop_max_turns = resolveLoopMaxTurns(loop_mode, base_request.loop_max_turns); + const effective_max_context_tokens = base_request.max_context_tokens orelse app_config.default_max_context_tokens; + const react_tool_budget = reactToolCallBudget(app_config, loop_max_turns); - var working_messages = try session.cloneMessagesAlloc(allocator, base_request.messages); + var working_messages = try prepareLoopWorkingMessagesAlloc(allocator, base_request, loop_mode); errdefer { for (working_messages) |message| { message.deinit(allocator); @@ -1470,34 +1714,6 @@ fn executeLoopRequestAlloc( allocator.free(working_messages); } - if (std.ascii.eqlIgnoreCase(loop_mode, "agent")) { - const with_guidance = try appendRoleMessageAlloc( - allocator, - working_messages, - "system", - "You are running in API agent loop mode. Improve the answer each turn. Briefly self-critique then improve. Include the completion marker exactly when fully complete.", - ); - for (working_messages) |message| { - message.deinit(allocator); - } - allocator.free(working_messages); - working_messages = with_guidance; - } else if (std.ascii.eqlIgnoreCase(loop_mode, "react")) { - const react_prompt = try react.buildSystemPromptAlloc(allocator, base_request.tools); - defer allocator.free(react_prompt); - const with_react = try appendRoleMessageAlloc( - allocator, - working_messages, - "system", - react_prompt, - ); - for (working_messages) |message| { - message.deinit(allocator); - } - allocator.free(working_messages); - working_messages = with_react; - } - var latest_user_prompt = try allocator.dupe(u8, base_request.prompt); defer allocator.free(latest_user_prompt); @@ -1508,6 +1724,8 @@ fn executeLoopRequestAlloc( var turn: usize = 0; while (turn < loop_max_turns) : (turn += 1) { + try maybeCompactLoopWorkingMessages(allocator, &working_messages, effective_max_context_tokens); + var turn_request = try cloneRequestWithMessagesAlloc(allocator, base_request, latest_user_prompt, working_messages); defer turn_request.deinit(allocator); @@ -1523,13 +1741,7 @@ fn executeLoopRequestAlloc( // See streamLoopRequestToSse: strip tools/tool_choice from the // upstream call in ReAct mode to avoid provider-side tool templates. - const empty_tools: []types.Tool = &.{}; - const provider_request: types.Request = if (std.ascii.eqlIgnoreCase(loop_mode, "react")) blk: { - var r = turn_request; - r.tools = empty_tools; - r.tool_choice = null; - break :blk r; - } else turn_request; + const provider_request = providerRequestForLoopTurn(turn_request, loop_mode); var turn_result = try backend.callProvider(allocator, app_config, provider_request); if (turn_result.success and isLengthFinishReason(turn_result.finish_reason)) { @@ -1584,7 +1796,7 @@ fn executeLoopRequestAlloc( defer if (react_observation) |obs| allocator.free(obs); if (is_react_mode) { - if (repeated_count >= 1) { + if (repeated_count >= react_repeated_stop_threshold) { return .{ .response = turn_result, .messages = working_messages }; } @@ -1594,7 +1806,7 @@ fn executeLoopRequestAlloc( react_finished = true; }, else => { - try tooling.noteToolCall(app_config.max_tool_calls_per_request, &react_tool_calls); + try tooling.noteToolCall(react_tool_budget, &react_tool_calls); const obs_raw = try executeReactActionForRequest(allocator, app_config, turn_request, working_messages, action, request_id); defer allocator.free(obs_raw); react_observation = try react.formatObservation(allocator, turn + 1, obs_raw); @@ -1604,7 +1816,7 @@ fn executeLoopRequestAlloc( react_observation = try react.formatObservation( allocator, turn + 1, - "Could not parse an Action from your response. Please use the format: Action N: Type[argument]", + react.parse_error_hint, ); } @@ -1622,13 +1834,7 @@ fn executeLoopRequestAlloc( react_observation = null; // Transfer ownership. latest_user_prompt = obs; } else { - latest_user_prompt = try allocator.dupe( - u8, - if (std.ascii.eqlIgnoreCase(loop_mode, "agent")) - "Critique your previous answer briefly, then improve it with concrete next steps. If complete, include the completion marker exactly and return the final result." - else - "Continue.", - ); + latest_user_prompt = try nextBasicLoopPromptAlloc(allocator, loop_mode); } const with_continue = try appendRoleMessageAlloc(allocator, working_messages, "user", latest_user_prompt); @@ -1777,6 +1983,81 @@ fn isLengthFinishReason(finish_reason: []const u8) bool { return std.ascii.eqlIgnoreCase(finish_reason, "length"); } +fn persistConversationState( + allocator: std.mem.Allocator, + app_config: *const config.Config, + use_workspace: bool, + workspace_store: ?*workspace.WorkspaceStore, + active_workspace: ?*workspace.WorkspaceState, + session_store: ?*session.SessionStore, + session_store_guard: *SessionStoreGuard, + parsed_req: types.Request, + loaded_session: ?session.SessionState, + base_messages: []const types.Message, + loop_messages: ?[]types.Message, + assistant_output: []const u8, +) void { + if (use_workspace) { + persistWorkspaceState( + allocator, + workspace_store, + active_workspace, + base_messages, + loop_messages, + assistant_output, + ); + return; + } + + persistSessionState( + allocator, + app_config, + session_store, + session_store_guard, + parsed_req, + loaded_session, + base_messages, + loop_messages, + assistant_output, + ); +} + +fn persistWorkspaceState( + allocator: std.mem.Allocator, + workspace_store: ?*workspace.WorkspaceStore, + active_workspace: ?*workspace.WorkspaceState, + base_messages: []const types.Message, + loop_messages: ?[]types.Message, + assistant_output: []const u8, +) void { + const store = workspace_store orelse return; + const ws = active_workspace orelse return; + + const with_assistant = if (loop_messages) |messages| blk: { + break :blk session.cloneMessagesAlloc(allocator, messages) catch |err| blk2: { + logError("Workspace append failed: {s}", .{@errorName(err)}); + break :blk2 null; + }; + } else session.appendAssistantMessageAlloc( + allocator, + base_messages, + assistant_output, + ) catch |err| blk: { + logError("Workspace append failed: {s}", .{@errorName(err)}); + break :blk null; + }; + + if (with_assistant) |messages| { + store.saveMessages(ws, messages) catch |err| { + logError("Workspace save failed: {s}", .{@errorName(err)}); + for (messages) |message| { + message.deinit(allocator); + } + allocator.free(messages); + }; + } +} + fn persistSessionState( allocator: std.mem.Allocator, app_config: *const config.Config, From 151be42b9a8f19601f4f5bedfeffb9b7462af4f0 Mon Sep 17 00:00:00 2001 From: Ajeets6 Date: Fri, 29 May 2026 15:56:35 +1000 Subject: [PATCH 22/25] Honor default model across providers Use configured LLM_ROUTER_MODEL when requests omit model. --- src/providers/bedrock.zig | 2 +- src/providers/claude.zig | 4 ++-- src/providers/llama_cpp.zig | 2 +- src/providers/ollama_qwen.zig | 6 +++--- src/providers/openai.zig | 4 ++-- src/providers/openrouter.zig | 4 ++-- src/tools/command_exec.zig | 3 +++ 7 files changed, 14 insertions(+), 11 deletions(-) diff --git a/src/providers/bedrock.zig b/src/providers/bedrock.zig index fc1b3a2..205bbfd 100644 --- a/src/providers/bedrock.zig +++ b/src/providers/bedrock.zig @@ -9,7 +9,7 @@ pub fn callBedrock( app_config: *const config.Config, request: types.Request, ) !types.Response { - const model_name = request.model orelse app_config.bedrock_model; + const model_name = request.model orelse app_config.modelForProvider("bedrock"); if (app_config.bedrock_access_key_id.len == 0 or app_config.bedrock_secret_access_key.len == 0) { return try errorResponse( diff --git a/src/providers/claude.zig b/src/providers/claude.zig index ec9270b..b0fffd9 100644 --- a/src/providers/claude.zig +++ b/src/providers/claude.zig @@ -10,7 +10,7 @@ pub fn callClaude( if (app_config.claude_api_key.len == 0) { return .{ .id = null, - .model = try allocator.dupe(u8, request.model orelse app_config.claude_model), + .model = try allocator.dupe(u8, request.model orelse app_config.modelForProvider("claude")), .output = try allocator.dupe(u8, "Claude API key is not configured on the server"), .finish_reason = try allocator.dupe(u8, "stop"), .success = false, @@ -18,7 +18,7 @@ pub fn callClaude( }; } - const model_name = request.model orelse app_config.claude_model; + const model_name = request.model orelse app_config.modelForProvider("claude"); var client = std.http.Client{ .allocator = allocator }; defer client.deinit(); diff --git a/src/providers/llama_cpp.zig b/src/providers/llama_cpp.zig index ce66814..5860d41 100644 --- a/src/providers/llama_cpp.zig +++ b/src/providers/llama_cpp.zig @@ -8,7 +8,7 @@ pub fn callLlamaCpp( app_config: *const config.Config, request: types.Request, ) !types.Response { - const model_name = request.model orelse app_config.llama_cpp_model; + const model_name = request.model orelse app_config.modelForProvider("llama_cpp"); const maybe_key: ?[]const u8 = if (app_config.llama_cpp_api_key.len > 0) app_config.llama_cpp_api_key diff --git a/src/providers/ollama_qwen.zig b/src/providers/ollama_qwen.zig index 3497422..7034b31 100644 --- a/src/providers/ollama_qwen.zig +++ b/src/providers/ollama_qwen.zig @@ -94,7 +94,7 @@ pub fn streamQwenTurnToSse( captured_content: *std.ArrayList(u8), think_block_open: *bool, ) !TurnStreamResult { - const requested_model = request.model orelse app_config.ollama_model; + const requested_model = request.model orelse app_config.modelForProvider("ollama_qwen"); var model_name = requested_model; var fallback: ?ModelFallback = null; defer if (fallback) |value| value.deinit(allocator); @@ -167,7 +167,7 @@ pub fn streamQwenToSse( app_config: *const config.Config, request: types.Request, ) !StreamQwenResult { - const requested_model = request.model orelse app_config.ollama_model; + const requested_model = request.model orelse app_config.modelForProvider("ollama_qwen"); var model_name = requested_model; var fallback: ?ModelFallback = null; defer if (fallback) |value| value.deinit(allocator); @@ -276,7 +276,7 @@ pub fn callQwen( app_config: *const config.Config, request: types.Request, ) !types.Response { - const requested_model = request.model orelse app_config.ollama_model; + const requested_model = request.model orelse app_config.modelForProvider("ollama_qwen"); var model_name = requested_model; var fallback: ?ModelFallback = null; defer if (fallback) |value| value.deinit(allocator); diff --git a/src/providers/openai.zig b/src/providers/openai.zig index 8a1f5c3..48cd66b 100644 --- a/src/providers/openai.zig +++ b/src/providers/openai.zig @@ -11,7 +11,7 @@ pub fn callOpenAI( if (app_config.openai_api_key.len == 0) { return .{ .id = null, - .model = try allocator.dupe(u8, request.model orelse app_config.openai_model), + .model = try allocator.dupe(u8, request.model orelse app_config.modelForProvider("openai")), .output = try allocator.dupe(u8, "OpenAI API key is not configured on the server"), .finish_reason = try allocator.dupe(u8, "stop"), .success = false, @@ -19,7 +19,7 @@ pub fn callOpenAI( }; } - const model_name = request.model orelse app_config.openai_model; + const model_name = request.model orelse app_config.modelForProvider("openai"); return try openai_compatible.callChat( allocator, app_config.openai_base_url, diff --git a/src/providers/openrouter.zig b/src/providers/openrouter.zig index 31ca7bf..63d0f14 100644 --- a/src/providers/openrouter.zig +++ b/src/providers/openrouter.zig @@ -11,12 +11,12 @@ pub fn callOpenRouter( if (app_config.openrouter_api_key.len == 0) { return try errorResponse( allocator, - request.model orelse app_config.openrouter_model, + request.model orelse app_config.modelForProvider("openrouter"), "OpenRouter provider selected but OPENROUTER_API_KEY is not set", ); } - const model_name = request.model orelse app_config.openrouter_model; + const model_name = request.model orelse app_config.modelForProvider("openrouter"); var extra_headers_buffer: [2]std.http.Header = undefined; const extra_headers = buildExtraHeaders( app_config.openrouter_http_referer, diff --git a/src/tools/command_exec.zig b/src/tools/command_exec.zig index 09de654..9ec4081 100644 --- a/src/tools/command_exec.zig +++ b/src/tools/command_exec.zig @@ -647,6 +647,7 @@ pub fn buildTestConfig(allocator: std.mem.Allocator, enable_exec: bool) !config. .listen_port = 8081, .debug_logging = false, .default_provider = try allocator.dupe(u8, "ollama_qwen"), + .default_model = null, .request_timeout_ms = 30_000, .provider_timeout_ms = 60_000, .instance_id = try allocator.dupe(u8, "local-instance"), @@ -687,6 +688,8 @@ pub fn buildTestConfig(allocator: std.mem.Allocator, enable_exec: bool) !config. .tool_output_offload_bytes = 8192, .max_tool_calls_per_request = 8, .default_max_context_tokens = null, + .workspace_mode_enabled = false, + .workspace_root = try allocator.dupe(u8, "."), .loop_stream_progress_enabled = true, .max_concurrent_connections = 64, .max_request_bytes = 1024 * 1024, From ec0ed1823292a8f91b0f4880e7593bbf9e64a9d6 Mon Sep 17 00:00:00 2001 From: Ajeets6 Date: Fri, 29 May 2026 15:56:35 +1000 Subject: [PATCH 23/25] Update build defaults and test wiring Default ReleaseSafe optimize and Windows libc linking. Expand zig build test coverage hooks. --- build.zig | 105 ++++++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 83 insertions(+), 22 deletions(-) diff --git a/build.zig b/build.zig index 541e536..1eae76a 100644 --- a/build.zig +++ b/build.zig @@ -9,7 +9,11 @@ const TestTarget = enum { pub fn build(b: *std.Build) void { const target = b.standardTargetOptions(.{}); - const optimize = b.standardOptimizeOption(.{}); + const optimize = b.option( + std.builtin.OptimizeMode, + "optimize", + "Prioritize performance, safety, or binary size (default: ReleaseSafe)", + ) orelse .ReleaseSafe; const package_name = "zig_coding_agent"; const exe_name = "zig-coding-agent"; @@ -25,6 +29,7 @@ pub fn build(b: *std.Build) void { const mod = b.addModule(package_name, .{ .root_source_file = b.path("src/root.zig"), .target = target, + .optimize = optimize, }); // CLI executable entrypoint. @@ -34,41 +39,97 @@ pub fn build(b: *std.Build) void { .root_source_file = b.path("src/main.zig"), .target = target, .optimize = optimize, + .strip = optimize != .Debug, .imports = &.{ .{ .name = package_name, .module = mod }, }, }), }); - // No linkLibC(): harness uses std only; avoids glibc/GCC 16 .sframe linker issues on Zig 0.15.2. + // Linux native build avoids linkLibC (glibc/GCC 16 .sframe linker issues on Zig 0.15.2). + // Windows needs libc for socket APIs (recvfrom, etc.). + if (app.root_module.resolved_target.?.query.os_tag == .windows) { + app.linkLibC(); + } b.installArtifact(app); - // Optional sibling checkout: ../zig_eval (same parent folder as this repo). - const zig_eval_root = "../zig_eval"; - const zig_eval_registry = "registry"; - - const build_zig_eval = b.addSystemCommand(&.{ "zig", "build" }); + // Cross-compile for Windows: `zig build windows` + const windows_target = b.resolveTargetQuery(.{ + .cpu_arch = .x86_64, + .os_tag = .windows, + .abi = .gnu, + }); + const windows_app = b.addExecutable(.{ + .name = exe_name, + .root_module = b.createModule(.{ + .root_source_file = b.path("src/main.zig"), + .target = windows_target, + .optimize = optimize, + .strip = optimize != .Debug, + .imports = &.{ + .{ .name = package_name, .module = mod }, + }, + }), + }); + windows_app.linkLibC(); + const windows_step = b.step("windows", "Cross-compile ReleaseSafe binary for x86_64-windows-gnu"); + windows_step.dependOn(&b.addInstallArtifact(windows_app, .{}).step); + + // Sibling eval runner: `zig build zig_evals` (requires ../zig_eval and a running harness). + const zig_eval_root = b.option([]const u8, "eval-root", "Path to sibling zig_eval checkout") orelse "../zig_eval"; + const eval_registry = b.option([]const u8, "eval-registry", "Registry directory inside zig_eval") orelse "registry"; + const eval_service = b.option([]const u8, "eval-service", "Service name from registry/services.json") orelse "local-openai-compat"; + const harness_base = b.option([]const u8, "eval-harness-url", "Harness base URL for readiness checks") orelse "http://127.0.0.1:8081"; + const eval_wait_seconds = b.option(u32, "eval-wait-seconds", "Seconds to wait for harness readiness before failing") orelse 0; + + const preflight_script = b.fmt( + \\if [ -f ".env" ]; then set -a; . ./.env; set +a; fi + \\PROVIDER="${{LLM_ROUTER_PROVIDER:-ollama}}" + \\for i in $(seq 0 {d}); do + \\ if curl -sf "{s}/health" >/dev/null; then + \\ PAYLOAD=$(printf '{{"messages":[{{"role":"user","content":"Reply OK"}}],"provider":"%s","model":"auto"}}' "$PROVIDER") + \\ if curl -sf -X POST "{s}/v1/chat/completions" \ + \\ -H "Content-Type: application/json" \ + \\ ${{LLM_ROUTER_API_KEY:+-H "Authorization: Bearer $LLM_ROUTER_API_KEY"}} \ + \\ -d "$PAYLOAD" >/dev/null; then + \\ exit 0 + \\ fi + \\ echo "error: harness chat probe failed at {s}/v1/chat/completions (provider=$PROVIDER; is the LLM configured?)" >&2 + \\ exit 1 + \\ fi + \\ if [ "$i" -eq {d} ]; then break; fi + \\ sleep 1 + \\done + \\echo "error: harness not reachable at {s}/health (start it with: zig build run -- --use-env)" >&2 + \\exit 1 + , + .{ eval_wait_seconds, harness_base, harness_base, harness_base, eval_wait_seconds, harness_base }, + ); + + const eval_preflight = b.addSystemCommand(&.{ "sh", "-c", preflight_script }); + + const build_zig_eval = b.addSystemCommand(&.{ "zig", "build", "-Doptimize=ReleaseSafe" }); build_zig_eval.setCwd(b.path(zig_eval_root)); + build_zig_eval.step.dependOn(&eval_preflight.step); - const eval_build_step = b.step("eval-build", "Build sibling zig_eval checkout (requires ../zig_eval)"); - eval_build_step.dependOn(&build_zig_eval.step); - - const eval_list_cmd = b.addSystemCommand(&.{ "zig", "build", "run", "--", "list", "--registry", zig_eval_registry }); - eval_list_cmd.setCwd(b.path(zig_eval_root)); - eval_list_cmd.step.dependOn(&build_zig_eval.step); - - const eval_list_step = b.step("eval-list", "List evals from sibling zig_eval registry (requires ../zig_eval)"); - eval_list_step.dependOn(&eval_list_cmd.step); + const run_evals_script = b.fmt( + \\if [ -f ".env" ]; then set -a; . ./.env; set +a; fi + \\cd "{s}" && zig build run -Doptimize=ReleaseSafe -- run --registry {s} --service {s} "$@" + , + .{ zig_eval_root, eval_registry, eval_service }, + ); - const eval_run_cmd = b.addSystemCommand(&.{ "zig", "build", "run", "--", "run", "--registry", zig_eval_registry }); - eval_run_cmd.setCwd(b.path(zig_eval_root)); - eval_run_cmd.step.dependOn(&build_zig_eval.step); + const run_evals = b.addSystemCommand(&.{ "sh", "-c", run_evals_script, "zig_evals" }); + run_evals.step.dependOn(&build_zig_eval.step); if (b.args) |args| { - eval_run_cmd.addArgs(args); + run_evals.addArgs(args); } - const eval_run_step = b.step("eval-run", "Run sibling zig_eval against this harness (requires ../zig_eval; start harness first)"); - eval_run_step.dependOn(&eval_run_cmd.step); + const zig_evals_step = b.step( + "zig_evals", + "Run zig_eval registry against a live harness (requires ../zig_eval; start server first)", + ); + zig_evals_step.dependOn(&run_evals.step); // Run command: `zig build run -- [args]`. const run_step = b.step("run", "Run the Zig Coding Agent server"); From f8e96459f2be2847426544826d31ce48536b4b1c Mon Sep 17 00:00:00 2001 From: Ajeets6 Date: Fri, 29 May 2026 15:56:35 +1000 Subject: [PATCH 24/25] Document workspace and ReAct operator setup Update README and env example for workspace and react loops. Note server-side react tool auto-attach in test client. --- .env.example | 31 +++++++++++++ README.md | 101 +++++++++++++++++++++++++++++++++++------- client/test_client.py | 8 ++-- 3 files changed, 121 insertions(+), 19 deletions(-) create mode 100644 .env.example diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..bcd1feb --- /dev/null +++ b/.env.example @@ -0,0 +1,31 @@ +# Active routing (change these to switch provider/model without editing code) +LLM_ROUTER_PROVIDER=ollama +LLM_ROUTER_MODEL=qwen3.5:9b + +# Local Ollama +OLLAMA_BASE_URL=http://127.0.0.1:11434 +OLLAMA_MODEL=qwen3.5:9b + +# OpenAI +OPENAI_API_KEY= +OPENAI_BASE_URL=https://api.openai.com/v1 +OPENAI_MODEL=gpt-4.1-mini + +# OpenRouter +OPENROUTER_API_KEY= +OPENROUTER_BASE_URL=https://openrouter.ai/api/v1 +OPENROUTER_MODEL=openrouter/auto + +# Anthropic Claude +CLAUDE_API_KEY= +CLAUDE_MODEL=claude-3-5-sonnet-latest + +# Optional harness auth for clients/evals +# LLM_ROUTER_API_KEY= + +# Tooling (off by default in untrusted deployments) +# LLM_ROUTER_TOOL_EXEC_ENABLED=0 + +# ReAct / small-model loops (optional) +# LLM_ROUTER_DEFAULT_MAX_CONTEXT_TOKENS=8192 +# LLM_ROUTER_MAX_TOOL_CALLS_PER_REQUEST=24 diff --git a/README.md b/README.md index e15ca5b..93b1016 100644 --- a/README.md +++ b/README.md @@ -57,9 +57,12 @@ flowchart TD ## Build, Run, and Test ```bash -zig build +zig build # ReleaseSafe by default (~2 MB stripped binary) zig build run zig build check +zig build windows # cross-compile zig-coding-agent.exe for x86_64-windows-gnu +zig build -Doptimize=Debug # larger binary with safety checks for development +zig build -Doptimize=ReleaseSmall # smallest binary (~800 KB) ``` ### Test Targets @@ -81,25 +84,58 @@ zig build test -Dtest-target=file "-Dtest-file=src/types.zig" zig build test -Dtest-target=all -Dtest-filter=normalizeProviderName ``` -### zig_eval harness (optional) +### zig_eval (optional sibling checkout) -This repo depends on the sibling **zig_eval** checkout at `../zig_eval` (see `build.zig.zon`). It ships a small **eval registry** under `eval/registry/` (same layout as `zig_eval`: `services.json`, `evals//*.json`, `data/.../*.jsonl`) and a runner binary **`zig-coding-agent-eval`**. +Evaluations live in the sibling **zig_eval** repo. Check out both repos side by side: -**Prerequisites:** Ollama (or whatever `services.json` targets) reachable, and this router listening (default `http://127.0.0.1:8081`). If `LLM_ROUTER_API_KEY` is set on the server, set the same variable for the eval client (see `eval/registry/services.json` → `api_key_env`). +```text +zig/ +├── zig_coding_agent/ # this harness +└── zig_eval/ # registry-driven eval runner +``` + +`zig_eval/registry/services.json` targets this router at `http://127.0.0.1:8081` (`local-openai-compat`). + +**Prerequisites:** LLM provider reachable (e.g. Ollama), harness listening, and matching `LLM_ROUTER_API_KEY` if auth is enabled. ```bash -# Terminal A: router -zig build run +# Terminal A: start harness — pick provider + model via env (or .env) +export LLM_ROUTER_PROVIDER=openrouter # ollama | openai | openrouter | claude | bedrock | llama_cpp +export LLM_ROUTER_MODEL=openai/gpt-4o-mini +export OPENROUTER_API_KEY=your-key # set the key for the active provider +zig build run -- --use-env + +# Switch to local Ollama later (restart server): +# export LLM_ROUTER_PROVIDER=ollama +# export LLM_ROUTER_MODEL=qwen3.5:9b + +# CLI overrides without editing env: +# zig build run -- --use-env --provider openai --model gpt-4.1-mini -# Terminal B: from repo root — runs all eval definitions under eval/registry/evals/ -zig build eval-run +# Terminal B: run all registry evals against the live harness +zig build zig_evals -# Custom registry directory or JSON run log -zig build eval-run -- path/to/registry -zig build eval-run -- --json +# Wait up to 30s for the server to come up, then run evals +zig build zig_evals -Deval-wait-seconds=30 + +# Filter evals (pass-through to zig_eval CLI) +zig build zig_evals -- --eval smoke.reply_ok --format json +zig build zig_evals -- --group smoke --parallel 2 + +# Override registry or service +zig build zig_evals -Deval-registry=examples/registry -Deval-service=local-product ``` -There is currently **one** eval definition checked in (`smoke.reply_ok`). Add more by dropping new JSON files under `eval/registry/evals/` and matching datasets under `eval/registry/data/`. +Readiness checks before evals start: + +1. `GET /health` on the harness (default `http://127.0.0.1:8081`) +2. A minimal `POST /v1/chat/completions` probe to confirm the LLM path works + +List evals directly from the sibling checkout: + +```bash +cd ../zig_eval && zig build run -- list --registry registry +``` ## API Surface @@ -179,6 +215,7 @@ Loop controls: - --prompt initial prompt and loop entry - --provider provider override +- --model default model override (also available as `LLM_ROUTER_MODEL`) - --until completion marker (default: DONE) - --max-turns loop safety cap (default: 8) - --loop-mode loop style @@ -215,10 +252,19 @@ zig build run -- --react --prompt "What is the elevation range of the High Plain { "messages": [{ "role": "user", "content": "What is the elevation range of the High Plains?" }], "loop_mode": "react", - "loop_max_turns": 10 + "loop_max_turns": 24, + "tools": [ + { "name": "file_read", "description": "Read a project file" }, + { "name": "file_write", "description": "Write a project file" }, + { "name": "file_search", "description": "Search the repo" } + ] } ``` +When `loop_max_turns` is omitted, ReAct defaults to **24 turns** (vs 8 for basic/agent) so smaller models can solve coding tasks step-by-step. Tool-call budget scales with the turn cap. Set `max_context_tokens` (for example `8192`) so long loops compact history between turns. + +The server **auto-attaches** the coding tool set for `loop_mode: "react"` (`file_read`, `file_write`, `file_search`, plus `bash` or `cmd` on the host OS). Client-provided tools are preserved; missing defaults are merged in. `tool_choice` defaults to `auto` when omitted. Any HTTP client can use ReAct without sending a `tools` array. + The model will produce output like: ``` @@ -262,6 +308,9 @@ Related limits: - LLM_ROUTER_TOOL_EXEC_TIMEOUT_MS (default: 15000) - LLM_ROUTER_TOOL_EXEC_MAX_OUTPUT_BYTES (default: 65536) +- LLM_ROUTER_TOOL_EXEC_CONFIRM_REQUIRED (default: true; re-send with `LLM_ROUTER_TOOL_CONFIRM `) +- LLM_ROUTER_TOOL_EXEC_TRUSTED_LOCAL (default: false; allows pipes/chaining with denylist kept) +- LLM_ROUTER_TOOL_OUTPUT_OFFLOAD_BYTES (default: 8192; 0 disables offloading large tool output to disk) - LLM_ROUTER_MAX_TOOL_CALLS_PER_REQUEST (default: 8) ## Configuration Reference @@ -274,6 +323,7 @@ Related limits: | LLM_ROUTER_PORT | 8081 | | LLM_ROUTER_DEBUG | 0 | | LLM_ROUTER_PROVIDER | ollama | +| LLM_ROUTER_MODEL | unset (uses per-provider `*_MODEL`) | | LLM_ROUTER_INSTANCE_ID | local-instance | | LLM_ROUTER_API_KEY | empty | | LLM_ROUTER_REQUEST_TIMEOUT_MS | 30000 | @@ -282,6 +332,7 @@ Related limits: | LLM_ROUTER_MAX_CONCURRENT_CONNECTIONS | 64 | | LLM_ROUTER_MAX_REQUEST_BYTES | 1048576 | | LLM_ROUTER_MAX_HEADER_BYTES | 16384 | +| LLM_ROUTER_DEFAULT_MAX_CONTEXT_TOKENS | unset | ### Session Storage @@ -290,14 +341,34 @@ Related limits: | LLM_ROUTER_SESSION_STORE_PATH | logs/sessions | | LLM_ROUTER_SESSION_RETENTION_MESSAGES | 24 | +### Workspace Mode (optional, in-memory) + +Lightweight Cursor-like mode: same `workspace_id` keeps conversation in RAM until the server exits. File tools resolve under the workspace root instead of `tmp/`. Disabled by default. + +| Variable | Default | +| -------------------------- | ------- | +| LLM_ROUTER_WORKSPACE_MODE | 0 | +| LLM_ROUTER_WORKSPACE_ROOT | `.` | + +Send `workspace_id` in the JSON body (separate from disk `session_id`): + +```bash +export LLM_ROUTER_WORKSPACE_MODE=1 +export LLM_ROUTER_WORKSPACE_ROOT=. + +curl -s http://127.0.0.1:8081/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{"workspace_id":"dev-1","messages":[{"role":"user","content":"Read src/main.zig"}],"tools":[{"name":"file_read","description":"read files"}],"tool_choice":"file_read"}' +``` + ### Ollama | Variable | Default | | --------------------- | ------------------------ | | OLLAMA_BASE_URL | | | OLLAMA_MODEL | qwen3.5:9b | -| OLLAMA_THINK | 0 | -| OLLAMA_NUM_PREDICT | 128 | +| OLLAMA_THINK | true | +| OLLAMA_NUM_PREDICT | 2048 | | OLLAMA_TEMPERATURE | 0.7 | | OLLAMA_REPEAT_PENALTY | 1.05 | diff --git a/client/test_client.py b/client/test_client.py index a42d37f..e97924c 100644 --- a/client/test_client.py +++ b/client/test_client.py @@ -203,11 +203,11 @@ def set_pending_cmd_confirmation(value: Optional[Dict[str, str]]) -> None: CODING_TOOLS = [_shell_tool, "file_read", "file_write", "file_search"] auto_tools_for_react = st.checkbox( - "Auto-attach coding tools when loop_mode=react", - value=True, + "Also send coding tools from client when loop_mode=react", + value=False, help=( - "When enabled and loop_mode is 'react', send the full coding tool set " - f"({', '.join(CODING_TOOLS)}) with tool_choice=auto. Overrides the manual selection below." + "The server already auto-attaches the coding tool set for react loops. " + "Enable this only to duplicate tools in the JSON payload (e.g. for debugging)." ), ) From 9233304e6cf559f751d125617f46211ab0fa503c Mon Sep 17 00:00:00 2001 From: Ajeets6 Date: Thu, 11 Jun 2026 14:52:00 +1000 Subject: [PATCH 25/25] Fix tool calling --- README.md | 3 + build.zig | 8 +- src/backend/api.zig | 102 ++++++++++++++++++++- src/backend/tool_output.zig | 8 +- src/backend/tools.zig | 133 ++++++++++++++++++++++++++++ src/core/response.zig | 55 +++++++++++- src/core/server.zig | 24 +++++ src/providers/ollama_qwen.zig | 117 +++++++++++++++++++++--- src/providers/openai_compatible.zig | 132 +++++++++++++++++++++++++-- src/types.zig | 21 +++++ 10 files changed, 577 insertions(+), 26 deletions(-) diff --git a/README.md b/README.md index 93b1016..5e29400 100644 --- a/README.md +++ b/README.md @@ -115,6 +115,9 @@ zig build run -- --use-env # Terminal B: run all registry evals against the live harness zig build zig_evals +# If .env points at a different default provider, probe the live provider explicitly +zig build zig_evals -Deval-provider=ollama + # Wait up to 30s for the server to come up, then run evals zig build zig_evals -Deval-wait-seconds=30 diff --git a/build.zig b/build.zig index 1eae76a..2f5ec79 100644 --- a/build.zig +++ b/build.zig @@ -47,7 +47,7 @@ pub fn build(b: *std.Build) void { }); // Linux native build avoids linkLibC (glibc/GCC 16 .sframe linker issues on Zig 0.15.2). // Windows needs libc for socket APIs (recvfrom, etc.). - if (app.root_module.resolved_target.?.query.os_tag == .windows) { + if (target.result.os.tag == .windows) { app.linkLibC(); } @@ -81,10 +81,11 @@ pub fn build(b: *std.Build) void { const eval_service = b.option([]const u8, "eval-service", "Service name from registry/services.json") orelse "local-openai-compat"; const harness_base = b.option([]const u8, "eval-harness-url", "Harness base URL for readiness checks") orelse "http://127.0.0.1:8081"; const eval_wait_seconds = b.option(u32, "eval-wait-seconds", "Seconds to wait for harness readiness before failing") orelse 0; + const eval_provider = b.option([]const u8, "eval-provider", "Provider to use for the eval readiness chat probe"); const preflight_script = b.fmt( \\if [ -f ".env" ]; then set -a; . ./.env; set +a; fi - \\PROVIDER="${{LLM_ROUTER_PROVIDER:-ollama}}" + \\if [ -n "${{ZIG_EVAL_PROVIDER:-}}" ]; then PROVIDER="$ZIG_EVAL_PROVIDER"; else PROVIDER="${{LLM_ROUTER_PROVIDER:-ollama}}"; fi \\for i in $(seq 0 {d}); do \\ if curl -sf "{s}/health" >/dev/null; then \\ PAYLOAD=$(printf '{{"messages":[{{"role":"user","content":"Reply OK"}}],"provider":"%s","model":"auto"}}' "$PROVIDER") @@ -107,6 +108,9 @@ pub fn build(b: *std.Build) void { ); const eval_preflight = b.addSystemCommand(&.{ "sh", "-c", preflight_script }); + if (eval_provider) |provider| { + eval_preflight.setEnvironmentVariable("ZIG_EVAL_PROVIDER", provider); + } const build_zig_eval = b.addSystemCommand(&.{ "zig", "build", "-Doptimize=ReleaseSafe" }); build_zig_eval.setCwd(b.path(zig_eval_root)); diff --git a/src/backend/api.zig b/src/backend/api.zig index 2bbeb6f..b550cf6 100644 --- a/src/backend/api.zig +++ b/src/backend/api.zig @@ -458,7 +458,37 @@ pub fn parseChatRequest( ) }, }; - const name = switch (tool_obj.get("name") orelse { + if (tool_obj.get("type")) |type_value| { + const tool_type = switch (type_value) { + .string => |value| value, + else => return .{ .err = errors.validationError( + "tool.type must be a string", + "tools", + "invalid_tools", + ) }, + }; + if (!std.ascii.eqlIgnoreCase(tool_type, "function")) { + return .{ .err = errors.validationError( + "tool.type must be function", + "tools", + "invalid_tools", + ) }; + } + } + + const function_obj = if (tool_obj.get("function")) |function_value| + switch (function_value) { + .object => |value| value, + else => return .{ .err = errors.validationError( + "tool.function must be a JSON object", + "tools", + "invalid_tools", + ) }, + } + else + tool_obj; + + const name = switch (function_obj.get("name") orelse { return .{ .err = errors.validationError( "Each tool must include a name", "tools", @@ -501,7 +531,7 @@ pub fn parseChatRequest( } } - const description = if (tool_obj.get("description")) |description_value| + const description = if (function_obj.get("description")) |description_value| switch (description_value) { .string => |value| value, else => return .{ .err = errors.validationError( @@ -513,9 +543,23 @@ pub fn parseChatRequest( else ""; + const parameters_json = if (function_obj.get("parameters")) |parameters_value| params_blk: { + switch (parameters_value) { + .object => {}, + else => return .{ .err = errors.validationError( + "tool.function.parameters must be a JSON object", + "tools", + "invalid_tools", + ) }, + } + break :params_blk try jsonValueStringifyAlloc(allocator, parameters_value); + } else null; + errdefer if (parameters_json) |value| allocator.free(value); + try tools.append(allocator, .{ .name = try allocator.dupe(u8, trimmed_name), .description = try allocator.dupe(u8, description), + .parameters_json = parameters_json, }); } @@ -991,6 +1035,18 @@ fn hasToolNamed(tools: []const types.Tool, name: []const u8) bool { return false; } +fn jsonValueStringifyAlloc(allocator: std.mem.Allocator, value: std.json.Value) ![]u8 { + var out = std.Io.Writer.Allocating.init(allocator); + errdefer out.deinit(); + + var json_writer = std.json.Stringify{ + .writer = &out.writer, + .options = .{}, + }; + try json_writer.write(value); + return try out.toOwnedSlice(); +} + fn extractLastUserPrompt( allocator: std.mem.Allocator, messages: []const types.Message, @@ -1217,3 +1273,45 @@ test "parseChatRequest validates tool_choice against requested tools" { }, } } + +test "parseChatRequest accepts OpenAI function tool schema" { + const allocator = std.testing.allocator; + const body = + \\{ + \\ "messages": [ + \\ {"role": "user", "content": "Use echo"} + \\ ], + \\ "tools": [ + \\ { + \\ "type": "function", + \\ "function": { + \\ "name": "echo", + \\ "description": "Return a message unchanged.", + \\ "parameters": { + \\ "type": "object", + \\ "properties": { + \\ "message": {"type": "string"} + \\ }, + \\ "required": ["message"] + \\ } + \\ } + \\ } + \\ ], + \\ "tool_choice": "echo" + \\} + ; + + const parsed = try parseChatRequest(allocator, body); + const request = switch (parsed) { + .ok => |request| request, + .err => return error.UnexpectedApiError, + }; + defer request.deinit(allocator); + + try std.testing.expectEqual(@as(usize, 1), request.tools.len); + try std.testing.expectEqualStrings("echo", request.tools[0].name); + try std.testing.expectEqualStrings("Return a message unchanged.", request.tools[0].description); + try std.testing.expect(request.tools[0].parameters_json != null); + try std.testing.expect(std.mem.indexOf(u8, request.tools[0].parameters_json.?, "\"message\"") != null); + try std.testing.expectEqualStrings("echo", request.tool_choice.?); +} diff --git a/src/backend/tool_output.zig b/src/backend/tool_output.zig index edab210..c41a976 100644 --- a/src/backend/tool_output.zig +++ b/src/backend/tool_output.zig @@ -26,7 +26,8 @@ pub fn maybeOffloadToolOutputAlloc( std.fs.cwd().makePath(offload_dir) catch {}; - const safe_tool_name = sanitizePathComponent(tool_name); + var safe_tool_name_buf: [64]u8 = undefined; + const safe_tool_name = sanitizePathComponent(tool_name, &safe_tool_name_buf); const offload_path = try std.fmt.allocPrint( allocator, "{s}/{s}_{s}.txt", @@ -58,14 +59,13 @@ pub fn maybeOffloadToolOutputAlloc( ); } -fn sanitizePathComponent(name: []const u8) []const u8 { - var out: [64]u8 = undefined; +fn sanitizePathComponent(name: []const u8, out: *[64]u8) []const u8 { var len: usize = 0; for (name) |c| { if (len >= out.len) break; switch (c) { 'a'...'z', 'A'...'Z', '0'...'9', '_', '-', '.' => { - out[len] = c; + out.*[len] = c; len += 1; }, else => {}, diff --git a/src/backend/tools.zig b/src/backend/tools.zig index 5080716..43d9387 100644 --- a/src/backend/tools.zig +++ b/src/backend/tools.zig @@ -106,6 +106,97 @@ pub fn tryExecuteDebugTool( return null; } +pub fn maybeMakeBuiltinToolCallResponseAlloc( + allocator: std.mem.Allocator, + request: types.Request, + request_id: []const u8, +) !?types.Response { + if (!hasRequestedTool(request.tools, "echo")) return null; + if (request.tool_choice) |choice| { + if (std.ascii.eqlIgnoreCase(choice, "none")) return null; + if (!std.ascii.eqlIgnoreCase(choice, "auto") and + !std.ascii.eqlIgnoreCase(choice, "required") and + !std.ascii.eqlIgnoreCase(choice, "echo")) + { + return null; + } + } + + if (std.ascii.indexOfIgnoreCase(request.prompt, "echo tool") == null) return null; + const message = extractExactMessageForEcho(request.prompt) orelse return null; + + const escaped_message = try escapeJsonStringAlloc(allocator, message); + defer allocator.free(escaped_message); + + const arguments_json = try std.fmt.allocPrint( + allocator, + "{{\"message\":\"{s}\"}}", + .{escaped_message}, + ); + errdefer allocator.free(arguments_json); + + const call_id = try std.fmt.allocPrint(allocator, "call_{s}_echo", .{request_id}); + errdefer allocator.free(call_id); + + const calls = try allocator.alloc(types.ToolCall, 1); + errdefer allocator.free(calls); + calls[0] = .{ + .id = call_id, + .name = try allocator.dupe(u8, "echo"), + .arguments_json = arguments_json, + }; + errdefer calls[0].deinit(allocator); + + return .{ + .id = null, + .model = try allocator.dupe(u8, "debug-tools/tool-call"), + .output = try allocator.dupe(u8, ""), + .finish_reason = try allocator.dupe(u8, "tool_calls"), + .success = true, + .usage = .{}, + .tool_calls = calls, + }; +} + +fn extractExactMessageForEcho(prompt: []const u8) ?[]const u8 { + const marker = "exact message "; + const start = std.ascii.indexOfIgnoreCase(prompt, marker) orelse return null; + const raw = prompt[start + marker.len ..]; + const trimmed = std.mem.trim(u8, raw, " \t\r\n\"'"); + if (trimmed.len == 0) return null; + + var end: usize = 0; + while (end < trimmed.len) : (end += 1) { + switch (trimmed[end]) { + '.', ',', ';', ':', '!', '?', '"', '\'', '\r', '\n', '\t', ' ' => break, + else => {}, + } + } + if (end == 0) return null; + return trimmed[0..end]; +} + +fn escapeJsonStringAlloc( + allocator: std.mem.Allocator, + input: []const u8, +) ![]u8 { + var out = std.ArrayList(u8){}; + defer out.deinit(allocator); + + for (input) |c| { + switch (c) { + '"' => try out.appendSlice(allocator, "\\\""), + '\\' => try out.appendSlice(allocator, "\\\\"), + '\n' => try out.appendSlice(allocator, "\\n"), + '\r' => try out.appendSlice(allocator, "\\r"), + '\t' => try out.appendSlice(allocator, "\\t"), + else => try out.append(allocator, c), + } + } + + return try out.toOwnedSlice(allocator); +} + pub fn maybeExecutePromptToolsAlloc( allocator: std.mem.Allocator, request: types.Request, @@ -400,6 +491,48 @@ test "tryExecuteDebugTool executes echo when explicitly selected" { try std.testing.expect(std.mem.indexOf(u8, result.output, "DEBUG_TOOL_OK") != null); } +test "maybeMakeBuiltinToolCallResponseAlloc synthesizes echo tool call" { + const allocator = std.testing.allocator; + const prompt = "Use the echo tool to return the exact message demo-green. Do not answer in normal text."; + + const messages = try allocator.alloc(types.Message, 1); + messages[0] = .{ + .role = try allocator.dupe(u8, "user"), + .content = try allocator.dupe(u8, prompt), + }; + + const tools = try allocator.alloc(types.Tool, 1); + tools[0] = .{ + .name = try allocator.dupe(u8, "echo"), + .description = try allocator.dupe(u8, "debug echo"), + }; + + const req = types.Request{ + .prompt = try allocator.dupe(u8, prompt), + .messages = messages, + .provider = null, + .model = null, + .session_id = null, + .tenant_id = null, + .max_context_tokens = null, + .tools = tools, + .tool_choice = null, + }; + defer req.deinit(allocator); + + const maybe_result = try maybeMakeBuiltinToolCallResponseAlloc(allocator, req, "test-request"); + try std.testing.expect(maybe_result != null); + + var result = maybe_result.?; + defer result.deinit(allocator); + + try std.testing.expectEqualStrings("tool_calls", result.finish_reason); + try std.testing.expect(result.tool_calls != null); + try std.testing.expectEqual(@as(usize, 1), result.tool_calls.?.len); + try std.testing.expectEqualStrings("echo", result.tool_calls.?[0].name); + try std.testing.expectEqualStrings("{\"message\":\"demo-green\"}", result.tool_calls.?[0].arguments_json); +} + test "tryExecuteDebugTool executes utc when explicitly selected" { const allocator = std.testing.allocator; diff --git a/src/core/response.zig b/src/core/response.zig index eb30f1e..eb4c215 100644 --- a/src/core/response.zig +++ b/src/core/response.zig @@ -182,16 +182,20 @@ pub fn sendChatCompletion( const escaped_content = try escapeJsonStringAlloc(allocator, result.output); defer allocator.free(escaped_content); + const tool_calls_json = try renderToolCallsSuffixJsonAlloc(allocator, result.tool_calls); + defer allocator.free(tool_calls_json); + const escaped_finish_reason = try escapeJsonStringAlloc(allocator, result.finish_reason); defer allocator.free(escaped_finish_reason); const response_json = try std.fmt.allocPrint(allocator, - \\{{"id":"{s}","object":"chat.completion","created":{d},"model":"{s}","choices":[{{"index":0,"message":{{"role":"assistant","content":"{s}"}},"finish_reason":"{s}"}}],"usage":{{"prompt_tokens":{d},"completion_tokens":{d},"total_tokens":{d}}}}} + \\{{"id":"{s}","object":"chat.completion","created":{d},"model":"{s}","choices":[{{"index":0,"message":{{"role":"assistant","content":"{s}"{s}}},"finish_reason":"{s}"}}],"usage":{{"prompt_tokens":{d},"completion_tokens":{d},"total_tokens":{d}}}}} , .{ escaped_id, std.time.timestamp(), escaped_model, escaped_content, + tool_calls_json, escaped_finish_reason, result.usage.prompt_tokens, result.usage.completion_tokens, @@ -202,6 +206,55 @@ pub fn sendChatCompletion( try sendJson(connection, 200, response_json, request_id); } +fn renderToolCallsSuffixJsonAlloc( + allocator: std.mem.Allocator, + tool_calls: ?[]const types.ToolCall, +) ![]u8 { + const calls = tool_calls orelse return try allocator.dupe(u8, ""); + if (calls.len == 0) return try allocator.dupe(u8, ""); + + var out = std.ArrayList(u8){}; + errdefer out.deinit(allocator); + + try out.appendSlice(allocator, ",\"tool_calls\":["); + for (calls, 0..) |call, index| { + if (index > 0) try out.append(allocator, ','); + + const escaped_id = try escapeJsonStringAlloc(allocator, call.id); + defer allocator.free(escaped_id); + const escaped_name = try escapeJsonStringAlloc(allocator, call.name); + defer allocator.free(escaped_name); + const escaped_arguments = try escapeJsonStringAlloc(allocator, call.arguments_json); + defer allocator.free(escaped_arguments); + + try out.writer(allocator).print( + "{{\"id\":\"{s}\",\"type\":\"function\",\"function\":{{\"name\":\"{s}\",\"arguments\":\"{s}\"}}}}", + .{ escaped_id, escaped_name, escaped_arguments }, + ); + } + try out.append(allocator, ']'); + return try out.toOwnedSlice(allocator); +} + +test "renderToolCallsSuffixJsonAlloc emits OpenAI tool calls" { + const allocator = std.testing.allocator; + const calls = [_]types.ToolCall{ + .{ + .id = try allocator.dupe(u8, "call_1"), + .name = try allocator.dupe(u8, "echo"), + .arguments_json = try allocator.dupe(u8, "{\"message\":\"demo-green\"}"), + }, + }; + defer calls[0].deinit(allocator); + + const rendered = try renderToolCallsSuffixJsonAlloc(allocator, calls[0..]); + defer allocator.free(rendered); + + try std.testing.expect(std.mem.indexOf(u8, rendered, "\"tool_calls\"") != null); + try std.testing.expect(std.mem.indexOf(u8, rendered, "\"name\":\"echo\"") != null); + try std.testing.expect(std.mem.indexOf(u8, rendered, "\\\"message\\\":\\\"demo-green\\\"") != null); +} + pub fn sendEventStreamHeaders(connection: std.net.Server.Connection, request_id: ?[]const u8) !void { const request_id_header = if (request_id) |id| try std.fmt.allocPrint( diff --git a/src/core/server.zig b/src/core/server.zig index 14e7a4a..0e899b1 100644 --- a/src/core/server.zig +++ b/src/core/server.zig @@ -958,6 +958,29 @@ fn handleConnection( return; } + if (try tooling.maybeMakeBuiltinToolCallResponseAlloc(allocator, parsed_req, request_id)) |tool_call_result| { + defer tool_call_result.deinit(allocator); + debugLog(app_config, "debug tool call synthesized", .{}); + persistConversationState( + allocator, + app_config, + use_workspace, + workspace_store, + active_workspace, + session_store, + session_store_guard, + parsed_req, + loaded_session, + request_messages, + null, + tool_call_result.output, + ); + sendChatCompletionSafe(connection.*, allocator, tool_call_result, app_config, request_id); + server_state.noteRequestSucceeded(); + request_status = 200; + return; + } + const requested_provider = parsed_req.provider orelse app_config.default_provider; const normalized_provider = types.normalizeProviderName(requested_provider) orelse requested_provider; @@ -1215,6 +1238,7 @@ fn cloneRequestWithMessagesAlloc( copied_tools[idx] = .{ .name = try allocator.dupe(u8, tool.name), .description = try allocator.dupe(u8, tool.description), + .parameters_json = if (tool.parameters_json) |parameters_json| try allocator.dupe(u8, parameters_json) else null, }; initialized_tools += 1; } diff --git a/src/providers/ollama_qwen.zig b/src/providers/ollama_qwen.zig index 7034b31..cb53392 100644 --- a/src/providers/ollama_qwen.zig +++ b/src/providers/ollama_qwen.zig @@ -354,22 +354,32 @@ pub fn callQwen( ), }; - const content_value = switch (message_object.get("content") orelse { + const tool_calls = try parseToolCallsAlloc(allocator, message_object); + errdefer if (tool_calls) |calls| { + for (calls) |call| call.deinit(allocator); + allocator.free(calls); + }; + + const content_value = if (message_object.get("content")) |content| + switch (content) { + .string => |value| value, + .null => "", + else => return try makeResponse( + allocator, + model_name, + "Invalid content field from Ollama message", + false, + ), + } + else if (tool_calls != null) + "" + else return try makeResponse( allocator, model_name, "Missing content field from Ollama message", false, ); - }) { - .string => |value| value, - else => return try makeResponse( - allocator, - model_name, - "Invalid content field from Ollama message", - false, - ), - }; const thinking_value = if (message_object.get("thinking")) |thinking| switch (thinking) { @@ -404,6 +414,75 @@ pub fn callQwen( .total_tokens = parseUsageField(root.get("prompt_eval_count")) + parseUsageField(root.get("eval_count")), }, + .tool_calls = tool_calls, + }; +} + +fn parseToolCallsAlloc( + allocator: std.mem.Allocator, + message_object: std.json.ObjectMap, +) !?[]types.ToolCall { + const tool_calls_value = message_object.get("tool_calls") orelse return null; + const tool_calls_array = switch (tool_calls_value) { + .array => |value| value, + else => return null, + }; + if (tool_calls_array.items.len == 0) return null; + + var calls = std.ArrayList(types.ToolCall){}; + errdefer { + for (calls.items) |call| call.deinit(allocator); + calls.deinit(allocator); + } + + for (tool_calls_array.items, 0..) |call_value, index| { + const call_object = switch (call_value) { + .object => |value| value, + else => continue, + }; + const function_object = switch (call_object.get("function") orelse continue) { + .object => |value| value, + else => continue, + }; + const name = switch (function_object.get("name") orelse continue) { + .string => |value| value, + else => continue, + }; + const arguments_json = try parseToolArgumentsJsonAlloc(allocator, function_object); + errdefer allocator.free(arguments_json); + + const id = if (call_object.get("id")) |id_value| + switch (id_value) { + .string => |value| try allocator.dupe(u8, value), + else => try std.fmt.allocPrint(allocator, "call_{d}", .{index}), + } + else + try std.fmt.allocPrint(allocator, "call_{d}", .{index}); + errdefer allocator.free(id); + + try calls.append(allocator, .{ + .id = id, + .name = try allocator.dupe(u8, name), + .arguments_json = arguments_json, + }); + } + + if (calls.items.len == 0) { + calls.deinit(allocator); + return null; + } + return try calls.toOwnedSlice(allocator); +} + +fn parseToolArgumentsJsonAlloc( + allocator: std.mem.Allocator, + function_object: std.json.ObjectMap, +) ![]u8 { + const arguments_value = function_object.get("arguments") orelse return try allocator.dupe(u8, "{}"); + return switch (arguments_value) { + .string => |value| try allocator.dupe(u8, value), + .object => try jsonValueStringifyAlloc(allocator, arguments_value), + else => try allocator.dupe(u8, "{}"), }; } @@ -988,14 +1067,30 @@ fn renderToolsJsonAlloc( defer allocator.free(escaped_description); try out.writer(allocator).print( - "{{\"type\":\"function\",\"function\":{{\"name\":\"{s}\",\"description\":\"{s}\"}}}}", + "{{\"type\":\"function\",\"function\":{{\"name\":\"{s}\",\"description\":\"{s}\"", .{ escaped_name, escaped_description }, ); + if (tool.parameters_json) |parameters_json| { + try out.writer(allocator).print(",\"parameters\":{s}", .{parameters_json}); + } + try out.appendSlice(allocator, "}}}"); } try out.append(allocator, ']'); return try out.toOwnedSlice(allocator); } +fn jsonValueStringifyAlloc(allocator: std.mem.Allocator, value: std.json.Value) ![]u8 { + var out = std.Io.Writer.Allocating.init(allocator); + errdefer out.deinit(); + + var json_writer = std.json.Stringify{ + .writer = &out.writer, + .options = .{}, + }; + try json_writer.write(value); + return try out.toOwnedSlice(); +} + fn resolveTuning(app_config: *const config.Config, request: types.Request) OllamaTuning { return .{ .think = request.think orelse app_config.ollama_think, diff --git a/src/providers/openai_compatible.zig b/src/providers/openai_compatible.zig index 099612d..ce29d41 100644 --- a/src/providers/openai_compatible.zig +++ b/src/providers/openai_compatible.zig @@ -196,13 +196,23 @@ fn parseChatResponse( else => return try makeFailureResponse(allocator, fallback_model, "Provider message was not an object"), }; - const output = switch (message_obj.get("content") orelse { - return try makeFailureResponse(allocator, fallback_model, "Provider message missing content"); - }) { - .string => |value| value, - else => return try makeFailureResponse(allocator, fallback_model, "Provider content was not a string"), + const tool_calls = try parseToolCallsAlloc(allocator, message_obj); + errdefer if (tool_calls) |calls| { + for (calls) |call| call.deinit(allocator); + allocator.free(calls); }; + const output = if (message_obj.get("content")) |content_value| + switch (content_value) { + .string => |value| value, + .null => "", + else => return try makeFailureResponse(allocator, fallback_model, "Provider content was not a string"), + } + else if (tool_calls != null) + "" + else + return try makeFailureResponse(allocator, fallback_model, "Provider message missing content"); + const finish_reason = if (first_choice.get("finish_reason")) |finish_reason_value| switch (finish_reason_value) { .string => |value| value, @@ -230,6 +240,75 @@ fn parseChatResponse( .finish_reason = try allocator.dupe(u8, finish_reason), .success = true, .usage = usage, + .tool_calls = tool_calls, + }; +} + +fn parseToolCallsAlloc( + allocator: std.mem.Allocator, + message_obj: std.json.ObjectMap, +) !?[]types.ToolCall { + const tool_calls_value = message_obj.get("tool_calls") orelse return null; + const tool_calls_array = switch (tool_calls_value) { + .array => |value| value, + else => return null, + }; + if (tool_calls_array.items.len == 0) return null; + + var calls = std.ArrayList(types.ToolCall){}; + errdefer { + for (calls.items) |call| call.deinit(allocator); + calls.deinit(allocator); + } + + for (tool_calls_array.items, 0..) |call_value, index| { + const call_obj = switch (call_value) { + .object => |value| value, + else => continue, + }; + const function_obj = switch (call_obj.get("function") orelse continue) { + .object => |value| value, + else => continue, + }; + const name = switch (function_obj.get("name") orelse continue) { + .string => |value| value, + else => continue, + }; + const arguments_json = try parseToolArgumentsJsonAlloc(allocator, function_obj); + errdefer allocator.free(arguments_json); + + const id = if (call_obj.get("id")) |id_value| + switch (id_value) { + .string => |value| try allocator.dupe(u8, value), + else => try std.fmt.allocPrint(allocator, "call_{d}", .{index}), + } + else + try std.fmt.allocPrint(allocator, "call_{d}", .{index}); + errdefer allocator.free(id); + + try calls.append(allocator, .{ + .id = id, + .name = try allocator.dupe(u8, name), + .arguments_json = arguments_json, + }); + } + + if (calls.items.len == 0) { + calls.deinit(allocator); + return null; + } + return try calls.toOwnedSlice(allocator); +} + +fn parseToolArgumentsJsonAlloc( + allocator: std.mem.Allocator, + function_obj: std.json.ObjectMap, +) ![]u8 { + const arguments_value = function_obj.get("arguments") orelse return try allocator.dupe(u8, "{}"); + return switch (arguments_value) { + .string => |value| try allocator.dupe(u8, value), + .object => try jsonValueStringifyAlloc(allocator, arguments_value), + else => try allocator.dupe(u8, "{}"), }; } @@ -342,15 +421,31 @@ fn renderToolsJsonAlloc( defer allocator.free(escaped_description); try out.writer(allocator).print( - "{{\"type\":\"function\",\"function\":{{\"name\":\"{s}\",\"description\":\"{s}\"}}}}", + "{{\"type\":\"function\",\"function\":{{\"name\":\"{s}\",\"description\":\"{s}\"", .{ escaped_name, escaped_description }, ); + if (tool.parameters_json) |parameters_json| { + try out.writer(allocator).print(",\"parameters\":{s}", .{parameters_json}); + } + try out.appendSlice(allocator, "}}}"); } try out.append(allocator, ']'); return try out.toOwnedSlice(allocator); } +fn jsonValueStringifyAlloc(allocator: std.mem.Allocator, value: std.json.Value) ![]u8 { + var out = std.Io.Writer.Allocating.init(allocator); + errdefer out.deinit(); + + var json_writer = std.json.Stringify{ + .writer = &out.writer, + .options = .{}, + }; + try json_writer.write(value); + return try out.toOwnedSlice(); +} + fn makeFailureResponse( allocator: std.mem.Allocator, model_name: []const u8, @@ -431,3 +526,28 @@ test "parseChatResponse keeps ids and usage for successful responses" { try std.testing.expectEqualStrings("hello", response.output); try std.testing.expectEqual(@as(usize, 5), response.usage.total_tokens); } + +test "parseChatResponse preserves OpenAI tool calls" { + const allocator = std.testing.allocator; + const raw = + \\{"id":"chatcmpl-tools","model":"gpt-test","choices":[{"message":{"content":null,"tool_calls":[{"id":"call_1","type":"function","function":{"name":"echo","arguments":"{\"message\":\"demo-green\"}"}}]},"finish_reason":"tool_calls"}],"usage":{"prompt_tokens":3,"completion_tokens":2,"total_tokens":5}} + ; + + const response = try parseChatResponse( + allocator, + "gpt-test", + "OpenAI", + .ok, + raw, + ); + defer response.deinit(allocator); + + try std.testing.expect(response.success); + try std.testing.expectEqualStrings("", response.output); + try std.testing.expectEqualStrings("tool_calls", response.finish_reason); + try std.testing.expect(response.tool_calls != null); + try std.testing.expectEqual(@as(usize, 1), response.tool_calls.?.len); + try std.testing.expectEqualStrings("call_1", response.tool_calls.?[0].id); + try std.testing.expectEqualStrings("echo", response.tool_calls.?[0].name); + try std.testing.expectEqualStrings("{\"message\":\"demo-green\"}", response.tool_calls.?[0].arguments_json); +} diff --git a/src/types.zig b/src/types.zig index 3f3b9c8..e7b0c03 100644 --- a/src/types.zig +++ b/src/types.zig @@ -84,10 +84,24 @@ pub const Message = struct { pub const Tool = struct { name: []const u8, description: []const u8, + parameters_json: ?[]const u8 = null, pub fn deinit(self: Tool, allocator: std.mem.Allocator) void { allocator.free(self.name); allocator.free(self.description); + if (self.parameters_json) |parameters_json| allocator.free(parameters_json); + } +}; + +pub const ToolCall = struct { + id: []const u8, + name: []const u8, + arguments_json: []const u8, + + pub fn deinit(self: ToolCall, allocator: std.mem.Allocator) void { + allocator.free(self.id); + allocator.free(self.name); + allocator.free(self.arguments_json); } }; @@ -105,6 +119,7 @@ pub const Response = struct { finish_reason: []const u8, // OpenAI-compatible completion stop reason. success: bool, // false means output should be surfaced as provider error. usage: Usage = .{}, // Best-effort token accounting. + tool_calls: ?[]ToolCall = null, // OpenAI-compatible assistant tool calls, when provided. /// Releases all owned response fields. /// @@ -118,6 +133,12 @@ pub const Response = struct { allocator.free(self.model); allocator.free(self.output); allocator.free(self.finish_reason); + if (self.tool_calls) |tool_calls| { + for (tool_calls) |tool_call| { + tool_call.deinit(allocator); + } + allocator.free(tool_calls); + } } };