Skip to content

Commit c1d7e26

Browse files
committed
fix output dimension for vectorized spline
1 parent a548161 commit c1d7e26

3 files changed

Lines changed: 24 additions & 7 deletions

File tree

pina/_src/model/spline.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,7 @@ def forward(self, x):
278278
:rtype: torch.Tensor
279279
"""
280280
basis = self.basis(x.as_subclass(torch.Tensor))
281+
# print("normal forward, cp:", self.control_points)
281282
return basis @ self.control_points
282283

283284
def derivative(self, x, degree):

pina/_src/model/vectorized_spline.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,10 @@ class VectorizedSpline(nn.Module):
2424

2525
def __init__(
2626
self,
27-
order: int,
28-
knots: torch.Tensor,
29-
control_points: torch.Tensor | None = None,
27+
order,
28+
knots,
29+
control_points=None,
30+
aggregate_output=None,
3031
):
3132
super().__init__()
3233
if not isinstance(order, int) or order <= 0:
@@ -68,6 +69,7 @@ def __init__(
6869
# f"Last dim of control_points must be n_ctrl={n_ctrl}. Got {control_points.shape[-1]}."
6970
# )
7071
self.control_points = nn.Parameter(control_points, requires_grad=True)
72+
self.aggregate_output = aggregate_output
7173

7274
@staticmethod
7375
def _compute_boundary_interval_idx(knots: torch.Tensor) -> int:
@@ -90,7 +92,8 @@ def basis(self, x: torch.Tensor) -> torch.Tensor:
9092
x = torch.as_tensor(x)
9193

9294
# ensure float dtype consistent
93-
x = x.to(dtype=self.knots.dtype, device=self.knots.device)
95+
# x = x.to(dtype=self.knots.dtype, device=self.knots.device)
96+
x = x.to(dtype=self.knots.dtype, device=self.knots.device).as_subclass(torch.Tensor)
9497

9598
# make x shape (..., 1) for broadcasting
9699
x_exp = x.unsqueeze(-1) # (..., 1)
@@ -147,11 +150,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
147150
B = self.basis(x) # (..., n_ctrl)
148151

149152
cp = self.control_points
153+
# print("vectorized forward, cp:", cp)
150154
if cp.ndim == 2:
151155
# (S, n_ctrl)
152156
# want (..., S) = (..., n_ctrl) @ (n_ctrl, S)
157+
# print('B shape:', B.shape, 'cp shape:', cp.shape)
158+
#out = (B @ cp.transpose(0, 1)).squeeze(-1)
153159
out = B @ cp.transpose(0, 1)
154-
return out
160+
# out = B @ cp[0]
155161
else:
156162
# (S, O, n_ctrl)
157163
# Compute for each S: (..., n_ctrl) @ (n_ctrl, O) -> (..., O), then stack over S
@@ -160,7 +166,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
160166
# out = torch.einsum("...n, son -> ...so", B, cp)
161167
out = torch.einsum("bsc,sco->bso", B, cp)
162168

163-
return out
169+
if self.aggregate_output == "mean":
170+
out = out.mean(dim=-1) # aggregate over O dimension if present
171+
elif self.aggregate_output == "sum":
172+
out = out.sum(dim=-1)
173+
174+
# print("vectorized forward, out:", out.shape)
175+
176+
return out
164177

165178
def forward_basis(self, basis):
166179
"""

tests/test_model/test_spline.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,9 @@ def test_vectorized(args, N):
217217
result_single = torch.stack([
218218
splines[i](x) for i in range(N)
219219
])
220-
result_single = result_single.permute(1, 2, 0)
220+
result_single = result_single.permute(1, 2, 0) # shape (100, N)
221221
out_vectorized = vectorized_spline(x)
222+
print("result single shape:", result_single.shape)
223+
print("out vectorized shape:", out_vectorized.shape)
224+
assert out_vectorized.shape == (100, 1, N)
222225
assert torch.allclose(out_vectorized, result_single, atol=1e-5, rtol=1e-5)

0 commit comments

Comments
 (0)