Skip to content

Commit 8d30d05

Browse files
fix(qwen): fix CFG failing when passing neg prompt embeds with none mask (#13379)
* fix(qwen): fix CFG failing when passing neg prompt embeds with none mask * fix(qwen): safely handle missing embeds masks in edit and inpaint pipelines * test(qwen): add tests for true cfg scale without neg prompt mask * fix(qwen): correct comments for copied functions in controlnet and inpaint pipelines * fix(qwen): add warnings for missing prompt and negative prompt masks in pipelines * test(qwen): use torch_device and clarify dummy inputs in cfg mask tests * fix(qwen): address Claude PR review feedback * fix(qwen): fix warning message based on reviewer suggestion --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent 3a4421c commit 8d30d05

16 files changed

Lines changed: 352 additions & 93 deletions

src/diffusers/modular_pipelines/qwenimage/encoders.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -117,15 +117,15 @@ def get_qwen_prompt_embeds_edit(
117117
).to(device)
118118

119119
outputs = text_encoder(
120-
input_ids=model_inputs.input_ids,
121-
attention_mask=model_inputs.attention_mask,
122-
pixel_values=model_inputs.pixel_values,
123-
image_grid_thw=model_inputs.image_grid_thw,
120+
input_ids=model_inputs["input_ids"],
121+
attention_mask=model_inputs["attention_mask"],
122+
pixel_values=model_inputs.get("pixel_values"),
123+
image_grid_thw=model_inputs.get("image_grid_thw"),
124124
output_hidden_states=True,
125125
)
126126

127127
hidden_states = outputs.hidden_states[-1]
128-
split_hidden_states = _extract_masked_hidden(hidden_states, model_inputs.attention_mask)
128+
split_hidden_states = _extract_masked_hidden(hidden_states, model_inputs["attention_mask"])
129129
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
130130
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
131131
max_seq_len = max([e.size(0) for e in split_hidden_states])
@@ -173,15 +173,15 @@ def get_qwen_prompt_embeds_edit_plus(
173173
return_tensors="pt",
174174
).to(device)
175175
outputs = text_encoder(
176-
input_ids=model_inputs.input_ids,
177-
attention_mask=model_inputs.attention_mask,
178-
pixel_values=model_inputs.pixel_values,
179-
image_grid_thw=model_inputs.image_grid_thw,
176+
input_ids=model_inputs["input_ids"],
177+
attention_mask=model_inputs["attention_mask"],
178+
pixel_values=model_inputs.get("pixel_values"),
179+
image_grid_thw=model_inputs.get("image_grid_thw"),
180180
output_hidden_states=True,
181181
)
182182

183183
hidden_states = outputs.hidden_states[-1]
184-
split_hidden_states = _extract_masked_hidden(hidden_states, model_inputs.attention_mask)
184+
split_hidden_states = _extract_masked_hidden(hidden_states, model_inputs["attention_mask"])
185185
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
186186
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
187187
max_seq_len = max([e.size(0) for e in split_hidden_states])

src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,22 @@ def check_inputs(
311311
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
312312
)
313313

314+
if prompt_embeds is not None and prompt_embeds_mask is None:
315+
logger.warning(
316+
"`prompt_embeds` is provided and `prompt_embeds_mask` is not provided, so the model will treat all"
317+
" prompt tokens as valid. If `prompt_embeds` contains padding, you should provide the padding mask as"
318+
" `prompt_embeds_mask`. Make sure to generate `prompt_embeds_mask` from the same text encoder that was"
319+
" used to generate `prompt_embeds`."
320+
)
321+
322+
if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
323+
logger.warning(
324+
"`negative_prompt_embeds` is provided and `negative_prompt_embeds_mask` is not provided, so the model will treat all"
325+
" negative prompt tokens as valid. If `negative_prompt_embeds` contains padding, you should provide the padding mask as"
326+
" `negative_prompt_embeds_mask`. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was"
327+
" used to generate `negative_prompt_embeds`."
328+
)
329+
314330
if max_sequence_length is not None and max_sequence_length > 1024:
315331
raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
316332

@@ -584,9 +600,7 @@ def __call__(
584600

585601
device = self._execution_device
586602

587-
has_neg_prompt = negative_prompt is not None or (
588-
negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
589-
)
603+
has_neg_prompt = negative_prompt is not None or negative_prompt_embeds is not None
590604

591605
if true_cfg_scale > 1 and not has_neg_prompt:
592606
logger.warning(

src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@
101101
"""
102102

103103

104-
# Coped from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift
104+
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift
105105
def calculate_shift(
106106
image_seq_len,
107107
base_seq_len: int = 256,
@@ -239,7 +239,7 @@ def __init__(
239239
self.prompt_template_encode_start_idx = 34
240240
self.default_sample_size = 128
241241

242-
# Coped from diffusers.pipelines.qwenimage.pipeline_qwenimage.extract_masked_hidden
242+
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._extract_masked_hidden
243243
def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
244244
bool_mask = mask.bool()
245245
valid_lengths = bool_mask.sum(dim=1)
@@ -248,7 +248,7 @@ def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor
248248

249249
return split_result
250250

251-
# Coped from diffusers.pipelines.qwenimage.pipeline_qwenimage.get_qwen_prompt_embeds
251+
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._get_qwen_prompt_embeds
252252
def _get_qwen_prompt_embeds(
253253
self,
254254
prompt: str | list[str] = None,
@@ -287,7 +287,7 @@ def _get_qwen_prompt_embeds(
287287

288288
return prompt_embeds, encoder_attention_mask
289289

290-
# Coped from diffusers.pipelines.qwenimage.pipeline_qwenimage.encode_prompt
290+
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline.encode_prompt
291291
def encode_prompt(
292292
self,
293293
prompt: str | list[str],
@@ -318,11 +318,13 @@ def encode_prompt(
318318
if prompt_embeds is None:
319319
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device)
320320

321+
prompt_embeds = prompt_embeds[:, :max_sequence_length]
321322
_, seq_len, _ = prompt_embeds.shape
322323
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
323324
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
324325

325326
if prompt_embeds_mask is not None:
327+
prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length]
326328
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
327329
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
328330

@@ -374,6 +376,22 @@ def check_inputs(
374376
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
375377
)
376378

379+
if prompt_embeds is not None and prompt_embeds_mask is None:
380+
logger.warning(
381+
"`prompt_embeds` is provided and `prompt_embeds_mask` is not provided, so the model will treat all"
382+
" prompt tokens as valid. If `prompt_embeds` contains padding, you should provide the padding mask as"
383+
" `prompt_embeds_mask`. Make sure to generate `prompt_embeds_mask` from the same text encoder that was"
384+
" used to generate `prompt_embeds`."
385+
)
386+
387+
if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
388+
logger.warning(
389+
"`negative_prompt_embeds` is provided and `negative_prompt_embeds_mask` is not provided, so the model will treat all"
390+
" negative prompt tokens as valid. If `negative_prompt_embeds` contains padding, you should provide the padding mask as"
391+
" `negative_prompt_embeds_mask`. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was"
392+
" used to generate `negative_prompt_embeds`."
393+
)
394+
377395
if max_sequence_length is not None and max_sequence_length > 1024:
378396
raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
379397

@@ -700,9 +718,7 @@ def __call__(
700718

701719
device = self._execution_device
702720

703-
has_neg_prompt = negative_prompt is not None or (
704-
negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
705-
)
721+
has_neg_prompt = negative_prompt is not None or negative_prompt_embeds is not None
706722

707723
if true_cfg_scale > 1 and not has_neg_prompt:
708724
logger.warning(

src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@
7474
"""
7575

7676

77-
# Coped from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift
77+
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift
7878
def calculate_shift(
7979
image_seq_len,
8080
base_seq_len: int = 256,
@@ -221,7 +221,7 @@ def __init__(
221221
self.prompt_template_encode_start_idx = 34
222222
self.default_sample_size = 128
223223

224-
# Coped from diffusers.pipelines.qwenimage.pipeline_qwenimage.extract_masked_hidden
224+
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._extract_masked_hidden
225225
def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
226226
bool_mask = mask.bool()
227227
valid_lengths = bool_mask.sum(dim=1)
@@ -230,7 +230,7 @@ def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor
230230

231231
return split_result
232232

233-
# Coped from diffusers.pipelines.qwenimage.pipeline_qwenimage.get_qwen_prompt_embeds
233+
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._get_qwen_prompt_embeds
234234
def _get_qwen_prompt_embeds(
235235
self,
236236
prompt: str | list[str] = None,
@@ -247,7 +247,7 @@ def _get_qwen_prompt_embeds(
247247
txt = [template.format(e) for e in prompt]
248248
txt_tokens = self.tokenizer(
249249
txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt"
250-
).to(self.device)
250+
).to(device)
251251
encoder_hidden_states = self.text_encoder(
252252
input_ids=txt_tokens.input_ids,
253253
attention_mask=txt_tokens.attention_mask,
@@ -269,7 +269,7 @@ def _get_qwen_prompt_embeds(
269269

270270
return prompt_embeds, encoder_attention_mask
271271

272-
# Coped from diffusers.pipelines.qwenimage.pipeline_qwenimage.encode_prompt
272+
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline.encode_prompt
273273
def encode_prompt(
274274
self,
275275
prompt: str | list[str],
@@ -280,6 +280,7 @@ def encode_prompt(
280280
max_sequence_length: int = 1024,
281281
):
282282
r"""
283+
283284
Args:
284285
prompt (`str` or `list[str]`, *optional*):
285286
prompt to be encoded
@@ -299,14 +300,18 @@ def encode_prompt(
299300
if prompt_embeds is None:
300301
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device)
301302

303+
prompt_embeds = prompt_embeds[:, :max_sequence_length]
302304
_, seq_len, _ = prompt_embeds.shape
303305
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
304306
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
305-
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
306-
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
307307

308-
if prompt_embeds_mask is not None and prompt_embeds_mask.all():
309-
prompt_embeds_mask = None
308+
if prompt_embeds_mask is not None:
309+
prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length]
310+
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
311+
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
312+
313+
if prompt_embeds_mask.all():
314+
prompt_embeds_mask = None
310315

311316
return prompt_embeds, prompt_embeds_mask
312317

@@ -354,12 +359,19 @@ def check_inputs(
354359
)
355360

356361
if prompt_embeds is not None and prompt_embeds_mask is None:
357-
raise ValueError(
358-
"If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
362+
logger.warning(
363+
"`prompt_embeds` is provided and `prompt_embeds_mask` is not provided, so the model will treat all"
364+
" prompt tokens as valid. If `prompt_embeds` contains padding, you should provide the padding mask as"
365+
" `prompt_embeds_mask`. Make sure to generate `prompt_embeds_mask` from the same text encoder that was"
366+
" used to generate `prompt_embeds`."
359367
)
368+
360369
if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
361-
raise ValueError(
362-
"If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
370+
logger.warning(
371+
"`negative_prompt_embeds` is provided and `negative_prompt_embeds_mask` is not provided, so the model will treat all"
372+
" negative prompt tokens as valid. If `negative_prompt_embeds` contains padding, you should provide the padding mask as"
373+
" `negative_prompt_embeds_mask`. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was"
374+
" used to generate `negative_prompt_embeds`."
363375
)
364376

365377
if max_sequence_length is not None and max_sequence_length > 1024:
@@ -739,9 +751,7 @@ def __call__(
739751

740752
device = self._execution_device
741753

742-
has_neg_prompt = negative_prompt is not None or (
743-
negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
744-
)
754+
has_neg_prompt = negative_prompt is not None or negative_prompt_embeds is not None
745755
do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
746756
prompt_embeds, prompt_embeds_mask = self.encode_prompt(
747757
prompt=prompt,

src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -247,15 +247,15 @@ def _get_qwen_prompt_embeds(
247247
).to(device)
248248

249249
outputs = self.text_encoder(
250-
input_ids=model_inputs.input_ids,
251-
attention_mask=model_inputs.attention_mask,
252-
pixel_values=model_inputs.pixel_values,
253-
image_grid_thw=model_inputs.image_grid_thw,
250+
input_ids=model_inputs["input_ids"],
251+
attention_mask=model_inputs["attention_mask"],
252+
pixel_values=model_inputs.get("pixel_values"),
253+
image_grid_thw=model_inputs.get("image_grid_thw"),
254254
output_hidden_states=True,
255255
)
256256

257257
hidden_states = outputs.hidden_states[-1]
258-
split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask)
258+
split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs["attention_mask"])
259259
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
260260
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
261261
max_seq_len = max([e.size(0) for e in split_hidden_states])
@@ -306,11 +306,13 @@ def encode_prompt(
306306
_, seq_len, _ = prompt_embeds.shape
307307
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
308308
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
309-
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
310-
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
311309

312-
if prompt_embeds_mask is not None and prompt_embeds_mask.all():
313-
prompt_embeds_mask = None
310+
if prompt_embeds_mask is not None:
311+
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
312+
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
313+
314+
if prompt_embeds_mask.all():
315+
prompt_embeds_mask = None
314316

315317
return prompt_embeds, prompt_embeds_mask
316318

@@ -358,12 +360,19 @@ def check_inputs(
358360
)
359361

360362
if prompt_embeds is not None and prompt_embeds_mask is None:
361-
raise ValueError(
362-
"If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
363+
logger.warning(
364+
"`prompt_embeds` is provided and `prompt_embeds_mask` is not provided, so the model will treat all"
365+
" prompt tokens as valid. If `prompt_embeds` contains padding, you should provide the padding mask as"
366+
" `prompt_embeds_mask`. Make sure to generate `prompt_embeds_mask` from the same text encoder that was"
367+
" used to generate `prompt_embeds`."
363368
)
369+
364370
if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
365-
raise ValueError(
366-
"If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
371+
logger.warning(
372+
"`negative_prompt_embeds` is provided and `negative_prompt_embeds_mask` is not provided, so the model will treat all"
373+
" negative prompt tokens as valid. If `negative_prompt_embeds` contains padding, you should provide the padding mask as"
374+
" `negative_prompt_embeds_mask`. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was"
375+
" used to generate `negative_prompt_embeds`."
367376
)
368377

369378
if max_sequence_length is not None and max_sequence_length > 1024:
@@ -705,9 +714,7 @@ def __call__(
705714
image = self.image_processor.preprocess(image, calculated_height, calculated_width)
706715
image = image.unsqueeze(2)
707716

708-
has_neg_prompt = negative_prompt is not None or (
709-
negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
710-
)
717+
has_neg_prompt = negative_prompt is not None or negative_prompt_embeds is not None
711718

712719
if true_cfg_scale > 1 and not has_neg_prompt:
713720
logger.warning(

0 commit comments

Comments
 (0)