SuperLU Distributed 9.0.0
gpu3d
batch_factorize_marshall.h
Go to the documentation of this file.
1#ifndef __SUPERLU_BATCH_FACTORIZE_MARSHALL_H__
2#define __SUPERLU_BATCH_FACTORIZE_MARSHALL_H__
3
5// Marshall Functors for batched execution
7#include "superlu_defs.h"
8
13
17 )
18 {
19 this->k_st = k_st;
20 this->ld_batch = ld_batch;
21 this->dim_batch = dim_batch;
22 this->diag_ptrs = diag_ptrs;
23
24 this->Lnzval_bc_ptr = Lnzval_bc_ptr;
25 this->Lrowind_bc_ptr = Lrowind_bc_ptr;
26 this->dperm_c_supno = dperm_c_supno;
27 this->xsup = xsup;
28 }
29
30 __device__ void operator()(const int_t &i) const
31 {
33 int_t *Lrowind_bc = Lrowind_bc_ptr[k];
34 double* Lnzval = Lnzval_bc_ptr[k];
35
36 if(Lnzval && Lrowind_bc)
37 {
38 diag_ptrs[i] = Lnzval;
39 ld_batch[i] = Lrowind_bc[1];
40 dim_batch[i] = SuperSize(k);
41 }
42 else
43 {
44 diag_ptrs[i] = NULL;
45 ld_batch[i] = 1;
46 dim_batch[i] = 0;
47 }
48 }
49};
50
55
60 )
61 {
62 this->k_st = k_st;
63 this->diag_ptrs = diag_ptrs;
64 this->diag_ld_batch = diag_ld_batch;
65 this->diag_dim_batch = diag_dim_batch;
66 this->panel_ptrs = panel_ptrs;
67 this->panel_ld_batch = panel_ld_batch;
68 this->panel_dim_batch = panel_dim_batch;
69 this->Unzval_br_new_ptr = Unzval_br_new_ptr;
70 this->Ucolind_br_ptr = Ucolind_br_ptr;
71 this->Lnzval_bc_ptr = Lnzval_bc_ptr;
72 this->Lrowind_bc_ptr = Lrowind_bc_ptr;
73 this->dperm_c_supno = dperm_c_supno;
74 this->xsup = xsup;
75 }
76
77 __device__ void operator()(const int_t &i) const
78 {
80 int_t ksupc = SuperSize(k);
81
82 int_t *Ucolind_br = Ucolind_br_ptr[k];
83 double* Unzval = Unzval_br_new_ptr[k];
84 int_t *Lrowind_bc = Lrowind_bc_ptr[k];
85 double* Lnzval = Lnzval_bc_ptr[k];
86
87 if(Ucolind_br && Unzval && Lrowind_bc && Lnzval)
88 {
89 int upanel_rows = Ucolind_br[2];
90 int sup_offset = ksupc - upanel_rows;
91
92 panel_ptrs[i] = Unzval;
93 panel_ld_batch[i] = upanel_rows;
94 panel_dim_batch[i] = Ucolind_br[1];
95
96 diag_ptrs[i] = Lnzval + sup_offset + sup_offset * Lrowind_bc[1];
97 diag_ld_batch[i] = Lrowind_bc[1];
98 diag_dim_batch[i] = upanel_rows;
99 }
100 else
101 {
102 panel_ptrs[i] = diag_ptrs[i] = NULL;
105 }
106 }
107};
108
113
118 )
119 {
120 this->k_st = k_st;
121 this->diag_ptrs = diag_ptrs;
122 this->diag_ld_batch = diag_ld_batch;
123 this->diag_dim_batch = diag_dim_batch;
124 this->panel_ptrs = panel_ptrs;
125 this->panel_ld_batch = panel_ld_batch;
126 this->panel_dim_batch = panel_dim_batch;
127
128 this->Lnzval_bc_ptr = Lnzval_bc_ptr;
129 this->Lrowind_bc_ptr = Lrowind_bc_ptr;
130 this->dperm_c_supno = dperm_c_supno;
131 this->xsup = xsup;
132 }
133
134 __device__ void operator()(const int_t &i) const
135 {
136 int_t k = dperm_c_supno[k_st + i];
137 int_t ksupc = SuperSize(k);
138 int_t *Lrowind_bc = Lrowind_bc_ptr[k];
139 double* Lnzval = Lnzval_bc_ptr[k];
140
141 if(Lnzval && Lrowind_bc)
142 {
143 int_t diag_block_offset = Lrowind_bc[BC_HEADER + 1];
144 int_t nzrows = Lrowind_bc[1];
145 int_t len = nzrows - diag_block_offset;
146
147 panel_ptrs[i] = Lnzval + diag_block_offset;
148 panel_ld_batch[i] = nzrows;
149 panel_dim_batch[i] = len;
150 diag_ptrs[i] = Lnzval;
151 diag_ld_batch[i] = nzrows;
152 diag_dim_batch[i] = ksupc;
153 }
154 else
155 {
156 panel_ptrs[i] = diag_ptrs[i] = NULL;
159 }
160 }
161};
162
164 double** A_ptrs, **B_ptrs, **C_ptrs;
169
175 int_t *xsup, double** dgpuGemmBuffs
176 )
177 {
178 this->k_st = k_st;
179 this->A_ptrs = A_ptrs;
180 this->B_ptrs = B_ptrs;
181 this->C_ptrs = C_ptrs;
182 this->lda_array = lda_array;
183 this->ldb_array = ldb_array;
184 this->ldc_array = ldc_array;
185 this->m_array = m_array;
186 this->n_array = n_array;
187 this->k_array = k_array;
188 this->ist = ist;
189 this->iend = iend;
190 this->jst = jst;
191 this->jend = jend;
192 this->Unzval_br_new_ptr = Unzval_br_new_ptr;
193 this->Ucolind_br_ptr = Ucolind_br_ptr;
194 this->Lnzval_bc_ptr = Lnzval_bc_ptr;
195 this->Lrowind_bc_ptr = Lrowind_bc_ptr;
196 this->dperm_c_supno = dperm_c_supno;
197 this->xsup = xsup;
198 this->dgpuGemmBuffs = dgpuGemmBuffs;
199 }
200
201 __device__ void operator()(const int_t &i) const
202 {
203 int_t k = dperm_c_supno[k_st + i];
204
205 int_t ksupc = SuperSize(k);
206 int_t *Ucolind_br = Ucolind_br_ptr[k];
207 double* Unzval = Unzval_br_new_ptr[k];
208 int_t *Lrowind_bc = Lrowind_bc_ptr[k];
209 double* Lnzval = Lnzval_bc_ptr[k];
210
211 if(Ucolind_br && Unzval && Lrowind_bc && Lnzval)
212 {
213 int upanel_rows = Ucolind_br[2];
214 int sup_offset = ksupc - upanel_rows;
215
216 int_t diag_block_offset = Lrowind_bc[BC_HEADER + 1];
217 int_t L_nzrows = Lrowind_bc[1];
218 int_t L_len = L_nzrows - diag_block_offset;
219
220 A_ptrs[i] = Lnzval + diag_block_offset + sup_offset * L_nzrows;
221 B_ptrs[i] = Unzval;
223
224 lda_array[i] = L_nzrows;
225 ldb_array[i] = upanel_rows;
226 ldc_array[i] = L_len;
227
228 m_array[i] = L_len;
229 n_array[i] = Ucolind_br[1];
230 k_array[i] = upanel_rows;
231
232 ist[i] = 1;
233 jst[i] = 0;
234 iend[i] = Lrowind_bc[0];
235 jend[i] = Ucolind_br[0];
236 }
237 else
238 {
239 A_ptrs[i] = B_ptrs[i] = C_ptrs[i] = NULL;
240 lda_array[i] = ldb_array[i] = ldc_array[i] = 1;
241 m_array[i] = n_array[i] = k_array[i] = 0;
242 }
243 }
244};
245
246template <class T, class offT>
248{
250 offT* offsets;
251
253 {
254 this->base_mem = base_mem;
255 this->offsets = offsets;
256 this->ptrs = ptrs;
257 }
258
259 inline __host__ __device__ void operator()(const offT &index) const
260 {
261 ptrs[index] = (offsets[index] < 0 ? NULL : base_mem + offsets[index]);
262 }
263};
264
265template<class T, class offT>
266inline void generateOffsetPointers(T *base_mem, offT *offsets, T **ptrs, size_t num_arrays)
267{
268 UnaryOffsetPtrAssign<T, offT> offset_ptr_functor(base_mem, offsets, ptrs);
269
270 thrust::for_each(
271 thrust::system::cuda::par, thrust::counting_iterator<offT>(0),
272 thrust::counting_iterator<offT>(num_arrays), offset_ptr_functor
273 );
274}
275
276template<typename T>
277struct element_diff : public thrust::unary_function<T,T>
278{
279 T* st, *end;
281 {
282 this->st = st;
283 this->end = end;
284 }
285
286 __device__ T operator()(const T &x) const
287 {
288 return end[x] - st[x];
289 }
290};
291
292#endif
#define BatchDim_t
Definition: batch_factorize.h:11
void generateOffsetPointers(T *base_mem, offT *offsets, T **ptrs, size_t num_arrays)
Definition: batch_factorize_marshall.h:266
Definition: batch_factorize_marshall.h:9
int_t k_st
Definition: batch_factorize_marshall.h:12
double ** Lnzval_bc_ptr
Definition: batch_factorize_marshall.h:11
int_t * xsup
Definition: batch_factorize_marshall.h:12
BatchDim_t * dim_batch
Definition: batch_factorize_marshall.h:10
double ** diag_ptrs
Definition: batch_factorize_marshall.h:11
__device__ void operator()(const int_t &i) const
Definition: batch_factorize_marshall.h:30
MarshallLUFunc_flat(int_t k_st, double **diag_ptrs, BatchDim_t *ld_batch, BatchDim_t *dim_batch, double **Lnzval_bc_ptr, int_t **Lrowind_bc_ptr, int_t *dperm_c_supno, int_t *xsup)
Definition: batch_factorize_marshall.h:14
BatchDim_t * ld_batch
Definition: batch_factorize_marshall.h:10
int_t * dperm_c_supno
Definition: batch_factorize_marshall.h:12
int_t ** Lrowind_bc_ptr
Definition: batch_factorize_marshall.h:12
Definition: batch_factorize_marshall.h:163
BatchDim_t * ist
Definition: batch_factorize_marshall.h:168
int_t ** Lrowind_bc_ptr
Definition: batch_factorize_marshall.h:167
MarshallSCUFunc_flat(int_t k_st, double **A_ptrs, BatchDim_t *lda_array, double **B_ptrs, BatchDim_t *ldb_array, double **C_ptrs, BatchDim_t *ldc_array, BatchDim_t *m_array, BatchDim_t *n_array, BatchDim_t *k_array, BatchDim_t *ist, BatchDim_t *iend, BatchDim_t *jst, BatchDim_t *jend, double **Unzval_br_new_ptr, int_t **Ucolind_br_ptr, double **Lnzval_bc_ptr, int_t **Lrowind_bc_ptr, int_t *dperm_c_supno, int_t *xsup, double **dgpuGemmBuffs)
Definition: batch_factorize_marshall.h:170
__device__ void operator()(const int_t &i) const
Definition: batch_factorize_marshall.h:201
BatchDim_t * ldc_array
Definition: batch_factorize_marshall.h:165
BatchDim_t * k_array
Definition: batch_factorize_marshall.h:165
int_t * xsup
Definition: batch_factorize_marshall.h:167
int_t k_st
Definition: batch_factorize_marshall.h:167
BatchDim_t * jend
Definition: batch_factorize_marshall.h:168
BatchDim_t * lda_array
Definition: batch_factorize_marshall.h:165
int_t ** Ucolind_br_ptr
Definition: batch_factorize_marshall.h:167
double ** C_ptrs
Definition: batch_factorize_marshall.h:164
BatchDim_t * n_array
Definition: batch_factorize_marshall.h:165
double ** Lnzval_bc_ptr
Definition: batch_factorize_marshall.h:166
double ** Unzval_br_new_ptr
Definition: batch_factorize_marshall.h:166
BatchDim_t * m_array
Definition: batch_factorize_marshall.h:165
int_t * dperm_c_supno
Definition: batch_factorize_marshall.h:167
BatchDim_t * ldb_array
Definition: batch_factorize_marshall.h:165
double ** B_ptrs
Definition: batch_factorize_marshall.h:164
BatchDim_t * jst
Definition: batch_factorize_marshall.h:168
BatchDim_t * iend
Definition: batch_factorize_marshall.h:168
double ** A_ptrs
Definition: batch_factorize_marshall.h:164
double ** dgpuGemmBuffs
Definition: batch_factorize_marshall.h:166
Definition: batch_factorize_marshall.h:109
BatchDim_t * panel_dim_batch
Definition: batch_factorize_marshall.h:110
int_t * xsup
Definition: batch_factorize_marshall.h:112
BatchDim_t * diag_dim_batch
Definition: batch_factorize_marshall.h:110
double ** Lnzval_bc_ptr
Definition: batch_factorize_marshall.h:111
BatchDim_t * panel_ld_batch
Definition: batch_factorize_marshall.h:110
int_t k_st
Definition: batch_factorize_marshall.h:112
MarshallTRSMLFunc_flat(int_t k_st, double **diag_ptrs, BatchDim_t *diag_ld_batch, BatchDim_t *diag_dim_batch, double **panel_ptrs, BatchDim_t *panel_ld_batch, BatchDim_t *panel_dim_batch, double **Lnzval_bc_ptr, int_t **Lrowind_bc_ptr, int_t *dperm_c_supno, int_t *xsup)
Definition: batch_factorize_marshall.h:114
__device__ void operator()(const int_t &i) const
Definition: batch_factorize_marshall.h:134
int_t * dperm_c_supno
Definition: batch_factorize_marshall.h:112
double ** panel_ptrs
Definition: batch_factorize_marshall.h:111
BatchDim_t * diag_ld_batch
Definition: batch_factorize_marshall.h:110
int_t ** Lrowind_bc_ptr
Definition: batch_factorize_marshall.h:112
double ** diag_ptrs
Definition: batch_factorize_marshall.h:111
Definition: batch_factorize_marshall.h:51
BatchDim_t * panel_ld_batch
Definition: batch_factorize_marshall.h:52
BatchDim_t * diag_ld_batch
Definition: batch_factorize_marshall.h:52
MarshallTRSMUFunc_flat(int_t k_st, double **diag_ptrs, BatchDim_t *diag_ld_batch, BatchDim_t *diag_dim_batch, double **panel_ptrs, BatchDim_t *panel_ld_batch, BatchDim_t *panel_dim_batch, double **Unzval_br_new_ptr, int_t **Ucolind_br_ptr, double **Lnzval_bc_ptr, int_t **Lrowind_bc_ptr, int_t *dperm_c_supno, int_t *xsup)
Definition: batch_factorize_marshall.h:56
int_t * xsup
Definition: batch_factorize_marshall.h:54
double ** Unzval_br_new_ptr
Definition: batch_factorize_marshall.h:53
__device__ void operator()(const int_t &i) const
Definition: batch_factorize_marshall.h:77
double ** panel_ptrs
Definition: batch_factorize_marshall.h:53
BatchDim_t * panel_dim_batch
Definition: batch_factorize_marshall.h:52
BatchDim_t * diag_dim_batch
Definition: batch_factorize_marshall.h:52
int_t ** Ucolind_br_ptr
Definition: batch_factorize_marshall.h:54
double ** diag_ptrs
Definition: batch_factorize_marshall.h:53
int_t ** Lrowind_bc_ptr
Definition: batch_factorize_marshall.h:54
double ** Lnzval_bc_ptr
Definition: batch_factorize_marshall.h:53
int_t * dperm_c_supno
Definition: batch_factorize_marshall.h:54
int_t k_st
Definition: batch_factorize_marshall.h:54
Definition: batch_factorize_marshall.h:248
UnaryOffsetPtrAssign(T *base_mem, offT *offsets, T **ptrs)
Definition: batch_factorize_marshall.h:252
T ** ptrs
Definition: batch_factorize_marshall.h:249
T * base_mem
Definition: batch_factorize_marshall.h:249
__host__ __device__ void operator()(const offT &index) const
Definition: batch_factorize_marshall.h:259
offT * offsets
Definition: batch_factorize_marshall.h:250
Definition: batch_factorize_marshall.h:278
__device__ T operator()(const T &x) const
Definition: batch_factorize_marshall.h:286
element_diff(T *st, T *end)
Definition: batch_factorize_marshall.h:280
T * st
Definition: batch_factorize_marshall.h:279
T * end
Definition: batch_factorize_marshall.h:279
Definitions which are precision-neutral.
#define SuperSize(bnum)
Definition: superlu_defs.h:271
int64_t int_t
Definition: superlu_defs.h:119
#define BC_HEADER
Definition: superlu_defs.h:198
int i
Definition: sutil_dist.c:287