Skip to content

Commit 6c7a537

Browse files
authored
feat: Propagate OTel context via WebSocket HTTP upgrade headers (#174)
Why === River's WebSocket connections don't carry any OTel context (traceparent, tracestate, baggage) from client to server. This means distributed tracing and baggage propagation are broken at the WebSocket boundary — the server has no way to inherit the caller's trace context or read OTel baggage entries. What changed ============ Uses the standard W3C HTTP header approach — the same mechanism any HTTP service uses for OTel propagation — applied to the WebSocket upgrade request. **Client side (`client_transport.py`, `v2/session.py`)** - Before calling `websockets.connect()`, inject the current OTel context into a headers dict via `propagate.inject()`. - Pass those headers as `extra_headers` (v1 legacy API) / `additional_headers` (v2 asyncio API) to the connect call. - This automatically includes `traceparent`, `tracestate`, and `baggage` headers if the corresponding propagators are configured in the global textmap. **Server side (`server.py`)** - In `Server.serve()`, extract the OTel context from `websocket.request_headers` via `propagate.extract()`. - Attach the extracted context as the ambient OTel context for the lifetime of the connection using `context.attach()` / `context.detach()`. - Any handler code running within the connection can now read baggage via `baggage.get_all()` and inherits the caller's trace context. **Tests (`tests/v1/test_opentelemetry.py`)** - `test_baggage_propagated_via_ws_headers`: Sets two baggage entries on the client, verifies the server handler can read them. - `test_no_baggage_when_none_set`: Verifies clean behavior when no baggage is set. - `test_traceparent_propagated_via_ws_headers`: Sets both an active span and baggage on the client, verifies both propagate. Test plan ========= ``` $ uv run pytest tests/ -v 64 passed in 8.46s ``` All existing tests pass unchanged. The 3 new tests verify end-to-end OTel context propagation through the WebSocket connection. ## Revertibility Safe to revert — only adds new `extra_headers`/`additional_headers` to `websockets.connect()` and a `propagate.extract()` + `context.attach()` wrapper on the server. No wire protocol changes, no schema changes, no data mutations. ~ written by Zerg 👾
1 parent bd88e45 commit 6c7a537

4 files changed

Lines changed: 278 additions & 31 deletions

File tree

src/replit_river/client_transport.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import nanoid
77
import websockets
8+
from opentelemetry import propagate
89
from pydantic import ValidationError
910
from websockets import (
1011
WebSocketCommonProtocol,
@@ -170,7 +171,13 @@ async def _establish_new_connection(
170171

171172
try:
172173
uri_and_metadata = await self._uri_and_metadata_factory()
173-
ws = await websockets.connect(uri_and_metadata["uri"], max_size=None)
174+
otel_headers: dict[str, str] = {}
175+
propagate.inject(otel_headers)
176+
ws = await websockets.connect(
177+
uri_and_metadata["uri"],
178+
max_size=None,
179+
extra_headers=otel_headers,
180+
)
174181
session_id = (
175182
self.generate_nanoid()
176183
if not old_session

src/replit_river/server.py

Lines changed: 39 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Mapping
44

55
import websockets
6+
from opentelemetry import context, propagate
67
from websockets.exceptions import ConnectionClosed
78
from websockets.server import WebSocketServerProtocol
89

@@ -68,34 +69,44 @@ async def serve(self, websocket: WebSocketServerProtocol) -> None:
6869
logger.debug(
6970
"River server started establishing session with ws: %s", websocket.id
7071
)
71-
grace_ms = self._transport_options.handshake_timeout_ms
72+
73+
# Extract OTel context (traceparent, tracestate, baggage) from the
74+
# WebSocket HTTP upgrade request headers and make it the ambient
75+
# context for the lifetime of this connection.
76+
otel_context = propagate.extract(websocket.request_headers)
77+
token = context.attach(otel_context)
78+
7279
try:
73-
session = await asyncio.wait_for(
74-
self._handshake_to_get_session(websocket),
75-
grace_ms / 1000, # wait_for unit is seconds
76-
)
77-
if not session:
80+
grace_ms = self._transport_options.handshake_timeout_ms
81+
try:
82+
session = await asyncio.wait_for(
83+
self._handshake_to_get_session(websocket),
84+
grace_ms / 1000, # wait_for unit is seconds
85+
)
86+
if not session:
87+
return
88+
except asyncio.TimeoutError:
89+
logger.error(f"Handshake timeout after {grace_ms}ms, closing websocket")
90+
await websocket.close()
7891
return
79-
except asyncio.TimeoutError:
80-
logger.error(f"Handshake timeout after {grace_ms}ms, closing websocket")
81-
await websocket.close()
82-
return
83-
except asyncio.CancelledError:
84-
logger.error("Handshake cancelled, closing websocket")
85-
await websocket.close()
86-
return
87-
logger.debug("River server session established, start serving messages")
92+
except asyncio.CancelledError:
93+
logger.error("Handshake cancelled, closing websocket")
94+
await websocket.close()
95+
return
96+
logger.debug("River server session established, start serving messages")
8897

89-
try:
90-
# Session serve will be closed in two cases
91-
# 1. websocket is closed
92-
# 2. exception thrown
93-
# session should be kept in order to be reused by the reconnect within the
94-
# grace period.
95-
await session.serve()
96-
except ConnectionClosed:
97-
logger.debug("ConnectionClosed while serving", exc_info=True)
98-
# We don't have to close the websocket here, it is already closed.
99-
except Exception:
100-
logger.exception("River transport error in server %s", self._server_id)
101-
await websocket.close()
98+
try:
99+
# Session serve will be closed in two cases
100+
# 1. websocket is closed
101+
# 2. exception thrown
102+
# session should be kept in order to be reused by the reconnect within
103+
# the grace period.
104+
await session.serve()
105+
except ConnectionClosed:
106+
logger.debug("ConnectionClosed while serving", exc_info=True)
107+
# We don't have to close the websocket here, it is already closed.
108+
except Exception:
109+
logger.exception("River transport error in server %s", self._server_id)
110+
await websocket.close()
111+
finally:
112+
context.detach(token)

src/replit_river/v2/session.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import websockets.asyncio.client
2424
from aiochannel import Channel, ChannelEmpty, ChannelFull
2525
from aiochannel.errors import ChannelClosed
26+
from opentelemetry import propagate
2627
from opentelemetry.trace import Span, use_span
2728
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
2829
from pydantic import ValidationError
@@ -1133,9 +1134,12 @@ async def _do_ensure_connected[HandshakeMetadata](
11331134
ws: ClientConnection | None = None
11341135
try:
11351136
uri_and_metadata = await uri_and_metadata_factory()
1137+
otel_headers: dict[str, str] = {}
1138+
propagate.inject(otel_headers)
11361139
ws = await websockets.asyncio.client.connect(
11371140
uri_and_metadata["uri"],
11381141
max_size=None,
1142+
additional_headers=otel_headers,
11391143
)
11401144
transition_connecting(ws)
11411145

tests/v1/test_opentelemetry.py

Lines changed: 227 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,25 @@
11
import contextlib
2+
import logging
23
from datetime import timedelta
3-
from typing import AsyncGenerator, AsyncIterator, Iterator
4+
from typing import AsyncGenerator, AsyncIterator, Iterator, Literal
45

56
import grpc
67
import grpc.aio
78
import pytest
9+
from opentelemetry import baggage, context, propagate, trace
10+
from opentelemetry.baggage.propagation import W3CBaggagePropagator
11+
from opentelemetry.propagators.composite import CompositePropagator
812
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
913
from opentelemetry.trace import StatusCode
14+
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
15+
from websockets.server import serve
1016

1117
from replit_river.client import Client
18+
from replit_river.client_transport import UriAndMetadata
1219
from replit_river.error_schema import RiverError, RiverException
13-
from replit_river.rpc import stream_method_handler
20+
from replit_river.rpc import rpc_method_handler, stream_method_handler
21+
from replit_river.server import Server
22+
from replit_river.transport_options import TransportOptions
1423
from tests.conftest import (
1524
HandlerMapping,
1625
deserialize_error,
@@ -219,3 +228,219 @@ async def stream_data() -> AsyncGenerator[str, None]:
219228
assert len(spans) == 1
220229
assert spans[0].name == "river.client.stream.test_service.stream_method"
221230
assert spans[0].status.status_code == StatusCode.OK
231+
232+
233+
# ===== OTel context propagation via WebSocket HTTP upgrade headers =====
234+
235+
236+
# A handler that reads OTel baggage from the ambient context and returns it.
237+
async def baggage_echo_handler(request: str, ctx: grpc.aio.ServicerContext) -> str:
238+
all_baggage = baggage.get_all()
239+
# Return baggage as a comma-separated "key=value" string
240+
return ",".join(f"{k}={v}" for k, v in sorted(all_baggage.items()))
241+
242+
243+
baggage_echo_handlers: HandlerMapping = {
244+
("test_service", "baggage_echo"): (
245+
"rpc",
246+
rpc_method_handler(
247+
baggage_echo_handler, deserialize_request, serialize_response
248+
),
249+
)
250+
}
251+
252+
253+
@pytest.fixture
254+
def _enable_baggage_propagator() -> Iterator[None]:
255+
"""Temporarily install a composite propagator that includes both
256+
W3C TraceContext and W3C Baggage propagation so that
257+
``propagate.inject()`` / ``propagate.extract()`` handle the
258+
``baggage`` HTTP header."""
259+
previous = propagate.get_global_textmap()
260+
propagate.set_global_textmap(
261+
CompositePropagator(
262+
[
263+
TraceContextTextMapPropagator(),
264+
W3CBaggagePropagator(),
265+
]
266+
)
267+
)
268+
yield
269+
propagate.set_global_textmap(previous)
270+
271+
272+
@pytest.mark.asyncio
273+
@pytest.mark.parametrize("handlers", [{**baggage_echo_handlers}])
274+
@pytest.mark.usefixtures("_enable_baggage_propagator")
275+
async def test_baggage_propagated_via_ws_headers(
276+
no_logging_error: NoErrors,
277+
server: Server,
278+
transport_options: TransportOptions,
279+
) -> None:
280+
"""Verify that OTel baggage set on the client side is propagated to the
281+
server via the WebSocket HTTP upgrade request headers."""
282+
283+
# Set baggage in the ambient OTel context *before* the client connects,
284+
# so that ``propagate.inject()`` (called inside ``websockets.connect()``)
285+
# includes the ``baggage`` header.
286+
ctx = baggage.set_baggage("test-key", "test-value")
287+
ctx = baggage.set_baggage("another-key", "another-value", context=ctx)
288+
token = context.attach(ctx)
289+
290+
binding = None
291+
try:
292+
binding = await serve(server.serve, "127.0.0.1")
293+
sockets = list(binding.sockets)
294+
assert len(sockets) == 1
295+
socket = sockets[0]
296+
297+
async def websocket_uri_factory() -> UriAndMetadata[None]:
298+
return {
299+
"uri": "ws://%s:%d" % socket.getsockname(),
300+
"metadata": None,
301+
}
302+
303+
client: Client[Literal[None]] = Client[None](
304+
uri_and_metadata_factory=websocket_uri_factory,
305+
client_id="test_client",
306+
server_id="test_server",
307+
transport_options=transport_options,
308+
)
309+
try:
310+
response = await client.send_rpc(
311+
"test_service",
312+
"baggage_echo",
313+
"ignored",
314+
serialize_request,
315+
deserialize_response,
316+
deserialize_error,
317+
timedelta(seconds=20),
318+
)
319+
# The handler returns sorted "key=value" pairs
320+
assert response == "another-key=another-value,test-key=test-value"
321+
finally:
322+
logging.debug("Start closing test client")
323+
await client.close()
324+
finally:
325+
context.detach(token)
326+
logging.debug("Start closing test server")
327+
if binding:
328+
binding.close()
329+
await server.close()
330+
if binding:
331+
await binding.wait_closed()
332+
333+
334+
@pytest.mark.asyncio
335+
@pytest.mark.parametrize("handlers", [{**baggage_echo_handlers}])
336+
@pytest.mark.usefixtures("_enable_baggage_propagator")
337+
async def test_no_baggage_when_none_set(
338+
no_logging_error: NoErrors,
339+
server: Server,
340+
transport_options: TransportOptions,
341+
) -> None:
342+
"""Verify that when no baggage is set, the server sees empty baggage."""
343+
344+
binding = None
345+
try:
346+
binding = await serve(server.serve, "127.0.0.1")
347+
sockets = list(binding.sockets)
348+
assert len(sockets) == 1
349+
socket = sockets[0]
350+
351+
async def websocket_uri_factory() -> UriAndMetadata[None]:
352+
return {
353+
"uri": "ws://%s:%d" % socket.getsockname(),
354+
"metadata": None,
355+
}
356+
357+
client: Client[Literal[None]] = Client[None](
358+
uri_and_metadata_factory=websocket_uri_factory,
359+
client_id="test_client",
360+
server_id="test_server",
361+
transport_options=transport_options,
362+
)
363+
try:
364+
response = await client.send_rpc(
365+
"test_service",
366+
"baggage_echo",
367+
"ignored",
368+
serialize_request,
369+
deserialize_response,
370+
deserialize_error,
371+
timedelta(seconds=20),
372+
)
373+
assert response == ""
374+
finally:
375+
logging.debug("Start closing test client")
376+
await client.close()
377+
finally:
378+
logging.debug("Start closing test server")
379+
if binding:
380+
binding.close()
381+
await server.close()
382+
if binding:
383+
await binding.wait_closed()
384+
385+
386+
@pytest.mark.asyncio
387+
@pytest.mark.parametrize("handlers", [{**baggage_echo_handlers}])
388+
@pytest.mark.usefixtures("_enable_baggage_propagator")
389+
async def test_traceparent_propagated_via_ws_headers(
390+
no_logging_error: NoErrors,
391+
server: Server,
392+
transport_options: TransportOptions,
393+
span_exporter: InMemorySpanExporter,
394+
) -> None:
395+
"""Verify that when a span is active on the client, the traceparent
396+
header is sent on the WS upgrade and the server-side context inherits
397+
the trace."""
398+
tracer = trace.get_tracer(__name__)
399+
400+
with tracer.start_as_current_span("client-operation"):
401+
# Also set some baggage
402+
ctx = baggage.set_baggage("trace-test", "yes")
403+
token = context.attach(ctx)
404+
405+
binding = None
406+
try:
407+
binding = await serve(server.serve, "127.0.0.1")
408+
sockets = list(binding.sockets)
409+
assert len(sockets) == 1
410+
socket = sockets[0]
411+
412+
async def websocket_uri_factory() -> UriAndMetadata[None]:
413+
return {
414+
"uri": "ws://%s:%d" % socket.getsockname(),
415+
"metadata": None,
416+
}
417+
418+
client: Client[Literal[None]] = Client[None](
419+
uri_and_metadata_factory=websocket_uri_factory,
420+
client_id="test_client",
421+
server_id="test_server",
422+
transport_options=transport_options,
423+
)
424+
try:
425+
response = await client.send_rpc(
426+
"test_service",
427+
"baggage_echo",
428+
"ignored",
429+
serialize_request,
430+
deserialize_response,
431+
deserialize_error,
432+
timedelta(seconds=20),
433+
)
434+
# Verify baggage was propagated
435+
assert response == "trace-test=yes"
436+
finally:
437+
logging.debug("Start closing test client")
438+
await client.close()
439+
finally:
440+
context.detach(token)
441+
logging.debug("Start closing test server")
442+
if binding:
443+
binding.close()
444+
await server.close()
445+
if binding:
446+
await binding.wait_closed()

0 commit comments

Comments
 (0)