Skip to content

Commit 394999b

Browse files
committed
chat_models can take str as input
1 parent 42f0362 commit 394999b

1 file changed

Lines changed: 9 additions & 3 deletions

File tree

src/agentlab/llm/chat_api.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import time
55
from dataclasses import dataclass
66
from functools import partial
7-
from typing import Optional
7+
from typing import Optional, Union
88

99
import openai
1010
from huggingface_hub import InferenceClient
@@ -13,7 +13,7 @@
1313
import agentlab.llm.tracking as tracking
1414
from agentlab.llm.base_api import AbstractChatModel, BaseModelArgs
1515
from agentlab.llm.huggingface_utils import HFBaseChatModel
16-
from agentlab.llm.llm_utils import AIMessage, Discussion
16+
from agentlab.llm.llm_utils import AIMessage, Discussion, HumanMessage
1717

1818

1919
def make_system_message(content: str) -> dict:
@@ -261,7 +261,13 @@ def __init__(
261261
**client_args,
262262
)
263263

264-
def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float = None) -> dict:
264+
def __call__(
265+
self, messages: Union[str, list[dict]], n_samples: int = 1, temperature: float = None
266+
) -> dict:
267+
268+
if isinstance(messages, str):
269+
messages = [HumanMessage(messages)]
270+
265271
# Initialize retry tracking attributes
266272
self.retries = 0
267273
self.success = False

0 commit comments

Comments
 (0)