|
1 | 1 | import asyncio |
| 2 | +import contextvars |
| 3 | +import functools |
2 | 4 | import inspect |
3 | 5 | from concurrent.futures import Executor |
4 | 6 | from logging import getLogger |
5 | 7 | from time import time |
6 | | -from typing import Any, Callable, Dict, List, Optional, Set, Union, get_type_hints |
| 8 | +from typing import Any, Callable, Dict, Optional, Set, Union, get_type_hints |
7 | 9 |
|
8 | 10 | import anyio |
9 | 11 | from taskiq_dependencies import DependencyGraph |
|
23 | 25 | QUEUE_DONE = b"-1" |
24 | 26 |
|
25 | 27 |
|
26 | | -def _run_sync( |
27 | | - target: Callable[..., Any], |
28 | | - args: List[Any], |
29 | | - kwargs: Dict[str, Any], |
30 | | -) -> Any: |
31 | | - """ |
32 | | - Runs function synchronously. |
33 | | -
|
34 | | - We use this function, because |
35 | | - we cannot pass kwargs in loop.run_with_executor(). |
36 | | -
|
37 | | - :param target: function to execute. |
38 | | - :param args: list of function's args. |
39 | | - :param kwargs: dict of function's kwargs. |
40 | | - :return: result of function's execution. |
41 | | - """ |
42 | | - return target(*args, **kwargs) |
43 | | - |
44 | | - |
45 | 28 | class Receiver: |
46 | 29 | """Class that uses as a callback handler.""" |
47 | 30 |
|
@@ -255,13 +238,13 @@ async def run_task( # noqa: C901, PLR0912, PLR0915 |
255 | 238 | else: |
256 | 239 | is_coroutine = False |
257 | 240 | # If this is a synchronous function, we |
258 | | - # run it in executor. |
| 241 | + # run it in executor and preserve the context. |
| 242 | + ctx = contextvars.copy_context() |
| 243 | + func = functools.partial(target, *message.args, **kwargs) |
259 | 244 | target_future = loop.run_in_executor( |
260 | 245 | self.executor, |
261 | | - _run_sync, |
262 | | - target, |
263 | | - message.args, |
264 | | - kwargs, |
| 246 | + ctx.run, |
| 247 | + func, |
265 | 248 | ) |
266 | 249 | timeout = message.labels.get("timeout") |
267 | 250 | if timeout is not None: |
|
0 commit comments