Skip to content

Commit 21a21cf

Browse files
Rename SPDBatchNormLie parameter n to num_features for API consistency
All other batchnorm modules (SPDBatchNormMean, SPDBatchNormMeanVar, BatchReNorm) use `num_features` as their matrix-size parameter. This aligns SPDBatchNormLie with the same convention.
1 parent 0786e4f commit 21a21cf

2 files changed

Lines changed: 159 additions & 23 deletions

File tree

spd_learn/modules/liebn.py

Lines changed: 158 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -43,32 +43,121 @@
4343
class SPDBatchNormLie(nn.Module):
4444
r"""Lie Group Batch Normalization for SPD matrices.
4545
46-
This class implements the SPD instance of the LieBN framework, using
47-
the three Lie group structures on the SPD manifold, corresponding to the AIM, LEM, and LCM.
46+
Implements the LieBN framework :cite:p:`chen2024liebn` for SPD manifolds.
47+
Unlike :class:`SPDBatchNormMeanVar`, which normalizes under a single
48+
Riemannian metric (AIRM), this layer exploits the **Lie group structure**
49+
of three classical SPD geometries to define centering, scaling, and biasing
50+
as group-theoretic operations with formal statistical guarantees.
51+
52+
**Algorithm.**
53+
Given a batch :math:`\{P_i\}_{i=1}^N \subset \mathcal{S}_{++}^n`, the
54+
forward pass applies three steps in the Lie algebra selected by ``metric``:
55+
56+
1. **Centering** -- translate the batch mean :math:`M` to the group
57+
identity :math:`E` via the inverse left translation:
58+
59+
.. math::
60+
61+
\bar{P}_i = L_{M_\odot^{-1}}(P_i)
62+
63+
2. **Scaling** -- normalize the Fréchet variance :math:`v^2` with a
64+
learnable shift :math:`s \in \mathbb{R}_{>0}`:
65+
66+
.. math::
67+
68+
\hat{P}_i = \operatorname{Exp}_E
69+
\!\left[\frac{s}{\sqrt{v^2 + \epsilon}}\,
70+
\operatorname{Log}_E(\bar{P}_i)\right]
71+
72+
3. **Biasing** -- translate to the learnable SPD parameter :math:`B`:
73+
74+
.. math::
75+
76+
\tilde{P}_i = L_B(\hat{P}_i)
77+
78+
**Theoretical guarantees** (Proposition 4.2 of the paper):
79+
80+
* *Mean control*: after centering and biasing with :math:`B = E`,
81+
the Fréchet mean of the output batch equals :math:`E`.
82+
* *Variance control*: after scaling, the output dispersion satisfies
83+
:math:`\sum_i w_i\,d^2(\hat{P}_i, E) = s^2`.
84+
85+
**Supported metrics.**
86+
The ``metric`` parameter selects one of three Lie group structures, each
87+
inducing a family of parameterized metrics via the power deformation
88+
:math:`\mathrm{P}_\theta`. The table below summarizes how each step is
89+
realized (see Table 2 in :cite:p:`chen2024liebn`):
90+
91+
.. list-table::
92+
:header-rows: 1
93+
:widths: 25 25 25 25
94+
95+
* - Operation
96+
- :math:`(\theta,\alpha,\beta)`-AIM
97+
- :math:`(\alpha,\beta)`-LEM
98+
- :math:`\theta`-LCM
99+
* - Pullback map
100+
- :math:`\mathrm{P}_\theta`
101+
- :math:`\operatorname{mlog}`
102+
- :math:`\psi_{\mathrm{LC}} \circ \mathrm{P}_\theta`
103+
* - Left translation :math:`L_Q(P)`
104+
- :math:`Q^{1/2} P\, Q^{1/2}`
105+
- :math:`P + Q`
106+
- :math:`P + Q`
107+
* - Scaling
108+
- :math:`\operatorname{Exp}_I[s\,\operatorname{Log}_I(P)]`
109+
- :math:`s \cdot P`
110+
- :math:`s \cdot P`
111+
* - Fréchet mean
112+
- Karcher flow
113+
- Arithmetic mean
114+
- Arithmetic mean
115+
* - Running mean update
116+
- AIRM geodesic
117+
- Linear interpolation
118+
- Linear interpolation
119+
120+
**Bi-invariant distance.**
121+
The Fréchet variance uses the :math:`(\alpha, \beta)` bi-invariant metric
122+
(Definition 3 and Eq. 3 of the paper):
123+
124+
.. math::
125+
126+
d^2(P, Q) = \alpha \lVert V \rVert_F^2
127+
+ \beta \, g(V)^2
128+
129+
where :math:`V` is the tangent representation (log-map) and
130+
:math:`g(V) = \log\det(P)` for AIM or :math:`\operatorname{tr}(V)`
131+
for LEM/LCM. The variance is normalized by :math:`\theta^2` for AIM
132+
and LCM.
48133
49134
Parameters
50135
----------
51-
n : int
52-
Size of the SPD matrices (n x n).
53-
metric : str, default="AIM"
54-
Lie group invariant metric. Supported values are ``"AIM"``, ``"LEM"``,
55-
and ``"LCM"``.
136+
num_features : int
137+
Size of the SPD matrices (:math:`n \times n`).
138+
metric : {"AIM", "LEM", "LCM"}, default="AIM"
139+
Lie group invariant metric.
56140
theta : float, default=1.0
57-
Power deformation parameter.
141+
Power deformation parameter :math:`\theta`. When
142+
:math:`\theta = 1`, no deformation is applied.
58143
alpha : float, default=1.0
59-
Frobenius norm weight in variance computation.
144+
Frobenius norm weight :math:`\alpha` in the bi-invariant distance.
60145
beta : float, default=0.0
61-
Trace/logdet weight in variance computation.
146+
Trace / log-determinant weight :math:`\beta` in the bi-invariant
147+
distance. Must satisfy :math:`\min(\alpha, \alpha + n\beta) > 0`.
62148
momentum : float, default=0.1
63-
Running statistics momentum.
149+
Momentum :math:`\gamma` for exponential moving average of running
150+
statistics.
64151
eps : float, default=1e-5
65-
Numerical stability constant for variance normalization.
152+
Numerical stability constant :math:`\epsilon` added to the variance
153+
before taking the square root.
66154
n_iter : int, default=1
67-
Number of Karcher flow iterations used by the AIM mean.
155+
Number of Karcher flow iterations for the AIM Fréchet mean.
156+
Ignored by LEM and LCM (which use arithmetic means).
68157
congruence : {"cholesky", "eig"}, default="cholesky"
69158
Implementation of the AIM congruence action (centering/biasing).
70159
``"cholesky"`` uses the Cholesky factor :math:`L` of :math:`P` to
71-
compute :math:`L X L^T` (as in the original LieBN paper).
160+
compute :math:`L X L^\top` (as in the original LieBN paper).
72161
``"eig"`` uses eigendecomposition-based :math:`M^{-1/2} X M^{-1/2}`
73162
(matching :func:`~spd_learn.functional.spd_centering`).
74163
Both are mathematically equivalent; Cholesky is typically faster,
@@ -79,11 +168,56 @@ class SPDBatchNormLie(nn.Module):
79168
Device on which to create parameters and buffers.
80169
dtype : torch.dtype, optional
81170
Data type of parameters and buffers.
171+
172+
Attributes
173+
----------
174+
bias : nn.Parameter
175+
Learnable SPD bias matrix :math:`B \in \mathcal{S}_{++}^n`,
176+
parametrized via :class:`~spd_learn.modules.SymmetricPositiveDefinite`.
177+
Initialized to the identity.
178+
shift : nn.Parameter
179+
Learnable positive scalar :math:`s > 0`,
180+
parametrized via :class:`~spd_learn.modules.PositiveDefiniteScalar`.
181+
Initialized to 1.
182+
running_mean : torch.Tensor
183+
Exponential moving average of the batch Fréchet mean.
184+
running_var : torch.Tensor
185+
Exponential moving average of the batch variance.
186+
187+
See Also
188+
--------
189+
:class:`SPDBatchNormMean` :
190+
Mean-only Riemannian batch normalization (AIRM centering without
191+
variance normalization) :cite:p:`brooks2019riemannian`.
192+
:class:`SPDBatchNormMeanVar` :
193+
Full Riemannian batch normalization under the AIRM
194+
:cite:p:`kobler2022spd`.
195+
:func:`~spd_learn.functional.frechet_mean` :
196+
Fréchet mean via Karcher flow (used internally for AIM).
197+
:func:`~spd_learn.functional.lie_group_variance` :
198+
Bi-invariant Fréchet variance computation.
199+
200+
References
201+
----------
202+
.. bibliography::
203+
:filter: key == "chen2024liebn"
204+
205+
Examples
206+
--------
207+
>>> import torch
208+
>>> from spd_learn.modules import SPDBatchNormLie
209+
>>> bn = SPDBatchNormLie(num_features=4, metric="AIM")
210+
>>> X = torch.randn(8, 4, 4, dtype=torch.float64)
211+
>>> X = X @ X.mT + 0.1 * torch.eye(4, dtype=torch.float64)
212+
>>> bn = bn.to(dtype=torch.float64)
213+
>>> Y = bn(X)
214+
>>> Y.shape
215+
torch.Size([8, 4, 4])
82216
"""
83217

