Skip to content

Commit eb86a2e

Browse files
committed
fix(qwen): correct comments for copied functions in controlnet and inpaint pipelines
1 parent 908c304 commit eb86a2e

2 files changed

Lines changed: 14 additions & 9 deletions

File tree

src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@
101101
"""
102102

103103

104-
# Coped from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift
104+
# `Copied from` diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift
105105
def calculate_shift(
106106
image_seq_len,
107107
base_seq_len: int = 256,
@@ -239,7 +239,7 @@ def __init__(
239239
self.prompt_template_encode_start_idx = 34
240240
self.default_sample_size = 128
241241

242-
# Coped from diffusers.pipelines.qwenimage.pipeline_qwenimage.extract_masked_hidden
242+
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._extract_masked_hidden
243243
def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
244244
bool_mask = mask.bool()
245245
valid_lengths = bool_mask.sum(dim=1)
@@ -248,7 +248,7 @@ def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor
248248

249249
return split_result
250250

251-
# Coped from diffusers.pipelines.qwenimage.pipeline_qwenimage.get_qwen_prompt_embeds
251+
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._get_qwen_prompt_embeds
252252
def _get_qwen_prompt_embeds(
253253
self,
254254
prompt: str | list[str] = None,
@@ -287,7 +287,7 @@ def _get_qwen_prompt_embeds(
287287

288288
return prompt_embeds, encoder_attention_mask
289289

290-
# Coped from diffusers.pipelines.qwenimage.pipeline_qwenimage.encode_prompt
290+
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline.encode_prompt
291291
def encode_prompt(
292292
self,
293293
prompt: str | list[str],
@@ -318,11 +318,13 @@ def encode_prompt(
318318
if prompt_embeds is None:
319319
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device)
320320

321+
prompt_embeds = prompt_embeds[:, :max_sequence_length]
321322
_, seq_len, _ = prompt_embeds.shape
322323
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
323324
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
324325

325326
if prompt_embeds_mask is not None:
327+
prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length]
326328
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
327329
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
328330

src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@
7474
"""
7575

7676

77-
# Coped from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift
77+
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift
7878
def calculate_shift(
7979
image_seq_len,
8080
base_seq_len: int = 256,
@@ -221,7 +221,7 @@ def __init__(
221221
self.prompt_template_encode_start_idx = 34
222222
self.default_sample_size = 128
223223

224-
# Coped from diffusers.pipelines.qwenimage.pipeline_qwenimage.extract_masked_hidden
224+
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._extract_masked_hidden
225225
def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
226226
bool_mask = mask.bool()
227227
valid_lengths = bool_mask.sum(dim=1)
@@ -230,7 +230,7 @@ def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor
230230

231231
return split_result
232232

233-
# Coped from diffusers.pipelines.qwenimage.pipeline_qwenimage.get_qwen_prompt_embeds
233+
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._get_qwen_prompt_embeds
234234
def _get_qwen_prompt_embeds(
235235
self,
236236
prompt: str | list[str] = None,
@@ -247,7 +247,7 @@ def _get_qwen_prompt_embeds(
247247
txt = [template.format(e) for e in prompt]
248248
txt_tokens = self.tokenizer(
249249
txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt"
250-
).to(self.device)
250+
).to(device)
251251
encoder_hidden_states = self.text_encoder(
252252
input_ids=txt_tokens.input_ids,
253253
attention_mask=txt_tokens.attention_mask,
@@ -269,7 +269,7 @@ def _get_qwen_prompt_embeds(
269269

270270
return prompt_embeds, encoder_attention_mask
271271

272-
# Coped from diffusers.pipelines.qwenimage.pipeline_qwenimage.encode_prompt
272+
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline.encode_prompt
273273
def encode_prompt(
274274
self,
275275
prompt: str | list[str],
@@ -280,6 +280,7 @@ def encode_prompt(
280280
max_sequence_length: int = 1024,
281281
):
282282
r"""
283+
283284
Args:
284285
prompt (`str` or `list[str]`, *optional*):
285286
prompt to be encoded
@@ -299,11 +300,13 @@ def encode_prompt(
299300
if prompt_embeds is None:
300301
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device)
301302

303+
prompt_embeds = prompt_embeds[:, :max_sequence_length]
302304
_, seq_len, _ = prompt_embeds.shape
303305
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
304306
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
305307

306308
if prompt_embeds_mask is not None:
309+
prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length]
307310
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
308311
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
309312

0 commit comments

Comments
 (0)