From a365aac31c45573630274915c8f697b4a6af0954 Mon Sep 17 00:00:00 2001 From: Tobias Garcia Date: Mon, 1 Jun 2026 12:22:11 +0900 Subject: [PATCH] improvement: base logic for image interface --- agent_core/__init__.py | 2 + agent_core/core/image_gen_interface.py | 12 + agent_core/core/impl/image_gen/__init__.py | 6 + agent_core/core/impl/image_gen/interface.py | 582 ++++++++++++++++++ agent_core/core/models/factory.py | 22 +- agent_core/core/models/model_registry.py | 11 + agent_core/core/models/types.py | 1 + app/agent_base.py | 53 ++ app/config.py | 15 + app/data/action/generate_image.py | 532 ++-------------- app/image_gen_interface.py | 55 ++ app/internal_action_interface.py | 22 + app/main.py | 13 +- app/ui_layer/adapters/browser_adapter.py | 17 + .../src/pages/Settings/ModelSettings.tsx | 115 ++++ .../src/store/selectors/modelSettings.ts | 2 + .../src/store/slices/modelSettingsSlice.ts | 29 + app/ui_layer/settings/model_settings.py | 29 + 18 files changed, 1025 insertions(+), 493 deletions(-) create mode 100644 agent_core/core/image_gen_interface.py create mode 100644 agent_core/core/impl/image_gen/__init__.py create mode 100644 agent_core/core/impl/image_gen/interface.py create mode 100644 app/image_gen_interface.py diff --git a/agent_core/__init__.py b/agent_core/__init__.py index 1d907f95..256dfd4b 100644 --- a/agent_core/__init__.py +++ b/agent_core/__init__.py @@ -31,6 +31,7 @@ ) from agent_core.core.embedding_interface import EmbeddingInterface from agent_core.core.vlm_interface import VLMInterface +from agent_core.core.image_gen_interface import ImageGenInterface from agent_core.core.database_interface import DatabaseInterface from agent_core.core.trigger import Trigger from agent_core.core.task import Task, TodoItem, TodoStatus @@ -272,6 +273,7 @@ # Interfaces "EmbeddingInterface", "VLMInterface", + "ImageGenInterface", "DatabaseInterface", "GeminiClient", "GeminiAPIError", diff --git a/agent_core/core/image_gen_interface.py b/agent_core/core/image_gen_interface.py new file mode 100644 index 00000000..ea9a4fff --- /dev/null +++ b/agent_core/core/image_gen_interface.py @@ -0,0 +1,12 @@ +# -*- coding: utf-8 -*- +""" +Image generation interface. + +Re-exports ImageGenInterface from the impl module for backward compatibility. +Use the runtime-specific wrapper (app/image_gen_interface.py) when running +inside CraftBot — it injects the appropriate state and usage hooks. +""" + +from agent_core.core.impl.image_gen import ImageGenInterface + +__all__ = ["ImageGenInterface"] diff --git a/agent_core/core/impl/image_gen/__init__.py b/agent_core/core/impl/image_gen/__init__.py new file mode 100644 index 00000000..0cd06833 --- /dev/null +++ b/agent_core/core/impl/image_gen/__init__.py @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- +"""Image generation interface package.""" + +from agent_core.core.impl.image_gen.interface import ImageGenInterface + +__all__ = ["ImageGenInterface"] diff --git a/agent_core/core/impl/image_gen/interface.py b/agent_core/core/impl/image_gen/interface.py new file mode 100644 index 00000000..97149085 --- /dev/null +++ b/agent_core/core/impl/image_gen/interface.py @@ -0,0 +1,582 @@ +# -*- coding: utf-8 -*- +""" +Image generation interface for agent_core. + +Supports OpenAI (gpt-image-2) and Gemini (gemini-*-image-preview) providers. +Provider logic lives here; the action (generate_image.py) just delegates. + +Mirrors VLMInterface structure — constructor, reinitialize(), hooks, dispatch. +""" + +from __future__ import annotations + +import asyncio +import base64 +import io +import os +import tempfile +import urllib.request as _urllib_request +from datetime import datetime +from typing import Any, Dict, List, Optional + +from agent_core.core.hooks import ( + GetTokenCountHook, + ReportUsageHook, + SetTokenCountHook, +) +from agent_core.utils.logger import logger + +try: + from PIL import Image as _PilImage # type: ignore[import] +except ImportError: + _PilImage = None # type: ignore[assignment] + +# OpenAI supports only three canvas sizes. Document the closest-fit mapping so +# callers understand the constraint at a glance. +# +# True 16:9 (1920×1080) and 9:16 are not available. The values below use the +# widest/tallest canvases OpenAI exposes; warn at runtime when the mismatch +# matters (resolution ≥ 2K or aspect ratio 16:9/9:16). +_OPENAI_ASPECT_MAP: Dict[str, str] = { + "1:1": "1024x1024", + "3:4": "1024x1536", + "4:3": "1536x1024", + "9:16": "1024x1536", # nearest fit — true 9:16 not supported + "16:9": "1536x1024", # nearest fit — true 16:9 not supported +} +_OPENAI_INEXACT_RATIOS = {"16:9", "9:16"} + +_OPENAI_QUALITY_MAP: Dict[str, str] = { + "1K": "medium", + "2K": "high", + "4K": "high", # API tops out at 1536px; warn caller +} + +# ── Error message catalog (provider-keyed, English) ────────────────────────── +# Used to build human-readable RuntimeErrors that flow back through the +# action-selection loop, matching VLMInterface's raise-don't-return pattern. +_ERR: Dict[str, Dict[str, str]] = { + "openai": { + "quota": "OpenAI API rate limit or quota exceeded", + "invalid_key": "Invalid OpenAI API key — verify your key in settings.", + "content_policy": "Request blocked by OpenAI content policy — modify your prompt.", + "model_not_found": ( + "OpenAI model not available — ensure your account has access to gpt-image-2." + ), + "generic": "OpenAI image generation failed", + }, + "gemini": { + "quota": "Gemini API rate limit or quota exceeded", + "invalid_key": "Invalid Gemini API key — verify your Google API key in settings.", + "content_policy": "Request blocked by Gemini safety filters — modify your prompt.", + "model_not_found": ( + "Gemini model not available — ensure your account has access to the " + "image generation preview model." + ), + "generic": "Gemini image generation failed", + }, +} + + +def _classify_error(provider: str, exc: Exception) -> str: + """Map a raw exception message to a catalog entry for the given provider.""" + msg = str(exc).lower() + catalog = _ERR.get(provider, _ERR["openai"]) + if "quota" in msg or "rate" in msg or "billing" in msg or "insufficient_quota" in msg: + return catalog["quota"] + if "invalid" in msg and "key" in msg or "invalid_api_key" in msg: + return catalog["invalid_key"] + if "content_policy" in msg or "safety" in msg or "blocked" in msg: + return catalog["content_policy"] + if "not found" in msg or "404" in msg or "not available" in msg: + return catalog["model_not_found"] + # Do NOT include the raw exception — SDK error messages can contain API key fragments. + return catalog["generic"] + + +# ── File-path helpers ───────────────────────────────────────────────────────── + +def _build_save_path( + output_path: str, + timestamp: str, + index: int, + total: int, +) -> str: + if output_path: + if total > 1: + base, ext = os.path.splitext(output_path) + ext = ext or ".png" + return f"{base}_{index + 1}{ext}" + save_path = output_path + if not os.path.splitext(save_path)[1]: + save_path += ".png" + return save_path + return os.path.join( + tempfile.gettempdir(), f"generated_image_{timestamp}_{index + 1}.png" + ) + + +def _save_pil_image(pil_image: Any, save_path: str) -> None: + parent = os.path.dirname(os.path.abspath(save_path)) + if parent: + os.makedirs(parent, exist_ok=True) + pil_image.save(save_path, "PNG") + + +def _to_pil_image(img_data: Any) -> Any: + """Convert raw image data (bytes or base64 str) to a PIL Image.""" + if isinstance(img_data, str): + return _PilImage.open(io.BytesIO(base64.b64decode(img_data))) + if isinstance(img_data, bytes): + return _PilImage.open(io.BytesIO(img_data)) + return img_data + + +# ── Gemini response helpers ─────────────────────────────────────────────────── + +def _extract_gemini_images(response: Any) -> List[Any]: + images: List[Any] = [] + if hasattr(response, "candidates") and response.candidates: + for candidate in response.candidates: + if not ( + hasattr(candidate, "content") + and hasattr(candidate.content, "parts") + ): + continue + for part in candidate.content.parts: + if ( + hasattr(part, "inline_data") + and part.inline_data + and hasattr(part.inline_data, "mime_type") + and part.inline_data.mime_type.startswith("image/") + ): + images.append(part.inline_data.data) + if not images and hasattr(response, "images"): + for img in response.images: + if hasattr(img, "data"): + images.append(img.data) + return images + + +def _gemini_block_reason(response: Any) -> Optional[str]: + if hasattr(response, "prompt_feedback"): + fb = response.prompt_feedback + if hasattr(fb, "block_reason") and fb.block_reason: + return str(fb.block_reason) + if hasattr(response, "candidates") and response.candidates: + for c in response.candidates: + if hasattr(c, "finish_reason") and c.finish_reason: + reason = str(c.finish_reason) + if "SAFETY" in reason.upper(): + return reason + return None + + +# ── Main interface ──────────────────────────────────────────────────────────── + +class ImageGenInterface: + """Image generation interface with multi-provider support. + + Supports OpenAI (gpt-image-2) and Gemini image generation models. + Uses hooks for state access and usage reporting, mirroring VLMInterface. + + Args: + provider: Provider name ("openai", "gemini"). + model: Model name override (None = use registry default). + api_key: API key (required for openai/gemini). + base_url: Base URL override (unused by image providers currently). + deferred: If True, allow deferred initialization without raising. + get_token_count: Hook to read current token count from state. + set_token_count: Hook to write token count to state. + report_usage: Optional hook to report usage for cost tracking. + """ + + def __init__( + self, + *, + provider: Optional[str] = None, + model: Optional[str] = None, + api_key: Optional[str] = None, + base_url: Optional[str] = None, + deferred: bool = False, + get_token_count: Optional[GetTokenCountHook] = None, + set_token_count: Optional[SetTokenCountHook] = None, + report_usage: Optional[ReportUsageHook] = None, + ) -> None: + self.provider = provider + self._initialized = False + self._deferred = deferred + self._init_api_key = api_key + self._init_base_url = base_url + + self._get_token_count = get_token_count or (lambda: 0) + self._set_token_count = set_token_count or (lambda x: None) + self._report_usage = report_usage + + # Defer import to avoid circular dependency (same pattern as VLMInterface) + from app.models.factory import ModelFactory + from app.models.types import InterfaceType + + ctx = ModelFactory.create( + provider=provider, + interface=InterfaceType.IMAGE_GEN, + model_override=model, + api_key=api_key, + base_url=base_url, + deferred=deferred, + ) + + self.provider = ctx["provider"] + self.model = ctx["model"] + self.client = ctx["client"] # OpenAI client or None + self._gemini_client = ctx["gemini_client"] + self._initialized = ctx.get("initialized", False) + + @property + def is_initialized(self) -> bool: + return self._initialized + + def reinitialize( + self, + provider: Optional[str] = None, + api_key: Optional[str] = None, + base_url: Optional[str] = None, + ) -> bool: + """Reinitialize with new provider/key settings (e.g. after config change).""" + from app.models.factory import ModelFactory + from app.models.types import InterfaceType + + target_provider = provider or self.provider + + if api_key is None or base_url is None: + from app.config import get_api_key, get_base_url + + target_api_key = api_key if api_key is not None else get_api_key(target_provider) + target_base_url = base_url if base_url is not None else get_base_url(target_provider) + else: + target_api_key = api_key + target_base_url = base_url + + try: + from app.config import get_image_gen_model as _get_model # type: ignore[import] + + target_model = _get_model() + except Exception: + target_model = None + + try: + ctx = ModelFactory.create( + provider=target_provider, + interface=InterfaceType.IMAGE_GEN, + model_override=target_model, + api_key=target_api_key, + base_url=target_base_url, + deferred=False, + ) + self.provider = ctx["provider"] + self.model = ctx["model"] + self.client = ctx["client"] + self._gemini_client = ctx["gemini_client"] + self._initialized = ctx.get("initialized", False) + logger.info( + f"[IMAGE_GEN] Reinitialized: provider={self.provider}, model={self.model}" + ) + return self._initialized + except Exception as e: + logger.error(f"[IMAGE_GEN] Reinitialize failed: {e}", exc_info=True) + return False + + # ─────────────────────────── Public API ────────────────────────────────── + + def generate_image( + self, + prompt: str, + resolution: str = "1K", + aspect_ratio: str = "1:1", + number_of_images: int = 1, + output_path: str = "", + negative_prompt: str = "", + reference_images: Optional[List[str]] = None, + safety_filter_level: str = "block_medium_and_above", + ) -> List[str]: + """Generate image(s) from a text prompt. + + Args: + prompt: Text description of the image to generate. + resolution: "1K", "2K", or "4K". + aspect_ratio: "1:1", "3:4", "4:3", "9:16", or "16:9". + number_of_images: Number of images to generate (1–4). + output_path: Absolute path for the output file. Timestamped temp + path used when empty. For multiple images the index is appended + before the extension. + negative_prompt: Elements to avoid (Gemini native; appended to + prompt for OpenAI which has no dedicated parameter). + reference_images: List of absolute paths to reference images. + **Provider semantics differ**: Gemini treats these as style + guidance; OpenAI's images.edit treats them as compositional / + mask inputs. Warn users accordingly. + safety_filter_level: Gemini safety threshold + ("block_none", "block_only_high", "block_medium_and_above", + "block_low_and_above"). Ignored by OpenAI. + + Returns: + List of absolute file paths to the generated PNG files. + + Raises: + RuntimeError: If the provider is unsupported or generation fails. + """ + if _PilImage is None: + raise RuntimeError( + "Pillow is required for image generation. " + "Install with: pip install Pillow" + ) + + if not prompt: + raise ValueError("prompt is required") + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + ref = reference_images or [] + + if self.provider == "openai": + paths = self._openai_generate( + prompt=prompt, + resolution=resolution, + aspect_ratio=aspect_ratio, + number_of_images=number_of_images, + output_path=output_path, + negative_prompt=negative_prompt, + reference_images=ref, + timestamp=timestamp, + ) + elif self.provider == "gemini": + paths = self._gemini_generate( + prompt=prompt, + resolution=resolution, + aspect_ratio=aspect_ratio, + number_of_images=number_of_images, + output_path=output_path, + negative_prompt=negative_prompt, + reference_images=ref, + safety_filter_level=safety_filter_level, + timestamp=timestamp, + ) + else: + raise RuntimeError( + f"Provider '{self.provider}' does not support image generation. " + "Set image_gen_provider to 'openai' or 'gemini' in settings.json." + ) + + logger.info( + f"[IMAGE_GEN] Generated {len(paths)} image(s) via {self.provider} " + f"(model={self.model})" + ) + return paths + + async def generate_image_async(self, **kwargs) -> List[str]: + """Async wrapper — runs generate_image in a thread pool.""" + loop = asyncio.get_event_loop() + return await loop.run_in_executor(None, lambda: self.generate_image(**kwargs)) + + # ──────────────────────── OpenAI provider ──────────────────────────────── + + def _openai_generate( + self, + *, + prompt: str, + resolution: str, + aspect_ratio: str, + number_of_images: int, + output_path: str, + negative_prompt: str, + reference_images: List[str], + timestamp: str, + ) -> List[str]: + size = _OPENAI_ASPECT_MAP.get(aspect_ratio, "1024x1024") + if aspect_ratio in _OPENAI_INEXACT_RATIOS: + logger.warning( + f"[IMAGE_GEN] OpenAI does not support true {aspect_ratio}. " + f"Using closest available canvas {size} (~3:2 / ~2:3)." + ) + + quality = _OPENAI_QUALITY_MAP.get(resolution, "medium") + if resolution == "4K": + logger.warning( + "[IMAGE_GEN] OpenAI image generation tops out at 1536px; " + "'4K' is mapped to quality='high' (no true 4K output)." + ) + + full_prompt = prompt + if negative_prompt: + full_prompt += f"\n\nAvoid: {negative_prompt}" + + if reference_images: + logger.warning( + "[IMAGE_GEN] OpenAI reference_images uses images.edit(), which treats " + "inputs as compositional/mask data — not style guidance as in Gemini. " + "Output may differ significantly from the Gemini provider." + ) + + try: + valid_refs = [p for p in reference_images if os.path.isfile(p)] + if valid_refs: + image_files = [open(p, "rb") for p in valid_refs] + try: + response = self.client.images.edit( + model=self.model, + image=image_files, + prompt=full_prompt, + n=number_of_images, + size=size, + quality=quality, + ) + finally: + for f in image_files: + f.close() + else: + response = self.client.images.generate( + model=self.model, + prompt=full_prompt, + n=number_of_images, + size=size, + quality=quality, + ) + except Exception as exc: + raise RuntimeError(_classify_error("openai", exc)) from exc + + images_bytes: List[bytes] = [] + for item in response.data: + if item.b64_json: + images_bytes.append(base64.b64decode(item.b64_json)) + elif item.url: + with _urllib_request.urlopen(item.url) as r: + images_bytes.append(r.read()) + + if not images_bytes: + raise RuntimeError( + "OpenAI returned no image data — try rephrasing your prompt." + ) + + paths: List[str] = [] + for i, raw in enumerate(images_bytes[:number_of_images]): + save_path = _build_save_path(output_path, timestamp, i, len(images_bytes)) + pil_image = _PilImage.open(io.BytesIO(raw)) + _save_pil_image(pil_image, save_path) + paths.append(save_path) + + return paths + + # ──────────────────────── Gemini provider ──────────────────────────────── + + def _gemini_generate( + self, + *, + prompt: str, + resolution: str, + aspect_ratio: str, + number_of_images: int, + output_path: str, + negative_prompt: str, + reference_images: List[str], + safety_filter_level: str, + timestamp: str, + ) -> List[str]: + try: + from google import genai # type: ignore[import] + from google.genai import types as gtypes # type: ignore[import] + except ImportError as exc: + raise RuntimeError( + f"google-genai is required for Gemini image generation. " + f"Install with `pip install google-genai`: {exc}" + ) from exc + + if not self._gemini_client: + raise RuntimeError( + "Gemini API key is not configured. " + "Set 'image_gen_provider' to 'openai' or add a Google API key in settings." + ) + + # Reference images as Gemini content parts (style guidance) + image_parts: List[Any] = [] + for ref_path in reference_images: + if not os.path.isfile(ref_path): + continue + try: + with open(ref_path, "rb") as f: + data = f.read() + ext = os.path.splitext(ref_path)[1].lower() + mime = { + ".png": "image/png", + ".jpg": "image/jpeg", + ".jpeg": "image/jpeg", + ".gif": "image/gif", + ".webp": "image/webp", + }.get(ext, "image/png") + image_parts.append(gtypes.Part.from_bytes(data=data, mime_type=mime)) + except Exception: + pass + + gen_prompt = f"Generate an image based on the following description:\n\n{prompt}" + gen_prompt += f"\n\nImage specifications:\n- Resolution: {resolution}\n- Aspect ratio: {aspect_ratio}\n- Number of variations: {number_of_images}" + if negative_prompt: + gen_prompt += f"\n- Avoid: {negative_prompt}" + + content_parts = image_parts + [gen_prompt] + + safety_settings = None + _threshold_map = { + "block_only_high": "BLOCK_ONLY_HIGH", + "block_medium_and_above": "BLOCK_MEDIUM_AND_ABOVE", + "block_low_and_above": "BLOCK_LOW_AND_ABOVE", + } + if safety_filter_level != "block_none": + threshold = _threshold_map.get(safety_filter_level, "BLOCK_MEDIUM_AND_ABOVE") + safety_settings = [ + gtypes.SafetySetting(category=cat, threshold=threshold) + for cat in ( + "HARM_CATEGORY_HARASSMENT", + "HARM_CATEGORY_HATE_SPEECH", + "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "HARM_CATEGORY_DANGEROUS_CONTENT", + ) + ] + + config = gtypes.GenerateContentConfig( + candidate_count=1, + response_modalities=["TEXT", "IMAGE"], + image_config=gtypes.ImageConfig(image_size=resolution), + safety_settings=safety_settings, + ) + + try: + api_key = self._gemini_client._api_key + client = genai.Client(api_key=api_key) + response = client.models.generate_content( + model=self.model, + contents=content_parts, + config=config, + ) + except Exception as exc: + raise RuntimeError(_classify_error("gemini", exc)) from exc + + images_data = _extract_gemini_images(response) + + if not images_data: + block_reason = _gemini_block_reason(response) + if block_reason: + raise RuntimeError( + f"Gemini blocked the request (safety filter: {block_reason}). " + "Try modifying your prompt or adjusting safety_filter_level." + ) + raise RuntimeError( + "Gemini returned no image data — try rephrasing your prompt or " + "check that your API key has access to image generation." + ) + + paths: List[str] = [] + for i, img_data in enumerate(images_data[:number_of_images]): + save_path = _build_save_path(output_path, timestamp, i, len(images_data)) + pil_image = _to_pil_image(img_data) + _save_pil_image(pil_image, save_path) + paths.append(save_path) + + return paths diff --git a/agent_core/core/models/factory.py b/agent_core/core/models/factory.py index 531d60d9..a2476e18 100644 --- a/agent_core/core/models/factory.py +++ b/agent_core/core/models/factory.py @@ -127,7 +127,27 @@ def create( raise ValueError(f"Unsupported provider: {provider}") cfg = PROVIDER_CONFIG[provider] - model = model_override or MODEL_REGISTRY[provider][interface] + model = model_override or MODEL_REGISTRY[provider].get(interface) + if model is None: + if deferred: + return { + "provider": provider, + "model": None, + "client": None, + "gemini_client": None, + "remote_url": None, + "byteplus": None, + "anthropic_client": None, + "bedrock_client": None, + "initialized": False, + } + supported = ", ".join( + p for p, caps in MODEL_REGISTRY.items() if caps.get(interface) + ) + raise ValueError( + f"Provider '{provider}' does not support {interface.value}. " + f"Supported providers: {supported}" + ) # Use provided base_url or fall back to default resolved_base_url = base_url or cfg.default_base_url diff --git a/agent_core/core/models/model_registry.py b/agent_core/core/models/model_registry.py index 52bb37e9..a953cc9b 100644 --- a/agent_core/core/models/model_registry.py +++ b/agent_core/core/models/model_registry.py @@ -8,46 +8,55 @@ InterfaceType.LLM: "gpt-5.2-2025-12-11", InterfaceType.VLM: "gpt-5.2-2025-12-11", InterfaceType.EMBEDDING: "text-embedding-3-small", + InterfaceType.IMAGE_GEN: "gpt-image-2", }, "gemini": { InterfaceType.LLM: "gemini-2.5-pro", InterfaceType.VLM: "gemini-2.5-pro", InterfaceType.EMBEDDING: "text-embedding-004", + InterfaceType.IMAGE_GEN: "gemini-3-pro-image", }, "anthropic": { InterfaceType.LLM: "claude-sonnet-4-5-20250929", InterfaceType.VLM: "claude-sonnet-4-5-20250929", InterfaceType.EMBEDDING: None, # Anthropic does not provide native embedding models + InterfaceType.IMAGE_GEN: None, }, "byteplus": { InterfaceType.LLM: "seed-2-0-pro-260328", InterfaceType.VLM: "seed-2-0-pro-260328", InterfaceType.EMBEDDING: "skylark-embedding-vision-250615", + InterfaceType.IMAGE_GEN: None, }, "remote": { InterfaceType.LLM: "llama3.2:3b", InterfaceType.VLM: "llava:7b", InterfaceType.EMBEDDING: "nomic-embed-text", + InterfaceType.IMAGE_GEN: None, }, "minimax": { InterfaceType.LLM: "MiniMax-Text-01", InterfaceType.VLM: "MiniMax-VL-01", InterfaceType.EMBEDDING: None, + InterfaceType.IMAGE_GEN: None, }, "deepseek": { InterfaceType.LLM: "deepseek-chat", InterfaceType.VLM: None, InterfaceType.EMBEDDING: None, + InterfaceType.IMAGE_GEN: None, }, "moonshot": { InterfaceType.LLM: "kimi-k2.5", InterfaceType.VLM: "moonshot-v1-8k-vision-preview", InterfaceType.EMBEDDING: None, + InterfaceType.IMAGE_GEN: None, }, "grok": { InterfaceType.LLM: "grok-3", InterfaceType.VLM: "grok-4-0709", InterfaceType.EMBEDDING: None, + InterfaceType.IMAGE_GEN: None, }, "openrouter": { # OpenRouter slugs follow `/` format. Default to a Claude @@ -55,6 +64,7 @@ InterfaceType.LLM: "anthropic/claude-sonnet-4.5", InterfaceType.VLM: "anthropic/claude-sonnet-4.5", InterfaceType.EMBEDDING: None, + InterfaceType.IMAGE_GEN: None, }, "bedrock": { # Default to Claude Haiku 4.5 — best price/performance on Bedrock with @@ -71,5 +81,6 @@ InterfaceType.LLM: "us.anthropic.claude-haiku-4-5-20251001-v1:0", InterfaceType.VLM: "us.anthropic.claude-haiku-4-5-20251001-v1:0", InterfaceType.EMBEDDING: "amazon.titan-embed-text-v2:0", + InterfaceType.IMAGE_GEN: None, }, } diff --git a/agent_core/core/models/types.py b/agent_core/core/models/types.py index 9f282e33..3b23f749 100644 --- a/agent_core/core/models/types.py +++ b/agent_core/core/models/types.py @@ -21,3 +21,4 @@ class InterfaceType(str, Enum): LLM = "llm" VLM = "vlm" EMBEDDING = "embedding" + IMAGE_GEN = "image_gen" diff --git a/app/agent_base.py b/app/agent_base.py index 4272c97b..58ba0542 100644 --- a/app/agent_base.py +++ b/app/agent_base.py @@ -70,6 +70,7 @@ LLMConsecutiveFailureError, ) from app.vlm_interface import VLMInterface +from app.image_gen_interface import ImageGenInterface from app.database_interface import DatabaseInterface from app.logger import logger from agent_core import ( @@ -163,6 +164,8 @@ def __init__( llm_model: str | None = None, vlm_provider: str | None = None, vlm_model: str | None = None, + image_gen_provider: str | None = None, + image_gen_model: str | None = None, deferred_init: bool = False, ) -> None: """ @@ -179,6 +182,8 @@ def __init__( llm_model: Model name override (None = use registry default). vlm_provider: Provider name for VLM (defaults to llm_provider if None). vlm_model: VLM model name override (None = use registry default). + image_gen_provider: Provider name for image generation (openai or gemini). + image_gen_model: Image gen model override (None = use registry default). deferred_init: If True, allow LLM/VLM initialization to be deferred until API key is configured (useful for first-time setup). """ @@ -212,6 +217,17 @@ def __init__( deferred=deferred_init, ) + # Image generation uses its own provider/model settings + from app.config import get_image_gen_provider as _get_img_prov + _img_provider = image_gen_provider or _get_img_prov() + _img_api_key = get_api_key(_img_provider) + self.image_gen = ImageGenInterface( + provider=_img_provider, + model=image_gen_model, + api_key=_img_api_key, + deferred=True, # always deferred — many users won't have an image-gen key + ) + self.event_stream_manager = EventStreamManager( self.llm, agent_file_system_path=AGENT_FILE_SYSTEM_PATH, @@ -308,6 +324,7 @@ def __init__( self.task_manager, self.state_manager, vlm_interface=self.vlm, + image_gen_interface=self.image_gen, memory_manager=self.memory_manager, context_engine=self.context_engine, ) @@ -2998,6 +3015,42 @@ def reinitialize_llm(self, provider: str | None = None) -> bool: ) return llm_ok and vlm_ok + def reinitialize_image_gen(self, provider: str | None = None) -> bool: + """Reinitialize the image generation interface with updated configuration. + + Creates a fresh ImageGenInterface instance rather than mutating the + existing one, so any in-flight action that holds a reference to the + old instance completes cleanly against the old provider/client. + + Args: + provider: Optional provider to switch to. If None, reads from settings. + + Returns: + True if reinitialization was successful. + """ + from app.config import get_image_gen_provider, get_api_key, get_image_gen_model + from app.image_gen_interface import ImageGenInterface + from app.internal_action_interface import InternalActionInterface + + target_provider = provider or get_image_gen_provider() + api_key = get_api_key(target_provider) + model = get_image_gen_model() + + new_interface = ImageGenInterface( + provider=target_provider, + model=model, + api_key=api_key, + deferred=False, + ) + ok = new_interface.is_initialized + if ok: + self.image_gen = new_interface + InternalActionInterface.image_gen_interface = new_interface + logger.info( + f"[AGENT] Image gen reinitialized: provider={target_provider}, success={ok}" + ) + return ok + @property def is_llm_initialized(self) -> bool: """Check if the LLM interface is properly initialized.""" diff --git a/app/config.py b/app/config.py index 3128bce6..77caad45 100644 --- a/app/config.py +++ b/app/config.py @@ -91,8 +91,10 @@ def _get_default_settings() -> Dict[str, Any]: "model": { "llm_provider": "anthropic", "vlm_provider": "anthropic", + "image_gen_provider": "openai", "llm_model": None, "vlm_model": None, + "image_gen_model": None, "slow_mode": False, "slow_mode_tpm_limit": 30000, }, @@ -215,6 +217,19 @@ def get_vlm_model() -> Optional[str]: return settings.get("model", {}).get("vlm_model") +def get_image_gen_provider() -> str: + """Get configured image generation provider.""" + settings = get_settings() + model = settings.get("model", {}) + return model.get("image_gen_provider") or model.get("vlm_provider", "openai") + + +def get_image_gen_model() -> Optional[str]: + """Get configured image generation model override (or None for default).""" + settings = get_settings() + return settings.get("model", {}).get("image_gen_model") + + def get_api_key(provider: str) -> str: """Get API key for a provider. diff --git a/app/data/action/generate_image.py b/app/data/action/generate_image.py index c51e7d6e..97ef21ac 100644 --- a/app/data/action/generate_image.py +++ b/app/data/action/generate_image.py @@ -3,13 +3,11 @@ @action( name="generate_image", - description="""Generates an image using either OpenAI's Images 2.0 (gpt-image-2) or Google's Nano Banana 2 (gemini-3.1-flash-image-preview) model. -- Automatically selects the provider based on which API key(s) are configured -- If only one API key is set, that provider is used automatically -- If both keys are configured, asks the user which provider to use and remembers the choice -- If no API keys are configured, returns an error with setup instructions -- Supports 1K, 2K, or 4K resolution and multiple aspect ratios -- TIP: When generating multiple images for the same project or related work, use 'reference_images' parameter with previously generated images to maintain consistent style across all outputs""", + description="""Generates an image from a text prompt using the configured image generation provider (OpenAI gpt-image-2 or Google Gemini). +- The provider is determined by 'image_gen_provider' in settings.json (default: openai). Supported: openai, gemini. +- Saves the result as a PNG file to output_path, or a timestamped temp file if omitted. +- TIP: When generating multiple related images, pass previously generated paths via 'reference_images' to maintain style consistency. +- NOTE: reference_images semantics differ by provider — Gemini uses them as style guidance; OpenAI's images.edit treats them as compositional/mask inputs.""", default=True, mode="CLI", action_sets=["content_creation", "image", "document_processing"], @@ -17,51 +15,43 @@ "prompt": { "type": "string", "example": "A serene mountain landscape at sunset with a lake reflection", - "description": "The text prompt describing the image to generate.", + "description": "Text description of the image to generate.", "required": True, }, "output_path": { "type": "string", "example": "C:/Users/user/Pictures/generated_image.png", - "description": "Absolute path where the generated image will be saved (e.g., C:/Users/user/image.png or /home/user/image.png). If not provided, saves to temp directory.", + "description": "Absolute path where the generated image will be saved. If omitted, saved to temp directory with a timestamped name.", }, "resolution": { "type": "string", "example": "2K", - "description": "Output resolution. Options: '1K' (1080p), '2K', '4K'. Default: '1K'. Higher resolution costs more.", + "description": "Output resolution: '1K' (default), '2K', or '4K'. Higher resolution costs more. OpenAI tops out at ~1536px regardless of this setting.", }, "aspect_ratio": { "type": "string", "example": "16:9", - "description": "Aspect ratio of the generated image. Options: '1:1', '3:4', '4:3', '9:16', '16:9'. Default: '1:1'.", + "description": "Aspect ratio: '1:1' (default), '3:4', '4:3', '9:16', '16:9'. OpenAI maps to the nearest available canvas size — true 16:9 and 9:16 are not supported natively.", }, "number_of_images": { "type": "integer", "example": 1, - "description": "Number of images to generate (1-4). Default: 1.", + "description": "Number of images to generate (1–4). Default: 1.", }, "negative_prompt": { "type": "string", "example": "blurry, low quality, distorted", - "description": "Text describing what to avoid in the generated image.", + "description": "Elements to avoid. Native on Gemini; appended to prompt for OpenAI.", }, "reference_images": { "type": "array", - "example": [ - "C:/Users/user/Pictures/reference1.png", - "C:/Users/user/Pictures/reference2.png", - ], - "description": "Optional list of reference image absolute paths to guide generation (up to 14 images). Use full absolute paths.", + "example": ["C:/Users/user/Pictures/ref.png"], + "description": "Optional reference image paths (up to 14). Gemini: style guidance. OpenAI: compositional/mask inputs (different behaviour — results may vary).", }, "safety_filter_level": { "type": "string", "example": "block_medium_and_above", - "description": "Safety filter level (Gemini only). Options: 'block_none', 'block_only_high', 'block_medium_and_above', 'block_low_and_above'. Default: 'block_medium_and_above'. Ignored when using OpenAI.", - }, - "provider_preference": { - "type": "string", - "example": "openai", - "description": "Which provider to use: 'openai' (Images 2.0 / gpt-image-2) or 'gemini' (Nano Banana 2 / gemini-3.1-flash-image-preview). Only needed when both API keys are configured and no saved preference exists. Providing this saves it as the default for future calls.", + "description": "Gemini safety filter: 'block_none', 'block_only_high', 'block_medium_and_above' (default), 'block_low_and_above'. Ignored by OpenAI.", }, }, output_schema={ @@ -72,19 +62,11 @@ }, "image_paths": { "type": "array", - "description": "List of paths to the generated image files.", - }, - "prompt_used": { - "type": "string", - "description": "The prompt that was used for generation.", - }, - "resolution": { - "type": "string", - "description": "The resolution of the generated image.", + "description": "Absolute paths to the generated PNG files.", }, "message": { "type": "string", - "description": "Status message or error message.", + "description": "Status message or error details.", }, }, requirement=["google-genai", "openai", "Pillow"], @@ -97,491 +79,63 @@ }, ) def generate_image(input_data: dict) -> dict: - """ - Generates an image using OpenAI's Images 2.0 (gpt-image-2) or Google's Nano Banana 2 (gemini-3.1-flash-image-preview). - """ - import os - import sys - import subprocess - import importlib - import tempfile - from datetime import datetime - simulated_mode = input_data.get("simulated_mode", False) - if simulated_mode: return { "status": "success", "image_paths": ["/tmp/simulated_image_001.png"], - "prompt_used": input_data.get("prompt", "Simulated prompt"), - "resolution": input_data.get("resolution", "1K"), "message": "Image generated successfully (simulated mode).", } - # Determine which provider to use based on available API keys and user preference - from app.config import get_api_key, get_settings, save_settings + import app.internal_action_interface as iai + from agent_core.core.models.model_registry import MODEL_REGISTRY + from agent_core.core.models.types import InterfaceType + from app.config import get_image_gen_provider - openai_key = get_api_key("openai") - gemini_key = get_api_key("gemini") + image_gen = iai.InternalActionInterface.image_gen_interface + current_provider = get_image_gen_provider() + registry_entry = MODEL_REGISTRY.get(current_provider, {}).get(InterfaceType.IMAGE_GEN) - if not openai_key and not gemini_key: + if image_gen is None or not registry_entry: return { "status": "error", "image_paths": [], - "prompt_used": "", - "resolution": "", "message": ( - "No image generation API key is configured. " - "Tell the user they need either an OpenAI API key (for Images 2.0 / gpt-image-2) " - "or a Google Gemini API key (for Nano Banana 2 / gemini-3.1-flash-image-preview), " - "and ask if they need help setting one up." + f"Provider '{current_provider}' does not support image generation. " + "Set 'image_gen_provider' to 'openai' or 'gemini' in settings.json " + "and ensure the corresponding API key is configured under 'api_keys'." ), } - provider_preference = input_data.get("provider_preference", "").strip().lower() - _cfg = get_settings() - saved_provider = _cfg.get("image_generation", {}).get("preferred_provider", "") - - if openai_key and not gemini_key: - provider = "openai" - elif gemini_key and not openai_key: - provider = "gemini" - else: - # Both keys present - if provider_preference in ("openai", "gemini"): - provider = provider_preference - _cfg.setdefault("image_generation", {})["preferred_provider"] = provider - save_settings(_cfg) - elif saved_provider in ("openai", "gemini"): - provider = saved_provider - else: - provider = "gemini" - - api_key = openai_key if provider == "openai" else gemini_key - - # Validate required input - prompt = input_data.get("prompt", "").strip() + prompt = str(input_data.get("prompt", "")).strip() if not prompt: return { "status": "error", "image_paths": [], - "prompt_used": "", - "resolution": "", - "message": "A prompt is required to generate an image.", - } - - # Get optional parameters - output_path = input_data.get("output_path", "") - resolution = input_data.get("resolution", "1K").upper() - aspect_ratio = input_data.get("aspect_ratio", "1:1") - number_of_images = min(max(int(input_data.get("number_of_images", 1)), 1), 4) - negative_prompt = input_data.get("negative_prompt", "") - reference_images = input_data.get("reference_images", []) - safety_filter_level = input_data.get( - "safety_filter_level", "block_medium_and_above" - ) - - # Validate resolution with user feedback - valid_resolutions = ["1K", "2K", "4K"] - warnings = [] - if resolution not in valid_resolutions: - warnings.append( - f"Invalid resolution '{resolution}'. Defaulting to '1K'. Valid options: {', '.join(valid_resolutions)}." - ) - resolution = "1K" - - # Validate aspect ratio with user feedback - valid_ratios = ["1:1", "3:4", "4:3", "9:16", "16:9"] - if aspect_ratio not in valid_ratios: - warnings.append( - f"Invalid aspect ratio '{aspect_ratio}'. Defaulting to '1:1'. Valid options: {', '.join(valid_ratios)}." - ) - aspect_ratio = "1:1" - - # Validate safety filter level with user feedback - valid_safety_levels = [ - "block_none", - "block_only_high", - "block_medium_and_above", - "block_low_and_above", - ] - if safety_filter_level not in valid_safety_levels: - warnings.append( - f"Invalid safety filter level '{safety_filter_level}'. Defaulting to 'block_medium_and_above'. Valid options: {', '.join(valid_safety_levels)}." - ) - safety_filter_level = "block_medium_and_above" - - # Validate number_of_images with user feedback - raw_num = int(input_data.get("number_of_images", 1)) - if raw_num < 1 or raw_num > 4: - warnings.append( - f"number_of_images '{raw_num}' out of range. Clamped to {number_of_images}. Valid range: 1-4." - ) - - # Limit reference images to 14 - if len(reference_images) > 14: - warnings.append( - f"Too many reference images ({len(reference_images)}). Only the first 14 will be used." - ) - reference_images = reference_images[:14] - - # Helper: extract images from Gemini response - def _extract_images_from_response(response): - images = [] - # Primary path: candidates[].content.parts[].inline_data - if hasattr(response, "candidates") and response.candidates: - for candidate in response.candidates: - if not ( - hasattr(candidate, "content") - and hasattr(candidate.content, "parts") - ): - continue - for part in candidate.content.parts: - if hasattr(part, "inline_data") and part.inline_data: - if hasattr( - part.inline_data, "mime_type" - ) and part.inline_data.mime_type.startswith("image/"): - images.append(part.inline_data.data) - # Fallback: response.images (older SDK versions) - if not images and hasattr(response, "images"): - for img in response.images: - if hasattr(img, "data"): - images.append(img.data) - elif hasattr(img, "_pil_image"): - images.append(img) - return images - - # Helper: check if response was blocked by safety filters - def _get_block_reason(response): - if hasattr(response, "prompt_feedback"): - feedback = response.prompt_feedback - if hasattr(feedback, "block_reason") and feedback.block_reason: - return str(feedback.block_reason) - if hasattr(response, "candidates") and response.candidates: - for candidate in response.candidates: - if hasattr(candidate, "finish_reason") and candidate.finish_reason: - reason = str(candidate.finish_reason) - if "SAFETY" in reason.upper(): - return reason - return None - - # Helper: build the save path for a generated image - def _build_save_path(output_path, timestamp, index, number_of_images, total_found): - if output_path: - if number_of_images > 1 or total_found > 1: - base, ext = os.path.splitext(output_path) - if not ext: - ext = ".png" - return f"{base}_{index + 1}{ext}" - else: - save_path = output_path - if not os.path.splitext(save_path)[1]: - save_path += ".png" - return save_path - else: - temp_dir = tempfile.gettempdir() - return os.path.join( - temp_dir, f"generated_image_{timestamp}_{index + 1}.png" - ) - - # Helper: convert image data to PIL Image - def _to_pil_image(img_data, Image, io, base64): - if isinstance(img_data, str): - image_bytes = base64.b64decode(img_data) - return Image.open(io.BytesIO(image_bytes)) - elif isinstance(img_data, bytes): - return Image.open(io.BytesIO(img_data)) - elif hasattr(img_data, "_pil_image"): - return img_data._pil_image - else: - return img_data - - # Ensure required packages are installed - def _ensure_package(pkg_name): - try: - importlib.import_module(pkg_name.replace("-", "_").split("[")[0]) - except ImportError: - subprocess.check_call( - [sys.executable, "-m", "pip", "install", pkg_name, "--quiet"] - ) - - try: - _ensure_package("google-genai") - _ensure_package("openai") - _ensure_package("Pillow") - except Exception as e: - return { - "status": "error", - "image_paths": [], - "prompt_used": prompt, - "resolution": resolution, - "message": f"Failed to install required packages: {str(e)}", + "message": "prompt is required.", } try: - from PIL import Image - import io - import base64 - - image_paths = [] - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - - if provider == "gemini": - from google import genai - from google.genai import types - - client = genai.Client(api_key=api_key) - - # Prepare reference images if provided - image_parts = [] - for ref_path in reference_images: - if os.path.exists(ref_path): - try: - with open(ref_path, "rb") as f: - image_data = f.read() - ext = os.path.splitext(ref_path)[1].lower() - mime_map = { - ".png": "image/png", - ".jpg": "image/jpeg", - ".jpeg": "image/jpeg", - ".gif": "image/gif", - ".webp": "image/webp", - } - mime_type = mime_map.get(ext, "image/png") - image_parts.append( - types.Part.from_bytes(data=image_data, mime_type=mime_type) - ) - except Exception: - pass # Skip invalid reference images - - # Build the prompt with generation instructions - generation_prompt = f"""Generate an image based on the following description: - -{prompt} - -Image specifications: -- Resolution: {resolution} -- Aspect ratio: {aspect_ratio} -- Number of variations: {number_of_images}""" - - if negative_prompt: - generation_prompt += f"\n- Avoid: {negative_prompt}" - - content_parts = list(image_parts) - content_parts.append(generation_prompt) - - # Safety settings - safety_settings = None - if safety_filter_level != "block_none": - harm_block_threshold = { - "block_only_high": "BLOCK_ONLY_HIGH", - "block_medium_and_above": "BLOCK_MEDIUM_AND_ABOVE", - "block_low_and_above": "BLOCK_LOW_AND_ABOVE", - }.get(safety_filter_level, "BLOCK_MEDIUM_AND_ABOVE") - - safety_settings = [ - types.SafetySetting( - category=category, threshold=harm_block_threshold - ) - for category in ( - "HARM_CATEGORY_HARASSMENT", - "HARM_CATEGORY_HATE_SPEECH", - "HARM_CATEGORY_SEXUALLY_EXPLICIT", - "HARM_CATEGORY_DANGEROUS_CONTENT", - ) - ] - - generate_config = types.GenerateContentConfig( - candidate_count=1, - response_modalities=["TEXT", "IMAGE"], - image_config=types.ImageConfig(image_size=resolution), - safety_settings=safety_settings, - ) - - response = client.models.generate_content( - model="gemini-3.1-flash-image-preview", - contents=content_parts, - config=generate_config, - ) - - images_found = _extract_images_from_response(response) - - if not images_found: - block_reason = _get_block_reason(response) - if block_reason: - return { - "status": "error", - "image_paths": [], - "prompt_used": prompt, - "resolution": resolution, - "message": f"Image generation was blocked by safety filters: {block_reason}. Try modifying your prompt or adjusting safety_filter_level.", - } - return { - "status": "error", - "image_paths": [], - "prompt_used": prompt, - "resolution": resolution, - "message": "No images were generated. The model did not produce image output for this prompt. Try rephrasing your prompt or check if your API key has access to image generation.", - } - - for i, img_data in enumerate(images_found[:number_of_images]): - save_path = _build_save_path( - output_path, timestamp, i, number_of_images, len(images_found) - ) - parent_dir = os.path.dirname(os.path.abspath(save_path)) - if parent_dir: - os.makedirs(parent_dir, exist_ok=True) - pil_image = _to_pil_image(img_data, Image, io, base64) - pil_image.save(save_path, "PNG") - image_paths.append(save_path) - - model_label = "Nano Banana 2" - - elif provider == "openai": - from openai import OpenAI - - if safety_filter_level != "block_medium_and_above": - warnings.append( - "safety_filter_level is not supported by OpenAI and has been ignored." - ) - - # Map aspect_ratio to OpenAI size string - openai_size_map = { - "1:1": "1024x1024", - "16:9": "1536x1024", - "4:3": "1536x1024", - "9:16": "1024x1536", - "3:4": "1024x1536", - } - openai_size = openai_size_map.get(aspect_ratio, "1024x1024") - - # Map resolution to OpenAI quality - openai_quality_map = {"1K": "medium", "2K": "high", "4K": "high"} - openai_quality = openai_quality_map.get(resolution, "medium") - - # Build prompt (OpenAI has no negative_prompt param — append to prompt) - full_prompt = prompt - if negative_prompt: - full_prompt += f"\n\nAvoid: {negative_prompt}" - - client = OpenAI(api_key=api_key) - - valid_ref_paths = [p for p in reference_images if os.path.exists(p)] - if valid_ref_paths: - image_files = [open(p, "rb") for p in valid_ref_paths] - try: - response = client.images.edit( - model="gpt-image-2", - image=image_files, - prompt=full_prompt, - n=number_of_images, - size=openai_size, - ) - finally: - for f in image_files: - f.close() - else: - response = client.images.generate( - model="gpt-image-2", - prompt=full_prompt, - n=number_of_images, - size=openai_size, - quality=openai_quality, - ) - - import urllib.request as _urllib_request - - images_found = [] - for item in response.data: - if item.b64_json: - images_found.append(base64.b64decode(item.b64_json)) - elif item.url: - with _urllib_request.urlopen(item.url) as _r: - images_found.append(_r.read()) - - if not images_found: - return { - "status": "error", - "image_paths": [], - "prompt_used": prompt, - "resolution": resolution, - "message": "No images were generated. The model did not produce image output for this prompt. Try rephrasing your prompt.", - } - - for i, img_bytes in enumerate(images_found[:number_of_images]): - save_path = _build_save_path( - output_path, timestamp, i, number_of_images, len(images_found) - ) - parent_dir = os.path.dirname(os.path.abspath(save_path)) - if parent_dir: - os.makedirs(parent_dir, exist_ok=True) - pil_image = Image.open(io.BytesIO(img_bytes)) - pil_image.save(save_path, "PNG") - image_paths.append(save_path) - - model_label = "Images 2.0" - - message = ( - f"Successfully generated {len(image_paths)} image(s) using {model_label}." + paths = iai.InternalActionInterface.generate_image( + prompt=prompt, + resolution=str(input_data.get("resolution", "1K")).upper(), + aspect_ratio=str(input_data.get("aspect_ratio", "1:1")), + number_of_images=min(max(int(input_data.get("number_of_images", 1)), 1), 4), + output_path=str(input_data.get("output_path") or ""), + negative_prompt=str(input_data.get("negative_prompt") or ""), + reference_images=list(input_data.get("reference_images") or []), + safety_filter_level=str( + input_data.get("safety_filter_level") or "block_medium_and_above" + ), ) - if warnings: - message += " Warnings: " + " ".join(warnings) - return { "status": "success", - "image_paths": image_paths, - "prompt_used": prompt, - "resolution": resolution, - "message": message, + "image_paths": paths, + "message": f"Generated {len(paths)} image(s) via {current_provider}.", } - except Exception as e: - error_message = str(e) - - if provider == "gemini": - if "quota" in error_message.lower() or "rate" in error_message.lower(): - error_message = ( - f"Gemini API rate limit or quota exceeded: {error_message}" - ) - elif "invalid" in error_message.lower() and "key" in error_message.lower(): - error_message = f"Invalid Gemini API key: {error_message}. Please verify your Google API key is correct." - elif ( - "permission" in error_message.lower() - or "access" in error_message.lower() - ): - error_message = f"Gemini API access denied: {error_message}. Ensure your API key has access to the Nano Banana 2 model." - elif ( - "safety" in error_message.lower() or "blocked" in error_message.lower() - ): - error_message = f"Content blocked by Gemini safety filters: {error_message}. Try modifying your prompt." - elif "not found" in error_message.lower() or "404" in error_message: - error_message = f"Gemini model not available: {error_message}. The gemini-3.1-flash-image-preview model may not be accessible with your API key. Try using Google AI Studio to verify access." - else: - if ( - "billing" in error_message.lower() - or "insufficient_quota" in error_message.lower() - or "rate" in error_message.lower() - ): - error_message = ( - f"OpenAI API rate limit or quota exceeded: {error_message}" - ) - elif "invalid_api_key" in error_message.lower() or ( - "invalid" in error_message.lower() and "key" in error_message.lower() - ): - error_message = f"Invalid OpenAI API key: {error_message}. Please verify your OpenAI API key is correct." - elif ( - "content_policy" in error_message.lower() - or "safety" in error_message.lower() - or "blocked" in error_message.lower() - ): - error_message = f"Content blocked by OpenAI safety policy: {error_message}. Try modifying your prompt." - elif "not found" in error_message.lower() or "404" in error_message: - error_message = f"OpenAI model not available: {error_message}. The gpt-image-2 model may not be accessible with your API key." - return { "status": "error", "image_paths": [], - "prompt_used": prompt, - "resolution": resolution, - "message": error_message, + "message": str(e), } diff --git a/app/image_gen_interface.py b/app/image_gen_interface.py new file mode 100644 index 00000000..0f2fa9ce --- /dev/null +++ b/app/image_gen_interface.py @@ -0,0 +1,55 @@ +# -*- coding: utf-8 -*- +""" +Image generation interface for CraftBot. + +Re-exports ImageGenInterface from agent_core with CraftBot-specific hooks +for state access (using STATE singleton) and usage reporting. +""" + +from typing import Optional + +from agent_core.core.impl.image_gen import ImageGenInterface as _ImageGenInterface +from agent_core.core.hooks.types import UsageEventData +from app.state.agent_state import get_session_props + + +def _get_token_count() -> int: + return get_session_props().get_property("token_count", 0) + + +def _set_token_count(count: int) -> None: + get_session_props().set_property("token_count", count) + + +async def _report_usage(event: UsageEventData) -> None: + from app.usage import get_usage_reporter + + await get_usage_reporter().report(event) + + +class ImageGenInterface(_ImageGenInterface): + """ImageGenInterface configured for CraftBot's STATE singleton. + + Automatically injects the get_token_count and set_token_count hooks + that use CraftBot's global STATE object. + """ + + def __init__( + self, + *, + provider: Optional[str] = None, + model: Optional[str] = None, + api_key: Optional[str] = None, + base_url: Optional[str] = None, + deferred: bool = False, + ) -> None: + super().__init__( + provider=provider, + model=model, + api_key=api_key, + base_url=base_url, + deferred=deferred, + get_token_count=_get_token_count, + set_token_count=_set_token_count, + report_usage=_report_usage, + ) diff --git a/app/internal_action_interface.py b/app/internal_action_interface.py index 0da8dc96..c80b457b 100644 --- a/app/internal_action_interface.py +++ b/app/internal_action_interface.py @@ -10,6 +10,7 @@ from typing import Dict, Any, Optional, List, TYPE_CHECKING from app.llm import LLMInterface, LLMCallType from app.vlm_interface import VLMInterface +from app.image_gen_interface import ImageGenInterface from app.task.task_manager import TaskManager from app.task import Task from app.state.state_manager import StateManager @@ -43,6 +44,7 @@ class InternalActionInterface: task_manager: Optional[TaskManager] = None state_manager: Optional[StateManager] = None vlm_interface: Optional[VLMInterface] = None + image_gen_interface: Optional[ImageGenInterface] = None context_engine: Optional["ContextEngine"] = None gui_module: Optional["GUIModule"] = None memory_manager: Optional[MemoryManager] = None @@ -57,6 +59,7 @@ def initialize( task_manager: TaskManager, state_manager: StateManager, vlm_interface: Optional[VLMInterface] = None, + image_gen_interface: Optional[ImageGenInterface] = None, context_engine: Optional["ContextEngine"] = None, gui_module: Optional["GUIModule"] = None, memory_manager: MemoryManager | None = None, @@ -74,6 +77,7 @@ def initialize( cls.task_manager = task_manager cls.state_manager = state_manager cls.vlm_interface = vlm_interface + cls.image_gen_interface = image_gen_interface cls.context_engine = context_engine cls.gui_module = gui_module cls.memory_manager = memory_manager @@ -110,6 +114,24 @@ def describe_image(cls, image_path: str, prompt: Optional[str] = None) -> str: ) return cls.vlm_interface.describe_image(image_path, user_prompt=prompt) + @classmethod + def generate_image(cls, **kwargs) -> List[str]: + """Generate image(s) from a prompt using the image generation interface. + + Delegates all arguments to ImageGenInterface.generate_image(). + + Returns: + List of absolute file paths to the generated images. + + Raises: + RuntimeError: If image_gen_interface is not initialized or generation fails. + """ + if cls.image_gen_interface is None: + raise RuntimeError( + "InternalActionInterface not initialized with ImageGenInterface." + ) + return cls.image_gen_interface.generate_image(**kwargs) + @classmethod def perform_ocr(cls, image_path: str, user_prompt: Optional[str] = None) -> dict: """ diff --git a/app/main.py b/app/main.py index 37f4b981..34c737cd 100644 --- a/app/main.py +++ b/app/main.py @@ -65,10 +65,12 @@ def _suppress_console_logging_early() -> None: from app.config import ( get_llm_provider, get_vlm_provider, + get_image_gen_provider, get_api_key, get_base_url, get_llm_model, get_vlm_model, + get_image_gen_model, ) from app.agent_base import AgentBase @@ -116,7 +118,8 @@ def _initial_settings() -> tuple: """Determine initial provider, API key, and base URL from settings.json. Returns: - Tuple of (provider, api_key, base_url, model, vlm_provider, vlm_model, has_valid_key) + Tuple of (provider, api_key, base_url, model, vlm_provider, vlm_model, + image_gen_provider, image_gen_model, has_valid_key) where has_valid_key indicates if a working API key was found. """ # Read directly from settings.json @@ -126,11 +129,13 @@ def _initial_settings() -> tuple: model = get_llm_model() # None → use registry default for the provider vlm_prov = get_vlm_provider() vlm_mod = get_vlm_model() + img_prov = get_image_gen_provider() + img_mod = get_image_gen_model() # Remote (Ollama) doesn't require API key has_key = bool(api_key) or provider == "remote" - return provider, api_key, base_url, model, vlm_prov, vlm_mod, has_key + return provider, api_key, base_url, model, vlm_prov, vlm_mod, img_prov, img_mod, has_key async def main_async() -> None: @@ -139,7 +144,7 @@ async def main_async() -> None: browser_mode = cli_args.get("browser", False) # Get settings from settings.json - provider, api_key, base_url, model, vlm_prov, vlm_mod, has_valid_key = ( + provider, api_key, base_url, model, vlm_prov, vlm_mod, img_prov, img_mod, has_valid_key = ( _initial_settings() ) @@ -166,6 +171,8 @@ async def main_async() -> None: llm_model=model, vlm_provider=vlm_prov, vlm_model=vlm_mod, + image_gen_provider=img_prov, + image_gen_model=img_mod, deferred_init=not has_valid_key, ) diff --git a/app/ui_layer/adapters/browser_adapter.py b/app/ui_layer/adapters/browser_adapter.py index 705081d0..edd47263 100644 --- a/app/ui_layer/adapters/browser_adapter.py +++ b/app/ui_layer/adapters/browser_adapter.py @@ -4659,6 +4659,7 @@ async def _handle_model_settings_update(self, data: Dict[str, Any]) -> None: try: new_provider = data.get("llmProvider") vlm_provider = data.get("vlmProvider") + image_gen_provider = data.get("imageGenProvider") api_key = data.get("apiKey") provider_for_key = data.get("providerForKey") base_url = data.get("baseUrl") @@ -4721,8 +4722,10 @@ async def _handle_model_settings_update(self, data: Dict[str, Any]) -> None: result = update_model_settings( llm_provider=new_provider, vlm_provider=vlm_provider, + image_gen_provider=image_gen_provider, llm_model=data.get("llmModel"), vlm_model=data.get("vlmModel"), + image_gen_model=data.get("imageGenModel"), api_key=api_key, provider_for_key=provider_for_key, base_url=base_url, @@ -4744,6 +4747,20 @@ async def _handle_model_settings_update(self, data: Dict[str, Any]) -> None: f"Settings saved but LLM reinitialization failed: {e}" ) + # Reinitialize image gen interface when its provider changes + if result.get("success") and image_gen_provider: + try: + agent = self._controller.agent + agent.reinitialize_image_gen(image_gen_provider) + logger.info( + f"[BROWSER] Image gen reinitialized with provider: {image_gen_provider}" + ) + except Exception as e: + logger.warning(f"[BROWSER] Failed to reinitialize image gen: {e}") + result["warning"] = result.get("warning") or ( + f"Settings saved but image gen reinitialization failed: {e}" + ) + await self._broadcast( { "type": "model_settings_update", diff --git a/app/ui_layer/browser/frontend/src/pages/Settings/ModelSettings.tsx b/app/ui_layer/browser/frontend/src/pages/Settings/ModelSettings.tsx index 2aa8c711..7dd71538 100644 --- a/app/ui_layer/browser/frontend/src/pages/Settings/ModelSettings.tsx +++ b/app/ui_layer/browser/frontend/src/pages/Settings/ModelSettings.tsx @@ -30,6 +30,8 @@ import { selectModelHasLoadedProviders, selectModelHasLoadedSettings, selectModelHasLoadedSlowMode, + selectImageGenProvider, + selectCurrentImageGenModel, } from '../../store/selectors/modelSettings' import { getOllamaInstallPercent } from '../../utils/ollamaInstall' import { @@ -47,7 +49,9 @@ interface ProviderInfo { base_url_env?: string llm_model: string | null vlm_model: string | null + image_gen_model: string | null has_vlm: boolean + has_image_gen: boolean supports_catalog?: boolean is_bedrock?: boolean } @@ -88,6 +92,8 @@ export function ModelSettings() { const ollamaModels = useAppSelector(selectOllamaModels) const ollamaAvailable = useAppSelector(selectOllamaAvailable) const awsCredentialsStatus = useAppSelector(selectAwsCredentials) + const imageGenProvider = useAppSelector(selectImageGenProvider) + const currentImageGenModel = useAppSelector(selectCurrentImageGenModel) const hasLoadedProviders = useAppSelector(selectModelHasLoadedProviders) const hasLoadedSettings = useAppSelector(selectModelHasLoadedSettings) const hasLoadedSlowMode = useAppSelector(selectModelHasLoadedSlowMode) @@ -111,6 +117,13 @@ export function ModelSettings() { const [newAwsSessionToken, setNewAwsSessionToken] = useState('') const [newAwsRegion, setNewAwsRegion] = useState('') + // Image generation form state (transient — local). + const [newImageGenProvider, setNewImageGenProvider] = useState('') + const [newImageGenModel, setNewImageGenModel] = useState('') + const [newImageGenApiKey, setNewImageGenApiKey] = useState('') + const [imageGenHasChanges, setImageGenHasChanges] = useState(false) + const [isImageGenSaving, setIsImageGenSaving] = useState(false) + // UI state (transient — local). const [isSaving, setIsSaving] = useState(false) const [isTesting, setIsTesting] = useState(false) @@ -167,6 +180,7 @@ export function ModelSettings() { onMessage('model_settings_update', (data: unknown) => { const d = data as { success: boolean; error?: string } setIsSaving(false) + setIsImageGenSaving(false) if (d.success) { setNewApiKey('') setNewBaseUrl('') @@ -177,6 +191,10 @@ export function ModelSettings() { setNewAwsSessionToken('') setNewAwsRegion('') setHasChanges(false) + setNewImageGenProvider('') + setNewImageGenModel('') + setNewImageGenApiKey('') + setImageGenHasChanges(false) showToast('success', 'Settings saved') } else { showToast('error', d.error || 'Failed to save') @@ -419,6 +437,19 @@ export function ModelSettings() { } } + const handleImageGenSave = () => { + setIsImageGenSaving(true) + const effectiveProvider = newImageGenProvider || imageGenProvider + send('model_settings_update', { + imageGenProvider: effectiveProvider, + imageGenModel: newImageGenModel || currentImageGenModel || undefined, + ...(newImageGenApiKey ? { + apiKey: newImageGenApiKey, + providerForKey: effectiveProvider, + } : {}), + }) + } + const handleDownloadModelClick = () => { setPullPhase('selecting') setPullBytes(null) @@ -890,6 +921,90 @@ export function ModelSettings() { + {/* Image Generation */} +
+
+

