Skip to content

Commit 6caa873

Browse files
fix minor shape bug
1 parent 6eb49fb commit 6caa873

1 file changed

Lines changed: 15 additions & 9 deletions

File tree

pina/_src/model/vectorized_spline.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,9 @@ def __init__(
222222
)
223223

224224
# Precompute boundary interval index
225-
self._boundary_interval_idx = self._compute_boundary_interval()
225+
self.register_buffer(
226+
"_boundary_interval_idx", self._compute_boundary_interval()
227+
)
226228

227229
# Precompute denominators used in derivative formulas
228230
self._compute_derivative_denominators()
@@ -252,7 +254,7 @@ def _compute_boundary_interval(self):
252254
idx[s] = valid_s[-1, 0] if valid_s.numel() > 0 else 0
253255

254256
return idx
255-
257+
256258
def _compute_derivative_denominators(self):
257259
"""
258260
Precompute the denominators used in the derivatives for all orders up to
@@ -334,8 +336,10 @@ def basis(self, x, collection=False):
334336
knot_right = self.knots[range_tensor, self._boundary_interval_idx + 1]
335337

336338
# Identify points at the rightmost boundary
337-
at_rightmost_boundary = (x >= knot_left.unsqueeze(0)) & torch.isclose(
338-
x, knot_right.unsqueeze(0), rtol=1e-8, atol=1e-10
339+
at_rightmost_boundary = (
340+
x.squeeze(-1) >= knot_left.unsqueeze(0)
341+
) & torch.isclose(
342+
x.squeeze(-1), knot_right.unsqueeze(0), rtol=1e-8, atol=1e-10
339343
)
340344

341345
# Ensure the correct value is set at the rightmost boundary
@@ -408,12 +412,12 @@ def forward(self, x):
408412
out = out.squeeze(-1)
409413

410414
return out
411-
415+
412416
def derivative(self, x, degree):
413417
"""
414418
Compute the ``degree``-th derivative of each univariate spline at the
415-
given input points.
416-
419+
given input points.
420+
417421
The output has shape ``[batch, s, o]``, where ``o`` is the output
418422
dimension of each univariate spline, unless an aggregation method is
419423
specified. If both ``s`` and ``o`` are 1, the output is aggregated
@@ -472,7 +476,7 @@ def _basis_derivative(self, x, degree):
472476

473477
# Iterate over basis orders
474478
for o in range(2, self.order + 1):
475-
479+
476480
# Retrieve precomputed factors
477481
left_fac = getattr(self, f"_left_factor_order_{o}")
478482
right_fac = getattr(self, f"_right_factor_order_{o}")
@@ -640,7 +644,9 @@ def knots(self, value):
640644

641645
# Recompute boundary interval when knots change
642646
if hasattr(self, "_boundary_interval_idx"):
643-
self._boundary_interval_idx = self._compute_boundary_interval()
647+
self.register_buffer(
648+
"_boundary_interval_idx", self._compute_boundary_interval()
649+
)
644650

645651
# Recompute derivative denominators when knots change
646652
self._compute_derivative_denominators()

0 commit comments

Comments
 (0)