3838#
3939# \tilde{C}_i = \bar{C}_j^{-1/2} \, C_i \, \bar{C}_j^{-1/2}
4040#
41- # where :math:`\bar{C}_j` is the Frechet mean of the target domain.
41+ # where :math:`\bar{C}_j` is the Fréchet mean of the target domain.
4242# However, the paper's **Proposition 2** shows that RCT only compensates
4343# conditional shift when the label priors are identical across domains.
4444# Under label shift, :math:`\bar{C}_j` is biased toward the
5757# \tilde{C}_i = \Phi_j^{1/2} \, \bar{C}_j^{-1/2}
5858# \, C_i \, \bar{C}_j^{-1/2} \, \Phi_j^{1/2}
5959#
60- # It uses ``karcher_mean`` for initialization and
61- # a log-space parameterization for optimization with standard Adam.
60+ # It initializes :math:`\Phi_j` with the target Fréchet mean and
61+ # optimizes it as an SPD-constrained parameter via
62+ # ``torch.nn.utils.parametrize`` and ``SymmetricPositiveDefinite``.
6263#
6364# SPDIM optimizes the **IM loss** (Eq. 21):
6465#
6768# \mathcal{L}_{\mathrm{IM}} = \underbrace{H(Y | X)}_{\text{conditional
6869# entropy}} - \underbrace{H(\bar{Y})}_{\text{marginal entropy}}
6970#
70- # This encourages confident predictions (low :math:`H(Y|X)`) while
71+ # Here, :math:`H(Y \mid X)` is the conditional entropy of the model
72+ # predictions for each target sample, and :math:`H(\bar{Y})` is the
73+ # entropy of the average predictive distribution across the target set.
74+ # This encourages confident predictions (low :math:`H(Y \mid X)`) while
7175# maintaining class diversity (high :math:`H(\bar{Y})`).
7276#
7377
98102 get_epsilon ,
99103 matrix_exp ,
100104 matrix_log ,
101- matrix_power ,
102105 matrix_sqrt_inv ,
103106)
104107
105108
106- def geodesic_transport_to_identity (X , mean , t ):
107- r"""Transport SPD matrices along the geodesic toward identity.
108-
109- .. math::
110-
111- T_t(X) = A^{-t/2} \, X \, A^{-t/2}
112-
113- When ``t = 0``: no transport. When ``t = 1``: full centering (RCT).
114- """
115- t_tensor = torch .as_tensor (t , dtype = X .dtype , device = X .device )
116- mean_pow = matrix_power .apply (mean , t_tensor * (- 0.5 ))
117- return mean_pow @ X @ mean_pow
118-
119-
120- def karcher_mean (X , max_iter = 50 , return_distances = False ):
121- r"""Compute the Frechet (Karcher) mean under the AIRM.
109+ def frechet_mean (X , max_iter = 50 , return_distances = False ):
110+ r"""Compute the Fréchet mean under the AIRM.
122111
123112 .. math::
124113
@@ -263,7 +252,7 @@ def karcher_mean(X, max_iter=50, return_distances=False):
263252# Following the paper's ``get_label_ratio`` protocol with
264253# ``ratio_level=0.2``: we keep all samples of the last class and
265254# subsample the other class(es) to 20%. This creates a 5:1 class
266- # imbalance, making the Frechet mean biased toward the majority class.
255+ # imbalance, making the Fréchet mean biased toward the majority class.
267256#
268257# As shown by the paper's **Proposition 2**, this biased mean causes
269258# RCT to misalign: the recentered features no longer align with
@@ -448,12 +437,12 @@ def spdim_forward(model, X_spd, adapter=None):
448437# -------------------------------------------------
449438#
450439# The **Recentering Transform (RCT)** :cite:p:`zanini2017transfer`
451- # baseline recomputes the Frechet mean and variance on target SPD
440+ # baseline recomputes the Fréchet mean and variance on target SPD
452441# features using the full Karcher flow. This corresponds to setting
453442# :math:`\varphi = 1` (standard centering) in the geodesic transport.
454443#
455444# Under label shift, Proposition 2 predicts that this will degrade
456- # performance because the biased Frechet mean shifts features away
445+ # performance because the biased Fréchet mean shifts features away
457446# from the source decision boundary.
458447#
459448
@@ -462,10 +451,10 @@ def spdim_forward(model, X_spd, adapter=None):
462451orig_running_var = underlying_model .spdbnorm .running_var .clone ()
463452
464453
465- def refit_spdbn_karcher (model , X_data , batch_size = 32 ):
466- """Refit SPDBatchNormMeanVar using full Karcher mean (SPDIM style)."""
454+ def refit_spdbn_frechet (model , X_data , batch_size = 32 ):
455+ """Refit SPDBatchNormMeanVar using the Fréchet mean (SPDIM style)."""
467456 X_spd = extract_spd_features (model , X_data , batch_size = batch_size )
468- mean , distances = karcher_mean (X_spd , max_iter = 50 , return_distances = True )
457+ mean , distances = frechet_mean (X_spd , max_iter = 50 , return_distances = True )
469458 variance = distances .square ().mean (dim = 0 , keepdim = True ).squeeze ()
470459 with torch .no_grad ():
471460 model .spdbnorm .running_mean .copy_ (mean )
@@ -476,9 +465,9 @@ def refit_spdbn_karcher(model, X_data, batch_size=32):
476465print ("SFUDA Step 1: Refit BN Statistics (RCT)" )
477466print (f"{ '=' * 50 } " )
478467
479- refit_spdbn_karcher (underlying_model , X_target_shifted )
468+ refit_spdbn_frechet (underlying_model , X_target_shifted )
480469
481- target_karcher_mean = underlying_model .spdbnorm .running_mean .clone ()
470+ target_frechet_mean = underlying_model .spdbnorm .running_mean .clone ()
482471
483472rct_pred = clf .predict (target_shifted_ds )
484473rct_bacc = balanced_accuracy_score (y_target_shifted , rct_pred )
@@ -514,16 +503,12 @@ def im_loss(logits, temperature=2.0):
514503# --------------------
515504#
516505# SPDIM(bias) (Eq. 19) learns a full SPD reference mean that replaces
517- # the (biased) Frechet mean in the geodesic transport. With
506+ # the (biased) Fréchet mean in the geodesic transport. With
518507# :math:`D(D+1)/2` degrees of freedom (vs 1 scalar for geodesic), it
519508# can compensate both conditional and label shift.
520509#
521- # We parameterize the learnable mean in log space:
522- # :math:`M = \exp(S)` where :math:`S` is an unconstrained symmetric
523- # matrix. This guarantees :math:`M \in \mathcal{S}_{++}^D` and allows
524- # standard Adam optimization via
525- # :func:`~spd_learn.functional.matrix_exp` /
526- # :func:`~spd_learn.functional.matrix_log`.
510+ # We initialize the learnable mean with the target Fréchet mean and
511+ # keep it on the SPD manifold via ``SymmetricPositiveDefinite``.
527512#
528513# Learnable SPD Recenter Module
529514# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -569,25 +554,21 @@ def forward(self, input):
569554#
570555
571556print (f"\n { '=' * 50 } " )
572- print ("SPDIM(bias): Learnable SPD Mean (Log-Space Parameterization) " )
557+ print ("SPDIM(bias): Learnable SPD Mean" )
573558print (f"{ '=' * 50 } " )
574559
575- # Parameterize in log space: M = exp(S), initialized from Karcher mean
576- with torch .no_grad ():
577- S_init = matrix_log .apply (target_karcher_mean .squeeze (0 ).unsqueeze (0 ))
578-
579560X_spd_target = extract_spd_features (underlying_model , X_target_shifted , batch_size = 32 )
580561
581- print (f"Log-space parameter S initialized. Shape: { target_karcher_mean .shape } " )
562+ print (f"SPD reference initialized. Shape: { target_frechet_mean .shape } " )
582563
583- adapter = SPDLearnableRecenter (target_karcher_mean .shape [- 1 ])
584- adapter .bias = target_karcher_mean .clone ()
564+ adapter = SPDLearnableRecenter (target_frechet_mean .shape [- 1 ])
565+ adapter .bias = target_frechet_mean .clone ()
585566
586567optimizer_bias = torch .optim .Adam (adapter .parameters (), lr = 0.05 )
587568n_epochs_bias = 30 # Reduced from 200 for faster documentation build
588569losses_bias = []
589570best_loss_bias = float ("inf" )
590- best_S = target_karcher_mean .clone ().detach ()
571+ best_bias = target_frechet_mean .clone ().detach ()
591572for epoch in range (n_epochs_bias ):
592573 optimizer_bias .zero_grad ()
593574 logits = spdim_forward (
@@ -603,14 +584,14 @@ def forward(self, input):
603584 losses_bias .append (current_loss )
604585 if current_loss < best_loss_bias :
605586 best_loss_bias = current_loss
606- best_S = adapter .bias .clone ().detach ()
587+ best_bias = adapter .bias .clone ().detach ()
607588
608589 if (epoch + 1 ) % 10 == 0 or epoch == 0 :
609590 print (f" Epoch { epoch + 1 :3d} /{ n_epochs_bias } : loss={ current_loss :.4f} " )
610591
611592# Evaluate with best parameters
612593with torch .no_grad ():
613- adapter .bias = best_S
594+ adapter .bias = best_bias
614595 logits = spdim_forward (
615596 underlying_model ,
616597 X_spd_target ,
@@ -735,8 +716,8 @@ def forward(self, input):
735716# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
736717#
737718# The Recentering Transform (RCT) :cite:p:`zanini2017transfer` computes
738- # the Frechet mean of the target domain and uses it to center the SPD
739- # features. Under **label shift**, the Frechet mean is biased toward
719+ # the Fréchet mean of the target domain and uses it to center the SPD
720+ # features. Under **label shift**, the Fréchet mean is biased toward
740721# the over-represented class (here, *feet* at 83% of samples).
741722#
742723# As predicted by **Proposition 2** of the paper, this biased mean
@@ -754,8 +735,6 @@ def forward(self, input):
754735# - **Best-model tracking**: Returns the parameter with lowest IM loss.
755736# - **Test-time BN**: Only geodesic transport (no dispersion
756737# normalization), matching the original SPDIM test-time pipeline.
757- # - **Float64 for adaptation**: SPD features are cast to float64 for
758- # numerical stability in eigendecompositions and matrix logarithms.
759738#
760739# References
761740# ----------
0 commit comments