Skip to content

Commit eba7477

Browse files
committed
Restrict kwarg type in code generated by @non_differentiable macro
1 parent c86125f commit eba7477

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

src/rule_definition_tools.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,7 @@ function _nondiff_rrule_expr(__source__, primal_sig_parts, primal_invoke)
445445
return @strip_linenos quote
446446
# Manually defined kw version to save compiler work. See explanation in rules.jl
447447
function (::Core.kwftype(typeof(rrule)))(
448-
$(esc(kwargs))::Any, ::typeof(rrule), $(esc_primal_sig_parts...)
448+
$(esc(kwargs))::NamedTuple, ::typeof(rrule), $(esc_primal_sig_parts...)
449449
)
450450
return ($(esc(_with_kwargs_expr(primal_invoke, kwargs))), $pullback_expr)
451451
end

0 commit comments

Comments
 (0)