-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathvisualize3d.py
More file actions
109 lines (87 loc) · 3.49 KB
/
visualize3d.py
File metadata and controls
109 lines (87 loc) · 3.49 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import matplotlib as mpl
import kdtree
import test_data
import helper
def knn(kdtree, point, result):
ax = _setup()
_plot_kdtree(ax, kdtree)
_plot_knn(ax, point, result)
plt.show()
def tree(kdtree):
ax = _setup()
_plot_kdtree(ax, kdtree)
plt.show()
def _setup():
mpl.style.use("seaborn")
fig = plt.figure()
ax = Axes3D(fig)
ax.set_xlabel('X Label')
ax.set_ylabel('Y Label')
ax.set_zlabel('Z Label')
return ax
# KNN
def _plot_knn(ax, point, result):
_plot_knn_point(ax, point)
_plot_knn_result(ax, result)
_plot_knn_sphere(ax, point, result[-1][1])
def _plot_knn_point(ax, point):
x = point[0]
y = point[1]
z = point[2]
ax.plot([x], [y], [z], marker="X", markersize=20, markerfacecolor="xkcd:red")
def _plot_knn_result(ax, point_dist_list):
xs = [point_dist[0][0] for point_dist in point_dist_list]
ys = [point_dist[0][1] for point_dist in point_dist_list]
zs = [point_dist[0][2] for point_dist in point_dist_list]
ax.scatter(xs, ys, zs, s=100, linewidth=4, marker="x", color="xkcd:pink", alpha=1)
def _plot_knn_sphere(ax, point, radius):
u = np.linspace(0, 2* np.pi, 100)
v = np.linspace(0, np.pi, 100)
x = radius * np.outer(np.cos(u), np.sin(v)) + point[0]
y = radius * np.outer(np.sin(u), np.sin(v)) + point[1]
z = radius * np.outer(np.ones(np.size(u)), np.cos(v)) + point[2]
ax.plot_surface(x, y, z, color='b', rstride=4, cstride=4, alpha=0.2)
# KDTree
def _plot_kdtree(ax, kdtree):
_plot_points(ax, kdtree)
_plot_planes(ax, kdtree)
def _plot_points(ax, kdtree):
xs = [point[0] for point in kdtree.point_list]
ys = [point[1] for point in kdtree.point_list]
zs = [point[2] for point in kdtree.point_list]
ax.scatter(xs, ys, zs, s=100, linewidth=4, marker=".", color="xkcd:yellow", alpha=1)
def _plot_planes(ax, kdtree):
_plot_planes_helper(ax, kdtree.root, kdtree.num_dims)
def _plot_planes_helper(ax, node, num_dims):
if not node:
return
if not node.left and not node.right:
return
_plot_plane(ax, node, num_dims)
_plot_planes_helper(ax, node.left, num_dims)
_plot_planes_helper(ax, node.right, num_dims)
def _plot_plane(ax, node, num_dims, default_plane_width=10, num_samples=10):
boundaries = helper._boundaries(node, num_dims)
boundaries = boundaries[node.axis:] + boundaries[:node.axis]
child_dim = _dim_range(boundaries[1], default_plane_width, num_samples)
grandchild_dim = _dim_range(boundaries[2], default_plane_width, num_samples)
child_matrix, grandchild_matrix = np.meshgrid(child_dim, grandchild_dim)
constant_dim = np.linspace(node.data[node.axis], node.data[node.axis], num_samples)
constant_matrix, _ = np.meshgrid(constant_dim, constant_dim)
plot_input = [constant_matrix, child_matrix, grandchild_matrix]
plot_input = plot_input[-node.axis:] + plot_input[:-node.axis]
ax.plot_surface(plot_input[0], plot_input[1], plot_input[2], alpha=0.8)
def _dim_range(boundary, default_plane_width, num_samples):
beg = boundary[0] if boundary[0] is not None else -default_plane_width
end = boundary[1] if boundary[1] is not None else default_plane_width
return np.linspace(beg, end, num_samples)
if __name__ == "__main__":
num_dims = 3
tree = kdtree.KDTree(test_data.list3d_2, num_dims)
point = test_data.rand_point(num_dims)
k = 7
result = tree.knn(point, k)
knn(tree, point, result)