Skip to content

Commit 76b21a9

Browse files
Address SPDIM review comments
1 parent e0a1324 commit 76b21a9

1 file changed

Lines changed: 30 additions & 51 deletions

File tree

examples/applied_examples/plot_source_free_domain.py

Lines changed: 30 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
#
3939
# \tilde{C}_i = \bar{C}_j^{-1/2} \, C_i \, \bar{C}_j^{-1/2}
4040
#
41-
# where :math:`\bar{C}_j` is the Frechet mean of the target domain.
41+
# where :math:`\bar{C}_j` is the Fréchet mean of the target domain.
4242
# However, the paper's **Proposition 2** shows that RCT only compensates
4343
# conditional shift when the label priors are identical across domains.
4444
# Under label shift, :math:`\bar{C}_j` is biased toward the
@@ -57,8 +57,9 @@
5757
# \tilde{C}_i = \Phi_j^{1/2} \, \bar{C}_j^{-1/2}
5858
# \, C_i \, \bar{C}_j^{-1/2} \, \Phi_j^{1/2}
5959
#
60-
# It uses ``karcher_mean`` for initialization and
61-
# a log-space parameterization for optimization with standard Adam.
60+
# It initializes :math:`\Phi_j` with the target Fréchet mean and
61+
# optimizes it as an SPD-constrained parameter via
62+
# ``torch.nn.utils.parametrize`` and ``SymmetricPositiveDefinite``.
6263
#
6364
# SPDIM optimizes the **IM loss** (Eq. 21):
6465
#
@@ -67,7 +68,10 @@
6768
# \mathcal{L}_{\mathrm{IM}} = \underbrace{H(Y | X)}_{\text{conditional
6869
# entropy}} - \underbrace{H(\bar{Y})}_{\text{marginal entropy}}
6970
#
70-
# This encourages confident predictions (low :math:`H(Y|X)`) while
71+
# Here, :math:`H(Y \mid X)` is the conditional entropy of the model
72+
# predictions for each target sample, and :math:`H(\bar{Y})` is the
73+
# entropy of the average predictive distribution across the target set.
74+
# This encourages confident predictions (low :math:`H(Y \mid X)`) while
7175
# maintaining class diversity (high :math:`H(\bar{Y})`).
7276
#
7377

@@ -98,27 +102,12 @@
98102
get_epsilon,
99103
matrix_exp,
100104
matrix_log,
101-
matrix_power,
102105
matrix_sqrt_inv,
103106
)
104107

105108

106-
def geodesic_transport_to_identity(X, mean, t):
107-
r"""Transport SPD matrices along the geodesic toward identity.
108-
109-
.. math::
110-
111-
T_t(X) = A^{-t/2} \, X \, A^{-t/2}
112-
113-
When ``t = 0``: no transport. When ``t = 1``: full centering (RCT).
114-
"""
115-
t_tensor = torch.as_tensor(t, dtype=X.dtype, device=X.device)
116-
mean_pow = matrix_power.apply(mean, t_tensor * (-0.5))
117-
return mean_pow @ X @ mean_pow
118-
119-
120-
def karcher_mean(X, max_iter=50, return_distances=False):
121-
r"""Compute the Frechet (Karcher) mean under the AIRM.
109+
def frechet_mean(X, max_iter=50, return_distances=False):
110+
r"""Compute the Fréchet mean under the AIRM.
122111
123112
.. math::
124113
@@ -263,7 +252,7 @@ def karcher_mean(X, max_iter=50, return_distances=False):
263252
# Following the paper's ``get_label_ratio`` protocol with
264253
# ``ratio_level=0.2``: we keep all samples of the last class and
265254
# subsample the other class(es) to 20%. This creates a 5:1 class
266-
# imbalance, making the Frechet mean biased toward the majority class.
255+
# imbalance, making the Fréchet mean biased toward the majority class.
267256
#
268257
# As shown by the paper's **Proposition 2**, this biased mean causes
269258
# RCT to misalign: the recentered features no longer align with
@@ -448,12 +437,12 @@ def spdim_forward(model, X_spd, adapter=None):
448437
# -------------------------------------------------
449438
#
450439
# The **Recentering Transform (RCT)** :cite:p:`zanini2017transfer`
451-
# baseline recomputes the Frechet mean and variance on target SPD
440+
# baseline recomputes the Fréchet mean and variance on target SPD
452441
# features using the full Karcher flow. This corresponds to setting
453442
# :math:`\varphi = 1` (standard centering) in the geodesic transport.
454443
#
455444
# Under label shift, Proposition 2 predicts that this will degrade
456-
# performance because the biased Frechet mean shifts features away
445+
# performance because the biased Fréchet mean shifts features away
457446
# from the source decision boundary.
458447
#
459448

@@ -462,10 +451,10 @@ def spdim_forward(model, X_spd, adapter=None):
462451
orig_running_var = underlying_model.spdbnorm.running_var.clone()
463452

464453

465-
def refit_spdbn_karcher(model, X_data, batch_size=32):
466-
"""Refit SPDBatchNormMeanVar using full Karcher mean (SPDIM style)."""
454+
def refit_spdbn_frechet(model, X_data, batch_size=32):
455+
"""Refit SPDBatchNormMeanVar using the Fréchet mean (SPDIM style)."""
467456
X_spd = extract_spd_features(model, X_data, batch_size=batch_size)
468-
mean, distances = karcher_mean(X_spd, max_iter=50, return_distances=True)
457+
mean, distances = frechet_mean(X_spd, max_iter=50, return_distances=True)
469458
variance = distances.square().mean(dim=0, keepdim=True).squeeze()
470459
with torch.no_grad():
471460
model.spdbnorm.running_mean.copy_(mean)
@@ -476,9 +465,9 @@ def refit_spdbn_karcher(model, X_data, batch_size=32):
476465
print("SFUDA Step 1: Refit BN Statistics (RCT)")
477466
print(f"{'=' * 50}")
478467

