1+ # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+ #
3+ # Licensed under the Apache License, Version 2.0 (the "License");
4+ # you may not use this file except in compliance with the License.
5+ # You may obtain a copy of the License at
6+ #
7+ # http://www.apache.org/licenses/LICENSE-2.0
8+ #
9+ # Unless required by applicable law or agreed to in writing, software
10+ # distributed under the License is distributed on an "AS IS" BASIS,
11+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ # See the License for the specific language governing permissions and
13+ # limitations under the License.
14+
15+ import pytest
16+ import traceback
17+ from fastdeploy import LLM , SamplingParams
18+ import os
19+ import subprocess
20+ import signal
21+
22+ FD_ENGINE_QUEUE_PORT = int (os .getenv ("FD_ENGINE_QUEUE_PORT" , 8313 ))
23+
24+ def format_chat_prompt (messages ):
25+ """
26+ Format multi-turn conversation into prompt string, suitable for chat models.
27+ Uses Qwen2 style with <|im_start|> / <|im_end|> tokens.
28+ """
29+ prompt = ""
30+ for msg in messages :
31+ role , content = msg ["role" ], msg ["content" ]
32+ if role == "user" :
33+ prompt += "<|im_start|>user\n {content}<|im_end|>\n " .format (content = content )
34+ elif role == "assistant" :
35+ prompt += "<|im_start|>assistant\n {content}<|im_end|>\n " .format (content = content )
36+ prompt += "<|im_start|>assistant\n "
37+ return prompt
38+
39+
40+ @pytest .fixture (scope = "module" )
41+ def model_path ():
42+ """
43+ Get model path from environment variable MODEL_PATH,
44+ default to "./Qwen2-7B-Instruct" if not set.
45+ """
46+ base_path = os .getenv ("MODEL_PATH" )
47+ if base_path :
48+ return os .path .join (base_path , "Qwen2-7B-Instruct" )
49+ else :
50+ return "./Qwen2-7B-Instruct"
51+
52+ @pytest .fixture (scope = "module" )
53+ def llm (model_path ):
54+ """
55+ Fixture to initialize the LLM model with a given model path
56+ """
57+ try :
58+ output = subprocess .check_output (f"lsof -i:{ FD_ENGINE_QUEUE_PORT } -t" , shell = True ).decode ().strip ()
59+ for pid in output .splitlines ():
60+ os .kill (int (pid ), signal .SIGKILL )
61+ print (f"Killed process on port { FD_ENGINE_QUEUE_PORT } , pid={ pid } " )
62+ except subprocess .CalledProcessError :
63+ pass
64+
65+ try :
66+ llm = LLM (
67+ model = model_path ,
68+ tensor_parallel_size = 1 ,
69+ engine_worker_queue_port = FD_ENGINE_QUEUE_PORT ,
70+ max_model_len = 4096
71+ )
72+ print ("Model loaded successfully from {}." .format (model_path ))
73+ yield llm
74+ except Exception :
75+ print ("Failed to load model from {}." .format (model_path ))
76+ traceback .print_exc ()
77+ pytest .fail ("Failed to initialize LLM model from {}" .format (model_path ))
78+
79+
80+ def test_generate_prompts (llm ):
81+ """
82+ Test basic prompt generation
83+ """
84+ # Only one prompt enabled for testing currently
85+ prompts = [
86+ "请介绍一下中国的四大发明。" ,
87+ # "太阳和地球之间的距离是多少?",
88+ # "写一首关于春天的古风诗。",
89+ ]
90+
91+ sampling_params = SamplingParams (
92+ temperature = 0.8 ,
93+ top_p = 0.95 ,
94+ )
95+
96+ try :
97+ outputs = llm .generate (prompts , sampling_params )
98+
99+ # Verify basic properties of the outputs
100+ assert len (outputs ) == len (prompts ), "Number of outputs should match number of prompts"
101+
102+ for i , output in enumerate (outputs ):
103+ assert output .prompt == prompts [i ], "Prompt mismatch for case {}" .format (i + 1 )
104+ assert isinstance (output .outputs .text , str ), "Output text should be string for case {}" .format (i + 1 )
105+ assert len (output .outputs .text ) > 0 , "Generated text should not be empty for case {}" .format (i + 1 )
106+ assert isinstance (output .finished , bool ), "'finished' should be boolean for case {}" .format (i + 1 )
107+ assert output .metrics .model_execute_time > 0 , "Execution time should be positive for case {}" .format (i + 1 )
108+
109+ print ("=== Prompt generation Case {} Passed ===" .format (i + 1 ))
110+
111+ except Exception :
112+ print ("Failed during prompt generation." )
113+ traceback .print_exc ()
114+ pytest .fail ("Prompt generation test failed" )
115+
116+
117+ def test_chat_completion (llm ):
118+ """
119+ Test chat completion with multiple turns
120+ """
121+ chat_cases = [
122+ [
123+ {"role" : "user" , "content" : "你好,请介绍一下你自己。" },
124+ ],
125+ [
126+ {"role" : "user" , "content" : "你知道地球到月球的距离是多少吗?" },
127+ {"role" : "assistant" , "content" : "大约是38万公里左右。" },
128+ {"role" : "user" , "content" : "那太阳到地球的距离是多少?" },
129+ ],
130+ [
131+ {"role" : "user" , "content" : "请给我起一个中文名。" },
132+ {"role" : "assistant" , "content" : "好的,你可以叫“星辰”。" },
133+ {"role" : "user" , "content" : "再起一个。" },
134+ {"role" : "assistant" , "content" : "那就叫”大海“吧。" },
135+ {"role" : "user" , "content" : "再来三个。" },
136+ ],
137+ ]
138+
139+ sampling_params = SamplingParams (
140+ temperature = 0.8 ,
141+ top_p = 0.95 ,
142+ )
143+
144+ for i , case in enumerate (chat_cases ):
145+ prompt = format_chat_prompt (case )
146+ try :
147+ outputs = llm .generate (prompt , sampling_params )
148+
149+ # Verify chat completion properties
150+ assert len (outputs ) == 1 , "Should return one output per prompt"
151+ assert isinstance (outputs [0 ].outputs .text , str ), "Output text should be string"
152+ assert len (outputs [0 ].outputs .text ) > 0 , "Generated text should not be empty"
153+ assert outputs [0 ].metrics .model_execute_time > 0 , "Execution time should be positive"
154+
155+ print ("=== Chat Case {} Passed ===" .format (i + 1 ))
156+
157+ except Exception :
158+ print ("[ERROR] Chat Case {} failed." .format (i + 1 ))
159+ traceback .print_exc ()
160+ pytest .fail ("Chat case {} failed" .format (i + 1 ))
161+
162+
163+ if __name__ == "__main__" :
164+ """
165+ Main entry point for the test script.
166+ """
167+ pytest .main (["-sv" , __file__ ])
0 commit comments