Skip to content

Commit a9cf69e

Browse files
committed
feat: streaming support in m serve OpenAI API server
Fixes: #822 Signed-off-by: Mark Sturdevant <mark.sturdevant@ibm.com>
1 parent 55b6b73 commit a9cf69e

7 files changed

Lines changed: 362 additions & 37 deletions

File tree

cli/serve/app.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,19 @@
1212
import uvicorn
1313
from fastapi import FastAPI, Request
1414
from fastapi.exceptions import RequestValidationError
15-
from fastapi.responses import JSONResponse
15+
from fastapi.responses import JSONResponse, StreamingResponse
1616

1717
from mellea.backends.model_options import ModelOption
18+
from mellea.helpers.openai_compatible_helpers import (
19+
build_completion_usage,
20+
stream_chat_completion_chunks,
21+
)
1822

1923
from .models import (
2024
ChatCompletion,
2125
ChatCompletionMessage,
2226
ChatCompletionRequest,
2327
Choice,
24-
CompletionUsage,
2528
OpenAIError,
2629
OpenAIErrorResponse,
2730
)
@@ -94,8 +97,8 @@ def _build_model_options(request: ChatCompletionRequest) -> dict:
9497
"n", # Number of completions - not supported in Mellea's model_options
9598
"user", # User tracking ID - metadata, not a generation parameter
9699
"extra", # Pydantic's extra fields dict - unused (see model_config)
100+
"stream_options", # Streaming options - handled separately in streaming response
97101
# Not-yet-implemented OpenAI parameters (silently ignored)
98-
"stream", # Streaming responses - not yet implemented
99102
"stop", # Stop sequences - not yet implemented
100103
"top_p", # Nucleus sampling - not yet implemented
101104
"presence_penalty", # Presence penalty - not yet implemented
@@ -111,6 +114,7 @@ def _build_model_options(request: ChatCompletionRequest) -> dict:
111114
"temperature": ModelOption.TEMPERATURE,
112115
"max_tokens": ModelOption.MAX_NEW_TOKENS,
113116
"seed": ModelOption.SEED,
117+
"stream": ModelOption.STREAM,
114118
}
115119

116120
filtered_options = {
@@ -157,19 +161,17 @@ async def endpoint(request: ChatCompletionRequest):
157161
model_options=model_options,
158162
)
159163

160-
# Extract usage information from the ModelOutputThunk if available
161-
usage = None
162-
if hasattr(output, "usage") and output.usage is not None:
163-
prompt_tokens = output.usage.get("prompt_tokens", 0)
164-
completion_tokens = output.usage.get("completion_tokens", 0)
165-
# Calculate total_tokens if not provided
166-
total_tokens = output.usage.get(
167-
"total_tokens", prompt_tokens + completion_tokens
168-
)
169-
usage = CompletionUsage(
170-
prompt_tokens=prompt_tokens,
171-
completion_tokens=completion_tokens,
172-
total_tokens=total_tokens,
164+
# Handle streaming response
165+
if request.stream:
166+
return StreamingResponse(
167+
stream_chat_completion_chunks(
168+
output=output,
169+
completion_id=completion_id,
170+
model=request.model,
171+
created=created_timestamp,
172+
stream_options=request.stream_options,
173+
),
174+
media_type="text/event-stream",
173175
)
174176

175177
# system_fingerprint represents backend config hash, not model name
@@ -192,7 +194,7 @@ async def endpoint(request: ChatCompletionRequest):
192194
],
193195
object="chat.completion", # type: ignore
194196
system_fingerprint=system_fingerprint,
195-
usage=usage,
197+
usage=build_completion_usage(output),
196198
) # type: ignore
197199
except ValueError as e:
198200
# Handle validation errors or invalid input

cli/serve/models.py

Lines changed: 63 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from pydantic import BaseModel, Field
44

5+
from mellea.helpers.openai_compatible_helpers import CompletionUsage
6+
57

68
class ChatMessage(BaseModel):
79
role: Literal["system", "user", "assistant", "tool", "function"]
@@ -58,6 +60,13 @@ class ChatCompletionRequest(BaseModel):
5860
seed: int | None = None
5961
response_format: ResponseFormat | None = None
6062

63+
# OpenAI-compatible streaming options. Only applies when stream=True.
64+
# Supports `include_usage` (bool) to control whether usage statistics are
65+
# included in the final streaming chunk. Defaults to True (include usage)
66+
# when not specified for backward compatibility. For non-streaming requests
67+
# (stream=False), usage is always included regardless of this parameter.
68+
stream_options: dict[str, Any] | None = None
69+
6170
# For future/undocumented fields
6271
extra: dict[str, Any] = Field(default_factory=dict)
6372

