optimize: compiled convergence-aware minimize_optax#54
Merged
Conversation
minimize_optax now runs the whole optimization inside
jax.jit(jax.lax.while_loop(...)): the fused F.value_and_grad is evaluated once
per iteration and cached, so the stopping test (grad_norm <= tol) and the
progress log reuse cached values with no recomputation and no per-iteration host
sync. It returns a rich OptaxResult and supports on-device convergence +
finiteness stop, four-column progress logging (iteration / value / delta F /
grad_norm) via jax.debug.print, a preallocated history buffer with post-hoc
progress_callback replay, and an optional project retraction hook.
BREAKING: the old fixed-`steps` eager loop is removed. New signature:
minimize_optax(F, x0, opt, *, max_iter=1000, tol=1e-6, project=None,
verbose=1, log_every=50, history_every=None,
record_history=True, progress_callback=None) -> OptaxResult
Call sites and tests migrated (steps -> max_iter; returns OptaxResult).
Builds on Functional.value_and_grad (#51).
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Move the minimize_optax and OptaxResult summaries to the line after the opening quotes (GL01), and document OptaxResult's dataclass fields under Parameters so numpydoc validates them against the generated __init__ (PR01). Docstring-only; no behavior change. Fixes the scripts/docstring_audit.py --check CI step. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Add value/gradient evaluation counts to the minimize_optax result: nfev and njev = num_iters + 1 (one fused value_and_grad per iteration plus the initial evaluation). Documented that line-search value_fn calls (e.g. optax.lbfgs) are not counted, and surfaced in the verbose summary. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Accumulate optax's num_linesearch_steps on-device through the loop and fold it into nfev/njev, so line-search optimizers (optax.lbfgs) report their true function/gradient evaluation counts, not just num_iters + 1. A structural finder locates num_linesearch_steps anywhere in the opt_state (robust across optimizers and optax versions); gradient-transformation optimizers report 0. The cumulative count is also exposed as OptaxResult.n_linesearch_steps and in the verbose summary. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…nt rebinding Address code review on the minimize_optax driver: 1. Complex gradients. _tree_l2_norm now uses |g|**2 (was real(g)**2, which reported norm 0 for a purely imaginary gradient). Finiteness now checks all value/gradient leaves (complex-aware) rather than only the scalar norm. 2. lbfgs evaluation counts. nfev/njev count only the driver's own value_and_grad calls (num_iters + 1); optax's internal line-search evaluations stay reported as n_linesearch_steps rather than being folded into an approximate total. Documented why the driver does not use optax.value_and_grad_from_state (it would substitute the autodiff gradient of F.value for X.riesz(F.grad), breaking the SpaceCore gradient contract). 3. x_element. For a TreeSpace the minimizer is rebound to a domain element via X.element(...), so it is a genuine element of F.domain as documented. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…citly approximate Review follow-up on the line-search accounting: - _linesearch_steps now sums every num_linesearch_steps found in the optax state instead of returning the first, so a chain of multiple line-search transforms is counted in full (previously only the first was counted). - Clarify in the docstring that n_linesearch_steps is an approximate count of the extra internal objective evaluations (roughly one per line-search step; the exact per-step count is optax-internal), summed over the run. Behavior for single-line-search optimizers (optax.lbfgs) is unchanged. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Turns
minimize_optaxinto a compiled, convergence-aware driver (W1 of the 0.4.2 optimizer work). The whole loop runs insidejax.jit(jax.lax.while_loop(...)): the fusedF.value_and_gradis evaluated once per iteration and cached, so the stopping test (grad_norm ≤ tol) and the progress log reuse cached values — no recomputation, no per-iteration host sync.Breaking change: the old fixed-
stepseager loop is gone. New signature returns a richOptaxResult:Changes
grad_norm ≤ tol) + finiteness stop as a first-class status.jax.debug.print— iteration · valueF(x)· objective deltaΔF = F_k − F_{k−1}· grad norm — on thelog_everycadence.history_every), withprogress_callbackreplayed after the loop (compiled loops can't call Python live).projectretraction after each optax update.optax.with_extra_args_support+value/grad/value_fnforwarded toopt.update(lbfgs / line-search).OptaxResult(success/status/message/iters/final stats/x_element/history/timing) exported.test_contracts.py;steps→max_iter).Testing
tests/optimize/test_minimize_optax.pyrewritten for the new contract (15 tests: convergence + early stop, weighted Riesz handoff, tree / bound-element pass-through,max_iter/nonfinite stop,project, four-column history + callback replay, one-eval-per-iteration, guards).pytest tests/optimize -q→ 51 passed. Full suite → 3347 passed, 169 skipped.Notes
Functional.value_and_gradfrom functional: add value_and_grad and auxiliary value/grad args #51 (merged).minimize_scipyis unchanged.🤖 Generated with Claude Code