|
3 | 3 | import os |
4 | 4 | import time |
5 | 5 | import uuid |
6 | | -from typing import AsyncIterator, Iterator |
| 6 | +from typing import AsyncIterator, Iterator, Dict, List |
7 | 7 |
|
8 | 8 | from langchain_core.documents import Document |
9 | 9 | from loguru import logger |
@@ -48,9 +48,33 @@ def __init__(self, unique_id: str, workflow_id: str, chat_id: str, user_id: int) |
48 | 48 | self.workflow_stop_key = f'workflow:{unique_id}:stop' |
49 | 49 | self.workflow_expire_time = settings.get_workflow_conf().timeout * 60 + 60 |
50 | 50 |
|
51 | | - def set_workflow_data(self, data: dict): |
| 51 | + def set_workflow_data(self, data: Dict, override: Dict = None): |
| 52 | + data = self.override_nodes_params(data, override) |
52 | 53 | self.redis_client.set(self.workflow_data_key, data, expiration=self.workflow_expire_time) |
53 | 54 |
|
| 55 | + @staticmethod |
| 56 | + def override_nodes_params(data: Dict, override: Dict = None) -> Dict: |
| 57 | + if not override: |
| 58 | + return data |
| 59 | + |
| 60 | + def replace_param(one_params: List[Dict], one_node_id: str): |
| 61 | + for param in one_params: |
| 62 | + param_key = param.get('key') |
| 63 | + if param_key not in override[node_id]: |
| 64 | + continue |
| 65 | + param['value'] = override[one_node_id][param_key] |
| 66 | + |
| 67 | + nodes = data.get('nodes', []) |
| 68 | + for node in nodes: |
| 69 | + node_data = node.get('data', {}) |
| 70 | + node_id = node_data.get('id') |
| 71 | + if node_id not in override: |
| 72 | + continue |
| 73 | + group_params = node_data.get('group_params', []) |
| 74 | + for group_param in group_params: |
| 75 | + replace_param(group_param.get('params', []), node_id) |
| 76 | + return data |
| 77 | + |
54 | 78 | async def async_set_workflow_data(self, data: dict): |
55 | 79 | await self.redis_client.aset(self.workflow_data_key, data, expiration=self.workflow_expire_time) |
56 | 80 |
|
|
0 commit comments