7474"""
7575
7676
77- # Coped from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift
77+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift
7878def calculate_shift (
7979 image_seq_len ,
8080 base_seq_len : int = 256 ,
@@ -221,7 +221,7 @@ def __init__(
221221 self .prompt_template_encode_start_idx = 34
222222 self .default_sample_size = 128
223223
224- # Coped from diffusers.pipelines.qwenimage.pipeline_qwenimage.extract_masked_hidden
224+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._extract_masked_hidden
225225 def _extract_masked_hidden (self , hidden_states : torch .Tensor , mask : torch .Tensor ):
226226 bool_mask = mask .bool ()
227227 valid_lengths = bool_mask .sum (dim = 1 )
@@ -230,7 +230,7 @@ def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor
230230
231231 return split_result
232232
233- # Coped from diffusers.pipelines.qwenimage.pipeline_qwenimage.get_qwen_prompt_embeds
233+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._get_qwen_prompt_embeds
234234 def _get_qwen_prompt_embeds (
235235 self ,
236236 prompt : str | list [str ] = None ,
@@ -247,7 +247,7 @@ def _get_qwen_prompt_embeds(
247247 txt = [template .format (e ) for e in prompt ]
248248 txt_tokens = self .tokenizer (
249249 txt , max_length = self .tokenizer_max_length + drop_idx , padding = True , truncation = True , return_tensors = "pt"
250- ).to (self . device )
250+ ).to (device )
251251 encoder_hidden_states = self .text_encoder (
252252 input_ids = txt_tokens .input_ids ,
253253 attention_mask = txt_tokens .attention_mask ,
@@ -269,7 +269,7 @@ def _get_qwen_prompt_embeds(
269269
270270 return prompt_embeds , encoder_attention_mask
271271
272- # Coped from diffusers.pipelines.qwenimage.pipeline_qwenimage.encode_prompt
272+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline .encode_prompt
273273 def encode_prompt (
274274 self ,
275275 prompt : str | list [str ],
@@ -280,6 +280,7 @@ def encode_prompt(
280280 max_sequence_length : int = 1024 ,
281281 ):
282282 r"""
283+
283284 Args:
284285 prompt (`str` or `list[str]`, *optional*):
285286 prompt to be encoded
@@ -299,11 +300,13 @@ def encode_prompt(
299300 if prompt_embeds is None :
300301 prompt_embeds , prompt_embeds_mask = self ._get_qwen_prompt_embeds (prompt , device )
301302
303+ prompt_embeds = prompt_embeds [:, :max_sequence_length ]
302304 _ , seq_len , _ = prompt_embeds .shape
303305 prompt_embeds = prompt_embeds .repeat (1 , num_images_per_prompt , 1 )
304306 prompt_embeds = prompt_embeds .view (batch_size * num_images_per_prompt , seq_len , - 1 )
305307
306308 if prompt_embeds_mask is not None :
309+ prompt_embeds_mask = prompt_embeds_mask [:, :max_sequence_length ]
307310 prompt_embeds_mask = prompt_embeds_mask .repeat (1 , num_images_per_prompt , 1 )
308311 prompt_embeds_mask = prompt_embeds_mask .view (batch_size * num_images_per_prompt , seq_len )
309312
0 commit comments