Skip to content

Commit e2ec4d0

Browse files
fix index mismatch and remove unused function
1 parent c1d7e26 commit e2ec4d0

1 file changed

Lines changed: 6 additions & 25 deletions

File tree

pina/_src/model/vectorized_spline.py

Lines changed: 6 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,9 @@ def basis(self, x: torch.Tensor) -> torch.Tensor:
9393

9494
# ensure float dtype consistent
9595
# x = x.to(dtype=self.knots.dtype, device=self.knots.device)
96-
x = x.to(dtype=self.knots.dtype, device=self.knots.device).as_subclass(torch.Tensor)
96+
x = x.as_subclass(torch.Tensor).to(
97+
dtype=self.knots.dtype, device=self.knots.device
98+
)
9799

98100
# make x shape (..., 1) for broadcasting
99101
x_exp = x.unsqueeze(-1) # (..., 1)
@@ -155,7 +157,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
155157
# (S, n_ctrl)
156158
# want (..., S) = (..., n_ctrl) @ (n_ctrl, S)
157159
# print('B shape:', B.shape, 'cp shape:', cp.shape)
158-
#out = (B @ cp.transpose(0, 1)).squeeze(-1)
160+
# out = (B @ cp.transpose(0, 1)).squeeze(-1)
159161
out = B @ cp.transpose(0, 1)
160162
# out = B @ cp[0]
161163
else:
@@ -164,34 +166,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
164166
# vectorized using einsum (yes, this one is actually appropriate)
165167
# (..., n) * (S, O, n) -> (..., S, O)
166168
# out = torch.einsum("...n, son -> ...so", B, cp)
167-
out = torch.einsum("bsc,sco->bso", B, cp)
169+
out = torch.einsum("bsc,soc->bso", B, cp)
168170

169171
if self.aggregate_output == "mean":
170172
out = out.mean(dim=-1) # aggregate over O dimension if present
171173
elif self.aggregate_output == "sum":
172174
out = out.sum(dim=-1)
173175

174176
# print("vectorized forward, out:", out.shape)
175-
176-
return out
177177

178-
def forward_basis(self, basis):
179-
"""
180-
Evaluate spline(s) given precomputed basis.
181-
182-
"""
183-
cp = self.control_points
184-
if cp.ndim == 2:
185-
# (S, n_ctrl)
186-
# want (..., S) = (..., n_ctrl) @ (n_ctrl, S)
187-
out = basis @ cp.transpose(0, 1)
188-
return out
189-
else:
190-
# (S, O, n_ctrl)
191-
# Compute for each S: (..., n_ctrl) @ (n_ctrl, O) -> (..., O), then stack over S
192-
# vectorized using einsum (yes, this one is actually appropriate)
193-
# (..., n) * (S, O, n) -> (..., S, O)
194-
# out = torch.einsum("...n, son -> ...so", B, cp)
195-
out = torch.einsum("bsc,sco->bso", basis, cp)
196-
197-
return out
178+
return out

0 commit comments

Comments
 (0)