35 #include <vector>
36 #include <cassert>
37 #include <memory>
38 #include <functional>
39 #include <algorithm>
41 #include "StructuredOptions.hpp"
42 #include "ClusterTree.hpp"
43 #include "dense/DenseMatrix.hpp"
44 #if defined(STRUMPACK_USE_MPI)
45 #include "dense/DistributedMatrix.hpp"
46 #endif
48 namespace strumpack {
55  namespace structured {
66  template<typename scalar_t>
67  using extract_t = std::function
68  <scalar_t(std::size_t i, std::size_t j)>;
79  template<typename scalar_t>
80  using extract_block_t = std::function
81  <void(const std::vector<std::size_t>& I,
82  const std::vector<std::size_t>& J,
97  template<typename scalar_t>
98  using mult_t = std::function
99  <void(Trans op,
100  const DenseMatrix<scalar_t>& R,
103 #if defined(STRUMPACK_USE_MPI)
115  template<typename scalar_t>
116  using extract_dist_block_t = std::function
117  <void(const std::vector<std::size_t>& I,
118  const std::vector<std::size_t>& J,
133  template<typename scalar_t>
134  using mult_2d_t = std::function
135  <void(Trans op,
157  template<typename scalar_t>
158  using mult_1d_t = std::function
159  <void(Trans op,
160  const DenseMatrix<scalar_t>& R,
162  const std::vector<int>& rdist,
163  const std::vector<int>& cdist)>;
164 #endif
209  template<typename scalar_t> class StructuredMatrix {
210  using real_t = typename RealType<scalar_t>::value_type;
212  public:
217  virtual ~StructuredMatrix() = default;
222  virtual std::size_t rows() const = 0;
227  virtual std::size_t cols() const = 0;
236  virtual std::size_t memory() const = 0;
244  virtual std::size_t nonzeros() const = 0;
252  virtual std::size_t rank() const = 0;
262  virtual std::size_t local_rows() const {
263  throw std::invalid_argument
264  ("1d block row distribution not supported for this format.");
265  }
274  virtual std::size_t begin_row() const {
275  throw std::invalid_argument
276  ("1d block row distribution not supported for this format.");
277  }
286  virtual std::size_t end_row() const {
287  throw std::invalid_argument
288  ("1d block row distribution not supported for this format.");
289  }
299  virtual const std::vector<int>& dist() const {
300  static std::vector<int> d = {0, int(rows())};
301  return d;
302  };
310  virtual const std::vector<int>& rdist() const {
311  return dist();
312  };
320  virtual const std::vector<int>& cdist() const {
321  return dist();
322  };
332  virtual void mult(Trans op, const DenseMatrix<scalar_t>& x,
333  DenseMatrix<scalar_t>& y) const;
347  void mult(Trans op, int m, const scalar_t* x, int ldx,
348  scalar_t* y, int ldy) const;
350 #if defined(STRUMPACK_USE_MPI)
359  virtual void mult(Trans op, const DistributedMatrix<scalar_t>& x,
360  DistributedMatrix<scalar_t>& y) const;
361 #endif
369  virtual void factor();
378  virtual void solve(DenseMatrix<scalar_t>& b) const;
389  virtual void solve(int nrhs, scalar_t* b, int ldb) const {
390  int lr = rows();
391  try { lr = local_rows(); }
392  catch(...) {}
393  DenseMatrixWrapper<scalar_t> B(lr, nrhs, b, ldb);
394  solve(B);
395  }
397 #if defined(STRUMPACK_USE_MPI)
405  virtual void solve(DistributedMatrix<scalar_t>& b) const;
406 #endif
417  virtual void shift(scalar_t s);
419  };
459  template<typename scalar_t>
460  std::unique_ptr<StructuredMatrix<scalar_t>>
462  const StructuredOptions<scalar_t>& opts,
463  const structured::ClusterTree* row_tree=nullptr,
464  const structured::ClusterTree* col_tree=nullptr);
506  template<typename scalar_t>
507  std::unique_ptr<StructuredMatrix<scalar_t>>
508  construct_from_dense(int rows, int cols, const scalar_t* A, int ldA,
509  const StructuredOptions<scalar_t>& opts,
510  const structured::ClusterTree* row_tree=nullptr,
511  const structured::ClusterTree* col_tree=nullptr);
545  template<typename scalar_t> std::unique_ptr<StructuredMatrix<scalar_t>>
546  construct_from_elements(int rows, int cols,
547  const extract_t<scalar_t>& A,
548  const StructuredOptions<scalar_t>& opts,
549  const structured::ClusterTree* row_tree=nullptr,
550  const structured::ClusterTree* col_tree=nullptr);
585  template<typename scalar_t> std::unique_ptr<StructuredMatrix<scalar_t>>
586  construct_from_elements(int rows, int cols,
587  const extract_block_t<scalar_t>& A,
588  const StructuredOptions<scalar_t>& opts,
589  const structured::ClusterTree* row_tree=nullptr,
590  const structured::ClusterTree* col_tree=nullptr);
624  template<typename scalar_t> std::unique_ptr<StructuredMatrix<scalar_t>>
625  construct_matrix_free(int rows, int cols,
626  const mult_t<scalar_t>& Amult,
627  const StructuredOptions<scalar_t>& opts,
628  const structured::ClusterTree* row_tree=nullptr,
629  const structured::ClusterTree* col_tree=nullptr);
665  template<typename scalar_t> std::unique_ptr<StructuredMatrix<scalar_t>>
667  const mult_t<scalar_t>& Amult,
668  const extract_block_t<scalar_t>& Aelem,
669  const StructuredOptions<scalar_t>& opts,
670  const structured::ClusterTree* row_tree=nullptr,
671  const structured::ClusterTree* col_tree=nullptr);
707  template<typename scalar_t> std::unique_ptr<StructuredMatrix<scalar_t>>
709  const mult_t<scalar_t>& Amult,
710  const extract_t<scalar_t>& Aelem,
711  const StructuredOptions<scalar_t>& opts,
712  const structured::ClusterTree* row_tree=nullptr,
713  const structured::ClusterTree* col_tree=nullptr);
716 #if defined(STRUMPACK_USE_MPI)
747  template<typename scalar_t> std::unique_ptr<StructuredMatrix<scalar_t>>
749  const StructuredOptions<scalar_t>& opts,
750  const structured::ClusterTree* row_tree=nullptr,
751  const structured::ClusterTree* col_tree=nullptr);
786  template<typename scalar_t> std::unique_ptr<StructuredMatrix<scalar_t>>
787  construct_from_elements(const MPIComm& comm, int rows, int cols,
788  const extract_t<scalar_t>& A,
789  const StructuredOptions<scalar_t>& opts,
790  const structured::ClusterTree* row_tree=nullptr,
791  const structured::ClusterTree* col_tree=nullptr);
826  template<typename scalar_t> std::unique_ptr<StructuredMatrix<scalar_t>>
827  construct_from_elements(const MPIComm& comm, int rows, int cols,
829  const StructuredOptions<scalar_t>& opts,
830  const structured::ClusterTree* row_tree=nullptr,
831  const structured::ClusterTree* col_tree=nullptr);
863  template<typename scalar_t> std::unique_ptr<StructuredMatrix<scalar_t>>
864  construct_matrix_free(const MPIComm& comm, const BLACSGrid* g,
865  int rows, int cols,
866  const mult_2d_t<scalar_t>& Amult,
867  const StructuredOptions<scalar_t>& opts,
868  const structured::ClusterTree* row_tree=nullptr,
869  const structured::ClusterTree* col_tree=nullptr);
901  template<typename scalar_t> std::unique_ptr<StructuredMatrix<scalar_t>>
902  construct_matrix_free(const MPIComm& comm, int rows, int cols,
903  const mult_1d_t<scalar_t>& Amult,
904  const StructuredOptions<scalar_t>& opts,
905  const structured::ClusterTree* row_tree=nullptr,
906  const structured::ClusterTree* col_tree=nullptr);
908 #endif
910  } // end namespace structured
911 } // end namespace strumpack
