Skip to content

Commit 0c2de80

Browse files
Merge pull request #17 from PlayerData/regenerate-ariadne-codegen-again
chore: regenerate ariadne codegen
2 parents 3d5d897 + 496cc20 commit 0c2de80

8 files changed

Lines changed: 2208 additions & 543 deletions

File tree

playerdatapy/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
AthleteRelativeAccelzoneAttributes,
7474
AthleteRelativeDecelzoneAttributes,
7575
AthleteSpeedzoneAttributes,
76+
BandedJumpZoneLowerBoundsInput,
7677
BulkUpdateMatchEventAttributes,
7778
ClaimPersonAttributes,
7879
ClubContextAttributes,
@@ -164,6 +165,7 @@
164165
"AthleteRelativeAccelzoneAttributes",
165166
"AthleteRelativeDecelzoneAttributes",
166167
"AthleteSpeedzoneAttributes",
168+
"BandedJumpZoneLowerBoundsInput",
167169
"BaseModel",
168170
"BulkUpdateMatchEventAttributes",
169171
"ChartDataTypeEnum",

playerdatapy/async_base_client.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Generated by ariadne-codegen
22

3+
import asyncio
34
import enum
45
import json
56
from collections.abc import AsyncIterator
@@ -12,6 +13,7 @@
1213

1314
from .base_model import UNSET, Upload
1415
from .exceptions import (
16+
GraphQLClientError,
1517
GraphQLClientGraphQLMultiError,
1618
GraphQLClientHttpError,
1719
GraphQLClientInvalidMessageFormat,
@@ -155,11 +157,11 @@ async def execute_ws(
155157
**kwargs: Any,
156158
) -> AsyncIterator[dict[str, Any]]:
157159
headers = self.ws_headers.copy()
158-
headers.update(kwargs.get("extra_headers", {}))
160+
headers.update(kwargs.pop("additional_headers", {}))
159161

160162
merged_kwargs: dict[str, Any] = {"origin": self.ws_origin}
161163
merged_kwargs.update(kwargs)
162-
merged_kwargs["extra_headers"] = headers
164+
merged_kwargs["additional_headers"] = headers
163165

164166
operation_id = str(uuid4())
165167
async with ws_connect(
@@ -168,12 +170,17 @@ async def execute_ws(
168170
**merged_kwargs,
169171
) as websocket:
170172
await self._send_connection_init(websocket)
171-
# wait for connection_ack from server
172-
await self._handle_ws_message(
173-
await websocket.recv(),
174-
websocket,
175-
expected_type=GraphQLTransportWSMessageType.CONNECTION_ACK,
176-
)
173+
# Wait for connection_ack; some servers (e.g. Hasura) send ping before
174+
# connection_ack, so we loop and handle pings until we get ack.
175+
try:
176+
await asyncio.wait_for(
177+
self._wait_for_connection_ack(websocket),
178+
timeout=5.0,
179+
)
180+
except asyncio.TimeoutError as exc:
181+
raise GraphQLClientError(
182+
"Connection ack not received within 5 seconds"
183+
) from exc
177184
await self._send_subscribe(
178185
websocket,
179186
operation_id=operation_id,
@@ -184,7 +191,7 @@ async def execute_ws(
184191

185192
async for message in websocket:
186193
data = await self._handle_ws_message(message, websocket)
187-
if data:
194+
if data and "connection_ack" not in data:
188195
yield data
189196

190197
def _process_variables(
@@ -315,6 +322,13 @@ async def _send_connection_init(self, websocket: ClientConnection) -> None:
315322
payload["payload"] = self.ws_connection_init_payload
316323
await websocket.send(json.dumps(payload))
317324

325+
async def _wait_for_connection_ack(self, websocket: ClientConnection) -> None:
326+
"""Read messages until connection_ack; handle ping/pong in between."""
327+
async for message in websocket:
328+
data = await self._handle_ws_message(message, websocket)
329+
if data is not None and "connection_ack" in data:
330+
return
331+
318332
async def _send_subscribe(
319333
self,
320334
websocket: ClientConnection,
@@ -371,5 +385,7 @@ async def _handle_ws_message(
371385
raise GraphQLClientGraphQLMultiError.from_errors_dicts(
372386
errors_dicts=payload, data=message_dict
373387
)
388+
elif type_ == GraphQLTransportWSMessageType.CONNECTION_ACK:
389+
return {"connection_ack": True}
374390

375391
return None

0 commit comments

Comments
 (0)