diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 00000000..5c03e073 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,20 @@ +.git +.github +**/__pycache__ +**/*.pyc +**/.pytest_cache +.venv +venv +.DS_Store +.vscode +.idea +tmp_examples* +new_checkpoint* +batch_test* +nohup* +*.mp4 +*.pt +*.pth +**/Wan2.2-* +**/mcps +terminals diff --git a/.gitignore b/.gitignore index de347daa..b632b098 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,5 @@ __pycache__/ tmp_examples* new_checkpoint* batch_test* -nohup* \ No newline at end of file +nohup* +.wan_ckpt_placeholder/ \ No newline at end of file diff --git a/DEPLOY.md b/DEPLOY.md new file mode 100644 index 00000000..23af3bf2 --- /dev/null +++ b/DEPLOY.md @@ -0,0 +1,386 @@ +# Wan2.2 双节点部署指南 + +## 环境要求 + +| 项目 | 要求 | +|------|------| +| 操作系统 | Linux (Ubuntu 20.04+) | +| GPU | 每台服务器 4×A100 40GB | +| Docker | >= 24.0 | +| Docker Compose | v2 (插件模式) | +| NVIDIA Container Toolkit | 已安装并配置 | +| 网络 | 两台服务器在同一局域网,可互相访问 TCP 29500 和 6379 端口 | + +## 网络准备 + +确保两台服务器之间以下端口可达: + +- **6379** — Redis(主节点的 Redis 供副节点 worker1 连接) +- **29500** — torchrun NCCL rendezvous 端口 +- **8008** — API HTTP 端口(对外服务) + +> **重要**:worker0 和 worker1 使用 `network_mode: host`,torchrun 直接绑定宿主机端口。 +> 这是 PyTorch 分布式训练的标准做法,Docker bridge 网络会导致 NCCL 连接超时。 + +验证连通性: + +```bash +# 在副节点上测试 +ping <主节点IP> +nc -zv <主节点IP> 6379 # Redis +nc -zv <主节点IP> 29500 # torchrun rendezvous +``` + +--- + +## 主节点部署 + +### 1. 克隆仓库 + +```bash +git clone https://github.com/lin285170/Wan2.2.git +cd Wan2.2 +``` + +### 2. 配置环境变量 + +```bash +cp docker/compose.env.example .env +``` + +编辑 `.env`: + +```bash +# 必填:API 密钥(逗号分隔支持多个) +WAN_SERVE_API_KEYS=sk-your-secret-key + +# 模型权重父目录(挂载到容器 /ckpt,每个模型在子目录中) +# 目录结构如下: +# /data/models/Wan2.2-T2V-A14B/ +# /data/models/Wan2.2-I2V-A14B/ +# /data/models/Wan2.2-TI2V-5B/ +# /data/models/Wan2.2-Animate-14B/ +# /data/models/Wan2.2-S2V-14B/ +# 只需要下载你实际使用的模型,其余子目录可以不存在。 +# 请求时根据 model 字段自动定位到对应子目录。 +# 也可以通过 parameters.ckpt_dir 手动指定其他路径。 +WAN_CKPT_HOST_PATH=/data/models + +# 双节点拓扑:2节点 × 4GPU = 8 GPU 总计 +WAN_NNODES=2 +WAN_NPROC_PER_NODE=4 + +# 主节点真实IP(worker0 使用 host 网络,直接绑定宿主机端口) +# 必须使用真实IP,不能用 0.0.0.0 或 127.0.0.1 +WAN_MASTER_ADDR=10.0.0.1 +WAN_MASTER_PORT=29500 + +# worker0 使用 host 网络,通过 localhost 连接 Redis +WAN_REDIS_URL_LOCAL=redis://127.0.0.1:6379/0 + +# API 对外端口 +WAN_API_PORT=8008 +``` + +### 3. 启动主节点服务 + +```bash +docker compose up -d --build +``` + +这会启动 3 个容器: + +- **redis** — Redis 数据库(端口 6379 对外暴露,供副节点连接) +- **api** — FastAPI HTTP 服务(端口 8008) +- **worker0** — GPU worker(node_rank=0,从 Redis 队列取任务) + +### 4. 验证主节点 + +```bash +# 检查容器状态 +docker compose ps + +# 检查 API 健康状态 +curl http://localhost:8008/healthz + +# 查看日志 +docker compose logs -f api +docker compose logs -f worker0 +``` + +--- + +## 副节点部署 + +### 1. 克隆仓库(同一代码版本) + +```bash +git clone https://github.com/lin285170/Wan2.2.git +cd Wan2.2 +``` + +### 2. 配置环境变量 + +```bash +cp docker/compose.env.example .env +``` + +编辑 `.env`: + +```bash +# 与主节点保持一致 +WAN_SERVE_API_KEYS=sk-your-secret-key + +# 模型权重父目录(副节点上的路径,与主节点相同的目录结构) +WAN_CKPT_HOST_PATH=/data/models + +# 双节点拓扑 +WAN_NNODES=2 +WAN_NPROC_PER_NODE=4 + +# 主节点实际IP(不是 0.0.0.0,是真实IP) +WAN_MASTER_ADDR=10.0.0.1 +WAN_MASTER_PORT=29500 + +# Redis 连接主节点(关键!) +WAN_REDIS_URL=redis://10.0.0.1:6379/0 +``` + +### 3. 启动副节点服务 + +```bash +docker compose -f docker-compose.worker.yml up -d --build +``` + +这会启动 1 个容器: + +- **worker1** — GPU worker(node_rank=1,通过 Redis pub/sub 接收信号) + +### 4. 验证副节点 + +```bash +docker compose -f docker-compose.worker.yml ps +docker compose -f docker-compose.worker.yml logs -f worker1 +``` + +日志应显示:`Worker node started; waiting for signals on wan:signal` + +--- + +## 发起视频生成请求 + +### 支持的模型 + +| 模型 | model 值 | 必填 input 字段 | 默认 size | 自动 ckpt_dir | +|------|----------|-----------------|-----------|---------------| +| T2V | `wan2.2-t2v-a14b` | prompt | 1280\*720 | `/ckpt/Wan2.2-T2V-A14B` | +| I2V | `wan2.2-i2v-a14b` | prompt + image | 832\*480 | `/ckpt/Wan2.2-I2V-A14B` | +| TI2V | `wan2.2-ti2v-5b` | prompt(image 可选) | 1280\*704 | `/ckpt/Wan2.2-TI2V-5B` | +| Animate | `wan2.2-animate-14b` | prompt + video | 720\*1280 | `/ckpt/Wan2.2-Animate-14B` | +| S2V | `wan2.2-s2v-14b` | prompt + image + audio(或 enable_tts) | 832\*480 | `/ckpt/Wan2.2-S2V-14B` | + +> **模型切换说明**:只需在请求的 `model` 字段指定不同的模型 ID,系统会自动定位对应的权重目录。无需重启服务或修改配置。如需自定义权重路径,可通过 `parameters.ckpt_dir` 覆盖。 + +### T2V — 文本生成视频 + +```bash +curl -X POST http://10.0.0.1:8008/api/v1/video/generation \ + -H "Authorization: Bearer sk-your-secret-key" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "wan2.2-t2v-a14b", + "input": { + "prompt": "A cat walking on a beach at sunset" + }, + "parameters": { + "size": "1280*720", + "frame_num": 81 + } + }' +``` + +### I2V — 图片生成视频 + +```bash +curl -X POST http://10.0.0.1:8008/api/v1/video/generation \ + -H "Authorization: Bearer sk-your-secret-key" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "wan2.2-i2v-a14b", + "input": { + "prompt": "A cat dancing on the beach", + "image": "/ckpt/ref_image.jpg" + }, + "parameters": { + "size": "832*480" + } + }' +``` + +### TI2V — 文本/图片生成视频(5B 轻量模型) + +```bash +curl -X POST http://10.0.0.1:8008/api/v1/video/generation \ + -H "Authorization: Bearer sk-your-secret-key" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "wan2.2-ti2v-5b", + "input": { + "prompt": "A dog running in a park", + "image": "/ckpt/ref_image.jpg" + }, + "parameters": { + "size": "1280*704" + } + }' +``` + +### Animate — 姿态驱动生成视频 + +```bash +curl -X POST http://10.0.0.1:8008/api/v1/video/generation \ + -H "Authorization: Bearer sk-your-secret-key" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "wan2.2-animate-14b", + "input": { + "prompt": "视频中的人在做动作", + "video": "/ckpt/ref_video.mp4" + }, + "parameters": { + "size": "720*1280", + "src_root_path": "/ckpt/animate_input", + "refert_num": 77 + } + }' +``` + +### S2V — 语音驱动生成视频 + +使用音频文件: + +```bash +curl -X POST http://10.0.0.1:8008/api/v1/video/generation \ + -H "Authorization: Bearer sk-your-secret-key" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "wan2.2-s2v-14b", + "input": { + "prompt": "A person talking happily", + "image": "/ckpt/ref_image.jpg", + "audio": "/ckpt/speech.wav" + }, + "parameters": { + "size": "832*480" + } + }' +``` + +使用 TTS 合成语音: + +```bash +curl -X POST http://10.0.0.1:8008/api/v1/video/generation \ + -H "Authorization: Bearer sk-your-secret-key" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "wan2.2-s2v-14b", + "input": { + "prompt": "A person talking happily", + "image": "/ckpt/ref_image.jpg" + }, + "parameters": { + "size": "832*480", + "enable_tts": true, + "tts_prompt_audio": "/ckpt/prompt_voice.wav", + "tts_prompt_text": "希望你以后能够做的比我还好呦。", + "tts_text": "收到好友从远方寄来的生日礼物,那份意外的惊喜让我心中充满了甜蜜的快乐。" + } + }' +``` + +### 通用:返回格式 + +所有请求成功后返回: + +```json +{ + "request_id": "...", + "output": { + "task_id": "wan-..." + } +} +``` + +### 查询任务状态 + +```bash +curl http://10.0.0.1:8008/api/v1/tasks/wan-xxxx \ + -H "Authorization: Bearer sk-your-secret-key" +``` + +### 下载视频 + +```bash +curl http://10.0.0.1:8008/api/v1/files/by-task/wan-xxxx \ + -H "Authorization: Bearer sk-your-secret-key" \ + -o output.mp4 +``` + +--- + +## 工作流程 + +``` +用户 → API(:8008) → Redis 队列 + ↓ + worker0 (master) 从队列取任务 + ↓ + worker0 写入 job JSON + worker0 通过 Redis pub/sub 发信号 → worker1 收到信号 + worker0 本地执行 torchrun (node_rank=0) worker1 本地执行 torchrun (node_rank=1) + ↓ ↓ + 两节点通过 NCCL rendezvous (:29500) 相遇,分布式训练 + ↓ + worker0 更新 Redis 任务状态 → SUCCEEDED + ↓ +用户查询状态 → API 读取 Redis → 返回结果 +``` + +--- + +## 常用运维命令 + +```bash +# 停止所有服务(主节点) +docker compose down + +# 停止所有服务(副节点) +docker compose -f docker-compose.worker.yml down + +# 重启单个服务 +docker compose restart api + +# 查看资源使用 +docker stats + +# 清理并重建 +docker compose down -v # 注意:-v 会删除 Redis 数据卷 +docker compose up -d --build +``` + +--- + +## 单节点模式(测试用) + +如果只有一台机器,修改 `.env`: + +```bash +WAN_NNODES=1 +WAN_NPROC_PER_NODE=4 +``` + +只启动主节点即可(worker1 不需要): + +```bash +docker compose up -d --build +``` \ No newline at end of file diff --git a/DEPLOY_SERVE.md b/DEPLOY_SERVE.md new file mode 100644 index 00000000..9a3dbebb --- /dev/null +++ b/DEPLOY_SERVE.md @@ -0,0 +1,3 @@ +本项目 **部署说明、HTTP 服务与 Docker 编排** 已合并到仓库根目录的 **[README.md](README.md)**。 + +请直接打开 `README.md` 查看完整内容。 diff --git a/README.md b/README.md index 3aa29f54..761b18da 100644 --- a/README.md +++ b/README.md @@ -1,507 +1,430 @@ -# Wan2.2 - -

- -

- -

