Loading...
Searching...
No Matches
Kernel.hpp
Go to the documentation of this file.
1/*
2 * STRUMPACK -- STRUctured Matrices PACKage, Copyright (c) 2014, The
3 * Regents of the University of California, through Lawrence Berkeley
4 * National Laboratory (subject to receipt of any required approvals
5 * from the U.S. Dept. of Energy). All rights reserved.
6 *
7 * If you have questions about your rights to use or distribute this
8 * software, please contact Berkeley Lab's Technology Transfer
9 * Department at TTD@lbl.gov.
10 *
11 * NOTICE. This software is owned by the U.S. Department of Energy. As
12 * such, the U.S. Government has been granted for itself and others
13 * acting on its behalf a paid-up, nonexclusive, irrevocable,
14 * worldwide license in the Software to reproduce, prepare derivative
15 * works, and perform publicly and display publicly. Beginning five
16 * (5) years after the date permission to assert copyright is obtained
17 * from the U.S. Department of Energy, and subject to any subsequent
18 * five (5) year renewals, the U.S. Government is granted for itself
19 * and others acting on its behalf a paid-up, nonexclusive,
20 * irrevocable, worldwide license in the Software to reproduce,
21 * prepare derivative works, distribute copies to the public, perform
22 * publicly and display publicly, and to permit others to do so.
23 *
24 * Developers: Pieter Ghysels, Francois-Henry Rouet, Xiaoye S. Li.
25 * (Lawrence Berkeley National Lab, Computational Research
26 * Division).
27 *
28 */
36#ifndef STRUMPACK_KERNEL_HPP
37#define STRUMPACK_KERNEL_HPP
38
39#include "Metrics.hpp"
40#include "HSS/HSSOptions.hpp"
41#include "dense/DenseMatrix.hpp"
42#if defined(STRUMPACK_USE_MPI)
44#if defined(STRUMPACK_USE_BPACK)
46#endif
47#endif
48
49namespace strumpack {
50
54 namespace kernel {
55
73 template<typename scalar_t> class Kernel {
74 using real_t = typename RealType<scalar_t>::value_type;
77#if defined(STRUMPACK_USE_MPI)
79#endif
80
81 public:
92 Kernel(DenseM_t& data, scalar_t lambda)
93 : data_(data), lambda_(lambda) { }
94
98 virtual ~Kernel() = default;
99
106 std::size_t n() const { return data_.cols(); }
107
113 std::size_t d() const { return data_.rows(); }
114
122 virtual scalar_t eval(std::size_t i, std::size_t j) const {
123 return eval_kernel_function(data_.ptr(0, i), data_.ptr(0, j))
124 + ((i == j) ? lambda_ : scalar_t(0.));
125 }
126
138 void operator()(const std::vector<std::size_t>& I,
139 const std::vector<std::size_t>& J,
140 DenseMatrix<real_t>& B) const {
141 assert(B.rows() == I.size() && B.cols() == J.size());
142 for (std::size_t j=0; j<J.size(); j++)
143 for (std::size_t i=0; i<I.size(); i++) {
144 assert(I[i] < n() && J[j] < n());
145 B(i, j) = eval(I[i], J[j]);
146 }
147 }
148
160 void operator()(const std::vector<std::size_t>& I,
161 const std::vector<std::size_t>& J,
162 DenseMatrix<std::complex<real_t>>& B) const {
163 assert(B.rows() == I.size() && B.cols() == J.size());
164 for (std::size_t j=0; j<J.size(); j++)
165 for (std::size_t i=0; i<I.size(); i++) {
166 assert(I[i] < n() && J[j] < n());
167 B(i, j) = eval(I[i], J[j]);
168 }
169 }
170
190 (std::vector<scalar_t>& labels, const HSS::HSSOptions<scalar_t>& opts);
191
203 std::vector<scalar_t> predict
204 (const DenseM_t& test, const DenseM_t& weights) const;
205
206#if defined(STRUMPACK_USE_MPI)
228 (const BLACSGrid& grid, std::vector<scalar_t>& labels,
229 const HSS::HSSOptions<scalar_t>& opts);
230
242 std::vector<scalar_t> predict
243 (const DenseM_t& test, const DistM_t& weights) const;
244
245#if defined(STRUMPACK_USE_BPACK)
265 (const MPIComm& c, std::vector<scalar_t>& labels,
267#endif
268#endif
269
276 const DenseM_t& data() const { return data_; }
282 DenseM_t& data() { return data_; }
283
284 std::vector<int>& permutation() { return perm_; }
285 const std::vector<int>& permutation() const { return perm_; }
286
287 virtual void permute() {
288 data_.lapmr(perm_, true);
289 }
290
291 protected:
292 DenseM_t& data_;
293 scalar_t lambda_;
294 std::vector<int> perm_;
295
310 virtual scalar_t eval_kernel_function
311 (const scalar_t* x, const scalar_t* y) const = 0;
312 };
313
314
332 template<typename scalar_t>
333 class GaussKernel : public Kernel<scalar_t> {
334 public:
345 GaussKernel(DenseMatrix<scalar_t>& data, scalar_t h, scalar_t lambda)
346 : Kernel<scalar_t>(data, lambda), h_(h) {}
347
348 protected:
349 scalar_t h_; // kernel width parameter
350
351 scalar_t eval_kernel_function
352 (const scalar_t* x, const scalar_t* y) const override {
353 return std::exp
354 (-Euclidean_distance_squared(this->d(), x, y)
355 / (scalar_t(2.) * h_ * h_));
356 }
357 };
358
359
377 template<typename scalar_t>
378 class LaplaceKernel : public Kernel<scalar_t> {
379 public:
390 LaplaceKernel(DenseMatrix<scalar_t>& data, scalar_t h, scalar_t lambda)
391 : Kernel<scalar_t>(data, lambda), h_(h) {}
392
393 protected:
394 scalar_t h_; // kernel width parameter
395
396 scalar_t eval_kernel_function
397 (const scalar_t* x, const scalar_t* y) const override {
398 return std::exp(-norm1_distance(this->d(), x, y) / h_);
399 }
400 };
401
423 template<typename scalar_t>
424 class ANOVAKernel : public Kernel<scalar_t> {
425 public:
438 (DenseMatrix<scalar_t>& data, scalar_t h, scalar_t lambda, int p=1)
439 : Kernel<scalar_t>(data, lambda), h_(h), p_(p) {
440 assert(p >= 1 && p <= int(this->d()));
441 }
442
443 protected:
444 scalar_t h_; // kernel width parameter
445 int p_; // kernel degree parameter 1 <= p_ <= this->d()
446
447 scalar_t eval_kernel_function
448 (const scalar_t* x, const scalar_t* y) const override {
449 std::vector<scalar_t> Ks(p_), Kss(p_), Kpp(p_+1);
450 Kpp[0] = 1;
451 for (int j=0; j<p_; j++) Kss[j] = 0;
452 for (std::size_t i=0; i<this->d(); i++) {
453 scalar_t tmp = std::exp
454 (-Euclidean_distance_squared(1, x+i, y+i)
455 / (scalar_t(2.) * h_ * h_));
456 Ks[0] = tmp;
457 Kss[0] += Ks[0];
458 for (int j=1; j<p_; j++) {
459 Ks[j] = Ks[j-1]*tmp;
460 Kss[j] += Ks[j];
461 }
462 }
463 for (int i=1; i<=p_; i++) {
464 Kpp[i] = 0;
465 for (int s=1; s<=i; s++)
466 Kpp[i] += std::pow(-1,s+1)*Kpp[i-s]*Kss[s-1];
467 Kpp[i] /= i;
468 }
469 return Kpp[p_];
470 }
471 };
472
473
485 template<typename scalar_t>
486 class DenseKernel : public Kernel<scalar_t> {
487 public:
499 DenseMatrix<scalar_t>& A, scalar_t lambda)
500 : Kernel<scalar_t>(data, lambda), A_(A) {}
501
502 scalar_t eval(std::size_t i, std::size_t j) const override {
503 return A_(i, j) + ((i == j) ? this->lambda_ : scalar_t(0.));
504 }
505
506 void permute() override {
508 A_.lapmt(this->perm_, true);
509 A_.lapmr(this->perm_, true);
510 }
511
512 protected:
513 DenseMatrix<scalar_t>& A_; // kernel matrix
514
515 scalar_t eval_kernel_function
516 (const scalar_t* x, const scalar_t* y) const override {
517 assert(false);
518 }
519 };
520
521
526 enum class KernelType {
527 DENSE,
528 GAUSS,
529 LAPLACE,
530 ANOVA
531 };
532
536 inline std::string get_name(KernelType k) {
537 switch (k) {
538 case KernelType::DENSE: return "dense";
539 case KernelType::GAUSS: return "Gauss";
540 case KernelType::LAPLACE: return "Laplace";
541 case KernelType::ANOVA: return "ANOVA";
542 default: return "UNKNOWN";
543 }
544 }
545
551 inline KernelType kernel_type(const std::string& k) {
552 if (k == "dense") return KernelType::DENSE;
553 else if (k == "Gauss") return KernelType::GAUSS;
554 else if (k == "Laplace") return KernelType::LAPLACE;
555 else if (k == "ANOVA") return KernelType::ANOVA;
556 std::cerr << "ERROR: Kernel type not recogonized, "
557 << " setting kernel type to Gauss."
558 << std::endl;
559 return KernelType::GAUSS;
560 }
561
574 template<typename scalar_t>
575 std::unique_ptr<Kernel<scalar_t>> create_kernel
577 scalar_t h, scalar_t lambda, int p=1) {
578 switch (k) {
579 // case KernelType::DENSE:
580 // return std::unique_ptr<Kernel<scalar_t>>
581 // (new DenseKernel<scalar_t>(args ...));
583 return std::unique_ptr<Kernel<scalar_t>>
584 (new GaussKernel<scalar_t>(data, h, lambda));
586 return std::unique_ptr<Kernel<scalar_t>>
587 (new LaplaceKernel<scalar_t>(data, h, lambda));
589 return std::unique_ptr<Kernel<scalar_t>>
590 (new ANOVAKernel<scalar_t>(data, h, lambda, p));
591 default:
592 return std::unique_ptr<Kernel<scalar_t>>
593 (new GaussKernel<scalar_t>(data, h, lambda));
594 }
595 }
596
597 } // end namespace kernel
598
599} // end namespace strumpack
600
601#endif // STRUMPACK_KERNEL_HPP
Contains the DenseMatrix and DenseMatrixWrapper classes, simple wrappers around BLAS/LAPACK style den...
Contains the DistributedMatrix and DistributedMatrixWrapper classes, wrappers around ScaLAPACK/PBLAS ...
Contains the class holding HODLR matrix options.
Contains the HSSOptions class as well as general routines for HSS options.
Definitions of distance metrics.
This is a small wrapper class around a BLACS grid and a BLACS context.
Definition BLACSGrid.hpp:66
Like DenseMatrix, this class represents a matrix, stored in column major format, to allow direct use ...
Definition DenseMatrix.hpp:1018
This class represents a matrix, stored in column major format, to allow direct use of BLAS/LAPACK rou...
Definition DenseMatrix.hpp:139
std::size_t cols() const
Definition DenseMatrix.hpp:231
std::size_t rows() const
Definition DenseMatrix.hpp:228
const scalar_t * ptr(std::size_t i, std::size_t j) const
Definition DenseMatrix.hpp:283
void lapmr(const std::vector< int > &P, bool fwd)
2D block cyclicly distributed matrix, as used by ScaLAPACK.
Definition DistributedMatrix.hpp:84
Class containing several options for the HODLR code and data-structures.
Definition HODLROptions.hpp:117
Class containing several options for the HSS code and data-structures.
Definition HSSOptions.hpp:152
Wrapper class around an MPI_Comm object.
Definition MPIWrapper.hpp:173
ANOVA kernel.
Definition Kernel.hpp:424
ANOVAKernel(DenseMatrix< scalar_t > &data, scalar_t h, scalar_t lambda, int p=1)
Definition Kernel.hpp:438
Arbitrary dense matrix, with underlying geometry.
Definition Kernel.hpp:486
DenseKernel(DenseMatrix< scalar_t > &data, DenseMatrix< scalar_t > &A, scalar_t lambda)
Definition Kernel.hpp:498
scalar_t eval(std::size_t i, std::size_t j) const override
Definition Kernel.hpp:502
Gaussian or radial basis function kernel.
Definition Kernel.hpp:333
GaussKernel(DenseMatrix< scalar_t > &data, scalar_t h, scalar_t lambda)
Definition Kernel.hpp:345
Representation of a kernel matrix.
Definition Kernel.hpp:73
virtual ~Kernel()=default
DenseM_t fit_HODLR(const MPIComm &c, std::vector< scalar_t > &labels, const HODLR::HODLROptions< scalar_t > &opts)
virtual scalar_t eval(std::size_t i, std::size_t j) const
Definition Kernel.hpp:122
std::size_t n() const
Definition Kernel.hpp:106
std::size_t d() const
Definition Kernel.hpp:113
void operator()(const std::vector< std::size_t > &I, const std::vector< std::size_t > &J, DenseMatrix< std::complex< real_t > > &B) const
Definition Kernel.hpp:160
std::vector< scalar_t > predict(const DenseM_t &test, const DistM_t &weights) const
const DenseM_t & data() const
Definition Kernel.hpp:276
void operator()(const std::vector< std::size_t > &I, const std::vector< std::size_t > &J, DenseMatrix< real_t > &B) const
Definition Kernel.hpp:138
DenseM_t & data()
Definition Kernel.hpp:282
DistM_t fit_HSS(const BLACSGrid &grid, std::vector< scalar_t > &labels, const HSS::HSSOptions< scalar_t > &opts)
std::vector< scalar_t > predict(const DenseM_t &test, const DenseM_t &weights) const
DenseM_t fit_HSS(std::vector< scalar_t > &labels, const HSS::HSSOptions< scalar_t > &opts)
Kernel(DenseM_t &data, scalar_t lambda)
Definition Kernel.hpp:92
Laplace kernel.
Definition Kernel.hpp:378
LaplaceKernel(DenseMatrix< scalar_t > &data, scalar_t h, scalar_t lambda)
Definition Kernel.hpp:390
std::string get_name(KernelType k)
Definition Kernel.hpp:536
std::unique_ptr< Kernel< scalar_t > > create_kernel(KernelType k, DenseMatrix< scalar_t > &data, scalar_t h, scalar_t lambda, int p=1)
Definition Kernel.hpp:576
KernelType
Definition Kernel.hpp:526
KernelType kernel_type(const std::string &k)
Definition Kernel.hpp:551
Definition StrumpackOptions.hpp:44
real_t norm1_distance(std::size_t d, const scalar_t *x, const scalar_t *y)
Definition Metrics.hpp:91
real_t Euclidean_distance_squared(std::size_t d, const scalar_t *x, const scalar_t *y)
Definition Metrics.hpp:53