@@ -24,9 +24,10 @@ class VectorizedSpline(nn.Module):
2424
2525 def __init__ (
2626 self ,
27- order : int ,
28- knots : torch .Tensor ,
29- control_points : torch .Tensor | None = None ,
27+ order ,
28+ knots ,
29+ control_points = None ,
30+ aggregate_output = None ,
3031 ):
3132 super ().__init__ ()
3233 if not isinstance (order , int ) or order <= 0 :
@@ -68,6 +69,7 @@ def __init__(
6869 # f"Last dim of control_points must be n_ctrl={n_ctrl}. Got {control_points.shape[-1]}."
6970 # )
7071 self .control_points = nn .Parameter (control_points , requires_grad = True )
72+ self .aggregate_output = aggregate_output
7173
7274 @staticmethod
7375 def _compute_boundary_interval_idx (knots : torch .Tensor ) -> int :
@@ -90,7 +92,8 @@ def basis(self, x: torch.Tensor) -> torch.Tensor:
9092 x = torch .as_tensor (x )
9193
9294 # ensure float dtype consistent
93- x = x .to (dtype = self .knots .dtype , device = self .knots .device )
95+ # x = x.to(dtype=self.knots.dtype, device=self.knots.device)
96+ x = x .to (dtype = self .knots .dtype , device = self .knots .device ).as_subclass (torch .Tensor )
9497
9598 # make x shape (..., 1) for broadcasting
9699 x_exp = x .unsqueeze (- 1 ) # (..., 1)
@@ -147,11 +150,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
147150 B = self .basis (x ) # (..., n_ctrl)
148151
149152 cp = self .control_points
153+ # print("vectorized forward, cp:", cp)
150154 if cp .ndim == 2 :
151155 # (S, n_ctrl)
152156 # want (..., S) = (..., n_ctrl) @ (n_ctrl, S)
157+ # print('B shape:', B.shape, 'cp shape:', cp.shape)
158+ #out = (B @ cp.transpose(0, 1)).squeeze(-1)
153159 out = B @ cp .transpose (0 , 1 )
154- return out
160+ # out = B @ cp[0]
155161 else :
156162 # (S, O, n_ctrl)
157163 # Compute for each S: (..., n_ctrl) @ (n_ctrl, O) -> (..., O), then stack over S
@@ -160,7 +166,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
160166 # out = torch.einsum("...n, son -> ...so", B, cp)
161167 out = torch .einsum ("bsc,sco->bso" , B , cp )
162168
163- return out
169+ if self .aggregate_output == "mean" :
170+ out = out .mean (dim = - 1 ) # aggregate over O dimension if present
171+ elif self .aggregate_output == "sum" :
172+ out = out .sum (dim = - 1 )
173+
174+ # print("vectorized forward, out:", out.shape)
175+
176+ return out
164177
165178 def forward_basis (self , basis ):
166179 """
0 commit comments