Skip to content

Commit e6b3a87

Browse files
authored
Merge pull request #367 from simpeg/tree-mesh-cell-bounds
Add `TreeCell.bounds` and `TreeMesh.cell_bounds` methods
2 parents 87b9300 + fa45e02 commit e6b3a87

3 files changed

Lines changed: 173 additions & 61 deletions

File tree

discretize/_extensions/tree.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,8 @@ class Cell{
123123
double volume;
124124

125125
Cell();
126-
Cell(Node *pts[4], int_t ndim, int_t maxlevel);//, function func);
127-
Cell(Node *pts[4], Cell *parent);
126+
Cell(Node *pts[8], int_t ndim, int_t maxlevel);//, function func);
127+
Cell(Node *pts[8], Cell *parent);
128128
~Cell();
129129

130130
inline Node* min_node(){ return points[0];};

discretize/_extensions/tree_ext.pyx

Lines changed: 113 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -55,21 +55,6 @@ cdef class TreeCell:
5555
cdef void _set(self, c_Cell* cell):
5656
self._cell = cell
5757
self._dim = cell.n_dim
58-
cdef:
59-
Node *min_n = cell.min_node()
60-
Node *max_n = cell.max_node()
61-
self._x = self._cell.location[0]
62-
self._x0 = min_n.location[0]
63-
64-
self._y = self._cell.location[1]
65-
self._y0 = min_n.location[1]
66-
67-
self._wx = max_n.location[0] - self._x0
68-
self._wy = max_n.location[1] - self._y0
69-
if(self._dim > 2):
70-
self._z = self._cell.location[2]
71-
self._z0 = min_n.location[2]
72-
self._wz = max_n.location[2] - self._z0
7358

7459
@property
7560
def nodes(self):
@@ -150,8 +135,10 @@ cdef class TreeCell:
150135
(dim) numpy.ndarray
151136
Cell center location for the tree cell
152137
"""
153-
if self._dim == 2: return np.array([self._x, self._y])
154-
return np.array([self._x, self._y, self._z])
138+
loc = self._cell.location
139+
if self._dim == 2:
140+
return np.array([loc[0], loc[1]])
141+
return np.array([loc[0], loc[1], loc[2]])
155142

156143
@property
157144
def origin(self):
@@ -166,8 +153,10 @@ cdef class TreeCell:
166153
(dim) numpy.ndarray
167154
Origin location ('anchor point') for the tree cell
168155
"""
169-
if self._dim == 2: return np.array([self._x0, self._y0])
170-
return np.array([self._x0, self._y0, self._z0])
156+
loc = self._cell.min_node().location
157+
if self._dim == 2:
158+
return np.array([loc[0], loc[1]])
159+
return np.array([loc[0], loc[1], loc[2]])
171160

172161
@property
173162
def x0(self):
@@ -196,8 +185,19 @@ cdef class TreeCell:
196185
(dim) numpy.ndarray
197186
Cell dimension along each axis direction
198187
"""
199-
if self._dim == 2: return np.array([self._wx, self._wy])
200-
return np.array([self._wx, self._wy, self._wz])
188+
loc_min = self._cell.min_node().location
189+
loc_max = self._cell.max_node().location
190+
191+
if self._dim == 2:
192+
return np.array([
193+
loc_max[0] - loc_min[0],
194+
loc_max[1] - loc_min[1],
195+
])
196+
return np.array([
197+
loc_max[0] - loc_min[0],
198+
loc_max[1] - loc_min[1],
199+
loc_max[2] - loc_min[2],
200+
])
201201

202202
@property
203203
def dim(self):
@@ -221,6 +221,43 @@ cdef class TreeCell:
221221
"""
222222
return self._cell.index
223223

224+
@property
225+
def bounds(self):
226+
"""
227+
Bounds of the cell.
228+
229+
Coordinates that define the bounds of the cell. Bounds are returned in
230+
the following order: ``x0``, ``x1``, ``y0``, ``y1``, ``z0``, ``z1``.
231+
232+
Returns
233+
-------
234+
bounds : (2 * dim) array
235+
Array with the cell bounds.
236+
"""
237+
loc_min = self._cell.min_node().location
238+
loc_max = self._cell.max_node().location
239+
240+
if self.dim == 2:
241+
return np.array(
242+
[
243+
loc_min[0],
244+
loc_max[0],
245+
loc_min[1],
246+
loc_max[1],
247+
]
248+
)
249+
return np.array(
250+
[
251+
loc_min[0],
252+
loc_max[0],
253+
loc_min[1],
254+
loc_max[1],
255+
loc_min[2],
256+
loc_max[2],
257+
]
258+
)
259+
260+
224261
@property
225262
def neighbors(self):
226263
"""Indices for this cell's neighbors within its parent tree mesh.
@@ -242,63 +279,64 @@ cdef class TreeCell:
242279
neighbors = [-1]*self._dim*2
243280

244281
for i in range(self._dim*2):
245-
if self._cell.neighbors[i] is NULL:
282+
neighbor = self._cell.neighbors[i]
283+
if neighbor is NULL:
246284
continue
247-
elif self._cell.neighbors[i].is_leaf():
248-
neighbors[i] = self._cell.neighbors[i].index
285+
elif neighbor.is_leaf():
286+
neighbors[i] = neighbor.index
249287
else:
250288
if self._dim==2:
251289
if i==0:
252-
neighbors[i] = [self._cell.neighbors[i].children[1].index,
253-
self._cell.neighbors[i].children[3].index]
290+
neighbors[i] = [neighbor.children[1].index,
291+
neighbor.children[3].index]
254292
elif i==1:
255-
neighbors[i] = [self._cell.neighbors[i].children[0].index,
256-
self._cell.neighbors[i].children[2].index]
293+
neighbors[i] = [neighbor.children[0].index,
294+
neighbor.children[2].index]
257295
elif i==2:
258-
neighbors[i] = [self._cell.neighbors[i].children[2].index,
259-
self._cell.neighbors[i].children[3].index]
296+
neighbors[i] = [neighbor.children[2].index,
297+
neighbor.children[3].index]
260298
else:
261-
neighbors[i] = [self._cell.neighbors[i].children[0].index,
262-
self._cell.neighbors[i].children[1].index]
299+
neighbors[i] = [neighbor.children[0].index,
300+
neighbor.children[1].index]
263301
else:
264302
if i==0:
265-
neighbors[i] = [self._cell.neighbors[i].children[1].index,
266-
self._cell.neighbors[i].children[3].index,
267-
self._cell.neighbors[i].children[5].index,
268-
self._cell.neighbors[i].children[7].index]
303+
neighbors[i] = [neighbor.children[1].index,
304+
neighbor.children[3].index,
305+
neighbor.children[5].index,
306+
neighbor.children[7].index]
269307
elif i==1:
270-
neighbors[i] = [self._cell.neighbors[i].children[0].index,
271-
self._cell.neighbors[i].children[2].index,
272-
self._cell.neighbors[i].children[4].index,
273-
self._cell.neighbors[i].children[6].index]
308+
neighbors[i] = [neighbor.children[0].index,
309+
neighbor.children[2].index,
310+
neighbor.children[4].index,
311+
neighbor.children[6].index]
274312
elif i==2:
275-
neighbors[i] = [self._cell.neighbors[i].children[2].index,
276-
self._cell.neighbors[i].children[3].index,
277-
self._cell.neighbors[i].children[6].index,
278-
self._cell.neighbors[i].children[7].index]
313+
neighbors[i] = [neighbor.children[2].index,
314+
neighbor.children[3].index,
315+
neighbor.children[6].index,
316+
neighbor.children[7].index]
279317
elif i==3:
280-
neighbors[i] = [self._cell.neighbors[i].children[0].index,
281-
self._cell.neighbors[i].children[1].index,
282-
self._cell.neighbors[i].children[4].index,
283-
self._cell.neighbors[i].children[5].index]
318+
neighbors[i] = [neighbor.children[0].index,
319+
neighbor.children[1].index,
320+
neighbor.children[4].index,
321+
neighbor.children[5].index]
284322
elif i==4:
285-
neighbors[i] = [self._cell.neighbors[i].children[4].index,
286-
self._cell.neighbors[i].children[5].index,
287-
self._cell.neighbors[i].children[6].index,
288-
self._cell.neighbors[i].children[7].index]
323+
neighbors[i] = [neighbor.children[4].index,
324+
neighbor.children[5].index,
325+
neighbor.children[6].index,
326+
neighbor.children[7].index]
289327
else:
290-
neighbors[i] = [self._cell.neighbors[i].children[0].index,
291-
self._cell.neighbors[i].children[1].index,
292-
self._cell.neighbors[i].children[2].index,
293-
self._cell.neighbors[i].children[3].index]
328+
neighbors[i] = [neighbor.children[0].index,
329+
neighbor.children[1].index,
330+
neighbor.children[2].index,
331+
neighbor.children[3].index]
294332
return neighbors
295333

296334
@property
297335
def _index_loc(self):
336+
loc_ind = self._cell.location_ind
298337
if self._dim == 2:
299-
return tuple((self._cell.location_ind[0], self._cell.location_ind[1]))
300-
return tuple((self._cell.location_ind[0], self._cell.location_ind[1],
301-
self._cell.location_ind[2]))
338+
return tuple((loc_ind[0], loc_ind[1]))
339+
return tuple((loc_ind[0], loc_ind[1], loc_ind[2]))
302340

303341
@property
304342
def _level(self):
@@ -1190,6 +1228,22 @@ cdef class _TreeMesh:
11901228
"""
11911229
return self._finalized
11921230

1231+
@property
1232+
@cython.boundscheck(False)
1233+
def cell_bounds(self):
1234+
cell_bounds = np.empty((self.n_cells, self.dim, 2), dtype=np.float64)
1235+
cdef np.float64_t[:, :, ::1] cell_bounds_view = cell_bounds
1236+
1237+
for cell in self.tree.cells:
1238+
min_loc = cell.min_node().location
1239+
max_loc = cell.max_node().location
1240+
1241+
for i in range(self._dim):
1242+
cell_bounds_view[cell.index, i, 0] = min_loc[i]
1243+
cell_bounds_view[cell.index, i, 1] = max_loc[i]
1244+
1245+
return cell_bounds.reshape((self.n_cells, -1))
1246+
11931247
def number(self):
11941248
"""Number the cells, nodes, faces, and edges of the TreeMesh."""
11951249
self.tree.number()

tests/tree/test_tree.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,64 @@ def test_total_nodes(self, sample_mesh):
380380
)
381381

