Skip to content

Commit 7fa33ee

Browse files
committed
test(qwen): use torch_device and clarify dummy inputs in cfg mask tests
1 parent 42587ae commit 7fa33ee

6 files changed

Lines changed: 30 additions & 30 deletions

File tree

tests/pipelines/qwenimage/test_qwenimage.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -238,24 +238,24 @@ def test_vae_tiling(self, expected_diff_max: float = 0.2):
238238
def test_true_cfg_without_negative_prompt_embeds_mask(self):
239239
components = self.get_dummy_components()
240240
pipe = self.pipeline_class(**components)
241-
pipe.to("cpu")
241+
pipe.to(torch_device)
242242
pipe.set_progress_bar_config(disable=None)
243243

244-
inputs = self.get_dummy_inputs("cpu")
244+
inputs = self.get_dummy_inputs(torch_device)
245245
prompt = inputs.pop("prompt")
246246

247247
prompt_embeds, prompt_embeds_mask = pipe.encode_prompt(
248248
prompt=prompt,
249-
device="cpu",
249+
device=torch_device,
250250
num_images_per_prompt=1,
251251
max_sequence_length=inputs.get("max_sequence_length", 16),
252252
)
253253

254254
inputs["prompt_embeds"] = prompt_embeds
255255
inputs["prompt_embeds_mask"] = prompt_embeds_mask
256256
inputs["negative_prompt_embeds"] = prompt_embeds
257-
inputs["negative_prompt"] = None
258-
inputs["negative_prompt_embeds_mask"] = None
257+
inputs.pop("negative_prompt", None)
258+
inputs.pop("negative_prompt_embeds_mask", None)
259259
inputs["true_cfg_scale"] = 2.0
260260

261261
image = pipe(**inputs).images

tests/pipelines/qwenimage/test_qwenimage_controlnet.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -340,24 +340,24 @@ def test_vae_tiling(self, expected_diff_max: float = 0.2):
340340
def test_true_cfg_without_negative_prompt_embeds_mask(self):
341341
components = self.get_dummy_components()
342342
pipe = self.pipeline_class(**components)
343-
pipe.to("cpu")
343+
pipe.to(torch_device)
344344
pipe.set_progress_bar_config(disable=None)
345345

346-
inputs = self.get_dummy_inputs("cpu")
346+
inputs = self.get_dummy_inputs(torch_device)
347347
prompt = inputs.pop("prompt")
348348

349349
prompt_embeds, prompt_embeds_mask = pipe.encode_prompt(
350350
prompt=prompt,
351-
device="cpu",
351+
device=torch_device,
352352
num_images_per_prompt=1,
353353
max_sequence_length=inputs.get("max_sequence_length", 16),
354354
)
355355

356356
inputs["prompt_embeds"] = prompt_embeds
357357
inputs["prompt_embeds_mask"] = prompt_embeds_mask
358358
inputs["negative_prompt_embeds"] = prompt_embeds
359-
inputs["negative_prompt"] = None
360-
inputs["negative_prompt_embeds_mask"] = None
359+
inputs.pop("negative_prompt", None)
360+
inputs.pop("negative_prompt_embeds_mask", None)
361361
inputs["true_cfg_scale"] = 2.0
362362

363363
image = pipe(**inputs).images