@@ -88,17 +97,6 @@ class Choice(BaseModel):
8897
"""The reason the model stopped generating tokens."""
8998

9099

91-
class CompletionUsage(BaseModel):
92-
completion_tokens: int
93-
"""Number of tokens in the generated completion."""
94-
95-
prompt_tokens: int
96-
"""Number of tokens in the prompt."""
97-
98-
total_tokens: int
99-
"""Total number of tokens used in the request (prompt + completion)."""
100-
101-
102100
class ChatCompletion(BaseModel):
103101
id: str
104102
"""A unique identifier for the chat completion."""
@@ -125,6 +123,60 @@ class ChatCompletion(BaseModel):
125123
"""Usage statistics for the completion request."""
126124

127125

126+
class ChatCompletionChunkDelta(BaseModel):
127+
"""Delta content in a streaming chunk."""
128+
129+
content: str | None = None
130+
"""The content fragment in this chunk."""
131+
132+
role: Literal["assistant"] | None = None
133+
"""The role (only present in first chunk)."""
134+
135+
refusal: str | None = None
136+
"""The refusal message fragment, if any."""
137+
138+
139+
class ChatCompletionChunkChoice(BaseModel):
140+
"""A choice in a streaming chunk."""
141+
142+
index: int
143+
"""The index of the choice in the list of choices."""
144+
145+
delta: ChatCompletionChunkDelta
146+
"""The delta content for this chunk."""
147+
148+
finish_reason: (
149+
Literal["stop", "length", "content_filter", "tool_calls", "function_call"]
150+
| None
151+
) = None
152+
"""The reason the model stopped generating tokens (only in final chunk)."""
153+
154+
155+
class ChatCompletionChunk(BaseModel):
156+
"""A chunk in a streaming chat completion response."""
157+
158+
id: str
159+
"""A unique identifier for the chat completion."""
160+
161+
choices: list[ChatCompletionChunkChoice]
162+
"""A list of chat completion choices."""
163+
164+
created: int
165+
"""The Unix timestamp (in seconds) of when the chat completion was created."""
166+
167+
model: str
168+
"""The model used for the chat completion."""
169+
170+
object: Literal["chat.completion.chunk"]
171+
"""The object type, which is always `chat.completion.chunk`."""
172+
173+
system_fingerprint: str | None = None
174+
"""This fingerprint represents the backend configuration that the model runs with."""
175+
176+
usage: CompletionUsage | None = None
177+
"""Usage statistics for the final streaming chunk when available from the backend."""
178+
179+
128180
class OpenAIError(BaseModel):
129181
"""OpenAI API error object."""
130182

docs/examples/m_serve/README.md

Lines changed: 61 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,21 @@ A simple example showing how to structure a Mellea program for serving as an API
1313
- Custom validation functions for API constraints
1414
- Handling chat message inputs
1515

16+
### m_serve_example_streaming.py
17+
A dedicated streaming example for `m serve` that supports both modes:
18+
- `stream=False` returns a normal computed response
19+
- `stream=True` returns an uncomputed thunk so the server can emit
20+
incremental Server-Sent Events (SSE) chunks
21+
1622
### pii_serve.py
1723
Example of serving a PII (Personally Identifiable Information) detection service.
1824

1925
### client.py
20-
Client code for testing the served API endpoints.
26+
Client code for testing the served API endpoints with non-streaming requests.
27+
28+
### client_streaming.py
29+
Client code demonstrating streaming responses using Server-Sent Events (SSE)
30+
against `m_serve_example_streaming.py`.
2131

2232
## Concepts Demonstrated
2333

@@ -26,6 +36,7 @@ Client code for testing the served API endpoints.
2636
- **Output Formatting**: Returning appropriate response types
2737
- **Validation in Production**: Using requirements in deployed services
2838
- **Model Options**: Passing model configuration through API
39+
- **Streaming Responses**: Real-time token streaming via Server-Sent Events (SSE)
2940

3041
## Basic Pattern
3142

