Skip to content

Commit c41a3c3

Browse files
RuixiangMadg845github-actions[bot]
authored
[Feat] Adds LongCat-AudioDiT pipeline (#13390)
* Add LongCat-AudioDiT pipeline Signed-off-by: Lancer <maruixiang6688@gmail.com> * upd Signed-off-by: Lancer <maruixiang6688@gmail.com> * upd * Apply suggestions from code review Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * upd Signed-off-by: Lancer <maruixiang6688@gmail.com> * upd Signed-off-by: Lancer <maruixiang6688@gmail.com> * upd Signed-off-by: Lancer <maruixiang6688@gmail.com> * upd Signed-off-by: Lancer <maruixiang6688@gmail.com> * Apply style fixes * upd Signed-off-by: Lancer <maruixiang6688@gmail.com> * upd Signed-off-by: Lancer <maruixiang6688@gmail.com> * Apply style fixes * upd Signed-off-by: Lancer <maruixiang6688@gmail.com> * Apply style fixes * upd Signed-off-by: Lancer <maruixiang6688@gmail.com> --------- Signed-off-by: Lancer <maruixiang6688@gmail.com> Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 0d79fc2 commit c41a3c3

18 files changed

Lines changed: 2070 additions & 0 deletions

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,8 @@
490490
- sections:
491491
- local: api/pipelines/audioldm2
492492
title: AudioLDM 2
493+
- local: api/pipelines/longcat_audio_dit
494+
title: LongCat-AudioDiT
493495
- local: api/pipelines/stable_audio
494496
title: Stable Audio
495497
title: Audio
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
<!--Copyright 2026 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
-->
12+
13+
# LongCat-AudioDiT
14+
15+
LongCat-AudioDiT is a text-to-audio diffusion model from Meituan LongCat. The diffusers integration exposes a standard [`DiffusionPipeline`] interface for text-conditioned audio generation.
16+
17+
This pipeline supports loading the original flat LongCat checkpoint layout from either a local directory or a Hugging Face Hub repository containing:
18+
19+
- `config.json`
20+
- `model.safetensors`
21+
22+
The loader builds the text encoder, transformer, and VAE from `config.json`, restores component weights from `model.safetensors`, and ties the shared UMT5 embedding when needed.
23+
24+
This pipeline was adapted from the LongCat-AudioDiT reference implementation: https://github.com/meituan-longcat/LongCat-AudioDiT
25+
26+
## Usage
27+
28+
```py
29+
import soundfile as sf
30+
import torch
31+
from diffusers import LongCatAudioDiTPipeline
32+
33+
pipeline = LongCatAudioDiTPipeline.from_pretrained(
34+
"meituan-longcat/LongCat-AudioDiT-1B",
35+
torch_dtype=torch.float16,
36+
)
37+
pipeline = pipeline.to("cuda")
38+
39+
audio = pipeline(
40+
prompt="A calm ocean wave ambience with soft wind in the background.",
41+
audio_end_in_s=5.0,
42+
num_inference_steps=16,
43+
guidance_scale=4.0,
44+
output_type="pt",
45+
).audios
46+
47+
output = audio[0, 0].float().cpu().numpy()
48+
sf.write("longcat.wav", output, pipeline.sample_rate)
49+
```
50+
51+
## Tips
52+
53+
- `audio_end_in_s` is the most direct way to control output duration.
54+
- `output_type="pt"` returns a PyTorch tensor shaped `(batch, channels, samples)`.
55+
56+
## LongCatAudioDiTPipeline
57+
58+
[[autodoc]] LongCatAudioDiTPipeline
59+
- all
60+
- __call__
61+
- from_pretrained

docs/source/en/api/pipelines/overview.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ The table below lists all the pipelines currently available in 🤗 Diffusers an
2929
|---|---|
3030
| [AnimateDiff](animatediff) | text2video |
3131
| [AudioLDM2](audioldm2) | text2audio |
32+
| [LongCat-AudioDiT](longcat_audio_dit) | text2audio |
3233
| [AuraFlow](aura_flow) | text2image |
3334
| [Bria 3.2](bria_3_2) | text2image |
3435
| [CogVideoX](cogvideox) | text2video |
Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
#!/usr/bin/env python3
2+
# Copyright 2026 The HuggingFace Team. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the 'License');
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an 'AS IS' BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
# Usage:
17+
# python scripts/convert_longcat_audio_dit_to_diffusers.py --checkpoint_path /path/to/model --output_path /data/models
18+
# python scripts/convert_longcat_audio_dit_to_diffusers.py --repo_id meituan-longcat/LongCat-AudioDiT-1B --output_path /data/models
19+
# python scripts/convert_longcat_audio_dit_to_diffusers.py --checkpoint_path /path/to/model --output_path /data/models --dtype fp16
20+
21+
import argparse
22+
import json
23+
from pathlib import Path
24+
25+
import torch
26+
from huggingface_hub import snapshot_download
27+
from safetensors.torch import load_file
28+
from transformers import AutoTokenizer, UMT5Config, UMT5EncoderModel
29+
30+
from diffusers import (
31+
FlowMatchEulerDiscreteScheduler,
32+
LongCatAudioDiTPipeline,
33+
LongCatAudioDiTTransformer,
34+
LongCatAudioDiTVae,
35+
)
36+
37+
38+
def find_checkpoint(input_dir: Path):
39+
safetensors_file = input_dir / "model.safetensors"
40+
if safetensors_file.exists():
41+
return input_dir, safetensors_file
42+
43+
index_file = input_dir / "model.safetensors.index.json"
44+
if index_file.exists():
45+
with open(index_file) as f:
46+
index = json.load(f)
47+
weight_map = index.get("weight_map", {})
48+
first_weight = list(weight_map.values())[0]
49+
return input_dir, input_dir / first_weight
50+
51+
for subdir in input_dir.iterdir():
52+
if subdir.is_dir():
53+
safetensors_file = subdir / "model.safetensors"
54+
if safetensors_file.exists():
55+
return subdir, safetensors_file
56+
index_file = subdir / "model.safetensors.index.json"
57+
if index_file.exists():
58+
with open(index_file) as f:
59+
index = json.load(f)
60+
weight_map = index.get("weight_map", {})
61+
first_weight = list(weight_map.values())[0]
62+
return subdir, subdir / first_weight
63+
64+
raise FileNotFoundError(f"No checkpoint found in {input_dir}")
65+
66+
67+
def convert_longcat_audio_dit(
68+
checkpoint_path: str | None = None,
69+
repo_id: str | None = None,
70+
output_path: str = "",
71+
dtype: str = "fp32",
72+
text_encoder_model: str = "google/umt5-xxl",
73+
):
74+
if not checkpoint_path and not repo_id:
75+
raise ValueError("Either --checkpoint_path or --repo_id must be provided")
76+
if checkpoint_path and repo_id:
77+
raise ValueError("Cannot specify both --checkpoint_path and --repo_id")
78+
79+
dtype_map = {
80+
"fp32": torch.float32,
81+
"fp16": torch.float16,
82+
"bf16": torch.bfloat16,
83+
}
84+
torch_dtype = dtype_map.get(dtype, torch.float32)
85+
86+
if repo_id:
87+
input_dir = Path(snapshot_download(repo_id, local_files_only=False))
88+
model_name = repo_id.split("/")[-1]
89+
else:
90+
input_dir = Path(checkpoint_path)
91+
if not input_dir.exists():
92+
raise FileNotFoundError(f"Checkpoint path not found: {checkpoint_path}")
93+
model_name = None
94+
95+
model_dir, checkpoint_path = find_checkpoint(input_dir)
96+
if model_name is None:
97+
model_name = model_dir.name
98+
99+
config_path = model_dir / "config.json"
100+
if not config_path.exists():
101+
raise FileNotFoundError(f"config.json not found in {model_dir}")
102+
103+
with open(config_path) as f:
104+
config = json.load(f)
105+
106+
state_dict = load_file(checkpoint_path)
107+
108+
transformer_keys = [k for k in state_dict.keys() if k.startswith("transformer.")]
109+
transformer_state_dict = {key[12:]: state_dict[key] for key in transformer_keys}
110+
111+
vae_keys = [k for k in state_dict.keys() if k.startswith("vae.")]
112+
vae_state_dict = {key[4:]: state_dict[key] for key in vae_keys}
113+
114+
text_encoder_keys = [k for k in state_dict.keys() if k.startswith("text_encoder.")]
115+
text_encoder_state_dict = {key[13:]: state_dict[key] for key in text_encoder_keys}
116+
117+
transformer = LongCatAudioDiTTransformer(
118+
dit_dim=config["dit_dim"],
119+
dit_depth=config["dit_depth"],
120+
dit_heads=config["dit_heads"],
121+
dit_text_dim=config["dit_text_dim"],
122+
latent_dim=config["latent_dim"],
123+
dropout=config.get("dit_dropout", 0.0),
124+
bias=config.get("dit_bias", True),
125+
cross_attn=config.get("dit_cross_attn", True),
126+
adaln_type=config.get("dit_adaln_type", "global"),
127+
adaln_use_text_cond=config.get("dit_adaln_use_text_cond", True),
128+
long_skip=config.get("dit_long_skip", True),
129+
text_conv=config.get("dit_text_conv", True),
130+
qk_norm=config.get("dit_qk_norm", True),
131+
cross_attn_norm=config.get("dit_cross_attn_norm", False),
132+
eps=config.get("dit_eps", 1e-6),
133+
use_latent_condition=config.get("dit_use_latent_condition", True),
134+
)
135+
transformer.load_state_dict(transformer_state_dict, strict=True)
136+
transformer = transformer.to(dtype=torch_dtype)
137+
138+
vae_config = dict(config["vae_config"])
139+
vae_config.pop("model_type", None)
140+
vae = LongCatAudioDiTVae(**vae_config)
141+
vae.load_state_dict(vae_state_dict, strict=True)
142+
vae = vae.to(dtype=torch_dtype)
143+
144+
text_encoder_config = UMT5Config.from_dict(config["text_encoder_config"])
145+
text_encoder = UMT5EncoderModel(text_encoder_config)
146+
text_missing, text_unexpected = text_encoder.load_state_dict(text_encoder_state_dict, strict=False)
147+
148+
allowed_missing = {"shared.weight"}
149+
unexpected_missing = set(text_missing) - allowed_missing
150+
if unexpected_missing:
151+
raise RuntimeError(f"Unexpected missing text encoder weights: {sorted(unexpected_missing)}")
152+
if text_unexpected:
153+
raise RuntimeError(f"Unexpected text encoder weights: {sorted(text_unexpected)}")
154+
if "shared.weight" in text_missing:
155+
text_encoder.shared.weight.data.copy_(text_encoder.encoder.embed_tokens.weight.data)
156+
157+
text_encoder = text_encoder.to(dtype=torch_dtype)
158+
159+
tokenizer = AutoTokenizer.from_pretrained(text_encoder_model)
160+
161+
scheduler_config = {"shift": 1.0, "invert_sigmas": True}
162+
scheduler_config.update(config.get("scheduler_config", {}))
163+
scheduler = FlowMatchEulerDiscreteScheduler(**scheduler_config)
164+
165+
pipeline = LongCatAudioDiTPipeline(
166+
vae=vae,
167+
text_encoder=text_encoder,
168+
tokenizer=tokenizer,
169+
transformer=transformer,
170+
scheduler=scheduler,
171+
)
172+
173+
pipeline.sample_rate = config.get("sampling_rate", 24000)
174+
pipeline.vae_scale_factor = config.get("vae_scale_factor", config.get("latent_hop", 2048))
175+
pipeline.max_wav_duration = config.get("max_wav_duration", 30.0)
176+
pipeline.text_norm_feat = config.get("text_norm_feat", True)
177+
pipeline.text_add_embed = config.get("text_add_embed", True)
178+
179+
output_path = Path(output_path) / f"{model_name}-Diffusers"
180+
output_path.mkdir(parents=True, exist_ok=True)
181+
182+
pipeline.save_pretrained(output_path)
183+
184+
185+
def get_args():
186+
parser = argparse.ArgumentParser()
187+
parser.add_argument(
188+
"--checkpoint_path",
189+
type=str,
190+
default=None,
191+
help="Path to local model directory",
192+
)
193+
parser.add_argument(
194+
"--repo_id",
195+
type=str,
196+
default=None,
197+
help="HuggingFace repo_id to download model",
198+
)
199+
parser.add_argument("--output_path", type=str, required=True, help="Output directory")
200+
parser.add_argument(
201+
"--dtype",
202+
type=str,
203+
default="fp32",
204+
choices=["fp32", "fp16", "bf16"],
205+
help="Data type for converted weights",
206+
)
207+
parser.add_argument(
208+
"--text_encoder_model",
209+
type=str,
210+
default="google/umt5-xxl",
211+
help="HuggingFace model ID for text encoder tokenizer",
212+
)
213+
return parser.parse_args()
214+
215+
216+
if __name__ == "__main__":
217+
args = get_args()
218+
convert_longcat_audio_dit(
219+
checkpoint_path=args.checkpoint_path,
220+
repo_id=args.repo_id,
221+
output_path=args.output_path,
222+
dtype=args.dtype,
223+
text_encoder_model=args.text_encoder_model,
224+
)

src/diffusers/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,8 @@
254254
"Kandinsky3UNet",
255255
"Kandinsky5Transformer3DModel",
256256
"LatteTransformer3DModel",
257+
"LongCatAudioDiTTransformer",
258+
"LongCatAudioDiTVae",
257259
"LongCatImageTransformer2DModel",
258260
"LTX2VideoTransformer3DModel",
259261
"LTXVideoTransformer3DModel",
@@ -599,6 +601,7 @@
599601
"LEditsPPPipelineStableDiffusionXL",
600602
"LLaDA2Pipeline",
601603
"LLaDA2PipelineOutput",
604+
"LongCatAudioDiTPipeline",
602605
"LongCatImageEditPipeline",
603606
"LongCatImagePipeline",
604607
"LTX2ConditionPipeline",
@@ -1058,6 +1061,8 @@
10581061
Kandinsky3UNet,
10591062
Kandinsky5Transformer3DModel,
10601063
LatteTransformer3DModel,
1064+
LongCatAudioDiTTransformer,
1065+
LongCatAudioDiTVae,
10611066
LongCatImageTransformer2DModel,
10621067
LTX2VideoTransformer3DModel,
10631068
LTXVideoTransformer3DModel,
@@ -1378,6 +1383,7 @@
13781383
LEditsPPPipelineStableDiffusionXL,
13791384
LLaDA2Pipeline,
13801385
LLaDA2PipelineOutput,
1386+
LongCatAudioDiTPipeline,
13811387
LongCatImageEditPipeline,
13821388
LongCatImagePipeline,
13831389
LTX2ConditionPipeline,

src/diffusers/models/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
_import_structure["autoencoders.autoencoder_kl_qwenimage"] = ["AutoencoderKLQwenImage"]
5151
_import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
5252
_import_structure["autoencoders.autoencoder_kl_wan"] = ["AutoencoderKLWan"]
53+
_import_structure["autoencoders.autoencoder_longcat_audio_dit"] = ["LongCatAudioDiTVae"]
5354
_import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"]
5455
_import_structure["autoencoders.autoencoder_rae"] = ["AutoencoderRAE"]
5556
_import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
@@ -112,6 +113,7 @@
112113
_import_structure["transformers.transformer_hunyuan_video_framepack"] = ["HunyuanVideoFramepackTransformer3DModel"]
113114
_import_structure["transformers.transformer_hunyuanimage"] = ["HunyuanImageTransformer2DModel"]
114115
_import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"]
116+
_import_structure["transformers.transformer_longcat_audio_dit"] = ["LongCatAudioDiTTransformer"]
115117
_import_structure["transformers.transformer_longcat_image"] = ["LongCatImageTransformer2DModel"]
116118
_import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
117119
_import_structure["transformers.transformer_ltx2"] = ["LTX2VideoTransformer3DModel"]
@@ -180,6 +182,7 @@
180182
AutoencoderTiny,
181183
AutoencoderVidTok,
182184
ConsistencyDecoderVAE,
185+
LongCatAudioDiTVae,
183186
VQModel,
184187
)
185188
from .cache_utils import CacheMixin
@@ -233,6 +236,7 @@
233236
HunyuanVideoTransformer3DModel,
234237
Kandinsky5Transformer3DModel,
235238
LatteTransformer3DModel,
239+
LongCatAudioDiTTransformer,
236240
LongCatImageTransformer2DModel,
237241
LTX2VideoTransformer3DModel,
238242
LTXVideoTransformer3DModel,

src/diffusers/models/autoencoders/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage
2020
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
2121
from .autoencoder_kl_wan import AutoencoderKLWan
22+
from .autoencoder_longcat_audio_dit import LongCatAudioDiTVae
2223
from .autoencoder_oobleck import AutoencoderOobleck
2324
from .autoencoder_rae import AutoencoderRAE
2425
from .autoencoder_tiny import AutoencoderTiny

0 commit comments

Comments
 (0)