@@ -222,7 +222,9 @@ def __init__(
222222 )
223223
224224 # Precompute boundary interval index
225- self ._boundary_interval_idx = self ._compute_boundary_interval ()
225+ self .register_buffer (
226+ "_boundary_interval_idx" , self ._compute_boundary_interval ()
227+ )
226228
227229 # Precompute denominators used in derivative formulas
228230 self ._compute_derivative_denominators ()
@@ -252,7 +254,7 @@ def _compute_boundary_interval(self):
252254 idx [s ] = valid_s [- 1 , 0 ] if valid_s .numel () > 0 else 0
253255
254256 return idx
255-
257+
256258 def _compute_derivative_denominators (self ):
257259 """
258260 Precompute the denominators used in the derivatives for all orders up to
@@ -334,8 +336,10 @@ def basis(self, x, collection=False):
334336 knot_right = self .knots [range_tensor , self ._boundary_interval_idx + 1 ]
335337
336338 # Identify points at the rightmost boundary
337- at_rightmost_boundary = (x >= knot_left .unsqueeze (0 )) & torch .isclose (
338- x , knot_right .unsqueeze (0 ), rtol = 1e-8 , atol = 1e-10
339+ at_rightmost_boundary = (
340+ x .squeeze (- 1 ) >= knot_left .unsqueeze (0 )
341+ ) & torch .isclose (
342+ x .squeeze (- 1 ), knot_right .unsqueeze (0 ), rtol = 1e-8 , atol = 1e-10
339343 )
340344
341345 # Ensure the correct value is set at the rightmost boundary
@@ -408,12 +412,12 @@ def forward(self, x):
408412 out = out .squeeze (- 1 )
409413
410414 return out
411-
415+
412416 def derivative (self , x , degree ):
413417 """
414418 Compute the ``degree``-th derivative of each univariate spline at the
415- given input points.
416-
419+ given input points.
420+
417421 The output has shape ``[batch, s, o]``, where ``o`` is the output
418422 dimension of each univariate spline, unless an aggregation method is
419423 specified. If both ``s`` and ``o`` are 1, the output is aggregated
@@ -472,7 +476,7 @@ def _basis_derivative(self, x, degree):
472476
473477 # Iterate over basis orders
474478 for o in range (2 , self .order + 1 ):
475-
479+
476480 # Retrieve precomputed factors
477481 left_fac = getattr (self , f"_left_factor_order_{ o } " )
478482 right_fac = getattr (self , f"_right_factor_order_{ o } " )
@@ -640,7 +644,9 @@ def knots(self, value):
640644
641645 # Recompute boundary interval when knots change
642646 if hasattr (self , "_boundary_interval_idx" ):
643- self ._boundary_interval_idx = self ._compute_boundary_interval ()
647+ self .register_buffer (
648+ "_boundary_interval_idx" , self ._compute_boundary_interval ()
649+ )
644650
645651 # Recompute derivative denominators when knots change
646652 self ._compute_derivative_denominators ()
0 commit comments