Skip to content

Commit eb0ffd7

Browse files
authored
rel-notes + version update (#166)
* rel-notes + version update * update version * update changelog wording
1 parent 50e041a commit eb0ffd7

3 files changed

Lines changed: 15 additions & 2 deletions

File tree

CHANGELOG.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,18 @@
11
## Latest Changes
22

3+
## 0.6.1 (2025-09-04)
4+
5+
### Added
6+
- [Torch/JAX] Support for variable leading batch dimensions in triangle multiplicative update
7+
- [Torch/JAX] Triangle attention kernel support for additional input configs: all hidden_dim<=32 and divisible by 4 for tf32/fp32, and for all hidden_dim<=128 and divisible by 8 for bf16/fp16. In the rare instance that the kernel does not support an input config, fallback to torch is enabled instead of erroring out.
8+
- [Torch/JAX] Tuned config for RTX PRO 6000 GPUs for triangle multiplicative update.
9+
- [JAX] vmap support for triangle multiplicative update and triangle attention
10+
- [Torch] Improved error reporting on import failure with traceback information for stacktrace
11+
12+
### Bug fix
13+
- [Torch/JAX] Fixed illegal memory access issue stemming from int32 indexing for longer sequences in triangle multiplicative update and attention with pair bias.
314
## 0.6.0 (2025-08-11)
15+
- [JAX] Moved to using nondiff_argnums instead of nondiff_argnames to be compatible with older JAX versions
416

517
### Added
618
- [Torch] New feature: Added `cuet.attention_pair_bias` (support for caching the pair bias tensor & further kernel acceleration coming up soon. There maybe API related changes for this in the next release)

VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
0.6.1rc2
1+
0.6.1

cuequivariance_torch/cuequivariance_torch/primitives/triangle.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def triangle_attention(
7070
Notes:
7171
(1) Context is saved for backward pass. You don't need to save it manually.
7272
(2) Kernel precision (fp32, bf16, fp16) is based on input dtypes. For tf32, set it from torch global scope
73-
(3) **Limitation**: Full FP32 is not supported for backward pass. Please set `torch.backends.cuda.matmul.allow_tf32=True`.
73+
(3) Triangle attention kernel supports: all hidden_dim<=32 and divisible by 4 for tf32/fp32, and for all hidden_dim<=128 and divisible by 8 for bf16/fp16. In the rare instance that the kernel does not support an input config, fallback to torch is enabled instead of erroring out.
7474
7575
Example:
7676
>>> import torch
@@ -195,6 +195,7 @@ def triangle_multiplicative_update(
195195
(3) **Limitation**: Currently only supports hidden_dim values that are multiples of 32.
196196
(4) We have moved away from the default round-towards-zero (RZ) implementation to round-nearest (RN) for better tf32 accuracy in cuex.triangle_multiplicative_update. In rare circumstances, this may cause minor differences in results observed.
197197
(5) When using torch compile, use `cueuivariance_ops_torch.init_triton_cache()` to initialize triton cache before calling torch compiled triangular multiplicative update.
198+
(6) Although the example demonstrates the most common case of one batch dimension, the API supports variable number of leading batch dimensions.
198199
199200
Example:
200201
>>> import torch

0 commit comments

Comments
 (0)