Skip to content

Commit f600a36

Browse files
committed
fix(qwen): safely handle missing embeds masks in edit and inpaint pipelines
1 parent d7814c3 commit f600a36

6 files changed

Lines changed: 49 additions & 77 deletions

File tree

src/diffusers/modular_pipelines/qwenimage/encoders.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -117,15 +117,15 @@ def get_qwen_prompt_embeds_edit(
117117
).to(device)
118118

119119
outputs = text_encoder(
120-
input_ids=model_inputs.input_ids,
121-
attention_mask=model_inputs.attention_mask,
122-
pixel_values=model_inputs.pixel_values,
123-
image_grid_thw=model_inputs.image_grid_thw,
120+
input_ids=model_inputs.get("input_ids"),
121+
attention_mask=model_inputs.get("attention_mask"),
122+
pixel_values=model_inputs.get("pixel_values"),
123+
image_grid_thw=model_inputs.get("image_grid_thw"),
124124
output_hidden_states=True,
125125
)
126126

127127
hidden_states = outputs.hidden_states[-1]
128-
split_hidden_states = _extract_masked_hidden(hidden_states, model_inputs.attention_mask)
128+
split_hidden_states = _extract_masked_hidden(hidden_states, model_inputs.get("attention_mask"))
129129
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
130130
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
131131
max_seq_len = max([e.size(0) for e in split_hidden_states])
@@ -173,15 +173,15 @@ def get_qwen_prompt_embeds_edit_plus(
173173
return_tensors="pt",
174174
).to(device)
175175
outputs = text_encoder(
176-
input_ids=model_inputs.input_ids,
177-
attention_mask=model_inputs.attention_mask,
178-
pixel_values=model_inputs.pixel_values,
179-
image_grid_thw=model_inputs.image_grid_thw,
176+
input_ids=model_inputs.get("input_ids"),
177+
attention_mask=model_inputs.get("attention_mask"),
178+
pixel_values=model_inputs.get("pixel_values"),
179+
image_grid_thw=model_inputs.get("image_grid_thw"),
180180
output_hidden_states=True,
181181
)
182182

183183
hidden_states = outputs.hidden_states[-1]
184-
split_hidden_states = _extract_masked_hidden(hidden_states, model_inputs.attention_mask)
184+
split_hidden_states = _extract_masked_hidden(hidden_states, model_inputs.get("attention_mask"))
185185
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
186186
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
187187
max_seq_len = max([e.size(0) for e in split_hidden_states])

