4343class SPDBatchNormLie (nn .Module ):
4444 r"""Lie Group Batch Normalization for SPD matrices.
4545
46- This class implements the SPD instance of the LieBN framework, using
47- the three Lie group structures on the SPD manifold, corresponding to the AIM, LEM, and LCM.
46+ Implements the LieBN framework :cite:p:`chen2024liebn` for SPD manifolds.
47+ Unlike :class:`SPDBatchNormMeanVar`, which normalizes under a single
48+ Riemannian metric (AIRM), this layer exploits the **Lie group structure**
49+ of three classical SPD geometries to define centering, scaling, and biasing
50+ as group-theoretic operations with formal statistical guarantees.
51+
52+ **Algorithm.**
53+ Given a batch :math:`\{P_i\}_{i=1}^N \subset \mathcal{S}_{++}^n`, the
54+ forward pass applies three steps in the Lie algebra selected by ``metric``:
55+
56+ 1. **Centering** -- translate the batch mean :math:`M` to the group
57+ identity :math:`E` via the inverse left translation:
58+
59+ .. math::
60+
61+ \bar{P}_i = L_{M_\odot^{-1}}(P_i)
62+
63+ 2. **Scaling** -- normalize the Fréchet variance :math:`v^2` with a
64+ learnable shift :math:`s \in \mathbb{R}_{>0}`:
65+
66+ .. math::
67+
68+ \hat{P}_i = \operatorname{Exp}_E
69+ \!\left[\frac{s}{\sqrt{v^2 + \epsilon}}\,
70+ \operatorname{Log}_E(\bar{P}_i)\right]
71+
72+ 3. **Biasing** -- translate to the learnable SPD parameter :math:`B`:
73+
74+ .. math::
75+
76+ \tilde{P}_i = L_B(\hat{P}_i)
77+
78+ **Theoretical guarantees** (Proposition 4.2 of the paper):
79+
80+ * *Mean control*: after centering and biasing with :math:`B = E`,
81+ the Fréchet mean of the output batch equals :math:`E`.
82+ * *Variance control*: after scaling, the output dispersion satisfies
83+ :math:`\sum_i w_i\,d^2(\hat{P}_i, E) = s^2`.
84+
85+ **Supported metrics.**
86+ The ``metric`` parameter selects one of three Lie group structures, each
87+ inducing a family of parameterized metrics via the power deformation
88+ :math:`\mathrm{P}_\theta`. The table below summarizes how each step is
89+ realized (see Table 2 in :cite:p:`chen2024liebn`):
90+
91+ .. list-table::
92+ :header-rows: 1
93+ :widths: 25 25 25 25
94+
95+ * - Operation
96+ - :math:`(\theta,\alpha,\beta)`-AIM
97+ - :math:`(\alpha,\beta)`-LEM
98+ - :math:`\theta`-LCM
99+ * - Pullback map
100+ - :math:`\mathrm{P}_\theta`
101+ - :math:`\operatorname{mlog}`
102+ - :math:`\psi_{\mathrm{LC}} \circ \mathrm{P}_\theta`
103+ * - Left translation :math:`L_Q(P)`
104+ - :math:`Q^{1/2} P\, Q^{1/2}`
105+ - :math:`P + Q`
106+ - :math:`P + Q`
107+ * - Scaling
108+ - :math:`\operatorname{Exp}_I[s\,\operatorname{Log}_I(P)]`
109+ - :math:`s \cdot P`
110+ - :math:`s \cdot P`
111+ * - Fréchet mean
112+ - Karcher flow
113+ - Arithmetic mean
114+ - Arithmetic mean
115+ * - Running mean update
116+ - AIRM geodesic
117+ - Linear interpolation
118+ - Linear interpolation
119+
120+ **Bi-invariant distance.**
121+ The Fréchet variance uses the :math:`(\alpha, \beta)` bi-invariant metric
122+ (Definition 3 and Eq. 3 of the paper):
123+
124+ .. math::
125+
126+ d^2(P, Q) = \alpha \lVert V \rVert_F^2
127+ + \beta \, g(V)^2
128+
129+ where :math:`V` is the tangent representation (log-map) and
130+ :math:`g(V) = \log\det(P)` for AIM or :math:`\operatorname{tr}(V)`
131+ for LEM/LCM. The variance is normalized by :math:`\theta^2` for AIM
132+ and LCM.
48133
49134 Parameters
50135 ----------
51- n : int
52- Size of the SPD matrices (n x n).
53- metric : str, default="AIM"
54- Lie group invariant metric. Supported values are ``"AIM"``, ``"LEM"``,
55- and ``"LCM"``.
136+ num_features : int
137+ Size of the SPD matrices (:math:`n \times n`).
138+ metric : {"AIM", "LEM", "LCM"}, default="AIM"
139+ Lie group invariant metric.
56140 theta : float, default=1.0
57- Power deformation parameter.
141+ Power deformation parameter :math:`\theta`. When
142+ :math:`\theta = 1`, no deformation is applied.
58143 alpha : float, default=1.0
59- Frobenius norm weight in variance computation .
144+ Frobenius norm weight :math:`\alpha` in the bi-invariant distance .
60145 beta : float, default=0.0
61- Trace/logdet weight in variance computation.
146+ Trace / log-determinant weight :math:`\beta` in the bi-invariant
147+ distance. Must satisfy :math:`\min(\alpha, \alpha + n\beta) > 0`.
62148 momentum : float, default=0.1
63- Running statistics momentum.
149+ Momentum :math:`\gamma` for exponential moving average of running
150+ statistics.
64151 eps : float, default=1e-5
65- Numerical stability constant for variance normalization.
152+ Numerical stability constant :math:`\epsilon` added to the variance
153+ before taking the square root.
66154 n_iter : int, default=1
67- Number of Karcher flow iterations used by the AIM mean.
155+ Number of Karcher flow iterations for the AIM Fréchet mean.
156+ Ignored by LEM and LCM (which use arithmetic means).
68157 congruence : {"cholesky", "eig"}, default="cholesky"
69158 Implementation of the AIM congruence action (centering/biasing).
70159 ``"cholesky"`` uses the Cholesky factor :math:`L` of :math:`P` to
71- compute :math:`L X L^T ` (as in the original LieBN paper).
160+ compute :math:`L X L^\top ` (as in the original LieBN paper).
72161 ``"eig"`` uses eigendecomposition-based :math:`M^{-1/2} X M^{-1/2}`
73162 (matching :func:`~spd_learn.functional.spd_centering`).
74163 Both are mathematically equivalent; Cholesky is typically faster,
@@ -79,11 +168,56 @@ class SPDBatchNormLie(nn.Module):
79168 Device on which to create parameters and buffers.
80169 dtype : torch.dtype, optional
81170 Data type of parameters and buffers.
171+
172+ Attributes
173+ ----------
174+ bias : nn.Parameter
175+ Learnable SPD bias matrix :math:`B \in \mathcal{S}_{++}^n`,
176+ parametrized via :class:`~spd_learn.modules.SymmetricPositiveDefinite`.
177+ Initialized to the identity.
178+ shift : nn.Parameter
179+ Learnable positive scalar :math:`s > 0`,
180+ parametrized via :class:`~spd_learn.modules.PositiveDefiniteScalar`.
181+ Initialized to 1.
182+ running_mean : torch.Tensor
183+ Exponential moving average of the batch Fréchet mean.
184+ running_var : torch.Tensor
185+ Exponential moving average of the batch variance.
186+
187+ See Also
188+ --------
189+ :class:`SPDBatchNormMean` :
190+ Mean-only Riemannian batch normalization (AIRM centering without
191+ variance normalization) :cite:p:`brooks2019riemannian`.
192+ :class:`SPDBatchNormMeanVar` :
193+ Full Riemannian batch normalization under the AIRM
194+ :cite:p:`kobler2022spd`.
195+ :func:`~spd_learn.functional.frechet_mean` :
196+ Fréchet mean via Karcher flow (used internally for AIM).
197+ :func:`~spd_learn.functional.lie_group_variance` :
198+ Bi-invariant Fréchet variance computation.
199+
200+ References
201+ ----------
202+ .. bibliography::
203+ :filter: key == "chen2024liebn"
204+
205+ Examples
206+ --------
207+ >>> import torch
208+ >>> from spd_learn.modules import SPDBatchNormLie
209+ >>> bn = SPDBatchNormLie(num_features=4, metric="AIM")
210+ >>> X = torch.randn(8, 4, 4, dtype=torch.float64)
211+ >>> X = X @ X.mT + 0.1 * torch.eye(4, dtype=torch.float64)
212+ >>> bn = bn.to(dtype=torch.float64)
213+ >>> Y = bn(X)
214+ >>> Y.shape
215+ torch.Size([8, 4, 4])
82216 """
83217
84218 def __init__ (
85219 self ,
86- n ,
220+ num_features ,
87221 metric = "AIM" ,
88222 theta = 1.0 ,
89223 alpha = 1.0 ,
@@ -105,7 +239,7 @@ def __init__(
105239 raise ValueError (
106240 f"congruence must be 'cholesky' or 'eig', got '{ congruence } '"
107241 )
108- self .n = n
242+ self .num_features = num_features
109243 self .metric = metric
110244 self .theta = theta
111245 self .alpha = alpha
@@ -115,18 +249,20 @@ def __init__(
115249 self .n_iter = n_iter
116250 self .congruence = congruence
117251
118- self .bias = nn .Parameter (torch .empty (1 , n , n , device = device , dtype = dtype ))
252+ self .bias = nn .Parameter (
253+ torch .empty (1 , num_features , num_features , device = device , dtype = dtype )
254+ )
119255 self .shift = nn .Parameter (torch .empty ((), device = device , dtype = dtype ))
120256
121257 if metric == "AIM" :
122258 self .register_buffer (
123259 "running_mean" ,
124- torch .eye (n , device = device , dtype = dtype ).unsqueeze (0 ),
260+ torch .eye (num_features , device = device , dtype = dtype ).unsqueeze (0 ),
125261 )
126262 else :
127263 self .register_buffer (
128264 "running_mean" ,
129- torch .zeros (1 , n , n , device = device , dtype = dtype ),
265+ torch .zeros (1 , num_features , num_features , device = device , dtype = dtype ),
130266 )
131267 self .register_buffer ("running_var" , torch .ones ((), device = device , dtype = dtype ))
132268
@@ -238,7 +374,7 @@ def forward(self, X):
238374
239375 def extra_repr (self ):
240376 return (
241- f"n ={ self .n } , metric={ self .metric } , theta= { self . theta } , "
242- f"alpha ={ self .alpha } , beta ={ self .beta } , momentum ={ self .momentum } , "
243- f"congruence={ self .congruence } "
377+ f"num_features ={ self .num_features } , metric={ self .metric } , "
378+ f"theta ={ self .theta } , alpha ={ self .alpha } , beta ={ self .beta } , "
379+ f"momentum= { self . momentum } , congruence={ self .congruence } "
244380 )
0 commit comments