1212import uvicorn
1313from fastapi import FastAPI , Request
1414from fastapi .exceptions import RequestValidationError
15- from fastapi .responses import JSONResponse
15+ from fastapi .responses import JSONResponse , StreamingResponse
1616
1717from mellea .backends .model_options import ModelOption
18+ from mellea .helpers .openai_compatible_helpers import (
19+ build_completion_usage ,
20+ stream_chat_completion_chunks ,
21+ )
1822
1923from .models import (
2024 ChatCompletion ,
2125 ChatCompletionMessage ,
2226 ChatCompletionRequest ,
2327 Choice ,
24- CompletionUsage ,
2528 OpenAIError ,
2629 OpenAIErrorResponse ,
2730)
@@ -94,8 +97,8 @@ def _build_model_options(request: ChatCompletionRequest) -> dict:
9497 "n" , # Number of completions - not supported in Mellea's model_options
9598 "user" , # User tracking ID - metadata, not a generation parameter
9699 "extra" , # Pydantic's extra fields dict - unused (see model_config)
100+ "stream_options" , # Streaming options - handled separately in streaming response
97101 # Not-yet-implemented OpenAI parameters (silently ignored)
98- "stream" , # Streaming responses - not yet implemented
99102 "stop" , # Stop sequences - not yet implemented
100103 "top_p" , # Nucleus sampling - not yet implemented
101104 "presence_penalty" , # Presence penalty - not yet implemented
@@ -111,6 +114,7 @@ def _build_model_options(request: ChatCompletionRequest) -> dict:
111114 "temperature" : ModelOption .TEMPERATURE ,
112115 "max_tokens" : ModelOption .MAX_NEW_TOKENS ,
113116 "seed" : ModelOption .SEED ,
117+ "stream" : ModelOption .STREAM ,
114118 }
115119
116120 filtered_options = {
@@ -157,19 +161,17 @@ async def endpoint(request: ChatCompletionRequest):
157161 model_options = model_options ,
158162 )
159163
160- # Extract usage information from the ModelOutputThunk if available
161- usage = None
162- if hasattr (output , "usage" ) and output .usage is not None :
163- prompt_tokens = output .usage .get ("prompt_tokens" , 0 )
164- completion_tokens = output .usage .get ("completion_tokens" , 0 )
165- # Calculate total_tokens if not provided
166- total_tokens = output .usage .get (
167- "total_tokens" , prompt_tokens + completion_tokens
168- )
169- usage = CompletionUsage (
170- prompt_tokens = prompt_tokens ,
171- completion_tokens = completion_tokens ,
172- total_tokens = total_tokens ,
164+ # Handle streaming response
165+ if request .stream :
166+ return StreamingResponse (
167+ stream_chat_completion_chunks (
168+ output = output ,
169+ completion_id = completion_id ,
170+ model = request .model ,
171+ created = created_timestamp ,
172+ stream_options = request .stream_options ,
173+ ),
174+ media_type = "text/event-stream" ,
173175 )
174176
175177 # system_fingerprint represents backend config hash, not model name
@@ -192,7 +194,7 @@ async def endpoint(request: ChatCompletionRequest):
192194 ],
193195 object = "chat.completion" , # type: ignore
194196 system_fingerprint = system_fingerprint ,
195- usage = usage ,
197+ usage = build_completion_usage ( output ) ,
196198 ) # type: ignore
197199 except ValueError as e :
198200 # Handle validation errors or invalid input
0 commit comments