@@ -12,13 +12,19 @@ def visualize(
1212 save_path = "output.png" ,
1313 figsize = (12 , 12 ),
1414 node_size = 2000 ,
15- node_color = "lightblue" ,
15+ default_node_color = "lightblue" ,
1616 font_size = 10 ,
1717 edge_color = "black" ,
1818 title = "Directed Network Graph" ,
19+ most_active_users = None ,
20+ most_influential_users = None ,
1921 ):
2022 """
21- Generate a visual representation of the directed graph with enhanced arrows.
23+ Generate a visual representation of the directed graph with enhanced arrows and highlighted users.
24+
25+ Args:
26+ most_active_users (list): List of user IDs for most active users
27+ most_influential_users (list): List of user IDs for most influential users
2228 """
2329 # Create a directed networkx graph
2430 G = nx .DiGraph ()
@@ -38,15 +44,28 @@ def visualize(
3844 # Use spring layout with more space between nodes
3945 pos = nx .spring_layout (G , k = 2.5 , iterations = 50 )
4046
47+ # Create color map for nodes
48+ node_colors = []
49+ for node in G .nodes ():
50+ if most_active_users and node in most_active_users :
51+ if most_influential_users and node in most_influential_users :
52+ node_colors .append ("purple" ) # Both active and influential
53+ else :
54+ node_colors .append ("red" ) # Only active
55+ elif most_influential_users and node in most_influential_users :
56+ node_colors .append ("green" ) # Only influential
57+ else :
58+ node_colors .append (default_node_color )
59+
4160 # Draw the network - nodes first
4261 nx .draw_networkx_nodes (
4362 G ,
4463 pos ,
4564 node_size = node_size ,
46- node_color = node_color ,
47- edgecolors = "black" , # Add black border to nodes
65+ node_color = node_colors ,
66+ edgecolors = "black" ,
4867 linewidths = 2 ,
49- ) # Node border width
68+ )
5069
5170 # Draw edges with enhanced arrows
5271 edges = nx .draw_networkx_edges (
@@ -55,12 +74,12 @@ def visualize(
5574 edge_color = edge_color ,
5675 width = 2 ,
5776 arrows = True ,
58- arrowsize = 40 , # Larger arrows
59- arrowstyle = "->" , # Simple arrow style
60- connectionstyle = "arc3,rad=0.2" , # Curved edges
61- min_source_margin = 35 , # Space between arrow and source node
77+ arrowsize = 40 ,
78+ arrowstyle = "->" ,
79+ connectionstyle = "arc3,rad=0.2" ,
80+ min_source_margin = 35 ,
6281 min_target_margin = 35 ,
63- ) # Space between arrow and target node
82+ )
6483
6584 # Add labels with white background for better visibility
6685 labels = {user .id : f"{ user .name } \n (ID: { user .id } )" for user in self .graph .users }
@@ -72,6 +91,47 @@ def visualize(
7291 bbox = dict (facecolor = "white" , edgecolor = "none" , alpha = 0.7 , pad = 5 ),
7392 )
7493
94+ # Add legend
95+ legend_elements = []
96+ if most_active_users :
97+ legend_elements .append (
98+ plt .Line2D (
99+ [0 ],
100+ [0 ],
101+ marker = "o" ,
102+ color = "w" ,
103+ markerfacecolor = "red" ,
104+ markersize = 15 ,
105+ label = "Most Active" ,
106+ )
107+ )
108+ if most_influential_users :
109+ legend_elements .append (
110+ plt .Line2D (
111+ [0 ],
112+ [0 ],
113+ marker = "o" ,
114+ color = "w" ,
115+ markerfacecolor = "green" ,
116+ markersize = 15 ,
117+ label = "Most Influential" ,
118+ )
119+ )
120+ if most_active_users and most_influential_users :
121+ legend_elements .append (
122+ plt .Line2D (
123+ [0 ],
124+ [0 ],
125+ marker = "o" ,
126+ color = "w" ,
127+ markerfacecolor = "purple" ,
128+ markersize = 15 ,
129+ label = "Both Active & Influential" ,
130+ )
131+ )
132+ if legend_elements :
133+ plt .legend (handles = legend_elements , loc = "upper left" , bbox_to_anchor = (1 , 1 ))
134+
75135 # Add title
76136 plt .title (title , fontsize = 16 , pad = 20 )
77137
@@ -81,5 +141,7 @@ def visualize(
81141 # Add more space around the plot
82142 plt .margins (0.2 )
83143 plt .tight_layout ()
84- plt .savefig ("graph.png" if save_path is None else save_path )
144+ plt .savefig (
145+ "graph.png" if save_path is None else save_path , bbox_inches = "tight"
146+ ) # Added bbox_inches to prevent legend cutoff
85147 plt .close ()
0 commit comments