@@ -693,6 +693,104 @@ where :math:`\frechet` is the Fréchet mean of the batch.
693693See :ref: `sphx_glr_generated_auto_examples_visualizations_plot_batchnorm_animation.py `
694694
695695
696+ Batch Normalization on SPD Manifolds
697+ =====================================
698+
699+ In Euclidean deep learning, batch normalization centers activations to zero mean
700+ and unit variance, stabilizing gradient flow and accelerating convergence. On the
701+ SPD manifold, the same principle applies — but "mean" and "variance" must respect
702+ the curved Riemannian geometry.
703+
704+ Why Euclidean BN Fails for SPD Matrices
705+ ----------------------------------------
706+
707+ Standard batch normalization computes :math: `\hat {x} = (x - \mu ) / \sigma `. For SPD
708+ matrices this is problematic:
709+
710+ - **Subtraction breaks SPD **: :math: `X - M` (with :math: `M` the arithmetic mean) may not
711+ be positive definite.
712+ - **The swelling effect **: The Euclidean mean of SPD matrices can have a larger determinant
713+ than any individual matrix, distorting the data distribution.
714+ - **Scale mismatch **: SPD matrices from different subjects or sessions can have vastly
715+ different spectral profiles; Euclidean normalization ignores this geometric structure.
716+
717+ Riemannian Batch Normalization
718+ -------------------------------
719+
720+ :class: `~spd_learn.modules.SPDBatchNormMeanVar ` addresses these issues by replacing
721+ Euclidean operations with their Riemannian counterparts under the AIRM:
722+
723+ 1. **Centering **: Compute the Fréchet mean :math: `\frechet ` of the batch, then
724+ apply congruence :math: `\tilde {X}_i = \frechet ^{-1 /2 } X_i \frechet ^{-1 /2 }` to center
725+ the batch around the identity matrix.
726+ 2. **Variance scaling **: Compute a scalar dispersion and normalize by a learnable weight.
727+ 3. **Biasing **: Apply a learnable SPD bias via congruence.
728+
729+ This preserves the SPD structure at every step.
730+
731+ Lie Group Batch Normalization (LieBN)
732+ --------------------------------------
733+
734+ :class: `~spd_learn.modules.SPDBatchNormLie ` :cite:p: `chen2024liebn ` generalizes
735+ Riemannian BN by exploiting the Lie group structure of :math: `\spd `. The key insight
736+ is that each Riemannian metric induces a different group action for centering and biasing.
737+
738+ The LieBN forward pass follows five steps:
739+
740+ 1. **Deformation ** — Map SPD matrices to a codomain via the metric
741+ (e.g., :math: `\log (X)` for LEM, Cholesky + log-diagonal for LCM, :math: `X^\theta ` for AIM).
742+ 2. **Centering ** — Translate the batch to zero/identity mean using the group action.
743+ 3. **Scaling ** — Normalize variance by a learnable dispersion parameter.
744+ 4. **Biasing ** — Translate by a learnable location parameter.
745+ 5. **Inverse deformation ** — Map back to the SPD manifold.
746+
747+ .. list-table ::
748+ :header-rows: 1
749+ :widths: 15 25 25 25
750+
751+ * - Metric
752+ - Deformation
753+ - Mean Computation
754+ - Group Action
755+ * - **LEM **
756+ - :math: `\log (X)`
757+ - Euclidean (closed-form)
758+ - Additive
759+ * - **LCM **
760+ - Cholesky + log-diag
761+ - Euclidean (closed-form)
762+ - Additive
763+ * - **AIM **
764+ - :math: `X^\theta `
765+ - Karcher (iterative)
766+ - Cholesky congruence
767+
768+ **Choosing a metric for batch normalization: **
769+
770+ - **LEM **: Fastest (closed-form mean), good default for most tasks.
771+ - **AIM **: Full affine invariance, best when data scale varies (e.g., cross-subject EEG).
772+ - **LCM **: Fast like LEM, with Cholesky-based numerical stability.
773+
774+ .. code-block :: python
775+
776+ from spd_learn.modules import SPDBatchNormLie
777+
778+ # LEM is the fastest — good default
779+ bn_lem = SPDBatchNormLie(num_features = 32 , metric = " LEM" )
780+
781+ # AIM for affine-invariant normalization
782+ bn_aim = SPDBatchNormLie(num_features = 32 , metric = " AIM" , theta = 1.0 )
783+
784+ # LCM for Cholesky stability
785+ bn_lcm = SPDBatchNormLie(num_features = 32 , metric = " LCM" )
786+
787+ .. seealso ::
788+
789+ :ref: `tutorial-batch-normalization ` — Hands-on tutorial comparing all BN strategies,
790+ :ref: `howto-add-batchnorm ` — Quick integration guide,
791+ :ref: `liebn-batch-normalization ` — Full benchmark reproduction across 3 datasets
792+
793+
696794References
697795==========
698796
0 commit comments