|
23 | 23 |
|
24 | 24 | import msgspec |
25 | 25 | import pytest |
| 26 | +import structlog |
26 | 27 |
|
27 | 28 | from airflow.sdk import timezone |
28 | | -from airflow.sdk.execution_time.comms import BundleInfo, MaskSecret, StartupDetails, _ResponseFrame |
29 | | -from airflow.sdk.execution_time.task_runner import CommsDecoder |
| 29 | +from airflow.sdk.execution_time.comms import ( |
| 30 | + BundleInfo, |
| 31 | + CommsDecoder, |
| 32 | + MaskSecret, |
| 33 | + StartupDetails, |
| 34 | + VariableResult, |
| 35 | + _ResponseFrame, |
| 36 | +) |
30 | 37 |
|
31 | 38 |
|
32 | 39 | class TestCommsModels: |
@@ -148,3 +155,41 @@ def test_huge_payload(self): |
148 | 155 | # It actually failed to read at all for large values, but lets just make sure we get it all |
149 | 156 | assert len(msg.value) == 10 * 1024 * 1024 + 1 |
150 | 157 | assert msg.value[-1] == "b" |
| 158 | + |
| 159 | + def test_send_thread_safety(self): |
| 160 | + r, w = socketpair() |
| 161 | + decoder = CommsDecoder(socket=r, log=structlog.get_logger()) |
| 162 | + num_threads = 5 |
| 163 | + results = [None] * num_threads |
| 164 | + errors = [None] * num_threads |
| 165 | + request_sent = [threading.Event() for _ in range(num_threads)] |
| 166 | + |
| 167 | + def send_and_store(idx): |
| 168 | + request_sent[idx].set() # Signal that this thread is about to send |
| 169 | + try: |
| 170 | + msg = VariableResult(key=f"key{idx}", value=f"value{idx}", type="VariableResult") |
| 171 | + results[idx] = decoder.send(msg) |
| 172 | + except Exception as e: |
| 173 | + errors[idx] = e |
| 174 | + |
| 175 | + threads = [threading.Thread(target=send_and_store, args=(i,)) for i in range(num_threads)] |
| 176 | + for t in threads: |
| 177 | + t.start() |
| 178 | + |
| 179 | + # For each thread, wait until it signals it's ready, then send the response |
| 180 | + for idx in range(num_threads): |
| 181 | + request_sent[idx].wait() |
| 182 | + resp = {"type": "VariableResult", "key": f"key{idx}", "value": f"value{idx}"} |
| 183 | + frame = _ResponseFrame(idx, resp, None) |
| 184 | + data = msgspec.msgpack.encode(frame) |
| 185 | + w.sendall(len(data).to_bytes(4, byteorder="big") + data) |
| 186 | + |
| 187 | + for t in threads: |
| 188 | + t.join(timeout=5) |
| 189 | + for idx, t in enumerate(threads): |
| 190 | + assert not t.is_alive(), f"Thread {idx} did not finish (possible deadlock or hang in send method)" |
| 191 | + |
| 192 | + for idx in range(num_threads): |
| 193 | + assert errors[idx] is None, f"Thread {idx} error: {errors[idx]}" |
| 194 | + assert results[idx].key == f"key{idx}", f"Out-of-order or missing response for thread {idx}" |
| 195 | + assert results[idx].value == f"value{idx}", f"Incorrect value for thread {idx}" |
0 commit comments