Skip to content

Commit 3ac0d74

Browse files
authored
Fix read out-of-order issue with send method in CommsDecoder (#64894)
* refactor: Fix read out-of-order issue with send method in CommsDecoder
1 parent 6fd0142 commit 3ac0d74

2 files changed

Lines changed: 48 additions & 3 deletions

File tree

task-sdk/src/airflow/sdk/execution_time/comms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ def send(self, msg: SendMsgType) -> ReceiveMsgType | None:
216216
# always be in the return type union
217217
return resp # type: ignore[return-value]
218218

219-
return self._get_response()
219+
return self._get_response()
220220

221221
async def asend(self, msg: SendMsgType) -> ReceiveMsgType | None:
222222
"""

task-sdk/tests/task_sdk/execution_time/test_comms.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,17 @@
2323

2424
import msgspec
2525
import pytest
26+
import structlog
2627

2728
from airflow.sdk import timezone
28-
from airflow.sdk.execution_time.comms import BundleInfo, MaskSecret, StartupDetails, _ResponseFrame
29-
from airflow.sdk.execution_time.task_runner import CommsDecoder
29+
from airflow.sdk.execution_time.comms import (
30+
BundleInfo,
31+
CommsDecoder,
32+
MaskSecret,
33+
StartupDetails,
34+
VariableResult,
35+
_ResponseFrame,
36+
)
3037

3138

3239
class TestCommsModels:
@@ -148,3 +155,41 @@ def test_huge_payload(self):
148155
# It actually failed to read at all for large values, but lets just make sure we get it all
149156
assert len(msg.value) == 10 * 1024 * 1024 + 1
150157
assert msg.value[-1] == "b"
158+
159+
def test_send_thread_safety(self):
160+
r, w = socketpair()
161+
decoder = CommsDecoder(socket=r, log=structlog.get_logger())
162+
num_threads = 5
163+
results = [None] * num_threads
164+
errors = [None] * num_threads
165+
request_sent = [threading.Event() for _ in range(num_threads)]
166+
167+
def send_and_store(idx):
168+
request_sent[idx].set() # Signal that this thread is about to send
169+
try:
170+
msg = VariableResult(key=f"key{idx}", value=f"value{idx}", type="VariableResult")
171+
results[idx] = decoder.send(msg)
172+
except Exception as e:
173+
errors[idx] = e
174+
175+
threads = [threading.Thread(target=send_and_store, args=(i,)) for i in range(num_threads)]
176+
for t in threads:
177+
t.start()
178+
179+
# For each thread, wait until it signals it's ready, then send the response
180+
for idx in range(num_threads):
181+
request_sent[idx].wait()
182+
resp = {"type": "VariableResult", "key": f"key{idx}", "value": f"value{idx}"}
183+
frame = _ResponseFrame(idx, resp, None)
184+
data = msgspec.msgpack.encode(frame)
185+
w.sendall(len(data).to_bytes(4, byteorder="big") + data)
186+
187+
for t in threads:
188+
t.join(timeout=5)
189+
for idx, t in enumerate(threads):
190+
assert not t.is_alive(), f"Thread {idx} did not finish (possible deadlock or hang in send method)"
191+
192+
for idx in range(num_threads):
193+
assert errors[idx] is None, f"Thread {idx} error: {errors[idx]}"
194+
assert results[idx].key == f"key{idx}", f"Out-of-order or missing response for thread {idx}"
195+
assert results[idx].value == f"value{idx}", f"Incorrect value for thread {idx}"

0 commit comments

Comments
 (0)