src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -302,11 +302,13 @@ def encode_prompt(
302302
_, seq_len, _ = prompt_embeds.shape
303303
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
304304
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
305-
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
306-
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
307305

308-
if prompt_embeds_mask is not None and prompt_embeds_mask.all():
309-
prompt_embeds_mask = None
306+
if prompt_embeds_mask is not None:
307+
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
308+
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
309+
310+
if prompt_embeds_mask.all():
311+
prompt_embeds_mask = None
310312

311313
return prompt_embeds, prompt_embeds_mask
312314

@@ -353,15 +355,6 @@ def check_inputs(
353355
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
354356
)
355357

356-
if prompt_embeds is not None and prompt_embeds_mask is None:
357-
raise ValueError(
358-
"If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
359-
)
360-
if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
361-
raise ValueError(
362-
"If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
363-
)
364-
365358
if max_sequence_length is not None and max_sequence_length > 1024:
366359
raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
367360

src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -247,15 +247,15 @@ def _get_qwen_prompt_embeds(
247247
).to(device)
248248

249249
outputs = self.text_encoder(
250-
input_ids=model_inputs.input_ids,
251-
attention_mask=model_inputs.attention_mask,
252-
pixel_values=model_inputs.pixel_values,
253-
image_grid_thw=model_inputs.image_grid_thw,
250+
input_ids=model_inputs.get("input_ids"),
251+
attention_mask=model_inputs.get("attention_mask"),
252+
pixel_values=model_inputs.get("pixel_values"),
253+
image_grid_thw=model_inputs.get("image_grid_thw"),
254254
output_hidden_states=True,
255255
)
256256

257257
hidden_states = outputs.hidden_states[-1]
258-
split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask)
258+
split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.get("attention_mask"))
259259
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
260260
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
261261
max_seq_len = max([e.size(0) for e in split_hidden_states])
@@ -306,11 +306,13 @@ def encode_prompt(
306306
_, seq_len, _ = prompt_embeds.shape
307307
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
308308
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
309-
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
310-
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
311309

312-
if prompt_embeds_mask is not None and prompt_embeds_mask.all():
313-
prompt_embeds_mask = None
310+
if prompt_embeds_mask is not None:
311+
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
312+
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
313+
314+
if prompt_embeds_mask.all():
315+
prompt_embeds_mask = None
314316

315317
return prompt_embeds, prompt_embeds_mask
316318

@@ -357,15 +359,6 @@ def check_inputs(
357359
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
358360
)
359361

360-
if prompt_embeds is not None and prompt_embeds_mask is None:
361-
raise ValueError(
362-
"If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
363-
)
364-
if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
365-
raise ValueError(
366-
"If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
367-
)
368-
369362
if max_sequence_length is not None and max_sequence_length > 1024:
370363
raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
371364

src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -258,15 +258,15 @@ def _get_qwen_prompt_embeds(
258258
).to(device)
259259

260260
outputs = self.text_encoder(
261-
input_ids=model_inputs.input_ids,
262-
attention_mask=model_inputs.attention_mask,
263-
pixel_values=model_inputs.pixel_values,
264-
image_grid_thw=model_inputs.image_grid_thw,
261+
input_ids=model_inputs.get("input_ids"),
262+
attention_mask=model_inputs.get("attention_mask"),
263+
pixel_values=model_inputs.get("pixel_values"),
264+
image_grid_thw=model_inputs.get("image_grid_thw"),
265265
output_hidden_states=True,
266266
)
267267

268268
hidden_states = outputs.hidden_states[-1]
269-
split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask)
269+
split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.get("attention_mask"))
270270
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
271271
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
272272
max_seq_len = max([e.size(0) for e in split_hidden_states])
@@ -318,11 +318,13 @@ def encode_prompt(
318318
_, seq_len, _ = prompt_embeds.shape
319319
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
320320
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
321-
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
322-
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
323321

324-
if prompt_embeds_mask is not None and prompt_embeds_mask.all():
325-
prompt_embeds_mask = None
322+
if prompt_embeds_mask is not None:
323+
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
324+
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
325+
326+
if prompt_embeds_mask.all():
327+
prompt_embeds_mask = None
326328

327329
return prompt_embeds, prompt_embeds_mask
328330

src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -260,15 +260,15 @@ def _get_qwen_prompt_embeds(
260260
).to(device)
261261

262262
outputs = self.text_encoder(
263-
input_ids=model_inputs.input_ids,
264-
attention_mask=model_inputs.attention_mask,
265-
pixel_values=model_inputs.pixel_values,
266-
image_grid_thw=model_inputs.image_grid_thw,
263+
input_ids=model_inputs.get("input_ids"),
264+
attention_mask=model_inputs.get("attention_mask"),
265+
pixel_values=model_inputs.get("pixel_values"),
266+
image_grid_thw=model_inputs.get("image_grid_thw"),
267267
output_hidden_states=True,
268268
)
269269

270270
hidden_states = outputs.hidden_states[-1]
271-
split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask)
271+
split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.get("attention_mask"))
272272
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
273273
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
274274
max_seq_len = max([e.size(0) for e in split_hidden_states])
@@ -320,11 +320,13 @@ def encode_prompt(
320320
_, seq_len, _ = prompt_embeds.shape
321321
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
322322
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
323-
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
324-
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
325323

326-
if prompt_embeds_mask is not None and prompt_embeds_mask.all():
327-
prompt_embeds_mask = None
324+
if prompt_embeds_mask is not None:
325+
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
326+
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
327+
328+
if prompt_embeds_mask.all():
329+
prompt_embeds_mask = None
328330

329331
return prompt_embeds, prompt_embeds_mask
330332

@@ -372,15 +374,6 @@ def check_inputs(
372374
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
373375
)
374376

375-
if prompt_embeds is not None and prompt_embeds_mask is None:
376-
raise ValueError(
377-
"If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
378-
)
379-
if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
380-
raise ValueError(
381-
"If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
382-
)
383-
384377
if max_sequence_length is not None and max_sequence_length > 1024:
385378
raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
386379

src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -384,15 +384,6 @@ def check_inputs(
384384
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
385385
)
386386

387-
if prompt_embeds is not None and prompt_embeds_mask is None:
388-
raise ValueError(
389-
"If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
390-
)
391-
if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
392-
raise ValueError(
393-
"If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
394-
)
395-
396387
if max_sequence_length is not None and max_sequence_length > 1024:
397388
raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
398389

0 commit comments

Comments
 (0)