HSSPartitionTree.hpp
Go to the documentation of this file.
1 /*
2  * STRUMPACK -- STRUctured Matrices PACKage, Copyright (c) 2014, The
3  * Regents of the University of California, through Lawrence Berkeley
4  * National Laboratory (subject to receipt of any required approvals
5  * from the U.S. Dept. of Energy). All rights reserved.
6  *
7  * If you have questions about your rights to use or distribute this
8  * software, please contact Berkeley Lab's Technology Transfer
9  * Department at TTD@lbl.gov.
10  *
11  * NOTICE. This software is owned by the U.S. Department of Energy. As
12  * such, the U.S. Government has been granted for itself and others
13  * acting on its behalf a paid-up, nonexclusive, irrevocable,
14  * worldwide license in the Software to reproduce, prepare derivative
15  * works, and perform publicly and display publicly. Beginning five
16  * (5) years after the date permission to assert copyright is obtained
17  * from the U.S. Department of Energy, and subject to any subsequent
18  * five (5) year renewals, the U.S. Government is granted for itself
19  * and others acting on its behalf a paid-up, nonexclusive,
20  * irrevocable, worldwide license in the Software to reproduce,
21  * prepare derivative works, distribute copies to the public, perform
22  * publicly and display publicly, and to permit others to do so.
23  *
24  * Developers: Pieter Ghysels, Francois-Henry Rouet, Xiaoye S. Li.
25  * (Lawrence Berkeley National Lab, Computational Research
26  * Division).
27  *
28  */
33 #ifndef HSS_PARTITION_TREE_HPP
34 #define HSS_PARTITION_TREE_HPP
35 
36 #include <vector>
37 #include <unordered_map>
38 #include <cassert>
39 #include <iostream>
40 
41 namespace strumpack {
42  namespace HSS {
43 
66  public:
72  int size;
73 
79  std::vector<HSSPartitionTree> c;
80 
85 
93  HSSPartitionTree(int n) : size(n) {}
94 
103  void refine(int leaf_size) {
104  assert(c.empty());
105  if (size >= 2*leaf_size) {
106  c.resize(2);
107  c[0].size = size/2;
108  c[0].refine(leaf_size);
109  c[1].size = size - size/2;
110  c[1].refine(leaf_size);
111  }
112  }
113 
117  void print() const {
118  for (auto& ch : c) ch.print();
119  std::cout << size << " ";
120  }
121 
128  int nodes() const {
129  int nr_nodes = 1;
130  for (auto& ch : c)
131  nr_nodes += ch.nodes();
132  return nr_nodes;
133  }
134 
140  int levels() const {
141  int lvls = 0;
142  for (auto& ch : c)
143  lvls = std::max(lvls, ch.levels());
144  return lvls + 1;
145  }
146 
156  truncate_complete_rec(1, min_levels());
157  }
158 
174  void expand_complete(bool allow_zero_nodes) {
175  expand_complete_rec(1, levels(), allow_zero_nodes);
176  }
177 
185  void expand_complete_levels(int lvls) {
186  expand_complete_levels_rec(1, lvls);
187  }
188 
198  template<typename integer_t> static HSSPartitionTree
199  deserialize(const std::vector<integer_t>& buf) {
200  return deserialize(buf.data());
201  }
202 
212  template<typename integer_t> static HSSPartitionTree
213  deserialize(const integer_t* buf) {
215  int n = buf[0];
216  int pid = n-1;
217  t.de_serialize_rec(buf+1, buf+n+1, buf+2*n+1, pid);
218  return t;
219  }
220 
231  std::vector<int> serialize() const {
232  int n = nodes(), pid = 0;
233  std::vector<int> buf(3*n+1);
234  buf[0] = n;
235  serialize_rec(buf.data()+1, buf.data()+n+1, buf.data()+2*n+1, pid);
236  return buf;
237  }
238 
245  bool is_complete() const {
246  if (c.empty()) return true;
247  else return c[0].levels() == c[1].levels();
248  }
249 
255  template<typename integer_t>
256  std::vector<integer_t> leaf_sizes() const {
257  std::vector<integer_t> lf;
258  leaf_sizes_rec(lf);
259  return lf;
260  }
261 
269  std::pair<std::vector<int>,std::vector<int>>
270  map_from_complete_to_leafs(int lvls) const {
271  int n = (1 << lvls) - 1;
272  std::vector<int> map0(n, -1), map1(n, -1);
273  int leaf = 0;
274  complete_to_orig_rec(1, map0, map1, leaf);
275  for (int i=0; i<n; i++) {
276  if (map0[i] == -1) map0[i] = map0[(i+1)/2-1];
277  if (map1[i] == -1) map1[i] = map1[(i+1)/2-1];
278  }
279  // std::cout << "nodes=" << nodes() << " levels()=" << levels()
280  // << " lvls=" << lvls << " map0/map1 = [";
281  // for (int i=0; i<n; i++)
282  // std::cout << map0[i] << "/" << map1[i] << " ";
283  // std::cout << std::endl;
284  return {map0, map1};
285  }
286 
287  private:
288  int min_levels() const {
289  int lvls = levels();
290  for (auto& ch : c)
291  lvls = std::min(lvls, 1 + ch.min_levels());
292  return lvls;
293  }
294  void truncate_complete_rec(int lvl, int lvls) {
295  if (lvl == lvls) c.clear();
296  else
297  for (auto& ch : c)
298  ch.truncate_complete_rec(lvl+1, lvls);
299  }
300  void expand_complete_rec(int lvl, int lvls, bool allow_zero_nodes) {
301  if (c.empty()) {
302  if (lvl != lvls) {
303  c.resize(2);
304  if (allow_zero_nodes) {
305  c[0].size = size;
306  c[1].size = 0;
307  } else {
308  int l1 = 1 << (lvls - lvl - 1);
309  c[0].size = size - l1;
310  c[1].size = l1;
311  }
312  c[0].expand_complete_rec(lvl+1, lvls, allow_zero_nodes);
313  c[1].expand_complete_rec(lvl+1, lvls, allow_zero_nodes);
314  }
315  } else
316  for (auto& ch : c)
317  ch.expand_complete_rec(lvl+1, lvls, allow_zero_nodes);
318  }
319 
320  void complete_to_orig_rec
321  (int id, std::vector<int>& map0, std::vector<int>& map1,
322  int& leaf) const {
323  if (c.empty()) map0[id-1] = map1[id-1] = leaf++;
324  else {
325  c[0].complete_to_orig_rec(id*2, map0, map1, leaf);
326  c[1].complete_to_orig_rec(id*2+1, map0, map1, leaf);
327  map0[id-1] = map0[id*2-1];
328  map1[id-1] = map1[id*2];
329  }
330  }
331 
332  void expand_complete_levels_rec(int lvl, int lvls) {
333  if (c.empty()) {
334  if (lvl != lvls) {
335  c.resize(2);
336  c[0].size = size / 2;
337  c[1].size = size - size / 2;
338  c[0].expand_complete_levels_rec(lvl+1, lvls);
339  c[1].expand_complete_levels_rec(lvl+1, lvls);
340  }
341  } else
342  for (auto& ch : c)
343  ch.expand_complete_levels_rec(lvl+1, lvls);
344  }
345  template<typename integer_t>
346  void leaf_sizes_rec(std::vector<integer_t>& lf) const {
347  for (auto& ch : c)
348  ch.leaf_sizes_rec(lf);
349  if (c.empty())
350  lf.push_back(size);
351  }
352  void serialize_rec(int* sizes, int* lchild, int* rchild, int& pid) const {
353  if (!c.empty()) {
354  c[0].serialize_rec(sizes, lchild, rchild, pid);
355  auto lroot = pid;
356  c[1].serialize_rec(sizes, lchild, rchild, pid);
357  lchild[pid] = lroot-1;
358  rchild[pid] = pid-1;
359  } else lchild[pid] = rchild[pid] = -1;
360  sizes[pid++] = size;
361  }
362 
363  template<typename integer_t> void de_serialize_rec
364  (const integer_t* sizes, const integer_t* lchild,
365  const integer_t* rchild, int& pid) {
366  size = sizes[pid--];
367  if (rchild[pid+1] != -1) {
368  c.resize(2);
369  c[1].de_serialize_rec(sizes, lchild, rchild, pid);
370  c[0].de_serialize_rec(sizes, lchild, rchild, pid);
371  }
372  }
373  };
374 
375  } // end namespace HSS
376 } // end namespace strumpack
377 
378 #endif
strumpack::HSS::HSSPartitionTree::print
void print() const
Definition: HSSPartitionTree.hpp:117
strumpack::HSS::HSSPartitionTree::c
std::vector< HSSPartitionTree > c
Definition: HSSPartitionTree.hpp:79
strumpack::HSS::HSSPartitionTree
The cluster tree, or partition tree that represents the matrix partitioning of an HSS matrix.
Definition: HSSPartitionTree.hpp:65
strumpack::HSS::HSSPartitionTree::nodes
int nodes() const
Definition: HSSPartitionTree.hpp:128
strumpack::HSS::HSSPartitionTree::expand_complete
void expand_complete(bool allow_zero_nodes)
Definition: HSSPartitionTree.hpp:174
strumpack::HSS::HSSPartitionTree::size
int size
Definition: HSSPartitionTree.hpp:72
strumpack
Definition: StrumpackOptions.hpp:42
strumpack::HSS::HSSPartitionTree::leaf_sizes
std::vector< integer_t > leaf_sizes() const
Definition: HSSPartitionTree.hpp:256
strumpack::HSS::HSSPartitionTree::HSSPartitionTree
HSSPartitionTree()
Definition: HSSPartitionTree.hpp:84
strumpack::HSS::HSSPartitionTree::map_from_complete_to_leafs
std::pair< std::vector< int >, std::vector< int > > map_from_complete_to_leafs(int lvls) const
Definition: HSSPartitionTree.hpp:270
strumpack::HSS::HSSPartitionTree::levels
int levels() const
Definition: HSSPartitionTree.hpp:140
strumpack::HSS::HSSPartitionTree::refine
void refine(int leaf_size)
Definition: HSSPartitionTree.hpp:103
strumpack::HSS::HSSPartitionTree::serialize
std::vector< int > serialize() const
Definition: HSSPartitionTree.hpp:231
strumpack::CompressionType::HSS
@ HSS
strumpack::HSS::HSSPartitionTree::is_complete
bool is_complete() const
Definition: HSSPartitionTree.hpp:245
strumpack::HSS::HSSPartitionTree::expand_complete_levels
void expand_complete_levels(int lvls)
Definition: HSSPartitionTree.hpp:185
strumpack::HSS::HSSPartitionTree::HSSPartitionTree
HSSPartitionTree(int n)
Definition: HSSPartitionTree.hpp:93
strumpack::HSS::HSSPartitionTree::deserialize
static HSSPartitionTree deserialize(const std::vector< integer_t > &buf)
Definition: HSSPartitionTree.hpp:199
strumpack::HSS::HSSPartitionTree::truncate_complete
void truncate_complete()
Definition: HSSPartitionTree.hpp:155
strumpack::HSS::HSSPartitionTree::deserialize
static HSSPartitionTree deserialize(const integer_t *buf)
Definition: HSSPartitionTree.hpp:213