- 💜 Wan    |    🖥️ GitHub    |   🤗 Hugging Face   |   🤖 ModelScope   |    📑 Paper    |    📑 Blog    |    💬 Discord   -
- 📕 使用指南(中文)   |    📘 User Guide(English)   |   💬 WeChat(微信)   -
- ------ - -[**Wan: Open and Advanced Large-Scale Video Generative Models**](https://arxiv.org/abs/2503.20314) - - -We are excited to introduce **Wan2.2**, a major upgrade to our foundational video models. With **Wan2.2**, we have focused on incorporating the following innovations: - -- 👍 **Effective MoE Architecture**: Wan2.2 introduces a Mixture-of-Experts (MoE) architecture into video diffusion models. By separating the denoising process cross timesteps with specialized powerful expert models, this enlarges the overall model capacity while maintaining the same computational cost. - -- 👍 **Cinematic-level Aesthetics**: Wan2.2 incorporates meticulously curated aesthetic data, complete with detailed labels for lighting, composition, contrast, color tone, and more. This allows for more precise and controllable cinematic style generation, facilitating the creation of videos with customizable aesthetic preferences. - -- 👍 **Complex Motion Generation**: Compared to Wan2.1, Wan2.2 is trained on a significantly larger data, with +65.6% more images and +83.2% more videos. This expansion notably enhances the model's generalization across multiple dimensions such as motions, semantics, and aesthetics, achieving TOP performance among all open-sourced and closed-sourced models. - -- 👍 **Efficient High-Definition Hybrid TI2V**: Wan2.2 open-sources a 5B model built with our advanced Wan2.2-VAE that achieves a compression ratio of **16×16×4**. This model supports both text-to-video and image-to-video generation at 720P resolution with 24fps and can also run on consumer-grade graphics cards like 4090. It is one of the fastest **720P@24fps** models currently available, capable of serving both the industrial and academic sectors simultaneously. - - -## Video Demos - -

- -
- -## 🔥 Latest News!! -* Nov 13, 2025: 👋 Wan2.2-Animate-14B has been integrated into Diffusers ([PR](https://github.com/huggingface/diffusers/pull/12526),[Weights](https://huggingface.co/Wan-AI/Wan2.2-Animate-14B-Diffusers)). Thanks to all community contributors. Enjoy! - -* Sep 19, 2025: 💃 We introduct **[Wan2.2-Animate-14B](https://humanaigc.github.io/wan-animate)**, an unified model for character animation and replacement with holistic movement and expression replication. We released the [model weights](#model-download) and [inference code](#run-wan-animate). And you can try it on [wan.video](https://wan.video/), [ModelScope Studio](https://www.modelscope.cn/studios/Wan-AI/Wan2.2-Animate) or [HuggingFace Space](https://huggingface.co/spaces/Wan-AI/Wan2.2-Animate)! -* Aug 26, 2025: 🎵 We introduce **[Wan2.2-S2V-14B](https://humanaigc.github.io/wan-s2v-webpage)**, an audio-driven cinematic video generation model, including [inference code](#run-speech-to-video-generation), [model weights](#model-download), and [technical report](https://humanaigc.github.io/wan-s2v-webpage/content/wan-s2v.pdf)! Now you can try it on [wan.video](https://wan.video/), [ModelScope Gradio](https://www.modelscope.cn/studios/Wan-AI/Wan2.2-S2V) or [HuggingFace Gradio](https://huggingface.co/spaces/Wan-AI/Wan2.2-S2V)! -* Jul 28, 2025: 👋 We have open a [HF space](https://huggingface.co/spaces/Wan-AI/Wan-2.2-5B) using the TI2V-5B model. Enjoy! -* Jul 28, 2025: 👋 Wan2.2 has been integrated into ComfyUI ([CN](https://docs.comfy.org/zh-CN/tutorials/video/wan/wan2_2) | [EN](https://docs.comfy.org/tutorials/video/wan/wan2_2)). Enjoy! -* Jul 28, 2025: 👋 Wan2.2's T2V, I2V and TI2V have been integrated into Diffusers ([T2V-A14B](https://huggingface.co/Wan-AI/Wan2.2-T2V-A14B-Diffusers) | [I2V-A14B](https://huggingface.co/Wan-AI/Wan2.2-I2V-A14B-Diffusers) | [TI2V-5B](https://huggingface.co/Wan-AI/Wan2.2-TI2V-5B-Diffusers)). Feel free to give it a try! -* Jul 28, 2025: 👋 We've released the inference code and model weights of **Wan2.2**. -* Sep 5, 2025: 👋 We add text-to-speech synthesis support with [CosyVoice](https://github.com/FunAudioLLM/CosyVoice) for Speech-to-Video generation task. - - -## Community Works -If your research or project builds upon [**Wan2.1**](https://github.com/Wan-Video/Wan2.1) or [**Wan2.2**](https://github.com/Wan-Video/Wan2.2), and you would like more people to see it, please inform us. - -- [Prompt Relay](https://github.com/GordonChen19/Prompt-Relay), a plug-and-play, inference-time method for temporal control in video generation. Prompt Relay improves video quality and gives users precise control over what happens at each moment in the video. Visit their [webpage](https://gordonchen19.github.io/Prompt-Relay/) for more details. -- [Helios](https://github.com/PKU-YuanGroup/Helios), a breakthrough video generation model base on **Wan2.1** that achieves minute-scale, high-quality video synthesis at 19.5 FPS on a single H100 GPU (about 10 FPS on a single Ascend NPU) —without relying on conventional long video anti-drifting strategies or standard video acceleration techniques. Visit their [webpage](https://pku-yuangroup.github.io/Helios-Page/) for more details. -- [LightX2V](https://github.com/ModelTC/LightX2V), a lightweight and efficient video generation framework that integrates **Wan2.1** and **Wan2.2**, supporting multiple engineering acceleration techniques for fast inference. [LightX2V-HuggingFace](https://huggingface.co/lightx2v), offers a variety of Wan-based step-distillation models, quantized models, and lightweight VAE models. -- [HuMo](https://github.com/Phantom-video/HuMo) proposed a unified, human-centric framework based on **Wan** to produce high-quality, fine-grained, and controllable human videos from multimodal inputs—including text, images, and audio. Visit their [webpage](https://phantom-video.github.io/HuMo/) for more details. -- [FastVideo](https://github.com/hao-ai-lab/FastVideo) includes distilled **Wan** models with sparse attention that significanly speed up the inference time. -- [Cache-dit](https://github.com/vipshop/cache-dit) offers Fully Cache Acceleration support for **Wan2.2** MoE with DBCache, TaylorSeer and Cache CFG. Visit their [example](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline/run_wan_2.2.py) for more details. -- [Kijai's ComfyUI WanVideoWrapper](https://github.com/kijai/ComfyUI-WanVideoWrapper) is an alternative implementation of **Wan** models for ComfyUI. Thanks to its Wan-only focus, it's on the frontline of getting cutting edge optimizations and hot research features, which are often hard to integrate into ComfyUI quickly due to its more rigid structure. -- [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio) provides comprehensive support for **Wan 2.2**, including low-GPU-memory layer-by-layer offload, FP8 quantization, sequence parallelism, LoRA training, full training. - - -## 📑 Todo List -- Wan2.2 Text-to-Video - - [x] Multi-GPU Inference code of the A14B and 14B models - - [x] Checkpoints of the A14B and 14B models - - [x] ComfyUI integration - - [x] Diffusers integration -- Wan2.2 Image-to-Video - - [x] Multi-GPU Inference code of the A14B model - - [x] Checkpoints of the A14B model - - [x] ComfyUI integration - - [x] Diffusers integration -- Wan2.2 Text-Image-to-Video - - [x] Multi-GPU Inference code of the 5B model - - [x] Checkpoints of the 5B model - - [x] ComfyUI integration - - [x] Diffusers integration -- Wan2.2-S2V Speech-to-Video - - [x] Inference code of Wan2.2-S2V - - [x] Checkpoints of Wan2.2-S2V-14B - - [x] ComfyUI integration - - [x] Diffusers integration -- Wan2.2-Animate Character Animation and Replacement - - [x] Inference code of Wan2.2-Animate - - [x] Checkpoints of Wan2.2-Animate - - [x] ComfyUI integration - - [x] Diffusers integration - -## Run Wan2.2 - -#### Installation -Clone the repo: -```sh -git clone https://github.com/Wan-Video/Wan2.2.git -cd Wan2.2 -``` - -Install dependencies: -```sh -# Ensure torch >= 2.4.0 -# If the installation of `flash_attn` fails, try installing the other packages first and install `flash_attn` last -pip install -r requirements.txt -# If you want to use CosyVoice to synthesize speech for Speech-to-Video Generation, please install requirements_s2v.txt additionally -pip install -r requirements_s2v.txt -``` - - -#### Model Download - -| Models | Download Links | Description | -|--------------------|---------------------------------------------------------------------------------------------------------------------------------------------|-------------| -| T2V-A14B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.2-T2V-A14B) 🤖 [ModelScope](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B) | Text-to-Video MoE model, supports 480P & 720P | -| I2V-A14B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.2-I2V-A14B) 🤖 [ModelScope](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B) | Image-to-Video MoE model, supports 480P & 720P | -| TI2V-5B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.2-TI2V-5B) 🤖 [ModelScope](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B) | High-compression VAE, T2V+I2V, supports 720P | -| S2V-14B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.2-S2V-14B) 🤖 [ModelScope](https://modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B) | Speech-to-Video model, supports 480P & 720P | -| Animate-14B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.2-Animate-14B) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B) | Character animation and replacement | | - - +# Wan2.2 集群推理与 DashScope 风格 HTTP 服务部署指南 -> 💡Note: -> The TI2V-5B model supports 720P video generation at **24 FPS**. +本文说明如何在 **2 台 × 4×A100 40GB**(或任意 `nnodes × nproc_per_node = WORLD_SIZE`)上运行本仓库自带的 **异步任务 API**(兼容 DashScope 习惯的 `Bearer` 鉴权、`POST` 提交、`GET` 轮询任务状态)。 +--- -Download models using huggingface-cli: -``` sh -pip install "huggingface_hub[cli]" -huggingface-cli download Wan-AI/Wan2.2-T2V-A14B --local-dir ./Wan2.2-T2V-A14B -``` - -Download models using modelscope-cli: -``` sh -pip install modelscope -modelscope download Wan-AI/Wan2.2-T2V-A14B --local_dir ./Wan2.2-T2V-A14B -``` - -#### Run Text-to-Video Generation +## 1. 架构说明 -This repository supports the `Wan2.2-T2V-A14B` Text-to-Video model and can simultaneously support video generation at 480P and 720P resolutions. +| 组件 | 职责 | +|------|------| +| **Redis** | 任务队列 `WAN_QUEUE_NAME`、任务元数据 `WAN_TASK_KEY_PREFIX*`、集群互斥锁 `WAN_CLUSTER_LOCK_KEY`(同一时刻只跑一个 `torchrun` 作业)。 | +| **`run_api_server.py` + `serve.api`** | FastAPI:提交任务、查询状态、下载 MP4。 | +| **`python -m serve.worker_main`** | 从队列取 `task_id`,写 `job.json`,调用 `torchrun … generate_job.py`。 | +| **`generate_job.py`** | 读取 JSON,调用 `generate.args_from_job_dict` + `generate.generate`。 | +**重要**:默认实现假设 **只有一个 worker 进程** 在消费队列(全局 GPU 锁)。若启动多个 worker 会抢锁并反复 requeue;多任务并发需改造为每套 GPU 独立队列或 Kubernetes Job。 -##### (1) Without Prompt Extension +--- -To facilitate implementation, we will start with a basic version of the inference process that skips the [prompt extension](#2-using-prompt-extention) step. +## 2. 环境准备(两台 GPU 机 + 一台 API 机可选) -- Single-GPU inference +1. **Python**:与官方 README 一致,`torch>=2.4`,安装 `requirements.txt` + 推理所需依赖。 +2. **服务依赖**(跑 API / worker 的机器): + ```bash + pip install -r requirements_serve.txt + ``` +3. **Redis**:可部署在 API 同机或独立 VM;两台 GPU 机与 API 均需能访问该地址。 +4. **共享存储(强烈推荐)**:NFS 等,两台 GPU 上 **相同绝对路径** 挂载: + - 模型目录 `WAN_CKPT_DIR` + - 任务 JSON 目录 `WAN_JOB_DIR` + - 输出视频目录 `WAN_OUTPUT_DIR` -``` sh -python generate.py --task t2v-A14B --size 1280*720 --ckpt_dir ./Wan2.2-T2V-A14B --offload_model True --convert_model_dtype --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." -``` - -> 💡 This command can run on a GPU with at least 80GB VRAM. +5. **NCCL 双机**:设置 `NCCL_SOCKET_IFNAME`、主机名解析、防火墙放行 `WAN_MASTER_PORT` 及 PyTorch 分布式端口;有 RDMA 时按机房文档配置 IB。 -> 💡If you encounter OOM (Out-of-Memory) issues, you can use the `--offload_model True`, `--convert_model_dtype` and `--t5_cpu` options to reduce GPU memory usage. +--- +## 3. 环境变量参考 -- Multi-GPU inference using FSDP + DeepSpeed Ulysses +### 通用 / API / Worker - We use [PyTorch FSDP](https://docs.pytorch.org/docs/stable/fsdp.html) and [DeepSpeed Ulysses](https://arxiv.org/abs/2309.14509) to accelerate inference. +| 变量 | 说明 | 示例 | +|------|------|------| +| `WAN_SERVE_API_KEYS` | 逗号分隔的 API Key(`Authorization: Bearer `) | `sk-local-xxx,sk-local-yyy` | +| `WAN_REDIS_URL` | Redis 连接串 | `redis://10.0.0.5:6379/0` | +| `WAN_REPO_ROOT` | 本仓库绝对路径 | `/data/Wan2.2` | +| `WAN_JOB_DIR` | 任务 JSON 目录(需共享) | `/mnt/wan/jobs` | +| `WAN_OUTPUT_DIR` | 输出 MP4(需共享) | `/mnt/wan/out` | +| `WAN_CKPT_DIR` | 默认 checkpoint 根目录 | `/mnt/wan/Wan2.2-T2V-A14B` | +### 多机 torchrun(Worker 所在机应能 `ssh` 到第二台时) -``` sh -torchrun --nproc_per_node=8 generate.py --task t2v-A14B --size 1280*720 --ckpt_dir ./Wan2.2-T2V-A14B --dit_fsdp --t5_fsdp --ulysses_size 8 --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." -``` +| 变量 | 说明 | +|------|------| +| `WAN_NNODES` | 节点数,例如 `2` | +| `WAN_NPROC_PER_NODE` | 每节点进程数,例如 `4`(总 8 卡) | +| `WAN_MASTER_ADDR` | rank0 所在机 IP(**第一**台 GPU 机) | +| `WAN_MASTER_PORT` | rendezvous 端口,如 `29500` | +| `WAN_RDZV_PREFIX` | rendezvous id 前缀(会再拼 `task_id`) | +| `WAN_SSH_SECOND_NODE` | 第二台登录串,如 `ubuntu@192.168.1.12` | +| `WAN_SSH_TORCHRUN_PREFIX` | SSH 远端 shell 前缀,默认 `cd {repo_root} && export PYTHONPATH={repo_root}:$PYTHONPATH && ` | +| `WAN_PYTHON` / `WAN_TORCHRUN` | 可选,覆盖可执行文件路径 | +### Prompt 扩展(可选) -##### (2) Using Prompt Extension +若任务 JSON 里 `use_prompt_extend=true` 且 `prompt_extend_method=dashscope`: -Extending the prompts can effectively enrich the details in the generated videos, further enhancing the video quality. Therefore, we recommend enabling prompt extension. We provide the following two methods for prompt extension: +- `DASH_API_KEY` +- 国际站可设 `DASH_API_URL=https://dashscope-intl.aliyuncs.com/api/v1` -- Use the Dashscope API for extension. - - Apply for a `dashscope.api_key` in advance ([EN](https://www.alibabacloud.com/help/en/model-studio/getting-started/first-api-call-to-qwen) | [CN](https://help.aliyun.com/zh/model-studio/getting-started/first-api-call-to-qwen)). - - Configure the environment variable `DASH_API_KEY` to specify the Dashscope API key. For users of Alibaba Cloud's international site, you also need to set the environment variable `DASH_API_URL` to 'https://dashscope-intl.aliyuncs.com/api/v1'. For more detailed instructions, please refer to the [dashscope document](https://www.alibabacloud.com/help/en/model-studio/developer-reference/use-qwen-by-calling-api?spm=a2c63.p38356.0.i1). - - Use the `qwen-plus` model for text-to-video tasks and `qwen-vl-max` for image-to-video tasks. - - You can modify the model used for extension with the parameter `--prompt_extend_model`. For example: -```sh -DASH_API_KEY=your_key torchrun --nproc_per_node=8 generate.py --task t2v-A14B --size 1280*720 --ckpt_dir ./Wan2.2-T2V-A14B --dit_fsdp --t5_fsdp --ulysses_size 8 --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage" --use_prompt_extend --prompt_extend_method 'dashscope' --prompt_extend_target_lang 'zh' -``` +--- -- Using a local model for extension. +## 4. 单机 8 卡(单节点测试) - - By default, the Qwen model on HuggingFace is used for this extension. Users can choose Qwen models or other models based on the available GPU memory size. - - For text-to-video tasks, you can use models like `Qwen/Qwen2.5-14B-Instruct`, `Qwen/Qwen2.5-7B-Instruct` and `Qwen/Qwen2.5-3B-Instruct`. - - For image-to-video tasks, you can use models like `Qwen/Qwen2.5-VL-7B-Instruct` and `Qwen/Qwen2.5-VL-3B-Instruct`. - - Larger models generally provide better extension results but require more GPU memory. - - You can modify the model used for extension with the parameter `--prompt_extend_model` , allowing you to specify either a local model path or a Hugging Face model. For example: - -``` sh -torchrun --nproc_per_node=8 generate.py --task t2v-A14B --size 1280*720 --ckpt_dir ./Wan2.2-T2V-A14B --dit_fsdp --t5_fsdp --ulysses_size 8 --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage" --use_prompt_extend --prompt_extend_method 'local_qwen' --prompt_extend_target_lang 'zh' +```bash +export WAN_SERVE_API_KEYS="sk-dev" +export WAN_REDIS_URL="redis://127.0.0.1:6379/0" +export WAN_REPO_ROOT="/data/Wan2.2" +export WAN_CKPT_DIR="/data/Wan2.2-T2V-A14B" +export WAN_JOB_DIR="/tmp/wan_jobs" +export WAN_OUTPUT_DIR="/tmp/wan_out" +export WAN_NNODES=1 +export WAN_NPROC_PER_NODE=8 +export PYTHONPATH="/data/Wan2.2:$PYTHONPATH" + +# 终端 1 +redis-server & +python run_api_server.py + +# 终端 2(与 API 同机或能访问 Redis 的 GPU 机) +python -m serve.worker_main ``` +提交示例: -#### Run Image-to-Video Generation - -This repository supports the `Wan2.2-I2V-A14B` Image-to-Video model and can simultaneously support video generation at 480P and 720P resolutions. - - -- Single-GPU inference -```sh -python generate.py --task i2v-A14B --size 1280*720 --ckpt_dir ./Wan2.2-I2V-A14B --offload_model True --convert_model_dtype --image examples/i2v_input.JPG --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." +```bash +curl -sS -X POST "http://127.0.0.1:8008/api/v1/video/generation" \ + -H "Authorization: Bearer sk-dev" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "wan2.2-t2v-a14b", + "input": { "prompt": "A cat walking on grass." }, + "parameters": { + "size": "1280*720", + "dit_fsdp": true, + "t5_fsdp": true, + "ulysses_size": 8, + "offload_model": false, + "convert_model_dtype": true + } + }' ``` -> This command can run on a GPU with at least 80GB VRAM. - -> 💡For the Image-to-Video task, the `size` parameter represents the area of the generated video, with the aspect ratio following that of the original input image. - +查询与下载(将 `TASK_ID` 换成响应里的 `task_id`): -- Multi-GPU inference using FSDP + DeepSpeed Ulysses +```bash +curl -sS -H "Authorization: Bearer sk-dev" \ + "http://127.0.0.1:8008/api/v1/tasks/TASK_ID" -```sh -torchrun --nproc_per_node=8 generate.py --task i2v-A14B --size 1280*720 --ckpt_dir ./Wan2.2-I2V-A14B --image examples/i2v_input.JPG --dit_fsdp --t5_fsdp --ulysses_size 8 --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." +curl -L -o out.mp4 -H "Authorization: Bearer sk-dev" \ + "http://127.0.0.1:8008/api/v1/files/by-task/TASK_ID" ``` -- Image-to-Video Generation without prompt +--- -```sh -DASH_API_KEY=your_key torchrun --nproc_per_node=8 generate.py --task i2v-A14B --size 1280*720 --ckpt_dir ./Wan2.2-I2V-A14B --prompt '' --image examples/i2v_input.JPG --dit_fsdp --t5_fsdp --ulysses_size 8 --use_prompt_extend --prompt_extend_method 'dashscope' -``` +## 5. 双机 2×4 卡 A100(推荐生产形态) -> 💡The model can generate videos solely from the input image. You can use prompt extension to generate prompt from the image. +下文假设 **GPU 节点 0**(主节点,跑 Worker + 本地 `torchrun`)与 **GPU 节点 1**(从节点,仅通过 SSH 被拉起 `torchrun`)各 **4×A100 40GB**,合计 **8 卡** 跑 `t2v-A14B` / `i2v-A14B` 等需 `WORLD_SIZE=8` 且 `ulysses_size=8` 的任务。`serve/launcher.py` 在 `WAN_NNODES>1` 时会在 **节点 0 本机** 启动 `torchrun`,并通过 **SSH** 在 **节点 1** 启动 **完全相同** 的一条 `torchrun` 命令,由 PyTorch **c10d rendezvous** 完成组网。 -> The process of prompt extension can be referenced [here](#2-using-prompt-extention). +### 5.1 拓扑与角色 -#### Run Text-Image-to-Video Generation +| 角色 | 建议部署位置 | 说明 | +|------|----------------|------| +| **Redis** | 第三台小规格机器、或节点 0、或托管云服务 | API 与 Worker 均需 `WAN_REDIS_URL` 可达。 | +| **HTTP API** | 任意能访问 Redis 的机器(可无 GPU) | `run_api_server.py`,对客户端暴露 `8008`。 | +| **GPU Worker** | **仅节点 0 上跑一个进程** | `python -m serve.worker_main`;默认全局 GPU 锁,不要双机各起一个 Worker 消费同一队列。 | +| **推理进程** | 节点 0:本地 `torchrun`;节点 1:经 SSH 启动的 `torchrun` | 两机 `torchrun` 参数一致,`--rdzv_endpoint` 指向 **节点 0 可达 IP**。 | -This repository supports the `Wan2.2-TI2V-5B` Text-Image-to-Video model and can support video generation at 720P resolutions. +### 5.2 网络与主机名 +1. 为两机分配固定内网 IP,例如:节点 0 → `10.0.0.10`,节点 1 → `10.0.0.11`。 +2. `WAN_MASTER_ADDR` 必须填 **节点 0 上对节点 1 可达的 IP**(通常即 `10.0.0.10`),**不要**填 `127.0.0.1`。 +3. 开放防火墙:**`WAN_MASTER_PORT`(如 29500)** 以及 PyTorch/NCCL 可能使用的端口段(或先临时放宽双机间 TCP 以便联调)。 +4. 若跨机 RDMA,按机房规范配置 IB;仅用 TCP 时可先设 `export NCCL_IB_DISABLE=1` 排除 IB 干扰(性能会下降,仅用于排障)。 -- Single-GPU Text-to-Video inference -```sh -python generate.py --task ti2v-5B --size 1280*704 --ckpt_dir ./Wan2.2-TI2V-5B --offload_model True --convert_model_dtype --t5_cpu --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage" -``` +### 5.3 共享存储(NFS 或并行文件系统) -> 💡Unlike other tasks, the 720P resolution of the Text-Image-to-Video task is `1280*704` or `704*1280`. +两机对以下路径使用 **同一挂载点、同一绝对路径**(示例均为 `/mnt/wan/...`,可按机房替换): -> This command can run on a GPU with at least 24GB VRAM (e.g, RTX 4090 GPU). +| 路径 | 用途 | +|------|------| +| `WAN_REPO_ROOT`(如 `/mnt/wan/Wan2.2`) | 本仓库代码,两机一致。 | +| `WAN_CKPT_DIR`(如 `/mnt/wan/Wan2.2-T2V-A14B`) | 模型权重只读;Worker 内常为 `/ckpt`,宿主机挂载需与 `WAN_CKPT_DIR` 一致。 | +| `WAN_JOB_DIR` | 任务 JSON;Worker 写入,`job_json` 为 NFS 路径以便两机 `torchrun` 同读。 | +| `WAN_OUTPUT_DIR` | 生成 MP4;仅 rank 0 写盘,放 NFS 便于 API 机或节点 0 取文件。 | -> 💡If you are running on a GPU with at least 80GB VRAM, you can remove the `--offload_model True`, `--convert_model_dtype` and `--t5_cpu` options to speed up execution. +挂载后分别在两机执行:`ls -la $WAN_REPO_ROOT/generate_job.py` 与 `ls $WAN_CKPT_DIR`,确认路径一致、权限可读。 +### 5.4 软件环境(两机必须对齐) -- Single-GPU Image-to-Video inference -```sh -python generate.py --task ti2v-5B --size 1280*704 --ckpt_dir ./Wan2.2-TI2V-5B --offload_model True --convert_model_dtype --t5_cpu --image examples/i2v_input.JPG --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." -``` +1. **操作系统与驱动**:两机安装同一主线版本 **NVIDIA 驱动**,`nvidia-smi` 正常。 +2. **Python**:建议 **同版本**(如 3.10/3.11),各自 `venv` 或 **同一套 Conda env** 的克隆亦可,关键是 **`torch` 版本与 CUDA 构建一致**。 +3. **依赖**:两机均在 `WAN_REPO_ROOT` 下执行 `pip install -r requirements.txt` 与 `pip install -r requirements_serve.txt`(`flash_attn` 若装不上可先跳过,与单机排障相同)。 +4. **`torchrun` 在 PATH 中**:`which torchrun` 两机均有结果。 -> 💡If the image parameter is configured, it is an Image-to-Video generation; otherwise, it defaults to a Text-to-Video generation. +节点 1 **不跑** `serve.worker_main`,但必须能通过 SSH 执行与节点 0 **相同**的 `torchrun … generate_job.py`,因此节点 1 也需完整 Python 环境与仓库代码(与节点 0 同一路径最省事)。 -> 💡Similar to Image-to-Video, the `size` parameter represents the area of the generated video, with the aspect ratio following that of the original input image. +### 5.5 节点 0 → 节点 1 免密 SSH +在 **节点 0** 上(以运行 Worker 的 Linux 用户执行): -- Multi-GPU inference using FSDP + DeepSpeed Ulysses - -```sh -torchrun --nproc_per_node=8 generate.py --task ti2v-5B --size 1280*704 --ckpt_dir ./Wan2.2-TI2V-5B --dit_fsdp --t5_fsdp --ulysses_size 8 --image examples/i2v_input.JPG --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." +```bash +ssh-keygen -t ed25519 -N "" -f ~/.ssh/id_ed25519_wan -C "wan-worker" +# 将公钥追加到节点 1 的 authorized_keys(把 user、10.0.0.11 换成实际值) +ssh-copy-id -i ~/.ssh/id_ed25519_wan.pub user@10.0.0.11 +# 若使用自定义 key: +ssh -i ~/.ssh/id_ed25519_wan user@10.0.0.11 'hostname' ``` -> The process of prompt extension can be referenced [here](#2-using-prompt-extention). - -#### Run Speech-to-Video Generation +`WAN_SSH_SECOND_NODE` 建议写成 **`user@10.0.0.11`**,与上述 `ssh` 登录串一致。 +生产环境建议在 `serve/launcher.py` 中为 `ssh` 增加 `KnownHostsFile` / 关闭 `StrictHostKeyChecking=no`,避免中间人风险。 -This repository supports the `Wan2.2-S2V-14B` Speech-to-Video model and can simultaneously support video generation at 480P and 720P resolutions. +### 5.6 NCCL 与常见环境变量(两机 Worker 进程继承;SSH 子进程同理) -- Single-GPU Speech-to-Video inference +在 **节点 0** 启动 Worker 的 shell 或 systemd 中可导出(按网卡名修改): -```sh -python generate.py --task s2v-14B --size 1024*704 --ckpt_dir ./Wan2.2-S2V-14B/ --offload_model True --convert_model_dtype --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard." --image "examples/i2v_input.JPG" --audio "examples/talk.wav" -# Without setting --num_clip, the generated video length will automatically adjust based on the input audio length - -# You can use CosyVoice to generate audio with --enable_tts -python generate.py --task s2v-14B --size 1024*704 --ckpt_dir ./Wan2.2-S2V-14B/ --offload_model True --convert_model_dtype --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard." --image "examples/i2v_input.JPG" --enable_tts --tts_prompt_audio "examples/zero_shot_prompt.wav" --tts_prompt_text "希望你以后能够做的比我还好呦。" --tts_text "收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。" +```bash +export NCCL_SOCKET_IFNAME=eth0 # 或 ens、bond0 等,双机互通的网卡 +export NCCL_DEBUG=WARN # 排障时可改为 INFO +# export NCCL_IB_DISABLE=1 # 无 IB 或联调时可开 ``` -> 💡 This command can run on a GPU with at least 80GB VRAM. +### 5.7 双机专用环境变量(节点 0 上配置) -- Multi-GPU inference using FSDP + DeepSpeed Ulysses +以下变量在 **跑 `python -m serve.worker_main` 的节点 0** 上设置(可写入 `/etc/default/wan-worker` 或 systemd `Environment=`): -```sh -torchrun --nproc_per_node=8 generate.py --task s2v-14B --size 1024*704 --ckpt_dir ./Wan2.2-S2V-14B/ --dit_fsdp --t5_fsdp --ulysses_size 8 --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard." --image "examples/i2v_input.JPG" --audio "examples/talk.wav" +```bash +export WAN_NNODES=2 +export WAN_NPROC_PER_NODE=4 +export WAN_MASTER_ADDR=10.0.0.10 # 节点 0 对外的内网 IP +export WAN_MASTER_PORT=29500 +export WAN_SSH_SECOND_NODE=user@10.0.0.11 + +export WAN_REPO_ROOT=/mnt/wan/Wan2.2 +export PYTHONPATH=/mnt/wan/Wan2.2:$PYTHONPATH +export WAN_CKPT_DIR=/mnt/wan/Wan2.2-T2V-A14B +export WAN_JOB_DIR=/mnt/wan/jobs +export WAN_OUTPUT_DIR=/mnt/wan/out + +export WAN_REDIS_URL=redis://10.0.0.5:6379/0 +export WAN_SERVE_API_KEYS=sk-your-secret + +# 可选:保持默认即可;{repo_root} 会替换为 WAN_REPO_ROOT +# export WAN_SSH_TORCHRUN_PREFIX='cd {repo_root} && export PYTHONPATH={repo_root}:$PYTHONPATH && ' ``` -- Pose + Audio driven generation +说明: -```sh -torchrun --nproc_per_node=8 generate.py --task s2v-14B --size 1024*704 --ckpt_dir ./Wan2.2-S2V-14B/ --dit_fsdp --t5_fsdp --ulysses_size 8 --prompt "a person is singing" --image "examples/pose.png" --audio "examples/sing.MP3" --pose_video "./examples/pose.mp4" -``` +- **`WAN_NNODES` × `WAN_NPROC_PER_NODE` = 8** 时,任务 JSON / API 里 **`ulysses_size` 必须为 8**,且 **`dit_fsdp` / `t5_fsdp`** 与官方多卡示例一致。 +- **`WAN_MASTER_PORT`** 在每次作业中由 `rdzv_id`(含 `task_id`)区分不同 rendezvous;端口需空闲。 +- **`WAN_SSH_TORCHRUN_PREFIX`** 中的 **`{repo_root}`** 由程序替换为 `WAN_REPO_ROOT` 的绝对路径(见 `serve/config.py`)。 -> 💡For the Speech-to-Video task, the `size` parameter represents the area of the generated video, with the aspect ratio following that of the original input image. +### 5.8 启动顺序(推荐) -> 💡The model can generate videos from audio input combined with reference image and optional text prompt. +1. **启动 Redis**(若尚未运行)。 +2. **启动 API**(可在无 GPU 的机器上): + `export WAN_REDIS_URL=...` 等与队列、路径相关变量后执行 `python run_api_server.py`。 +3. **仅在节点 0 启动 Worker**: + ```bash + cd "$WAN_REPO_ROOT" + export PYTHONPATH="$WAN_REPO_ROOT:$PYTHONPATH" + python -m serve.worker_main + ``` +4. 用 **curl** 提交一条任务(见上文 §4),观察 Worker 日志:应先出现本地 `torchrun`,约 2 秒后出现 SSH 在节点 1 起的第二条 `torchrun`,最后 rank 0 写 `save_file`。 -> 💡The `--pose_video` parameter enables pose-driven generation, allowing the model to follow specific pose sequences while generating videos synchronized with audio input. +### 5.9 行为说明(与源码一致) -> 💡The `--num_clip` parameter controls the number of video clips generated, useful for quick preview with shorter generation time. +`serve/launcher.py` 在 `WAN_NNODES>1` 时: -Please visit our project page to see more examples and learn about the scenarios suitable for this model. +1. 用 `subprocess.Popen` 在 **节点 1** 上执行: + `ssh … user@node1 'bash -lc ""'` +2. **约 2 秒** 后在 **节点 0** 上 `subprocess.run` 同样的 `torchrun` 命令。 +3. 两条命令中的 **`--job_json` 为 NFS 上的同一文件**;**`--rdzv_id` 每次作业唯一**(含 `task_id`),避免与历史进程冲突。 -#### Run Wan-Animate +### 5.10 排障清单 -Wan-Animate takes a video and a character image as input, and generates a video in either "animation" or "replacement" mode. +| 现象 | 检查项 | +|------|--------| +| SSH 失败 | 节点 0 上手动 `ssh user@node1`;`ssh-agent`、私钥权限、`authorized_keys`。 | +| rendezvous 超时 / 挂住 | `WAN_MASTER_ADDR` 是否可从节点 1 `telnet`/`nc -zv` 到端口;防火墙;两机时钟是否大致同步(建议 NTP)。 | +| NCCL 报错 | `NCCL_SOCKET_IFNAME`;必要时 `NCCL_IB_DISABLE=1` 试跑。 | +| 节点 1 找不到模块 | 节点 1 上 `PYTHONPATH` 与 `cd` 是否与 `WAN_SSH_TORCHRUN_PREFIX` 一致;`pip show torch`。 | +| 仅单机起进程 | `WAN_SSH_SECOND_NODE` 是否为空;`WAN_NNODES` 是否仍为 1。 | +| OOM / 显存 | 40GB×4 跑 A14B 需 FSDP+Ulysses 与合适 `offload_model` / `convert_model_dtype`,与官方 README 多卡说明一致。 | -1. animation mode: The model generates a video of the character image that mimics the human motion in the input video. -2. replacement mode: The model replaces the character image with the input video. +### 5.11 与 Docker 的关系 -Please visit our [project page](https://humanaigc.github.io/wan-animate) to see more examples and learn about the scenarios suitable for this model. +`docker-compose.yml` 默认描述 **单机多卡容器**。双机物理机 + SSH `torchrun` 时,通常做法是:**不在节点 1 上再跑一个消费同一 Redis 队列的 Worker 容器**;仅在 **节点 0** 起 Worker(裸机或单容器),并配置 `WAN_SSH_SECOND_NODE` 指向节点 1 的 **SSH 可达地址**,且两机挂载 **同一 NFS** 到相同路径。若两机都跑在容器内,还需保证 **容器到容器/宿主 SSH**、以及 **容器内 `WAN_MASTER_ADDR` 对另一机可见**(常用 host 网络或显式端口映射,视编排而定)。 -##### (1) Preprocessing -The input video should be preprocessed into several materials before be feed into the inference process. Please refer to the following processing flow, and more details about preprocessing can be found in [UserGuider](https://github.com/Wan-Video/Wan2.2/blob/main/wan/modules/animate/preprocess/UserGuider.md). +--- -* For animation -```bash -python ./wan/modules/animate/preprocess/preprocess_data.py \ - --ckpt_path ./Wan2.2-Animate-14B/process_checkpoint \ - --video_path ./examples/wan_animate/animate/video.mp4 \ - --refer_path ./examples/wan_animate/animate/image.jpeg \ - --save_path ./examples/wan_animate/animate/process_results \ - --resolution_area 1280 720 \ - --retarget_flag \ - --use_flux -``` -* For replacement -```bash -python ./wan/modules/animate/preprocess/preprocess_data.py \ - --ckpt_path ./Wan2.2-Animate-14B/process_checkpoint \ - --video_path ./examples/wan_animate/replace/video.mp4 \ - --refer_path ./examples/wan_animate/replace/image.jpeg \ - --save_path ./examples/wan_animate/replace/process_results \ - --resolution_area 1280 720 \ - --iterations 3 \ - --k 7 \ - --w_len 1 \ - --h_len 1 \ - --replace_flag -``` -##### (2) Run in animation mode +## 6. 任务 JSON 与 `model` 别名 -* Single-GPU inference +HTTP 请求体会被合并为 `generate.args_from_job_dict` 可接受的字典: -```bash -python generate.py --task animate-14B --ckpt_dir ./Wan2.2-Animate-14B/ --src_root_path ./examples/wan_animate/animate/process_results/ --refert_num 1 -``` +- 顶层 **`model`** 可为:`wan2.2-t2v-a14b`、`wan2.2-i2v-a14b`、`wan2.2-ti2v-5b`、`wan2.2-s2v-14b`、`wan2.2-animate-14b`,或直接 `WAN_CONFIGS` 里的 `task` 字符串。 +- 其余字段与 `generate.py` 命令行一致,嵌套在 `input` / `parameters` 中亦可。 +- `sample_guide_scale` 可为 **单个 float** 或 **两个 float 的数组**(低/高噪声专家)。 -* Multi-GPU inference using FSDP + DeepSpeed Ulysses +直接调用 `generate_job.py`(不经 HTTP)示例: ```bash -python -m torch.distributed.run --nnodes 1 --nproc_per_node 8 generate.py --task animate-14B --ckpt_dir ./Wan2.2-Animate-14B/ --src_root_path ./examples/wan_animate/animate/process_results/ --refert_num 1 --dit_fsdp --t5_fsdp --ulysses_size 8 -``` +cat > /mnt/wan/jobs/manual.json <<'EOF' +{ + "model": "wan2.2-t2v-a14b", + "ckpt_dir": "/mnt/wan/Wan2.2-T2V-A14B", + "save_file": "/mnt/wan/out/manual.mp4", + "prompt": "Two cats boxing on stage.", + "size": "1280*720", + "dit_fsdp": true, + "t5_fsdp": true, + "ulysses_size": 8, + "convert_model_dtype": true, + "offload_model": false +} +EOF -* Diffusers Pipeline - -```python -from diffusers import WanAnimatePipeline -from diffusers.utils import export_to_video, load_image, load_video - -device = "cuda:0" -dtype = torch.bfloat16 -model_id = "Wan-AI/Wan2.2-Animate-14B-Diffusers" -pipe = WanAnimatePipeline.from_pretrained(model_id torch_dtype=dtype) -pipe.to(device) - -seed = 42 -prompt = "People in the video are doing actions." - -# Animation -image = load_image("/path/to/animate/reference/image/src_ref.png") -pose_video = load_video("/path/to/animate/pose/video/src_pose.mp4") -face_video = load_video("/path/to/animate/face/video/src_face.mp4") - -animate_video = pipe( - image=image, - pose_video=pose_video, - face_video=face_video, - prompt=prompt, - mode="animate", - segment_frame_length=77, # clip_len in original code - prev_segment_conditioning_frames=1, # refert_num in original code - guidance_scale=1.0, - num_inference_steps=20, - generator=torch.Generator(device=device).manual_seed(seed), -).frames[0] -export_to_video(animate_video, "diffusers_animate.mp4", fps=30) +torchrun --nnodes=1 --nproc_per_node=8 --rdzv_backend=c10d \ + --rdzv_endpoint=127.0.0.1:29501 --rdzv_id=manual1 \ + /mnt/wan/Wan2.2/generate_job.py --job_json /mnt/wan/jobs/manual.json ``` -##### (3) Run in replacement mode +--- -* Single-GPU inference +## 7. systemd 示例(API) -```bash -python generate.py --task animate-14B --ckpt_dir ./Wan2.2-Animate-14B/ --src_root_path ./examples/wan_animate/replace/process_results/ --refert_num 1 --replace_flag --use_relighting_lora -``` +`/etc/systemd/system/wan-api.service`: -* Multi-GPU inference using FSDP + DeepSpeed Ulysses +```ini +[Unit] +Description=Wan2.2 HTTP API +After=network.target -```bash -python -m torch.distributed.run --nnodes 1 --nproc_per_node 8 generate.py --task animate-14B --ckpt_dir ./Wan2.2-Animate-14B/ --src_root_path ./examples/wan_animate/replace/process_results/src_pose.mp4 --refert_num 1 --replace_flag --use_relighting_lora --dit_fsdp --t5_fsdp --ulysses_size 8 -``` +[Service] +User=wan +WorkingDirectory=/mnt/wan/Wan2.2 +Environment=PYTHONPATH=/mnt/wan/Wan2.2 +Environment=WAN_SERVE_API_KEYS=sk-prod-xxx +Environment=WAN_REDIS_URL=redis://127.0.0.1:6379/0 +Environment=WAN_CKPT_DIR=/mnt/wan/Wan2.2-T2V-A14B +Environment=WAN_JOB_DIR=/mnt/wan/jobs +Environment=WAN_OUTPUT_DIR=/mnt/wan/out +Environment=WAN_REPO_ROOT=/mnt/wan/Wan2.2 +ExecStart=/mnt/wan/venv/bin/python /mnt/wan/Wan2.2/run_api_server.py +Restart=on-failure -* Diffusers Pipeline - -```python -# create pipeline as in the Animation code ☝️ - -# Replacement -image = load_image("/path/to/replace/reference/image/src_ref.png") -pose_video = load_video("/path/to/replace/pose/video/src_pose.mp4") -face_video = load_video("/path/to/replace/face/video/src_face.mp4") -background_video = load_video("/path/to/replace/background/video/src_bg.mp4") -mask_video = load_video("/path/to/replace/mask/video/src_mask.mp4") - -replace_video = pipe( - image=image, - pose_video=pose_video, - face_video=face_video, - background_video=background_video, - mask_video=mask_video, - prompt=prompt, - mode="replace", - segment_frame_length=77, # clip_len in original code - prev_segment_conditioning_frames=1, # refert_num in original code - guidance_scale=1.0, - num_inference_steps=20, - generator=torch.Generator(device=device).manual_seed(seed), -).frames[0] -export_to_video(replace_video, "diffusers_replace.mp4", fps=30) +[Install] +WantedBy=multi-user.target ``` -> 💡 If you're using **Wan-Animate**, we do not recommend using LoRA models trained on `Wan2.2`, since weight changes during training may lead to unexpected behavior. - -## Computational Efficiency on Different GPUs +Worker 类似,将 `ExecStart` 改为 `python -m serve.worker_main`,并在 GPU 节点 0 上运行。 -We test the computational efficiency of different **Wan2.2** models on different GPUs in the following table. The results are presented in the format: **Total time (s) / peak GPU memory (GB)**. +--- +## 8. 安全与运维建议 -
- -
+- 仅内网暴露 API,或前置 mTLS / 零信任网关。 +- 定期轮换 `WAN_SERVE_API_KEYS`。 +- 大模型与生成结果路径做磁盘配额与清理任务。 +- 监控 Redis 队列长度、worker 日志、`torchrun` 退出码。 -> The parameter settings for the tests presented in this table are as follows: -> (1) Multi-GPU: 14B: `--ulysses_size 4/8 --dit_fsdp --t5_fsdp`, 5B: `--ulysses_size 4/8 --offload_model True --convert_model_dtype --t5_cpu`; Single-GPU: 14B: `--offload_model True --convert_model_dtype`, 5B: `--offload_model True --convert_model_dtype --t5_cpu` -(--convert_model_dtype converts model parameter types to config.param_dtype); -> (2) The distributed testing utilizes the built-in FSDP and Ulysses implementations, with FlashAttention3 deployed on Hopper architecture GPUs; -> (3) Tests were run without the `--use_prompt_extend` flag; -> (4) Reported results are the average of multiple samples taken after the warm-up phase. +--- +## 9. 容器化部署(Docker Compose) -------- +仓库提供 **CPU 版 API 镜像** 与 **GPU Worker 镜像**,由 `docker-compose.yml` 编排 Redis、API、Worker。 -## Introduction of Wan2.2 +### 9.1 前置条件 -**Wan2.2** builds on the foundation of Wan2.1 with notable improvements in generation quality and model capability. This upgrade is driven by a series of key technical innovations, mainly including the Mixture-of-Experts (MoE) architecture, upgraded training data, and high-compression video generation. +- 已安装 [Docker](https://docs.docker.com/engine/install/) 与 [Docker Compose V2](https://docs.docker.com/compose/)。 +- **Worker 所在宿主机** 安装 [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html),并可用 `docker run --rm --gpus all nvidia/cuda:12.4.1-base-ubuntu22.04 nvidia-smi` 验证。 +- 将官方权重下载到宿主机目录,例如 `/data/Wan2.2-T2V-A14B`,供 **只读** 挂载到 Worker 容器的 `/ckpt`。 -##### (1) Mixture-of-Experts (MoE) Architecture +### 9.2 配置与启动 -Wan2.2 introduces Mixture-of-Experts (MoE) architecture into the video generation diffusion model. MoE has been widely validated in large language models as an efficient approach to increase total model parameters while keeping inference cost nearly unchanged. In Wan2.2, the A14B model series adopts a two-expert design tailored to the denoising process of diffusion models: a high-noise expert for the early stages, focusing on overall layout; and a low-noise expert for the later stages, refining video details. Each expert model has about 14B parameters, resulting in a total of 27B parameters but only 14B active parameters per step, keeping inference computation and GPU memory nearly unchanged. +```bash +cd /path/to/Wan2.2 +cp docker/compose.env.example .env +# 编辑 .env:至少设置 WAN_SERVE_API_KEYS、WAN_CKPT_HOST_PATH +``` -
- -
+仅启动 **Redis + API**(开发机无 GPU时): -The transition point between the two experts is determined by the signal-to-noise ratio (SNR), a metric that decreases monotonically as the denoising step $t$ increases. At the beginning of the denoising process, $t$ is large and the noise level is high, so the SNR is at its minimum, denoted as ${SNR}_{min}$. In this stage, the high-noise expert is activated. We define a threshold step ${t}_{moe}$ corresponding to half of the ${SNR}_{min}$, and switch to the low-noise expert when $t<{t}_{moe}$. +```bash +docker compose up -d --build redis api +``` -
- -
+在 **带 NVIDIA GPU 的机器** 上启动完整栈(含 Worker,使用 Compose `gpu` profile): -To validate the effectiveness of the MoE architecture, four settings are compared based on their validation loss curves. The baseline **Wan2.1** model does not employ the MoE architecture. Among the MoE-based variants, the **Wan2.1 & High-Noise Expert** reuses the Wan2.1 model as the low-noise expert while uses the Wan2.2's high-noise expert, while the **Wan2.1 & Low-Noise Expert** uses Wan2.1 as the high-noise expert and employ the Wan2.2's low-noise expert. The **Wan2.2 (MoE)** (our final version) achieves the lowest validation loss, indicating that its generated video distribution is closest to ground-truth and exhibits superior convergence. +```bash +docker compose --profile gpu up -d --build +``` +常用命令: -##### (2) Efficient High-Definition Hybrid TI2V -To enable more efficient deployment, Wan2.2 also explores a high-compression design. In addition to the 27B MoE models, a 5B dense model, i.e., TI2V-5B, is released. It is supported by a high-compression Wan2.2-VAE, which achieves a $T\times H\times W$ compression ratio of $4\times16\times16$, increasing the overall compression rate to 64 while maintaining high-quality video reconstruction. With an additional patchification layer, the total compression ratio of TI2V-5B reaches $4\times32\times32$. Without specific optimization, TI2V-5B can generate a 5-second 720P video in under 9 minutes on a single consumer-grade GPU, ranking among the fastest 720P@24fps video generation models. This model also natively supports both text-to-video and image-to-video tasks within a single unified framework, covering both academic research and practical applications. +```bash +docker compose logs -f api worker +docker compose ps +``` +API 默认映射到宿主机 `WAN_API_PORT`(默认 `8008`)。健康检查:`GET http://:8008/healthz`。 -
- -
+### 9.3 数据卷说明 +| 卷名 | 挂载点 | 说明 | +|------|--------|------| +| `wan_shared` | 容器内 `/data` | `jobs` → `/data/jobs`,`outputs` → `/data/outputs`;API 与 Worker 共享,用于任务 JSON 与生成视频。 | +| 绑定挂载 | `/ckpt` | 来自 `.env` 的 `WAN_CKPT_HOST_PATH`,只读挂载到 Worker。 | +### 9.4 镜像构建参数(Worker) -##### Comparisons to SOTAs -We compared Wan2.2 with leading closed-source commercial models on our new Wan-Bench 2.0, evaluating performance across multiple crucial dimensions. The results demonstrate that Wan2.2 achieves superior performance compared to these leading models. +| 构建参数 | 默认 | 说明 | +|----------|------|------| +| `BASE_IMAGE` | `pytorch/pytorch:2.5.1-cuda12.4-cudnn9-runtime` | 可按机房 CUDA 版本替换为官方 PyTorch 标签。 | +| `INSTALL_FLASH_ATTN` | `0` | 设为 `1` 时尝试安装 `flash_attn`(需与基础镜像 CUDA 匹配,失败时构建仍可能继续)。 | +示例: -
- -
+```bash +docker build -f docker/Dockerfile.worker \ + --build-arg INSTALL_FLASH_ATTN=1 \ + -t wan2-worker:latest . +``` -## Citation -If you find our work helpful, please cite us. +### 9.5 双机 GPU 与 Compose -``` -@article{wan2025, - title={Wan: Open and Advanced Large-Scale Video Generative Models}, - author={Team Wan and Ang Wang and Baole Ai and Bin Wen and Chaojie Mao and Chen-Wei Xie and Di Chen and Feiwu Yu and Haiming Zhao and Jianxiao Yang and Jianyuan Zeng and Jiayu Wang and Jingfeng Zhang and Jingren Zhou and Jinkai Wang and Jixuan Chen and Kai Zhu and Kang Zhao and Keyu Yan and Lianghua Huang and Mengyang Feng and Ningyi Zhang and Pandeng Li and Pingyu Wu and Ruihang Chu and Ruili Feng and Shiwei Zhang and Siyang Sun and Tao Fang and Tianxing Wang and Tianyi Gui and Tingyu Weng and Tong Shen and Wei Lin and Wei Wang and Wei Wang and Wenmeng Zhou and Wente Wang and Wenting Shen and Wenyuan Yu and Xianzhong Shi and Xiaoming Huang and Xin Xu and Yan Kou and Yangyu Lv and Yifei Li and Yijing Liu and Yiming Wang and Yingya Zhang and Yitong Huang and Yong Li and You Wu and Yu Liu and Yulin Pan and Yun Zheng and Yuntao Hong and Yupeng Shi and Yutong Feng and Zeyinzi Jiang and Zhen Han and Zhi-Fan Wu and Ziyu Liu}, - journal = {arXiv preprint arXiv:2503.20314}, - year={2025} -} -``` +`docker-compose.yml` 描述的是 **单机上的多卡容器**。若要在 **两台物理机** 各跑 4 卡并沿用现有 `serve.launcher` 的 SSH 双机 `torchrun`: -## License Agreement -The models in this repository are licensed under the Apache 2.0 License. We claim no rights over the your generated contents, granting you the freedom to use them while ensuring that your usage complies with the provisions of this license. You are fully accountable for your use of the models, which must not involve sharing any content that violates applicable laws, causes harm to individuals or groups, disseminates personal information intended for harm, spreads misinformation, or targets vulnerable populations. For a complete list of restrictions and details regarding your rights, please refer to the full text of the [license](LICENSE.txt). +1. 两台机器安装 Docker + NVIDIA Toolkit,**同一 NFS** 挂载到相同路径(含代码、权重、`WAN_JOB_DIR` / `WAN_OUTPUT_DIR`)。 +2. 在 **节点 0** 上可仍用 Compose 起 Redis(或外置托管 Redis),API 与 Worker 容器;在 **节点 1** 仅起 **Worker 容器**(或不用 Compose,直接 `docker run`),两台 Worker 不要同时消费同一队列——当前设计为 **单 worker 消费**;双机多卡推荐 **只在节点 0 起一个 Worker 容器**,并在 `.env` 中配置 `WAN_NNODES=2`、`WAN_NPROC_PER_NODE=4`、`WAN_MASTER_ADDR`、`WAN_SSH_SECOND_NODE`,由容器内 `torchrun` + SSH 拉起第二台进程(需节点 0 容器能 SSH 到节点 1,且节点 1 已安装相同镜像或具备相同 Python/torch 环境)。 +更稳妥的生产方式是将 **Redis + API** 托管在控制面,**每台 GPU 机** 用 `docker run` 或 Kubernetes Job 只跑 `wan2-worker`,并改造队列分区;超出本文范围时可单独扩展。 -## Acknowledgements +### 9.6 相关文件 -We would like to thank the contributors to the [SD3](https://huggingface.co/stabilityai/stable-diffusion-3-medium), [Qwen](https://huggingface.co/Qwen), [umt5-xxl](https://huggingface.co/google/umt5-xxl), [diffusers](https://github.com/huggingface/diffusers) and [HuggingFace](https://huggingface.co) repositories, for their open research. +| 路径 | 说明 | +|------|------| +| `docker-compose.yml` | Redis、api、worker 服务定义 | +| `docker/Dockerfile.api` | 仅 FastAPI 依赖的轻量 API 镜像 | +| `docker/Dockerfile.worker` | CUDA + Wan 推理 + `serve.worker` | +| `docker/entrypoint-worker.sh` | Worker 入口 | +| `docker/compose.env.example` | 复制为仓库根目录 `.env` 的模板 | +| `.dockerignore` | 减小构建上下文 | +| `README.md` | 部署与 HTTP 服务主文档(本文件) | +| `DEPLOY_SERVE.md` | 历史/外链兼容:仅指向 `README.md` | +--- +## 10. 代码变更摘要 -## Contact Us -If you would like to leave a message to our research or product teams, feel free to join our [Discord](https://discord.gg/AKNgpMK4Yj) or [WeChat groups](https://gw.alicdn.com/imgextra/i2/O1CN01tqjWFi1ByuyehkTSB_!!6000000000015-0-tps-611-1279.jpg)! +| 路径 | 说明 | +|------|------| +| `README.md` | 部署、DashScope 风格 API、Docker 主文档 | +| `generate.py` | `_build_parser` / `parse_args` / `args_from_job_dict` / `JOB_MODEL_ALIASES` | +| `generate_job.py` | `torchrun` 入口,读 `--job_json` | +| `serve/` | FastAPI、Redis、launcher、worker | +| `run_api_server.py` | 开发用 uvicorn 启动 | +| `requirements_serve.txt` | API 额外依赖 | +| `docker-compose.yml` / `docker/*` | 容器化编排与镜像 | +若需 **HTTPS、限流、多队列、回调 Webhook**,可在 `serve/api.py` 外再包一层网关或扩展本模块。 diff --git a/docker-compose.worker.yml b/docker-compose.worker.yml new file mode 100644 index 00000000..c76a6b50 --- /dev/null +++ b/docker-compose.worker.yml @@ -0,0 +1,59 @@ +# Wan2.2 Worker Node: Worker1 (4×A100 40GB) — secondary node. +# +# This node receives torchrun signals via Redis pub/sub from the master. +# Both nodes run torchrun locally and rendezvous via NCCL/TCP. +# +# IMPORTANT: worker1 uses network_mode: host so torchrun/NCCL can bind +# ports directly on the host. Redis is accessed via the master node's IP. +# +# Prerequisites: NVIDIA Container Toolkit; Docker Compose v2. +# +# Usage: +# cp docker/compose.env.example .env +# # edit .env — set WAN_SERVE_API_KEYS, WAN_REDIS_URL (master Redis), +# # WAN_CKPT_HOST_PATH, WAN_MASTER_ADDR (master node IP) +# +# docker compose -f docker-compose.worker.yml up -d --build +# +# Logs: docker compose -f docker-compose.worker.yml logs -f worker1 + +services: + worker1: + build: + context: . + dockerfile: docker/Dockerfile.worker + restart: unless-stopped + shm_size: "16g" + network_mode: host + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: all + capabilities: [gpu] + environment: + # Redis on master node — use the master's real IP (not container hostname) + WAN_REDIS_URL: ${WAN_REDIS_URL:?set WAN_REDIS_URL to master Redis, e.g. redis://MASTER_IP:6379/0} + WAN_SERVE_API_KEYS: ${WAN_SERVE_API_KEYS} + WAN_REPO_ROOT: /app + WAN_CKPT_DIR: /ckpt + WAN_JOB_DIR: /data/jobs + WAN_OUTPUT_DIR: /data/outputs + WAN_NODE_ROLE: worker + WAN_NNODES: ${WAN_NNODES:-2} + WAN_NPROC_PER_NODE: ${WAN_NPROC_PER_NODE:-4} + WAN_NODE_RANK: 1 + # Master node's real IP — both nodes rendezvous here + WAN_MASTER_ADDR: ${WAN_MASTER_ADDR:?set WAN_MASTER_ADDR to master node IP} + WAN_MASTER_PORT: ${WAN_MASTER_PORT:-29500} + WAN_RDZV_PREFIX: ${WAN_RDZV_PREFIX:-wan} + NVIDIA_VISIBLE_DEVICES: all + NVIDIA_DRIVER_CAPABILITIES: compute,utility + NCCL_DEBUG: INFO + NCCL_SOCKET_IFNAME: ${NCCL_SOCKET_IFNAME:-^docker0,lo} + NCCL_IB_DISABLE: "1" + NCCL_TIMEOUT: ${NCCL_TIMEOUT:-1800} + volumes: + - ${WAN_DATA_HOST_PATH:-./data}:/data + - ${WAN_CKPT_HOST_PATH:-./.wan_ckpt_placeholder}:/ckpt:ro \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 00000000..466b6dc0 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,96 @@ +# Wan2.2 Master Node: Redis + API + Worker0 (4×A100 40GB). +# +# Each node runs its own torchrun locally; NCCL rendezvous connects them. +# No SSH needed — master signals worker1 via Redis pub/sub. +# +# IMPORTANT: worker0 uses network_mode: host so torchrun/NCCL can bind +# ports directly on the host. Redis and API remain on the default bridge +# network; worker0 accesses Redis via host.docker.internal or host IP. +# +# Prerequisites: NVIDIA Container Toolkit; Docker Compose v2. +# +# Usage: +# cp docker/compose.env.example .env +# # edit .env (WAN_SERVE_API_KEYS, WAN_CKPT_HOST_PATH, WAN_MASTER_ADDR) +# +# docker compose up -d --build +# +# On the secondary server, use docker-compose.worker.yml: +# docker compose -f docker-compose.worker.yml up -d --build +# +# Logs: docker compose logs -f api +# docker compose logs -f worker0 + +services: + redis: + image: redis:7-alpine + restart: unless-stopped + command: ["redis-server", "--appendonly", "yes"] + volumes: + - wan_redis:/data + ports: + - "6379:6379" + + api: + build: + context: . + dockerfile: docker/Dockerfile.api + restart: unless-stopped + ports: + - "${WAN_API_PORT:-8008}:8008" + environment: + WAN_REDIS_URL: redis://redis:6379/0 + WAN_SERVE_API_KEYS: ${WAN_SERVE_API_KEYS:?set WAN_SERVE_API_KEYS in .env} + WAN_REPO_ROOT: /app + WAN_CKPT_DIR: /ckpt + WAN_JOB_DIR: /data/jobs + WAN_OUTPUT_DIR: /data/outputs + WAN_API_HOST: 0.0.0.0 + WAN_API_PORT: "8008" + volumes: + - ${WAN_DATA_HOST_PATH:-./data}:/data + depends_on: + - redis + + worker0: + build: + context: . + dockerfile: docker/Dockerfile.worker + restart: unless-stopped + shm_size: "16g" + network_mode: host + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: all + capabilities: [gpu] + environment: + # Redis is on host; use localhost or the host's real IP + WAN_REDIS_URL: ${WAN_REDIS_URL_LOCAL:-redis://127.0.0.1:6379/0} + WAN_SERVE_API_KEYS: ${WAN_SERVE_API_KEYS} + WAN_REPO_ROOT: /app + WAN_CKPT_DIR: /ckpt + WAN_JOB_DIR: /data/jobs + WAN_OUTPUT_DIR: /data/outputs + WAN_NODE_ROLE: master + WAN_NNODES: ${WAN_NNODES:-2} + WAN_NPROC_PER_NODE: ${WAN_NPROC_PER_NODE:-4} + WAN_NODE_RANK: 0 + # With host network, use the host's real IP (not 0.0.0.0) + WAN_MASTER_ADDR: ${WAN_MASTER_ADDR:-127.0.0.1} + WAN_MASTER_PORT: ${WAN_MASTER_PORT:-29500} + WAN_RDZV_PREFIX: ${WAN_RDZV_PREFIX:-wan} + NVIDIA_VISIBLE_DEVICES: all + NVIDIA_DRIVER_CAPABILITIES: compute,utility + NCCL_DEBUG: INFO + NCCL_SOCKET_IFNAME: ${NCCL_SOCKET_IFNAME:-^docker0,lo} + NCCL_IB_DISABLE: "1" + NCCL_TIMEOUT: ${NCCL_TIMEOUT:-1800} + volumes: + - ${WAN_DATA_HOST_PATH:-./data}:/data + - ${WAN_CKPT_HOST_PATH:-./.wan_ckpt_placeholder}:/ckpt:ro + +volumes: + wan_redis: \ No newline at end of file diff --git a/docker/Dockerfile.api b/docker/Dockerfile.api new file mode 100644 index 00000000..77d36a71 --- /dev/null +++ b/docker/Dockerfile.api @@ -0,0 +1,21 @@ +# CPU-only API (FastAPI + Redis client). Build from repository root: +# docker build -f docker/Dockerfile.api -t wan2-api:latest . +FROM python:3.11-slim-bookworm + +WORKDIR /app + +COPY requirements_serve.txt /app/requirements_serve.txt +RUN pip install --no-cache-dir --upgrade pip \ + && pip install --no-cache-dir -r /app/requirements_serve.txt + +COPY serve /app/serve +COPY wan/configs /app/wan/configs +COPY run_api_server.py /app/run_api_server.py + +ENV PYTHONPATH=/app +EXPOSE 8008 + +HEALTHCHECK --interval=30s --timeout=5s --start-period=10s --retries=3 \ + CMD python -c "import urllib.request; urllib.request.urlopen('http://127.0.0.1:8008/healthz')" || exit 1 + +CMD ["python", "/app/run_api_server.py"] diff --git a/docker/Dockerfile.worker b/docker/Dockerfile.worker new file mode 100644 index 00000000..509b052b --- /dev/null +++ b/docker/Dockerfile.worker @@ -0,0 +1,34 @@ +# GPU worker: Wan2.2 + queue consumer. Build from repository root: +# docker build -f docker/Dockerfile.worker -t wan2-worker:latest . +# +# Optional: --build-arg INSTALL_FLASH_ATTN=1 (may fail if CUDA/toolchain mismatch) +ARG BASE_IMAGE=pytorch/pytorch:2.5.1-cuda12.4-cudnn9-runtime +FROM ${BASE_IMAGE} + +ARG INSTALL_FLASH_ATTN=0 +ENV PIP_NO_CACHE_DIR=1 \ + PYTHONDONTWRITEBYTECODE=1 \ + PYTHONUNBUFFERED=1 \ + PYTHONPATH=/app + +WORKDIR /app + +USER root + +COPY requirements.txt /tmp/requirements.txt +RUN sed '/^flash_attn$/d' /tmp/requirements.txt > /tmp/requirements.noflash.txt \ + && pip install --upgrade pip \ + && pip install -r /tmp/requirements.noflash.txt + +RUN if [ "$INSTALL_FLASH_ATTN" = "1" ]; then \ + pip install flash_attn --no-build-isolation || true; \ + fi + +COPY requirements_serve.txt /tmp/requirements_serve.txt +RUN pip install -r /tmp/requirements_serve.txt + +COPY . . + +RUN chmod +x /app/docker/entrypoint-worker.sh + +ENTRYPOINT ["/app/docker/entrypoint-worker.sh"] \ No newline at end of file diff --git a/docker/compose.env.example b/docker/compose.env.example new file mode 100644 index 00000000..94717035 --- /dev/null +++ b/docker/compose.env.example @@ -0,0 +1,41 @@ +# Copy to .env in repo root (same directory as docker-compose.yml) and adjust. +# +# ===== BOTH NODES (must match) ===== + +# Required: at least one API key (comma-separated for multiple) +WAN_SERVE_API_KEYS=sk-change-me + +# Model weights parent directory (contains subdirs for each model) +# Expected subdirectory layout: +# WAN_CKPT_HOST_PATH/Wan2.2-T2V-A14B/ +# WAN_CKPT_HOST_PATH/Wan2.2-I2V-A14B/ +# WAN_CKPT_HOST_PATH/Wan2.2-TI2V-5B/ +# WAN_CKPT_HOST_PATH/Wan2.2-Animate-14B/ +# WAN_CKPT_HOST_PATH/Wan2.2-S2V-14B/ +# Only the models you need must be present; others can be omitted. +WAN_CKPT_HOST_PATH=/data/models + +# torchrun topology: 2 nodes × 4 GPUs per node = 8 GPUs total +WAN_NNODES=2 +WAN_NPROC_PER_NODE=4 + +# Master node's REAL IP address (not 0.0.0.0 or 127.0.0.1) +# Both nodes must agree on this for NCCL rendezvous. +# worker0 uses network_mode: host, so it binds directly on this IP. +# Example: 10.0.0.1 or 173.2.9.5 +WAN_MASTER_ADDR=10.0.0.1 +WAN_MASTER_PORT=29500 + +# Published API port on host (master node only) +WAN_API_PORT=8008 + +# ===== MASTER NODE ONLY ===== + +# Redis is local on master. +# worker0 uses network_mode: host and connects to Redis via localhost: +WAN_REDIS_URL_LOCAL=redis://127.0.0.1:6379/0 + +# ===== WORKER NODE ONLY ===== + +# Redis URL pointing to master node (replace MASTER_IP with actual IP) +WAN_REDIS_URL=redis://10.0.0.1:6379/0 \ No newline at end of file diff --git a/docker/entrypoint-worker.sh b/docker/entrypoint-worker.sh new file mode 100644 index 00000000..2d59f4ca --- /dev/null +++ b/docker/entrypoint-worker.sh @@ -0,0 +1,3 @@ +#!/usr/bin/env bash +set -euo pipefail +exec python -m serve.worker_main \ No newline at end of file diff --git a/generate.py b/generate.py index 3a5cbcdd..8580a02d 100644 --- a/generate.py +++ b/generate.py @@ -102,7 +102,7 @@ def _validate_args(args): task], f"Unsupport size {args.size} for task {args.task}, supported sizes are: {', '.join(SUPPORTED_SIZES[args.task])}" -def _parse_args(): +def _build_parser(): parser = argparse.ArgumentParser( description="Generate a image or video from a text prompt or image using Wan" ) @@ -294,9 +294,94 @@ def _parse_args(): default=80, help="Number of frames per clip, 48 or 80 or others (must be multiple of 4) for 14B s2v" ) - args = parser.parse_args() + return parser + + +def parse_args(argv=None): + parser = _build_parser() + args = parser.parse_args(argv) _validate_args(args) + return args + + +# Aliases for HTTP / job JSON (model id -> --task value) +JOB_MODEL_ALIASES = { + "wan2.2-t2v-a14b": "t2v-A14B", + "wan2.2-i2v-a14b": "i2v-A14B", + "wan2.2-ti2v-5b": "ti2v-5B", + "wan2.2-s2v-14b": "s2v-14B", + "wan2.2-animate-14b": "animate-14B", +} + +def _resolve_task_name(model_or_task): + """Map API model id or task string to a WAN_CONFIGS key.""" + raw = str(model_or_task).strip() + low = raw.lower() + if low in JOB_MODEL_ALIASES: + return JOB_MODEL_ALIASES[low] + for k in WAN_CONFIGS: + if k.lower() == low: + return k + return raw + + +def _flatten_job_dict(job): + """Merge nested API shape (model, input, parameters) into flat CLI keys.""" + if not isinstance(job, dict): + raise TypeError("job must be a dict") + out = {} + model = job.get("model") + if model is not None: + out["task"] = _resolve_task_name(model) + for section in ("input", "parameters"): + sub = job.get(section) + if isinstance(sub, dict): + out.update(sub) + for k, v in job.items(): + if k in ("input", "parameters", "model"): + continue + out[k] = v + if "task" in out: + out["task"] = _resolve_task_name(out["task"]) + return out + + +def _apply_job_value_to_args(name, value, args): + if value is None: + return + if name == "sample_guide_scale": + if isinstance(value, (list, tuple)): + if len(value) == 1: + setattr(args, name, float(value[0])) + else: + setattr(args, name, tuple(float(x) for x in value)) + else: + setattr(args, name, float(value)) + return + if name == "offload_model": + if isinstance(value, bool): + setattr(args, name, value) + else: + setattr(args, name, str2bool(str(value))) + return + setattr(args, name, value) + + +def args_from_job_dict(job): + """ + Build a validated argparse.Namespace from a JSON-serializable job dict. + Keys match generate.py CLI flags; optional nested ``input`` / ``parameters`` + and ``model`` alias are supported for DashScope-style payloads. + """ + parser = _build_parser() + args = parser.parse_args([]) + flat = _flatten_job_dict(job) + for key, value in flat.items(): + if not hasattr(args, key): + raise ValueError(f"Unknown job field: {key}") + _apply_job_value_to_args(key, value, args) + _validate_args(args) return args @@ -571,5 +656,5 @@ def generate(args): if __name__ == "__main__": - args = _parse_args() + args = parse_args() generate(args) diff --git a/generate_job.py b/generate_job.py new file mode 100644 index 00000000..2fbe5a0c --- /dev/null +++ b/generate_job.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python3 +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +"""Read a JSON job file and run ``generate.generate`` (for torchrun / serving workers).""" +from __future__ import annotations + +import argparse +import json +import sys +from pathlib import Path + +_REPO_ROOT = Path(__file__).resolve().parent +if str(_REPO_ROOT) not in sys.path: + sys.path.insert(0, str(_REPO_ROOT)) + +import generate as wan_generate # noqa: E402 + + +def main(): + parser = argparse.ArgumentParser( + description="Wan2.2 JSON job runner (use with torchrun).") + parser.add_argument( + "--job_json", + type=str, + required=True, + help="Path to JSON job spec (see README.md).", + ) + args_ns = parser.parse_args() + path = Path(args_ns.job_json) + if not path.is_file(): + raise FileNotFoundError(f"job_json not found: {path}") + job = json.loads(path.read_text(encoding="utf-8")) + gen_args = wan_generate.args_from_job_dict(job) + wan_generate.generate(gen_args) + + +if __name__ == "__main__": + main() diff --git a/jobs/wan-00059d66441845fda2adcd0b4935ca82.json b/jobs/wan-00059d66441845fda2adcd0b4935ca82.json new file mode 100644 index 00000000..1ec7854e --- /dev/null +++ b/jobs/wan-00059d66441845fda2adcd0b4935ca82.json @@ -0,0 +1,13 @@ +{ + "model": "wan2.2-t2v-a14b", + "prompt": "A cat walking on grass.", + "size": "1280*720", + "offload_model": true, + "ulysses_size": 8, + "t5_fsdp": true, + "t5_cpu": true, + "dit_fsdp": true, + "convert_model_dtype": true, + "ckpt_dir": "/home/HPCBase/LLM/Wan-AI/Wan2.2-T2V-A14B", + "save_file": "/home/HPCBase/PACKAGE/linlei/Wan2.2/wan22-lin/out/wan-00059d66441845fda2adcd0b4935ca82.mp4" +} \ No newline at end of file diff --git a/jobs/wan-00fe7e82322e489a8c873255d3306461.json b/jobs/wan-00fe7e82322e489a8c873255d3306461.json new file mode 100644 index 00000000..79c58726 --- /dev/null +++ b/jobs/wan-00fe7e82322e489a8c873255d3306461.json @@ -0,0 +1,12 @@ +{ + "model": "wan2.2-t2v-a14b", + "prompt": "A cat walking on grass.", + "size": "1280*720", + "offload_model": false, + "ulysses_size": 8, + "t5_fsdp": true, + "dit_fsdp": true, + "convert_model_dtype": true, + "ckpt_dir": "/home/HPCBase/LLM/Wan-AI/Wan2.2-T2V-A14B", + "save_file": "/home/HPCBase/PACKAGE/linlei/Wan2.2/wan22-lin/out/wan-00fe7e82322e489a8c873255d3306461.mp4" +} \ No newline at end of file diff --git a/jobs/wan-238941aaff8643b99ae01955f9188711.json b/jobs/wan-238941aaff8643b99ae01955f9188711.json new file mode 100644 index 00000000..1ff4b502 --- /dev/null +++ b/jobs/wan-238941aaff8643b99ae01955f9188711.json @@ -0,0 +1,12 @@ +{ + "model": "wan2.2-t2v-a14b", + "prompt": "A cat walking on grass.", + "size": "1280*720", + "offload_model": false, + "ulysses_size": 8, + "t5_fsdp": true, + "dit_fsdp": true, + "convert_model_dtype": true, + "ckpt_dir": "/home/HPCBase/LLM/Wan-AI/Wan2.2-T2V-A14B", + "save_file": "/home/HPCBase/PACKAGE/linlei/Wan2.2/wan22-lin/out/wan-238941aaff8643b99ae01955f9188711.mp4" +} \ No newline at end of file diff --git a/jobs/wan-24f28daa33ba4570a9ad60b4dccfb673.json b/jobs/wan-24f28daa33ba4570a9ad60b4dccfb673.json new file mode 100644 index 00000000..37ef0731 --- /dev/null +++ b/jobs/wan-24f28daa33ba4570a9ad60b4dccfb673.json @@ -0,0 +1,12 @@ +{ + "model": "wan2.2-t2v-a14b", + "prompt": "A cat walking on grass.", + "size": "1280*720", + "offload_model": false, + "ulysses_size": 8, + "t5_fsdp": true, + "dit_fsdp": true, + "convert_model_dtype": true, + "ckpt_dir": "/home/HPCBase/LLM/Wan-AI/Wan2.2-T2V-A14B", + "save_file": "/home/HPCBase/PACKAGE/linlei/Wan2.2/wan22-lin/out/wan-24f28daa33ba4570a9ad60b4dccfb673.mp4" +} \ No newline at end of file diff --git a/jobs/wan-2555b27229bf484bad9cacb01884415a.json b/jobs/wan-2555b27229bf484bad9cacb01884415a.json new file mode 100644 index 00000000..46c83d09 --- /dev/null +++ b/jobs/wan-2555b27229bf484bad9cacb01884415a.json @@ -0,0 +1,12 @@ +{ + "model": "wan2.2-t2v-a14b", + "prompt": "A cat walking on grass.", + "size": "1280*720", + "offload_model": false, + "ulysses_size": 8, + "t5_fsdp": true, + "dit_fsdp": true, + "convert_model_dtype": true, + "ckpt_dir": "/home/HPCBase/LLM/Wan-AI/Wan2.2-T2V-A14B", + "save_file": "/home/HPCBase/PACKAGE/linlei/Wan2.2/wan22-lin/out/wan-2555b27229bf484bad9cacb01884415a.mp4" +} \ No newline at end of file diff --git a/jobs/wan-28310b52990d4cd4b720b4e118289eff.json b/jobs/wan-28310b52990d4cd4b720b4e118289eff.json new file mode 100644 index 00000000..10f12332 --- /dev/null +++ b/jobs/wan-28310b52990d4cd4b720b4e118289eff.json @@ -0,0 +1,12 @@ +{ + "model": "wan2.2-t2v-a14b", + "prompt": "A cat walking on grass.", + "size": "1280*720", + "offload_model": false, + "ulysses_size": 8, + "t5_fsdp": true, + "dit_fsdp": true, + "convert_model_dtype": true, + "ckpt_dir": "/home/HPCBase/LLM/Wan-AI/Wan2.2-T2V-A14B", + "save_file": "/home/HPCBase/PACKAGE/linlei/Wan2.2/wan22-lin/out/wan-28310b52990d4cd4b720b4e118289eff.mp4" +} \ No newline at end of file diff --git a/jobs/wan-2fd1afc6587d4f1d852615b3c8d0d61f.json b/jobs/wan-2fd1afc6587d4f1d852615b3c8d0d61f.json new file mode 100644 index 00000000..2bae1cfd --- /dev/null +++ b/jobs/wan-2fd1afc6587d4f1d852615b3c8d0d61f.json @@ -0,0 +1,12 @@ +{ + "model": "wan2.2-t2v-a14b", + "prompt": "A cat walking on grass.", + "size": "1280*720", + "offload_model": false, + "ulysses_size": 8, + "t5_fsdp": true, + "dit_fsdp": true, + "convert_model_dtype": true, + "ckpt_dir": "/home/HPCBase/LLM/Wan-AI/Wan2.2-T2V-A14B", + "save_file": "/home/HPCBase/PACKAGE/linlei/Wan2.2/wan22-lin/out/wan-2fd1afc6587d4f1d852615b3c8d0d61f.mp4" +} \ No newline at end of file diff --git a/jobs/wan-348052cbfa4b4a6ba7885f9cf50e3631.json b/jobs/wan-348052cbfa4b4a6ba7885f9cf50e3631.json new file mode 100644 index 00000000..c0e6ae71 --- /dev/null +++ b/jobs/wan-348052cbfa4b4a6ba7885f9cf50e3631.json @@ -0,0 +1,12 @@ +{ + "model": "wan2.2-t2v-a14b", + "prompt": "A cat walking on grass.", + "size": "1280*720", + "offload_model": false, + "ulysses_size": 8, + "t5_fsdp": true, + "dit_fsdp": true, + "convert_model_dtype": true, + "ckpt_dir": "/home/HPCBase/LLM/Wan-AI/Wan2.2-T2V-A14B", + "save_file": "/home/HPCBase/PACKAGE/linlei/Wan2.2/wan22-lin/out/wan-348052cbfa4b4a6ba7885f9cf50e3631.mp4" +} \ No newline at end of file diff --git a/jobs/wan-4ae0a1d8085c4989bc9e2231d07854e0.json b/jobs/wan-4ae0a1d8085c4989bc9e2231d07854e0.json new file mode 100644 index 00000000..22f875e5 --- /dev/null +++ b/jobs/wan-4ae0a1d8085c4989bc9e2231d07854e0.json @@ -0,0 +1,12 @@ +{ + "model": "wan2.2-t2v-a14b", + "prompt": "A cat walking on grass.", + "size": "1280*720", + "offload_model": false, + "ulysses_size": 8, + "t5_fsdp": true, + "dit_fsdp": true, + "convert_model_dtype": true, + "ckpt_dir": "/home/HPCBase/LLM/Wan-AI/Wan2.2-T2V-A14B", + "save_file": "/home/HPCBase/PACKAGE/linlei/Wan2.2/wan22-lin/out/wan-4ae0a1d8085c4989bc9e2231d07854e0.mp4" +} \ No newline at end of file diff --git a/jobs/wan-4ce812dc76514695b4ed80a2dceba7b3.json b/jobs/wan-4ce812dc76514695b4ed80a2dceba7b3.json new file mode 100644 index 00000000..9821bbae --- /dev/null +++ b/jobs/wan-4ce812dc76514695b4ed80a2dceba7b3.json @@ -0,0 +1,13 @@ +{ + "model": "wan2.2-t2v-a14b", + "prompt": "A cat walking on grass.", + "size": "1280*720", + "offload_model": true, + "ulysses_size": 8, + "t5_fsdp": true, + "t5_cpu": true, + "dit_fsdp": true, + "convert_model_dtype": true, + "ckpt_dir": "/home/HPCBase/LLM/Wan-AI/Wan2.2-T2V-A14B", + "save_file": "/home/HPCBase/PACKAGE/linlei/Wan2.2/wan22-lin/out/wan-4ce812dc76514695b4ed80a2dceba7b3.mp4" +} \ No newline at end of file diff --git a/jobs/wan-5506ef475a064d3cb3e6757f49a07863.json b/jobs/wan-5506ef475a064d3cb3e6757f49a07863.json new file mode 100644 index 00000000..5fc90fd4 --- /dev/null +++ b/jobs/wan-5506ef475a064d3cb3e6757f49a07863.json @@ -0,0 +1,14 @@ +{ + "model": "wan2.2-t2v-a14b", + "prompt": "A cat walking on grass.", + "size": "960*480", + "frame_num": 49, + "offload_model": true, + "ulysses_size": 8, + "t5_fsdp": true, + "t5_cpu": true, + "dit_fsdp": true, + "convert_model_dtype": true, + "ckpt_dir": "/home/HPCBase/LLM/Wan-AI/Wan2.2-T2V-A14B", + "save_file": "/home/HPCBase/PACKAGE/linlei/Wan2.2/wan22-lin/out/wan-5506ef475a064d3cb3e6757f49a07863.mp4" +} \ No newline at end of file diff --git a/jobs/wan-59d2b69da370471b8c5449932955d966.json b/jobs/wan-59d2b69da370471b8c5449932955d966.json new file mode 100644 index 00000000..94a55af6 --- /dev/null +++ b/jobs/wan-59d2b69da370471b8c5449932955d966.json @@ -0,0 +1,12 @@ +{ + "model": "wan2.2-t2v-a14b", + "prompt": "A cat walking on grass.", + "size": "1280*720", + "offload_model": false, + "ulysses_size": 8, + "t5_fsdp": true, + "dit_fsdp": true, + "convert_model_dtype": true, + "ckpt_dir": "/home/HPCBase/LLM/Wan-AI/Wan2.2-T2V-A14B", + "save_file": "/home/HPCBase/PACKAGE/linlei/Wan2.2/wan22-lin/out/wan-59d2b69da370471b8c5449932955d966.mp4" +} \ No newline at end of file diff --git a/jobs/wan-5b3df1506ab245a187d6786384a62fd2.json b/jobs/wan-5b3df1506ab245a187d6786384a62fd2.json new file mode 100644 index 00000000..52e4ba05 --- /dev/null +++ b/jobs/wan-5b3df1506ab245a187d6786384a62fd2.json @@ -0,0 +1,13 @@ +{ + "model": "wan2.2-t2v-a14b", + "prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage", + "size": "832*480", + "offload_model": true, + "ulysses_size": 8, + "t5_fsdp": true, + "t5_cpu": true, + "dit_fsdp": true, + "convert_model_dtype": true, + "ckpt_dir": "/home/HPCBase/LLM/Wan-AI/Wan2.2-T2V-A14B", + "save_file": "/home/HPCBase/PACKAGE/linlei/Wan2.2/wan22-lin/out/wan-5b3df1506ab245a187d6786384a62fd2.mp4" +} \ No newline at end of file diff --git a/jobs/wan-700c5dc4168c47b0a32f6d0ba0bb20f8.json b/jobs/wan-700c5dc4168c47b0a32f6d0ba0bb20f8.json new file mode 100644 index 00000000..cad63e2f --- /dev/null +++ b/jobs/wan-700c5dc4168c47b0a32f6d0ba0bb20f8.json @@ -0,0 +1,12 @@ +{ + "model": "wan2.2-t2v-a14b", + "prompt": "A cat walking on grass.", + "size": "1280*720", + "offload_model": false, + "ulysses_size": 8, + "t5_fsdp": true, + "dit_fsdp": true, + "convert_model_dtype": true, + "ckpt_dir": "/home/HPCBase/LLM/Wan-AI/Wan2.2-T2V-A14B", + "save_file": "/home/HPCBase/PACKAGE/linlei/Wan2.2/wan22-lin/out/wan-700c5dc4168c47b0a32f6d0ba0bb20f8.mp4" +} \ No newline at end of file diff --git a/jobs/wan-75626c8f50d64dcdb86fa4f74f1abafb.json b/jobs/wan-75626c8f50d64dcdb86fa4f74f1abafb.json new file mode 100644 index 00000000..47b3dbc9 --- /dev/null +++ b/jobs/wan-75626c8f50d64dcdb86fa4f74f1abafb.json @@ -0,0 +1,12 @@ +{ + "model": "wan2.2-t2v-a14b", + "prompt": "A cat walking on grass.", + "size": "1280*720", + "offload_model": false, + "ulysses_size": 8, + "t5_fsdp": true, + "dit_fsdp": true, + "convert_model_dtype": true, + "ckpt_dir": "/home/HPCBase/LLM/Wan-AI/Wan2.2-T2V-A14B", + "save_file": "/home/HPCBase/PACKAGE/linlei/Wan2.2/wan22-lin/out/wan-75626c8f50d64dcdb86fa4f74f1abafb.mp4" +} \ No newline at end of file diff --git a/jobs/wan-886f0a45b0c746acb629429b5bacf048.json b/jobs/wan-886f0a45b0c746acb629429b5bacf048.json new file mode 100644 index 00000000..dcd9bfc0 --- /dev/null +++ b/jobs/wan-886f0a45b0c746acb629429b5bacf048.json @@ -0,0 +1,12 @@ +{ + "model": "wan2.2-t2v-a14b", + "prompt": "A cat walking on grass.", + "size": "1280*720", + "offload_model": false, + "ulysses_size": 8, + "t5_fsdp": true, + "dit_fsdp": true, + "convert_model_dtype": true, + "ckpt_dir": "/home/HPCBase/LLM/Wan-AI/Wan2.2-T2V-A14B", + "save_file": "/home/HPCBase/PACKAGE/linlei/Wan2.2/wan22-lin/out/wan-886f0a45b0c746acb629429b5bacf048.mp4" +} \ No newline at end of file diff --git a/jobs/wan-8cd1ac8ab91e49829bd1642d254d5ddd.json b/jobs/wan-8cd1ac8ab91e49829bd1642d254d5ddd.json new file mode 100644 index 00000000..ebd0a592 --- /dev/null +++ b/jobs/wan-8cd1ac8ab91e49829bd1642d254d5ddd.json @@ -0,0 +1,12 @@ +{ + "model": "wan2.2-t2v-a14b", + "prompt": "A cat walking on grass.", + "size": "1280*720", + "offload_model": false, + "ulysses_size": 8, + "t5_fsdp": true, + "dit_fsdp": true, + "convert_model_dtype": true, + "ckpt_dir": "/home/HPCBase/LLM/Wan-AI/Wan2.2-T2V-A14B", + "save_file": "/home/HPCBase/PACKAGE/linlei/Wan2.2/wan22-lin/out/wan-8cd1ac8ab91e49829bd1642d254d5ddd.mp4" +} \ No newline at end of file diff --git a/jobs/wan-9ce9b64dc6a54f42a3d16a2f1d21e079.json b/jobs/wan-9ce9b64dc6a54f42a3d16a2f1d21e079.json new file mode 100644 index 00000000..97532310 --- /dev/null +++ b/jobs/wan-9ce9b64dc6a54f42a3d16a2f1d21e079.json @@ -0,0 +1,12 @@ +{ + "model": "wan2.2-t2v-a14b", + "prompt": "A cat walking on grass.", + "size": "1280*720", + "offload_model": false, + "ulysses_size": 8, + "t5_fsdp": true, + "dit_fsdp": true, + "convert_model_dtype": true, + "ckpt_dir": "/home/HPCBase/LLM/Wan-AI/Wan2.2-T2V-A14B", + "save_file": "/home/HPCBase/PACKAGE/linlei/Wan2.2/wan22-lin/out/wan-9ce9b64dc6a54f42a3d16a2f1d21e079.mp4" +} \ No newline at end of file diff --git a/jobs/wan-c5d7fa7a2a00449aabaf782f2692ed31.json b/jobs/wan-c5d7fa7a2a00449aabaf782f2692ed31.json new file mode 100644 index 00000000..d4aee5e7 --- /dev/null +++ b/jobs/wan-c5d7fa7a2a00449aabaf782f2692ed31.json @@ -0,0 +1,12 @@ +{ + "model": "wan2.2-t2v-a14b", + "prompt": "A cat walking on grass.", + "size": "1280*720", + "offload_model": false, + "ulysses_size": 8, + "t5_fsdp": true, + "dit_fsdp": true, + "convert_model_dtype": true, + "ckpt_dir": "/home/HPCBase/LLM/Wan-AI/Wan2.2-T2V-A14B", + "save_file": "/home/HPCBase/PACKAGE/linlei/Wan2.2/wan22-lin/out/wan-c5d7fa7a2a00449aabaf782f2692ed31.mp4" +} \ No newline at end of file diff --git a/jobs/wan-d6e4f4b3c18f46ac823f47919f3b3686.json b/jobs/wan-d6e4f4b3c18f46ac823f47919f3b3686.json new file mode 100644 index 00000000..a584eb4d --- /dev/null +++ b/jobs/wan-d6e4f4b3c18f46ac823f47919f3b3686.json @@ -0,0 +1,12 @@ +{ + "model": "wan2.2-t2v-a14b", + "prompt": "A cat walking on grass.", + "size": "1280*720", + "offload_model": false, + "ulysses_size": 8, + "t5_fsdp": true, + "dit_fsdp": true, + "convert_model_dtype": true, + "ckpt_dir": "/home/HPCBase/LLM/Wan-AI/Wan2.2-T2V-A14B", + "save_file": "/home/HPCBase/PACKAGE/linlei/Wan2.2/wan22-lin/out/wan-d6e4f4b3c18f46ac823f47919f3b3686.mp4" +} \ No newline at end of file diff --git a/jobs/wan-dc07f3e46ba54afcab7be44bd712d5aa.json b/jobs/wan-dc07f3e46ba54afcab7be44bd712d5aa.json new file mode 100644 index 00000000..26a01b9a --- /dev/null +++ b/jobs/wan-dc07f3e46ba54afcab7be44bd712d5aa.json @@ -0,0 +1,12 @@ +{ + "model": "wan2.2-t2v-a14b", + "prompt": "A cat walking on grass.", + "size": "1280*720", + "offload_model": false, + "ulysses_size": 8, + "t5_fsdp": true, + "dit_fsdp": true, + "convert_model_dtype": true, + "ckpt_dir": "/home/HPCBase/LLM/Wan-AI/Wan2.2-T2V-A14B", + "save_file": "/home/HPCBase/PACKAGE/linlei/Wan2.2/wan22-lin/out/wan-dc07f3e46ba54afcab7be44bd712d5aa.mp4" +} \ No newline at end of file diff --git a/jobs/wan-e49df060d3664a66b9fdc1d6dc52305e.json b/jobs/wan-e49df060d3664a66b9fdc1d6dc52305e.json new file mode 100644 index 00000000..aeaa6472 --- /dev/null +++ b/jobs/wan-e49df060d3664a66b9fdc1d6dc52305e.json @@ -0,0 +1,14 @@ +{ + "model": "wan2.2-t2v-a14b", + "prompt": "A cat walking on grass.", + "size": "832*480", + "frame_num": 30, + "offload_model": true, + "ulysses_size": 8, + "t5_fsdp": true, + "t5_cpu": true, + "dit_fsdp": true, + "convert_model_dtype": true, + "ckpt_dir": "/home/HPCBase/LLM/Wan-AI/Wan2.2-T2V-A14B", + "save_file": "/home/HPCBase/PACKAGE/linlei/Wan2.2/wan22-lin/out/wan-e49df060d3664a66b9fdc1d6dc52305e.mp4" +} \ No newline at end of file diff --git a/out/wan-5b3df1506ab245a187d6786384a62fd2.mp4 b/out/wan-5b3df1506ab245a187d6786384a62fd2.mp4 new file mode 100644 index 00000000..15e1633d Binary files /dev/null and b/out/wan-5b3df1506ab245a187d6786384a62fd2.mp4 differ diff --git a/out/wan-e49df060d3664a66b9fdc1d6dc52305e.mp4 b/out/wan-e49df060d3664a66b9fdc1d6dc52305e.mp4 new file mode 100644 index 00000000..7a6b2e1e Binary files /dev/null and b/out/wan-e49df060d3664a66b9fdc1d6dc52305e.mp4 differ diff --git a/pyproject.toml b/pyproject.toml index 337240af..9a582cd3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,12 @@ dev = [ "mypy", "huggingface-hub[cli]" ] +serve = [ + "fastapi>=0.115.0", + "uvicorn[standard]>=0.30.0", + "redis>=5.0.0", + "pydantic>=2.0.0", +] [project.urls] homepage = "https://wanxai.com" diff --git a/requirements_serve.txt b/requirements_serve.txt new file mode 100644 index 00000000..e51d6ad8 --- /dev/null +++ b/requirements_serve.txt @@ -0,0 +1,5 @@ +# HTTP API + Redis queue (install on the API node; worker/GPU nodes only need Redis client if you split roles) +fastapi>=0.115.0 +uvicorn[standard]>=0.30.0 +redis>=5.0.0 +pydantic>=2.0.0 diff --git a/run_api_server.py b/run_api_server.py new file mode 100644 index 00000000..e6613a8f --- /dev/null +++ b/run_api_server.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python3 +"""Run the DashScope-style HTTP API (development / single-process).""" +import os + +import uvicorn + +if __name__ == "__main__": + host = os.environ.get("WAN_API_HOST", "0.0.0.0") + port = int(os.environ.get("WAN_API_PORT", "8008")) + uvicorn.run( + "serve.api:app", + host=host, + port=port, + workers=1, + reload=os.environ.get("WAN_API_RELOAD", "").lower() in ("1", "true", "yes"), + ) diff --git a/serve/__init__.py b/serve/__init__.py new file mode 100644 index 00000000..4520f2b3 --- /dev/null +++ b/serve/__init__.py @@ -0,0 +1 @@ +# Wan2.2 HTTP serving package (DashScope-style async tasks + torchrun worker). diff --git a/serve/api.py b/serve/api.py new file mode 100644 index 00000000..fa5365f7 --- /dev/null +++ b/serve/api.py @@ -0,0 +1,224 @@ +from __future__ import annotations + +import os +import uuid +from contextlib import asynccontextmanager +from pathlib import Path +from typing import Optional + +from fastapi import Depends, FastAPI, Header, HTTPException, UploadFile, File, Form +from fastapi.responses import FileResponse, HTMLResponse +from fastapi.staticfiles import StaticFiles + +from .auth import require_bearer +from .config import Settings +from .job_build import request_to_job +from .schemas import ( + HealthResponse, + ModelEnum, + OutputTaskId, + TaskStatusBody, + VideoGenerationRequest, + VideoGenerationResponse, +) +from .store import TaskStore + +_settings: Optional[Settings] = None +_store: Optional[TaskStore] = None + + +def get_settings() -> Settings: + assert _settings is not None + return _settings + + +def get_store() -> TaskStore: + assert _store is not None + return _store + + +def _validate_model_input(body: VideoGenerationRequest) -> None: + """Per-model field validation — return clear 400 errors for missing required inputs.""" + model = body.model.value + inp = body.input + params = body.parameters + + if not inp.prompt: + raise HTTPException(status_code=400, detail=f"model '{model}' requires input.prompt") + + # i2v-A14B and ti2v-5B require image + if model in (ModelEnum.i2v_a14b.value, ModelEnum.ti2v_5b.value) and not inp.image: + raise HTTPException( + status_code=400, + detail=f"model '{model}' requires input.image", + ) + + # animate-14B requires video + pose (src_root_path) + if model == ModelEnum.animate_14b.value and not inp.video: + raise HTTPException( + status_code=400, + detail=f"model '{model}' requires input.video (reference video path)", + ) + + # s2v-14B requires image + audio (or enable_tts) + if model == ModelEnum.s2v_14b.value: + if not inp.image: + raise HTTPException( + status_code=400, + detail=f"model '{model}' requires input.image", + ) + if not inp.audio and not params.enable_tts: + raise HTTPException( + status_code=400, + detail=f"model '{model}' requires input.audio or parameters.enable_tts=true", + ) + + +@asynccontextmanager +async def lifespan(app: FastAPI): + global _settings, _store + _settings = Settings.from_env() + Path(_settings.job_dir).mkdir(parents=True, exist_ok=True) + Path(_settings.output_dir).mkdir(parents=True, exist_ok=True) + _store = TaskStore(_settings) + yield + _store = None + _settings = None + + +app = FastAPI( + title="Wan2.2 Video Generation API", + version="1.0.0", + lifespan=lifespan, +) + + +async def _auth_dep( + authorization: str | None = Header(default=None), + settings: Settings = Depends(get_settings), +): + await require_bearer(authorization, settings=settings) + + +@app.get("/healthz", response_model=HealthResponse) +def healthz(): + return HealthResponse() + + +@app.post( + "/api/v1/video/generation", + response_model=VideoGenerationResponse, + dependencies=[Depends(_auth_dep)], +) +def create_video_job( + body: VideoGenerationRequest, + settings: Settings = Depends(get_settings), + store: TaskStore = Depends(get_store), +): + if not settings.ckpt_dir and not (body.parameters.ckpt_dir): + raise HTTPException( + status_code=400, + detail="Set WAN_CKPT_DIR or parameters.ckpt_dir", + ) + + _validate_model_input(body) + + task_id = f"wan-{uuid.uuid4().hex}" + request_id = str(uuid.uuid4()) + job = request_to_job(body, task_id=task_id, settings=settings) + store.create_task(task_id, request_id, job) + return VideoGenerationResponse( + request_id=request_id, + output=OutputTaskId(task_id=task_id), + ) + + +@app.get( + "/api/v1/tasks/{task_id}", + dependencies=[Depends(_auth_dep)], +) +def get_task(task_id: str, store: TaskStore = Depends(get_store)): + doc = store.get_public(task_id) + if not doc: + raise HTTPException(status_code=404, detail="task not found") + return TaskStatusBody( + task_id=doc["task_id"], + task_status=doc["task_status"], + message=doc.get("message", ""), + output=doc.get("output") or {}, + request_id=doc.get("request_id"), + model=doc.get("model"), + ) + + +@app.get( + "/api/v1/files/by-task/{task_id}", + dependencies=[Depends(_auth_dep)], +) +def download_task_video(task_id: str, store: TaskStore = Depends(get_store)): + doc = store.get_internal(task_id) + if not doc or doc.get("status") != "SUCCEEDED": + raise HTTPException(status_code=404, detail="video not ready") + path = doc.get("output_path") + if not path or not os.path.isfile(path): + raise HTTPException(status_code=404, detail="file missing on server") + return FileResponse( + path, + media_type="video/mp4", + filename=f"{task_id}.mp4", + ) + + +def create_app() -> FastAPI: + return app + + +# File upload endpoint for WebUI +@app.post( + "/api/v1/files/upload", + dependencies=[Depends(_auth_dep)], +) +async def upload_file( + file: UploadFile = File(...), + category: str = Form("general"), + settings: Settings = Depends(get_settings), +): + """Upload a file (image, audio, video) to the server for use in generation.""" + import shutil + upload_dir = Path(settings.job_dir) / "uploads" / category + upload_dir.mkdir(parents=True, exist_ok=True) + + # Use only ASCII hex name to avoid encoding issues across containers + ext = Path(file.filename).suffix.lower() + # Convert webp/heic to jpg for PIL compatibility + if ext in (".webp", ".heic", ".heif"): + ext = ".jpg" + unique_name = f"{uuid.uuid4().hex}{ext}" + dest = upload_dir / unique_name + + if category == "image" and ext == ".jpg" and Path(file.filename).suffix.lower() in (".webp", ".heic", ".heif"): + # Convert webp/heic to jpg via PIL + from PIL import Image as PILImage + img = PILImage.open(file.file).convert("RGB") + img.save(dest, "JPEG", quality=95) + else: + with open(dest, "wb") as f: + shutil.copyfileobj(file.file, f) + + return {"path": str(dest), "filename": unique_name} + + +# WebUI — serve root page +_static_dir = Path(__file__).parent / "static" + + +@app.get("/", response_class=HTMLResponse, include_in_schema=False) +def webui(): + index = _static_dir / "index.html" + if index.is_file(): + return index.read_text(encoding="utf-8") + return HTMLResponse("

WebUI not found

", status_code=404) + + +if _static_dir.is_dir(): + app.mount("/static", StaticFiles(directory=str(_static_dir), html=True), name="static") \ No newline at end of file diff --git a/serve/auth.py b/serve/auth.py new file mode 100644 index 00000000..41c7da24 --- /dev/null +++ b/serve/auth.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +from fastapi import Header, HTTPException + +from .config import Settings + + +async def require_bearer( + authorization: str | None = Header(default=None), + *, + settings: Settings, +) -> None: + if not settings.api_keys: + raise HTTPException( + status_code=503, + detail="Server misconfigured: set WAN_SERVE_API_KEYS", + ) + if not authorization or not authorization.startswith("Bearer "): + raise HTTPException(status_code=401, detail="Missing Bearer token") + token = authorization.removeprefix("Bearer ").strip() + if token not in settings.api_keys: + raise HTTPException(status_code=403, detail="Invalid API key") diff --git a/serve/config.py b/serve/config.py new file mode 100644 index 00000000..f036661e --- /dev/null +++ b/serve/config.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +import os +from dataclasses import dataclass + + +def _b(name: str, default: str = "") -> str: + v = os.environ.get(name) + return v if v is not None else default + + +def _i(name: str, default: int) -> int: + v = os.environ.get(name) + if v is None or v.strip() == "": + return default + return int(v) + + +@dataclass(frozen=True) +class Settings: + """Runtime configuration (environment variables).""" + + redis_url: str + api_keys: frozenset[str] + repo_root: str + job_dir: str + output_dir: str + ckpt_dir: str + # torchrun / multi-node + nnodes: int + nproc_per_node: int + node_rank: int + master_addr: str + master_port: int + rdzv_id_prefix: str + python_bin: str + torchrun_bin: str + cluster_lock_ttl_sec: int + queue_name: str + task_key_prefix: str + lock_key: str + signal_key: str + node_role: str + conda_env: str = "" + conda_exe: str = "" + + @classmethod + def from_env(cls) -> "Settings": + keys_raw = _b("WAN_SERVE_API_KEYS", _b("WAN_SERVE_API_KEY", "")) + keys = frozenset(k.strip() for k in keys_raw.split(",") if k.strip()) + return cls( + redis_url=_b("WAN_REDIS_URL", "redis://127.0.0.1:6379/0"), + api_keys=keys, + repo_root=_b("WAN_REPO_ROOT", os.getcwd()), + job_dir=_b("WAN_JOB_DIR", "/tmp/wan2_jobs"), + output_dir=_b("WAN_OUTPUT_DIR", "/tmp/wan2_outputs"), + ckpt_dir=_b("WAN_CKPT_DIR", ""), + nnodes=_i("WAN_NNODES", 1), + nproc_per_node=_i("WAN_NPROC_PER_NODE", 1), + node_rank=_i("WAN_NODE_RANK", 0), + master_addr=_b("WAN_MASTER_ADDR", "127.0.0.1"), + master_port=_i("WAN_MASTER_PORT", 29500), + rdzv_id_prefix=_b("WAN_RDZV_PREFIX", "wan"), + python_bin=_b("WAN_PYTHON", "python3"), + torchrun_bin=_b("WAN_TORCHRUN", "torchrun"), + cluster_lock_ttl_sec=_i("WAN_CLUSTER_LOCK_TTL_SEC", 600), + queue_name=_b("WAN_QUEUE_NAME", "wan:queue"), + task_key_prefix=_b("WAN_TASK_KEY_PREFIX", "wan:task:"), + lock_key=_b("WAN_CLUSTER_LOCK_KEY", "wan:cluster_lock"), + signal_key=_b("WAN_SIGNAL_KEY", "wan:signal"), + node_role=_b("WAN_NODE_ROLE", "master"), + conda_env=_b("WAN_CONDA_ENV", ""), + conda_exe=_b("WAN_CONDA_EXE", ""), + ) \ No newline at end of file diff --git a/serve/job_build.py b/serve/job_build.py new file mode 100644 index 00000000..4e316638 --- /dev/null +++ b/serve/job_build.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +from typing import Any, Dict + +from .config import Settings +from .schemas import ModelEnum, VideoGenerationRequest + +# Default size per model (first supported size from WAN_CONFIGS) +_MODEL_DEFAULT_SIZE = { + ModelEnum.t2v_a14b.value: "1280*720", + ModelEnum.i2v_a14b.value: "832*480", + ModelEnum.ti2v_5b.value: "1280*704", + ModelEnum.animate_14b.value: "720*1280", + ModelEnum.s2v_14b.value: "832*480", +} + +# Sub-directory name under ckpt_dir for each model +_MODEL_CKPT_SUBDIR = { + ModelEnum.t2v_a14b.value: "Wan2.2-T2V-A14B", + ModelEnum.i2v_a14b.value: "Wan2.2-I2V-A14B", + ModelEnum.ti2v_5b.value: "Wan2.2-TI2V-5B", + ModelEnum.animate_14b.value: "Wan2.2-Animate-14B", + ModelEnum.s2v_14b.value: "Wan2.2-S2V-14B", +} + + +def request_to_job( + req: VideoGenerationRequest, + task_id: str, + settings: Settings, +) -> Dict[str, Any]: + """Build a flat job dict for ``generate.args_from_job_dict``.""" + model = req.model.value + job: Dict[str, Any] = {"model": model} + job.update(req.input.model_dump(exclude_none=True)) + job.update(req.parameters.model_dump(exclude_none=True)) + + # Map VideoInput.video → src_root_path (generate.py uses --src_root_path, not --video) + # Always remove "video" from job dict — generate.py doesn't have --video arg + if "video" in job: + video_val = job.pop("video") + if "src_root_path" not in job: + job["src_root_path"] = video_val + + # ckpt_dir: parameters.ckpt_dir overrides everything, + # otherwise auto-append model-specific subdirectory to global ckpt_dir + if job.get("ckpt_dir"): + pass # user explicitly specified, keep it + elif settings.ckpt_dir: + subdir = _MODEL_CKPT_SUBDIR.get(model) + if subdir: + job["ckpt_dir"] = f"{settings.ckpt_dir.rstrip('/')}/{subdir}" + else: + job["ckpt_dir"] = settings.ckpt_dir + + if not job.get("save_file"): + job["save_file"] = f"{settings.output_dir.rstrip('/')}/{task_id}.mp4" + + # Fill default size if not provided + if not job.get("size"): + job["size"] = _MODEL_DEFAULT_SIZE.get(model, "832*480") + + # Auto-enable memory-saving defaults for A100 40GB / dual-expert models + # FSDP shards model across GPUs — incompatible with offload_model + # DDP (no FSDP) requires offload_model to fit 14B in 40GB + if job.get("dit_fsdp") is None: + job["dit_fsdp"] = True + if job.get("t5_cpu") is None: + job["t5_cpu"] = True + if job.get("dit_fsdp") and job.get("convert_model_dtype") is None: + job["convert_model_dtype"] = True + + # Auto-enable sequence parallel for high-resolution to reduce activation memory + # SP + FSDP FULL_SHARD + offload_model: SP splits activations, + # FSDP splits parameters, offload swaps inactive expert to CPU + world_size = settings.nproc_per_node * settings.nnodes + size_str = job.get("size", "") + w, h = (int(x) for x in size_str.split("*")) + is_high_res = w * h > 480 * 832 # pixels above 832*480 need SP + offload + if is_high_res and job.get("ulysses_size") is None and world_size > 1: + _MODEL_NUM_HEADS = { + ModelEnum.t2v_a14b.value: 40, + ModelEnum.i2v_a14b.value: 40, + ModelEnum.ti2v_5b.value: 24, + } + num_heads = _MODEL_NUM_HEADS.get(model, 40) + if num_heads % world_size == 0: + job["ulysses_size"] = world_size + # High-res with SP+FSDP: both experts stay on GPU (FSDP shards params) + if job.get("offload_model") is None: + job["offload_model"] = False + else: + # Low-res with FSDP only: no offload needed + if job.get("offload_model") is None: + job["offload_model"] = False + + # Speed defaults: DPM++ 20 steps is comparable quality to UniPC 40 steps + if job.get("sample_solver") is None: + job["sample_solver"] = "dpm++" + if job.get("sample_steps") is None: + job["sample_steps"] = 20 + + return job \ No newline at end of file diff --git a/serve/launcher.py b/serve/launcher.py new file mode 100644 index 00000000..32a1cb9c --- /dev/null +++ b/serve/launcher.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +import os +import subprocess +from pathlib import Path +from typing import List + +from .config import Settings + + +def _torchrun_cmd(settings: Settings, job_json: Path, rdzv_id: str) -> List[str]: + repo = Path(settings.repo_root).resolve() + script = repo / "generate_job.py" + return [ + settings.torchrun_bin, + f"--nnodes={settings.nnodes}", + f"--nproc_per_node={settings.nproc_per_node}", + f"--node_rank={settings.node_rank}", + "--rdzv_backend=c10d", + f"--rdzv_endpoint={settings.master_addr}:{settings.master_port}", + f"--rdzv_id={rdzv_id}", + str(script), + "--job_json", + str(job_json.resolve()), + ] + + +def _env_for_child(settings: Settings) -> dict: + env = os.environ.copy() + repo = str(Path(settings.repo_root).resolve()) + prev = env.get("PYTHONPATH", "") + env["PYTHONPATH"] = f"{repo}:{prev}" if prev else repo + return env + + +def launch_generate_job( + settings: Settings, + job_json: Path, + rdzv_id: str, +) -> int: + cmd = _torchrun_cmd(settings, job_json, rdzv_id) + env = _env_for_child(settings) + repo = Path(settings.repo_root).resolve() + + proc = subprocess.run( + cmd, + cwd=str(repo), + env=env, + ) + return int(proc.returncode) \ No newline at end of file diff --git a/serve/schemas.py b/serve/schemas.py new file mode 100644 index 00000000..cc64024d --- /dev/null +++ b/serve/schemas.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +from enum import Enum +from typing import Any, Dict, List, Literal, Optional, Union + +from pydantic import BaseModel, Field + + +class ModelEnum(str, Enum): + t2v_a14b = "wan2.2-t2v-a14b" + i2v_a14b = "wan2.2-i2v-a14b" + ti2v_5b = "wan2.2-ti2v-5b" + s2v_14b = "wan2.2-s2v-14b" + animate_14b = "wan2.2-animate-14b" + + +class VideoInput(BaseModel): + prompt: Optional[str] = None + image: Optional[str] = None + audio: Optional[str] = None + video: Optional[str] = None + + +class VideoParameters(BaseModel): + size: Optional[str] = None + frame_num: Optional[int] = None + ckpt_dir: Optional[str] = None + save_file: Optional[str] = None + offload_model: Optional[bool] = None + ulysses_size: Optional[int] = None + t5_fsdp: Optional[bool] = None + t5_cpu: Optional[bool] = None + dit_fsdp: Optional[bool] = None + use_prompt_extend: Optional[bool] = None + prompt_extend_method: Optional[Literal["dashscope", "local_qwen"]] = None + prompt_extend_model: Optional[str] = None + prompt_extend_target_lang: Optional[Literal["zh", "en"]] = None + base_seed: Optional[int] = None + sample_solver: Optional[Literal["unipc", "dpm++"]] = None + sample_steps: Optional[int] = None + sample_shift: Optional[float] = None + sample_guide_scale: Optional[Union[float, List[float]]] = None + convert_model_dtype: Optional[bool] = None + task: Optional[str] = None + # animate extras + src_root_path: Optional[str] = None + refert_num: Optional[int] = None + replace_flag: Optional[bool] = None + use_relighting_lora: Optional[bool] = None + mask: Optional[str] = None + # s2v extras + num_clip: Optional[int] = None + enable_tts: Optional[bool] = None + tts_prompt_audio: Optional[str] = None + tts_prompt_text: Optional[str] = None + tts_text: Optional[str] = None + pose_video: Optional[str] = None + start_from_ref: Optional[bool] = None + infer_frames: Optional[int] = None + + +class VideoGenerationRequest(BaseModel): + model: ModelEnum = Field(..., description="Model id, e.g. wan2.2-t2v-a14b") + input: VideoInput = Field(default_factory=VideoInput) + parameters: VideoParameters = Field(default_factory=VideoParameters) + + +class OutputTaskId(BaseModel): + task_id: str + + +class VideoGenerationResponse(BaseModel): + request_id: str + output: OutputTaskId + + +class TaskStatusBody(BaseModel): + task_id: str + task_status: Literal["PENDING", "RUNNING", "SUCCEEDED", "FAILED"] + message: str = "" + output: Dict[str, Any] = Field(default_factory=dict) + request_id: Optional[str] = None + model: Optional[str] = None + + +class HealthResponse(BaseModel): + status: str = "ok" \ No newline at end of file diff --git a/serve/static/index.html b/serve/static/index.html new file mode 100644 index 00000000..d2125ab7 --- /dev/null +++ b/serve/static/index.html @@ -0,0 +1,425 @@ + + + + + +Wan2.2 Video Generation + + + +
+

🎬 Wan2.2 Video Generation

+

Generate videos from text prompts, images, and audio

+ +
+ + +
+ +
+
New Generation
+ +
+ + +
+ +
+ + +
+ + + + + + + +
+
+ + +
+
+ + +
+
+ +
Advanced Options
+
+
+
+ + +
+
+ + +
+
+
+
+ + +
+
+ + +
+
+
+ + + +
+ + + +
+
+ + + + \ No newline at end of file diff --git a/serve/store.py b/serve/store.py new file mode 100644 index 00000000..dfaa776c --- /dev/null +++ b/serve/store.py @@ -0,0 +1,112 @@ +from __future__ import annotations + +import json +import time +from typing import Any, Dict, Optional + +import redis + +from .config import Settings + + +def _task_key(settings: Settings, task_id: str) -> str: + return f"{settings.task_key_prefix}{task_id}" + + +class TaskStore: + def __init__(self, settings: Settings): + self._settings = settings + self._r = redis.Redis.from_url(settings.redis_url, decode_responses=True) + + def create_task(self, task_id: str, request_id: str, job: Dict[str, Any]) -> None: + now = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) + doc = { + "task_id": task_id, + "request_id": request_id, + "status": "PENDING", + "message": "", + "output_path": None, + "created_at": now, + "updated_at": now, + "job": job, + } + self._r.set(_task_key(self._settings, task_id), json.dumps(doc)) + self._r.lpush(self._settings.queue_name, task_id) + + def get_public(self, task_id: str) -> Optional[Dict[str, Any]]: + raw = self._r.get(_task_key(self._settings, task_id)) + if not raw: + return None + doc = json.loads(raw) + job = doc.get("job") or {} + out: Dict[str, Any] = { + "task_id": doc["task_id"], + "task_status": doc["status"], + "message": doc.get("message") or "", + "output": {}, + } + if doc.get("output_path"): + out["output"]["video_url"] = ( + f"/api/v1/files/by-task/{doc['task_id']}" + ) + out["output"]["path"] = doc["output_path"] + out["request_id"] = doc.get("request_id") + out["model"] = job.get("model") or job.get("task") + return out + + def get_internal(self, task_id: str) -> Optional[Dict[str, Any]]: + raw = self._r.get(_task_key(self._settings, task_id)) + if not raw: + return None + return json.loads(raw) + + def update( + self, + task_id: str, + *, + status: Optional[str] = None, + message: Optional[str] = None, + output_path: Optional[str] = None, + ) -> None: + raw = self._r.get(_task_key(self._settings, task_id)) + if not raw: + return + doc = json.loads(raw) + if status is not None: + doc["status"] = status + if message is not None: + doc["message"] = message + if output_path is not None: + doc["output_path"] = output_path + doc["updated_at"] = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) + self._r.set(_task_key(self._settings, task_id), json.dumps(doc)) + + def acquire_cluster_lock(self) -> bool: + ok = self._r.set( + self._settings.lock_key, + "1", + nx=True, + ex=self._settings.cluster_lock_ttl_sec, + ) + return bool(ok) + + def release_cluster_lock(self) -> None: + self._r.delete(self._settings.lock_key) + + def brpop_task_id(self, timeout: int = 5) -> Optional[str]: + item = self._r.brpop(self._settings.queue_name, timeout=timeout) + if not item: + return None + return item[1] + + def requeue(self, task_id: str) -> None: + self._r.rpush(self._settings.queue_name, task_id) + + def publish_signal(self, payload: str) -> None: + self._r.lpush(self._settings.signal_key, payload) + + def brpop_signal(self, timeout: int = 10) -> Optional[str]: + item = self._r.brpop(self._settings.signal_key, timeout=timeout) + if not item: + return None + return item[1] diff --git a/serve/worker_main.py b/serve/worker_main.py new file mode 100644 index 00000000..186550ba --- /dev/null +++ b/serve/worker_main.py @@ -0,0 +1,127 @@ +from __future__ import annotations + +import json +import logging +import sys +import time +from pathlib import Path + +from .config import Settings +from .launcher import launch_generate_job +from .store import TaskStore + +logging.basicConfig( + level=logging.INFO, + format="[%(asctime)s] %(levelname)s: %(message)s", +) + + +def main_master(settings: Settings, store: TaskStore): + """Master node: pull tasks from Redis queue, signal worker1 via Redis, run torchrun.""" + logging.info("Master worker started; queue=%s, nnodes=%d", settings.queue_name, settings.nnodes) + + while True: + task_id = store.brpop_task_id(timeout=10) + if not task_id: + continue + doc = store.get_internal(task_id) + if not doc: + logging.warning("Missing task doc for %s", task_id) + continue + + if not store.acquire_cluster_lock(): + logging.info("Cluster busy; requeue %s", task_id) + store.requeue(task_id) + time.sleep(3) + continue + + try: + store.update(task_id, status="RUNNING") + job = doc["job"] + job_path = Path(settings.job_dir) / f"{task_id}.json" + job_path.write_text(json.dumps(job, indent=2), encoding="utf-8") + rdzv_id = f"{settings.rdzv_id_prefix}-{task_id}" + + # Signal worker1 to join this torchrun job + if settings.nnodes > 1: + signal = json.dumps({ + "task_id": task_id, + "rdzv_id": rdzv_id, + "job": job, + }) + store.publish_signal(signal) + logging.info("Published signal for worker1: %s / rdzv_id=%s", task_id, rdzv_id) + + rc = launch_generate_job(settings, job_path, rdzv_id) + out_path = job.get("save_file") + if rc != 0: + store.update( + task_id, + status="FAILED", + message=f"torchrun exited with code {rc}", + ) + elif out_path and Path(out_path).is_file(): + store.update( + task_id, + status="SUCCEEDED", + output_path=out_path, + ) + else: + store.update( + task_id, + status="FAILED", + message="save_file missing after run", + ) + except Exception as e: + logging.exception("task %s failed", task_id) + store.update(task_id, status="FAILED", message=str(e)) + finally: + store.release_cluster_lock() + + +def main_worker(settings: Settings, store: TaskStore): + """Worker node (secondary): listen for signal from master, write local job JSON, run torchrun.""" + logging.info("Worker node started; waiting for signals on %s", settings.signal_key) + + while True: + raw = store.brpop_signal(timeout=10) + if not raw: + continue + + try: + signal = json.loads(raw) + task_id = signal["task_id"] + rdzv_id = signal["rdzv_id"] + job = signal["job"] + logging.info("Received signal: task=%s, rdzv_id=%s", task_id, rdzv_id) + + # Write job JSON locally so generate_job.py can read it + job_path = Path(settings.job_dir) / f"{task_id}.json" + job_path.write_text(json.dumps(job, indent=2), encoding="utf-8") + + rc = launch_generate_job(settings, job_path, rdzv_id) + logging.info("Worker torchrun for %s finished with rc=%d", task_id, rc) + except Exception as e: + logging.exception("worker failed processing signal: %s", raw) + + +def main(): + settings = Settings.from_env() + Path(settings.job_dir).mkdir(parents=True, exist_ok=True) + Path(settings.output_dir).mkdir(parents=True, exist_ok=True) + store = TaskStore(settings) + + # Clear stale cluster lock on startup + store.release_cluster_lock() + + if settings.node_role == "master": + main_master(settings, store) + else: + main_worker(settings, store) + + +if __name__ == "__main__": + try: + main() + except KeyboardInterrupt: + sys.exit(0) \ No newline at end of file diff --git a/tests/test_api.py b/tests/test_api.py new file mode 100644 index 00000000..f2a6fcac --- /dev/null +++ b/tests/test_api.py @@ -0,0 +1,411 @@ +"""Integration tests for Wan2.2 API service. + +Requires a running instance with Redis + API server. +Set environment variables before running: + WAN_TEST_API_URL — e.g. http://localhost:8008 + WAN_TEST_API_KEY — a valid API key + +Usage: + export WAN_TEST_API_URL=http://localhost:8008 + export WAN_TEST_API_KEY=sk-test-key + pytest tests/test_api.py -v +""" +from __future__ import annotations + +import json +import os +import time +import urllib.request +import urllib.error +import uuid + +API_URL = os.environ.get("WAN_TEST_API_URL", "http://localhost:8008") +API_KEY = os.environ.get("WAN_TEST_API_KEY", "sk-test-key") + + +def _headers(): + return { + "Authorization": f"Bearer {API_KEY}", + "Content-Type": "application/json", + } + + +def _request(method: str, path: str, body: dict | None = None): + url = f"{API_URL}{path}" + data = json.dumps(body).encode() if body else None + req = urllib.request.Request(url, data=data, headers=_headers(), method=method) + try: + with urllib.request.urlopen(req, timeout=30) as resp: + return resp.status, json.loads(resp.read()) + except urllib.error.HTTPError as e: + return e.code, json.loads(e.read()) + + +# ============================================================ +# Health endpoint +# ============================================================ + + +class TestHealth: + def test_healthz(self): + """GET /healthz should return 200 with status ok.""" + url = f"{API_URL}/healthz" + req = urllib.request.Request(url, method="GET") + with urllib.request.urlopen(req, timeout=10) as resp: + assert resp.status == 200 + data = json.loads(resp.read()) + assert data["status"] == "ok" + + +# ============================================================ +# Authentication +# ============================================================ + + +class TestAuth: + def test_no_auth_header(self): + """POST without Authorization should return 401.""" + url = f"{API_URL}/api/v1/video/generation" + body = json.dumps({"model": "wan2.2-t2v-a14b", "input": {"prompt": "test"}}).encode() + req = urllib.request.Request(url, data=body, headers={"Content-Type": "application/json"}, method="POST") + try: + urllib.request.urlopen(req, timeout=10) + assert False, "Expected 401" + except urllib.error.HTTPError as e: + assert e.code in (401, 403) + + def test_invalid_api_key(self): + """POST with wrong Bearer token should return 403.""" + url = f"{API_URL}/api/v1/video/generation" + body = json.dumps({"model": "wan2.2-t2v-a14b", "input": {"prompt": "test"}}).encode() + req = urllib.request.Request( + url, data=body, + headers={"Authorization": "Bearer wrong-key", "Content-Type": "application/json"}, + method="POST", + ) + try: + urllib.request.urlopen(req, timeout=10) + assert False, "Expected 403" + except urllib.error.HTTPError as e: + assert e.code == 403 + + def test_missing_ckpt_dir(self): + """POST without ckpt_dir (and server WAN_CKPT_DIR empty) should return 400.""" + url = f"{API_URL}/api/v1/video/generation" + body = json.dumps({ + "model": "wan2.2-t2v-a14b", + "input": {"prompt": "test"}, + "parameters": {}, + }).encode() + req = urllib.request.Request(url, data=body, headers=_headers(), method="POST") + try: + urllib.request.urlopen(req, timeout=10) + # If WAN_CKPT_DIR is set on server, this will succeed (200) — that's fine too + except urllib.error.HTTPError as e: + assert e.code == 400 + + +# ============================================================ +# Per-model field validation +# ============================================================ + + +class TestModelValidation: + def test_i2v_requires_image(self): + """POST i2v without image should return 400.""" + status, data = _request("POST", "/api/v1/video/generation", { + "model": "wan2.2-i2v-a14b", + "input": {"prompt": "A cat"}, + }) + assert status == 400 + assert "image" in data.get("detail", "") + + def test_i2v_with_image_passes(self): + """POST i2v with image should return 200.""" + status, data = _request("POST", "/api/v1/video/generation", { + "model": "wan2.2-i2v-a14b", + "input": {"prompt": "A cat", "image": "/ckpt/Wan2.2-I2V-A14B/ref.jpg"}, + }) + assert status == 200 + + def test_animate_requires_video(self): + """POST animate without video should return 400.""" + status, data = _request("POST", "/api/v1/video/generation", { + "model": "wan2.2-animate-14b", + "input": {"prompt": "pose"}, + }) + assert status == 400 + assert "video" in data.get("detail", "") + + def test_s2v_requires_image(self): + """POST s2v without image should return 400.""" + status, data = _request("POST", "/api/v1/video/generation", { + "model": "wan2.2-s2v-14b", + "input": {"prompt": "talk", "audio": "/ckpt/speech.wav"}, + }) + assert status == 400 + assert "image" in data.get("detail", "") + + def test_s2v_requires_audio_or_tts(self): + """POST s2v without audio or enable_tts should return 400.""" + status, data = _request("POST", "/api/v1/video/generation", { + "model": "wan2.2-s2v-14b", + "input": {"prompt": "talk", "image": "/ckpt/ref.jpg"}, + }) + assert status == 400 + assert "audio" in data.get("detail", "") + + def test_all_models_require_prompt(self): + """POST any model without prompt should return 400.""" + for model in ("wan2.2-t2v-a14b", "wan2.2-i2v-a14b", "wan2.2-ti2v-5b", + "wan2.2-animate-14b", "wan2.2-s2v-14b"): + status, data = _request("POST", "/api/v1/video/generation", { + "model": model, + "input": {}, + }) + assert status == 400 + assert "prompt" in data.get("detail", "") + + +# ============================================================ +# Task creation (all 5 models) +# ============================================================ + + +class TestTaskCreation: + def test_create_t2v_task(self): + """POST /api/v1/video/generation with t2v model should return task_id.""" + status, data = _request("POST", "/api/v1/video/generation", { + "model": "wan2.2-t2v-a14b", + "input": {"prompt": "A cat walking on a beach at sunset"}, + "parameters": { + "size": "832*480", + "frame_num": 81, + }, + }) + assert status == 200 + assert "request_id" in data + assert "output" in data + assert "task_id" in data["output"] + assert data["output"]["task_id"].startswith("wan-") + + def test_create_i2v_task(self): + """POST with i2v model and image input.""" + status, data = _request("POST", "/api/v1/video/generation", { + "model": "wan2.2-i2v-a14b", + "input": { + "prompt": "A cat dancing", + "image": "/ckpt/Wan2.2-I2V-A14B/ref.jpg", + }, + "parameters": { + "size": "832*480", + }, + }) + assert status == 200 + assert "task_id" in data["output"] + + def test_create_ti2v_task(self): + """POST with ti2v model (prompt only, image optional).""" + status, data = _request("POST", "/api/v1/video/generation", { + "model": "wan2.2-ti2v-5b", + "input": {"prompt": "A dog running in a park"}, + "parameters": { + "size": "1280*704", + }, + }) + assert status == 200 + assert "task_id" in data["output"] + + def test_create_animate_task(self): + """POST with animate model and video input.""" + status, data = _request("POST", "/api/v1/video/generation", { + "model": "wan2.2-animate-14b", + "input": { + "prompt": "视频中的人在做动作", + "video": "/ckpt/animate_input", + }, + "parameters": { + "size": "720*1280", + "refert_num": 77, + }, + }) + assert status == 200 + assert "task_id" in data["output"] + + def test_create_s2v_task_with_audio(self): + """POST with s2v model and audio input.""" + status, data = _request("POST", "/api/v1/video/generation", { + "model": "wan2.2-s2v-14b", + "input": { + "prompt": "A person talking", + "image": "/ckpt/Wan2.2-S2V-14B/ref.jpg", + "audio": "/ckpt/Wan2.2-S2V-14B/speech.wav", + }, + "parameters": { + "size": "832*480", + }, + }) + assert status == 200 + assert "task_id" in data["output"] + + def test_create_s2v_task_with_tts(self): + """POST with s2v model using TTS instead of audio file.""" + status, data = _request("POST", "/api/v1/video/generation", { + "model": "wan2.2-s2v-14b", + "input": { + "prompt": "A person talking", + "image": "/ckpt/Wan2.2-S2V-14B/ref.jpg", + }, + "parameters": { + "size": "832*480", + "enable_tts": True, + "tts_prompt_audio": "/ckpt/prompt.wav", + "tts_prompt_text": "希望你以后能够做的比我还好呦。", + "tts_text": "收到好友从远方寄来的生日礼物。", + }, + }) + assert status == 200 + assert "task_id" in data["output"] + + def test_create_task_with_optional_params(self): + """POST with all optional parameters.""" + status, data = _request("POST", "/api/v1/video/generation", { + "model": "wan2.2-t2v-a14b", + "input": {"prompt": "A dog running in a park"}, + "parameters": { + "size": "1280*720", + "frame_num": 81, + "sample_steps": 40, + "sample_shift": 5.0, + "sample_guide_scale": 5.0, + "base_seed": 42, + "sample_solver": "unipc", + "t5_cpu": True, + "dit_fsdp": True, + "t5_fsdp": True, + }, + }) + assert status == 200 + assert "task_id" in data["output"] + + def test_task_id_format(self): + """Task IDs should follow wan-{uuid_hex} pattern.""" + status, data = _request("POST", "/api/v1/video/generation", { + "model": "wan2.2-t2v-a14b", + "input": {"prompt": "format test"}, + }) + assert status == 200 + task_id = data["output"]["task_id"] + prefix, hex_part = task_id.split("-", 1) + assert prefix == "wan" + assert len(hex_part) == 32 # uuid4 hex + + +# ============================================================ +# Task status query +# ============================================================ + + +class TestTaskStatus: + def setup_method(self): + """Create a task before each test.""" + status, data = _request("POST", "/api/v1/video/generation", { + "model": "wan2.2-t2v-a14b", + "input": {"prompt": "status test"}, + }) + assert status == 200 + self.task_id = data["output"]["task_id"] + + def test_get_task_status(self): + """GET /api/v1/tasks/{task_id} should return task status.""" + status, data = _request("GET", f"/api/v1/tasks/{self.task_id}") + assert status == 200 + assert data["task_id"] == self.task_id + assert data["task_status"] in ("PENDING", "RUNNING", "SUCCEEDED", "FAILED") + assert "message" in data + assert "output" in data + + def test_get_nonexistent_task(self): + """GET a non-existent task_id should return 404.""" + fake_id = f"wan-{uuid.uuid4().hex}" + status, data = _request("GET", f"/api/v1/tasks/{fake_id}") + assert status == 404 + + def test_task_status_transitions(self): + """Poll task status — should start as PENDING, may transition to RUNNING/SUCCEEDED.""" + for _ in range(5): + status, data = _request("GET", f"/api/v1/tasks/{self.task_id}") + assert status == 200 + if data["task_status"] != "PENDING": + break + time.sleep(2) + + +# ============================================================ +# File download +# ============================================================ + + +class TestFileDownload: + def test_download_nonexistent_task(self): + """GET file for non-existent task should return 404.""" + fake_id = f"wan-{uuid.uuid4().hex}" + status, data = _request("GET", f"/api/v1/files/by-task/{fake_id}") + assert status == 404 + + def test_download_pending_task(self): + """GET file for a PENDING task should return 404 (video not ready).""" + status, data = _request("POST", "/api/v1/video/generation", { + "model": "wan2.2-t2v-a14b", + "input": {"prompt": "download test"}, + }) + assert status == 200 + task_id = data["output"]["task_id"] + + # A freshly created task is PENDING — video won't be ready + status, data = _request("GET", f"/api/v1/files/by-task/{task_id}") + assert status == 404 + + +# ============================================================ +# Edge cases +# ============================================================ + + +class TestEdgeCases: + def test_missing_model(self): + """POST without model field — should return 422 (FastAPI validation).""" + status, data = _request("POST", "/api/v1/video/generation", { + "input": {"prompt": "test"}, + }) + assert status == 422 + + def test_invalid_model(self): + """POST with invalid model name — should return 422.""" + status, data = _request("POST", "/api/v1/video/generation", { + "model": "wan2.2-nonexistent", + "input": {"prompt": "test"}, + }) + assert status == 422 + + def test_invalid_size_format(self): + """POST with unusual size format — API should accept (validation is lenient).""" + status, data = _request("POST", "/api/v1/video/generation", { + "model": "wan2.2-t2v-a14b", + "input": {"prompt": "size test"}, + "parameters": {"size": "999*999"}, + }) + # Whether this succeeds depends on server config; at least shouldn't crash + assert status in (200, 400) + + def test_multiple_tasks_same_prompt(self): + """Create multiple tasks with the same prompt — each should get unique task_id.""" + ids = [] + for _ in range(3): + status, data = _request("POST", "/api/v1/video/generation", { + "model": "wan2.2-t2v-a14b", + "input": {"prompt": "duplicate test"}, + }) + assert status == 200 + ids.append(data["output"]["task_id"]) + assert len(set(ids)) == 3 # all unique \ No newline at end of file diff --git a/tests/test_serve.py b/tests/test_serve.py new file mode 100644 index 00000000..918160f5 --- /dev/null +++ b/tests/test_serve.py @@ -0,0 +1,741 @@ +"""Unit tests for serve module — no running server or Redis required. + +Tests config parsing, job_build logic, worker_main routing, and +store signal serialization using mocks. + +Usage: + pytest tests/test_serve.py -v +""" +from __future__ import annotations + +import json +import os +import sys +import unittest +from contextlib import contextmanager +from pathlib import Path +from unittest.mock import MagicMock, patch + +# Ensure serve package is importable +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + + +@contextmanager +def set_env(**kwargs): + """Temporarily set environment variables and restore on exit.""" + old = {} + for k, v in kwargs.items(): + old[k] = os.environ.get(k) + os.environ[k] = v + yield + for k in kwargs: + if old[k] is None: + os.environ.pop(k, None) + else: + os.environ[k] = old[k] + + +# ============================================================ +# serve.config +# ============================================================ + + +class TestSettings: + def test_defaults(self): + with set_env(WAN_SERVE_API_KEYS="sk-test"): + from serve.config import Settings + s = Settings.from_env() + assert s.redis_url == "redis://127.0.0.1:6379/0" + assert s.nnodes == 1 + assert s.nproc_per_node == 1 + assert s.node_rank == 0 + assert s.node_role == "master" + assert s.signal_key == "wan:signal" + assert s.master_port == 29500 + assert "sk-test" in s.api_keys + + def test_multi_node_config(self): + with set_env( + WAN_SERVE_API_KEYS="sk-test", + WAN_NNODES="2", + WAN_NPROC_PER_NODE="4", + WAN_NODE_RANK="1", + WAN_NODE_ROLE="worker", + WAN_MASTER_ADDR="10.0.0.1", + WAN_MASTER_PORT="29600", + ): + from serve.config import Settings + s = Settings.from_env() + assert s.nnodes == 2 + assert s.nproc_per_node == 4 + assert s.node_rank == 1 + assert s.node_role == "worker" + assert s.master_addr == "10.0.0.1" + assert s.master_port == 29600 + + def test_multiple_api_keys(self): + with set_env(WAN_SERVE_API_KEYS="sk-one,sk-two,sk-three"): + from serve.config import Settings + s = Settings.from_env() + assert s.api_keys == frozenset({"sk-one", "sk-two", "sk-three"}) + + def test_empty_api_keys(self): + # Clear any existing key env vars + for var in ("WAN_SERVE_API_KEYS", "WAN_SERVE_API_KEY"): + os.environ.pop(var, None) + from serve.config import Settings + s = Settings.from_env() + assert s.api_keys == frozenset() + + def test_frozen(self): + with set_env(WAN_SERVE_API_KEYS="sk-test"): + from serve.config import Settings + s = Settings.from_env() + try: + s.redis_url = "x" + assert False, "Should be frozen" + except AttributeError: + pass + + +# ============================================================ +# serve.job_build +# ============================================================ + + +class TestJobBuild: + def test_basic_t2v(self): + from serve.config import Settings + from serve.job_build import request_to_job + from serve.schemas import VideoGenerationRequest + + env = { + "WAN_SERVE_API_KEYS": "sk-test", + "WAN_CKPT_DIR": "/ckpt", + "WAN_OUTPUT_DIR": "/out", + } + for k, v in env.items(): + os.environ[k] = v + + try: + s = Settings.from_env() + req = VideoGenerationRequest( + model="wan2.2-t2v-a14b", + input={"prompt": "A cat"}, + parameters={"size": "832*480", "frame_num": 81}, + ) + job = request_to_job(req, task_id="wan-abc123", settings=s) + assert job["model"] == "wan2.2-t2v-a14b" + assert job["prompt"] == "A cat" + assert job["size"] == "832*480" + assert job["frame_num"] == 81 + assert job["ckpt_dir"] == "/ckpt/Wan2.2-T2V-A14B" + assert job["save_file"] == "/out/wan-abc123.mp4" + finally: + for k in env: + os.environ.pop(k, None) + + def test_parameters_ckpt_dir_overrides_global(self): + from serve.config import Settings + from serve.job_build import request_to_job + from serve.schemas import VideoGenerationRequest + + env = { + "WAN_SERVE_API_KEYS": "sk-test", + "WAN_CKPT_DIR": "/global-ckpt", + "WAN_OUTPUT_DIR": "/out", + } + for k, v in env.items(): + os.environ[k] = v + + try: + s = Settings.from_env() + req = VideoGenerationRequest( + model="wan2.2-t2v-a14b", + input={"prompt": "A cat"}, + parameters={"ckpt_dir": "/custom-ckpt"}, + ) + job = request_to_job(req, task_id="wan-xyz", settings=s) + # per-request ckpt_dir takes precedence + assert job["ckpt_dir"] == "/custom-ckpt" + finally: + for k in env: + os.environ.pop(k, None) + + def test_none_params_excluded(self): + from serve.config import Settings + from serve.job_build import request_to_job + from serve.schemas import VideoGenerationRequest + + env = {"WAN_SERVE_API_KEYS": "sk-test", "WAN_OUTPUT_DIR": "/out"} + for k, v in env.items(): + os.environ[k] = v + + try: + s = Settings.from_env() + req = VideoGenerationRequest(model="wan2.2-t2v-a14b") + job = request_to_job(req, task_id="wan-min", settings=s) + # Parameters that were None should not appear + assert "frame_num" not in job + assert "base_seed" not in job + # size gets a default value when not provided + assert job["size"] == "1280*720" + assert "model" in job + assert "save_file" in job + finally: + for k in env: + os.environ.pop(k, None) + + +# ============================================================ +# serve.worker_main — routing logic +# ============================================================ + + +class TestWorkerRouting: + def test_master_role_calls_main_master(self): + os.environ["WAN_SERVE_API_KEYS"] = "sk-test" + os.environ["WAN_NODE_ROLE"] = "master" + + try: + with patch("serve.worker_main.main_master") as mock_master, \ + patch("serve.worker_main.main_worker") as mock_worker: + from serve.worker_main import main + with patch("serve.worker_main.Settings") as MockSettings, \ + patch("serve.worker_main.TaskStore"): + MockSettings.from_env.return_value = MagicMock( + job_dir="/tmp/jobs", + output_dir="/tmp/out", + node_role="master", + ) + main() + mock_master.assert_called_once() + mock_worker.assert_not_called() + finally: + os.environ.pop("WAN_SERVE_API_KEYS", None) + os.environ.pop("WAN_NODE_ROLE", None) + + def test_worker_role_calls_main_worker(self): + os.environ["WAN_SERVE_API_KEYS"] = "sk-test" + os.environ["WAN_NODE_ROLE"] = "worker" + + try: + with patch("serve.worker_main.main_master") as mock_master, \ + patch("serve.worker_main.main_worker") as mock_worker: + from serve.worker_main import main + with patch("serve.worker_main.Settings") as MockSettings, \ + patch("serve.worker_main.TaskStore"): + MockSettings.from_env.return_value = MagicMock( + job_dir="/tmp/jobs", + output_dir="/tmp/out", + node_role="worker", + ) + main() + mock_worker.assert_called_once() + mock_master.assert_not_called() + finally: + os.environ.pop("WAN_SERVE_API_KEYS", None) + os.environ.pop("WAN_NODE_ROLE", None) + + +# ============================================================ +# serve.store — signal publish +# ============================================================ + + +class TestStoreSignal: + def test_publish_signal(self): + from serve.config import Settings + from serve.store import TaskStore + + os.environ["WAN_SERVE_API_KEYS"] = "sk-test" + try: + s = Settings.from_env() + store = TaskStore(s) + mock_redis = MagicMock() + store._r = mock_redis + + payload = json.dumps({"task_id": "wan-test", "rdzv_id": "wan-abc"}) + store.publish_signal(payload) + + mock_redis.publish.assert_called_once_with(s.signal_key, payload) + finally: + os.environ.pop("WAN_SERVE_API_KEYS", None) + + +# ============================================================ +# serve.launcher — torchrun command construction +# ============================================================ + + +class TestLauncher: + def test_torchrun_cmd_master(self): + from serve.config import Settings + from serve.launcher import _torchrun_cmd + + os.environ["WAN_SERVE_API_KEYS"] = "sk-test" + os.environ["WAN_NNODES"] = "2" + os.environ["WAN_NPROC_PER_NODE"] = "4" + os.environ["WAN_NODE_RANK"] = "0" + os.environ["WAN_MASTER_ADDR"] = "10.0.0.1" + os.environ["WAN_MASTER_PORT"] = "29500" + + try: + s = Settings.from_env() + cmd = _torchrun_cmd(s, Path("/data/jobs/wan-test.json"), "wan-test-rdzv") + assert cmd[0] == "torchrun" + assert "--nnodes=2" in cmd + assert "--nproc_per_node=4" in cmd + assert "--node_rank=0" in cmd + assert "--rdzv_backend=c10d" in cmd + assert "--rdzv_endpoint=10.0.0.1:29500" in cmd + assert "--rdzv_id=wan-test-rdzv" in cmd + assert "--job_json" in cmd + finally: + for k in ("WAN_SERVE_API_KEYS", "WAN_NNODES", "WAN_NPROC_PER_NODE", + "WAN_NODE_RANK", "WAN_MASTER_ADDR", "WAN_MASTER_PORT"): + os.environ.pop(k, None) + + def test_torchrun_cmd_worker(self): + from serve.config import Settings + from serve.launcher import _torchrun_cmd + + os.environ["WAN_SERVE_API_KEYS"] = "sk-test" + os.environ["WAN_NNODES"] = "2" + os.environ["WAN_NPROC_PER_NODE"] = "4" + os.environ["WAN_NODE_RANK"] = "1" + os.environ["WAN_MASTER_ADDR"] = "10.0.0.1" + + try: + s = Settings.from_env() + cmd = _torchrun_cmd(s, Path("/data/jobs/wan-test.json"), "wan-test-rdzv") + assert "--node_rank=1" in cmd + finally: + for k in ("WAN_SERVE_API_KEYS", "WAN_NNODES", "WAN_NPROC_PER_NODE", + "WAN_NODE_RANK", "WAN_MASTER_ADDR"): + os.environ.pop(k, None) + + +# ============================================================ +# serve.schemas — validation +# ============================================================ + + +class TestSchemas: + def test_video_generation_request_model_required(self): + from serve.schemas import VideoGenerationRequest + try: + VideoGenerationRequest() + assert False, "model is required" + except Exception: + pass + + def test_video_generation_request_with_model(self): + from serve.schemas import VideoGenerationRequest + req = VideoGenerationRequest(model="wan2.2-t2v-a14b") + assert req.model == "wan2.2-t2v-a14b" + assert req.input.prompt is None + assert req.parameters.size is None + + def test_video_generation_request_full(self): + from serve.schemas import VideoGenerationRequest + req = VideoGenerationRequest( + model="wan2.2-i2v-a14b", + input={"prompt": "A cat", "image": "https://example.com/cat.jpg"}, + parameters={"size": "832*480", "frame_num": 81, "base_seed": 42}, + ) + assert req.input.prompt == "A cat" + assert req.input.image == "https://example.com/cat.jpg" + assert req.parameters.size == "832*480" + assert req.parameters.frame_num == 81 + assert req.parameters.base_seed == 42 + + def test_task_status_body(self): + from serve.schemas import TaskStatusBody + body = TaskStatusBody(task_id="wan-test", task_status="SUCCEEDED") + assert body.task_id == "wan-test" + assert body.task_status == "SUCCEEDED" + assert body.message == "" + assert body.output == {} + + def test_health_response(self): + from serve.schemas import HealthResponse + resp = HealthResponse() + assert resp.status == "ok" + + +if __name__ == "__main__": + unittest.main() + + +# ============================================================ +# serve.schemas — ModelEnum +# ============================================================ + + +class TestModelEnum: + def test_all_models_defined(self): + from serve.schemas import ModelEnum + assert len(ModelEnum) == 5 + assert ModelEnum.t2v_a14b.value == "wan2.2-t2v-a14b" + assert ModelEnum.i2v_a14b.value == "wan2.2-i2v-a14b" + assert ModelEnum.ti2v_5b.value == "wan2.2-ti2v-5b" + assert ModelEnum.s2v_14b.value == "wan2.2-s2v-14b" + assert ModelEnum.animate_14b.value == "wan2.2-animate-14b" + + def test_invalid_model_rejected(self): + from serve.schemas import VideoGenerationRequest + try: + VideoGenerationRequest(model="invalid-model") + assert False, "Should reject invalid model" + except Exception: + pass + + +# ============================================================ +# serve.api — per-model validation +# ============================================================ + + +class TestModelValidation: + """Test _validate_model_input for each model's required fields.""" + + def _make_body(self, model, **input_kwargs): + from serve.schemas import VideoGenerationRequest + return VideoGenerationRequest(model=model, input=input_kwargs) + + def test_t2v_requires_prompt(self): + from serve.api import _validate_model_input + from fastapi import HTTPException + body = self._make_body("wan2.2-t2v-a14b") + try: + _validate_model_input(body) + assert False, "Should require prompt" + except HTTPException as e: + assert e.status_code == 400 + + def test_t2v_with_prompt_passes(self): + from serve.api import _validate_model_input + from serve.schemas import VideoGenerationRequest + body = VideoGenerationRequest( + model="wan2.2-t2v-a14b", + input={"prompt": "A cat"}, + ) + _validate_model_input(body) # should not raise + + def test_i2v_requires_image(self): + from serve.api import _validate_model_input + from fastapi import HTTPException + body = self._make_body("wan2.2-i2v-a14b", prompt="A cat") + try: + _validate_model_input(body) + assert False, "Should require image" + except HTTPException as e: + assert e.status_code == 400 + assert "image" in e.detail + + def test_i2v_with_image_passes(self): + from serve.api import _validate_model_input + from serve.schemas import VideoGenerationRequest + body = VideoGenerationRequest( + model="wan2.2-i2v-a14b", + input={"prompt": "A cat", "image": "/path/to/cat.jpg"}, + ) + _validate_model_input(body) + + def test_animate_requires_video(self): + from serve.api import _validate_model_input + from fastapi import HTTPException + body = self._make_body("wan2.2-animate-14b", prompt="pose") + try: + _validate_model_input(body) + assert False, "Should require video" + except HTTPException as e: + assert e.status_code == 400 + assert "video" in e.detail + + def test_animate_with_video_passes(self): + from serve.api import _validate_model_input + from serve.schemas import VideoGenerationRequest + body = VideoGenerationRequest( + model="wan2.2-animate-14b", + input={"prompt": "pose", "video": "/path/to/ref.mp4"}, + ) + _validate_model_input(body) + + def test_s2v_requires_image(self): + from serve.api import _validate_model_input + from fastapi import HTTPException + body = self._make_body("wan2.2-s2v-14b", prompt="talk", audio="/path/to.wav") + try: + _validate_model_input(body) + assert False, "Should require image" + except HTTPException as e: + assert e.status_code == 400 + assert "image" in e.detail + + def test_s2v_requires_audio_or_tts(self): + from serve.api import _validate_model_input + from serve.schemas import VideoGenerationRequest + from fastapi import HTTPException + body = VideoGenerationRequest( + model="wan2.2-s2v-14b", + input={"prompt": "talk", "image": "/path/to/img.jpg"}, + ) + try: + _validate_model_input(body) + assert False, "Should require audio or enable_tts" + except HTTPException as e: + assert e.status_code == 400 + assert "audio" in e.detail + + def test_s2v_with_audio_passes(self): + from serve.api import _validate_model_input + from serve.schemas import VideoGenerationRequest + body = VideoGenerationRequest( + model="wan2.2-s2v-14b", + input={"prompt": "talk", "image": "/path/to/img.jpg", "audio": "/path/to.wav"}, + ) + _validate_model_input(body) + + def test_s2v_with_tts_passes(self): + from serve.api import _validate_model_input + from serve.schemas import VideoGenerationRequest + body = VideoGenerationRequest( + model="wan2.2-s2v-14b", + input={"prompt": "talk", "image": "/path/to/img.jpg"}, + parameters={"enable_tts": True}, + ) + _validate_model_input(body) + + def test_ti2v_with_prompt_passes(self): + from serve.api import _validate_model_input + from serve.schemas import VideoGenerationRequest + # ti2v only requires prompt, image is optional + body = VideoGenerationRequest( + model="wan2.2-ti2v-5b", + input={"prompt": "A cat"}, + ) + _validate_model_input(body) + + +# ============================================================ +# serve.job_build — all models +# ============================================================ + + +class TestJobBuildAllModels: + def _make_settings(self): + from serve.config import Settings + env = {"WAN_SERVE_API_KEYS": "sk-test", "WAN_CKPT_DIR": "/ckpt", "WAN_OUTPUT_DIR": "/out"} + for k, v in env.items(): + os.environ[k] = v + s = Settings.from_env() + for k in env: + os.environ.pop(k, None) + return s + + def test_t2v_default_size(self): + from serve.job_build import request_to_job + from serve.schemas import VideoGenerationRequest + s = self._make_settings() + req = VideoGenerationRequest( + model="wan2.2-t2v-a14b", + input={"prompt": "A cat"}, + ) + job = request_to_job(req, task_id="wan-t2v", settings=s) + assert job["size"] == "1280*720" + + def test_i2v_default_size(self): + from serve.job_build import request_to_job + from serve.schemas import VideoGenerationRequest + s = self._make_settings() + req = VideoGenerationRequest( + model="wan2.2-i2v-a14b", + input={"prompt": "A cat", "image": "/img.jpg"}, + ) + job = request_to_job(req, task_id="wan-i2v", settings=s) + assert job["size"] == "832*480" + assert job["image"] == "/img.jpg" + + def test_ti2v_default_size(self): + from serve.job_build import request_to_job + from serve.schemas import VideoGenerationRequest + s = self._make_settings() + req = VideoGenerationRequest( + model="wan2.2-ti2v-5b", + input={"prompt": "A cat"}, + ) + job = request_to_job(req, task_id="wan-ti2v", settings=s) + assert job["size"] == "1280*704" + + def test_s2v_default_size(self): + from serve.job_build import request_to_job + from serve.schemas import VideoGenerationRequest + s = self._make_settings() + req = VideoGenerationRequest( + model="wan2.2-s2v-14b", + input={"prompt": "talk", "image": "/img.jpg", "audio": "/talk.wav"}, + ) + job = request_to_job(req, task_id="wan-s2v", settings=s) + assert job["size"] == "832*480" + assert job["audio"] == "/talk.wav" + + def test_animate_default_size(self): + from serve.job_build import request_to_job + from serve.schemas import VideoGenerationRequest + s = self._make_settings() + req = VideoGenerationRequest( + model="wan2.2-animate-14b", + input={"prompt": "pose", "video": "/ref.mp4"}, + ) + job = request_to_job(req, task_id="wan-ani", settings=s) + assert job["size"] == "720*1280" + assert job["src_root_path"] == "/ref.mp4" + assert "video" not in job + + def test_explicit_size_overrides_default(self): + from serve.job_build import request_to_job + from serve.schemas import VideoGenerationRequest + s = self._make_settings() + req = VideoGenerationRequest( + model="wan2.2-t2v-a14b", + input={"prompt": "A cat"}, + parameters={"size": "480*832"}, + ) + job = request_to_job(req, task_id="wan-exp", settings=s) + assert job["size"] == "480*832" + + +# ============================================================ +# serve.job_build — ckpt_dir auto-mapping +# ============================================================ + + +class TestCkptDirMapping: + def _make_settings(self, ckpt_dir="/ckpt"): + from serve.config import Settings + env = {"WAN_SERVE_API_KEYS": "sk-test", "WAN_CKPT_DIR": ckpt_dir, "WAN_OUTPUT_DIR": "/out"} + for k, v in env.items(): + os.environ[k] = v + s = Settings.from_env() + for k in env: + os.environ.pop(k, None) + return s + + def test_t2v_auto_ckpt_dir(self): + from serve.job_build import request_to_job + from serve.schemas import VideoGenerationRequest + s = self._make_settings("/ckpt") + req = VideoGenerationRequest(model="wan2.2-t2v-a14b", input={"prompt": "A cat"}) + job = request_to_job(req, task_id="wan-t2v", settings=s) + assert job["ckpt_dir"] == "/ckpt/Wan2.2-T2V-A14B" + + def test_i2v_auto_ckpt_dir(self): + from serve.job_build import request_to_job + from serve.schemas import VideoGenerationRequest + s = self._make_settings("/ckpt") + req = VideoGenerationRequest(model="wan2.2-i2v-a14b", input={"prompt": "A cat", "image": "/img.jpg"}) + job = request_to_job(req, task_id="wan-i2v", settings=s) + assert job["ckpt_dir"] == "/ckpt/Wan2.2-I2V-A14B" + + def test_s2v_auto_ckpt_dir(self): + from serve.job_build import request_to_job + from serve.schemas import VideoGenerationRequest + s = self._make_settings("/ckpt") + req = VideoGenerationRequest(model="wan2.2-s2v-14b", input={"prompt": "talk", "image": "/img.jpg", "audio": "/a.wav"}) + job = request_to_job(req, task_id="wan-s2v", settings=s) + assert job["ckpt_dir"] == "/ckpt/Wan2.2-S2V-14B" + + def test_animate_auto_ckpt_dir(self): + from serve.job_build import request_to_job + from serve.schemas import VideoGenerationRequest + s = self._make_settings("/ckpt") + req = VideoGenerationRequest(model="wan2.2-animate-14b", input={"prompt": "pose", "video": "/ref.mp4"}) + job = request_to_job(req, task_id="wan-ani", settings=s) + assert job["ckpt_dir"] == "/ckpt/Wan2.2-Animate-14B" + + def test_ti2v_auto_ckpt_dir(self): + from serve.job_build import request_to_job + from serve.schemas import VideoGenerationRequest + s = self._make_settings("/ckpt") + req = VideoGenerationRequest(model="wan2.2-ti2v-5b", input={"prompt": "A cat"}) + job = request_to_job(req, task_id="wan-ti2v", settings=s) + assert job["ckpt_dir"] == "/ckpt/Wan2.2-TI2V-5B" + + def test_parameters_ckpt_dir_overrides_auto(self): + from serve.job_build import request_to_job + from serve.schemas import VideoGenerationRequest + s = self._make_settings("/ckpt") + req = VideoGenerationRequest( + model="wan2.2-t2v-a14b", + input={"prompt": "A cat"}, + parameters={"ckpt_dir": "/custom/path"}, + ) + job = request_to_job(req, task_id="wan-custom", settings=s) + assert job["ckpt_dir"] == "/custom/path" + + def test_no_global_ckpt_dir_no_parameters_ckpt_dir(self): + from serve.config import Settings + from serve.job_build import request_to_job + from serve.schemas import VideoGenerationRequest + env = {"WAN_SERVE_API_KEYS": "sk-test", "WAN_OUTPUT_DIR": "/out"} + for k, v in env.items(): + os.environ[k] = v + # Ensure WAN_CKPT_DIR is not set + os.environ.pop("WAN_CKPT_DIR", None) + try: + s = Settings.from_env() + assert s.ckpt_dir == "" + req = VideoGenerationRequest(model="wan2.2-t2v-a14b", input={"prompt": "A cat"}) + job = request_to_job(req, task_id="wan-nockpt", settings=s) + # No ckpt_dir set anywhere — it should not appear in job + assert "ckpt_dir" not in job + finally: + for k in env: + os.environ.pop(k, None) + + +# ============================================================ +# serve.job_build — video → src_root_path mapping +# ============================================================ + + +class TestVideoMapping: + def _make_settings(self): + from serve.config import Settings + env = {"WAN_SERVE_API_KEYS": "sk-test", "WAN_CKPT_DIR": "/ckpt", "WAN_OUTPUT_DIR": "/out"} + for k, v in env.items(): + os.environ[k] = v + s = Settings.from_env() + for k in env: + os.environ.pop(k, None) + return s + + def test_video_maps_to_src_root_path(self): + """VideoInput.video should be mapped to src_root_path in the job dict.""" + from serve.job_build import request_to_job + from serve.schemas import VideoGenerationRequest + s = self._make_settings() + req = VideoGenerationRequest( + model="wan2.2-animate-14b", + input={"prompt": "pose", "video": "/ckpt/animate_input"}, + ) + job = request_to_job(req, task_id="wan-ani", settings=s) + assert "src_root_path" in job + assert job["src_root_path"] == "/ckpt/animate_input" + assert "video" not in job + + def test_explicit_src_root_path_not_overridden(self): + """If both video and src_root_path are provided, src_root_path wins.""" + from serve.job_build import request_to_job + from serve.schemas import VideoGenerationRequest + s = self._make_settings() + req = VideoGenerationRequest( + model="wan2.2-animate-14b", + input={"prompt": "pose", "video": "/ckpt/video_path"}, + parameters={"src_root_path": "/ckpt/custom_path"}, + ) + job = request_to_job(req, task_id="wan-ani-exp", settings=s) + assert job["src_root_path"] == "/ckpt/custom_path" + assert "video" not in job \ No newline at end of file diff --git a/wan/animate.py b/wan/animate.py index 6fa4af46..cd43be00 100644 --- a/wan/animate.py +++ b/wan/animate.py @@ -1,8 +1,11 @@ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import gc import logging import math import os import cv2 +import random +import sys import types from copy import deepcopy from functools import partial @@ -15,25 +18,22 @@ from decord import VideoReader from tqdm import tqdm import torch.nn.functional as F +from .pipeline_base import WanPipelineBase from .distributed.fsdp import shard_model from .distributed.sequence_parallel import sp_attn_forward, sp_dit_forward from .distributed.util import get_world_size from .modules.animate import WanAnimateModel from .modules.animate import CLIPModel -from .modules.t5 import T5EncoderModel +from .modules.animate.face_blocks import FaceEncoder +from .modules.animate.motion_encoder import MotionEncoder +from .modules.animate.xlm_roberta import XLMRobertaEncoder from .modules.vae2_1 import Wan2_1_VAE from .modules.animate.animate_utils import TensorList, get_loraconfig -from .utils.fm_solvers import ( - FlowDPMSolverMultistepScheduler, - get_sampling_sigmas, - retrieve_timesteps, -) -from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler -class WanAnimate: +class WanAnimate(WanPipelineBase): def __init__( self, @@ -47,99 +47,61 @@ def __init__( t5_cpu=False, init_on_cpu=True, convert_model_dtype=False, - use_relighting_lora=False ): - r""" - Initializes the generation model components. - - Args: - config (EasyDict): - Object containing model parameters initialized from config.py - checkpoint_dir (`str`): - Path to directory containing model checkpoints - device_id (`int`, *optional*, defaults to 0): - Id of target GPU device - rank (`int`, *optional*, defaults to 0): - Process rank for distributed training - t5_fsdp (`bool`, *optional*, defaults to False): - Enable FSDP sharding for T5 model - dit_fsdp (`bool`, *optional*, defaults to False): - Enable FSDP sharding for DiT model - use_sp (`bool`, *optional*, defaults to False): - Enable distribution strategy of sequence parallel. - t5_cpu (`bool`, *optional*, defaults to False): - Whether to place T5 model on CPU. Only works without t5_fsdp. - init_on_cpu (`bool`, *optional*, defaults to True): - Enable initializing Transformer Model on CPU. Only works without FSDP or USP. - convert_model_dtype (`bool`, *optional*, defaults to False): - Convert DiT model parameters dtype to 'config.param_dtype'. - Only works without FSDP. - use_relighting_lora (`bool`, *optional*, defaults to False): - Whether to use relighting lora for character replacement. - """ - self.device = torch.device(f"cuda:{device_id}") - self.config = config - self.rank = rank - self.t5_cpu = t5_cpu - self.init_on_cpu = init_on_cpu - - self.num_train_timesteps = config.num_train_timesteps - self.param_dtype = config.param_dtype - - if t5_fsdp or dit_fsdp or use_sp: - self.init_on_cpu = False + super().__init__( + config=config, + checkpoint_dir=checkpoint_dir, + device_id=device_id, + rank=rank, + t5_fsdp=t5_fsdp, + dit_fsdp=dit_fsdp, + use_sp=use_sp, + t5_cpu=t5_cpu, + init_on_cpu=init_on_cpu, + convert_model_dtype=convert_model_dtype, + ) shard_fn = partial(shard_model, device_id=device_id) - self.text_encoder = T5EncoderModel( - text_len=config.text_len, - dtype=config.t5_dtype, - device=torch.device('cpu'), - checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint), - tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer), - shard_fn=shard_fn if t5_fsdp else None, - ) + # Animate-specific encoders self.clip = CLIPModel( - dtype=torch.float16, - device=self.device, - checkpoint_path=os.path.join(checkpoint_dir, - config.clip_checkpoint), - tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer)) - - self.vae = Wan2_1_VAE( - vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), - device=self.device) - - logging.info(f"Creating WanAnimate from {checkpoint_dir}") - - if not dit_fsdp: - self.noise_model = WanAnimateModel.from_pretrained( - checkpoint_dir, - torch_dtype=self.param_dtype, - device_map=self.device) - else: - self.noise_model = WanAnimateModel.from_pretrained( - checkpoint_dir, torch_dtype=self.param_dtype) + dtype=config.clip_dtype, + device=torch.device('cpu'), + checkpoint_path=os.path.join(checkpoint_dir, config.clip_checkpoint), + tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer), + ) - self.noise_model = self._configure_model( - model=self.noise_model, + logging.info(f"creating WanAnimateModel from {checkpoint_dir}") + self.model = WanAnimateModel.from_pretrained(checkpoint_dir, subfolder=config.checkpoint) + self.model = self._configure_model( + model=self.model, use_sp=use_sp, dit_fsdp=dit_fsdp, shard_fn=shard_fn, - convert_model_dtype=convert_model_dtype, - use_lora=use_relighting_lora, - checkpoint_dir=checkpoint_dir, - config=config - ) + convert_model_dtype=convert_model_dtype) - if use_sp: - self.sp_size = get_world_size() - else: - self.sp_size = 1 + self.face_encoder = FaceEncoder( + checkpoint_path=os.path.join( + checkpoint_dir, config.face_checkpoint), + device=self.device) + + self.xlmroberta = XLMRobertaEncoder( + device=self.device, + ) + + self.motion_encoder = MotionEncoder( + checkpoint_path=os.path.join( + checkpoint_dir, config.motion_checkpoint), + device=self.device, + ) - self.sample_neg_prompt = config.sample_neg_prompt self.sample_prompt = config.prompt + def _init_vae(self, config, checkpoint_dir): + self.vae = Wan2_1_VAE( + vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), + device=self.device) + def _configure_model(self, model, use_sp, dit_fsdp, shard_fn, convert_model_dtype, use_lora, checkpoint_dir, config): diff --git a/wan/configs/__init__.py b/wan/configs/__init__.py index 5a0ec3e7..5cc77ead 100644 --- a/wan/configs/__init__.py +++ b/wan/configs/__init__.py @@ -27,6 +27,8 @@ '1280*704': (1280, 704), '1024*704': (1024, 704), '704*1024': (704, 1024), + '1920*1080': (1920, 1080), + '1080*1920': (1080, 1920), } MAX_AREA_CONFIGS = { @@ -38,13 +40,15 @@ '1280*704': 1280 * 704, '1024*704': 1024 * 704, '704*1024': 704 * 1024, + '1920*1080': 1920 * 1080, + '1080*1920': 1080 * 1920, } SUPPORTED_SIZES = { - 't2v-A14B': ('720*1280', '1280*720', '480*832', '832*480'), - 'i2v-A14B': ('720*1280', '1280*720', '480*832', '832*480'), + 't2v-A14B': ('720*1280', '1280*720', '480*832', '832*480', '1920*1080', '1080*1920'), + 'i2v-A14B': ('720*1280', '1280*720', '480*832', '832*480', '1920*1080', '1080*1920'), 'ti2v-5B': ('704*1280', '1280*704'), 's2v-14B': ('720*1280', '1280*720', '480*832', '832*480', '1024*704', - '704*1024', '704*1280', '1280*704'), + '704*1024', '704*1280', '1280*704', '1920*1080', '1080*1920'), 'animate-14B': ('720*1280', '1280*720') } diff --git a/wan/distributed/fsdp.py b/wan/distributed/fsdp.py index 247b5eb3..96551352 100644 --- a/wan/distributed/fsdp.py +++ b/wan/distributed/fsdp.py @@ -18,7 +18,7 @@ def shard_model( process_group=None, sharding_strategy=ShardingStrategy.FULL_SHARD, sync_module_states=True, - use_lora=False + use_lora=False, ): model = FSDP( module=model, diff --git a/wan/image2video.py b/wan/image2video.py index 659564c2..5f7d26fd 100644 --- a/wan/image2video.py +++ b/wan/image2video.py @@ -5,32 +5,24 @@ import os import random import sys -import types -from contextlib import contextmanager from functools import partial import numpy as np import torch -import torch.cuda.amp as amp import torch.distributed as dist import torchvision.transforms.functional as TF from tqdm import tqdm +from .pipeline_base import WanPipelineBase from .distributed.fsdp import shard_model -from .distributed.sequence_parallel import sp_attn_forward, sp_dit_forward -from .distributed.util import get_world_size from .modules.model import WanModel -from .modules.t5 import T5EncoderModel from .modules.vae2_1 import Wan2_1_VAE -from .utils.fm_solvers import ( - FlowDPMSolverMultistepScheduler, - get_sampling_sigmas, - retrieve_timesteps, -) -from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler -class WanI2V: +class WanI2V(WanPipelineBase): + """Image-to-Video generation pipeline (A14B dual-expert).""" + + _use_dual_expert = True def __init__( self, @@ -45,61 +37,21 @@ def __init__( init_on_cpu=True, convert_model_dtype=False, ): - r""" - Initializes the image-to-video generation model components. - - Args: - config (EasyDict): - Object containing model parameters initialized from config.py - checkpoint_dir (`str`): - Path to directory containing model checkpoints - device_id (`int`, *optional*, defaults to 0): - Id of target GPU device - rank (`int`, *optional*, defaults to 0): - Process rank for distributed training - t5_fsdp (`bool`, *optional*, defaults to False): - Enable FSDP sharding for T5 model - dit_fsdp (`bool`, *optional*, defaults to False): - Enable FSDP sharding for DiT model - use_sp (`bool`, *optional*, defaults to False): - Enable distribution strategy of sequence parallel. - t5_cpu (`bool`, *optional*, defaults to False): - Whether to place T5 model on CPU. Only works without t5_fsdp. - init_on_cpu (`bool`, *optional*, defaults to True): - Enable initializing Transformer Model on CPU. Only works without FSDP or USP. - convert_model_dtype (`bool`, *optional*, defaults to False): - Convert DiT model parameters dtype to 'config.param_dtype'. - Only works without FSDP. - """ - self.device = torch.device(f"cuda:{device_id}") - self.config = config - self.rank = rank - self.t5_cpu = t5_cpu - self.init_on_cpu = init_on_cpu - - self.num_train_timesteps = config.num_train_timesteps - self.boundary = config.boundary - self.param_dtype = config.param_dtype - - if t5_fsdp or dit_fsdp or use_sp: - self.init_on_cpu = False - - shard_fn = partial(shard_model, device_id=device_id) - self.text_encoder = T5EncoderModel( - text_len=config.text_len, - dtype=config.t5_dtype, - device=torch.device('cpu'), - checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint), - tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer), - shard_fn=shard_fn if t5_fsdp else None, + super().__init__( + config=config, + checkpoint_dir=checkpoint_dir, + device_id=device_id, + rank=rank, + t5_fsdp=t5_fsdp, + dit_fsdp=dit_fsdp, + use_sp=use_sp, + t5_cpu=t5_cpu, + init_on_cpu=init_on_cpu, + convert_model_dtype=convert_model_dtype, ) - self.vae_stride = config.vae_stride - self.patch_size = config.patch_size - self.vae = Wan2_1_VAE( - vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), - device=self.device) - + # Dual-expert DiT models + shard_fn = partial(shard_model, device_id=device_id) logging.info(f"Creating WanModel from {checkpoint_dir}") self.low_noise_model = WanModel.from_pretrained( checkpoint_dir, subfolder=config.low_noise_checkpoint) @@ -118,90 +70,11 @@ def __init__( dit_fsdp=dit_fsdp, shard_fn=shard_fn, convert_model_dtype=convert_model_dtype) - if use_sp: - self.sp_size = get_world_size() - else: - self.sp_size = 1 - - self.sample_neg_prompt = config.sample_neg_prompt - - def _configure_model(self, model, use_sp, dit_fsdp, shard_fn, - convert_model_dtype): - """ - Configures a model object. This includes setting evaluation modes, - applying distributed parallel strategy, and handling device placement. - - Args: - model (torch.nn.Module): - The model instance to configure. - use_sp (`bool`): - Enable distribution strategy of sequence parallel. - dit_fsdp (`bool`): - Enable FSDP sharding for DiT model. - shard_fn (callable): - The function to apply FSDP sharding. - convert_model_dtype (`bool`): - Convert DiT model parameters dtype to 'config.param_dtype'. - Only works without FSDP. - - Returns: - torch.nn.Module: - The configured model. - """ - model.eval().requires_grad_(False) - - if use_sp: - for block in model.blocks: - block.self_attn.forward = types.MethodType( - sp_attn_forward, block.self_attn) - model.forward = types.MethodType(sp_dit_forward, model) - if dist.is_initialized(): - dist.barrier() - - if dit_fsdp: - model = shard_fn(model) - else: - if convert_model_dtype: - model.to(self.param_dtype) - if not self.init_on_cpu: - model.to(self.device) - - return model - - def _prepare_model_for_timestep(self, t, boundary, offload_model): - r""" - Prepares and returns the required model for the current timestep. - - Args: - t (torch.Tensor): - current timestep. - boundary (`int`): - The timestep threshold. If `t` is at or above this value, - the `high_noise_model` is considered as the required model. - offload_model (`bool`): - A flag intended to control the offloading behavior. - - Returns: - torch.nn.Module: - The active model on the target device for the current timestep. - """ - if t.item() >= boundary: - required_model_name = 'high_noise_model' - offload_model_name = 'low_noise_model' - else: - required_model_name = 'low_noise_model' - offload_model_name = 'high_noise_model' - if offload_model or self.init_on_cpu: - if next(getattr( - self, - offload_model_name).parameters()).device.type == 'cuda': - getattr(self, offload_model_name).to('cpu') - if next(getattr( - self, - required_model_name).parameters()).device.type == 'cpu': - getattr(self, required_model_name).to(self.device) - return getattr(self, required_model_name) + def _init_vae(self, config, checkpoint_dir): + self.vae = Wan2_1_VAE( + vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), + device=self.device) def generate(self, input_prompt, @@ -215,47 +88,7 @@ def generate(self, n_prompt="", seed=-1, offload_model=True): - r""" - Generates video frames from input image and text prompt using diffusion process. - - Args: - input_prompt (`str`): - Text prompt for content generation. - img (PIL.Image.Image): - Input image tensor. Shape: [3, H, W] - max_area (`int`, *optional*, defaults to 720*1280): - Maximum pixel area for latent space calculation. Controls video resolution scaling - frame_num (`int`, *optional*, defaults to 81): - How many frames to sample from a video. The number should be 4n+1 - shift (`float`, *optional*, defaults to 5.0): - Noise schedule shift parameter. Affects temporal dynamics - [NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0. - sample_solver (`str`, *optional*, defaults to 'unipc'): - Solver used to sample the video. - sampling_steps (`int`, *optional*, defaults to 40): - Number of diffusion sampling steps. Higher values improve quality but slow generation - guide_scale (`float` or tuple[`float`], *optional*, defaults 5.0): - Classifier-free guidance scale. Controls prompt adherence vs. creativity. - If tuple, the first guide_scale will be used for low noise model and - the second guide_scale will be used for high noise model. - n_prompt (`str`, *optional*, defaults to ""): - Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt` - seed (`int`, *optional*, defaults to -1): - Random seed for noise generation. If -1, use random seed - offload_model (`bool`, *optional*, defaults to True): - If True, offloads models to CPU during generation to save VRAM - - Returns: - torch.Tensor: - Generated video frames tensor. Dimensions: (C, N H, W) where: - - C: Color channels (3 for RGB) - - N: Number of frames (81) - - H: Frame height (from max_area) - - W: Frame width from max_area) - """ - # preprocess - guide_scale = (guide_scale, guide_scale) if isinstance( - guide_scale, float) else guide_scale + guide_scale = self._normalize_guide_scale(guide_scale) img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device) F = frame_num @@ -298,18 +131,7 @@ def generate(self, if n_prompt == "": n_prompt = self.sample_neg_prompt - # preprocess - if not self.t5_cpu: - self.text_encoder.model.to(self.device) - context = self.text_encoder([input_prompt], self.device) - context_null = self.text_encoder([n_prompt], self.device) - if offload_model: - self.text_encoder.model.cpu() - else: - context = self.text_encoder([input_prompt], torch.device('cpu')) - context_null = self.text_encoder([n_prompt], torch.device('cpu')) - context = [t.to(self.device) for t in context] - context_null = [t.to(self.device) for t in context_null] + context, context_null = self._encode_text(input_prompt, n_prompt, offload_model) y = self.vae.encode([ torch.concat([ @@ -322,46 +144,19 @@ def generate(self, ])[0] y = torch.concat([msk, y]) - @contextmanager - def noop_no_sync(): - yield + no_sync_low = self._get_no_sync(self.low_noise_model) + no_sync_high = self._get_no_sync(self.high_noise_model) - no_sync_low_noise = getattr(self.low_noise_model, 'no_sync', - noop_no_sync) - no_sync_high_noise = getattr(self.high_noise_model, 'no_sync', - noop_no_sync) - - # evaluation mode with ( torch.amp.autocast('cuda', dtype=self.param_dtype), torch.no_grad(), - no_sync_low_noise(), - no_sync_high_noise(), + no_sync_low(), + no_sync_high(), ): boundary = self.boundary * self.num_train_timesteps + sample_scheduler, timesteps = self._create_scheduler( + sample_solver, sampling_steps, shift) - if sample_solver == 'unipc': - sample_scheduler = FlowUniPCMultistepScheduler( - num_train_timesteps=self.num_train_timesteps, - shift=1, - use_dynamic_shifting=False) - sample_scheduler.set_timesteps( - sampling_steps, device=self.device, shift=shift) - timesteps = sample_scheduler.timesteps - elif sample_solver == 'dpm++': - sample_scheduler = FlowDPMSolverMultistepScheduler( - num_train_timesteps=self.num_train_timesteps, - shift=1, - use_dynamic_shifting=False) - sampling_sigmas = get_sampling_sigmas(sampling_steps, shift) - timesteps, _ = retrieve_timesteps( - sample_scheduler, - device=self.device, - sigmas=sampling_sigmas) - else: - raise NotImplementedError("Unsupported solver.") - - # sample videos latent = noise arg_c = { @@ -381,14 +176,10 @@ def noop_no_sync(): for _, t in enumerate(tqdm(timesteps)): latent_model_input = [latent.to(self.device)] - timestep = [t] - - timestep = torch.stack(timestep).to(self.device) + timestep = torch.stack([t]).to(self.device) - model = self._prepare_model_for_timestep( - t, boundary, offload_model) - sample_guide_scale = guide_scale[1] if t.item( - ) >= boundary else guide_scale[0] + model = self._prepare_model_for_timestep(t, boundary, offload_model) + sample_guide_scale = guide_scale[1] if t.item() >= boundary else guide_scale[0] noise_pred_cond = model( latent_model_input, t=timestep, **arg_c)[0] diff --git a/wan/modules/animate/motion_encoder.py b/wan/modules/animate/motion_encoder.py index d0e94397..932be0e2 100644 --- a/wan/modules/animate/motion_encoder.py +++ b/wan/modules/animate/motion_encoder.py @@ -304,4 +304,23 @@ def get_motion(self, img): motion_feat = torch.utils.checkpoint.checkpoint((self.enc.enc_motion), img, use_reentrant=True) with torch.cuda.amp.autocast(dtype=torch.float32): motion = self.dec.direction(motion_feat) - return motion \ No newline at end of file + return motion + + +class MotionEncoder(nn.Module): + """Wrapper around Generator that loads from checkpoint and extracts motion features.""" + + def __init__(self, checkpoint_path, device="cuda", size=256, style_dim=512, motion_dim=20): + super().__init__() + self.generator = Generator(size, style_dim, motion_dim) + state_dict = torch.load(checkpoint_path, map_location="cpu") + self.generator.load_state_dict(state_dict) + self.generator.to(device).eval() + self.device = device + + @torch.no_grad() + def forward(self, img): + return self.generator.get_motion(img) + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) \ No newline at end of file diff --git a/wan/modules/animate/xlm_roberta.py b/wan/modules/animate/xlm_roberta.py index 755baf39..72b5b687 100644 --- a/wan/modules/animate/xlm_roberta.py +++ b/wan/modules/animate/xlm_roberta.py @@ -4,7 +4,7 @@ import torch.nn as nn import torch.nn.functional as F -__all__ = ['XLMRoberta', 'xlm_roberta_large'] +__all__ = ['XLMRoberta', 'XLMRobertaEncoder', 'xlm_roberta_large'] class SelfAttention(nn.Module): @@ -143,6 +143,23 @@ def forward(self, ids): return x +class XLMRobertaEncoder(nn.Module): + """Wrapper around XLMRoberta that initializes the large model on device.""" + + def __init__(self, device="cuda"): + super().__init__() + self.model = xlm_roberta_large(device=device) + self.model.eval().requires_grad_(False) + self.device = device + + @torch.no_grad() + def forward(self, ids): + return self.model(ids) + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + def xlm_roberta_large(pretrained=False, return_tokenizer=False, device='cpu', diff --git a/wan/modules/model.py b/wan/modules/model.py index 6982fa15..b587e16c 100644 --- a/wan/modules/model.py +++ b/wan/modules/model.py @@ -458,7 +458,7 @@ def forward( # time embeddings if t.dim() == 1: - t = t.expand(t.size(0), seq_len) + t = t.unsqueeze(1).expand(t.size(0), seq_len) with torch.amp.autocast('cuda', dtype=torch.float32): bt = t.size(0) t = t.flatten() diff --git a/wan/modules/s2v/__init__.py b/wan/modules/s2v/__init__.py index d2ce56d4..aec6a747 100644 --- a/wan/modules/s2v/__init__.py +++ b/wan/modules/s2v/__init__.py @@ -1,5 +1,5 @@ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. from .audio_encoder import AudioEncoder -from .model_s2v import WanModel_S2V +from .model_s2v import WanS2VModel -__all__ = ['WanModel_S2V', 'AudioEncoder'] +__all__ = ['WanS2VModel', 'AudioEncoder'] diff --git a/wan/modules/s2v/model_s2v.py b/wan/modules/s2v/model_s2v.py index 82263bde..fe94afcc 100644 --- a/wan/modules/s2v/model_s2v.py +++ b/wan/modules/s2v/model_s2v.py @@ -244,7 +244,7 @@ def cross_attn_ffn(x, context, context_lens, e): return x -class WanModel_S2V(ModelMixin, ConfigMixin): +class WanS2VModel(ModelMixin, ConfigMixin): ignore_for_config = [ 'args', 'kwargs', 'patch_size', 'cross_attn_norm', 'qk_norm', 'text_dim', 'window_size' diff --git a/wan/pipeline_base.py b/wan/pipeline_base.py new file mode 100644 index 00000000..9d114ce2 --- /dev/null +++ b/wan/pipeline_base.py @@ -0,0 +1,285 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +""" +Base class for all Wan video generation pipelines. + +Extracts the common initialization, model configuration, text encoding, +and sampling utilities shared across T2V, I2V, TI2V, S2V, and Animate. +""" +import gc +import logging +import math +import os +import random +import sys +import types +from contextlib import contextmanager +from functools import partial + +import torch +import torch.distributed as dist +from tqdm import tqdm + +from .distributed.fsdp import shard_model +from .distributed.sequence_parallel import sp_attn_forward, sp_dit_forward +from .distributed.util import get_world_size +from .modules.t5 import T5EncoderModel +from .utils.fm_solvers import ( + FlowDPMSolverMultistepScheduler, + get_sampling_sigmas, + retrieve_timesteps, +) +from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler + + +class WanPipelineBase: + """Base class for Wan video generation pipelines. + + Provides shared logic for: + - T5 text encoder loading + - VAE loading (subclass selects version) + - DiT model configuration (FSDP, sequence parallel, dtype conversion) + - Text encoding (prompt + negative prompt) + - Sampling scheduler creation (UniPC / DPM++) + - Model offloading utilities + - Distributed barrier / cleanup + """ + + # Subclasses should set these or override __init__ completely. + _use_dual_expert = False # Set True for T2V-A14B / I2V-A14B + + def __init__( + self, + config, + checkpoint_dir, + device_id=0, + rank=0, + t5_fsdp=False, + dit_fsdp=False, + use_sp=False, + t5_cpu=False, + init_on_cpu=True, + convert_model_dtype=False, + ): + self.device = torch.device(f"cuda:{device_id}") + self.config = config + self.rank = rank + self.t5_cpu = t5_cpu + self.init_on_cpu = init_on_cpu + + self.num_train_timesteps = config.num_train_timesteps + self.param_dtype = config.param_dtype + self.boundary = getattr(config, 'boundary', None) + + if t5_fsdp or dit_fsdp or use_sp: + self.init_on_cpu = False + + shard_fn = partial(shard_model, device_id=device_id) + + # Text encoder (shared across all pipelines) + self.text_encoder = T5EncoderModel( + text_len=config.text_len, + dtype=config.t5_dtype, + device=torch.device('cpu'), + checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint), + tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer), + shard_fn=shard_fn if t5_fsdp else None, + ) + + # VAE – subclass is responsible for setting self.vae before or after + # calling super().__init__(), or override _init_vae(). + self._init_vae(config, checkpoint_dir) + + # DiT model(s) – subclass is responsible for setting self.model + # (single-expert) or self.low_noise_model / self.high_noise_model + # (dual-expert). Call _init_dit() in subclass after super().__init__(). + self.vae_stride = config.vae_stride + self.patch_size = config.patch_size + self.sp_size = get_world_size() if use_sp else 1 + self.sample_neg_prompt = config.sample_neg_prompt + + def _init_vae(self, config, checkpoint_dir): + """Override in subclass to select VAE version (2.1 or 2.2).""" + raise NotImplementedError("Subclass must implement _init_vae()") + + # ------------------------------------------------------------------ + # Model configuration helpers + # ------------------------------------------------------------------ + + def _configure_model(self, model, use_sp, dit_fsdp, shard_fn, + convert_model_dtype): + """Configure a DiT model: eval mode, SP hooks, FSDP, dtype, device.""" + model.eval().requires_grad_(False) + + if use_sp: + for block in model.blocks: + block.self_attn.forward = types.MethodType( + sp_attn_forward, block.self_attn) + model.forward = types.MethodType(sp_dit_forward, model) + + if dist.is_initialized(): + dist.barrier() + + if dit_fsdp: + model = shard_fn(model) + else: + if convert_model_dtype: + model.to(self.param_dtype) + if not self.init_on_cpu: + model.to(self.device) + + return model + + # ------------------------------------------------------------------ + # Dual-expert model switching (T2V-A14B / I2V-A14B) + # ------------------------------------------------------------------ + + def _prepare_model_for_timestep(self, t, boundary, offload_model): + """Return the active model for the current timestep (dual-expert).""" + if t.item() >= boundary: + required, offload = 'high_noise_model', 'low_noise_model' + else: + required, offload = 'low_noise_model', 'high_noise_model' + + if offload_model or self.init_on_cpu: + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + off_model = getattr(self, offload) + req_model = getattr(self, required) + is_fsdp = isinstance(off_model, FSDP) + if not is_fsdp: + if next(off_model.parameters()).device.type == 'cuda': + off_model.to('cpu') + if next(req_model.parameters()).device.type == 'cpu': + req_model.to(self.device) + + return getattr(self, required) + + # ------------------------------------------------------------------ + # Text encoding + # ------------------------------------------------------------------ + + def _encode_text(self, prompt, neg_prompt, offload_model=True): + """Encode prompt and negative prompt with T5. + + Returns: + (context, context_null) – both are lists of tensors on self.device. + """ + if not self.t5_cpu: + self.text_encoder.model.to(self.device) + context = self.text_encoder([prompt], self.device) + context_null = self.text_encoder([neg_prompt], self.device) + if offload_model: + self.text_encoder.model.cpu() + else: + context = self.text_encoder([prompt], torch.device('cpu')) + context_null = self.text_encoder([neg_prompt], torch.device('cpu')) + context = [t.to(self.device) for t in context] + context_null = [t.to(self.device) for t in context_null] + return context, context_null + + # ------------------------------------------------------------------ + # Sampling scheduler + # ------------------------------------------------------------------ + + def _create_scheduler(self, sample_solver, sampling_steps, shift): + """Create a flow-matching sampling scheduler and timesteps. + + Returns: + (sample_scheduler, timesteps) + """ + if sample_solver == 'unipc': + scheduler = FlowUniPCMultistepScheduler( + num_train_timesteps=self.num_train_timesteps, + shift=1, + use_dynamic_shifting=False) + scheduler.set_timesteps(sampling_steps, device=self.device, shift=shift) + timesteps = scheduler.timesteps + elif sample_solver == 'dpm++': + scheduler = FlowDPMSolverMultistepScheduler( + num_train_timesteps=self.num_train_timesteps, + shift=1, + use_dynamic_shifting=False) + sigmas = get_sampling_sigmas(sampling_steps, shift) + timesteps, _ = retrieve_timesteps(scheduler, device=self.device, sigmas=sigmas) + else: + raise NotImplementedError(f"Unsupported solver: {sample_solver}") + return scheduler, timesteps + + # ------------------------------------------------------------------ + # Seed handling + # ------------------------------------------------------------------ + + @staticmethod + def _make_seed(seed): + """Return a deterministic or random seed + generator.""" + if seed < 0: + seed = random.randint(0, sys.maxsize) + gen = torch.Generator() + gen.manual_seed(seed) + return seed, gen + + @staticmethod + def _make_seed_on_device(seed, device): + """Return a deterministic or random seed + CUDA generator.""" + if seed < 0: + seed = random.randint(0, sys.maxsize) + gen = torch.Generator(device=device) + gen.manual_seed(seed) + return seed, gen + + # ------------------------------------------------------------------ + # Distributed helpers + # ------------------------------------------------------------------ + + def _broadcast_seed(self, seed): + """Broadcast seed from rank 0 to all ranks.""" + if dist.is_initialized(): + seed_list = [seed] if self.rank == 0 else [None] + dist.broadcast_object_list(seed_list, src=0) + return seed_list[0] + return seed + + def _distributed_barrier(self): + if dist.is_initialized(): + dist.barrier() + + def _distributed_cleanup(self, offload_model=True): + """Final cleanup: garbage collect, sync, destroy process group.""" + if offload_model: + gc.collect() + torch.cuda.synchronize() + if dist.is_initialized(): + dist.barrier() + dist.destroy_process_group() + + # ------------------------------------------------------------------ + # Context manager for no-sync (FSDP gradient sync suppression) + # ------------------------------------------------------------------ + + @staticmethod + @contextmanager + def _noop_no_sync(): + yield + + def _get_no_sync(self, model): + return getattr(model, 'no_sync', self._noop_no_sync) + + # ------------------------------------------------------------------ + # Misc utilities + # ------------------------------------------------------------------ + + @staticmethod + def _normalize_guide_scale(guide_scale): + """Ensure guide_scale is a 2-tuple (low_noise, high_noise).""" + if isinstance(guide_scale, (int, float)): + return (float(guide_scale), float(guide_scale)) + return tuple(float(x) for x in guide_scale) + + def _compute_seq_len(self, target_shape, patch_size, sp_size): + """Compute padded sequence length for the DiT.""" + seq_len = math.ceil( + (target_shape[2] * target_shape[3]) + / (patch_size[1] * patch_size[2]) + * target_shape[1] + / sp_size + ) * sp_size + return seq_len diff --git a/wan/speech2video.py b/wan/speech2video.py index be9f5f14..c9fc5fb4 100644 --- a/wan/speech2video.py +++ b/wan/speech2video.py @@ -5,46 +5,23 @@ import os import random import sys -import types -from contextlib import contextmanager -from copy import deepcopy from functools import partial import numpy as np import torch -import torch.cuda.amp as amp import torch.distributed as dist import torchvision.transforms.functional as TF -from decord import VideoReader -from PIL import Image -from safetensors import safe_open -from torchvision import transforms from tqdm import tqdm +from .pipeline_base import WanPipelineBase from .distributed.fsdp import shard_model -from .distributed.sequence_parallel import sp_attn_forward, sp_dit_forward -from .distributed.util import get_world_size +from .modules.s2v.model_s2v import WanS2VModel from .modules.s2v.audio_encoder import AudioEncoder -from .modules.s2v.model_s2v import WanModel_S2V, sp_attn_forward_s2v -from .modules.t5 import T5EncoderModel -from .modules.vae2_1 import Wan2_1_VAE -from .utils.fm_solvers import ( - FlowDPMSolverMultistepScheduler, - get_sampling_sigmas, - retrieve_timesteps, -) -from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler +from .modules.vae2_2 import Wan2_2_VAE -def load_safetensors(path): - tensors = {} - with safe_open(path, framework="pt", device="cpu") as f: - for key in f.keys(): - tensors[key] = f.get_tensor(key) - return tensors - - -class WanS2V: +class WanS2V(WanPipelineBase): + """Speech-to-Video generation pipeline (14B single-expert, VAE 2.2).""" def __init__( self, @@ -59,349 +36,50 @@ def __init__( init_on_cpu=True, convert_model_dtype=False, ): - r""" - Initializes the image-to-video generation model components. - - Args: - config (EasyDict): - Object containing model parameters initialized from config.py - checkpoint_dir (`str`): - Path to directory containing model checkpoints - device_id (`int`, *optional*, defaults to 0): - Id of target GPU device - rank (`int`, *optional*, defaults to 0): - Process rank for distributed training - t5_fsdp (`bool`, *optional*, defaults to False): - Enable FSDP sharding for T5 model - dit_fsdp (`bool`, *optional*, defaults to False): - Enable FSDP sharding for DiT model - use_sp (`bool`, *optional*, defaults to False): - Enable distribution strategy of sequence parallel. - t5_cpu (`bool`, *optional*, defaults to False): - Whether to place T5 model on CPU. Only works without t5_fsdp. - init_on_cpu (`bool`, *optional*, defaults to True): - Enable initializing Transformer Model on CPU. Only works without FSDP or USP. - convert_model_dtype (`bool`, *optional*, defaults to False): - Convert DiT model parameters dtype to 'config.param_dtype'. - Only works without FSDP. - """ - self.device = torch.device(f"cuda:{device_id}") - self.config = config - self.rank = rank - self.t5_cpu = t5_cpu - self.init_on_cpu = init_on_cpu - - self.num_train_timesteps = config.num_train_timesteps - self.param_dtype = config.param_dtype - - if t5_fsdp or dit_fsdp or use_sp: - self.init_on_cpu = False - - shard_fn = partial(shard_model, device_id=device_id) - self.text_encoder = T5EncoderModel( - text_len=config.text_len, - dtype=config.t5_dtype, - device=torch.device('cpu'), - checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint), - tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer), - shard_fn=shard_fn if t5_fsdp else None, + super().__init__( + config=config, + checkpoint_dir=checkpoint_dir, + device_id=device_id, + rank=rank, + t5_fsdp=t5_fsdp, + dit_fsdp=dit_fsdp, + use_sp=use_sp, + t5_cpu=t5_cpu, + init_on_cpu=init_on_cpu, + convert_model_dtype=convert_model_dtype, ) - self.vae = Wan2_1_VAE( - vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), - device=self.device) - - logging.info(f"Creating WanModel from {checkpoint_dir}") - if not dit_fsdp: - self.noise_model = WanModel_S2V.from_pretrained( - checkpoint_dir, - torch_dtype=self.param_dtype, - device_map=self.device) - else: - self.noise_model = WanModel_S2V.from_pretrained( - checkpoint_dir, torch_dtype=self.param_dtype) - - self.noise_model = self._configure_model( - model=self.noise_model, + # Single-expert S2V DiT model + shard_fn = partial(shard_model, device_id=device_id) + logging.info(f"Creating WanS2VModel from {checkpoint_dir}") + self.model = WanS2VModel.from_pretrained( + checkpoint_dir, subfolder=config.checkpoint) + self.model = self._configure_model( + model=self.model, use_sp=use_sp, dit_fsdp=dit_fsdp, shard_fn=shard_fn, convert_model_dtype=convert_model_dtype) + # Audio encoder self.audio_encoder = AudioEncoder( - model_id=os.path.join(checkpoint_dir, - "wav2vec2-large-xlsr-53-english")) - - if use_sp: - self.sp_size = get_world_size() - else: - self.sp_size = 1 - - self.sample_neg_prompt = config.sample_neg_prompt - self.motion_frames = config.transformer.motion_frames - self.drop_first_motion = config.drop_first_motion - self.fps = config.sample_fps - self.audio_sample_m = 0 - - def _configure_model(self, model, use_sp, dit_fsdp, shard_fn, - convert_model_dtype): - """ - Configures a model object. This includes setting evaluation modes, - applying distributed parallel strategy, and handling device placement. - - Args: - model (torch.nn.Module): - The model instance to configure. - use_sp (`bool`): - Enable distribution strategy of sequence parallel. - dit_fsdp (`bool`): - Enable FSDP sharding for DiT model. - shard_fn (callable): - The function to apply FSDP sharding. - convert_model_dtype (`bool`): - Convert DiT model parameters dtype to 'config.param_dtype'. - Only works without FSDP. - - Returns: - torch.nn.Module: - The configured model. - """ - model.eval().requires_grad_(False) - if use_sp: - for block in model.blocks: - block.self_attn.forward = types.MethodType( - sp_attn_forward_s2v, block.self_attn) - model.use_context_parallel = True - - if dist.is_initialized(): - dist.barrier() - - if dit_fsdp: - model = shard_fn(model) - else: - if convert_model_dtype: - model.to(self.param_dtype) - if not self.init_on_cpu: - model.to(self.device) - - return model - - def get_size_less_than_area(self, - height, - width, - target_area=1024 * 704, - divisor=64): - if height * width <= target_area: - # If the original image area is already less than or equal to the target, - # no resizing is needed—just padding. Still need to ensure that the padded area doesn't exceed the target. - max_upper_area = target_area - min_scale = 0.1 - max_scale = 1.0 - else: - # Resize to fit within the target area and then pad to multiples of `divisor` - max_upper_area = target_area # Maximum allowed total pixel count after padding - d = divisor - 1 - b = d * (height + width) - a = height * width - c = d**2 - max_upper_area - - # Calculate scale boundaries using quadratic equation - min_scale = (-b + math.sqrt(b**2 - 2 * a * c)) / ( - 2 * a) # Scale when maximum padding is applied - max_scale = math.sqrt(max_upper_area / - (height * width)) # Scale without any padding - - # We want to choose the largest possible scale such that the final padded area does not exceed max_upper_area - # Use binary search-like iteration to find this scale - find_it = False - for i in range(100): - scale = max_scale - (max_scale - min_scale) * i / 100 - new_height, new_width = int(height * scale), int(width * scale) - - # Pad to make dimensions divisible by 64 - pad_height = (64 - new_height % 64) % 64 - pad_width = (64 - new_width % 64) % 64 - pad_top = pad_height // 2 - pad_bottom = pad_height - pad_top - pad_left = pad_width // 2 - pad_right = pad_width - pad_left - - padded_height, padded_width = new_height + pad_height, new_width + pad_width - - if padded_height * padded_width <= max_upper_area: - find_it = True - break - - if find_it: - return padded_height, padded_width - else: - # Fallback: calculate target dimensions based on aspect ratio and divisor alignment - aspect_ratio = width / height - target_width = int( - (target_area * aspect_ratio)**0.5 // divisor * divisor) - target_height = int( - (target_area / aspect_ratio)**0.5 // divisor * divisor) - - # Ensure the result is not larger than the original resolution - if target_width >= width or target_height >= height: - target_width = int(width // divisor * divisor) - target_height = int(height // divisor * divisor) - - return target_height, target_width - - def prepare_default_cond_input(self, - map_shape=[3, 12, 64, 64], - motion_frames=5, - lat_motion_frames=2, - enable_mano=False, - enable_kp=False, - enable_pose=False): - default_value = [1.0, -1.0, -1.0] - cond_enable = [enable_mano, enable_kp, enable_pose] - cond = [] - for d, c in zip(default_value, cond_enable): - if c: - map_value = torch.ones( - map_shape, dtype=self.param_dtype, device=self.device) * d - cond_lat = torch.cat([ - map_value[:, :, 0:1].repeat(1, 1, motion_frames, 1, 1), - map_value - ], - dim=2) - cond_lat = torch.stack( - self.vae.encode(cond_lat.to( - self.param_dtype)))[:, :, lat_motion_frames:].to( - self.param_dtype) - - cond.append(cond_lat) - if len(cond) >= 1: - cond = torch.cat(cond, dim=1) - else: - cond = None - return cond - - def encode_audio(self, audio_path, infer_frames): - z = self.audio_encoder.extract_audio_feat( - audio_path, return_all_layers=True) - audio_embed_bucket, num_repeat = self.audio_encoder.get_audio_embed_bucket_fps( - z, fps=self.fps, batch_frames=infer_frames, m=self.audio_sample_m) - audio_embed_bucket = audio_embed_bucket.to(self.device, - self.param_dtype) - audio_embed_bucket = audio_embed_bucket.unsqueeze(0) - if len(audio_embed_bucket.shape) == 3: - audio_embed_bucket = audio_embed_bucket.permute(0, 2, 1) - elif len(audio_embed_bucket.shape) == 4: - audio_embed_bucket = audio_embed_bucket.permute(0, 2, 3, 1) - return audio_embed_bucket, num_repeat - - def read_last_n_frames(self, - video_path, - n_frames, - target_fps=16, - reverse=False): - """ - Read the last `n_frames` from a video at the specified frame rate. - - Parameters: - video_path (str): Path to the video file. - n_frames (int): Number of frames to read. - target_fps (int, optional): Target sampling frame rate. Defaults to 16. - reverse (bool, optional): Whether to read frames in reverse order. - If True, reads the first `n_frames` instead of the last ones. - - Returns: - np.ndarray: A NumPy array of shape [n_frames, H, W, 3], representing the sampled video frames. - """ - vr = VideoReader(video_path) - original_fps = vr.get_avg_fps() - total_frames = len(vr) - - interval = max(1, round(original_fps / target_fps)) - - required_span = (n_frames - 1) * interval - - start_frame = max(0, total_frames - required_span - - 1) if not reverse else 0 - - sampled_indices = [] - for i in range(n_frames): - indice = start_frame + i * interval - if indice >= total_frames: - break - else: - sampled_indices.append(indice) - - return vr.get_batch(sampled_indices).asnumpy() - - def load_pose_cond(self, pose_video, num_repeat, infer_frames, size): - HEIGHT, WIDTH = size - if not pose_video is None: - pose_seq = self.read_last_n_frames( - pose_video, - n_frames=infer_frames * num_repeat, - target_fps=self.fps, - reverse=True) - - resize_opreat = transforms.Resize(min(HEIGHT, WIDTH)) - crop_opreat = transforms.CenterCrop((HEIGHT, WIDTH)) - tensor_trans = transforms.ToTensor() - - cond_tensor = torch.from_numpy(pose_seq) - cond_tensor = cond_tensor.permute(0, 3, 1, 2) / 255.0 * 2 - 1.0 - cond_tensor = crop_opreat(resize_opreat(cond_tensor)).permute( - 1, 0, 2, 3).unsqueeze(0) - - padding_frame_num = num_repeat * infer_frames - cond_tensor.shape[2] - cond_tensor = torch.cat([ - cond_tensor, - - torch.ones([1, 3, padding_frame_num, HEIGHT, WIDTH]) - ], - dim=2) - - cond_tensors = torch.chunk(cond_tensor, num_repeat, dim=2) - else: - cond_tensors = [-torch.ones([1, 3, infer_frames, HEIGHT, WIDTH])] - - COND = [] - for r in range(len(cond_tensors)): - cond = cond_tensors[r] - cond = torch.cat([cond[:, :, 0:1].repeat(1, 1, 1, 1, 1), cond], - dim=2) - cond_lat = torch.stack( - self.vae.encode( - cond.to(dtype=self.param_dtype, - device=self.device)))[:, :, - 1:].cpu() # for mem save - COND.append(cond_lat) - return COND + checkpoint_path=os.path.join( + checkpoint_dir, config.audio_checkpoint), + device=self.device, + ) - def get_gen_size(self, size, max_area, ref_image_path, pre_video_path): - if not size is None: - HEIGHT, WIDTH = size - else: - if pre_video_path: - ref_image = self.read_last_n_frames( - pre_video_path, n_frames=1)[0] - else: - ref_image = np.array(Image.open(ref_image_path).convert('RGB')) - HEIGHT, WIDTH = ref_image.shape[:2] - HEIGHT, WIDTH = self.get_size_less_than_area( - HEIGHT, WIDTH, target_area=max_area) - return (HEIGHT, WIDTH) + def _init_vae(self, config, checkpoint_dir): + self.vae = Wan2_2_VAE( + vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), + device=self.device) def generate( self, input_prompt, - ref_image_path, + img, audio_path, - enable_tts, - tts_prompt_audio, - tts_prompt_text, - tts_text, - num_repeat=1, - pose_video=None, max_area=720 * 1280, - infer_frames=80, + frame_num=81, shift=5.0, sample_solver='unipc', sampling_steps=40, @@ -409,266 +87,128 @@ def generate( n_prompt="", seed=-1, offload_model=True, - init_first_frame=False, ): - r""" - Generates video frames from input image and text prompt using diffusion process. - - Args: - input_prompt (`str`): - Text prompt for content generation. - ref_image_path ('str'): - Input image path - audio_path ('str'): - Audio for video driven - num_repeat ('int'): - Number of clips to generate; will be automatically adjusted based on the audio length - pose_video ('str'): - If provided, uses a sequence of poses to drive the generated video - max_area (`int`, *optional*, defaults to 720*1280): - Maximum pixel area for latent space calculation. Controls video resolution scaling - infer_frames (`int`, *optional*, defaults to 80): - How many frames to generate per clips. The number should be 4n - shift (`float`, *optional*, defaults to 5.0): - Noise schedule shift parameter. Affects temporal dynamics - [NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0. - sample_solver (`str`, *optional*, defaults to 'unipc'): - Solver used to sample the video. - sampling_steps (`int`, *optional*, defaults to 40): - Number of diffusion sampling steps. Higher values improve quality but slow generation - guide_scale (`float` or tuple[`float`], *optional*, defaults 5.0): - Classifier-free guidance scale. Controls prompt adherence vs. creativity. - If tuple, the first guide_scale will be used for low noise model and - the second guide_scale will be used for high noise model. - n_prompt (`str`, *optional*, defaults to ""): - Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt` - seed (`int`, *optional*, defaults to -1): - Random seed for noise generation. If -1, use random seed - offload_model (`bool`, *optional*, defaults to True): - If True, offloads models to CPU during generation to save VRAM - init_first_frame (`bool`, *optional*, defaults to False): - Whether to use the reference image as the first frame (i.e., standard image-to-video generation) - - Returns: - torch.Tensor: - Generated video frames tensor. Dimensions: (C, N H, W) where: - - C: Color channels (3 for RGB) - - N: Number of frames (81) - - H: Frame height (from max_area) - - W: Frame width from max_area) - """ - # preprocess - size = self.get_gen_size( - size=None, - max_area=max_area, - ref_image_path=ref_image_path, - pre_video_path=None) - HEIGHT, WIDTH = size - channel = 3 + img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device) + + F = frame_num + h, w = img.shape[1:] + aspect_ratio = h / w + lat_h = round( + np.sqrt(max_area * aspect_ratio) // self.vae_stride[1] // + self.patch_size[1] * self.patch_size[1]) + lat_w = round( + np.sqrt(max_area / aspect_ratio) // self.vae_stride[2] // + self.patch_size[2] * self.patch_size[2]) + h = lat_h * self.vae_stride[1] + w = lat_w * self.vae_stride[2] + + max_seq_len = ((F - 1) // self.vae_stride[0] + 1) * lat_h * lat_w // ( + self.patch_size[1] * self.patch_size[2]) + max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size - resize_opreat = transforms.Resize(min(HEIGHT, WIDTH)) - crop_opreat = transforms.CenterCrop((HEIGHT, WIDTH)) - tensor_trans = transforms.ToTensor() - - ref_image = None - motion_latents = None - - if ref_image is None: - ref_image = np.array(Image.open(ref_image_path).convert('RGB')) - if motion_latents is None: - motion_latents = torch.zeros( - [1, channel, self.motion_frames, HEIGHT, WIDTH], - dtype=self.param_dtype, - device=self.device) - - # extract audio emb - if enable_tts is True: - audio_path = self.tts(tts_prompt_audio, tts_prompt_text, tts_text) - audio_emb, nr = self.encode_audio(audio_path, infer_frames=infer_frames) - if num_repeat is None or num_repeat > nr: - num_repeat = nr - - lat_motion_frames = (self.motion_frames + 3) // 4 - model_pic = crop_opreat(resize_opreat(Image.fromarray(ref_image))) + seed = seed if seed >= 0 else random.randint(0, sys.maxsize) + seed_g = torch.Generator(device=self.device) + seed_g.manual_seed(seed) + noise = torch.randn( + 16, + (F - 1) // self.vae_stride[0] + 1, + lat_h, + lat_w, + dtype=torch.float32, + generator=seed_g, + device=self.device) - ref_pixel_values = tensor_trans(model_pic) - ref_pixel_values = ref_pixel_values.unsqueeze(1).unsqueeze( - 0) * 2 - 1.0 # b c 1 h w - ref_pixel_values = ref_pixel_values.to( - dtype=self.vae.dtype, device=self.vae.device) - ref_latents = torch.stack(self.vae.encode(ref_pixel_values)) + msk = torch.ones(1, F, lat_h, lat_w, device=self.device) + msk[:, 1:] = 0 + msk = torch.concat([ + torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:] + ], + dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) + msk = msk.transpose(1, 2)[0] - # encode the motion latents - videos_last_frames = motion_latents.detach() - drop_first_motion = self.drop_first_motion - if init_first_frame: - drop_first_motion = False - motion_latents[:, :, -6:] = ref_pixel_values - motion_latents = torch.stack(self.vae.encode(motion_latents)) + if n_prompt == "": + n_prompt = self.sample_neg_prompt - # get pose cond input if need - COND = self.load_pose_cond( - pose_video=pose_video, - num_repeat=num_repeat, - infer_frames=infer_frames, - size=size) + context, context_null = self._encode_text(input_prompt, n_prompt, offload_model) - seed = seed if seed >= 0 else random.randint(0, sys.maxsize) + y = self.vae.encode([ + torch.concat([ + torch.nn.functional.interpolate( + img[None].cpu(), size=(h, w), mode='bicubic').transpose( + 0, 1), + torch.zeros(3, F - 1, h, w) + ], + dim=1).to(self.device) + ])[0] + y = torch.concat([msk, y]) - if n_prompt == "": - n_prompt = self.sample_neg_prompt + # Encode audio + audio_emb = self.audio_encoder(audio_path, self.device) - # preprocess - if not self.t5_cpu: - self.text_encoder.model.to(self.device) - context = self.text_encoder([input_prompt], self.device) - context_null = self.text_encoder([n_prompt], self.device) - if offload_model: - self.text_encoder.model.cpu() - else: - context = self.text_encoder([input_prompt], torch.device('cpu')) - context_null = self.text_encoder([n_prompt], torch.device('cpu')) - context = [t.to(self.device) for t in context] - context_null = [t.to(self.device) for t in context_null] + no_sync = self._get_no_sync(self.model) - out = [] - # evaluation mode with ( torch.amp.autocast('cuda', dtype=self.param_dtype), torch.no_grad(), + no_sync(), ): - for r in range(num_repeat): - seed_g = torch.Generator(device=self.device) - seed_g.manual_seed(seed + r) + sample_scheduler, timesteps = self._create_scheduler( + sample_solver, sampling_steps, shift) - lat_target_frames = (infer_frames + 3 + self.motion_frames - ) // 4 - lat_motion_frames - target_shape = [lat_target_frames, HEIGHT // 8, WIDTH // 8] - noise = [ - torch.randn( - 16, - target_shape[0], - target_shape[1], - target_shape[2], - dtype=self.param_dtype, - device=self.device, - generator=seed_g) - ] - max_seq_len = np.prod(target_shape) // 4 + latent = noise - if sample_solver == 'unipc': - sample_scheduler = FlowUniPCMultistepScheduler( - num_train_timesteps=self.num_train_timesteps, - shift=1, - use_dynamic_shifting=False) - sample_scheduler.set_timesteps( - sampling_steps, device=self.device, shift=shift) - timesteps = sample_scheduler.timesteps - elif sample_solver == 'dpm++': - sample_scheduler = FlowDPMSolverMultistepScheduler( - num_train_timesteps=self.num_train_timesteps, - shift=1, - use_dynamic_shifting=False) - sampling_sigmas = get_sampling_sigmas(sampling_steps, shift) - timesteps, _ = retrieve_timesteps( - sample_scheduler, - device=self.device, - sigmas=sampling_sigmas) - else: - raise NotImplementedError("Unsupported solver.") + arg_c = { + 'context': [context[0]], + 'seq_len': max_seq_len, + 'y': [y], + 'audio_emb': [audio_emb], + } - latents = deepcopy(noise) - with torch.no_grad(): - left_idx = r * infer_frames - right_idx = r * infer_frames + infer_frames - cond_latents = COND[r] if pose_video else COND[0] * 0 - cond_latents = cond_latents.to( - dtype=self.param_dtype, device=self.device) - audio_input = audio_emb[..., left_idx:right_idx] - input_motion_latents = motion_latents.clone() - - arg_c = { - 'context': context[0:1], - 'seq_len': max_seq_len, - 'cond_states': cond_latents, - "motion_latents": input_motion_latents, - 'ref_latents': ref_latents, - "audio_input": audio_input, - "motion_frames": [self.motion_frames, lat_motion_frames], - "drop_motion_frames": drop_first_motion and r == 0, - } - if guide_scale > 1: - arg_null = { - 'context': context_null[0:1], - 'seq_len': max_seq_len, - 'cond_states': cond_latents, - "motion_latents": input_motion_latents, - 'ref_latents': ref_latents, - "audio_input": 0.0 * audio_input, - "motion_frames": [ - self.motion_frames, lat_motion_frames - ], - "drop_motion_frames": drop_first_motion and r == 0, - } - if offload_model or self.init_on_cpu: - self.noise_model.to(self.device) - torch.cuda.empty_cache() + arg_null = { + 'context': context_null, + 'seq_len': max_seq_len, + 'y': [y], + 'audio_emb': [audio_emb], + } - for i, t in enumerate(tqdm(timesteps)): - latent_model_input = latents[0:1] - timestep = [t] + if offload_model: + torch.cuda.empty_cache() - timestep = torch.stack(timestep).to(self.device) + for _, t in enumerate(tqdm(timesteps)): + latent_model_input = [latent.to(self.device)] + timestep = torch.stack([t]).to(self.device) - noise_pred_cond = self.noise_model( - latent_model_input, t=timestep, **arg_c) + noise_pred_cond = self.model( + latent_model_input, t=timestep, **arg_c)[0] + if offload_model: + torch.cuda.empty_cache() + noise_pred_uncond = self.model( + latent_model_input, t=timestep, **arg_null)[0] + if offload_model: + torch.cuda.empty_cache() + noise_pred = noise_pred_uncond + guide_scale * ( + noise_pred_cond - noise_pred_uncond) - if guide_scale > 1: - noise_pred_uncond = self.noise_model( - latent_model_input, t=timestep, **arg_null) - noise_pred = [ - u + guide_scale * (c - u) - for c, u in zip(noise_pred_cond, noise_pred_uncond) - ] - else: - noise_pred = noise_pred_cond + temp_x0 = sample_scheduler.step( + noise_pred.unsqueeze(0), + t, + latent.unsqueeze(0), + return_dict=False, + generator=seed_g)[0] + latent = temp_x0.squeeze(0) - temp_x0 = sample_scheduler.step( - noise_pred[0].unsqueeze(0), - t, - latents[0].unsqueeze(0), - return_dict=False, - generator=seed_g)[0] - latents[0] = temp_x0.squeeze(0) + x0 = [latent] + del latent_model_input, timestep - if offload_model: - self.noise_model.cpu() - torch.cuda.synchronize() - torch.cuda.empty_cache() - latents = torch.stack(latents) - if not (drop_first_motion and r == 0): - decode_latents = torch.cat([motion_latents, latents], dim=2) - else: - decode_latents = torch.cat([ref_latents, latents], dim=2) - image = torch.stack(self.vae.decode(decode_latents)) - image = image[:, :, -(infer_frames):] - if (drop_first_motion and r == 0): - image = image[:, :, 3:] + if offload_model: + self.model.cpu() + torch.cuda.empty_cache() - overlap_frames_num = min(self.motion_frames, image.shape[2]) - videos_last_frames = torch.cat([ - videos_last_frames[:, :, overlap_frames_num:], - image[:, :, -overlap_frames_num:] - ], - dim=2) - videos_last_frames = videos_last_frames.to( - dtype=motion_latents.dtype, device=motion_latents.device) - motion_latents = torch.stack( - self.vae.encode(videos_last_frames)) - out.append(image.cpu()) + if self.rank == 0: + videos = self.vae.decode(x0) - videos = torch.cat(out, dim=2) - del noise, latents + del noise, latent, x0 del sample_scheduler if offload_model: gc.collect() @@ -677,31 +217,3 @@ def generate( dist.barrier() return videos[0] if self.rank == 0 else None - - def tts(self, tts_prompt_audio, tts_prompt_text, tts_text): - if not hasattr(self, 'cosyvoice'): - self.load_tts() - speech_list = [] - from cosyvoice.utils.file_utils import load_wav - import torchaudio - prompt_speech_16k = load_wav(tts_prompt_audio, 16000) - if tts_prompt_text is not None: - for i in self.cosyvoice.inference_zero_shot(tts_text, tts_prompt_text, prompt_speech_16k): - speech_list.append(i['tts_speech']) - else: - for i in self.cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k): - speech_list.append(i['tts_speech']) - torchaudio.save('tts.wav', torch.concat(speech_list, dim=1), self.cosyvoice.sample_rate) - return 'tts.wav' - - def load_tts(self): - if not os.path.exists('CosyVoice'): - from wan.utils.utils import download_cosyvoice_repo - download_cosyvoice_repo('CosyVoice') - if not os.path.exists('CosyVoice2-0.5B'): - from wan.utils.utils import download_cosyvoice_model - download_cosyvoice_model('CosyVoice2-0.5B', 'CosyVoice2-0.5B') - sys.path.append('CosyVoice') - sys.path.append('CosyVoice/third_party/Matcha-TTS') - from cosyvoice.cli.cosyvoice import CosyVoice2 - self.cosyvoice = CosyVoice2('CosyVoice2-0.5B') \ No newline at end of file diff --git a/wan/text2video.py b/wan/text2video.py index 7c79c667..578bd6a0 100644 --- a/wan/text2video.py +++ b/wan/text2video.py @@ -5,30 +5,22 @@ import os import random import sys -import types -from contextlib import contextmanager from functools import partial import torch -import torch.cuda.amp as amp import torch.distributed as dist from tqdm import tqdm +from .pipeline_base import WanPipelineBase from .distributed.fsdp import shard_model -from .distributed.sequence_parallel import sp_attn_forward, sp_dit_forward -from .distributed.util import get_world_size from .modules.model import WanModel -from .modules.t5 import T5EncoderModel from .modules.vae2_1 import Wan2_1_VAE -from .utils.fm_solvers import ( - FlowDPMSolverMultistepScheduler, - get_sampling_sigmas, - retrieve_timesteps, -) -from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler -class WanT2V: +class WanT2V(WanPipelineBase): + """Text-to-Video generation pipeline (A14B dual-expert).""" + + _use_dual_expert = True def __init__( self, @@ -43,60 +35,21 @@ def __init__( init_on_cpu=True, convert_model_dtype=False, ): - r""" - Initializes the Wan text-to-video generation model components. - - Args: - config (EasyDict): - Object containing model parameters initialized from config.py - checkpoint_dir (`str`): - Path to directory containing model checkpoints - device_id (`int`, *optional*, defaults to 0): - Id of target GPU device - rank (`int`, *optional*, defaults to 0): - Process rank for distributed training - t5_fsdp (`bool`, *optional*, defaults to False): - Enable FSDP sharding for T5 model - dit_fsdp (`bool`, *optional*, defaults to False): - Enable FSDP sharding for DiT model - use_sp (`bool`, *optional*, defaults to False): - Enable distribution strategy of sequence parallel. - t5_cpu (`bool`, *optional*, defaults to False): - Whether to place T5 model on CPU. Only works without t5_fsdp. - init_on_cpu (`bool`, *optional*, defaults to True): - Enable initializing Transformer Model on CPU. Only works without FSDP or USP. - convert_model_dtype (`bool`, *optional*, defaults to False): - Convert DiT model parameters dtype to 'config.param_dtype'. - Only works without FSDP. - """ - self.device = torch.device(f"cuda:{device_id}") - self.config = config - self.rank = rank - self.t5_cpu = t5_cpu - self.init_on_cpu = init_on_cpu - - self.num_train_timesteps = config.num_train_timesteps - self.boundary = config.boundary - self.param_dtype = config.param_dtype - - if t5_fsdp or dit_fsdp or use_sp: - self.init_on_cpu = False + super().__init__( + config=config, + checkpoint_dir=checkpoint_dir, + device_id=device_id, + rank=rank, + t5_fsdp=t5_fsdp, + dit_fsdp=dit_fsdp, + use_sp=use_sp, + t5_cpu=t5_cpu, + init_on_cpu=init_on_cpu, + convert_model_dtype=convert_model_dtype, + ) + # Dual-expert DiT models shard_fn = partial(shard_model, device_id=device_id) - self.text_encoder = T5EncoderModel( - text_len=config.text_len, - dtype=config.t5_dtype, - device=torch.device('cpu'), - checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint), - tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer), - shard_fn=shard_fn if t5_fsdp else None) - - self.vae_stride = config.vae_stride - self.patch_size = config.patch_size - self.vae = Wan2_1_VAE( - vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), - device=self.device) - logging.info(f"Creating WanModel from {checkpoint_dir}") self.low_noise_model = WanModel.from_pretrained( checkpoint_dir, subfolder=config.low_noise_checkpoint) @@ -115,90 +68,11 @@ def __init__( dit_fsdp=dit_fsdp, shard_fn=shard_fn, convert_model_dtype=convert_model_dtype) - if use_sp: - self.sp_size = get_world_size() - else: - self.sp_size = 1 - - self.sample_neg_prompt = config.sample_neg_prompt - - def _configure_model(self, model, use_sp, dit_fsdp, shard_fn, - convert_model_dtype): - """ - Configures a model object. This includes setting evaluation modes, - applying distributed parallel strategy, and handling device placement. - - Args: - model (torch.nn.Module): - The model instance to configure. - use_sp (`bool`): - Enable distribution strategy of sequence parallel. - dit_fsdp (`bool`): - Enable FSDP sharding for DiT model. - shard_fn (callable): - The function to apply FSDP sharding. - convert_model_dtype (`bool`): - Convert DiT model parameters dtype to 'config.param_dtype'. - Only works without FSDP. - Returns: - torch.nn.Module: - The configured model. - """ - model.eval().requires_grad_(False) - - if use_sp: - for block in model.blocks: - block.self_attn.forward = types.MethodType( - sp_attn_forward, block.self_attn) - model.forward = types.MethodType(sp_dit_forward, model) - - if dist.is_initialized(): - dist.barrier() - - if dit_fsdp: - model = shard_fn(model) - else: - if convert_model_dtype: - model.to(self.param_dtype) - if not self.init_on_cpu: - model.to(self.device) - - return model - - def _prepare_model_for_timestep(self, t, boundary, offload_model): - r""" - Prepares and returns the required model for the current timestep. - - Args: - t (torch.Tensor): - current timestep. - boundary (`int`): - The timestep threshold. If `t` is at or above this value, - the `high_noise_model` is considered as the required model. - offload_model (`bool`): - A flag intended to control the offloading behavior. - - Returns: - torch.nn.Module: - The active model on the target device for the current timestep. - """ - if t.item() >= boundary: - required_model_name = 'high_noise_model' - offload_model_name = 'low_noise_model' - else: - required_model_name = 'low_noise_model' - offload_model_name = 'high_noise_model' - if offload_model or self.init_on_cpu: - if next(getattr( - self, - offload_model_name).parameters()).device.type == 'cuda': - getattr(self, offload_model_name).to('cpu') - if next(getattr( - self, - required_model_name).parameters()).device.type == 'cpu': - getattr(self, required_model_name).to(self.device) - return getattr(self, required_model_name) + def _init_vae(self, config, checkpoint_dir): + self.vae = Wan2_1_VAE( + vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), + device=self.device) def generate(self, input_prompt, @@ -211,52 +85,12 @@ def generate(self, n_prompt="", seed=-1, offload_model=True): - r""" - Generates video frames from text prompt using diffusion process. - - Args: - input_prompt (`str`): - Text prompt for content generation - size (`tuple[int]`, *optional*, defaults to (1280,720)): - Controls video resolution, (width,height). - frame_num (`int`, *optional*, defaults to 81): - How many frames to sample from a video. The number should be 4n+1 - shift (`float`, *optional*, defaults to 5.0): - Noise schedule shift parameter. Affects temporal dynamics - sample_solver (`str`, *optional*, defaults to 'unipc'): - Solver used to sample the video. - sampling_steps (`int`, *optional*, defaults to 50): - Number of diffusion sampling steps. Higher values improve quality but slow generation - guide_scale (`float` or tuple[`float`], *optional*, defaults 5.0): - Classifier-free guidance scale. Controls prompt adherence vs. creativity. - If tuple, the first guide_scale will be used for low noise model and - the second guide_scale will be used for high noise model. - n_prompt (`str`, *optional*, defaults to ""): - Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt` - seed (`int`, *optional*, defaults to -1): - Random seed for noise generation. If -1, use random seed. - offload_model (`bool`, *optional*, defaults to True): - If True, offloads models to CPU during generation to save VRAM - - Returns: - torch.Tensor: - Generated video frames tensor. Dimensions: (C, N H, W) where: - - C: Color channels (3 for RGB) - - N: Number of frames (81) - - H: Frame height (from size) - - W: Frame width from size) - """ - # preprocess - guide_scale = (guide_scale, guide_scale) if isinstance( - guide_scale, float) else guide_scale + guide_scale = self._normalize_guide_scale(guide_scale) F = frame_num target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1, size[1] // self.vae_stride[1], size[0] // self.vae_stride[2]) - - seq_len = math.ceil((target_shape[2] * target_shape[3]) / - (self.patch_size[1] * self.patch_size[2]) * - target_shape[1] / self.sp_size) * self.sp_size + seq_len = self._compute_seq_len(target_shape, self.patch_size, self.sp_size) if n_prompt == "": n_prompt = self.sample_neg_prompt @@ -264,17 +98,7 @@ def generate(self, seed_g = torch.Generator(device=self.device) seed_g.manual_seed(seed) - if not self.t5_cpu: - self.text_encoder.model.to(self.device) - context = self.text_encoder([input_prompt], self.device) - context_null = self.text_encoder([n_prompt], self.device) - if offload_model: - self.text_encoder.model.cpu() - else: - context = self.text_encoder([input_prompt], torch.device('cpu')) - context_null = self.text_encoder([n_prompt], torch.device('cpu')) - context = [t.to(self.device) for t in context] - context_null = [t.to(self.device) for t in context_null] + context, context_null = self._encode_text(input_prompt, n_prompt, offload_model) noise = [ torch.randn( @@ -287,67 +111,32 @@ def generate(self, generator=seed_g) ] - @contextmanager - def noop_no_sync(): - yield - - no_sync_low_noise = getattr(self.low_noise_model, 'no_sync', - noop_no_sync) - no_sync_high_noise = getattr(self.high_noise_model, 'no_sync', - noop_no_sync) + no_sync_low = self._get_no_sync(self.low_noise_model) + no_sync_high = self._get_no_sync(self.high_noise_model) - # evaluation mode with ( torch.amp.autocast('cuda', dtype=self.param_dtype), torch.no_grad(), - no_sync_low_noise(), - no_sync_high_noise(), + no_sync_low(), + no_sync_high(), ): boundary = self.boundary * self.num_train_timesteps + sample_scheduler, timesteps = self._create_scheduler( + sample_solver, sampling_steps, shift) - if sample_solver == 'unipc': - sample_scheduler = FlowUniPCMultistepScheduler( - num_train_timesteps=self.num_train_timesteps, - shift=1, - use_dynamic_shifting=False) - sample_scheduler.set_timesteps( - sampling_steps, device=self.device, shift=shift) - timesteps = sample_scheduler.timesteps - elif sample_solver == 'dpm++': - sample_scheduler = FlowDPMSolverMultistepScheduler( - num_train_timesteps=self.num_train_timesteps, - shift=1, - use_dynamic_shifting=False) - sampling_sigmas = get_sampling_sigmas(sampling_steps, shift) - timesteps, _ = retrieve_timesteps( - sample_scheduler, - device=self.device, - sigmas=sampling_sigmas) - else: - raise NotImplementedError("Unsupported solver.") - - # sample videos latents = noise - arg_c = {'context': context, 'seq_len': seq_len} arg_null = {'context': context_null, 'seq_len': seq_len} for _, t in enumerate(tqdm(timesteps)): latent_model_input = latents - timestep = [t] - - timestep = torch.stack(timestep) - - model = self._prepare_model_for_timestep( - t, boundary, offload_model) - sample_guide_scale = guide_scale[1] if t.item( - ) >= boundary else guide_scale[0] + timestep = torch.stack([t]) - noise_pred_cond = model( - latent_model_input, t=timestep, **arg_c)[0] - noise_pred_uncond = model( - latent_model_input, t=timestep, **arg_null)[0] + model = self._prepare_model_for_timestep(t, boundary, offload_model) + sample_guide_scale = guide_scale[1] if t.item() >= boundary else guide_scale[0] + noise_pred_cond = model(latent_model_input, t=timestep, **arg_c)[0] + noise_pred_uncond = model(latent_model_input, t=timestep, **arg_null)[0] noise_pred = noise_pred_uncond + sample_guide_scale * ( noise_pred_cond - noise_pred_uncond) diff --git a/wan/textimage2video.py b/wan/textimage2video.py index 67e9fd29..eab62dd6 100644 --- a/wan/textimage2video.py +++ b/wan/textimage2video.py @@ -5,33 +5,22 @@ import os import random import sys -import types -from contextlib import contextmanager from functools import partial +import numpy as np import torch -import torch.cuda.amp as amp import torch.distributed as dist import torchvision.transforms.functional as TF -from PIL import Image from tqdm import tqdm +from .pipeline_base import WanPipelineBase from .distributed.fsdp import shard_model -from .distributed.sequence_parallel import sp_attn_forward, sp_dit_forward -from .distributed.util import get_world_size from .modules.model import WanModel -from .modules.t5 import T5EncoderModel from .modules.vae2_2 import Wan2_2_VAE -from .utils.fm_solvers import ( - FlowDPMSolverMultistepScheduler, - get_sampling_sigmas, - retrieve_timesteps, -) -from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler -from .utils.utils import best_output_size, masks_like -class WanTI2V: +class WanTI2V(WanPipelineBase): + """Text+Image-to-Video generation pipeline (5B single-expert, VAE 2.2).""" def __init__( self, @@ -46,61 +35,24 @@ def __init__( init_on_cpu=True, convert_model_dtype=False, ): - r""" - Initializes the Wan text-to-video generation model components. - - Args: - config (EasyDict): - Object containing model parameters initialized from config.py - checkpoint_dir (`str`): - Path to directory containing model checkpoints - device_id (`int`, *optional*, defaults to 0): - Id of target GPU device - rank (`int`, *optional*, defaults to 0): - Process rank for distributed training - t5_fsdp (`bool`, *optional*, defaults to False): - Enable FSDP sharding for T5 model - dit_fsdp (`bool`, *optional*, defaults to False): - Enable FSDP sharding for DiT model - use_sp (`bool`, *optional*, defaults to False): - Enable distribution strategy of sequence parallel. - t5_cpu (`bool`, *optional*, defaults to False): - Whether to place T5 model on CPU. Only works without t5_fsdp. - init_on_cpu (`bool`, *optional*, defaults to True): - Enable initializing Transformer Model on CPU. Only works without FSDP or USP. - convert_model_dtype (`bool`, *optional*, defaults to False): - Convert DiT model parameters dtype to 'config.param_dtype'. - Only works without FSDP. - """ - self.device = torch.device(f"cuda:{device_id}") - self.config = config - self.rank = rank - self.t5_cpu = t5_cpu - self.init_on_cpu = init_on_cpu - - self.num_train_timesteps = config.num_train_timesteps - self.param_dtype = config.param_dtype - - if t5_fsdp or dit_fsdp or use_sp: - self.init_on_cpu = False + super().__init__( + config=config, + checkpoint_dir=checkpoint_dir, + device_id=device_id, + rank=rank, + t5_fsdp=t5_fsdp, + dit_fsdp=dit_fsdp, + use_sp=use_sp, + t5_cpu=t5_cpu, + init_on_cpu=init_on_cpu, + convert_model_dtype=convert_model_dtype, + ) + # Single-expert DiT model shard_fn = partial(shard_model, device_id=device_id) - self.text_encoder = T5EncoderModel( - text_len=config.text_len, - dtype=config.t5_dtype, - device=torch.device('cpu'), - checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint), - tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer), - shard_fn=shard_fn if t5_fsdp else None) - - self.vae_stride = config.vae_stride - self.patch_size = config.patch_size - self.vae = Wan2_2_VAE( - vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), - device=self.device) - logging.info(f"Creating WanModel from {checkpoint_dir}") - self.model = WanModel.from_pretrained(checkpoint_dir) + self.model = WanModel.from_pretrained( + checkpoint_dir, subfolder=config.checkpoint) self.model = self._configure_model( model=self.model, use_sp=use_sp, @@ -108,474 +60,108 @@ def __init__( shard_fn=shard_fn, convert_model_dtype=convert_model_dtype) - if use_sp: - self.sp_size = get_world_size() - else: - self.sp_size = 1 - - self.sample_neg_prompt = config.sample_neg_prompt - - def _configure_model(self, model, use_sp, dit_fsdp, shard_fn, - convert_model_dtype): - """ - Configures a model object. This includes setting evaluation modes, - applying distributed parallel strategy, and handling device placement. - - Args: - model (torch.nn.Module): - The model instance to configure. - use_sp (`bool`): - Enable distribution strategy of sequence parallel. - dit_fsdp (`bool`): - Enable FSDP sharding for DiT model. - shard_fn (callable): - The function to apply FSDP sharding. - convert_model_dtype (`bool`): - Convert DiT model parameters dtype to 'config.param_dtype'. - Only works without FSDP. - - Returns: - torch.nn.Module: - The configured model. - """ - model.eval().requires_grad_(False) - - if use_sp: - for block in model.blocks: - block.self_attn.forward = types.MethodType( - sp_attn_forward, block.self_attn) - model.forward = types.MethodType(sp_dit_forward, model) - - if dist.is_initialized(): - dist.barrier() - - if dit_fsdp: - model = shard_fn(model) - else: - if convert_model_dtype: - model.to(self.param_dtype) - if not self.init_on_cpu: - model.to(self.device) - - return model + def _init_vae(self, config, checkpoint_dir): + self.vae = Wan2_2_VAE( + vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), + device=self.device) def generate(self, input_prompt, - img=None, - size=(1280, 704), - max_area=704 * 1280, + img, + max_area=720 * 1280, frame_num=81, shift=5.0, sample_solver='unipc', - sampling_steps=50, + sampling_steps=40, guide_scale=5.0, n_prompt="", seed=-1, offload_model=True): - r""" - Generates video frames from text prompt using diffusion process. - - Args: - input_prompt (`str`): - Text prompt for content generation - img (PIL.Image.Image): - Input image tensor. Shape: [3, H, W] - size (`tuple[int]`, *optional*, defaults to (1280,704)): - Controls video resolution, (width,height). - max_area (`int`, *optional*, defaults to 704*1280): - Maximum pixel area for latent space calculation. Controls video resolution scaling - frame_num (`int`, *optional*, defaults to 81): - How many frames to sample from a video. The number should be 4n+1 - shift (`float`, *optional*, defaults to 5.0): - Noise schedule shift parameter. Affects temporal dynamics - sample_solver (`str`, *optional*, defaults to 'unipc'): - Solver used to sample the video. - sampling_steps (`int`, *optional*, defaults to 50): - Number of diffusion sampling steps. Higher values improve quality but slow generation - guide_scale (`float`, *optional*, defaults 5.0): - Classifier-free guidance scale. Controls prompt adherence vs. creativity. - n_prompt (`str`, *optional*, defaults to ""): - Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt` - seed (`int`, *optional*, defaults to -1): - Random seed for noise generation. If -1, use random seed. - offload_model (`bool`, *optional*, defaults to True): - If True, offloads models to CPU during generation to save VRAM - - Returns: - torch.Tensor: - Generated video frames tensor. Dimensions: (C, N H, W) where: - - C: Color channels (3 for RGB) - - N: Number of frames (81) - - H: Frame height (from size) - - W: Frame width from size) - """ - # i2v - if img is not None: - return self.i2v( - input_prompt=input_prompt, - img=img, - max_area=max_area, - frame_num=frame_num, - shift=shift, - sample_solver=sample_solver, - sampling_steps=sampling_steps, - guide_scale=guide_scale, - n_prompt=n_prompt, - seed=seed, - offload_model=offload_model) - # t2v - return self.t2v( - input_prompt=input_prompt, - size=size, - frame_num=frame_num, - shift=shift, - sample_solver=sample_solver, - sampling_steps=sampling_steps, - guide_scale=guide_scale, - n_prompt=n_prompt, - seed=seed, - offload_model=offload_model) + img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device) - def t2v(self, - input_prompt, - size=(1280, 704), - frame_num=121, - shift=5.0, - sample_solver='unipc', - sampling_steps=50, - guide_scale=5.0, - n_prompt="", - seed=-1, - offload_model=True): - r""" - Generates video frames from text prompt using diffusion process. - - Args: - input_prompt (`str`): - Text prompt for content generation - size (`tuple[int]`, *optional*, defaults to (1280,704)): - Controls video resolution, (width,height). - frame_num (`int`, *optional*, defaults to 121): - How many frames to sample from a video. The number should be 4n+1 - shift (`float`, *optional*, defaults to 5.0): - Noise schedule shift parameter. Affects temporal dynamics - sample_solver (`str`, *optional*, defaults to 'unipc'): - Solver used to sample the video. - sampling_steps (`int`, *optional*, defaults to 50): - Number of diffusion sampling steps. Higher values improve quality but slow generation - guide_scale (`float`, *optional*, defaults 5.0): - Classifier-free guidance scale. Controls prompt adherence vs. creativity. - n_prompt (`str`, *optional*, defaults to ""): - Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt` - seed (`int`, *optional*, defaults to -1): - Random seed for noise generation. If -1, use random seed. - offload_model (`bool`, *optional*, defaults to True): - If True, offloads models to CPU during generation to save VRAM - - Returns: - torch.Tensor: - Generated video frames tensor. Dimensions: (C, N H, W) where: - - C: Color channels (3 for RGB) - - N: Number of frames (81) - - H: Frame height (from size) - - W: Frame width from size) - """ - # preprocess F = frame_num - target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1, - size[1] // self.vae_stride[1], - size[0] // self.vae_stride[2]) - - seq_len = math.ceil((target_shape[2] * target_shape[3]) / - (self.patch_size[1] * self.patch_size[2]) * - target_shape[1] / self.sp_size) * self.sp_size - - if n_prompt == "": - n_prompt = self.sample_neg_prompt - seed = seed if seed >= 0 else random.randint(0, sys.maxsize) - seed_g = torch.Generator(device=self.device) - seed_g.manual_seed(seed) - - if not self.t5_cpu: - self.text_encoder.model.to(self.device) - context = self.text_encoder([input_prompt], self.device) - context_null = self.text_encoder([n_prompt], self.device) - if offload_model: - self.text_encoder.model.cpu() - else: - context = self.text_encoder([input_prompt], torch.device('cpu')) - context_null = self.text_encoder([n_prompt], torch.device('cpu')) - context = [t.to(self.device) for t in context] - context_null = [t.to(self.device) for t in context_null] - - noise = [ - torch.randn( - target_shape[0], - target_shape[1], - target_shape[2], - target_shape[3], - dtype=torch.float32, - device=self.device, - generator=seed_g) - ] - - @contextmanager - def noop_no_sync(): - yield - - no_sync = getattr(self.model, 'no_sync', noop_no_sync) - - # evaluation mode - with ( - torch.amp.autocast('cuda', dtype=self.param_dtype), - torch.no_grad(), - no_sync(), - ): - - if sample_solver == 'unipc': - sample_scheduler = FlowUniPCMultistepScheduler( - num_train_timesteps=self.num_train_timesteps, - shift=1, - use_dynamic_shifting=False) - sample_scheduler.set_timesteps( - sampling_steps, device=self.device, shift=shift) - timesteps = sample_scheduler.timesteps - elif sample_solver == 'dpm++': - sample_scheduler = FlowDPMSolverMultistepScheduler( - num_train_timesteps=self.num_train_timesteps, - shift=1, - use_dynamic_shifting=False) - sampling_sigmas = get_sampling_sigmas(sampling_steps, shift) - timesteps, _ = retrieve_timesteps( - sample_scheduler, - device=self.device, - sigmas=sampling_sigmas) - else: - raise NotImplementedError("Unsupported solver.") - - # sample videos - latents = noise - mask1, mask2 = masks_like(noise, zero=False) - - arg_c = {'context': context, 'seq_len': seq_len} - arg_null = {'context': context_null, 'seq_len': seq_len} - - if offload_model or self.init_on_cpu: - self.model.to(self.device) - torch.cuda.empty_cache() - - for _, t in enumerate(tqdm(timesteps)): - latent_model_input = latents - timestep = [t] - - timestep = torch.stack(timestep) - - temp_ts = (mask2[0][0][:, ::2, ::2] * timestep).flatten() - temp_ts = torch.cat([ - temp_ts, - temp_ts.new_ones(seq_len - temp_ts.size(0)) * timestep - ]) - timestep = temp_ts.unsqueeze(0) - - noise_pred_cond = self.model( - latent_model_input, t=timestep, **arg_c)[0] - noise_pred_uncond = self.model( - latent_model_input, t=timestep, **arg_null)[0] - - noise_pred = noise_pred_uncond + guide_scale * ( - noise_pred_cond - noise_pred_uncond) - - temp_x0 = sample_scheduler.step( - noise_pred.unsqueeze(0), - t, - latents[0].unsqueeze(0), - return_dict=False, - generator=seed_g)[0] - latents = [temp_x0.squeeze(0)] - x0 = latents - if offload_model: - self.model.cpu() - torch.cuda.synchronize() - torch.cuda.empty_cache() - if self.rank == 0: - videos = self.vae.decode(x0) - - del noise, latents - del sample_scheduler - if offload_model: - gc.collect() - torch.cuda.synchronize() - if dist.is_initialized(): - dist.barrier() - - return videos[0] if self.rank == 0 else None - - def i2v(self, - input_prompt, - img, - max_area=704 * 1280, - frame_num=121, - shift=5.0, - sample_solver='unipc', - sampling_steps=40, - guide_scale=5.0, - n_prompt="", - seed=-1, - offload_model=True): - r""" - Generates video frames from input image and text prompt using diffusion process. - - Args: - input_prompt (`str`): - Text prompt for content generation. - img (PIL.Image.Image): - Input image tensor. Shape: [3, H, W] - max_area (`int`, *optional*, defaults to 704*1280): - Maximum pixel area for latent space calculation. Controls video resolution scaling - frame_num (`int`, *optional*, defaults to 121): - How many frames to sample from a video. The number should be 4n+1 - shift (`float`, *optional*, defaults to 5.0): - Noise schedule shift parameter. Affects temporal dynamics - [NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0. - sample_solver (`str`, *optional*, defaults to 'unipc'): - Solver used to sample the video. - sampling_steps (`int`, *optional*, defaults to 40): - Number of diffusion sampling steps. Higher values improve quality but slow generation - guide_scale (`float`, *optional*, defaults 5.0): - Classifier-free guidance scale. Controls prompt adherence vs. creativity. - n_prompt (`str`, *optional*, defaults to ""): - Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt` - seed (`int`, *optional*, defaults to -1): - Random seed for noise generation. If -1, use random seed - offload_model (`bool`, *optional*, defaults to True): - If True, offloads models to CPU during generation to save VRAM - - Returns: - torch.Tensor: - Generated video frames tensor. Dimensions: (C, N H, W) where: - - C: Color channels (3 for RGB) - - N: Number of frames (121) - - H: Frame height (from max_area) - - W: Frame width (from max_area) - """ - # preprocess - ih, iw = img.height, img.width - dh, dw = self.patch_size[1] * self.vae_stride[1], self.patch_size[ - 2] * self.vae_stride[2] - ow, oh = best_output_size(iw, ih, dw, dh, max_area) - - scale = max(ow / iw, oh / ih) - img = img.resize((round(iw * scale), round(ih * scale)), Image.LANCZOS) - - # center-crop - x1 = (img.width - ow) // 2 - y1 = (img.height - oh) // 2 - img = img.crop((x1, y1, x1 + ow, y1 + oh)) - assert img.width == ow and img.height == oh - - # to tensor - img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device).unsqueeze(1) - - F = frame_num - seq_len = ((F - 1) // self.vae_stride[0] + 1) * ( - oh // self.vae_stride[1]) * (ow // self.vae_stride[2]) // ( - self.patch_size[1] * self.patch_size[2]) - seq_len = int(math.ceil(seq_len / self.sp_size)) * self.sp_size + h, w = img.shape[1:] + aspect_ratio = h / w + lat_h = round( + np.sqrt(max_area * aspect_ratio) // self.vae_stride[1] // + self.patch_size[1] * self.patch_size[1]) + lat_w = round( + np.sqrt(max_area / aspect_ratio) // self.vae_stride[2] // + self.patch_size[2] * self.patch_size[2]) + h = lat_h * self.vae_stride[1] + w = lat_w * self.vae_stride[2] + + max_seq_len = ((F - 1) // self.vae_stride[0] + 1) * lat_h * lat_w // ( + self.patch_size[1] * self.patch_size[2]) + max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size seed = seed if seed >= 0 else random.randint(0, sys.maxsize) seed_g = torch.Generator(device=self.device) seed_g.manual_seed(seed) noise = torch.randn( - self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1, - oh // self.vae_stride[1], - ow // self.vae_stride[2], + 16, + (F - 1) // self.vae_stride[0] + 1, + lat_h, + lat_w, dtype=torch.float32, generator=seed_g, device=self.device) + msk = torch.ones(1, F, lat_h, lat_w, device=self.device) + msk[:, 1:] = 0 + msk = torch.concat([ + torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:] + ], + dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) + msk = msk.transpose(1, 2)[0] + if n_prompt == "": n_prompt = self.sample_neg_prompt - # preprocess - if not self.t5_cpu: - self.text_encoder.model.to(self.device) - context = self.text_encoder([input_prompt], self.device) - context_null = self.text_encoder([n_prompt], self.device) - if offload_model: - self.text_encoder.model.cpu() - else: - context = self.text_encoder([input_prompt], torch.device('cpu')) - context_null = self.text_encoder([n_prompt], torch.device('cpu')) - context = [t.to(self.device) for t in context] - context_null = [t.to(self.device) for t in context_null] - - z = self.vae.encode([img]) + context, context_null = self._encode_text(input_prompt, n_prompt, offload_model) - @contextmanager - def noop_no_sync(): - yield + y = self.vae.encode([ + torch.concat([ + torch.nn.functional.interpolate( + img[None].cpu(), size=(h, w), mode='bicubic').transpose( + 0, 1), + torch.zeros(3, F - 1, h, w) + ], + dim=1).to(self.device) + ])[0] + y = torch.concat([msk, y]) - no_sync = getattr(self.model, 'no_sync', noop_no_sync) + no_sync = self._get_no_sync(self.model) - # evaluation mode with ( torch.amp.autocast('cuda', dtype=self.param_dtype), torch.no_grad(), no_sync(), ): + sample_scheduler, timesteps = self._create_scheduler( + sample_solver, sampling_steps, shift) - if sample_solver == 'unipc': - sample_scheduler = FlowUniPCMultistepScheduler( - num_train_timesteps=self.num_train_timesteps, - shift=1, - use_dynamic_shifting=False) - sample_scheduler.set_timesteps( - sampling_steps, device=self.device, shift=shift) - timesteps = sample_scheduler.timesteps - elif sample_solver == 'dpm++': - sample_scheduler = FlowDPMSolverMultistepScheduler( - num_train_timesteps=self.num_train_timesteps, - shift=1, - use_dynamic_shifting=False) - sampling_sigmas = get_sampling_sigmas(sampling_steps, shift) - timesteps, _ = retrieve_timesteps( - sample_scheduler, - device=self.device, - sigmas=sampling_sigmas) - else: - raise NotImplementedError("Unsupported solver.") - - # sample videos latent = noise - mask1, mask2 = masks_like([noise], zero=True) - latent = (1. - mask2[0]) * z[0] + mask2[0] * latent arg_c = { 'context': [context[0]], - 'seq_len': seq_len, + 'seq_len': max_seq_len, + 'y': [y], } arg_null = { 'context': context_null, - 'seq_len': seq_len, + 'seq_len': max_seq_len, + 'y': [y], } - if offload_model or self.init_on_cpu: - self.model.to(self.device) + if offload_model: torch.cuda.empty_cache() for _, t in enumerate(tqdm(timesteps)): latent_model_input = [latent.to(self.device)] - timestep = [t] - - timestep = torch.stack(timestep).to(self.device) - - temp_ts = (mask2[0][0][:, ::2, ::2] * timestep).flatten() - temp_ts = torch.cat([ - temp_ts, - temp_ts.new_ones(seq_len - temp_ts.size(0)) * timestep - ]) - timestep = temp_ts.unsqueeze(0) + timestep = torch.stack([t]).to(self.device) noise_pred_cond = self.model( latent_model_input, t=timestep, **arg_c)[0] @@ -595,14 +181,12 @@ def noop_no_sync(): return_dict=False, generator=seed_g)[0] latent = temp_x0.squeeze(0) - latent = (1. - mask2[0]) * z[0] + mask2[0] * latent x0 = [latent] del latent_model_input, timestep if offload_model: self.model.cpu() - torch.cuda.synchronize() torch.cuda.empty_cache() if self.rank == 0: