Skip to content

Commit f3e0414

Browse files
updating the pre-commit
1 parent 1e50bc3 commit f3e0414

3 files changed

Lines changed: 18 additions & 8 deletions

File tree

examples/howto/plot_howto_add_batchnorm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
from spd_learn.modules import BiMap, LogEig, ReEig, SPDBatchNormLie
2727

28+
2829
######################################################################
2930
# Step 1: Choose Your Normalization Layer
3031
# ----------------------------------------

examples/howto/plot_howto_choose_metric.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,10 @@
4545
import torch
4646

4747
from pyriemann.datasets import make_gaussian_blobs
48+
4849
from spd_learn.modules import SPDBatchNormLie
4950

51+
5052
torch.manual_seed(42)
5153

5254
# Generate SPD data using pyriemann (2-class, 2*n_matrices total samples)
@@ -87,7 +89,9 @@
8789
_ = bn(X_bench)
8890
elapsed = (time.time() - t0) / 20
8991
timings[metric] = elapsed * 1000
90-
print(f"{metric}: {elapsed*1000:.1f} ms/batch ({n_bench}x{n_bench}, batch={batch_size})")
92+
print(
93+
f"{metric}: {elapsed * 1000:.1f} ms/batch ({n_bench}x{n_bench}, batch={batch_size})"
94+
)
9195

9296
fig, ax = plt.subplots(figsize=(6, 4))
9397
ax.bar(
@@ -177,7 +181,9 @@
177181
color=metric_colors[metric],
178182
alpha=0.8,
179183
)
180-
ax.set_title(f"{metric} Eigenvalues", fontweight="bold", color=metric_colors[metric])
184+
ax.set_title(
185+
f"{metric} Eigenvalues", fontweight="bold", color=metric_colors[metric]
186+
)
181187
ax.set_xlabel("Index")
182188

183189
plt.suptitle(
@@ -231,8 +237,7 @@
231237
ax.grid(True, alpha=0.3)
232238

233239
print(
234-
f"AIM (theta={theta}): eigval range "
235-
f"[{eigvals.min():.3f}, {eigvals.max():.3f}]"
240+
f"AIM (theta={theta}): eigval range [{eigvals.min():.3f}, {eigvals.max():.3f}]"
236241
)
237242

238243
axes[0].set_ylabel("Eigenvalue")

examples/tutorials/tutorial_05_batch_normalization.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
SPDBatchNormMeanVar,
4040
)
4141

42+
4243
torch.manual_seed(42)
4344
np.random.seed(42)
4445

@@ -224,8 +225,12 @@ def train_model(model, X_train, y_train, X_test, y_test, epochs=150, lr=5e-3):
224225
for metric, (_, accs) in liebn_results.items():
225226
epochs, vals = zip(*accs)
226227
ax2.plot(
227-
epochs, vals, "o-", label=f"LieBN ({metric})",
228-
color=colors[metric], markersize=3,
228+
epochs,
229+
vals,
230+
"o-",
231+
label=f"LieBN ({metric})",
232+
color=colors[metric],
233+
markersize=3,
229234
)
230235
ax2.set_xlabel("Epoch")
231236
ax2.set_ylabel("Test Accuracy")
@@ -328,8 +333,7 @@ def train_model(model, X_train, y_train, X_test, y_test, epochs=150, lr=5e-3):
328333
loss.backward()
329334
eigvals = torch.linalg.eigvalsh(out.detach())
330335
print(
331-
f"{metric}: min_eigval={eigvals.min():.2e}, "
332-
f"grad_norm={X_check.grad.norm():.4f}"
336+
f"{metric}: min_eigval={eigvals.min():.2e}, grad_norm={X_check.grad.norm():.4f}"
333337
)
334338
X_check.grad = None
335339

0 commit comments

Comments
 (0)