|
3 | 3 | from types import MethodType |
4 | 4 | from typing import TYPE_CHECKING |
5 | 5 |
|
| 6 | +import numpy as np |
6 | 7 | import xarray as xr |
7 | 8 | from xarray.plot.accessor import DataArrayPlotAccessor |
8 | 9 |
|
|
11 | 12 | if TYPE_CHECKING: |
12 | 13 | from matplotlib.animation import FuncAnimation |
13 | 14 |
|
14 | | -import numpy as np |
15 | | -from scipy.interpolate import RegularGridInterpolator |
16 | | - |
17 | | -def resize_ndarray( |
| 15 | + |
| 16 | +def _resize_ndarray( |
18 | 17 | arr: np.ndarray, |
19 | 18 | new_shape: tuple | list | np.ndarray, |
20 | 19 | ) -> np.ndarray: |
21 | | - new_shape = list(new_shape) |
| 20 | + |
| 21 | + from scipy.interpolate import RegularGridInterpolator # noqa: PLC0415 |
22 | 22 |
|
23 | 23 | if arr.ndim != len(new_shape): |
24 | | - raise ValueError("The number of dimensions in new_shape must match the input array.") |
| 24 | + raise ValueError( |
| 25 | + f"The number of dimensions must match the input array. (original: {arr.ndim}, new: {len(new_shape)})" |
| 26 | + ) |
25 | 27 |
|
26 | 28 | old_grids = tuple(np.linspace(0, 1, size) for size in arr.shape) |
27 | | - interpolator = RegularGridInterpolator(old_grids, arr, bounds_error=False, fill_value=0) |
28 | 29 | new_grids = tuple(np.linspace(0, 1, size) for size in new_shape) |
29 | | - mesh = np.meshgrid(*new_grids, indexing='ij') |
| 30 | + mesh = np.meshgrid(*new_grids, indexing="ij") |
30 | 31 | coords = np.stack(mesh, axis=-1) |
31 | | - output = interpolator(coords) |
32 | | - |
33 | | - return output |
34 | | - |
35 | | -def resize_dataarray( |
36 | | - da: xr.DataArray, |
37 | | - new_shape: tuple | list | np.ndarray, |
38 | | - ) -> xr.DataArray: |
39 | | - |
40 | | - resized_data = resize_ndarray(da.values, new_shape) |
41 | | - |
42 | | - da_resized = da.copy() |
43 | 32 |
|
44 | | - da_resized = xr.DataArray( |
45 | | - data=resized_data, |
46 | | - dims=da.dims, |
47 | | - attrs=da.attrs, |
| 33 | + return RegularGridInterpolator(old_grids, arr, bounds_error=False, fill_value=0)( |
| 34 | + coords |
48 | 35 | ) |
49 | 36 |
|
50 | | - shape = da_resized.shape |
51 | | - for i in range(len(da_resized.dims)): |
52 | | - coord = list(da_resized.dims)[i] |
53 | | - da_resized[coord] = resize_ndarray(da[coord], [shape[i]]) |
54 | | - da_resized[coord].attrs = da[coord].attrs |
55 | | - |
56 | | - da_resized.attrs["original_shape"] = da.shape |
57 | | - |
58 | | - return da_resized |
59 | 37 |
|
60 | 38 | @xr.register_dataarray_accessor("epoch") |
61 | 39 | class EpochAccessor: |
@@ -117,3 +95,31 @@ def animate(self, *args, **kwargs) -> FuncAnimation: |
117 | 95 | anim.show = MethodType(show, anim) |
118 | 96 |
|
119 | 97 | return anim |
| 98 | + |
| 99 | + def resize( |
| 100 | + self, |
| 101 | + new_shape: tuple | list | np.ndarray, |
| 102 | + ) -> xr.DataArray: |
| 103 | + |
| 104 | + da = self._obj |
| 105 | + # Create a copy of the existing dataarray so that we can copy over the |
| 106 | + # original dims, attrs and shape |
| 107 | + da_resized = da.copy() |
| 108 | + |
| 109 | + # Resize the dataarray's data, either via upsampling or downsampling |
| 110 | + resized_data = _resize_ndarray(da.values, new_shape) |
| 111 | + |
| 112 | + da_resized = xr.DataArray( |
| 113 | + data=resized_data, |
| 114 | + dims=da.dims, |
| 115 | + attrs=da.attrs, |
| 116 | + ) |
| 117 | + # Add a new attr containing the original shape |
| 118 | + da_resized.attrs["original_shape"] = da.shape |
| 119 | + |
| 120 | + # Resize the dataarray's underlying dimensions with their new shapes |
| 121 | + for coord_name, shape in zip(da_resized.dims, da_resized.shape): |
| 122 | + da_resized[coord_name] = _resize_ndarray(da[coord_name], [shape]) |
| 123 | + da_resized[coord_name].attrs = da[coord_name].attrs |
| 124 | + |
| 125 | + return da_resized |
0 commit comments