@@ -93,7 +93,9 @@ def basis(self, x: torch.Tensor) -> torch.Tensor:
9393
9494 # ensure float dtype consistent
9595 # x = x.to(dtype=self.knots.dtype, device=self.knots.device)
96- x = x .to (dtype = self .knots .dtype , device = self .knots .device ).as_subclass (torch .Tensor )
96+ x = x .as_subclass (torch .Tensor ).to (
97+ dtype = self .knots .dtype , device = self .knots .device
98+ )
9799
98100 # make x shape (..., 1) for broadcasting
99101 x_exp = x .unsqueeze (- 1 ) # (..., 1)
@@ -155,7 +157,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
155157 # (S, n_ctrl)
156158 # want (..., S) = (..., n_ctrl) @ (n_ctrl, S)
157159 # print('B shape:', B.shape, 'cp shape:', cp.shape)
158- #out = (B @ cp.transpose(0, 1)).squeeze(-1)
160+ # out = (B @ cp.transpose(0, 1)).squeeze(-1)
159161 out = B @ cp .transpose (0 , 1 )
160162 # out = B @ cp[0]
161163 else :
@@ -164,34 +166,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
164166 # vectorized using einsum (yes, this one is actually appropriate)
165167 # (..., n) * (S, O, n) -> (..., S, O)
166168 # out = torch.einsum("...n, son -> ...so", B, cp)
167- out = torch .einsum ("bsc,sco ->bso" , B , cp )
169+ out = torch .einsum ("bsc,soc ->bso" , B , cp )
168170
169171 if self .aggregate_output == "mean" :
170172 out = out .mean (dim = - 1 ) # aggregate over O dimension if present
171173 elif self .aggregate_output == "sum" :
172174 out = out .sum (dim = - 1 )
173175
174176 # print("vectorized forward, out:", out.shape)
175-
176- return out
177177
178- def forward_basis (self , basis ):
179- """
180- Evaluate spline(s) given precomputed basis.
181-
182- """
183- cp = self .control_points
184- if cp .ndim == 2 :
185- # (S, n_ctrl)
186- # want (..., S) = (..., n_ctrl) @ (n_ctrl, S)
187- out = basis @ cp .transpose (0 , 1 )
188- return out
189- else :
190- # (S, O, n_ctrl)
191- # Compute for each S: (..., n_ctrl) @ (n_ctrl, O) -> (..., O), then stack over S
192- # vectorized using einsum (yes, this one is actually appropriate)
193- # (..., n) * (S, O, n) -> (..., S, O)
194- # out = torch.einsum("...n, son -> ...so", B, cp)
195- out = torch .einsum ("bsc,sco->bso" , basis , cp )
196-
197- return out
178+ return out
0 commit comments