Skip to content

Commit 1d605b7

Browse files
committed
feat: implement multi-route inference proxy support
- Proto: add InferenceModelEntry message with alias/provider/model fields; add repeated models field to ClusterInferenceConfig, Set/Get request/response - Server: add upsert_multi_model_route() for storing multiple model entries under a single route slot; update resolve_route_by_name() to expand multi-model configs into per-alias ResolvedRoute entries - Router: add select_route() with alias-first, protocol-fallback strategy; add model_hint parameter to proxy_with_candidates() variants - Sandbox proxy: extract model field from JSON body as routing hint - Tests: 7 new tests covering select_route, multi-model resolution, and bundle expansion; all 291 existing tests continue to pass Signed-off-by: Lyle Hopkins <lyle@cosmicnetworks.com>
1 parent dd2be9a commit 1d605b7

5 files changed

Lines changed: 445 additions & 29 deletions

File tree

crates/openshell-router/src/lib.rs

Lines changed: 83 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,28 @@ pub struct Router {
3636
client: reqwest::Client,
3737
}
3838

39+
/// Select a route from `candidates` using alias-first, protocol-fallback strategy.
40+
///
41+
/// 1. If `model_hint` is provided, find a candidate whose `name` matches the hint
42+
/// **and** whose protocols include `protocol`.
43+
/// 2. Otherwise, return the first candidate whose protocols contain `protocol`.
44+
fn select_route<'a>(
45+
candidates: &'a [ResolvedRoute],
46+
protocol: &str,
47+
model_hint: Option<&str>,
48+
) -> Option<&'a ResolvedRoute> {
49+
if let Some(hint) = model_hint {
50+
if let Some(r) = candidates.iter().find(|r| {
51+
r.name == hint && r.protocols.iter().any(|p| p == protocol)
52+
}) {
53+
return Some(r);
54+
}
55+
}
56+
candidates
57+
.iter()
58+
.find(|r| r.protocols.iter().any(|p| p == protocol))
59+
}
60+
3961
impl Router {
4062
pub fn new() -> Result<Self, RouterError> {
4163
let client = reqwest::Client::builder()
@@ -57,8 +79,10 @@ impl Router {
5779

5880
/// Proxy a raw HTTP request to the first compatible route from `candidates`.
5981
///
60-
/// Filters candidates by `source_protocol` compatibility (exact match against
61-
/// one of the route's `protocols`), then forwards to the first match.
82+
/// When `model_hint` is provided, the router first looks for a candidate whose
83+
/// `name` (alias) matches the hint. If no alias matches, it falls back to
84+
/// protocol-based selection (first candidate whose `protocols` list contains
85+
/// `source_protocol`).
6286
pub async fn proxy_with_candidates(
6387
&self,
6488
source_protocol: &str,
@@ -67,11 +91,10 @@ impl Router {
6791
headers: Vec<(String, String)>,
6892
body: bytes::Bytes,
6993
candidates: &[ResolvedRoute],
94+
model_hint: Option<&str>,
7095
) -> Result<ProxyResponse, RouterError> {
7196
let normalized_source = source_protocol.trim().to_ascii_lowercase();
72-
let route = candidates
73-
.iter()
74-
.find(|r| r.protocols.iter().any(|p| p == &normalized_source))
97+
let route = select_route(candidates, &normalized_source, model_hint)
7598
.ok_or_else(|| RouterError::NoCompatibleRoute(source_protocol.to_string()))?;
7699

77100
info!(
@@ -111,11 +134,10 @@ impl Router {
111134
headers: Vec<(String, String)>,
112135
body: bytes::Bytes,
113136
candidates: &[ResolvedRoute],
137+
model_hint: Option<&str>,
114138
) -> Result<StreamingProxyResponse, RouterError> {
115139
let normalized_source = source_protocol.trim().to_ascii_lowercase();
116-
let route = candidates
117-
.iter()
118-
.find(|r| r.protocols.iter().any(|p| p == &normalized_source))
140+
let route = select_route(candidates, &normalized_source, model_hint)
119141
.ok_or_else(|| RouterError::NoCompatibleRoute(source_protocol.to_string()))?;
120142

121143
info!(
@@ -187,4 +209,57 @@ mod tests {
187209
let err = Router::from_config(&config).unwrap_err();
188210
assert!(matches!(err, RouterError::Internal(_)));
189211
}
212+
213+
fn make_route(name: &str, protocols: Vec<&str>) -> ResolvedRoute {
214+
ResolvedRoute {
215+
name: name.to_string(),
216+
endpoint: "http://localhost".to_string(),
217+
model: format!("{name}-model"),
218+
api_key: "key".to_string(),
219+
protocols: protocols.into_iter().map(String::from).collect(),
220+
auth: config::AuthHeader::Bearer,
221+
default_headers: Vec::new(),
222+
}
223+
}
224+
225+
#[test]
226+
fn select_route_protocol_fallback_when_no_hint() {
227+
let routes = vec![
228+
make_route("ollama-local", vec!["openai_chat_completions"]),
229+
make_route("anthropic-prod", vec!["anthropic_messages"]),
230+
];
231+
let r = select_route(&routes, "anthropic_messages", None).unwrap();
232+
assert_eq!(r.name, "anthropic-prod");
233+
}
234+
235+
#[test]
236+
fn select_route_alias_match_takes_priority() {
237+
let routes = vec![
238+
make_route("ollama-local", vec!["openai_chat_completions"]),
239+
make_route("openai-prod", vec!["openai_chat_completions", "openai_responses"]),
240+
];
241+
// Both support openai_chat_completions, but hint selects the second one.
242+
let r = select_route(&routes, "openai_chat_completions", Some("openai-prod")).unwrap();
243+
assert_eq!(r.name, "openai-prod");
244+
}
245+
246+
#[test]
247+
fn select_route_alias_must_also_match_protocol() {
248+
let routes = vec![
249+
make_route("ollama-local", vec!["openai_chat_completions"]),
250+
make_route("anthropic-prod", vec!["anthropic_messages"]),
251+
];
252+
// Hint says "anthropic-prod" but protocol is openai_chat_completions — can't use it.
253+
// Falls back to protocol match → ollama-local.
254+
let r = select_route(&routes, "openai_chat_completions", Some("anthropic-prod")).unwrap();
255+
assert_eq!(r.name, "ollama-local");
256+
}
257+
258+
#[test]
259+
fn select_route_no_match_returns_none() {
260+
let routes = vec![
261+
make_route("ollama-local", vec!["openai_chat_completions"]),
262+
];
263+
assert!(select_route(&routes, "anthropic_messages", None).is_none());
264+
}
190265
}

crates/openshell-router/tests/backend_integration.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ async fn proxy_forwards_request_to_backend() {
6666
vec![("content-type".to_string(), "application/json".to_string())],
6767
bytes::Bytes::from(body),
6868
&candidates,
69+
None,
6970
)
7071
.await
7172
.unwrap();
@@ -98,6 +99,7 @@ async fn proxy_upstream_401_returns_error() {
9899
vec![],
99100
bytes::Bytes::new(),
100101
&candidates,
102+
None,
101103
)
102104
.await
103105
.unwrap();
@@ -127,6 +129,7 @@ async fn proxy_no_compatible_route_returns_error() {
127129
vec![],
128130
bytes::Bytes::new(),
129131
&candidates,
132+
None,
130133
)
131134
.await
132135
.unwrap_err();
@@ -160,6 +163,7 @@ async fn proxy_strips_auth_header() {
160163
vec![("authorization".to_string(), "Bearer client-key".to_string())],
161164
bytes::Bytes::new(),
162165
&candidates,
166+
None,
163167
)
164168
.await
165169
.unwrap();
@@ -194,6 +198,7 @@ async fn proxy_mock_route_returns_canned_response() {
194198
vec![("content-type".to_string(), "application/json".to_string())],
195199
bytes::Bytes::from(body),
196200
&candidates,
201+
None,
197202
)
198203
.await
199204
.unwrap();
@@ -239,6 +244,7 @@ async fn proxy_overrides_model_in_request_body() {
239244
vec![("content-type".to_string(), "application/json".to_string())],
240245
bytes::Bytes::from(body),
241246
&candidates,
247+
None,
242248
)
243249
.await
244250
.unwrap();
@@ -277,6 +283,7 @@ async fn proxy_inserts_model_when_absent_from_body() {
277283
vec![("content-type".to_string(), "application/json".to_string())],
278284
bytes::Bytes::from(body),
279285
&candidates,
286+
None,
280287
)
281288
.await
282289
.unwrap();
@@ -332,6 +339,7 @@ async fn proxy_uses_x_api_key_for_anthropic_route() {
332339
],
333340
bytes::Bytes::from(body),
334341
&candidates,
342+
None,
335343
)
336344
.await
337345
.unwrap();
@@ -380,6 +388,7 @@ async fn proxy_anthropic_does_not_send_bearer_auth() {
380388
vec![("content-type".to_string(), "application/json".to_string())],
381389
bytes::Bytes::from(b"{}".to_vec()),
382390
&candidates,
391+
None,
383392
)
384393
.await
385394
.unwrap();
@@ -436,6 +445,7 @@ async fn proxy_forwards_client_anthropic_version_header() {
436445
],
437446
bytes::Bytes::from(body),
438447
&candidates,
448+
None,
439449
)
440450
.await
441451
.unwrap();

crates/openshell-sandbox/src/proxy.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ impl InferenceContext {
105105
) -> Result<openshell_router::ProxyResponse, openshell_router::RouterError> {
106106
let routes = self.system_routes.read().await;
107107
self.router
108-
.proxy_with_candidates(protocol, method, path, headers, body, &routes)
108+
.proxy_with_candidates(protocol, method, path, headers, body, &routes, None)
109109
.await
110110
}
111111
}
@@ -993,6 +993,12 @@ async fn route_inference_request(
993993
return Ok(true);
994994
}
995995

996+
// Extract the model field from the JSON body as a routing hint.
997+
// If parsing fails or model is absent, we fall back to protocol-only matching.
998+
let model_hint = serde_json::from_slice::<serde_json::Value>(&request.body)
999+
.ok()
1000+
.and_then(|v| v.get("model")?.as_str().map(String::from));
1001+
9961002
match ctx
9971003
.router
9981004
.proxy_with_candidates_streaming(
@@ -1002,6 +1008,7 @@ async fn route_inference_request(
10021008
filtered_headers,
10031009
bytes::Bytes::from(request.body.clone()),
10041010
&routes,
1011+
model_hint.as_deref(),
10051012
)
10061013
.await
10071014
{

0 commit comments

Comments
 (0)