Skip to content

Commit e7ed771

Browse files
committed
Refactor resize_dataarray to dataarray.epoch.resize()
1 parent 5b5db10 commit e7ed771

1 file changed

Lines changed: 39 additions & 33 deletions

File tree

src/sdf_xarray/dataarray_accessor.py

Lines changed: 39 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from types import MethodType
44
from typing import TYPE_CHECKING
55

6+
import numpy as np
67
import xarray as xr
78
from xarray.plot.accessor import DataArrayPlotAccessor
89

@@ -11,51 +12,28 @@
1112
if TYPE_CHECKING:
1213
from matplotlib.animation import FuncAnimation
1314

14-
import numpy as np
15-
from scipy.interpolate import RegularGridInterpolator
16-
17-
def resize_ndarray(
15+
16+
def _resize_ndarray(
1817
arr: np.ndarray,
1918
new_shape: tuple | list | np.ndarray,
2019
) -> np.ndarray:
21-
new_shape = list(new_shape)
20+
21+
from scipy.interpolate import RegularGridInterpolator # noqa: PLC0415
2222

2323
if arr.ndim != len(new_shape):
24-
raise ValueError("The number of dimensions in new_shape must match the input array.")
24+
raise ValueError(
25+
f"The number of dimensions must match the input array. (original: {arr.ndim}, new: {len(new_shape)})"
26+
)
2527

2628
old_grids = tuple(np.linspace(0, 1, size) for size in arr.shape)
27-
interpolator = RegularGridInterpolator(old_grids, arr, bounds_error=False, fill_value=0)
2829
new_grids = tuple(np.linspace(0, 1, size) for size in new_shape)
29-
mesh = np.meshgrid(*new_grids, indexing='ij')
30+
mesh = np.meshgrid(*new_grids, indexing="ij")
3031
coords = np.stack(mesh, axis=-1)
31-
output = interpolator(coords)
32-
33-
return output
34-
35-
def resize_dataarray(
36-
da: xr.DataArray,
37-
new_shape: tuple | list | np.ndarray,
38-
) -> xr.DataArray:
39-
40-
resized_data = resize_ndarray(da.values, new_shape)
41-
42-
da_resized = da.copy()
4332

44-
da_resized = xr.DataArray(
45-
data=resized_data,
46-
dims=da.dims,
47-
attrs=da.attrs,
33+
return RegularGridInterpolator(old_grids, arr, bounds_error=False, fill_value=0)(
34+
coords
4835
)
4936

50-
shape = da_resized.shape
51-
for i in range(len(da_resized.dims)):
52-
coord = list(da_resized.dims)[i]
53-
da_resized[coord] = resize_ndarray(da[coord], [shape[i]])
54-
da_resized[coord].attrs = da[coord].attrs
55-
56-
da_resized.attrs["original_shape"] = da.shape
57-
58-
return da_resized
5937

6038
@xr.register_dataarray_accessor("epoch")
6139
class EpochAccessor:
@@ -117,3 +95,31 @@ def animate(self, *args, **kwargs) -> FuncAnimation:
11795
anim.show = MethodType(show, anim)
11896

11997
return anim
98+
99+
def resize(
100+
self,
101+
new_shape: tuple | list | np.ndarray,
102+
) -> xr.DataArray:
103+
104+
da = self._obj
105+
# Create a copy of the existing dataarray so that we can copy over the
106+
# original dims, attrs and shape
107+
da_resized = da.copy()
108+
109+
# Resize the dataarray's data, either via upsampling or downsampling
110+
resized_data = _resize_ndarray(da.values, new_shape)
111+
112+
da_resized = xr.DataArray(
113+
data=resized_data,
114+
dims=da.dims,
115+
attrs=da.attrs,
116+
)
117+
# Add a new attr containing the original shape
118+
da_resized.attrs["original_shape"] = da.shape
119+
120+
# Resize the dataarray's underlying dimensions with their new shapes
121+
for coord_name, shape in zip(da_resized.dims, da_resized.shape):
122+
da_resized[coord_name] = _resize_ndarray(da[coord_name], [shape])
123+
da_resized[coord_name].attrs = da[coord_name].attrs
124+
125+
return da_resized

0 commit comments

Comments
 (0)