Skip to content

Commit 0786e4f

Browse files
Extract frechet_mean into functional API and rename LieBN.py to liebn.py
Add frechet_mean() to spd_learn.functional.batchnorm, unifying the duplicated Karcher flow logic from SPDBatchNormMean, SPDBatchNormMeanVar, SPDBatchNormLie, and the SPDIM tutorial into a single reusable function. Rename LieBN.py to liebn.py for snake_case consistency with all other module files, and rename karcher_steps to n_iter in SPDBatchNormLie to match the other batchnorm modules.
1 parent 175369c commit 0786e4f

7 files changed

Lines changed: 100 additions & 105 deletions

File tree

examples/applied_examples/plot_source_free_domain.py

Lines changed: 4 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -96,75 +96,10 @@
9696
# SPDIM Geometric Operations
9797
# --------------------------
9898
#
99-
# We define two core geometric operations needed for the SPDIM pipeline.
100-
# These will be included in a future release of ``spd_learn.functional``.
99+
# The Fréchet mean and geodesic distances used by SPDIM are available
100+
# directly from ``spd_learn.functional``.
101101
#
102102

103-
from spd_learn.functional import (
104-
get_epsilon,
105-
matrix_exp,
106-
matrix_log,
107-
matrix_sqrt_inv,
108-
)
109-
110-
111-
def frechet_mean(X, max_iter=50, return_distances=False):
112-
r"""Compute the Fréchet mean under the AIRM.
113-
114-
.. math::
115-
116-
\bar{X} = \arg\min_{G \in \mathcal{S}_{++}^n}
117-
\sum_{i=1}^{N} d_{\text{AIRM}}^2(G, X_i)
118-
119-
Uses adaptive step-size Karcher flow.
120-
"""
121-
eps = get_epsilon(X.dtype, "eigval_log")
122-
n_samples = X.shape[0]
123-
124-
if n_samples == 1:
125-
mean = X[:1]
126-
if return_distances:
127-
return mean, torch.zeros(X.shape[:-2], dtype=X.dtype, device=X.device)
128-
return mean
129-
130-
w = torch.ones((*X.shape[:-2], 1, 1), dtype=X.dtype, device=X.device)
131-
w = w / n_samples
132-
G = (X * w).sum(dim=0, keepdim=True)
133-
134-
nu = 1.0
135-
tau = float("inf")
136-
137-
for _ in range(max_iter):
138-
G_sqrt, G_invsqrt = matrix_sqrt_inv.apply(G)
139-
X_tangent = matrix_log.apply(G_invsqrt @ X @ G_invsqrt)
140-
G_tangent = (X_tangent * w).sum(dim=0, keepdim=True)
141-
142-
crit = torch.norm(G_tangent, p="fro", dim=(-2, -1)).max().item()
143-
if crit <= eps:
144-
break
145-
146-
G = G_sqrt @ matrix_exp.apply(nu * G_tangent) @ G_sqrt
147-
148-
h = nu * crit
149-
if h < tau:
150-
nu = 0.95 * nu
151-
tau = h
152-
else:
153-
nu = 0.5 * nu
154-
155-
if nu <= eps:
156-
break
157-
158-
if return_distances:
159-
G_sqrt, G_invsqrt = matrix_sqrt_inv.apply(G)
160-
X_tangent = matrix_log.apply(G_invsqrt @ X @ G_invsqrt)
161-
G_tangent = (X_tangent * w).sum(dim=0, keepdim=True)
162-
distances = torch.norm(X_tangent - G_tangent, p="fro", dim=(-2, -1))
163-
return G, distances
164-
165-
return G
166-
167-
168103
######################################################################
169104
# Loading the Dataset
170105
# -------------------
@@ -179,12 +114,13 @@ def frechet_mean(X, max_iter=50, return_distances=False):
179114
# - **Source domain**: Session A (training with labels)
180115
# - **Target domain**: Session B (adaptation without labels)
181116
#
182-
183117
from braindecode.datasets import create_from_X_y
184118
from moabb.datasets import BNCI2015_001
185119
from moabb.paradigms import MotorImagery
186120
from sklearn.preprocessing import LabelEncoder
187121