@@ -53,12 +64,59 @@ def serve(input: list[ChatMessage],
5364

5465
## Running the Server
5566

67+
### Sampling
68+
5669
```bash
57-
# Start the server
70+
# Start the sampling example server
5871
m serve docs/examples/m_serve/m_serve_example_simple.py
5972

60-
# In another terminal, test with client
73+
# In another terminal, test with the non-streaming client
6174
python docs/examples/m_serve/client.py
75+
76+
### Streaming
77+
78+
# Start the dedicated streaming example server
79+
m serve docs/examples/m_serve/m_serve_example_streaming.py
80+
81+
# In another terminal, test with the streaming client
82+
python docs/examples/m_serve/client_streaming.py
83+
```
84+
85+
## Streaming Support
86+
87+
The server supports streaming responses via Server-Sent Events (SSE) when the
88+
`stream=True` parameter is set in the request. This allows clients to receive
89+
tokens as they are generated, providing a better user experience for long-running
90+
generations.
91+
92+
For a real streaming demo, serve `m_serve_example_streaming.py`. That example
93+
supports both normal and streaming responses consistently. The sampling example
94+
(`m_serve_example_simple.py`) demonstrates rejection sampling and validation,
95+
not token-by-token streaming.
96+
97+
**Key Features:**
98+
- Real-time token streaming using SSE
99+
- OpenAI-compatible streaming format (`ChatCompletionChunk`)
100+
- Final chunk includes usage statistics when the backend provides usage data
101+
- The dedicated streaming example supports both `stream=False` and `stream=True`
102+
- Works with any backend that supports `ModelOutputThunk.astream()`
103+
104+
**Example:**
105+
```python
106+
import openai
107+
108+
client = openai.OpenAI(api_key="na", base_url="http://0.0.0.0:8080/v1")
109+
110+
# Enable streaming with stream=True
111+
stream = client.chat.completions.create(
112+
messages=[{"role": "user", "content": "Tell me a story"}],
113+
model="granite4:micro-h",
114+
stream=True,
115+
)
116+
117+
for chunk in stream:
118+
if chunk.choices[0].delta.content:
119+
print(chunk.choices[0].delta.content, end="", flush=True)
62120
```
63121

64122
## API Endpoints
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# pytest: skip_always
2+
"""Example client demonstrating responses from m serve.
3+
4+
This example shows how to use the OpenAI Python client with a Mellea server
5+
started with:
6+
7+
m serve docs/examples/m_serve/m_serve_example_streaming.py
8+
9+
Set ``streaming`` below to:
10+
- ``True`` for incremental SSE chunks
11+
- ``False`` for a normal non-streaming response
12+
"""
13+
14+
import openai
15+
16+
PORT = 8080
17+
18+
client = openai.OpenAI(api_key="na", base_url=f"http://0.0.0.0:{PORT}/v1")
19+
20+
streaming = True # streaming enabled toggle
21+
22+
print(f"stream={streaming} response:")
23+
print("-" * 50)
24+
25+
# Request either a streaming or non-streaming response from the dedicated example server
26+
if streaming:
27+
stream_result = client.chat.completions.create(
28+
messages=[
29+
{"role": "user", "content": "Count down from 100 using words not digits."}
30+
],
31+
model="granite4:micro-h",
32+
stream=True,
33+
)
34+
for chunk in stream_result:
35+
if chunk.choices[0].delta.content:
36+
# If you want to see the chunks more clearly separated, change end
37+
print(chunk.choices[0].delta.content, end="", flush=True)
38+
else:
39+
completion_result = client.chat.completions.create(
40+
messages=[
41+
{"role": "user", "content": "Count down from 100 using words not digits."}
42+
],
43+
model="granite4:micro-h",
44+
stream=False,
45+
)
46+
print(completion_result.choices[0].message.content)
47+
48+
print("\n" + "-" * 50)
49+
print("Stream complete!")
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# pytest: ollama, e2e
2+
3+
"""Example to run m serve with true streaming support."""
4+
5+
import mellea
6+
from cli.serve.models import ChatMessage
7+
from mellea.backends.model_options import ModelOption
8+
from mellea.core import ComputedModelOutputThunk, ModelOutputThunk
9+
from mellea.stdlib.context import SimpleContext
10+
11+
session = mellea.start_session(ctx=SimpleContext())
12+
13+
14+
async def serve(
15+
input: list[ChatMessage],
16+
requirements: list[str] | None = None,
17+
model_options: dict | None = None,
18+
) -> ModelOutputThunk | ComputedModelOutputThunk:
19+
"""Support both normal and streaming responses from the same example.
20+
21+
Returns a computed result for non-streaming requests and an uncomputed thunk
22+
for streaming requests.
23+
"""
24+
del requirements
25+
message = input[-1].content or ""
26+
is_streaming = bool((model_options or {}).get(ModelOption.STREAM, False))
27+
28+
if is_streaming:
29+
return await session.ainstruct(
30+
description=message,
31+
strategy=None,
32+
model_options=model_options,
33+
await_result=False,
34+
)
35+
36+
return await session.ainstruct(
37+
description=message,
38+
strategy=None,
39+
model_options=model_options,
40+
await_result=True,
41+
)

0 commit comments

Comments
 (0)