479-
refit_spdbn_karcher(underlying_model, X_target_shifted)
468+
refit_spdbn_frechet(underlying_model, X_target_shifted)
480469

481-
target_karcher_mean = underlying_model.spdbnorm.running_mean.clone()
470+
target_frechet_mean = underlying_model.spdbnorm.running_mean.clone()
482471

483472
rct_pred = clf.predict(target_shifted_ds)
484473
rct_bacc = balanced_accuracy_score(y_target_shifted, rct_pred)
@@ -514,16 +503,12 @@ def im_loss(logits, temperature=2.0):
514503
# --------------------
515504
#
516505
# SPDIM(bias) (Eq. 19) learns a full SPD reference mean that replaces
517-
# the (biased) Frechet mean in the geodesic transport. With
506+
# the (biased) Fréchet mean in the geodesic transport. With
518507
# :math:`D(D+1)/2` degrees of freedom (vs 1 scalar for geodesic), it
519508
# can compensate both conditional and label shift.
520509
#
521-
# We parameterize the learnable mean in log space:
522-
# :math:`M = \exp(S)` where :math:`S` is an unconstrained symmetric
523-
# matrix. This guarantees :math:`M \in \mathcal{S}_{++}^D` and allows
524-
# standard Adam optimization via
525-
# :func:`~spd_learn.functional.matrix_exp` /
526-
# :func:`~spd_learn.functional.matrix_log`.
510+
# We initialize the learnable mean with the target Fréchet mean and
511+
# keep it on the SPD manifold via ``SymmetricPositiveDefinite``.
527512
#
528513
# Learnable SPD Recenter Module
529514
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -569,25 +554,21 @@ def forward(self, input):
569554
#
570555

571556
print(f"\n{'=' * 50}")
572-
print("SPDIM(bias): Learnable SPD Mean (Log-Space Parameterization)")
557+
print("SPDIM(bias): Learnable SPD Mean")
573558
print(f"{'=' * 50}")
574559

575-
# Parameterize in log space: M = exp(S), initialized from Karcher mean
576-
with torch.no_grad():
577-
S_init = matrix_log.apply(target_karcher_mean.squeeze(0).unsqueeze(0))
578-
579560
X_spd_target = extract_spd_features(underlying_model, X_target_shifted, batch_size=32)
580561

581-
print(f"Log-space parameter S initialized. Shape: {target_karcher_mean.shape}")
562+
print(f"SPD reference initialized. Shape: {target_frechet_mean.shape}")
582563

583-
adapter = SPDLearnableRecenter(target_karcher_mean.shape[-1])
584-
adapter.bias = target_karcher_mean.clone()
564+
adapter = SPDLearnableRecenter(target_frechet_mean.shape[-1])
565+
adapter.bias = target_frechet_mean.clone()
585566

586567
optimizer_bias = torch.optim.Adam(adapter.parameters(), lr=0.05)
587568
n_epochs_bias = 30 # Reduced from 200 for faster documentation build
588569
losses_bias = []
589570
best_loss_bias = float("inf")
590-
best_S = target_karcher_mean.clone().detach()
571+
best_bias = target_frechet_mean.clone().detach()
591572
for epoch in range(n_epochs_bias):
592573
optimizer_bias.zero_grad()
593574
logits = spdim_forward(
@@ -603,14 +584,14 @@ def forward(self, input):
603584
losses_bias.append(current_loss)
604585
if current_loss < best_loss_bias:
605586
best_loss_bias = current_loss
606-
best_S = adapter.bias.clone().detach()
587+
best_bias = adapter.bias.clone().detach()
607588

608589
if (epoch + 1) % 10 == 0 or epoch == 0:
609590
print(f" Epoch {epoch + 1:3d}/{n_epochs_bias}: loss={current_loss:.4f}")
610591

611592
# Evaluate with best parameters
612593
with torch.no_grad():
613-
adapter.bias = best_S
594+
adapter.bias = best_bias
614595
logits = spdim_forward(
615596
underlying_model,
616597
X_spd_target,
@@ -735,8 +716,8 @@ def forward(self, input):
735716
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
736717
#
737718
# The Recentering Transform (RCT) :cite:p:`zanini2017transfer` computes
738-
# the Frechet mean of the target domain and uses it to center the SPD
739-
# features. Under **label shift**, the Frechet mean is biased toward
719+
# the Fréchet mean of the target domain and uses it to center the SPD
720+
# features. Under **label shift**, the Fréchet mean is biased toward
740721
# the over-represented class (here, *feet* at 83% of samples).
741722
#
742723
# As predicted by **Proposition 2** of the paper, this biased mean
@@ -754,8 +735,6 @@ def forward(self, input):
754735
# - **Best-model tracking**: Returns the parameter with lowest IM loss.
755736
# - **Test-time BN**: Only geodesic transport (no dispersion
756737
# normalization), matching the original SPDIM test-time pipeline.
757-
# - **Float64 for adaptation**: SPD features are cast to float64 for
758-
# numerical stability in eigendecompositions and matrix logarithms.
759738
#
760739
# References
761740
# ----------

0 commit comments

Comments
 (0)