Skip to content

optimize: compiled convergence-aware minimize_optax#54

Merged
Pavlo3P merged 6 commits into
0.4.2from
53-improve-optax-optimization-loop
Jul 1, 2026
Merged

optimize: compiled convergence-aware minimize_optax#54
Pavlo3P merged 6 commits into
0.4.2from
53-improve-optax-optimization-loop

Conversation

@Pavlo3P

@Pavlo3P Pavlo3P commented Jul 1, 2026

Copy link
Copy Markdown
Owner

Summary

Turns minimize_optax into a compiled, convergence-aware driver (W1 of the 0.4.2 optimizer work). The whole loop runs inside jax.jit(jax.lax.while_loop(...)): the fused F.value_and_grad is evaluated once per iteration and cached, so the stopping test (grad_norm ≤ tol) and the progress log reuse cached values — no recomputation, no per-iteration host sync.

Breaking change: the old fixed-steps eager loop is gone. New signature returns a rich OptaxResult:

minimize_optax(F, x0, opt, *, max_iter=1000, tol=1e-6, project=None,
               verbose=1, log_every=50, history_every=None,
               record_history=True, progress_callback=None) -> OptaxResult

Changes

  • On-device convergence (grad_norm ≤ tol) + finiteness stop as a first-class status.
  • Four-column progress logging via jax.debug.printiteration · value F(x) · objective delta ΔF = F_k − F_{k−1} · grad norm — on the log_every cadence.
  • Preallocated on-device history buffer (history_every), with progress_callback replayed after the loop (compiled loops can't call Python live).
  • Optional project retraction after each optax update.
  • optax.with_extra_args_support + value/grad/value_fn forwarded to opt.update (lbfgs / line-search).
  • OptaxResult (success/status/message/iters/final stats/x_element/history/timing) exported.
  • Call sites migrated (test_contracts.py; stepsmax_iter).

Testing

  • tests/optimize/test_minimize_optax.py rewritten for the new contract (15 tests: convergence + early stop, weighted Riesz handoff, tree / bound-element pass-through, max_iter/nonfinite stop, project, four-column history + callback replay, one-eval-per-iteration, guards).
  • pytest tests/optimize -q → 51 passed. Full suite → 3347 passed, 169 skipped.

Notes

🤖 Generated with Claude Code

Pavlo3P and others added 6 commits July 1, 2026 03:14
minimize_optax now runs the whole optimization inside
jax.jit(jax.lax.while_loop(...)): the fused F.value_and_grad is evaluated once
per iteration and cached, so the stopping test (grad_norm <= tol) and the
progress log reuse cached values with no recomputation and no per-iteration host
sync. It returns a rich OptaxResult and supports on-device convergence +
finiteness stop, four-column progress logging (iteration / value / delta F /
grad_norm) via jax.debug.print, a preallocated history buffer with post-hoc
progress_callback replay, and an optional project retraction hook.

BREAKING: the old fixed-`steps` eager loop is removed. New signature:
  minimize_optax(F, x0, opt, *, max_iter=1000, tol=1e-6, project=None,
                 verbose=1, log_every=50, history_every=None,
                 record_history=True, progress_callback=None) -> OptaxResult
Call sites and tests migrated (steps -> max_iter; returns OptaxResult).

Builds on Functional.value_and_grad (#51).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Move the minimize_optax and OptaxResult summaries to the line after the opening
quotes (GL01), and document OptaxResult's dataclass fields under Parameters so
numpydoc validates them against the generated __init__ (PR01). Docstring-only;
no behavior change. Fixes the scripts/docstring_audit.py --check CI step.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Add value/gradient evaluation counts to the minimize_optax result: nfev and
njev = num_iters + 1 (one fused value_and_grad per iteration plus the initial
evaluation). Documented that line-search value_fn calls (e.g. optax.lbfgs) are
not counted, and surfaced in the verbose summary.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Accumulate optax's num_linesearch_steps on-device through the loop and fold it
into nfev/njev, so line-search optimizers (optax.lbfgs) report their true
function/gradient evaluation counts, not just num_iters + 1. A structural finder
locates num_linesearch_steps anywhere in the opt_state (robust across optimizers
and optax versions); gradient-transformation optimizers report 0. The cumulative
count is also exposed as OptaxResult.n_linesearch_steps and in the verbose summary.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…nt rebinding

Address code review on the minimize_optax driver:

1. Complex gradients. _tree_l2_norm now uses |g|**2 (was real(g)**2, which
   reported norm 0 for a purely imaginary gradient). Finiteness now checks all
   value/gradient leaves (complex-aware) rather than only the scalar norm.

2. lbfgs evaluation counts. nfev/njev count only the driver's own value_and_grad
   calls (num_iters + 1); optax's internal line-search evaluations stay reported
   as n_linesearch_steps rather than being folded into an approximate total.
   Documented why the driver does not use optax.value_and_grad_from_state (it
   would substitute the autodiff gradient of F.value for X.riesz(F.grad), breaking
   the SpaceCore gradient contract).

3. x_element. For a TreeSpace the minimizer is rebound to a domain element via
   X.element(...), so it is a genuine element of F.domain as documented.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…citly approximate

Review follow-up on the line-search accounting:

- _linesearch_steps now sums every num_linesearch_steps found in the optax state
  instead of returning the first, so a chain of multiple line-search transforms is
  counted in full (previously only the first was counted).
- Clarify in the docstring that n_linesearch_steps is an approximate count of the
  extra internal objective evaluations (roughly one per line-search step; the exact
  per-step count is optax-internal), summed over the run.

Behavior for single-line-search optimizers (optax.lbfgs) is unchanged.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@Pavlo3P Pavlo3P merged commit 074464e into 0.4.2 Jul 1, 2026
5 checks passed
@Pavlo3P Pavlo3P deleted the 53-improve-optax-optimization-loop branch July 1, 2026 16:06
@Pavlo3P Pavlo3P mentioned this pull request Jul 1, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant