11# Generated by ariadne-codegen
22
3+ import asyncio
34import enum
45import json
56from collections .abc import AsyncIterator
1213
1314from .base_model import UNSET , Upload
1415from .exceptions import (
16+ GraphQLClientError ,
1517 GraphQLClientGraphQLMultiError ,
1618 GraphQLClientHttpError ,
1719 GraphQLClientInvalidMessageFormat ,
@@ -155,11 +157,11 @@ async def execute_ws(
155157 ** kwargs : Any ,
156158 ) -> AsyncIterator [dict [str , Any ]]:
157159 headers = self .ws_headers .copy ()
158- headers .update (kwargs .get ( "extra_headers " , {}))
160+ headers .update (kwargs .pop ( "additional_headers " , {}))
159161
160162 merged_kwargs : dict [str , Any ] = {"origin" : self .ws_origin }
161163 merged_kwargs .update (kwargs )
162- merged_kwargs ["extra_headers " ] = headers
164+ merged_kwargs ["additional_headers " ] = headers
163165
164166 operation_id = str (uuid4 ())
165167 async with ws_connect (
@@ -168,12 +170,17 @@ async def execute_ws(
168170 ** merged_kwargs ,
169171 ) as websocket :
170172 await self ._send_connection_init (websocket )
171- # wait for connection_ack from server
172- await self ._handle_ws_message (
173- await websocket .recv (),
174- websocket ,
175- expected_type = GraphQLTransportWSMessageType .CONNECTION_ACK ,
176- )
173+ # Wait for connection_ack; some servers (e.g. Hasura) send ping before
174+ # connection_ack, so we loop and handle pings until we get ack.
175+ try :
176+ await asyncio .wait_for (
177+ self ._wait_for_connection_ack (websocket ),
178+ timeout = 5.0 ,
179+ )
180+ except asyncio .TimeoutError as exc :
181+ raise GraphQLClientError (
182+ "Connection ack not received within 5 seconds"
183+ ) from exc
177184 await self ._send_subscribe (
178185 websocket ,
179186 operation_id = operation_id ,
@@ -184,7 +191,7 @@ async def execute_ws(
184191
185192 async for message in websocket :
186193 data = await self ._handle_ws_message (message , websocket )
187- if data :
194+ if data and "connection_ack" not in data :
188195 yield data
189196
190197 def _process_variables (
@@ -315,6 +322,13 @@ async def _send_connection_init(self, websocket: ClientConnection) -> None:
315322 payload ["payload" ] = self .ws_connection_init_payload
316323 await websocket .send (json .dumps (payload ))
317324
325+ async def _wait_for_connection_ack (self , websocket : ClientConnection ) -> None :
326+ """Read messages until connection_ack; handle ping/pong in between."""
327+ async for message in websocket :
328+ data = await self ._handle_ws_message (message , websocket )
329+ if data is not None and "connection_ack" in data :
330+ return
331+
318332 async def _send_subscribe (
319333 self ,
320334 websocket : ClientConnection ,
@@ -371,5 +385,7 @@ async def _handle_ws_message(
371385 raise GraphQLClientGraphQLMultiError .from_errors_dicts (
372386 errors_dicts = payload , data = message_dict
373387 )
388+ elif type_ == GraphQLTransportWSMessageType .CONNECTION_ACK :
389+ return {"connection_ack" : True }
374390
375391 return None
0 commit comments