1414 You should have received a copy of the GNU General Public License
1515 along with SBG Library. If not, see <http://www.gnu.org/licenses/>.
1616
17- ******************************************************************************/
17+ ******************************************************************************/
1818
1919#include < unordered_map>
2020
@@ -28,57 +28,41 @@ using namespace SBG::LIB;
2828
2929namespace sbg_partitioner {
3030
31+ unordered_map<SetPiece, Set, SetPieceHash> CommunicationCost::_communication_by_set_piece = {};
32+
3133namespace internal {
3234
3335
34- static CommunicationCost* cost_matrix = nullptr ;
36+ // the only real instance
37+ CommunicationCostPtr cost_matrix = nullptr ;
3538
3639
3740namespace {
3841
39- ec_ic compute_EC_IC_from_map_1_to_map_2 (
40- const Partition& partition,
41- const SetPiece& nodes,
42- const PWMap& map_1,
43- const PWMap& map_2,
44- const SetAF& set_fact)
42+ Set set_piece_communication (const SetPiece& nodes, const WeightedSBGraph& graph)
4543{
46- auto nodes_set = set_fact.createSet (nodes);
47- auto d = map_1.preImage (nodes_set);
48- auto im = map_2.image (d);
49- auto partition_set = from_vector (partition, set_fact);
50- auto ic_nodes = partition_set.intersection (im);
51- ic_nodes = ic_nodes.difference (nodes_set);
52- auto ec_nodes = im.difference (ic_nodes);
53- auto ic = map_2.preImage (ic_nodes).intersection (d);
54- auto ec = map_2.preImage (ec_nodes).intersection (d);
44+ // convert nodes into a set
45+ auto node_set = graph.fact ().createSet (nodes);
5546
56- return make_pair (ec, ic);
57- }
47+ // compute preImage of map1 and map2 to get the edges that connects `nodes`
48+ auto edges_map1 = graph.map1 ().preImage (node_set);
49+ auto edges_map2 = graph.map2 ().preImage (node_set);
5850
51+ // Now compute the disjoint union to remove loop edges
52+ auto communication = edges_map1.cup (edges_map2).difference (edges_map1.intersection (edges_map2));
5953
54+ return communication;
6055}
6156
62-
63- ec_ic compute_EC_IC (
64- const Partition& partition,
65- const SetPiece& nodes,
66- const SBG::LIB::WeightedSBGraph& graph)
67- {
68- ec_ic cost1 = compute_EC_IC_from_map_1_to_map_2 (partition, nodes, graph.map1 (), graph.map2 (), graph.fact ());
69- ec_ic cost2 = compute_EC_IC_from_map_1_to_map_2 (partition, nodes, graph.map2 (), graph.map1 (), graph.fact ());
70-
71- ec_ic cost = ec_ic (cost1.first .cup (cost2.first ), cost1.second .cup (cost2.second ));
72-
73- return cost;
7457}
7558
7659}
7760
7861
7962CommunicationCost::CommunicationCost (const WeightedSBGraph& graph, PartitionMap partitions)
80- : _graph(graph),
81- _partitions (partitions)
63+ : ICommunicationCost(),
64+ _graph (graph),
65+ _partitions(partitions)
8266{
8367 initialize ();
8468}
@@ -91,40 +75,53 @@ void CommunicationCost::initialize()
9175 _ec_cost_by_interval.reserve (_partitions.size ());
9276 _ic_cost_by_interval.reserve (_partitions.size ());
9377 for (size_t i = 0 ; i < _partitions.size (); i++) {
78+ Set partition_i_communication = _graph.fact ().createSet ();
79+ Set internal_communication_partition_i = _graph.fact ().createSet ();
9480
95- _cost_by_partition.emplace_back (make_pair (_graph.fact ().createSet (), _graph.fact ().createSet ()));
96- _ec_cost_by_interval.emplace_back ();
97- _ic_cost_by_interval.emplace_back ();
9881 for (const auto & node : _partitions.at (i)) {
99- auto [ec, ic] = internal::compute_EC_IC (_partitions.at (i), node, _graph);
82+ if (_communication_by_set_piece.find (node) == _communication_by_set_piece.end ()) {
83+ _communication_by_set_piece.insert ({node, internal::set_piece_communication (node, _graph)});
84+ }
10085
101- _cost_by_partition. back () = { _cost_by_partition. back (). first . cup (ec), _cost_by_partition. back (). second . cup (ic) } ;
102- _ec_cost_by_interval. back (). insert ({node, ec} );
103- _ic_cost_by_interval. back (). insert ({node, ic});
86+ const auto & node_edges = _communication_by_set_piece. at (node) ;
87+ internal_communication_partition_i = node_edges. intersection (partition_i_communication). cup (internal_communication_partition_i );
88+ partition_i_communication = partition_i_communication. cup (node_edges);
10489 }
90+
91+ auto ec_parition_i = partition_i_communication.difference (internal_communication_partition_i);
92+ _cost_by_partition.emplace_back (make_pair (ec_parition_i, move (internal_communication_partition_i)));
93+
94+ _ic_cost_by_interval.emplace_back (); // save space for this, will be filled on demand
95+ _ec_cost_by_interval.emplace_back ();
10596 }
10697}
10798
10899
109- void CommunicationCost::update_partitions (PartitionMap& partitions, optional<vector< size_t >> modified_partitions)
100+ void CommunicationCost::update_partitions (PartitionMap& partitions, optional<reference_wrapper< const list< size_t > >> modified_partitions)
110101{
111102 _partitions = partitions;
112103 if (modified_partitions) {
113- Set update_nodes = _graph.fact ().createSet ();
114104 // now, update communication for partitions that were updated
115- for (size_t i : *modified_partitions) {
116- _ec_cost_by_interval[i].clear ();
117- _ic_cost_by_interval[i].clear ();
118- update_nodes = update_nodes.cup (from_vector (_partitions.at (i), _graph.fact ()));
119- _cost_by_partition[i] = make_pair (_graph.fact ().createSet (), _graph.fact ().createSet ());
105+ for (size_t i : modified_partitions->get ()) {
106+ Set partition_i_communication = _graph.fact ().createSet ();
107+ Set internal_communication_partition_i = _graph.fact ().createSet ();
108+
120109 for (const auto & node : _partitions.at (i)) {
121- auto [ec, ic] = internal::compute_EC_IC (_partitions.at (i), node, _graph);
110+ if (_communication_by_set_piece.find (node) == _communication_by_set_piece.end ()) {
111+ _communication_by_set_piece.insert ({node, internal::set_piece_communication (node, _graph)});
112+ }
122113
123- _cost_by_partition[i] = { _cost_by_partition .at (i). first . cup (ec), _cost_by_partition. at (i). second . cup (ic) } ;
124- _ec_cost_by_interval[i]. insert_or_assign (node, ec );
125- _ic_cost_by_interval[i]. insert_or_assign (node, ic );
114+ const auto & node_edges = _communication_by_set_piece .at (node) ;
115+ internal_communication_partition_i = node_edges. intersection (partition_i_communication). cup (internal_communication_partition_i );
116+ partition_i_communication = partition_i_communication. cup (node_edges );
126117 }
127- }
118+
119+ auto ec_parition_i = partition_i_communication.difference (internal_communication_partition_i);
120+ _cost_by_partition[i] = (make_pair (ec_parition_i, move (internal_communication_partition_i)));
121+
122+ _ic_cost_by_interval[i].clear ();
123+ _ec_cost_by_interval[i].clear ();
124+ }
128125 } else {
129126 // if modified partitions was not provided, update everything
130127 _cost_by_partition.clear ();
@@ -141,17 +138,36 @@ Set CommunicationCost::get_ec_by_partition_id(unsigned partition_id)
141138}
142139
143140
141+ pair<Set, Set> CommunicationCost::compute_ec_ic (unsigned partition_id, const SetPiece& nodes)
142+ {
143+ if (_communication_by_set_piece.find (nodes) == _communication_by_set_piece.end ()) {
144+ _communication_by_set_piece.insert ({nodes, internal::set_piece_communication (nodes, _graph)});
145+ }
146+
147+ auto communication = _communication_by_set_piece.at (nodes);
148+
149+ auto ec = communication.intersection (_cost_by_partition[partition_id].first );
150+ auto ic = communication.difference (ec);
151+ _ec_cost_by_interval[partition_id].insert ({nodes, ec});
152+ _ic_cost_by_interval[partition_id].insert ({nodes, ic});
153+
154+ return { ec, ic };
155+ }
156+
157+
144158Set CommunicationCost::get_ec_by_interval (unsigned partition_id, const SetPiece& nodes)
145159{
146160 if (_ec_cost_by_interval[partition_id].find (nodes) != _ec_cost_by_interval[partition_id].end ()) {
147161 return _ec_cost_by_interval[partition_id].at (nodes);
148162 }
149163
150- auto cost = internal::compute_EC_IC (_partitions.at (partition_id), nodes, _graph);
151- _ec_cost_by_interval[partition_id].insert ({nodes, cost.first });
152- _ic_cost_by_interval[partition_id].insert ({nodes, cost.second });
164+ if (_communication_by_set_piece.find (nodes) == _communication_by_set_piece.end ()) {
165+ _communication_by_set_piece.insert ({nodes, internal::set_piece_communication (nodes, _graph)});
166+ }
167+
168+ auto [ec, _] = compute_ec_ic (partition_id, nodes);
153169
154- return cost. first ;
170+ return ec ;
155171}
156172
157173
@@ -161,20 +177,67 @@ Set CommunicationCost::get_ic_by_interval(unsigned partition_id, const SetPiece&
161177 return _ic_cost_by_interval[partition_id].at (nodes);
162178 }
163179
164- auto cost = internal::compute_EC_IC (_partitions.at (partition_id), nodes, _graph);
165- _ec_cost_by_interval[partition_id].insert ({nodes, cost.first });
166- _ic_cost_by_interval[partition_id].insert ({nodes, cost.second });
180+ if (_communication_by_set_piece.find (nodes) == _communication_by_set_piece.end ()) {
181+ _communication_by_set_piece.insert ({nodes, internal::set_piece_communication (nodes, _graph)});
182+ }
183+
184+ auto [_, ic] = compute_ec_ic (partition_id, nodes);
185+
186+ return ic;
187+ }
188+
189+
190+
191+ CommunicationCostSync::CommunicationCostSync (const WeightedSBGraph& graph, PartitionMap partitions)
192+ :ICommunicationCost(),
193+ _comm_cost(graph, partitions)
194+ {}
195+
196+
197+ void CommunicationCostSync::update_partitions (PartitionMap& partitions, optional<reference_wrapper<const list<size_t >>> modified_partitions)
198+ {
199+ const lock_guard<mutex> lock (_mutex);
200+ _comm_cost.update_partitions (partitions, modified_partitions);
201+ }
202+
203+
204+ Set CommunicationCostSync::get_ec_by_partition_id (unsigned partition_id)
205+ {
206+ const lock_guard<mutex> lock (_mutex);
207+ return _comm_cost.get_ec_by_partition_id (partition_id);
208+ }
209+
210+
211+ Set CommunicationCostSync::get_ec_by_interval (unsigned partition_id, const SetPiece& nodes)
212+ {
213+ const lock_guard<mutex> lock (_mutex);
214+ return _comm_cost.get_ec_by_interval (partition_id, nodes);
215+ }
216+
167217
168- return cost.second ;
218+ Set CommunicationCostSync::get_ic_by_interval (unsigned partition_id, const SetPiece& nodes)
219+ {
220+ const lock_guard<mutex> lock (_mutex);
221+ return _comm_cost.get_ic_by_interval (partition_id, nodes);
169222}
170223
171224
172- void set_communication_cost (CommunicationCost& cost_matrix)
225+
226+ CommunicationCostPtr create_communication_cost (const WeightedSBGraph& graph, PartitionMap partitions, bool multithreading_enabled)
227+ {
228+ if (multithreading_enabled) {
229+ return make_unique<CommunicationCostSync>(graph, partitions);
230+ } else {
231+ return make_unique<CommunicationCost>(graph, partitions);
232+ }
233+ }
234+
235+ void set_communication_cost (CommunicationCostPtr&& cost_matrix)
173236{
174- internal::cost_matrix = new CommunicationCost (cost_matrix);
237+ internal::cost_matrix = move (cost_matrix);
175238}
176239
177- CommunicationCost & get_communication_cost ()
240+ ICommunicationCost & get_communication_cost ()
178241{
179242 assert (internal::cost_matrix);
180243 return *internal::cost_matrix;
0 commit comments