|
1 | 1 | import asyncio |
2 | 2 | import logging |
3 | | -from collections.abc import AsyncIterator |
4 | | -from typing import Any, AsyncGenerator, Iterator, Literal |
| 3 | +from typing import Any, AsyncGenerator, Literal, Mapping |
5 | 4 |
|
6 | | -import grpc.aio |
7 | | -import nanoid # type: ignore |
| 5 | +import nanoid |
8 | 6 | import pytest |
9 | 7 | from opentelemetry import trace |
10 | 8 | from opentelemetry.sdk.trace import TracerProvider |
|
14 | 12 |
|
15 | 13 | from replit_river.client import Client |
16 | 14 | from replit_river.client_transport import UriAndMetadata |
17 | | -from replit_river.error_schema import RiverError, RiverException |
| 15 | +from replit_river.error_schema import RiverError |
18 | 16 | from replit_river.rpc import ( |
| 17 | + GenericRpcHandler, |
19 | 18 | TransportMessage, |
20 | | - rpc_method_handler, |
21 | | - stream_method_handler, |
22 | | - subscription_method_handler, |
23 | | - upload_method_handler, |
24 | 19 | ) |
25 | 20 | from replit_river.server import Server |
26 | 21 | from replit_river.transport_options import TransportOptions |
|
29 | 24 | # Modular fixtures |
30 | 25 | pytest_plugins = ["tests.river_fixtures.logging"] |
31 | 26 |
|
| 27 | +HandlerMapping = Mapping[tuple[str, str], tuple[str, GenericRpcHandler]] |
| 28 | + |
32 | 29 |
|
33 | 30 | def transport_message( |
34 | 31 | seq: int = 0, |
@@ -71,93 +68,22 @@ def deserialize_error(response: dict) -> RiverError: |
71 | 68 | return RiverError.model_validate(response) |
72 | 69 |
|
73 | 70 |
|
74 | | -# RPC method handlers for testing |
75 | | -async def rpc_handler(request: str, context: grpc.aio.ServicerContext) -> str: |
76 | | - return f"Hello, {request}!" |
77 | | - |
78 | | - |
79 | | -async def subscription_handler( |
80 | | - request: str, context: grpc.aio.ServicerContext |
81 | | -) -> AsyncGenerator[str, None]: |
82 | | - for i in range(5): |
83 | | - yield f"Subscription message {i} for {request}" |
84 | | - |
85 | | - |
86 | | -async def upload_handler( |
87 | | - request: Iterator[str] | AsyncIterator[str], context: Any |
88 | | -) -> str: |
89 | | - uploaded_data = [] |
90 | | - if isinstance(request, AsyncIterator): |
91 | | - async for data in request: |
92 | | - uploaded_data.append(data) |
93 | | - else: |
94 | | - for data in request: |
95 | | - uploaded_data.append(data) |
96 | | - return f"Uploaded: {', '.join(uploaded_data)}" |
97 | | - |
98 | | - |
99 | | -async def stream_handler( |
100 | | - request: Iterator[str] | AsyncIterator[str], |
101 | | - context: grpc.aio.ServicerContext, |
102 | | -) -> AsyncGenerator[str, None]: |
103 | | - if isinstance(request, AsyncIterator): |
104 | | - async for data in request: |
105 | | - yield f"Stream response for {data}" |
106 | | - else: |
107 | | - for data in request: |
108 | | - yield f"Stream response for {data}" |
109 | | - |
110 | | - |
111 | | -async def stream_error_handler( |
112 | | - request: Iterator[str] | AsyncIterator[str], |
113 | | - context: grpc.aio.ServicerContext, |
114 | | -) -> AsyncGenerator[str, None]: |
115 | | - raise RiverException("INJECTED_ERROR", "test error") |
116 | | - yield "test" # appease the type checker |
117 | | - |
118 | | - |
119 | 71 | @pytest.fixture |
120 | 72 | def transport_options() -> TransportOptions: |
121 | 73 | return TransportOptions() |
122 | 74 |
|
123 | 75 |
|
124 | 76 | @pytest.fixture |
125 | | -def server(transport_options: TransportOptions) -> Server: |
| 77 | +def server_handlers(handlers: HandlerMapping) -> HandlerMapping: |
| 78 | + return handlers |
| 79 | + |
| 80 | + |
| 81 | +@pytest.fixture |
| 82 | +def server( |
| 83 | + transport_options: TransportOptions, server_handlers: HandlerMapping |
| 84 | +) -> Server: |
126 | 85 | server = Server(server_id="test_server", transport_options=transport_options) |
127 | | - server.add_rpc_handlers( |
128 | | - { |
129 | | - ("test_service", "rpc_method"): ( |
130 | | - "rpc", |
131 | | - rpc_method_handler( |
132 | | - rpc_handler, deserialize_request, serialize_response |
133 | | - ), |
134 | | - ), |
135 | | - ("test_service", "subscription_method"): ( |
136 | | - "subscription", |
137 | | - subscription_method_handler( |
138 | | - subscription_handler, deserialize_request, serialize_response |
139 | | - ), |
140 | | - ), |
141 | | - ("test_service", "upload_method"): ( |
142 | | - "upload", |
143 | | - upload_method_handler( |
144 | | - upload_handler, deserialize_request, serialize_response |
145 | | - ), |
146 | | - ), |
147 | | - ("test_service", "stream_method"): ( |
148 | | - "stream", |
149 | | - stream_method_handler( |
150 | | - stream_handler, deserialize_request, serialize_response |
151 | | - ), |
152 | | - ), |
153 | | - ("test_service", "stream_method_error"): ( |
154 | | - "stream", |
155 | | - stream_method_handler( |
156 | | - stream_error_handler, deserialize_request, serialize_response |
157 | | - ), |
158 | | - ), |
159 | | - } |
160 | | - ) |
| 86 | + server.add_rpc_handlers(server_handlers) |
161 | 87 | return server |
162 | 88 |
|
163 | 89 |
|
|
0 commit comments