Skip to content

Commit 267d976

Browse files
authored
Merge pull request #401 from jcapriot/tensor_p2i
Add point2index functionality for `tensor_mesh`
2 parents 8f94934 + 11f4c2a commit 267d976

2 files changed

Lines changed: 94 additions & 1 deletion

File tree

discretize/tensor_mesh.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from discretize.base import BaseRectangularMesh, BaseTensorMesh
77
from discretize.operators import DiffOperators, InnerProducts
88
from discretize.mixins import InterfaceMixins, TensorMeshIO
9-
from discretize.utils import mkvc
9+
from discretize.utils import mkvc, as_array_n_by_dim
1010
from discretize.utils.code_utils import deprecate_property
1111

1212
from .tensor_cell import TensorCell
@@ -756,6 +756,33 @@ def cell_boundary_indices(self):
756756
indzu = self.gridCC[:, 2] == max(self.gridCC[:, 2])
757757
return indxd, indxu, indyd, indyu, indzd, indzu
758758

759+
def point2index(self, locs): # NOQA D102
760+
# Documentation inherited from discretize.base.BaseMesh
761+
762+
locs = as_array_n_by_dim(locs, self.dim)
763+
# in each dimension do a sorted search within the nodes
764+
# arrays to find the containing cell in that dimension
765+
cell_bounds = [
766+
self.nodes_x,
767+
]
768+
if self.dim > 1:
769+
cell_bounds.append(self.nodes_y)
770+
if self.dim == 3:
771+
cell_bounds.append(self.nodes_z)
772+
773+
# subtract 1 here because given the nodes [0, 1], the point 0.5 would be inserted
774+
# at index 1 to maintain the sorted list, but that corresponds to cell 0.
775+
# clipping here ensures that anything outside the mesh will return the nearest cell.
776+
multi_inds = tuple(
777+
np.clip(np.searchsorted(n, p) - 1, 0, len(n) - 2)
778+
for n, p in zip(cell_bounds, locs.T)
779+
)
780+
# and of course, we are fortran ordered in a tensor mesh.
781+
if self.dim == 1:
782+
return multi_inds[0]
783+
else:
784+
return np.ravel_multi_index(multi_inds, self.shape_cells, order="F")
785+
759786
def _repr_attributes(self):
760787
"""Represent attributes of the mesh."""
761788
attrs = {}

tests/base/test_tensor.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import pytest
22
import numpy as np
3+
import numpy.testing as npt
34
import unittest
45
import discretize
56
from scipy.sparse.linalg import spsolve
@@ -321,5 +322,70 @@ def test_orderBackward(self):
321322
self.orderTest()
322323

323324

325+
@pytest.fixture(params=[1, 2, 3], ids=["dims-1", "dims-2", "dims-3"])
326+
def random_tensor_mesh(request):
327+
dim = request.param
328+
rng = np.random.default_rng(440122)
329+
shape = rng.integers(5, 10, dim)
330+
cell_widths = [rng.uniform(3.0, 872634.321, n) for n in shape]
331+
origin = rng.uniform(-101.031, 33.2, dim)
332+
333+
return discretize.TensorMesh(cell_widths, origin)
334+
335+
336+
def test_tensor_point2index_inside_points(random_tensor_mesh):
337+
mesh = random_tensor_mesh
338+
dim = mesh.dim
339+
m_origin = mesh.origin
340+
m_extent = np.atleast_1d(np.max(mesh.nodes, axis=0))
341+
342+
nd = 15
343+
points = np.stack(np.meshgrid(*np.linspace(m_origin, m_extent, nd).T), axis=-1)
344+
points = points.reshape((-1, dim))
345+
346+
npt.assert_array_equal(mesh.is_inside(points), True)
347+
348+
cell_inds = mesh.point2index(points)
349+
for icell, p in zip(cell_inds, points):
350+
cell = mesh[icell]
351+
c_origin, c_extent = cell.bounds.reshape((dim, 2)).T
352+
dim_test = (p >= c_origin) & (p <= c_extent)
353+
npt.assert_equal(dim_test, True)
354+
355+
356+
def test_tensor_point2index_outside_points(random_tensor_mesh):
357+
mesh = random_tensor_mesh
358+
dim = mesh.dim
359+
m_origin = mesh.origin
360+
m_extent = np.atleast_1d(np.max(mesh.nodes, axis=0))
361+
m_width = m_extent - m_origin
362+
363+
nd = 15
364+
points = np.stack(
365+
np.meshgrid(*np.linspace(m_origin - m_width * 2, m_extent + m_width * 2, nd).T),
366+
axis=-1,
367+
)
368+
points = points.reshape((-1, dim))
369+
outside_points = points[~mesh.is_inside(points)]
370+
371+
npt.assert_array_equal(mesh.is_inside(outside_points), False)
372+
373+
# manually check each point that is outside
374+
cell_inds = mesh.point2index(outside_points)
375+
for icell, p in zip(cell_inds, outside_points):
376+
cell = mesh[icell]
377+
c_origin, c_extent = cell.bounds.reshape((dim, 2)).T
378+
dim_test = np.zeros(dim, bool)
379+
for i in range(dim):
380+
p_d = p[i]
381+
if p_d < m_origin[i]:
382+
dim_test[i] = p_d < c_origin[i]
383+
elif p_d > m_extent[i]:
384+
dim_test[i] = p_d > c_extent[i]
385+
else:
386+
dim_test[i] = p_d >= c_origin[i] and p_d <= c_extent[i]
387+
npt.assert_equal(dim_test, True)
388+
389+
324390
if __name__ == "__main__":
325391
unittest.main()

0 commit comments

Comments
 (0)