122+
from spd_learn.functional import frechet_mean
123+
188124

189125
dataset = BNCI2015_001()
190126
paradigm = MotorImagery(

spd_learn/functional/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from .autograd import modeig_backward, modeig_forward
1818
from .batchnorm import (
19+
frechet_mean,
1920
karcher_mean_iteration,
2021
lie_group_variance,
2122
spd_centering,
@@ -157,6 +158,7 @@
157158
"ledoit_wolf",
158159
"shrinkage_covariance",
159160
# Batch normalization
161+
"frechet_mean",
160162
"karcher_mean_iteration",
161163
"lie_group_variance",
162164
"spd_centering",

spd_learn/functional/batchnorm.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
1010
Functions
1111
---------
12+
frechet_mean
13+
Fréchet mean of SPD matrices under the AIRM via Karcher flow.
1214
karcher_mean_iteration
1315
Single iteration of the Karcher (Fréchet) mean algorithm.
1416
spd_centering
@@ -26,7 +28,7 @@
2628
:class:`~spd_learn.modules.SPDBatchNormMeanVar` : Full Riemannian batch normalization.
2729
"""
2830

29-
from typing import Tuple, Union
31+
from typing import Optional, Tuple, Union
3032

3133
import torch
3234

@@ -103,6 +105,80 @@ def karcher_mean_iteration(
103105
return new_mean
104106

105107

108+
def frechet_mean(
109+
X: torch.Tensor,
110+
max_iter: int = 1,
111+
weights: Optional[torch.Tensor] = None,
112+
return_distances: bool = False,
113+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
114+
r"""Fréchet mean of SPD matrices under the AIRM via Karcher flow.
115+
116+
Computes the minimizer of the sum of squared geodesic distances:
117+
118+
.. math::
119+
120+
\bar{X} = \arg\min_{G \in \mathcal{S}_{++}^n}
121+
\sum_{i=1}^{N} w_i \, d_{\text{AIRM}}^2(G, X_i)
122+
123+
using iterative Karcher flow initialized from the (weighted) Euclidean mean.
124+
125+
Parameters
126+
----------
127+
X : torch.Tensor
128+
Batch of SPD matrices with shape ``(batch_size, ..., n, n)``.
129+
max_iter : int, default=1
130+
Number of Karcher flow iterations. A single iteration is often
131+
sufficient for batch normalization; use more (e.g. 50) when a
132+
high-accuracy mean is needed.
133+
weights : torch.Tensor, optional
134+
Per-sample weights with shape broadcastable to ``X``. When ``None``,
135+
uniform weights ``1/N`` are used.
136+
return_distances : bool, default=False
137+
If True, also returns the geodesic distances from each sample to
138+
the mean.
139+
140+
Returns
141+
-------
142+
mean : torch.Tensor
143+
Fréchet mean with shape ``(1, ..., n, n)``.
144+
distances : torch.Tensor
145+
Only returned when ``return_distances=True``. Geodesic distances
146+
from each sample to the mean, with shape ``(batch_size, ...)``.
147+
148+
See Also
149+
--------
150+
:func:`karcher_mean_iteration` : Single Karcher step (lower-level).
151+
:func:`~spd_learn.functional.airm_distance` : Pairwise AIRM distance.
152+
153+
References
154+
----------
155+
See :cite:p:`pennec2006riemannian` for details on Karcher mean computation.
156+
"""
157+
batch = X.detach()
158+
159+
if weights is None:
160+
mean = batch.mean(dim=0, keepdim=True)
161+
else:
162+
mean = (batch * weights).sum(dim=0, keepdim=True)
163+
164+
for _ in range(max_iter):
165+
mean_sqrt, mean_invsqrt = matrix_sqrt_inv.apply(mean)
166+
X_tangent = matrix_log.apply(mean_invsqrt @ batch @ mean_invsqrt)
167+
if weights is None:
168+
mean_tangent = X_tangent.mean(dim=0, keepdim=True)
169+
else:
170+
mean_tangent = (X_tangent * weights).sum(dim=0, keepdim=True)
171+
mean = mean_sqrt @ matrix_exp.apply(mean_tangent) @ mean_sqrt
172+
173+
if return_distances:
174+
mean_sqrt, mean_invsqrt = matrix_sqrt_inv.apply(mean)
175+
X_tangent = matrix_log.apply(mean_invsqrt @ batch @ mean_invsqrt)
176+
distances = torch.norm(X_tangent, p="fro", dim=(-2, -1))
177+
return mean, distances
178+
179+
return mean
180+
181+
106182
def spd_centering(
107183
X: torch.Tensor,
108184
mean_invsqrt: torch.Tensor,
@@ -339,6 +415,7 @@ def lie_group_variance(
339415

340416

341417
__all__ = [
418+
"frechet_mean",
342419
"karcher_mean_iteration",
343420
"lie_group_variance",
344421
"spd_centering",

spd_learn/modules/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from .bilinear import BiMap, BiMapIncreaseDim
55
from .covariance import CovLayer
66
from .dropout import SPDDropout
7-
from .LieBN import SPDBatchNormLie
7+
from .liebn import SPDBatchNormLie
88
from .manifold import PositiveDefiniteScalar, SymmetricPositiveDefinite
99
from .modeig import ExpEig, LogEig, ReEig
1010
from .regularize import Shrinkage, TraceNorm

spd_learn/modules/batchnorm.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
matrix_sqrt,
1414
)
1515
from ..functional.batchnorm import (
16-
karcher_mean_iteration,
16+
frechet_mean,
1717
spd_centering,
1818
spd_rebiasing,
1919
tangent_space_variance,
@@ -194,10 +194,7 @@ def forward(self, input):
194194
195195
"""
196196
if self.training:
197-
mean = input.mean(dim=0, keepdim=True)
198-
if input.shape[0] > 1:
199-
for _ in range(self.n_iter):
200-
mean = karcher_mean_iteration(input, mean)
197+
mean = frechet_mean(input, max_iter=self.n_iter)
201198
with torch.no_grad():
202199
self.running_mean = airm_geodesic(
203200
self.running_mean, mean, self.momentum
@@ -478,14 +475,8 @@ def forward(self, input):
478475
Normalized tensor of the same shape as the input.
479476
480477
"""
481-
n_samples = input.shape[0]
482478
if self.training:
483-
# Kobler et al. SPDMBN/SPDBN: estimate batch Fréchet mean via Karcher step
484-
batch_mean = input.mean(dim=0, keepdim=True)
485-
if n_samples > 1:
486-
for _ in range(self.n_iter):
487-
# Kobler et al. (Eq. 4): P2 L132-145; Karcher flow note: P2 L163-165
488-
batch_mean = karcher_mean_iteration(input, batch_mean)
479+
batch_mean = frechet_mean(input, max_iter=self.n_iter)
489480

490481
# Scalar dispersion: mean squared Frobenius norm of log at the mean (a single scalar, not variance matrix)
491482
mean_inv_sqrt = matrix_inv_sqrt.apply(batch_mean)
Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
matrix_sqrt,
2727
)
2828
from ..functional.batchnorm import (
29-
karcher_mean_iteration,
29+
frechet_mean,
3030
lie_group_variance,
3131
spd_centering,
3232
spd_cholesky_congruence,
@@ -63,9 +63,8 @@ class SPDBatchNormLie(nn.Module):
6363
Running statistics momentum.
6464
eps : float, default=1e-5
6565
Numerical stability constant for variance normalization.
66-
karcher_steps : int, default=1
67-
Number of Karcher flow iterations used by the AIM mean. Iterations
68-
stop early when the tangent update norm falls below ``1e-5``.
66+
n_iter : int, default=1
67+
Number of Karcher flow iterations used by the AIM mean.
6968
congruence : {"cholesky", "eig"}, default="cholesky"
7069
Implementation of the AIM congruence action (centering/biasing).
7170
``"cholesky"`` uses the Cholesky factor :math:`L` of :math:`P` to
@@ -91,7 +90,7 @@ def __init__(
9190
beta=0.0,
9291
momentum=0.1,
9392
eps=1e-5,
94-
karcher_steps=1,
93+
n_iter=1,
9594
congruence="cholesky",
9695
device=None,
9796
dtype=None,
@@ -113,7 +112,7 @@ def __init__(
113112
self.beta = beta
114113
self.momentum = momentum
115114
self.eps = eps
116-
self.karcher_steps = karcher_steps
115+
self.n_iter = n_iter
117116
self.congruence = congruence
118117

119118
self.bias = nn.Parameter(torch.empty(1, n, n, device=device, dtype=dtype))
@@ -182,15 +181,7 @@ def _translate(self, X, P, inverse=False):
182181
def _frechet_mean(self, X_def):
183182
"""Fréchet mean in the deformed space."""
184183
if self.metric == "AIM":
185-
batch = X_def.detach()
186-
mean = batch.mean(dim=0, keepdim=True)
187-
for _ in range(self.karcher_steps):
188-
mean, mean_tangent = karcher_mean_iteration(
189-
batch, mean, detach=True, return_tangent=True
190-
)
191-
if mean_tangent.norm(dim=(-1, -2)).max() < 1e-5:
192-
break
193-
return mean
184+
return frechet_mean(X_def, max_iter=self.n_iter)
194185
return X_def.detach().mean(dim=0, keepdim=True)
195186

196187
def _scale(self, X, var):

tests/test_liebn.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def test_post_normalization_mean(simulated_data, metric, congruence):
9797
"""
9898
x, _, ndim, nobs = simulated_data
9999
layer = SPDBatchNormLie(
100-
ndim, metric=metric, karcher_steps=64, congruence=congruence, dtype=DTYPE
100+
ndim, metric=metric, n_iter=64, congruence=congruence, dtype=DTYPE
101101
)
102102
layer.train()
103103

@@ -135,7 +135,7 @@ def test_post_normalization_variance(simulated_data, metric):
135135
this is close to 1.0.
136136
"""
137137
x, _, ndim, nobs = simulated_data
138-
layer = SPDBatchNormLie(ndim, metric=metric, karcher_steps=64, dtype=DTYPE)
138+
layer = SPDBatchNormLie(ndim, metric=metric, n_iter=64, dtype=DTYPE)
139139
layer.train()
140140

141141
with torch.no_grad():
@@ -170,9 +170,7 @@ def test_post_normalization_variance(simulated_data, metric):
170170
def test_running_stats_single_batch(simulated_data, metric):
171171
"""With momentum=1.0, running stats should match batch stats exactly."""
172172
x, _, ndim, nobs = simulated_data
173-
layer = SPDBatchNormLie(
174-
ndim, metric=metric, momentum=1.0, karcher_steps=64, dtype=DTYPE
175-
)
173+
layer = SPDBatchNormLie(ndim, metric=metric, momentum=1.0, n_iter=64, dtype=DTYPE)
176174
layer.train()
177175

178176
with torch.no_grad():
@@ -216,12 +214,12 @@ def test_running_stats_single_batch(simulated_data, metric):
216214
def test_running_stats_convergence(simulated_data, metric):
217215
"""Running stats should converge to population stats over mini-batches."""
218216
x, _, ndim, nobs = simulated_data
219-
layer = SPDBatchNormLie(ndim, metric=metric, karcher_steps=1, dtype=DTYPE)
217+
layer = SPDBatchNormLie(ndim, metric=metric, n_iter=1, dtype=DTYPE)
220218

221219
# Full-batch reference statistics (high precision)
222220
with torch.no_grad():
223221
ref_layer = SPDBatchNormLie(
224-
ndim, metric=metric, momentum=1.0, karcher_steps=64, dtype=DTYPE
222+
ndim, metric=metric, momentum=1.0, n_iter=64, dtype=DTYPE
225223
)
226224
ref_layer.train()
227225
ref_layer(x)
@@ -258,7 +256,7 @@ def test_gradient_flow(simulated_data, metric):
258256
# Use a small batch to keep computation fast
259257
x_small = x[:8].clone().requires_grad_(True)
260258

261-
layer = SPDBatchNormLie(ndim, metric=metric, karcher_steps=1, dtype=DTYPE)
259+
layer = SPDBatchNormLie(ndim, metric=metric, n_iter=1, dtype=DTYPE)
262260
layer.train()
263261

264262
output = layer(x_small)

0 commit comments

Comments
 (0)