Skip to content

0.4.2#62

Merged
Pavlo3P merged 21 commits into
masterfrom
0.4.2
Jul 1, 2026
Merged

0.4.2#62
Pavlo3P merged 21 commits into
masterfrom
0.4.2

Conversation

@Pavlo3P

@Pavlo3P Pavlo3P commented Jul 1, 2026

Copy link
Copy Markdown
Owner

No description provided.

Pavlo3P and others added 21 commits July 1, 2026 02:23
Widen Functional.value/grad to accept *args/**kwargs (auxiliary parameters
such as data, temperature, or a penalty weight) and add value_and_grad with
a base default that delegates to value + grad, so subclasses can override
with a single-pass evaluator. Extras are threaded through __call__ and the
check-free cores; checked_method already forwards them and validates only x.

Groundwork for the 0.4.2 optimizer-loop work (fused value_and_grad) and
TraceFunctional.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
functional: add value_and_grad and auxiliary value/grad args
minimize_optax now runs the whole optimization inside
jax.jit(jax.lax.while_loop(...)): the fused F.value_and_grad is evaluated once
per iteration and cached, so the stopping test (grad_norm <= tol) and the
progress log reuse cached values with no recomputation and no per-iteration host
sync. It returns a rich OptaxResult and supports on-device convergence +
finiteness stop, four-column progress logging (iteration / value / delta F /
grad_norm) via jax.debug.print, a preallocated history buffer with post-hoc
progress_callback replay, and an optional project retraction hook.

BREAKING: the old fixed-`steps` eager loop is removed. New signature:
  minimize_optax(F, x0, opt, *, max_iter=1000, tol=1e-6, project=None,
                 verbose=1, log_every=50, history_every=None,
                 record_history=True, progress_callback=None) -> OptaxResult
Call sites and tests migrated (steps -> max_iter; returns OptaxResult).

Builds on Functional.value_and_grad (#51).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Move the minimize_optax and OptaxResult summaries to the line after the opening
quotes (GL01), and document OptaxResult's dataclass fields under Parameters so
numpydoc validates them against the generated __init__ (PR01). Docstring-only;
no behavior change. Fixes the scripts/docstring_audit.py --check CI step.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Add value/gradient evaluation counts to the minimize_optax result: nfev and
njev = num_iters + 1 (one fused value_and_grad per iteration plus the initial
evaluation). Documented that line-search value_fn calls (e.g. optax.lbfgs) are
not counted, and surfaced in the verbose summary.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Accumulate optax's num_linesearch_steps on-device through the loop and fold it
into nfev/njev, so line-search optimizers (optax.lbfgs) report their true
function/gradient evaluation counts, not just num_iters + 1. A structural finder
locates num_linesearch_steps anywhere in the opt_state (robust across optimizers
and optax versions); gradient-transformation optimizers report 0. The cumulative
count is also exposed as OptaxResult.n_linesearch_steps and in the verbose summary.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…nt rebinding

Address code review on the minimize_optax driver:

1. Complex gradients. _tree_l2_norm now uses |g|**2 (was real(g)**2, which
   reported norm 0 for a purely imaginary gradient). Finiteness now checks all
   value/gradient leaves (complex-aware) rather than only the scalar norm.

2. lbfgs evaluation counts. nfev/njev count only the driver's own value_and_grad
   calls (num_iters + 1); optax's internal line-search evaluations stay reported
   as n_linesearch_steps rather than being folded into an approximate total.
   Documented why the driver does not use optax.value_and_grad_from_state (it
   would substitute the autodiff gradient of F.value for X.riesz(F.grad), breaking
   the SpaceCore gradient contract).

3. x_element. For a TreeSpace the minimizer is rebound to a domain element via
   X.element(...), so it is a genuine element of F.domain as documented.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…citly approximate

Review follow-up on the line-search accounting:

- _linesearch_steps now sums every num_linesearch_steps found in the optax state
  instead of returning the first, so a chain of multiple line-search transforms is
  counted in full (previously only the first was counted).
- Clarify in the docstring that n_linesearch_steps is an approximate count of the
  extra internal objective evaluations (roughly one per line-search step; the exact
  per-step count is optax-internal), summed over the run.

Behavior for single-line-search optimizers (optax.lbfgs) is unchanged.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
optimize: compiled convergence-aware minimize_optax
Add trace/determinant/unit to the Jordan-algebra hierarchy (0.4.2 W2), derived
from each space's spectrum:

- Base JordanAlgebraSpace: trace = sum(spectrum), determinant = prod(spectrum),
  unit as @AbstractMethod.
- HermitianSpace: trace via the diagonal (real(einsum '...ii->...'), no eigh),
  unit = eye(n); inherits the base determinant.
- ElementwiseJordanSpace: trace/determinant reduce the element's own axes,
  unit = ones(shape).
- Tree and stacked mixins (direct sums): trace additive, determinant
  multiplicative, unit assembled from the leaf/copy units -- inherited by every
  Jordan composite with no per-class edits.

All primitives preserve leading batch axes. The oracle trace(x) == inner(unit(), x)
holds on every Euclidean Jordan algebra across all five space families.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
space: Jordan spectral primitives (trace / determinant / unit)
W3 — a tree's spectral results can now mirror the space treedef:

- TreeSpace.spectrum(x, structured=True) returns a treedef-shaped pytree whose
  leaves hold each leaf's eigenvalue vector. The default flat spectrum() is
  unchanged (it is load-bearing for SpectralLpNormFunctional and the base
  trace/determinant reductions).
- TreeSpectralDecomposition carries the producing space's treedef (static JAX
  pytree aux) and gains to_tree(), exposing the eigenvalues in the tree's own
  (possibly nested) structure. spectral_decompose tags it on.

Round-trips decompose -> from_spectrum on nested trees. A structured spectrum is
a treedef-matching pytree but NOT a space member (its leaves are eigenvalue-
shaped), documented and covered by a check_member-rejects test.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…-tree-structure

space: preserve tree structure in Jordan spectral results
Add the scalar half of the functional algebra (0.4.2 W4), mirroring
linop/_algebra.py: ScaledFunctional (value = a*F.value; grad = X.scale(a, F.grad)
in the domain geometry, not raw *), the make_scaled_functional factory (unit
passthrough + nested-scalar fold), and Functional.__neg__/__mul__/__rmul__
delegating to it (NotImplemented for non-scalar operands). Exported at package
and top level.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Add the additive half of the functional algebra (0.4.2 W4): SumFunctional
(value = sum of term values; grad = the term grads folded via X.add in the
domain geometry, correct on tree/stacked domains), make_functional_sum (flattens
nested sums, unwraps a lone term, validates a common domain), and
Functional.__add__/__radd__/__sub__/__rsub__ (0 + F enables builtin sum();
NotImplemented for non-Functional operands). Exported at package and top level.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…erage

Complete the functional algebra (0.4.2 W4):

- ShiftedFunctional (F + c: value shifted, gradient unchanged) and the
  make_shifted_functional factory (zero-offset passthrough, nested-offset fold).
- ZeroFunctional (value 0, grad = X.zeros()) as the additive identity;
  make_scaled_functional now returns it for a zero scalar and make_functional_sum
  drops it (empty sum -> ZeroFunctional).
- Functional.__add__/__radd__/__sub__/__rsub__ route a scalar operand to the
  affine shift (the one place functionals extend beyond LinOp).
- Register the four algebra nodes in functional_cases() so the registry-
  completeness check and the generic functional laws (value / gradient /
  directional-derivative / conversion) exercise them.
- Tests: affine + zero behavior, factory zero-handling, and an end-to-end
  minimize_optax(0.5F + 0.5F) integration.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…dation

Adversarial review caught two bugs:

- ScaledFunctional.grad/value_and_grad must conjugate a complex scalar: the
  Riesz gradient of a*F is conj(a)*grad(F) (the domain inner product conjugates
  its first argument), matching ScaledLinOp.rapply. Value still scales by a;
  only the metric-gradient side conjugates. Real scalars are unaffected.
- make_functional_sum dropped ZeroFunctional terms BEFORE validating domains, so
  F(X) + Zero(Y) silently returned F. Validate all flattened terms' domains
  (which fold in the backend/dtype context) before dropping zeros, mirroring
  linop make_sum.

Regression tests: the complex Riesz identity <grad(aF), h> == a<c, h>, and a
mismatched-domain Zero term raising instead of being swallowed.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
functional: lazy functional algebra (mirror LinOp)
Bump version to 0.4.2 and record the release notes: Jordan spectral primitives
(trace/determinant/unit), lazy functional algebra, Functional.value_and_grad,
structure-preserving tree spectra, and the compiled convergence-aware
minimize_optax driver.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@Pavlo3P Pavlo3P merged commit 9b4d40b into master Jul 1, 2026
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant