SuperLU Distributed 9.0.0
gpu3d
cublas_cusolver_wrappers.hpp
Go to the documentation of this file.
1#pragma once
2#include <cublas_v2.h>
3
4template <typename Ftype>
5cusolverStatus_t myCusolverGetrf(cusolverDnHandle_t handle, int m, int n, Ftype *A, int lda, Ftype *Workspace, int *devIpiv, int *devInfo);
6
7template <>
8cusolverStatus_t myCusolverGetrf<double>(cusolverDnHandle_t handle, int m, int n, double *A, int lda, double *Workspace, int *devIpiv, int *devInfo)
9{
10 return cusolverDnDgetrf(handle, m, n, A, lda, Workspace, devIpiv, devInfo);
11}
12
13template <>
14cusolverStatus_t myCusolverGetrf<float>(cusolverDnHandle_t handle, int m, int n, float *A, int lda, float *Workspace, int *devIpiv, int *devInfo)
15{
16 return cusolverDnSgetrf(handle, m, n, A, lda, Workspace, devIpiv, devInfo);
17}
18
19template <>
20cusolverStatus_t myCusolverGetrf<cuComplex>(cusolverDnHandle_t handle, int m, int n, cuComplex *A, int lda, cuComplex *Workspace, int *devIpiv, int *devInfo)
21{
22 return cusolverDnCgetrf(handle, m, n, A, lda, Workspace, devIpiv, devInfo);
23}
24
25template <>
26cusolverStatus_t myCusolverGetrf<cuDoubleComplex>(cusolverDnHandle_t handle, int m, int n, cuDoubleComplex *A, int lda, cuDoubleComplex *Workspace, int *devIpiv, int *devInfo)
27{
28 return cusolverDnZgetrf(handle, m, n, A, lda, Workspace, devIpiv, devInfo);
29}
30
31template <typename Ftype>
32cublasStatus_t myCublasTrsm(cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, cublasOperation_t trans, cublasDiagType_t diag, int m, int n, const Ftype *alpha, const Ftype *A, int lda, Ftype *B, int ldb);
33
34template <>
35cublasStatus_t myCublasTrsm<double>(cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, cublasOperation_t trans, cublasDiagType_t diag, int m, int n, const double *alpha, const double *A, int lda, double *B, int ldb)
36{
37 return cublasDtrsm(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb);
38}
39
40template <>
41cublasStatus_t myCublasTrsm<float>(cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, cublasOperation_t trans, cublasDiagType_t diag, int m, int n, const float *alpha, const float *A, int lda, float *B, int ldb)
42{
43 return cublasStrsm(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb);
44}
45
46template <>
47cublasStatus_t myCublasTrsm<cuComplex>(cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, cublasOperation_t trans, cublasDiagType_t diag, int m, int n, const cuComplex *alpha, const cuComplex *A, int lda, cuComplex *B, int ldb)
48{
49 return cublasCtrsm(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb);
50}
51
52template <>
53cublasStatus_t myCublasTrsm<cuDoubleComplex>(cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, cublasOperation_t trans, cublasDiagType_t diag, int m, int n, const cuDoubleComplex *alpha, const cuDoubleComplex *A, int lda, cuDoubleComplex *B, int ldb)
54{
55 return cublasZtrsm(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb);
56}
57
58template <typename Ftype>
59cublasStatus_t myCublasScal(cublasHandle_t handle, int n, const Ftype *alpha, Ftype *x, int incx);
60
61template <typename Ftype>
62cublasStatus_t myCublasAxpy(cublasHandle_t handle, int n, const Ftype *alpha, const Ftype *x, int incx, Ftype *y, int incy);
63
64template <>
65cublasStatus_t myCublasScal<double>(cublasHandle_t handle, int n, const double *alpha, double *x, int incx)
66{
67 return cublasDscal(handle, n, alpha, x, incx);
68}
69
70template <>
71cublasStatus_t myCublasScal<float>(cublasHandle_t handle, int n, const float *alpha, float *x, int incx)
72{
73 return cublasSscal(handle, n, alpha, x, incx);
74}
75
76template <>
77cublasStatus_t myCublasAxpy<double>(cublasHandle_t handle, int n, const double *alpha, const double *x, int incx, double *y, int incy)
78{
79 return cublasDaxpy(handle, n, alpha, x, incx, y, incy);
80}
81
82template <>
83cublasStatus_t myCublasAxpy<float>(cublasHandle_t handle, int n, const float *alpha, const float *x, int incx, float *y, int incy)
84{
85 return cublasSaxpy(handle, n, alpha, x, incx, y, incy);
86}
87
88template <>
89cublasStatus_t myCublasScal<cuComplex>(cublasHandle_t handle, int n, const cuComplex *alpha, cuComplex *x, int incx)
90{
91 return cublasCscal(handle, n, alpha, x, incx);
92}
93
94template <>
95cublasStatus_t myCublasScal<cuDoubleComplex>(cublasHandle_t handle, int n, const cuDoubleComplex *alpha, cuDoubleComplex *x, int incx)
96{
97 return cublasZscal(handle, n, alpha, x, incx);
98}
99
100template <>
101cublasStatus_t myCublasAxpy<cuComplex>(cublasHandle_t handle, int n, const cuComplex *alpha, const cuComplex *x, int incx, cuComplex *y, int incy)
102{
103 return cublasCaxpy(handle, n, alpha, x, incx, y, incy);
104}
105
106template <>
107cublasStatus_t myCublasAxpy<cuDoubleComplex>(cublasHandle_t handle, int n, const cuDoubleComplex *alpha, const cuDoubleComplex *x, int incx, cuDoubleComplex *y, int incy)
108{
109 return cublasZaxpy(handle, n, alpha, x, incx, y, incy);
110}
111
112template <typename Ftype>
113cublasStatus_t myCublasGemm(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const Ftype *alpha, const Ftype *A, int lda, const Ftype *B, int ldb, const Ftype *beta, Ftype *C, int ldc);
114
115template <>
116cublasStatus_t myCublasGemm<double>(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const double *alpha, const double *A, int lda, const double *B, int ldb, const double *beta, double *C, int ldc)
117{
118 return cublasDgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
119}
120
121template <>
122cublasStatus_t myCublasGemm<float>(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const float *alpha, const float *A, int lda, const float *B, int ldb, const float *beta, float *C, int ldc)
123{
124 return cublasSgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
125}
126
127template <>
128cublasStatus_t myCublasGemm<cuComplex>(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const cuComplex *alpha, const cuComplex *A, int lda, const cuComplex *B, int ldb, const cuComplex *beta, cuComplex *C, int ldc)
129{
130 return cublasCgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
131}
132
133template <>
134cublasStatus_t myCublasGemm<cuDoubleComplex>(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const cuDoubleComplex *alpha, const cuDoubleComplex *A, int lda, const cuDoubleComplex *B, int ldb, const cuDoubleComplex *beta, cuDoubleComplex *C, int ldc)
135{
136 return cublasZgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
137}
138
139template <>
140cublasStatus_t myCublasGemm<doublecomplex>(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const doublecomplex *alpha, const doublecomplex *A, int lda, const doublecomplex *B, int ldb, const doublecomplex *beta, doublecomplex *C, int ldc)
141{
142 // return cublasZgemm(handle, transa, transb, m, n, k,
143 // alpha, A, lda, B, ldb, beta, C, ldc);
144 // cast doublecomplex to cuDoubleComplex
145 return cublasZgemm(
146 handle, transa, transb, m, n, k,
147 reinterpret_cast<const cuDoubleComplex *>(alpha),
148 reinterpret_cast<const cuDoubleComplex *>(A), lda,
149 reinterpret_cast<const cuDoubleComplex *>(B), ldb,
150 reinterpret_cast<const cuDoubleComplex *>(beta),
151 reinterpret_cast<cuDoubleComplex *>(C), ldc);
152
153}
154
155template <>
157 cusolverDnHandle_t handle, int m, int n, doublecomplex *A, int lda,
158 doublecomplex *Workspace, int *devIpiv, int *devInfo)
159{
160 // return cusolverDnZgetrf(handle, m, n, A, lda, Workspace, devIpiv, devInfo);
161 // cast doublecomplex to cuDoubleComplex
162 return cusolverDnZgetrf(
163 handle, m, n, reinterpret_cast<cuDoubleComplex *>(A), lda,
164 reinterpret_cast<cuDoubleComplex *>(Workspace), devIpiv, devInfo);
165}
166
167// now creating the wrappers for the other functions
168template <>
169cublasStatus_t myCublasTrsm<doublecomplex>(cublasHandle_t handle,
170 cublasSideMode_t side, cublasFillMode_t uplo,
171 cublasOperation_t trans, cublasDiagType_t diag,
172 int m, int n,
173 const doublecomplex *alpha,
174 const doublecomplex *A, int lda,
175 doublecomplex *B, int ldb) {
176 // Your implementation here
177 // You can use cublasZtrsm function because it's for cuDoubleComplex type
178 return cublasZtrsm(handle, side, uplo, trans, diag, m, n,
179 reinterpret_cast<const cuDoubleComplex*>(alpha),
180 reinterpret_cast<const cuDoubleComplex*>(A), lda,
181 reinterpret_cast<cuDoubleComplex*>(B), ldb);
182}
183
184template <>
185cublasStatus_t myCublasScal<doublecomplex>(cublasHandle_t handle, int n,
186 const doublecomplex *alpha,
187 doublecomplex *x, int incx) {
188 // Your implementation here
189 // You can use cublasZscal function because it's for cuDoubleComplex type
190 return cublasZscal(handle, n, reinterpret_cast<const cuDoubleComplex*>(alpha),
191 reinterpret_cast<cuDoubleComplex*>(x), incx);
192}
193
194template <>
195cublasStatus_t myCublasAxpy<doublecomplex>(cublasHandle_t handle, int n,
196 const doublecomplex *alpha,
197 const doublecomplex *x, int incx,
198 doublecomplex *y, int incy) {
199 // Your implementation here
200 // You can use cublasZaxpy function because it's for cuDoubleComplex type
201 return cublasZaxpy(handle, n, reinterpret_cast<const cuDoubleComplex*>(alpha),
202 reinterpret_cast<const cuDoubleComplex*>(x), incx,
203 reinterpret_cast<cuDoubleComplex*>(y), incy);
204}
205
206
207// cublasStatus_t myCublasScal<doublecomplex>
208// cublasStatus_t myCublasAxpy<doublecomplex>
209// cublasStatus_t myCublasGemm<doublecomplex>
cublasStatus_t myCublasGemm< cuComplex >(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const cuComplex *alpha, const cuComplex *A, int lda, const cuComplex *B, int ldb, const cuComplex *beta, cuComplex *C, int ldc)
Definition: cublas_cusolver_wrappers.hpp:128
cusolverStatus_t myCusolverGetrf< cuComplex >(cusolverDnHandle_t handle, int m, int n, cuComplex *A, int lda, cuComplex *Workspace, int *devIpiv, int *devInfo)
Definition: cublas_cusolver_wrappers.hpp:20
cusolverStatus_t myCusolverGetrf< doublecomplex >(cusolverDnHandle_t handle, int m, int n, doublecomplex *A, int lda, doublecomplex *Workspace, int *devIpiv, int *devInfo)
Definition: cublas_cusolver_wrappers.hpp:156
cublasStatus_t myCublasScal(cublasHandle_t handle, int n, const Ftype *alpha, Ftype *x, int incx)
cusolverStatus_t myCusolverGetrf< double >(cusolverDnHandle_t handle, int m, int n, double *A, int lda, double *Workspace, int *devIpiv, int *devInfo)
Definition: cublas_cusolver_wrappers.hpp:8
cublasStatus_t myCublasAxpy< float >(cublasHandle_t handle, int n, const float *alpha, const float *x, int incx, float *y, int incy)
Definition: cublas_cusolver_wrappers.hpp:83
cublasStatus_t myCublasAxpy< doublecomplex >(cublasHandle_t handle, int n, const doublecomplex *alpha, const doublecomplex *x, int incx, doublecomplex *y, int incy)
Definition: cublas_cusolver_wrappers.hpp:195
cublasStatus_t myCublasAxpy< cuComplex >(cublasHandle_t handle, int n, const cuComplex *alpha, const cuComplex *x, int incx, cuComplex *y, int incy)
Definition: cublas_cusolver_wrappers.hpp:101
cublasStatus_t myCublasScal< doublecomplex >(cublasHandle_t handle, int n, const doublecomplex *alpha, doublecomplex *x, int incx)
Definition: cublas_cusolver_wrappers.hpp:185
cublasStatus_t myCublasGemm< doublecomplex >(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const doublecomplex *alpha, const doublecomplex *A, int lda, const doublecomplex *B, int ldb, const doublecomplex *beta, doublecomplex *C, int ldc)
Definition: cublas_cusolver_wrappers.hpp:140
cublasStatus_t myCublasAxpy< double >(cublasHandle_t handle, int n, const double *alpha, const double *x, int incx, double *y, int incy)
Definition: cublas_cusolver_wrappers.hpp:77
cublasStatus_t myCublasTrsm< float >(cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, cublasOperation_t trans, cublasDiagType_t diag, int m, int n, const float *alpha, const float *A, int lda, float *B, int ldb)
Definition: cublas_cusolver_wrappers.hpp:41
cublasStatus_t myCublasScal< double >(cublasHandle_t handle, int n, const double *alpha, double *x, int incx)
Definition: cublas_cusolver_wrappers.hpp:65
cublasStatus_t myCublasGemm< float >(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const float *alpha, const float *A, int lda, const float *B, int ldb, const float *beta, float *C, int ldc)
Definition: cublas_cusolver_wrappers.hpp:122
cublasStatus_t myCublasGemm< double >(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const double *alpha, const double *A, int lda, const double *B, int ldb, const double *beta, double *C, int ldc)
Definition: cublas_cusolver_wrappers.hpp:116
cublasStatus_t myCublasAxpy< cuDoubleComplex >(cublasHandle_t handle, int n, const cuDoubleComplex *alpha, const cuDoubleComplex *x, int incx, cuDoubleComplex *y, int incy)
Definition: cublas_cusolver_wrappers.hpp:107
cusolverStatus_t myCusolverGetrf(cusolverDnHandle_t handle, int m, int n, Ftype *A, int lda, Ftype *Workspace, int *devIpiv, int *devInfo)
cusolverStatus_t myCusolverGetrf< cuDoubleComplex >(cusolverDnHandle_t handle, int m, int n, cuDoubleComplex *A, int lda, cuDoubleComplex *Workspace, int *devIpiv, int *devInfo)
Definition: cublas_cusolver_wrappers.hpp:26
cublasStatus_t myCublasScal< float >(cublasHandle_t handle, int n, const float *alpha, float *x, int incx)
Definition: cublas_cusolver_wrappers.hpp:71
cublasStatus_t myCublasTrsm< doublecomplex >(cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, cublasOperation_t trans, cublasDiagType_t diag, int m, int n, const doublecomplex *alpha, const doublecomplex *A, int lda, doublecomplex *B, int ldb)
Definition: cublas_cusolver_wrappers.hpp:169
cublasStatus_t myCublasTrsm< cuComplex >(cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, cublasOperation_t trans, cublasDiagType_t diag, int m, int n, const cuComplex *alpha, const cuComplex *A, int lda, cuComplex *B, int ldb)
Definition: cublas_cusolver_wrappers.hpp:47
cublasStatus_t myCublasTrsm< cuDoubleComplex >(cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, cublasOperation_t trans, cublasDiagType_t diag, int m, int n, const cuDoubleComplex *alpha, const cuDoubleComplex *A, int lda, cuDoubleComplex *B, int ldb)
Definition: cublas_cusolver_wrappers.hpp:53
cublasStatus_t myCublasScal< cuDoubleComplex >(cublasHandle_t handle, int n, const cuDoubleComplex *alpha, cuDoubleComplex *x, int incx)
Definition: cublas_cusolver_wrappers.hpp:95
cublasStatus_t myCublasAxpy(cublasHandle_t handle, int n, const Ftype *alpha, const Ftype *x, int incx, Ftype *y, int incy)
cublasStatus_t myCublasTrsm(cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, cublasOperation_t trans, cublasDiagType_t diag, int m, int n, const Ftype *alpha, const Ftype *A, int lda, Ftype *B, int ldb)
cublasStatus_t myCublasGemm< cuDoubleComplex >(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const cuDoubleComplex *alpha, const cuDoubleComplex *A, int lda, const cuDoubleComplex *B, int ldb, const cuDoubleComplex *beta, cuDoubleComplex *C, int ldc)
Definition: cublas_cusolver_wrappers.hpp:134
cublasStatus_t myCublasScal< cuComplex >(cublasHandle_t handle, int n, const cuComplex *alpha, cuComplex *x, int incx)
Definition: cublas_cusolver_wrappers.hpp:89
cusolverStatus_t myCusolverGetrf< float >(cusolverDnHandle_t handle, int m, int n, float *A, int lda, float *Workspace, int *devIpiv, int *devInfo)
Definition: cublas_cusolver_wrappers.hpp:14
cublasStatus_t myCublasTrsm< double >(cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, cublasOperation_t trans, cublasDiagType_t diag, int m, int n, const double *alpha, const double *A, int lda, double *B, int ldb)
Definition: cublas_cusolver_wrappers.hpp:35
cublasStatus_t myCublasGemm(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const Ftype *alpha, const Ftype *A, int lda, const Ftype *B, int ldb, const Ftype *beta, Ftype *C, int ldc)
integer, parameter, public trans
Definition: superlupara.f90:35
Definition: dcomplex.h:30