Skip to content

functional: add value_and_grad and auxiliary value/grad args#51

Merged
Pavlo3P merged 1 commit into
0.4.2from
50-improve-functional-class
Jul 1, 2026
Merged

functional: add value_and_grad and auxiliary value/grad args#51
Pavlo3P merged 1 commit into
0.4.2from
50-improve-functional-class

Conversation

@Pavlo3P

@Pavlo3P Pavlo3P commented Jul 1, 2026

Copy link
Copy Markdown
Owner

Summary

Widen Functional.value/grad to accept *args, **kwargs (auxiliary parameters such as data, temperature, or a penalty weight) and add a value_and_grad method.

Changes

  • value(self, x, *args, **kwargs) / grad(self, x, *args, **kwargs); extras threaded through __call__, _value_core, _grad_core.
  • New value_and_grad(self, x, *args, **kwargs) -> (value, grad) — base default delegates to value + grad; subclasses may override with a single-pass AD.
  • checked_method unchanged — already forwards extras and validates only x.

Testing

  • pytest tests/functional tests/optimize -q → 375 passed.

🤖 Generated with Claude Code

Widen Functional.value/grad to accept *args/**kwargs (auxiliary parameters
such as data, temperature, or a penalty weight) and add value_and_grad with
a base default that delegates to value + grad, so subclasses can override
with a single-pass evaluator. Extras are threaded through __call__ and the
check-free cores; checked_method already forwards them and validates only x.

Groundwork for the 0.4.2 optimizer-loop work (fused value_and_grad) and
TraceFunctional.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@Pavlo3P Pavlo3P merged commit f6c375c into 0.4.2 Jul 1, 2026
11 checks passed
@Pavlo3P Pavlo3P deleted the 50-improve-functional-class branch July 1, 2026 05:30
Pavlo3P added a commit that referenced this pull request Jul 1, 2026
minimize_optax now runs the whole optimization inside
jax.jit(jax.lax.while_loop(...)): the fused F.value_and_grad is evaluated once
per iteration and cached, so the stopping test (grad_norm <= tol) and the
progress log reuse cached values with no recomputation and no per-iteration host
sync. It returns a rich OptaxResult and supports on-device convergence +
finiteness stop, four-column progress logging (iteration / value / delta F /
grad_norm) via jax.debug.print, a preallocated history buffer with post-hoc
progress_callback replay, and an optional project retraction hook.

BREAKING: the old fixed-`steps` eager loop is removed. New signature:
  minimize_optax(F, x0, opt, *, max_iter=1000, tol=1e-6, project=None,
                 verbose=1, log_every=50, history_every=None,
                 record_history=True, progress_callback=None) -> OptaxResult
Call sites and tests migrated (steps -> max_iter; returns OptaxResult).

Builds on Functional.value_and_grad (#51).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant