Skip to content

Commit c507097

Browse files
authored
[Bugfix] Fix shape mismatch in LongCatAudioDiTTransformer conversion (#13494)
Signed-off-by: Lancer <maruixiang6688@gmail.com>
1 parent a503401 commit c507097

2 files changed

Lines changed: 3 additions & 1 deletion

File tree

scripts/convert_longcat_audio_dit_to_diffusers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ def convert_longcat_audio_dit(
131131
cross_attn_norm=config.get("dit_cross_attn_norm", False),
132132
eps=config.get("dit_eps", 1e-6),
133133
use_latent_condition=config.get("dit_use_latent_condition", True),
134+
ff_mult=config.get("dit_ff_mult", 4),
134135
)
135136
transformer.load_state_dict(transformer_state_dict, strict=True)
136137
transformer = transformer.to(dtype=torch_dtype)

src/diffusers/models/transformers/transformer_longcat_audio_dit.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,7 @@ def __init__(
475475
cross_attn_norm: bool = False,
476476
eps: float = 1e-6,
477477
use_latent_condition: bool = True,
478+
ff_mult: float = 4.0,
478479
):
479480
super().__init__()
480481
dim = dit_dim
@@ -498,7 +499,7 @@ def __init__(
498499
cross_attn_norm=cross_attn_norm,
499500
adaln_type=adaln_type,
500501
adaln_use_text_cond=adaln_use_text_cond,
501-
ff_mult=4.0,
502+
ff_mult=ff_mult,
502503
)
503504
for _ in range(dit_depth)
504505
]

0 commit comments

Comments
 (0)