Skip to content

Commit 3a7ecb1

Browse files
yiyixuxuyiyi@huggingface.coclaude
authored
[agents docs] add float64 gotcha (#13472)
* [docs] add float64 + runtime weight-dtype gotchas to models.md Document two dtype pitfalls surfaced by Ernie-Image follow-up #13464: unconditional torch.float64 in RoPE/precompute (breaks MPS/NPU) and reading a child module's weight dtype at runtime (breaks gguf/quant). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * update claude config to allow .ai folder * [ci] fetch default branch before .ai/ checkout in claude_review When triggered by pull_request_review_comment, actions/checkout lands on the PR head and fetch-depth=1 means origin/main isn't tracked, so the follow-up `git checkout origin/main -- .ai/` fails with exit 128. Fetch the default branch explicitly first. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * combine #10 into #8 * Apply suggestions from code review Co-authored-by: YiYi Xu <yixu310@gmail.com> --------- Co-authored-by: yiyi@huggingface.co <yiyi@ip-26-0-160-103.ec2.internal> Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent b3889ea commit 3a7ecb1

1 file changed

Lines changed: 11 additions & 1 deletion

File tree

.ai/models.md

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,4 +73,14 @@ Consult the implementations in `src/diffusers/models/transformers/` if you need
7373

7474
7. **Forgetting to update `_import_structure` and `_lazy_modules`.** The top-level `src/diffusers/__init__.py` has both -- missing either one causes partial import failures.
7575

76-
8. **Hardcoded dtype in model forward.** Don't hardcode `torch.float32` or `torch.bfloat16` in the model's forward pass. Use the dtype of the input tensors or `self.dtype` so the model works with any precision.
76+
8. **Hardcoded dtype in model forward.** Don't hardcode `torch.float32` or `torch.bfloat16`, and don't cast activations by reading a weight's dtype (`self.linear.weight.dtype`) — the stored weight dtype isn't the compute dtype under gguf / quantized loading. Always derive the cast target from the input tensor's dtype or `self.dtype`.
77+
78+
9. **`torch.float64` anywhere in the model.** MPS and several NPU backends don't support float64 -- ops will either error out or silently fall back. Reference repos commonly reach for float64 in RoPE frequency bases, timestep embeddings, sinusoidal position encodings, and similar "precision-sensitive" precompute code (`torch.arange(..., dtype=torch.float64)`, `.double()`, `torch.float64` literals). When porting a model, grep for `float64` / `double()` up front and resolve as follows:
79+
- **Default: just use `torch.float32`.** For inference it is almost always sufficient -- the precision difference in RoPE angles, timestep embeddings, etc. is immaterial to image/video quality. Flip it and move on.
80+
- **Only if float32 visibly degrades output, fall back to the device-gated pattern** we use in the repo:
81+
```python
82+
is_mps = hidden_states.device.type == "mps"
83+
is_npu = hidden_states.device.type == "npu"
84+
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
85+
```
86+
See `transformer_flux.py`, `transformer_flux2.py`, `transformer_wan.py`, `unet_2d_condition.py` for reference usages. Never leave an unconditional `torch.float64` in the model.

0 commit comments

Comments
 (0)