tests/pipelines/qwenimage/test_qwenimage_edit.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -245,25 +245,25 @@ def test_encode_prompt_works_in_isolation(self, extra_required_param_value_dict=
245245
def test_true_cfg_without_negative_prompt_embeds_mask(self):
246246
components = self.get_dummy_components()
247247
pipe = self.pipeline_class(**components)
248-
pipe.to("cpu")
248+
pipe.to(torch_device)
249249
pipe.set_progress_bar_config(disable=None)
250250

251-
inputs = self.get_dummy_inputs("cpu")
251+
inputs = self.get_dummy_inputs(torch_device)
252252
prompt = inputs.pop("prompt")
253253

254254
prompt_embeds, prompt_embeds_mask = pipe.encode_prompt(
255255
prompt=prompt,
256256
image=inputs.get("image") if "image" in inputs else None,
257-
device="cpu",
257+
device=torch_device,
258258
num_images_per_prompt=1,
259259
max_sequence_length=inputs.get("max_sequence_length", 16),
260260
)
261261

262262
inputs["prompt_embeds"] = prompt_embeds
263263
inputs["prompt_embeds_mask"] = prompt_embeds_mask
264264
inputs["negative_prompt_embeds"] = prompt_embeds
265-
inputs["negative_prompt"] = None
266-
inputs["negative_prompt_embeds_mask"] = None
265+
inputs.pop("negative_prompt", None)
266+
inputs.pop("negative_prompt_embeds_mask", None)
267267
inputs["true_cfg_scale"] = 2.0
268268

269269
image = pipe(**inputs).images

tests/pipelines/qwenimage/test_qwenimage_edit_plus.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -255,25 +255,25 @@ def test_inference_batch_single_identical():
255255
def test_true_cfg_without_negative_prompt_embeds_mask(self):
256256
components = self.get_dummy_components()
257257
pipe = self.pipeline_class(**components)
258-
pipe.to("cpu")
258+
pipe.to(torch_device)
259259
pipe.set_progress_bar_config(disable=None)
260260

261-
inputs = self.get_dummy_inputs("cpu")
261+
inputs = self.get_dummy_inputs(torch_device)
262262
prompt = inputs.pop("prompt")
263263

264264
prompt_embeds, prompt_embeds_mask = pipe.encode_prompt(
265265
prompt=prompt,
266266
image=inputs.get("image") if "image" in inputs else None,
267-
device="cpu",
267+
device=torch_device,
268268
num_images_per_prompt=1,
269269
max_sequence_length=inputs.get("max_sequence_length", 16),
270270
)
271271

272272
inputs["prompt_embeds"] = prompt_embeds
273273
inputs["prompt_embeds_mask"] = prompt_embeds_mask
274274
inputs["negative_prompt_embeds"] = prompt_embeds
275-
inputs["negative_prompt"] = None
276-
inputs["negative_prompt_embeds_mask"] = None
275+
inputs.pop("negative_prompt", None)
276+
inputs.pop("negative_prompt_embeds_mask", None)
277277
inputs["true_cfg_scale"] = 2.0
278278

279279
image = pipe(**inputs).images

tests/pipelines/qwenimage/test_qwenimage_img2img.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -220,24 +220,24 @@ def test_vae_tiling(self, expected_diff_max: float = 0.2):
220220
def test_true_cfg_without_negative_prompt_embeds_mask(self):
221221
components = self.get_dummy_components()
222222
pipe = self.pipeline_class(**components)
223-
pipe.to("cpu")
223+
pipe.to(torch_device)
224224
pipe.set_progress_bar_config(disable=None)
225225

226-
inputs = self.get_dummy_inputs("cpu")
226+
inputs = self.get_dummy_inputs(torch_device)
227227
prompt = inputs.pop("prompt")
228228

229229
prompt_embeds, prompt_embeds_mask = pipe.encode_prompt(
230230
prompt=prompt,
231-
device="cpu",
231+
device=torch_device,
232232
num_images_per_prompt=1,
233233
max_sequence_length=inputs.get("max_sequence_length", 16),
234234
)
235235

236236
inputs["prompt_embeds"] = prompt_embeds
237237
inputs["prompt_embeds_mask"] = prompt_embeds_mask
238238
inputs["negative_prompt_embeds"] = prompt_embeds
239-
inputs["negative_prompt"] = None
240-
inputs["negative_prompt_embeds_mask"] = None
239+
inputs.pop("negative_prompt", None)
240+
inputs.pop("negative_prompt_embeds_mask", None)
241241
inputs["true_cfg_scale"] = 2.0
242242

243243
image = pipe(**inputs).images

tests/pipelines/qwenimage/test_qwenimage_inpaint.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -235,24 +235,24 @@ def test_vae_tiling(self, expected_diff_max: float = 0.2):
235235
def test_true_cfg_without_negative_prompt_embeds_mask(self):
236236
components = self.get_dummy_components()
237237
pipe = self.pipeline_class(**components)
238-
pipe.to("cpu")
238+
pipe.to(torch_device)
239239
pipe.set_progress_bar_config(disable=None)
240240

241-
inputs = self.get_dummy_inputs("cpu")
241+
inputs = self.get_dummy_inputs(torch_device)
242242
prompt = inputs.pop("prompt")
243243

244244
prompt_embeds, prompt_embeds_mask = pipe.encode_prompt(
245245
prompt=prompt,
246-
device="cpu",
246+
device=torch_device,
247247
num_images_per_prompt=1,
248248
max_sequence_length=inputs.get("max_sequence_length", 16),
249249
)
250250

251251
inputs["prompt_embeds"] = prompt_embeds
252252
inputs["prompt_embeds_mask"] = prompt_embeds_mask
253253
inputs["negative_prompt_embeds"] = prompt_embeds
254-
inputs["negative_prompt"] = None
255-
inputs["negative_prompt_embeds_mask"] = None
254+
inputs.pop("negative_prompt", None)
255+
inputs.pop("negative_prompt_embeds_mask", None)
256256
inputs["true_cfg_scale"] = 2.0
257257

258258
image = pipe(**inputs).images

0 commit comments

Comments
 (0)