84218
def __init__(
85219
self,
86-
n,
220+
num_features,
87221
metric="AIM",
88222
theta=1.0,
89223
alpha=1.0,
@@ -105,7 +239,7 @@ def __init__(
105239
raise ValueError(
106240
f"congruence must be 'cholesky' or 'eig', got '{congruence}'"
107241
)
108-
self.n = n
242+
self.num_features = num_features
109243
self.metric = metric
110244
self.theta = theta
111245
self.alpha = alpha
@@ -115,18 +249,20 @@ def __init__(
115249
self.n_iter = n_iter
116250
self.congruence = congruence
117251

118-
self.bias = nn.Parameter(torch.empty(1, n, n, device=device, dtype=dtype))
252+
self.bias = nn.Parameter(
253+
torch.empty(1, num_features, num_features, device=device, dtype=dtype)
254+
)
119255
self.shift = nn.Parameter(torch.empty((), device=device, dtype=dtype))
120256

121257
if metric == "AIM":
122258
self.register_buffer(
123259
"running_mean",
124-
torch.eye(n, device=device, dtype=dtype).unsqueeze(0),
260+
torch.eye(num_features, device=device, dtype=dtype).unsqueeze(0),
125261
)
126262
else:
127263
self.register_buffer(
128264
"running_mean",
129-
torch.zeros(1, n, n, device=device, dtype=dtype),
265+
torch.zeros(1, num_features, num_features, device=device, dtype=dtype),
130266
)
131267
self.register_buffer("running_var", torch.ones((), device=device, dtype=dtype))
132268

@@ -238,7 +374,7 @@ def forward(self, X):
238374

239375
def extra_repr(self):
240376
return (
241-
f"n={self.n}, metric={self.metric}, theta={self.theta}, "
242-
f"alpha={self.alpha}, beta={self.beta}, momentum={self.momentum}, "
243-
f"congruence={self.congruence}"
377+
f"num_features={self.num_features}, metric={self.metric}, "
378+
f"theta={self.theta}, alpha={self.alpha}, beta={self.beta}, "
379+
f"momentum={self.momentum}, congruence={self.congruence}"
244380
)

tests/test_integration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
"SPDBatchNormMean": dict(num_features=10),
2222
"BatchReNorm": dict(num_features=10),
2323
"SPDBatchNormMeanVar": dict(num_features=10),
24-
"SPDBatchNormLie": dict(n=10),
24+
"SPDBatchNormLie": dict(num_features=10),
2525
"PatchEmbeddingLayer": dict(n_chans=10, n_patches=2),
2626
"BiMapIncreaseDim": dict(in_features=10, out_features=20),
2727
"Shrinkage": dict(n_chans=10),

0 commit comments

Comments
 (0)