[JAX] Support for cuDNN-backed flex attention#2985
Conversation
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR introduces a cuDNN-frontend-backed
Confidence Score: 4/5The core score_mod dispatch, VJP wiring, C++ deserialization, and graph execution are all correct; the only new finding is a silent performance trap when lambdas or closures are used as score_mod callbacks — no wrong results, no crashes. The new dispatch path has clean separation of concerns, thorough input validation, and correct JAX custom_vjp plumbing. The C++ double-checked-locking cache and thread-local handles are sound. No data-corruption or execution-correctness bugs were found beyond what is already under discussion in existing threads. transformer_engine/jax/cpp_extensions/attention.py — the uncacheable-key path in _score_mod_callback_cache_key silently skips graph caching for lambdas and functools.partial without any diagnostic; transformer_engine/jax/csrc/extensions/attention.cpp — minor typo in getScoreModeGraphCache accessor name. Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant fused_attn
participant make_config as make_fused_attn_score_mod_config
participant fwd_rule as _fused_attn_score_mod_fwd_rule
participant py_cache as _score_mod_graph_cache (Python)
participant build_graph as _build_score_mod_fwd_graph
participant cudnn as cudnn.pygraph
participant ffi as ffi.ffi_call
participant cpp as C++ FFI Handler
participant cpp_cache as C++ Graph Cache
User->>fused_attn: "fused_attn(..., score_mod=fn, ...)"
fused_attn->>make_config: classify callback key, split tensors/scalars
make_config-->>fused_attn: _FusedAttnScoreModConfig
fused_attn->>fwd_rule: _fused_attn_score_mod(qkv, tensors, config)
Note over fwd_rule,py_cache: JAX trace time
fwd_rule->>py_cache: lookup (direction, config, avals)
alt cache miss
py_cache->>build_graph: _build_score_mod_fwd_graph(avals, config)
build_graph->>cudnn: "pygraph.sdpa(score_mod=wrapped_fn)"
cudnn-->>build_graph: serialized graph + workspace_size
build_graph-->>py_cache: _SerializedScoreModGraph
end
py_cache-->>fwd_rule: serialized graph
fwd_rule->>ffi: ffi_call(te_fused_attn_score_mod_forward_ffi, serialized_graph as attr)
Note over ffi,cpp_cache: Runtime execution
ffi->>cpp: FusedAttnScoreModForwardFFI(stream, q,k,v, variadic_tensors, attrs)
cpp->>cpp_cache: lookup by (device_id, hash0, hash1, frontend_version)
alt C++ cache miss
cpp->>cpp: deserialize graph via cudnn_frontend
cpp->>cpp_cache: store shared_ptr Graph
end
cpp->>cpp: "graph->execute(handle, variant_pack, workspace)"
cpp-->>ffi: output, stats
fwd_rule-->>User: output (+ residuals for VJP)
Reviews (5): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
| struct ScoreModGraphEntry { | ||
| PyObject *py_graph = nullptr; | ||
| std::vector<int64_t> user_uids; | ||
| std::vector<int64_t> input_uids; | ||
| std::vector<int64_t> output_uids; | ||
| std::vector<int64_t> scalar_uids; | ||
| std::vector<ScoreModScalarStorage> scalar_values; | ||
| }; |
There was a problem hiding this comment.
Python reference leak:
Py_INCREF without a matching Py_DECREF
ScoreModGraphEntry stores a raw PyObject* and its refcount is bumped at registration (Py_INCREF(entry->py_graph) at line 833), but the struct has no destructor to call Py_DECREF. Because ScoreModGraphRegistry never removes entries either, every cuDNN Python graph object registered here is permanently immortalised — it will never be collected by Python's GC regardless of what the call site does. Over many different attention shapes or graph configurations this accumulates silently. The fix is to add a destructor that acquires the GIL and calls Py_DECREF, or to store a pybind11::object (which manages the refcount automatically) and ensure destruction always happens under the GIL.
There was a problem hiding this comment.
@vcherepanov-nv This seems like a valid comment from greptile about leaking pygraphs. But I'm also not sure if that is the intended design to prevent GC freeing up a graph too early by mistake before we use it in the XLA C++ FFI. I'm not sure what the best option is here
There was a problem hiding this comment.
Not sure if I'd call it a leak, but yes, currently the cache is process-lifetime. If we ever encounter an issue with it's growth, then we'll need to implement some kind of eviction policy. But it is out of scope of this PR.
| intermediate_data_type=cudnn.data_type.FLOAT, | ||
| compute_data_type=cudnn.data_type.FLOAT, | ||
| ) | ||
|
|
||
| q_dim, q_stride = _bshd_as_bhsd_dim_stride(q_aval.shape) | ||
| k_dim, k_stride = _bshd_as_bhsd_dim_stride(k_aval.shape) | ||
| v_dim, v_stride = _bshd_as_bhsd_dim_stride(v_aval.shape) | ||
| o_dim, o_stride = _bshd_as_bhsd_dim_stride(output_aval.shape) | ||
| do_dim, do_stride = _bshd_as_bhsd_dim_stride(doutput_aval.shape) |
There was a problem hiding this comment.
id()-based cache keys can produce false cache hits after GC
_score_mod_callback_cache_key builds its key from id(self_obj) and id(func). Python recycles object addresses after GC, so if a callback instance is collected and a new object (of a different class or with different graph logic) is allocated at the same address, the new config will compare equal to the old one under __eq__. JAX's nondiff-argnum caching then reuses the traced function and graph built for the original callback, silently executing the wrong cuDNN graph. The risk is low for long-lived module-level functions but real for short-lived class instances. Anchoring the key to a non-id stable identifier (e.g., a weakref plus explicit id, or requiring callers to supply an explicit stable key) would eliminate the ambiguity.
| Error_Type ExecuteScoreModGraph(cudaStream_t stream, int64_t graph_id, | ||
| const std::vector<void *> &input_ptrs, | ||
| const std::vector<void *> &output_ptrs, void *workspace) { | ||
| auto entry = GetScoreModGraphEntry(graph_id); | ||
| NVTE_CHECK(input_ptrs.size() == entry->input_uids.size(), "cuDNN score_mod graph expected ", | ||
| entry->input_uids.size(), " inputs but got ", input_ptrs.size()); | ||
| NVTE_CHECK(output_ptrs.size() >= entry->output_uids.size(), | ||
| "cuDNN score_mod graph expected at least ", entry->output_uids.size(), | ||
| " outputs but got ", output_ptrs.size()); | ||
|
|
||
| std::unordered_map<int64_t, void *> variant_pack; | ||
| for (size_t i = 0; i < entry->input_uids.size(); ++i) { | ||
| variant_pack.emplace(entry->input_uids[i], input_ptrs[i]); | ||
| } | ||
| for (size_t i = 0; i < entry->output_uids.size(); ++i) { | ||
| variant_pack.emplace(entry->output_uids[i], output_ptrs[i]); | ||
| } | ||
| for (size_t i = 0; i < entry->scalar_uids.size(); ++i) { | ||
| variant_pack.emplace(entry->scalar_uids[i], entry->scalar_values[i].data.data()); | ||
| } | ||
|
|
||
| std::vector<std::intptr_t> user_ptrs; | ||
| user_ptrs.reserve(entry->user_uids.size()); | ||
| for (const auto uid : entry->user_uids) { | ||
| auto it = variant_pack.find(uid); | ||
| NVTE_CHECK(it != variant_pack.end(), "cuDNN score_mod graph variant pack is missing UID ", uid); | ||
| user_ptrs.push_back(reinterpret_cast<std::intptr_t>(it->second)); | ||
| } | ||
|
|
||
| auto handle = GetScoreModCudnnHandle(); | ||
| NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream)); | ||
| { | ||
| pybind11::gil_scoped_acquire gil; | ||
| try { | ||
| auto graph = pybind11::reinterpret_borrow<pybind11::object>(entry->py_graph); | ||
| graph.attr("_execute_with_ptrs")(user_ptrs, reinterpret_cast<std::intptr_t>(workspace), | ||
| reinterpret_cast<std::intptr_t>(handle)); | ||
| } catch (const pybind11::error_already_set &exc) { | ||
| NVTE_ERROR("cuDNN score_mod SDPA graph execution failed: ", exc.what()); | ||
| } | ||
| } | ||
| return ffi_with_cuda_error_check(); | ||
| } |
There was a problem hiding this comment.
GIL held across a CUDA FFI call boundary
ExecuteScoreModGraph acquires pybind11::gil_scoped_acquire while the CUDA stream is live and calls a Python method (_execute_with_ptrs) synchronously. Any other Python thread that holds the GIL and is waiting on CUDA work will deadlock. More broadly, acquiring the GIL inside an XLA/JAX FFI handler — which JAX may dispatch from a non-Python thread — creates a locking inversion risk. This is by-design if cuDNN's Python frontend has no C-level execution path, but the limitation should be documented and the possibility of multi-threaded JAX dispatch should be explicitly considered.
| _SCORE_MOD_UID_DQ = 7 | ||
| _SCORE_MOD_UID_DK = 8 | ||
| _SCORE_MOD_UID_DV = 9 | ||
| _SCORE_MOD_FWD_TENSOR_UID_BASE = 1000 |
There was a problem hiding this comment.
_score_mod_graph_cache and C++ registry grow without bound
_score_mod_graph_cache is a module-level dict that accumulates (graph_id, workspace_size) entries for every unique (direction, config, aval-tuple) seen during tracing, and the C++ ScoreModGraphRegistry holds the corresponding cuDNN graph objects forever. Each entry keeps a Python cuDNN graph alive (and, due to the missing Py_DECREF noted separately, prevents GC). In long-running services or evaluation loops that sweep over many shapes/dtypes, this leads to unbounded cuDNN graph memory accumulation. An LRU eviction strategy or an explicit graph-release API paired with cache invalidation would contain the growth.
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
|
|
||
| def forward(self, graph, score, tensors): | ||
| import cudnn # pylint: disable=import-outside-toplevel | ||
|
|
||
| self.before_tanh_activation = graph.div( | ||
| a=score, | ||
| b=tensors["softcap"], | ||
| compute_data_type=cudnn.data_type.FLOAT, | ||
| ) | ||
| self.before_tanh_activation.set_data_type(cudnn.data_type.FLOAT) | ||
| tanh_out = graph.tanh(input=self.before_tanh_activation) | ||
| tanh_out.set_data_type(cudnn.data_type.FLOAT) | ||
| return graph.mul( | ||
| a=tanh_out, | ||
| b=tensors["softcap"], | ||
| compute_data_type=cudnn.data_type.FLOAT, | ||
| ) | ||
|
|
||
| def backward(self, graph, dscore, tensors): | ||
| import cudnn # pylint: disable=import-outside-toplevel | ||
|
|
||
| d_tanh_out = graph.mul( | ||
| a=dscore, | ||
| b=tensors["softcap"], | ||
| compute_data_type=cudnn.data_type.FLOAT, | ||
| ) | ||
| d_tanh_out.set_data_type(cudnn.data_type.FLOAT) | ||
| d_before_tanh_activation = graph.tanh_backward( | ||
| loss=d_tanh_out, | ||
| input=self.before_tanh_activation, | ||
| compute_data_type=cudnn.data_type.FLOAT, | ||
| ) | ||
| d_before_tanh_activation.set_data_type(cudnn.data_type.FLOAT) | ||
| return graph.div( | ||
| a=d_before_tanh_activation, | ||
| b=tensors["softcap"], | ||
| compute_data_type=cudnn.data_type.FLOAT, | ||
| ) | ||
|
|
||
|
|
||
| def _reference_attention( | ||
| query, key, value, scale, *, causal=False, relative_position=False, softcap=None | ||
| ): | ||
| scores = jnp.einsum("bqhd,bkhd->bhqk", query, key).astype(jnp.float32) * scale | ||
| if causal: | ||
| q_pos = jnp.arange(query.shape[1])[:, None] | ||
| kv_pos = jnp.arange(key.shape[1])[None, :] | ||
| scores = jnp.where(q_pos >= kv_pos, scores, -1e9) | ||
| if relative_position: | ||
| q_pos = jnp.arange(query.shape[1], dtype=jnp.float32)[:, None] | ||
| kv_pos = jnp.arange(key.shape[1], dtype=jnp.float32)[None, :] | ||
| scores = scores + q_pos - kv_pos | ||
| if softcap is not None: |
There was a problem hiding this comment.
_ScoreModSoftcap.backward relies on undocumented cuDNN callback ordering
backward reads self.before_tanh_activation, which is written by forward during sdpa_backward graph construction. This is only safe if cuDNN's sdpa_backward guarantees it calls score_mod (the forward callback) before score_mod_bprop (the backward callback) within the same graph-build invocation. If that order is ever reversed, self.before_tanh_activation is None at the time backward runs, and graph.tanh_backward(input=None, ...) will fail silently or crash at execution time rather than at graph-build time.
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
for more information, see https://pre-commit.ci
| _SCORE_MOD_UID_K = 2 | ||
| _SCORE_MOD_UID_V = 3 | ||
| _SCORE_MOD_UID_O = 4 | ||
| _SCORE_MOD_UID_STATS = 5 |
There was a problem hiding this comment.
Where do these _SCORE_MOD_UID_XXXXX come from? Is it a C/C++ enum? If so, we should make this a Python Enum that derives its values from the C/C++ enum exposed via pybind
See this enum for reference:
There was a problem hiding this comment.
These are just arbitrary numbers, really. In fact, assigning UIDs is completely optional, cuDNN can auto-assign. UIDs are added here just for determinism / to make future troubleshooting easier, e.g. so that we know that 4 is the output tensor.
There was a problem hiding this comment.
Could you please add comments explaining this reason in the file ?
Thanks
| struct ScoreModGraphEntry { | ||
| PyObject *py_graph = nullptr; | ||
| std::vector<int64_t> user_uids; | ||
| std::vector<int64_t> input_uids; | ||
| std::vector<int64_t> output_uids; | ||
| std::vector<int64_t> scalar_uids; | ||
| std::vector<ScoreModScalarStorage> scalar_values; | ||
| }; |
There was a problem hiding this comment.
@vcherepanov-nv This seems like a valid comment from greptile about leaking pygraphs. But I'm also not sure if that is the intended design to prevent GC freeing up a graph too early by mistake before we use it in the XLA C++ FFI. I'm not sure what the best option is here
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
for more information, see https://pre-commit.ci
| status.get_message()); | ||
|
|
||
| std::lock_guard<std::mutex> lock(ScoreModGraphCacheMutex()); | ||
| auto &cache = ScoreModGraphCache(); |
There was a problem hiding this comment.
SR: Can we name ScoreModGraphCache() something like getScoreModeGraphCache()? On my first read-thru I read ScoreModeGraphCache as a constructing a new object and thought this was always using a fresh cache.
| NVTE_CHECK(status.is_good(), "Failed to deserialize cuDNN score_mod SDPA graph: ", | ||
| status.get_message()); | ||
|
|
||
| std::lock_guard<std::mutex> lock(ScoreModGraphCacheMutex()); |
There was a problem hiding this comment.
Same here about ScoreModGraphCacheMutex() -> getScoreModGraphCacheMutex()
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
for more information, see https://pre-commit.ci
| score_mod: Optional[Callable] = None, | ||
| score_mod_bprop: Optional[Callable] = None, | ||
| score_mod_tensors: Optional[Mapping[str, Any]] = None, | ||
| score_mod_bprop_tensors: Optional[Mapping[str, Any]] = None, | ||
| ): |
There was a problem hiding this comment.
Looks like this is the highest API that score_mod has been plumbed to.
There are higher APIs that would need to be plumbed to as well - please do take a look
At the very least FusedDPA and DPA
| qkv: Tuple[jnp.ndarray, ...], | ||
| bias: Optional[jnp.ndarray], | ||
| sequence_descriptor: SequenceDescriptor, | ||
| sequence_descriptor: Optional[SequenceDescriptor], |
There was a problem hiding this comment.
If you are making sequence_descriptor Optional then please add a check in the function body to ensure that SequenceDescriptor is passed when the score_mod / flex attn is not being used
Users should not use non-flex attn / non-score-mod attn without sequence_descriptor
There was a problem hiding this comment.
Never mind, I think we have a check that I had forgotten about:
if sequence_descriptor is None or isinstance(sequence_descriptor, jnp.ndarray):
Do confirm please
| if score_mod is None: | ||
| if score_mod_bprop is not None: | ||
| raise ValueError("score_mod_bprop requires score_mod to be provided.") | ||
| if score_mod_tensors is not None: | ||
| raise ValueError("score_mod_tensors requires score_mod to be provided.") | ||
| if score_mod_bprop_tensors is not None: | ||
| raise ValueError("score_mod_bprop_tensors requires score_mod to be provided.") |
There was a problem hiding this comment.
While being descriptive with the asserts is great in general, I am wondering if this overkill ?
The essence seems to be that we error out when score_mod is None irrespective of values of other params, so why not consolidate these three ?
And if score_mod is not None we use that as a flag that the user wants to use flex attn right ?
| def _validate_fused_attn_score_mod( | ||
| qkv: Tuple[jnp.ndarray, ...], | ||
| bias: Optional[jnp.ndarray], | ||
| sequence_descriptor: Optional[SequenceDescriptor], | ||
| seed: Optional[jnp.ndarray], | ||
| attn_bias_type: AttnBiasType, | ||
| attn_mask_type: AttnMaskType, | ||
| qkv_layout: QKVLayout, | ||
| softmax_type: AttnSoftmaxType, | ||
| dropout_probability: float, | ||
| max_segments_per_seq: int, | ||
| window_size: Optional[Tuple[int, int]], | ||
| context_parallel_strategy: CPStrategy, | ||
| context_parallel_causal_load_balanced: bool, | ||
| context_parallel_axis: str, | ||
| softmax_offset: Optional[jnp.ndarray], | ||
| stripe_size: int | None, | ||
| ): |
There was a problem hiding this comment.
I am wondering if there's merit in moving this check one step downstream to cpp_extensions similar to how there are checks for fused_attn fwd in fused_attn_fwd()
Note: There's merit in having a checking function like this and I think that would be a good addition for fused_attn as well but outside the scope of this PR
| def _fused_attn_score_mod( | ||
| qkv: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray], | ||
| score_mod_tensors: Tuple[jnp.ndarray, ...], | ||
| score_mod_bprop_tensors: Tuple[jnp.ndarray, ...], | ||
| config, | ||
| context_checkpoint_name: str, | ||
| ): | ||
| output, _ = _fused_attn_score_mod_fwd_rule( | ||
| qkv, | ||
| score_mod_tensors, | ||
| score_mod_bprop_tensors, | ||
| config, | ||
| context_checkpoint_name, | ||
| ) | ||
| return output | ||
|
|
||
|
|
||
| def _fused_attn_score_mod_fwd_rule( | ||
| qkv, | ||
| score_mod_tensors, | ||
| score_mod_bprop_tensors, | ||
| config, | ||
| context_checkpoint_name, | ||
| ): | ||
| output, softmax_stats = tex.fused_attn_score_mod_fwd(qkv, score_mod_tensors, config) | ||
| output = checkpoint_name(output, context_checkpoint_name) | ||
| softmax_stats = checkpoint_name(softmax_stats, context_checkpoint_name) | ||
| return output, (qkv, score_mod_tensors, score_mod_bprop_tensors, output, softmax_stats) | ||
|
|
||
|
|
||
| def _fused_attn_score_mod_bwd_rule(config, context_checkpoint_name, ctx, dz): | ||
| del context_checkpoint_name | ||
| qkv, score_mod_tensors, score_mod_bprop_tensors, output, softmax_stats = ctx | ||
| grad_qkv = tex.fused_attn_score_mod_bwd( | ||
| qkv, | ||
| output, | ||
| dz, | ||
| softmax_stats, | ||
| score_mod_tensors, | ||
| score_mod_bprop_tensors, | ||
| config, | ||
| ) | ||
| return ( | ||
| grad_qkv, | ||
| tuple(None for _ in score_mod_tensors), | ||
| tuple(None for _ in score_mod_bprop_tensors), | ||
| ) | ||
|
|
||
|
|
||
| _fused_attn_score_mod.defvjp(_fused_attn_score_mod_fwd_rule, _fused_attn_score_mod_bwd_rule) |
There was a problem hiding this comment.
Initially my thought was if we could combine fused_attn_score_mode* functions with the pre-existing fused_attn functions but as the score_mod functions are anyways not being coupled with a lot of the other features such as determinism, dropout, bias, cp, etc, I guess it is okay to keep it as a separate API
| "L1": [(4, 16, 4, 64)], | ||
| "L2": [(4, 16, 4, 64)], | ||
| } |
There was a problem hiding this comment.
Why are we running tests with the same shape ?
If only 1 shape is needed, only L1 should suffice - no need for L2 then
| dtype, | ||
| ): | ||
| _require_cudnn_frontend_score_mod() | ||
| batch, seqlen, num_heads, head_dim = data_shape |
There was a problem hiding this comment.
There seems to be a lot of input setup here,
Can we not use _setup_inputs() with minor modifications ?
You could also use _setup_inputs() and "undo" some of the setup in the test if you do not wish to modify setup_inputs() - that is fine too.
All tests in attention end up using it so would make sense to maintain uniformity if possible.
However, if it involves too much branching and changes for flex attn, I understand keeping it separate.
Additionally all tests end up using the FusedAttnRunner() - are you not able to use it for the flex tests with minor changes ? Using the runner also means that you do not need to import fused_attn in the distributed tests.
The current setup only imports the fused_attn in the test_fused_attn.py tests the the distributed tests setup and use it directly from there
The above two might help reduce duplication of code and maintain uniformity and the test inputs generated. Moreover, we would not want to add setup, breakdown for any new attn types we add, especially when we are using the same fused_attn API for all.
| inverse_reorder_causal_load_balancing, | ||
| CPStrategy, | ||
| ReorderStrategy, | ||
| fused_attn, |
There was a problem hiding this comment.
Please see below comment
Would be good to avoid this import here if possible
The CP tests also do not end up importing it but use it via the runner customcall_fused_dpa
There was a problem hiding this comment.
A lot fo the other imports like assert_allclose acan also be prevented by using the FusedAttnRunner nad setup_inputs in the fused_attn tests
| def _reference_attention( | ||
| query, key, value, scale, *, causal=False, post_scale_bias=False, softcap=None | ||
| ): | ||
| scores = jnp.einsum("bqhd,bkhd->bhqk", query, key).astype(jnp.float32) * scale | ||
| if causal: | ||
| q_pos = jnp.arange(query.shape[1])[:, None] | ||
| kv_pos = jnp.arange(key.shape[1])[None, :] | ||
| scores = jnp.where(q_pos >= kv_pos, scores, -1e9) | ||
| if post_scale_bias: | ||
| q_pos = jnp.arange(query.shape[1], dtype=jnp.float32)[:, None] | ||
| kv_pos = jnp.arange(key.shape[1], dtype=jnp.float32)[None, :] | ||
| scores = scores + q_pos - kv_pos | ||
| if softcap is not None: | ||
| scores = softcap * jnp.tanh(scores / softcap) | ||
| probs = jax.nn.softmax(scores, axis=-1) | ||
| return jnp.einsum("bhqk,bkhd->bqhd", probs, value).astype(query.dtype) | ||
|
|
||
|
|
There was a problem hiding this comment.
WHy create your own reference and not use the jax native reference in the test file already ?
| _deterministic = not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1"))) | ||
|
|
||
|
|
||
| def _has_cudnn_frontend_python(): |
There was a problem hiding this comment.
This is a general comment for the newly added items to conform to the existing test infrastructure and not add completely "new" tests:
- Have you considered reusing the
setup_inputs()rather than each test regenerating the inputs themselves ? If not,please do - Have you consider using the
FusedAttnRunnersetup to maintain uniformity across tests ? If not, please consider
If for whatever reason, you are unable to conform to the existing test infrastructure in fused attn for the flex attn tests, please move all the flex attn items a different test file
Description
This PR introduces an alternative code path for the FusedAttention backend for JAX.
The user can specify score_mod and score_mod_bprop functions, which get routed to the corresponding parameters of the sdpa and sdpa_backward calls to cuDNN FE.
Fixes # (issue)
#2492
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: