Skip to content

Commit 113efad

Browse files
Merge pull request #21 from GitZH-Chen/main
Integrate LieBNSPD into spd_learn
2 parents 4719156 + f3e0414 commit 113efad

18 files changed

Lines changed: 2936 additions & 85 deletions

docs/source/api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,7 @@ or related representations.
362362

363363
SPDBatchNormMean
364364
SPDBatchNormMeanVar
365+
SPDBatchNormLie
365366
BatchReNorm
366367

367368

docs/source/conf.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,18 @@
250250
from sphinx_gallery.sorting import ExplicitOrder
251251

252252

253+
def _reset_torch_defaults(gallery_conf, fname):
254+
"""Reset torch global state between sphinx-gallery examples.
255+
256+
Some examples call ``torch.set_default_dtype(torch.float64)`` which
257+
persists across examples when run in the same worker process and
258+
causes dtype-mismatch errors in subsequent examples.
259+
"""
260+
import torch
261+
262+
torch.set_default_dtype(torch.float32)
263+
264+
253265
sphinx_gallery_conf = {
254266
"examples_dirs": ["../../examples"],
255267
"gallery_dirs": ["generated/auto_examples"],
@@ -258,10 +270,11 @@
258270
# Point 3: Image optimization - compress images and reduce thumbnail size
259271
"compress_images": ("images", "thumbnails"),
260272
"thumbnail_size": (400, 280), # Smaller thumbnails for faster loading
261-
# Order: tutorials first, then visualizations, then applied examples
273+
# Order: tutorials, how-to guides, visualizations, then applied examples
262274
"subsection_order": ExplicitOrder(
263275
[
264276
"../../examples/tutorials",
277+
"../../examples/howto",
265278
"../../examples/visualizations",
266279
"../../examples/applied_examples",
267280
]
@@ -277,6 +290,8 @@
277290
# Include both plot_* files and tutorial_* files
278291
"filename_pattern": r"/(plot_|tutorial_)",
279292
"ignore_pattern": r"(__init__|spd_visualization_utils)\.py",
293+
# Reset torch default dtype between examples to prevent float64 leakage
294+
"reset_modules": ("matplotlib", "seaborn", _reset_torch_defaults),
280295
# Show signature link template (includes Colab launcher)
281296
"show_signature": False,
282297
# First cell in generated notebooks (for Colab compatibility)

docs/source/geometric_concepts.rst

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -693,6 +693,104 @@ where :math:`\frechet` is the Fréchet mean of the batch.
693693
See :ref:`sphx_glr_generated_auto_examples_visualizations_plot_batchnorm_animation.py`
694694

695695

696+
Batch Normalization on SPD Manifolds
697+
=====================================
698+
699+
In Euclidean deep learning, batch normalization centers activations to zero mean
700+
and unit variance, stabilizing gradient flow and accelerating convergence. On the
701+
SPD manifold, the same principle applies — but "mean" and "variance" must respect
702+
the curved Riemannian geometry.
703+
704+
Why Euclidean BN Fails for SPD Matrices
705+
----------------------------------------
706+
707+
Standard batch normalization computes :math:`\hat{x} = (x - \mu) / \sigma`. For SPD
708+
matrices this is problematic:
709+
710+
- **Subtraction breaks SPD**: :math:`X - M` (with :math:`M` the arithmetic mean) may not
711+
be positive definite.
712+
- **The swelling effect**: The Euclidean mean of SPD matrices can have a larger determinant
713+
than any individual matrix, distorting the data distribution.
714+
- **Scale mismatch**: SPD matrices from different subjects or sessions can have vastly
715+
different spectral profiles; Euclidean normalization ignores this geometric structure.
716+
717+
Riemannian Batch Normalization
718+
-------------------------------
719+
720+
:class:`~spd_learn.modules.SPDBatchNormMeanVar` addresses these issues by replacing
721+
Euclidean operations with their Riemannian counterparts under the AIRM:
722+
723+
1. **Centering**: Compute the Fréchet mean :math:`\frechet` of the batch, then
724+
apply congruence :math:`\tilde{X}_i = \frechet^{-1/2} X_i \frechet^{-1/2}` to center
725+
the batch around the identity matrix.
726+
2. **Variance scaling**: Compute a scalar dispersion and normalize by a learnable weight.
727+
3. **Biasing**: Apply a learnable SPD bias via congruence.
728+
729+
This preserves the SPD structure at every step.
730+
731+
Lie Group Batch Normalization (LieBN)
732+
--------------------------------------
733+
734+
:class:`~spd_learn.modules.SPDBatchNormLie` :cite:p:`chen2024liebn` generalizes
735+
Riemannian BN by exploiting the Lie group structure of :math:`\spd`. The key insight
736+
is that each Riemannian metric induces a different group action for centering and biasing.
737+
738+
The LieBN forward pass follows five steps:
739+
740+
1. **Deformation** — Map SPD matrices to a codomain via the metric
741+
(e.g., :math:`\log(X)` for LEM, Cholesky + log-diagonal for LCM, :math:`X^\theta` for AIM).
742+
2. **Centering** — Translate the batch to zero/identity mean using the group action.
743+
3. **Scaling** — Normalize variance by a learnable dispersion parameter.
744+
4. **Biasing** — Translate by a learnable location parameter.
745+
5. **Inverse deformation** — Map back to the SPD manifold.
746+
747+
.. list-table::
748+
:header-rows: 1
749+
:widths: 15 25 25 25
750+
751+
* - Metric
752+
- Deformation
753+
- Mean Computation
754+
- Group Action
755+
* - **LEM**
756+
- :math:`\log(X)`
757+
- Euclidean (closed-form)
758+
- Additive
759+
* - **LCM**
760+
- Cholesky + log-diag
761+
- Euclidean (closed-form)
762+
- Additive
763+
* - **AIM**
764+
- :math:`X^\theta`
765+
- Karcher (iterative)
766+
- Cholesky congruence
767+
768+
**Choosing a metric for batch normalization:**
769+
770+
- **LEM**: Fastest (closed-form mean), good default for most tasks.
771+
- **AIM**: Full affine invariance, best when data scale varies (e.g., cross-subject EEG).
772+
- **LCM**: Fast like LEM, with Cholesky-based numerical stability.
773+
774+
.. code-block:: python
775+
776+
from spd_learn.modules import SPDBatchNormLie
777+
778+
# LEM is the fastest — good default
779+
bn_lem = SPDBatchNormLie(num_features=32, metric="LEM")
780+
781+
# AIM for affine-invariant normalization
782+
bn_aim = SPDBatchNormLie(num_features=32, metric="AIM", theta=1.0)
783+
784+
# LCM for Cholesky stability
785+
bn_lcm = SPDBatchNormLie(num_features=32, metric="LCM")
786+
787+
.. seealso::
788+
789+
:ref:`tutorial-batch-normalization` — Hands-on tutorial comparing all BN strategies,
790+
:ref:`howto-add-batchnorm` — Quick integration guide,
791+
:ref:`liebn-batch-normalization` — Full benchmark reproduction across 3 datasets
792+
793+
696794
References
697795
==========
698796

docs/source/references.bib

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,14 @@ @inproceedings{kobler2022spd
138138
url={https://proceedings.neurips.cc/paper_files/paper/2022/hash/28ef7ee7cd3e03093acc39e1272411b7-Abstract-Conference.html}
139139
}
140140

141+
@inproceedings{chen2024liebn,
142+
title={A Lie Group Approach to Riemannian Batch Normalization},
143+
author={Chen, Ziheng and Song, Yue and Xu, Yunmei and Sebe, Nicu},
144+
booktitle={International Conference on Learning Representations},
145+
year={2024},
146+
url={https://openreview.net/forum?id=okYdj8Ysru}
147+
}
148+
141149
@inproceedings{pan2022matt,
142150
title={MAtt: A manifold attention network for EEG decoding},
143151
author={Pan, Yue-Ting and Chou, Jing-Lun and Wei, Chun-Shu},

0 commit comments

Comments
 (0)