|
14 | 14 | # https://github.com/networkx/networkx/blob/ead0e65bda59862e329f2e6f1da47919c6b07ca9/networkx/drawing/tests/test_pylab.py |
15 | 15 |
|
16 | 16 | import os |
| 17 | +import sys |
17 | 18 | import unittest |
18 | 19 |
|
19 | 20 | import rustworkx |
@@ -204,3 +205,81 @@ def test_hexagonal_lattice_directed(self): |
204 | 205 | plt.close("all") |
205 | 206 | mpl_draw(graph, pos=[graph.get_node_data(n) for n in range(len(graph))]) |
206 | 207 | _save_images(plt.gcf(), "test_hexagonal_lattice_directed.png") |
| 208 | + |
| 209 | + |
| 210 | +@unittest.skipUnless(HAS_MPL, "matplotlib is required for running these tests") |
| 211 | +@unittest.skipUnless(sys.platform.startswith("linux"), "Assertion tests are Linux-only") |
| 212 | +class TestMPLDrawAssertions(unittest.TestCase): |
| 213 | + def setUp(self): |
| 214 | + plt.close("all") |
| 215 | + self.fig, self.ax = plt.subplots() |
| 216 | + |
| 217 | + def tearDown(self): |
| 218 | + plt.close("all") |
| 219 | + |
| 220 | + def test_node_count(self): |
| 221 | + graph = rustworkx.generators.empty_graph(5) |
| 222 | + mpl_draw(graph, ax=self.ax) |
| 223 | + node_collection = self.ax.collections[0] |
| 224 | + self.assertEqual(len(node_collection.get_offsets()), 5) |
| 225 | + |
| 226 | + def test_node_list_filters_nodes(self): |
| 227 | + graph = rustworkx.generators.star_graph(10) |
| 228 | + mpl_draw(graph, ax=self.ax, node_list=[0, 1, 2]) |
| 229 | + node_collection = self.ax.collections[0] |
| 230 | + self.assertEqual(len(node_collection.get_offsets()), 3) |
| 231 | + |
| 232 | + def test_node_color_applied(self): |
| 233 | + graph = rustworkx.PyGraph() |
| 234 | + graph.add_nodes_from(range(3)) |
| 235 | + mpl_draw(graph, ax=self.ax, node_color=["red", "red", "red"]) |
| 236 | + colors = self.ax.collections[0].get_facecolors() |
| 237 | + self.assertEqual(len(colors), 3) |
| 238 | + |
| 239 | + def test_node_labels_drawn(self): |
| 240 | + graph = rustworkx.PyGraph() |
| 241 | + graph.add_nodes_from(["x", "y"]) |
| 242 | + mpl_draw(graph, ax=self.ax, with_labels=True, labels=str) |
| 243 | + texts = [t.get_text() for t in self.ax.texts] |
| 244 | + self.assertIn("x", texts) |
| 245 | + self.assertIn("y", texts) |
| 246 | + |
| 247 | + def test_empty_graph_no_collections(self): |
| 248 | + graph = rustworkx.PyGraph() |
| 249 | + mpl_draw(graph, ax=self.ax) |
| 250 | + # No nodes drawn means no offset collections with data |
| 251 | + offsets = [ |
| 252 | + c for c in self.ax.collections if hasattr(c, "get_offsets") and len(c.get_offsets()) > 0 |
| 253 | + ] |
| 254 | + self.assertEqual(len(offsets), 0) |
| 255 | + |
| 256 | + def test_edge_count(self): |
| 257 | + graph = rustworkx.PyGraph() |
| 258 | + graph.add_nodes_from(range(3)) |
| 259 | + graph.add_edges_from([(0, 1, None), (1, 2, None)]) |
| 260 | + mpl_draw(graph, ax=self.ax) |
| 261 | + # Edges are drawn as LineCollection or FancyArrowPatches |
| 262 | + self.assertGreater(len(self.ax.collections) + len(self.ax.patches), 1) |
| 263 | + |
| 264 | + def test_directed_graph_produces_arrows(self): |
| 265 | + graph = rustworkx.PyDiGraph() |
| 266 | + graph.add_nodes_from(range(2)) |
| 267 | + graph.add_edge(0, 1, None) |
| 268 | + mpl_draw(graph, ax=self.ax) |
| 269 | + # Directed edges are FancyArrowPatch objects in ax.patches |
| 270 | + self.assertGreater(len(self.ax.patches), 0) |
| 271 | + |
| 272 | + def test_edge_labels_drawn(self): |
| 273 | + graph = rustworkx.PyGraph() |
| 274 | + graph.add_nodes_from(range(2)) |
| 275 | + graph.add_edge(0, 1, "myedge") |
| 276 | + mpl_draw(graph, ax=self.ax, edge_labels=str) |
| 277 | + texts = [t.get_text() for t in self.ax.texts] |
| 278 | + self.assertIn("myedge", texts) |
| 279 | + |
| 280 | + def test_node_size_zero_draws_no_visible_nodes(self): |
| 281 | + graph = rustworkx.generators.empty_graph(3) |
| 282 | + mpl_draw(graph, ax=self.ax, node_size=0) |
| 283 | + # With node_size=0, collection exists but sizes should all be 0 |
| 284 | + sizes = self.ax.collections[0].get_sizes() |
| 285 | + self.assertTrue(all(s == 0 for s in sizes)) |
0 commit comments