382382

383+
class TestTreeCellBounds:
384+
"""Test ``TreeCell.bounds`` method"""
385+
386+
@pytest.fixture(params=["2D", "3D"])
387+
def mesh(self, request):
388+
"""Return a sample TreeMesh"""
389+
nc = 16
390+
if request.param == "2D":
391+
h = [nc, nc]
392+
origin = (-32.4, 245.4)
393+
mesh = discretize.TreeMesh(h, origin)
394+
p1 = (origin[0] + 0.4, origin[1] + 0.4)
395+
p2 = (origin[0] + 0.6, origin[1] + 0.6)
396+
mesh.refine_box(p1, p2, levels=5, finalize=True)
397+
else:
398+
h = [nc, nc, nc]
399+
origin = (-32.4, 245.4, 192.3)
400+
mesh = discretize.TreeMesh(h, origin)
401+
p1 = (origin[0] + 0.4, origin[1] + 0.4, origin[2] + 0.7)
402+
p2 = (origin[0] + 0.6, origin[1] + 0.6, origin[2] + 0.9)
403+
mesh.refine_box(p1, p2, levels=5, finalize=True)
404+
return mesh
405+
406+
def test_bounds(self, mesh):
407+
"""Test bounds method of one of the cells in the mesh."""
408+
cell = mesh[16]
409+
nodes = mesh.nodes[cell.nodes]
410+
x1, x2 = nodes[0][0], nodes[-1][0]
411+
y1, y2 = nodes[0][1], nodes[-1][1]
412+
if mesh.dim == 2:
413+
expected_bounds = np.array([x1, x2, y1, y2])
414+
else:
415+
z1, z2 = nodes[0][2], nodes[-1][2]
416+
expected_bounds = np.array([x1, x2, y1, y2, z1, z2])
417+
np.testing.assert_equal(cell.bounds, expected_bounds)
418+
419+
def test_bounds_relations(self, mesh):
420+
"""Test if bounds are in the right order for one cell in the mesh."""
421+
cell = mesh[16]
422+
if mesh.dim == 2:
423+
x1, x2, y1, y2 = cell.bounds
424+
assert x1 < x2
425+
assert y1 < y2
426+
else:
427+
x1, x2, y1, y2, z1, z2 = cell.bounds
428+
assert x1 < x2
429+
assert y1 < y2
430+
assert z1 < z2
431+
432+
def test_cell_bounds(self, mesh):
433+
"""Test cell_bounds method of the tree mesh."""
434+
cell_bounds = mesh.cell_bounds
435+
cell_bounds_slow = np.empty((mesh.n_cells, 2 * mesh.dim))
436+
for i, cell in enumerate(mesh):
437+
cell_bounds_slow[i] = cell.bounds
438+
np.testing.assert_equal(cell_bounds, cell_bounds_slow)
439+
440+
383441
class Test2DInterpolation(unittest.TestCase):
384442
def setUp(self):
385443
def topo(x):

0 commit comments

Comments
 (0)