1+ /* *
2+ This file is part of Set--Based Graph Library.
3+
4+ SBG Library is free software: you can redistribute it and/or modify
5+ it under the terms of the GNU General Public License as published by
6+ the Free Software Foundation, either version 3 of the License, or
7+ (at your option) any later version.
8+
9+ SBG Library is distributed in the hope that it will be useful,
10+ but WITHOUT ANY WARRANTY; without even the implied warranty of
11+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12+ GNU General Public License for more details.
13+
14+ You should have received a copy of the GNU General Public License
15+ along with SBG Library. If not, see <http://www.gnu.org/licenses/>.
16+
17+ ******************************************************************************/
18+
19+ #include < unordered_map>
20+
21+ #include " communication_cost.hpp"
22+ #include " partition_graph.hpp"
23+
24+
25+ using namespace std ;
26+
27+ using namespace SBG ::LIB;
28+
29+ namespace sbg_partitioner {
30+
31+ namespace internal {
32+
33+
34+ static CommunicationCost* cost_matrix = nullptr ;
35+
36+
37+ namespace {
38+
39+ ec_ic compute_EC_IC_from_map_1_to_map_2 (
40+ const Partition& partition,
41+ const SetPiece& nodes,
42+ const PWMap& map_1,
43+ const PWMap& map_2,
44+ const SetAF& set_fact)
45+ {
46+ auto nodes_set = set_fact.createSet (nodes);
47+ auto d = map_1.preImage (nodes_set);
48+ auto im = map_2.image (d);
49+ auto partition_set = from_vector (partition, set_fact);
50+ auto ic_nodes = partition_set.intersection (im);
51+ ic_nodes = ic_nodes.difference (nodes_set);
52+ auto ec_nodes = im.difference (ic_nodes);
53+ auto ic = map_2.preImage (ic_nodes).intersection (d);
54+ auto ec = map_2.preImage (ec_nodes).intersection (d);
55+
56+ return make_pair (ec, ic);
57+ }
58+
59+
60+ }
61+
62+
63+ ec_ic compute_EC_IC (
64+ const Partition& partition,
65+ const SetPiece& nodes,
66+ const SBG::LIB::WeightedSBGraph& graph)
67+ {
68+ ec_ic cost1 = compute_EC_IC_from_map_1_to_map_2 (partition, nodes, graph.map1 (), graph.map2 (), graph.fact ());
69+ ec_ic cost2 = compute_EC_IC_from_map_1_to_map_2 (partition, nodes, graph.map2 (), graph.map1 (), graph.fact ());
70+
71+ ec_ic cost = ec_ic (cost1.first .cup (cost2.first ), cost1.second .cup (cost2.second ));
72+
73+ return cost;
74+ }
75+
76+ }
77+
78+
79+ CommunicationCost::CommunicationCost (const WeightedSBGraph& graph, PartitionMap partitions)
80+ : _graph(graph),
81+ _partitions (partitions)
82+ {
83+ initialize ();
84+ }
85+
86+
87+ void CommunicationCost::initialize ()
88+ {
89+ // compute cost by interval and partitions
90+ _cost_by_partition.reserve (_partitions.size ());
91+ _ec_cost_by_interval.reserve (_partitions.size ());
92+ _ic_cost_by_interval.reserve (_partitions.size ());
93+ for (size_t i = 0 ; i < _partitions.size (); i++) {
94+
95+ _cost_by_partition.emplace_back (make_pair (_graph.fact ().createSet (), _graph.fact ().createSet ()));
96+ _ec_cost_by_interval.emplace_back ();
97+ _ic_cost_by_interval.emplace_back ();
98+ for (const auto & node : _partitions.at (i)) {
99+ auto [ec, ic] = internal::compute_EC_IC (_partitions.at (i), node, _graph);
100+
101+ _cost_by_partition.back () = { _cost_by_partition.back ().first .cup (ec), _cost_by_partition.back ().second .cup (ic) };
102+ _ec_cost_by_interval.back ().insert ({node, ec});
103+ _ic_cost_by_interval.back ().insert ({node, ic});
104+ }
105+ }
106+ }
107+
108+
109+ void CommunicationCost::update_partitions (PartitionMap& partitions, optional<vector<size_t >> modified_partitions)
110+ {
111+ _partitions = partitions;
112+ if (modified_partitions) {
113+ Set update_nodes = _graph.fact ().createSet ();
114+ // now, update communication for partitions that were updated
115+ for (size_t i : *modified_partitions) {
116+ _ec_cost_by_interval[i].clear ();
117+ _ic_cost_by_interval[i].clear ();
118+ update_nodes = update_nodes.cup (from_vector (_partitions.at (i), _graph.fact ()));
119+ _cost_by_partition[i] = make_pair (_graph.fact ().createSet (), _graph.fact ().createSet ());
120+ for (const auto & node : _partitions.at (i)) {
121+ auto [ec, ic] = internal::compute_EC_IC (_partitions.at (i), node, _graph);
122+
123+ _cost_by_partition[i] = { _cost_by_partition.at (i).first .cup (ec), _cost_by_partition.at (i).second .cup (ic) };
124+ _ec_cost_by_interval[i].insert_or_assign (node, ec);
125+ _ic_cost_by_interval[i].insert_or_assign (node, ic);
126+ }
127+ }
128+ } else {
129+ // if modified partitions was not provided, update everything
130+ _cost_by_partition.clear ();
131+ _ec_cost_by_interval.clear ();
132+ _ic_cost_by_interval.clear ();
133+ initialize ();
134+ }
135+ }
136+
137+
138+ Set CommunicationCost::get_ec_by_partition_id (unsigned partition_id)
139+ {
140+ return _cost_by_partition[partition_id].first ;
141+ }
142+
143+
144+ Set CommunicationCost::get_ec_by_interval (unsigned partition_id, const SetPiece& nodes)
145+ {
146+ if (_ec_cost_by_interval[partition_id].find (nodes) != _ec_cost_by_interval[partition_id].end ()) {
147+ return _ec_cost_by_interval[partition_id].at (nodes);
148+ }
149+
150+ auto cost = internal::compute_EC_IC (_partitions.at (partition_id), nodes, _graph);
151+ _ec_cost_by_interval[partition_id].insert ({nodes, cost.first });
152+ _ic_cost_by_interval[partition_id].insert ({nodes, cost.second });
153+
154+ return cost.first ;
155+ }
156+
157+
158+ Set CommunicationCost::get_ic_by_interval (unsigned partition_id, const SetPiece& nodes)
159+ {
160+ if (_ic_cost_by_interval[partition_id].find (nodes) != _ic_cost_by_interval[partition_id].end ()) {
161+ return _ic_cost_by_interval[partition_id].at (nodes);
162+ }
163+
164+ auto cost = internal::compute_EC_IC (_partitions.at (partition_id), nodes, _graph);
165+ _ec_cost_by_interval[partition_id].insert ({nodes, cost.first });
166+ _ic_cost_by_interval[partition_id].insert ({nodes, cost.second });
167+
168+ return cost.second ;
169+ }
170+
171+
172+ void set_communication_cost (CommunicationCost& cost_matrix)
173+ {
174+ internal::cost_matrix = new CommunicationCost (cost_matrix);
175+ }
176+
177+ CommunicationCost& get_communication_cost ()
178+ {
179+ assert (internal::cost_matrix);
180+ return *internal::cost_matrix;
181+ }
182+
183+ }
0 commit comments