Image Generation

+

Configure the provider used for image generation

+
+ {(() => { + const imageGenProviders = providers.filter(p => p.has_image_gen) + const effectiveImgProvider = newImageGenProvider || imageGenProvider + const imgProviderInfo = imageGenProviders.find(p => p.id === effectiveImgProvider) + return ( +
+
+ + +
+ + {/* API Key for image gen provider */} + {imgProviderInfo?.requires_api_key && ( +
+ + {apiKeys[effectiveImgProvider]?.has_key && ( +
{apiKeys[effectiveImgProvider].masked_key}
+ )} + { setNewImageGenApiKey(e.target.value); setImageGenHasChanges(true) }} + placeholder={apiKeys[effectiveImgProvider]?.has_key ? 'Enter new key to replace...' : 'Enter API key...'} + /> +
+ )} + + {/* Model override */} +
+ + { setNewImageGenModel(e.target.value); setImageGenHasChanges(true) }} + placeholder={imgProviderInfo?.image_gen_model || 'Default model'} + /> +
+ +
+ +
+
+ ) + })()} + {/* Slow Mode */}
diff --git a/app/ui_layer/browser/frontend/src/store/selectors/modelSettings.ts b/app/ui_layer/browser/frontend/src/store/selectors/modelSettings.ts index 707894b4..b98bc312 100644 --- a/app/ui_layer/browser/frontend/src/store/selectors/modelSettings.ts +++ b/app/ui_layer/browser/frontend/src/store/selectors/modelSettings.ts @@ -6,6 +6,8 @@ export const selectApiKeys = (state: RootState) => state.modelSettings.apiKeys export const selectBaseUrls = (state: RootState) => state.modelSettings.baseUrls export const selectCurrentLlmModel = (state: RootState) => state.modelSettings.currentLlmModel export const selectCurrentVlmModel = (state: RootState) => state.modelSettings.currentVlmModel +export const selectImageGenProvider = (state: RootState) => state.modelSettings.imageGenProvider +export const selectCurrentImageGenModel = (state: RootState) => state.modelSettings.currentImageGenModel export const selectSlowModeEnabled = (state: RootState) => state.modelSettings.slowModeEnabled export const selectOllamaModels = (state: RootState) => state.modelSettings.ollamaModels export const selectOllamaAvailable = (state: RootState) => state.modelSettings.ollamaAvailable diff --git a/app/ui_layer/browser/frontend/src/store/slices/modelSettingsSlice.ts b/app/ui_layer/browser/frontend/src/store/slices/modelSettingsSlice.ts index b598a596..10a7ff5c 100644 --- a/app/ui_layer/browser/frontend/src/store/slices/modelSettingsSlice.ts +++ b/app/ui_layer/browser/frontend/src/store/slices/modelSettingsSlice.ts @@ -9,7 +9,9 @@ export interface ProviderInfo { base_url_env?: string llm_model: string | null vlm_model: string | null + image_gen_model: string | null has_vlm: boolean + has_image_gen: boolean supports_catalog?: boolean is_bedrock?: boolean } @@ -30,10 +32,12 @@ export interface AwsCredentialsStatus { interface ModelSettingsState { providers: ProviderInfo[] provider: string + imageGenProvider: string apiKeys: Record baseUrls: Record currentLlmModel: string currentVlmModel: string + currentImageGenModel: string slowModeEnabled: boolean ollamaModels: string[] ollamaAvailable: boolean | null @@ -46,10 +50,12 @@ interface ModelSettingsState { const initialState: ModelSettingsState = { providers: [], provider: 'anthropic', + imageGenProvider: 'openai', apiKeys: {}, baseUrls: {}, currentLlmModel: '', currentVlmModel: '', + currentImageGenModel: '', slowModeEnabled: false, ollamaModels: [], ollamaAvailable: null, @@ -69,15 +75,19 @@ const modelSettingsSlice = createSlice({ }, setSettings(state, action: PayloadAction<{ provider: string + imageGenProvider: string llmModel: string vlmModel: string + imageGenModel: string apiKeys: Record baseUrls: Record awsCredentials?: AwsCredentialsStatus | null }>) { state.provider = action.payload.provider + state.imageGenProvider = action.payload.imageGenProvider state.currentLlmModel = action.payload.llmModel state.currentVlmModel = action.payload.vlmModel + state.currentImageGenModel = action.payload.imageGenModel state.apiKeys = action.payload.apiKeys state.baseUrls = action.payload.baseUrls if (action.payload.awsCredentials !== undefined) { @@ -91,12 +101,18 @@ const modelSettingsSlice = createSlice({ setProvider(state, action: PayloadAction) { state.provider = action.payload }, + setImageGenProvider(state, action: PayloadAction) { + state.imageGenProvider = action.payload + }, setCurrentLlmModel(state, action: PayloadAction) { state.currentLlmModel = action.payload }, setCurrentVlmModel(state, action: PayloadAction) { state.currentVlmModel = action.payload }, + setCurrentImageGenModel(state, action: PayloadAction) { + state.currentImageGenModel = action.payload + }, setApiKeys(state, action: PayloadAction>) { state.apiKeys = action.payload }, @@ -118,8 +134,10 @@ export const { setProviders, setSettings, setProvider, + setImageGenProvider, setCurrentLlmModel, setCurrentVlmModel, + setCurrentImageGenModel, setApiKeys, setBaseUrls, setSlowModeEnabled, @@ -138,8 +156,10 @@ register('model_settings_get', (data, dispatch) => { const d = data as { success: boolean llm_provider: string + image_gen_provider: string llm_model: string | null vlm_model: string | null + image_gen_model: string | null api_keys: Record base_urls: Record aws_credentials?: AwsCredentialsStatus | null @@ -147,8 +167,10 @@ register('model_settings_get', (data, dispatch) => { if (d.success) { dispatch(setSettings({ provider: d.llm_provider || 'anthropic', + imageGenProvider: d.image_gen_provider || 'openai', llmModel: d.llm_model || '', vlmModel: d.vlm_model || '', + imageGenModel: d.image_gen_model || '', apiKeys: d.api_keys || {}, baseUrls: d.base_urls || {}, awsCredentials: d.aws_credentials ?? null, @@ -160,18 +182,25 @@ register('model_settings_update', (data, dispatch) => { const d = data as { success: boolean llm_provider?: string + image_gen_provider?: string llm_model?: string | null vlm_model?: string | null + image_gen_model?: string | null api_keys?: Record base_urls?: Record aws_credentials?: AwsCredentialsStatus | null } if (!d.success) return if (d.llm_provider) dispatch(setProvider(d.llm_provider)) + // Always update imageGenProvider when the field is present (even on partial saves); + // using `!== undefined` mirrors model_settings_get so version-mismatched backends + // that omit the field don't silently leave the UI showing a stale provider. + if (d.image_gen_provider !== undefined) dispatch(setImageGenProvider(d.image_gen_provider || 'openai')) if (d.api_keys) dispatch(setApiKeys(d.api_keys)) if (d.base_urls) dispatch(setBaseUrls(d.base_urls)) if (d.llm_model !== undefined) dispatch(setCurrentLlmModel(d.llm_model || '')) if (d.vlm_model !== undefined) dispatch(setCurrentVlmModel(d.vlm_model || '')) + if (d.image_gen_model !== undefined) dispatch(setCurrentImageGenModel(d.image_gen_model || '')) if (d.aws_credentials !== undefined) dispatch(setAwsCredentials(d.aws_credentials)) }) diff --git a/app/ui_layer/settings/model_settings.py b/app/ui_layer/settings/model_settings.py index 93500ca5..81c9c31d 100644 --- a/app/ui_layer/settings/model_settings.py +++ b/app/ui_layer/settings/model_settings.py @@ -175,6 +175,7 @@ def get_available_providers() -> Dict[str, Any]: llm_model = provider_models.get(InterfaceType.LLM) vlm_model = provider_models.get(InterfaceType.VLM) + image_gen_model = provider_models.get(InterfaceType.IMAGE_GEN) providers.append( { @@ -186,6 +187,8 @@ def get_available_providers() -> Dict[str, Any]: "llm_model": llm_model, "vlm_model": vlm_model, "has_vlm": vlm_model is not None, + "image_gen_model": image_gen_model, + "has_image_gen": image_gen_model is not None, "supports_catalog": info.get("supports_catalog", False), "is_bedrock": info.get("is_bedrock", False), } @@ -222,10 +225,12 @@ def get_model_settings() -> Dict[str, Any]: # Get configured providers (settings.json is the single source of truth) llm_provider = model_settings.get("llm_provider", "anthropic") vlm_provider = model_settings.get("vlm_provider", llm_provider) + image_gen_provider = model_settings.get("image_gen_provider", "openai") # Get custom models if set llm_model = model_settings.get("llm_model") vlm_model = model_settings.get("vlm_model") + image_gen_model = model_settings.get("image_gen_model") # Check API key status for each provider (settings.json only) api_keys = {} @@ -292,8 +297,10 @@ def get_model_settings() -> Dict[str, Any]: "success": True, "llm_provider": llm_provider, "vlm_provider": vlm_provider, + "image_gen_provider": image_gen_provider, "llm_model": llm_model, "vlm_model": vlm_model, + "image_gen_model": image_gen_model, "api_keys": api_keys, "base_urls": base_urls, "aws_credentials": aws_creds_status, @@ -308,8 +315,10 @@ def get_model_settings() -> Dict[str, Any]: def update_model_settings( llm_provider: Optional[str] = None, vlm_provider: Optional[str] = None, + image_gen_provider: Optional[str] = None, llm_model: Optional[str] = None, vlm_model: Optional[str] = None, + image_gen_model: Optional[str] = None, api_key: Optional[str] = None, provider_for_key: Optional[str] = None, base_url: Optional[str] = None, @@ -367,11 +376,31 @@ def update_model_settings( if vlm_model is None: settings["model"]["vlm_model"] = None + # Update image generation provider (validate before saving) + if image_gen_provider: + supported_img_providers = { + p for p, caps in MODEL_REGISTRY.items() if caps.get(InterfaceType.IMAGE_GEN) + } + if image_gen_provider not in supported_img_providers: + return { + "success": False, + "error": ( + f"'{image_gen_provider}' does not support image generation. " + f"Supported providers: {', '.join(sorted(supported_img_providers))}" + ), + } + old_img_provider = settings["model"].get("image_gen_provider") + settings["model"]["image_gen_provider"] = image_gen_provider + if image_gen_provider != old_img_provider and image_gen_model is None: + settings["model"]["image_gen_model"] = None + # Update custom models (explicit values override the auto-clear above) if llm_model is not None: settings["model"]["llm_model"] = llm_model if llm_model else None if vlm_model is not None: settings["model"]["vlm_model"] = vlm_model if vlm_model else None + if image_gen_model is not None: + settings["model"]["image_gen_model"] = image_gen_model if image_gen_model else None # Update API key in settings.json if provider_for_key and api_key is not None: