Skip to content

Commit 0b015e0

Browse files
committed
Add support fireworks image generation models
1 parent fe9238b commit 0b015e0

2 files changed

Lines changed: 242 additions & 1 deletion

File tree

llms/extensions/providers/fireworks.py

Lines changed: 174 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,184 @@
1+
import json
2+
import time
3+
4+
import aiohttp
5+
6+
17
def install_fireworks(ctx):
2-
from llms.main import OpenAiCompatible
8+
from llms.main import GeneratorBase, OpenAiCompatible
9+
10+
# https://docs.fireworks.ai/api-reference/generate-a-new-image-from-a-text-prompt
11+
class FireworksGenerator(GeneratorBase):
12+
sdk = "fireworks/image"
13+
14+
def __init__(self, **kwargs):
15+
super().__init__(**kwargs)
16+
self.base_url = "https://api.fireworks.ai/inference/v1/workflows"
17+
18+
def _model_url(self, model):
19+
"""Convert model ID (e.g. fireworks/flux-1-dev-fp8) to API path."""
20+
if "/" in model:
21+
parts = model.split("/", 1)
22+
return f"accounts/{parts[0]}/models/{parts[1]}"
23+
return model
24+
25+
def _is_async_model(self, model):
26+
"""Kontext models use async polling instead of synchronous text_to_image."""
27+
model_name = model.split("/")[-1] if "/" in model else model
28+
return "kontext" in model_name
29+
30+
def _save_and_respond(self, image_data, chat, seed=None, ext="png", context=None):
31+
model_name = chat["model"].split("/")[-1]
32+
filename = f"{model_name}.{ext}"
33+
relative_url, info = ctx.save_image_to_cache(
34+
image_data,
35+
filename,
36+
ctx.to_file_info(chat, {"seed": seed}),
37+
context=context,
38+
)
39+
return ctx.log_json(
40+
{
41+
"choices": [
42+
{
43+
"message": {
44+
"role": "assistant",
45+
"content": self.default_content,
46+
"images": [
47+
{
48+
"type": "image_url",
49+
"image_url": {"url": relative_url},
50+
}
51+
],
52+
}
53+
}
54+
],
55+
"created": int(time.time()),
56+
}
57+
)
58+
59+
async def _chat_sync(self, gen_url, headers, gen_request, chat, context=None):
60+
"""Synchronous text_to_image flow for schnell/dev models."""
61+
headers["Accept"] = "image/png"
62+
async with aiohttp.ClientSession() as session, session.post(
63+
gen_url,
64+
headers=headers,
65+
data=json.dumps(gen_request),
66+
timeout=aiohttp.ClientTimeout(total=300),
67+
) as response:
68+
if response.status >= 400:
69+
text = await response.text()
70+
raise Exception(f"Fireworks API error {response.status}: {text}")
71+
72+
finish_reason = response.headers.get("Finish-Reason", "")
73+
if finish_reason == "CONTENT_FILTERED":
74+
raise Exception("Image generation was filtered by safety check")
75+
76+
seed = response.headers.get("Seed")
77+
image_data = await response.read()
78+
return self._save_and_respond(image_data, chat, seed=seed, context=context)
79+
80+
async def _chat_async(self, gen_url, headers, gen_request, chat, context=None):
81+
"""Async polling flow for Kontext models."""
82+
import asyncio
83+
84+
headers["Accept"] = "application/json"
85+
async with aiohttp.ClientSession() as session:
86+
# Step 1: Submit generation request
87+
async with session.post(
88+
gen_url,
89+
headers=headers,
90+
data=json.dumps(gen_request),
91+
timeout=aiohttp.ClientTimeout(total=60),
92+
) as response:
93+
if response.status >= 400:
94+
text = await response.text()
95+
raise Exception(f"Fireworks API error {response.status}: {text}")
96+
result = await response.json()
97+
98+
request_id = result.get("request_id")
99+
if not request_id:
100+
raise Exception(f"No request_id returned: {json.dumps(result)}")
101+
ctx.log(f"Fireworks async request submitted: {request_id}")
102+
103+
# Step 2: Poll for result
104+
result_url = f"{gen_url}/get_result"
105+
for attempt in range(120):
106+
await asyncio.sleep(1)
107+
async with session.post(
108+
result_url,
109+
headers=headers,
110+
data=json.dumps({"id": request_id}),
111+
timeout=aiohttp.ClientTimeout(total=30),
112+
) as poll_response:
113+
poll_result = await poll_response.json()
114+
status = poll_result.get("status", "")
115+
116+
if status in ("Ready", "Complete", "Finished"):
117+
sample = poll_result.get("result", {}).get("sample")
118+
if not sample:
119+
raise Exception("No image data in completed result")
120+
121+
if isinstance(sample, str) and sample.startswith("http"):
122+
# Download from URL
123+
async with session.get(sample) as img_response:
124+
image_data = await img_response.read()
125+
else:
126+
# Base64 data
127+
import base64
128+
129+
image_data = base64.b64decode(sample)
130+
131+
seed = poll_result.get("result", {}).get("seed")
132+
return self._save_and_respond(image_data, chat, seed=seed, ext="png", context=context)
133+
134+
if status in ("Failed", "Error"):
135+
details = poll_result.get("details", "Unknown error")
136+
raise Exception(f"Fireworks generation failed: {details}")
137+
138+
ctx.log(f"Fireworks polling attempt {attempt + 1}: {status}")
139+
140+
raise Exception("Fireworks generation timed out after 120 seconds")
141+
142+
async def chat(self, chat, provider=None, context=None):
143+
headers = self.get_headers(provider, chat)
144+
if provider is not None:
145+
chat["model"] = provider.provider_model(chat["model"]) or chat["model"]
146+
147+
prompt = ctx.last_user_prompt(chat)
148+
image_config = chat.get("image_config", {})
149+
aspect_ratio = ctx.chat_to_aspect_ratio(chat) or "1:1"
150+
151+
gen_request = {
152+
"prompt": prompt,
153+
"aspect_ratio": aspect_ratio,
154+
}
155+
if "guidance_scale" in image_config:
156+
gen_request["guidance_scale"] = float(image_config["guidance_scale"])
157+
if "num_inference_steps" in image_config:
158+
gen_request["num_inference_steps"] = int(image_config["num_inference_steps"])
159+
if "seed" in image_config:
160+
gen_request["seed"] = int(image_config["seed"])
161+
162+
model_path = self._model_url(chat["model"])
163+
if self._is_async_model(chat["model"]):
164+
gen_url = f"{self.base_url}/{model_path}"
165+
ctx.log(f"POST {gen_url} (async)")
166+
else:
167+
gen_url = f"{self.base_url}/{model_path}/text_to_image"
168+
ctx.log(f"POST {gen_url}")
169+
ctx.log(self.gen_summary(gen_request))
170+
171+
if self._is_async_model(chat["model"]):
172+
return await self._chat_async(gen_url, headers, gen_request, chat, context=context)
173+
else:
174+
return await self._chat_sync(gen_url, headers, gen_request, chat, context=context)
3175

4176
class FireworksProvider(OpenAiCompatible):
5177
sdk = "@fireworks/ai-sdk-provider"
6178

7179
def __init__(self, **kwargs):
8180
super().__init__(**kwargs)
181+
self.modalities["image"] = FireworksGenerator(**kwargs)
9182

10183
async def process_chat(self, chat, provider_id=None):
11184
ret = await super().process_chat(chat, provider_id)

llms/providers-extra.json

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,72 @@
11
{
2+
"fireworks-ai": {
3+
"models": {
4+
"fireworks/flux-kontext-pro": {
5+
"name": "FLUX.1 Kontext Pro",
6+
"modalities": {
7+
"input": [
8+
"text"
9+
],
10+
"output": [
11+
"image"
12+
]
13+
},
14+
"cost": {
15+
"input": 0,
16+
"output": 0.04,
17+
"type": "per_image"
18+
}
19+
},
20+
"fireworks/flux-kontext-max": {
21+
"name": "FLUX.1 Kontext Max",
22+
"modalities": {
23+
"input": [
24+
"text"
25+
],
26+
"output": [
27+
"image"
28+
]
29+
},
30+
"cost": {
31+
"input": 0,
32+
"output": 0.08,
33+
"type": "per_image"
34+
}
35+
},
36+
"fireworks/flux-1-dev-fp8": {
37+
"name": "FLUX.1 Dev FP8",
38+
"modalities": {
39+
"input": [
40+
"text"
41+
],
42+
"output": [
43+
"image"
44+
]
45+
},
46+
"cost": {
47+
"input": 0,
48+
"output": 0.0005,
49+
"type": "per_step"
50+
}
51+
},
52+
"fireworks/flux-1-schnell-fp8": {
53+
"name": "FLUX.1 [schnell] FP8",
54+
"modalities": {
55+
"input": [
56+
"text"
57+
],
58+
"output": [
59+
"image"
60+
]
61+
},
62+
"cost": {
63+
"input": 0,
64+
"output": 0.00035,
65+
"type": "per_step"
66+
}
67+
}
68+
}
69+
},
270
"google": {
371
"models": {
472
"gemini-3.1-flash-image-preview": {

0 commit comments

Comments
 (0)