Skip to content

Commit e130701

Browse files
authored
Fix MeshFilter.get_pandas_dataframe to handle all mesh types (#3817)
1 parent 8c24c1c commit e130701

8 files changed

Lines changed: 834 additions & 697 deletions

File tree

openmc/filter.py

Lines changed: 16 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -972,53 +972,36 @@ def get_pandas_dataframe(self, data_size, stride, **kwargs):
972972
Returns
973973
-------
974974
pandas.DataFrame
975-
A Pandas DataFrame with three columns describing the x,y,z mesh
976-
cell indices corresponding to each filter bin. The number of rows
977-
in the DataFrame is the same as the total number of bins in the
978-
corresponding tally, with the filter bin appropriately tiled to map
979-
to the corresponding tally bins.
975+
A Pandas DataFrame with columns describing the mesh cell indices
976+
corresponding to each filter bin. Column names depend on the mesh
977+
type (e.g., x/y/z for RegularMesh, r/phi/z for CylindricalMesh,
978+
r/theta/phi for SphericalMesh, or element index for
979+
UnstructuredMesh). The number of rows in the DataFrame is the same
980+
as the total number of bins in the corresponding tally, with the
981+
filter bin appropriately tiled to map to the corresponding tally
982+
bins.
980983
981984
See also
982985
--------
983986
Tally.get_pandas_dataframe(), CrossFilter.get_pandas_dataframe()
984987
985988
"""
986-
# Initialize Pandas DataFrame
987-
df = pd.DataFrame()
988-
989989
# Initialize dictionary to build Pandas Multi-index column
990990
filter_dict = {}
991991

992992
# Append mesh ID as outermost index of multi-index
993993
mesh_key = f'mesh {self.mesh.id}'
994994

995-
# Find mesh dimensions - use 3D indices for simplicity
996-
n_dim = len(self.mesh.dimension)
997-
if n_dim == 3:
998-
nx, ny, nz = self.mesh.dimension
999-
elif n_dim == 2:
1000-
nx, ny = self.mesh.dimension
1001-
nz = 1
1002-
else:
1003-
nx = self.mesh.dimension
1004-
ny = nz = 1
1005-
1006-
# Generate multi-index sub-column for x-axis
1007-
filter_dict[mesh_key, 'x'] = _repeat_and_tile(
1008-
np.arange(1, nx + 1), stride, data_size)
1009-
1010-
# Generate multi-index sub-column for y-axis
1011-
filter_dict[mesh_key, 'y'] = _repeat_and_tile(
1012-
np.arange(1, ny + 1), nx * stride, data_size)
1013-
1014-
# Generate multi-index sub-column for z-axis
1015-
filter_dict[mesh_key, 'z'] = _repeat_and_tile(
1016-
np.arange(1, nz + 1), nx * ny * stride, data_size)
995+
# Determine index base (0-based for unstructured, 1-based otherwise)
996+
idx_start = 0 if isinstance(self.mesh, openmc.UnstructuredMesh) else 1
1017997

1018-
# Initialize a Pandas DataFrame from the mesh dictionary
1019-
df = pd.concat([df, pd.DataFrame(filter_dict)])
998+
# Generate a multi-index sub-column for each axis
999+
for label, dim_size in zip(self.mesh._axis_labels, self.mesh.dimension):
1000+
filter_dict[mesh_key, label] = _repeat_and_tile(
1001+
np.arange(idx_start, idx_start + dim_size), stride, data_size)
1002+
stride *= dim_size
10201003

1021-
return df
1004+
return pd.DataFrame(filter_dict)
10221005

10231006
def to_xml_element(self):
10241007
"""Return XML Element representing the Filter.

openmc/mesh.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -236,12 +236,12 @@ def name(self, name: str):
236236
self._name = name
237237
else:
238238
self._name = ''
239-
239+
240240
@property
241241
@abstractmethod
242242
def lower_left(self):
243243
pass
244-
244+
245245
@property
246246
@abstractmethod
247247
def upper_right(self):
@@ -255,7 +255,7 @@ def bounding_box(self) -> openmc.BoundingBox:
255255
@abstractmethod
256256
def indices(self):
257257
pass
258-
258+
259259
@property
260260
@abstractmethod
261261
def n_elements(self):
@@ -537,6 +537,11 @@ def dimension(self):
537537
def n_dimension(self):
538538
pass
539539

540+
@property
541+
@abstractmethod
542+
def _axis_labels(self):
543+
pass
544+
540545
@property
541546
@abstractmethod
542547
def _grids(self):
@@ -636,7 +641,7 @@ def centroids(self):
636641
s0 = (slice(0, -1),)*ndim + (slice(None),)
637642
s1 = (slice(1, None),)*ndim + (slice(None),)
638643
return (vertices[s0] + vertices[s1]) / 2
639-
644+
640645
@property
641646
def n_elements(self):
642647
return np.prod(self.dimension)
@@ -995,6 +1000,10 @@ def n_dimension(self):
9951000
else:
9961001
return None
9971002

1003+
@property
1004+
def _axis_labels(self):
1005+
return ('x', 'y', 'z')[:self.n_dimension]
1006+
9981007
@property
9991008
def lower_left(self):
10001009
return self._lower_left
@@ -1475,6 +1484,10 @@ def dimension(self):
14751484
def n_dimension(self):
14761485
return 3
14771486

1487+
@property
1488+
def _axis_labels(self):
1489+
return ('x', 'y', 'z')
1490+
14781491
@property
14791492
def x_grid(self):
14801493
return self._x_grid
@@ -1709,6 +1722,10 @@ def dimension(self):
17091722
def n_dimension(self):
17101723
return 3
17111724

1725+
@property
1726+
def _axis_labels(self):
1727+
return ('r', 'phi', 'z')
1728+
17121729
@property
17131730
def origin(self):
17141731
return self._origin
@@ -2156,6 +2173,10 @@ def dimension(self):
21562173
def n_dimension(self):
21572174
return 3
21582175

2176+
@property
2177+
def _axis_labels(self):
2178+
return ('r', 'theta', 'phi')
2179+
21592180
@property
21602181
def origin(self):
21612182
return self._origin
@@ -2671,6 +2692,10 @@ def dimension(self):
26712692
def n_dimension(self):
26722693
return 3
26732694

2695+
@property
2696+
def _axis_labels(self):
2697+
return ('element_index',)
2698+
26742699
@property
26752700
@require_statepoint_data
26762701
def indices(self):

openmc/mgxs/mdgxs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -877,8 +877,8 @@ def get_pandas_dataframe(self, groups='all', nuclides='all',
877877
# energy groups such that data is from fast to thermal
878878
if self.domain_type == 'mesh':
879879
mesh_str = f'mesh {self.domain.id}'
880-
df.sort_values(by=[(mesh_str, 'x'), (mesh_str, 'y'),
881-
(mesh_str, 'z')] + columns, inplace=True)
880+
mesh_cols = [(mesh_str, label) for label in self.domain._axis_labels]
881+
df.sort_values(by=mesh_cols + columns, inplace=True)
882882
else:
883883
df.sort_values(by=[self.domain_type] + columns, inplace=True)
884884

openmc/mgxs/mgxs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2134,8 +2134,8 @@ def get_pandas_dataframe(self, groups='all', nuclides='all',
21342134
# energy groups such that data is from fast to thermal
21352135
if self.domain_type == 'mesh':
21362136
mesh_str = f'mesh {self.domain.id}'
2137-
df.sort_values(by=[(mesh_str, 'x'), (mesh_str, 'y'),
2138-
(mesh_str, 'z')] + columns, inplace=True)
2137+
mesh_cols = [(mesh_str, label) for label in self.domain._axis_labels]
2138+
df.sort_values(by=mesh_cols + columns, inplace=True)
21392139
else:
21402140
df.sort_values(by=[self.domain_type] + columns, inplace=True)
21412141

0 commit comments

Comments
 (0)