@@ -704,16 +704,16 @@ def local_elements(self):
704704 # def mode_mask(self):
705705 # return reduce()
706706
707- def load_from_hdf5 (self , file , index , task = None ):
707+ def load_from_hdf5 (self , file , index , task = None , func = None ):
708708 """Load grid data from an hdf5 file. Task corresponds to field name by default."""
709709 if task is None :
710710 task = self .name
711711 dset = file ['tasks' ][task ]
712712 if not np .all (dset .attrs ['grid_space' ]):
713713 raise ValueError ("Can only load data from grid space" )
714- self .load_from_global_grid_data (dset , pre_slices = (index ,))
714+ self .load_from_global_grid_data (dset , pre_slices = (index ,), func = func )
715715
716- def load_from_global_grid_data (self , global_data , pre_slices = tuple ()):
716+ def load_from_global_grid_data (self , global_data , pre_slices = tuple (), func = None ):
717717 """Load local grid data from array-like global grid data."""
718718 dim = self .dist .dim
719719 layout = self .dist .grid_layout
@@ -724,7 +724,10 @@ def load_from_global_grid_data(self, global_data, pre_slices=tuple()):
724724 component_slices = tuple (slice (None ) for cs in self .tensorsig )
725725 spatial_slices = layout .slices (self .domain , scales )
726726 local_slices = pre_slices + component_slices + spatial_slices
727- self [layout ] = global_data [local_slices ]
727+ if func is None :
728+ self [layout ] = global_data [local_slices ]
729+ else :
730+ self [layout ] = func (global_data [local_slices ])
728731 # Change scales back to dealias scales
729732 self .change_scales (self .domain .dealias )
730733
0 commit comments