Skip to content

Commit 55b6b73

Browse files
authored
fix(cli): handle sync/async serve functions in m serve (#784)
* fix(cli): handle sync/async serve functions in m serve Fixes sync/async mismatch in `m serve` by detecting function type and handling appropriately: - Async serve functions are awaited directly - Sync serve functions are wrapped in asyncio.to_thread() to prevent blocking FastAPI's event loop This ensures the server can handle concurrent requests efficiently regardless of whether user-defined serve functions are sync or async. Changes: - cli/serve/app.py: Add asyncio/inspect imports, update make_chat_endpoint() to detect coroutine functions and wrap sync functions in to_thread() - test/cli/test_serve_sync_async.py: Add comprehensive test suite (9 tests) including empirical timing test that proves non-blocking behavior Signed-off-by: Mark Sturdevant <mark.sturdevant@ibm.com> * feat: Map OpenAI parameters to ModelOption sentinels: - temperature → ModelOption.TEMPERATURE - max_tokens → ModelOption.MAX_NEW_TOKENS - seed → ModelOption.SEED Signed-off-by: Mark Sturdevant <mark.sturdevant@ibm.com> * fix: catch BaseException to handle StopAsyncIteration With async added, need to fix the catch Signed-off-by: Mark Sturdevant <mark.sturdevant@ibm.com> * fix: function docstring for raises but has no Raises section Fix for -- Error: function raises but has no Raises section Signed-off-by: Mark Sturdevant <mark.sturdevant@ibm.com> * fix: more mapping and filtering of model options in m serve app * filter out model, n, user, and extra * comment the filtering * use a map and ModelOption.replace_keys() for mapping * fix replace_keys to no-op with from == to Signed-off-by: Mark Sturdevant <mark.sturdevant@ibm.com> * fix: improve tests per review * Use whole seconds for timing test so it should be less flakey * Return a real ModelOutputThunk instead of a misleading Mock * Use AsyncMock with side_effect for consistency Signed-off-by: Mark Sturdevant <mark.sturdevant@ibm.com> * fix: make m serve errors OpenAI API compatible * For n > 1 use 400 (we have not implemented n > 1 yet) * For pydantic validation errors, add a handler to convert to OpenAI API error format. Signed-off-by: Mark Sturdevant <mark.sturdevant@ibm.com> * fix: filter out params set to None in m serve model options Signed-off-by: Mark Sturdevant <mark.sturdevant@ibm.com> * fix: filter out model options that are not implemented in m serve * Add more fields to excluded_fields. Some of these should be implemented or handled better soon, but filter/ignore for now. * Move _build_model_options from nested function inside make_chat_endpoint to module level. The function has no closure dependencies and can be tested directly. * Add unit tests for _build_model_options * fix: fix test_n_zero_rejected_at_http_level reusing global FastAPI app instance causing test pollution Created a fresh FastAPI() instance for the test Added RequestValidationError import from fastapi.exceptions Manually registered the validation_exception_handler to maintain the same behavior Signed-off-by: Mark Sturdevant <mark.sturdevant@ibm.com> --------- Signed-off-by: Mark Sturdevant <mark.sturdevant@ibm.com>
1 parent 7bdb9f6 commit 55b6b73

5 files changed

Lines changed: 783 additions & 15 deletions

File tree

cli/serve/app.py

Lines changed: 99 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,21 @@
11
"""A simple app that runs an OpenAI compatible server wrapped around a M program."""
22

3+
import asyncio
34
import importlib.util
5+
import inspect
46
import os
57
import sys
68
import time
79
import uuid
810

911
import typer
1012
import uvicorn
11-
from fastapi import FastAPI
13+
from fastapi import FastAPI, Request
14+
from fastapi.exceptions import RequestValidationError
1215
from fastapi.responses import JSONResponse
1316

17+
from mellea.backends.model_options import ModelOption
18+
1419
from .models import (
1520
ChatCompletion,
1621
ChatCompletionMessage,
@@ -28,6 +33,34 @@
2833
)
2934

3035

36+
@app.exception_handler(RequestValidationError)
37+
async def validation_exception_handler(
38+
request: Request, exc: RequestValidationError
39+
) -> JSONResponse:
40+
"""Convert FastAPI validation errors to OpenAI-compatible format.
41+
42+
FastAPI returns 422 with a 'detail' array by default. OpenAI API uses
43+
400 with an 'error' object containing message, type, and param fields.
44+
"""
45+
# Extract the first validation error
46+
errors = exc.errors()
47+
if errors:
48+
first_error = errors[0]
49+
# Get the field name from the location tuple (e.g., ('body', 'n') -> 'n')
50+
param = first_error["loc"][-1] if first_error["loc"] else None
51+
message = first_error["msg"]
52+
else:
53+
param = None
54+
message = "Invalid request parameters"
55+
56+
return create_openai_error_response(
57+
status_code=400,
58+
message=message,
59+
error_type="invalid_request_error",
60+
param=str(param) if param else None,
61+
)
62+
63+
3164
def load_module_from_path(path: str):
3265
"""Load the module with M program in it."""
3366
module_name = os.path.splitext(os.path.basename(path))[0]
@@ -50,23 +83,79 @@ def create_openai_error_response(
5083
)
5184

5285

86+
def _build_model_options(request: ChatCompletionRequest) -> dict:
87+
"""Build model_options dict from OpenAI-compatible request parameters."""
88+
excluded_fields = {
89+
# Request structure fields (handled separately)
90+
"messages", # Chat messages - passed separately to serve()
91+
"requirements", # Mellea requirements - passed separately to serve()
92+
# Routing/metadata fields (not generation parameters)
93+
"model", # Model identifier - used for routing, not generation
94+
"n", # Number of completions - not supported in Mellea's model_options
95+
"user", # User tracking ID - metadata, not a generation parameter
96+
"extra", # Pydantic's extra fields dict - unused (see model_config)
97+
# Not-yet-implemented OpenAI parameters (silently ignored)
98+
"stream", # Streaming responses - not yet implemented
99+
"stop", # Stop sequences - not yet implemented
100+
"top_p", # Nucleus sampling - not yet implemented
101+
"presence_penalty", # Presence penalty - not yet implemented
102+
"frequency_penalty", # Frequency penalty - not yet implemented
103+
"logit_bias", # Logit bias - not yet implemented
104+
"response_format", # Response format (json_object) - not yet implemented
105+
"functions", # Legacy function calling - not yet implemented
106+
"function_call", # Legacy function calling - not yet implemented
107+
"tools", # Tool calling - not yet implemented
108+
"tool_choice", # Tool choice - not yet implemented
109+
}
110+
openai_to_model_option = {
111+
"temperature": ModelOption.TEMPERATURE,
112+
"max_tokens": ModelOption.MAX_NEW_TOKENS,
113+
"seed": ModelOption.SEED,
114+
}
115+
116+
filtered_options = {
117+
key: value
118+
for key, value in request.model_dump(exclude_none=True).items()
119+
if key not in excluded_fields
120+
}
121+
return ModelOption.replace_keys(filtered_options, openai_to_model_option)
122+
123+
53124
def make_chat_endpoint(module):
54125
"""Makes a chat endpoint using a custom module."""
55126

56127
async def endpoint(request: ChatCompletionRequest):
57128
try:
129+
# Validate that n=1 (we don't support multiple completions)
130+
if request.n is not None and request.n > 1:
131+
return create_openai_error_response(
132+
status_code=400,
133+
message=f"Multiple completions (n={request.n}) are not supported. Please set n=1 or omit the parameter.",
134+
error_type="invalid_request_error",
135+
param="n",
136+
)
137+
58138
completion_id = f"chatcmpl-{uuid.uuid4().hex[:29]}"
59139
created_timestamp = int(time.time())
60140

61-
output = module.serve(
62-
input=request.messages,
63-
requirements=request.requirements,
64-
model_options={
65-
k: v
66-
for k, v in request.model_dump().items()
67-
if k not in ["messages", "requirements"]
68-
},
69-
)
141+
model_options = _build_model_options(request)
142+
143+
# Detect if serve is async or sync and handle accordingly
144+
if inspect.iscoroutinefunction(module.serve):
145+
# It's async, await it directly
146+
output = await module.serve(
147+
input=request.messages,
148+
requirements=request.requirements,
149+
model_options=model_options,
150+
)
151+
else:
152+
# It's sync, run in thread pool to avoid blocking event loop
153+
output = await asyncio.to_thread(
154+
module.serve,
155+
input=request.messages,
156+
requirements=request.requirements,
157+
model_options=model_options,
158+
)
70159

71160
# Extract usage information from the ModelOutputThunk if available
72161
usage = None

mellea/backends/model_options.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,10 @@ def replace_keys(options: dict, from_to: dict[str, str]) -> dict[str, Any]:
9191
# This will usually be a @@@<>@@@ ModelOption.<> key.
9292
new_key = from_to.get(old_key, None)
9393
if new_key:
94+
# Skip if old_key and new_key are the same (no-op replacement)
95+
if old_key == new_key:
96+
continue
97+
9498
if new_options.get(new_key, None) is not None:
9599
# The key already has a value associated with it in the dict. Leave it be.
96100
conflict_log.append(
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
"""Unit tests for _build_model_options function."""
2+
3+
import pytest
4+
5+
from cli.serve.app import _build_model_options
6+
from cli.serve.models import ChatCompletionRequest, ChatMessage
7+
from mellea.backends.model_options import ModelOption
8+
9+
10+
class TestBuildModelOptions:
11+
"""Direct unit tests for _build_model_options."""
12+
13+
def test_temperature_mapping(self):
14+
"""Test that temperature is correctly mapped to ModelOption.TEMPERATURE."""
15+
request = ChatCompletionRequest(
16+
model="test-model",
17+
messages=[ChatMessage(role="user", content="test")],
18+
temperature=0.7,
19+
)
20+
options = _build_model_options(request)
21+
assert options[ModelOption.TEMPERATURE] == 0.7
22+
23+
def test_max_tokens_mapping(self):
24+
"""Test that max_tokens is correctly mapped to ModelOption.MAX_NEW_TOKENS."""
25+
request = ChatCompletionRequest(
26+
model="test-model",
27+
messages=[ChatMessage(role="user", content="test")],
28+
max_tokens=100,
29+
)
30+
options = _build_model_options(request)
31+
assert options[ModelOption.MAX_NEW_TOKENS] == 100
32+
33+
def test_seed_mapping(self):
34+
"""Test that seed is correctly mapped to ModelOption.SEED."""
35+
request = ChatCompletionRequest(
36+
model="test-model",
37+
messages=[ChatMessage(role="user", content="test")],
38+
seed=42,
39+
)
40+
options = _build_model_options(request)
41+
assert options[ModelOption.SEED] == 42
42+
43+
def test_multiple_options(self):
44+
"""Test that multiple options are correctly mapped together."""
45+
request = ChatCompletionRequest(
46+
model="test-model",
47+
messages=[ChatMessage(role="user", content="test")],
48+
temperature=0.8,
49+
max_tokens=200,
50+
seed=123,
51+
)
52+
options = _build_model_options(request)
53+
assert options[ModelOption.TEMPERATURE] == 0.8
54+
assert options[ModelOption.MAX_NEW_TOKENS] == 200
55+
assert options[ModelOption.SEED] == 123
56+
57+
def test_excluded_fields_not_in_output(self):
58+
"""Test that excluded fields are not included in model_options."""
59+
request = ChatCompletionRequest(
60+
model="test-model",
61+
messages=[ChatMessage(role="user", content="test")],
62+
n=1,
63+
user="test-user",
64+
stream=False,
65+
temperature=0.5,
66+
)
67+
options = _build_model_options(request)
68+
# Check that excluded fields are not present
69+
assert "model" not in options
70+
assert "messages" not in options
71+
assert "n" not in options
72+
assert "user" not in options
73+
assert "stream" not in options
74+
# Check that temperature is present
75+
assert ModelOption.TEMPERATURE in options
76+
77+
def test_none_values_excluded(self):
78+
"""Test that None values are excluded from output."""
79+
request = ChatCompletionRequest(
80+
model="test-model",
81+
messages=[ChatMessage(role="user", content="test")],
82+
temperature=None,
83+
max_tokens=None,
84+
)
85+
options = _build_model_options(request)
86+
assert ModelOption.TEMPERATURE not in options
87+
assert ModelOption.MAX_NEW_TOKENS not in options
88+
89+
def test_minimal_request_includes_defaults(self):
90+
"""Test that a minimal request includes default values like temperature."""
91+
request = ChatCompletionRequest(
92+
model="test-model", messages=[ChatMessage(role="user", content="test")]
93+
)
94+
options = _build_model_options(request)
95+
# ChatCompletionRequest has default temperature=1.0
96+
assert options == {ModelOption.TEMPERATURE: 1.0}
97+
98+
def test_requirements_excluded(self):
99+
"""Test that requirements field is excluded from model_options."""
100+
request = ChatCompletionRequest(
101+
model="test-model",
102+
messages=[ChatMessage(role="user", content="test")],
103+
requirements=["req1", "req2"],
104+
temperature=0.7,
105+
)
106+
options = _build_model_options(request)
107+
assert "requirements" not in options
108+
assert ModelOption.TEMPERATURE in options

0 commit comments

Comments
 (0)