SuperLU Distributed 9.0.0
gpu3d
anc25d-GPU_impl.hpp
Go to the documentation of this file.
1#pragma once
2#include <cstdio>
3#include "superlu_ddefs.h"
4#include "lupanels.hpp"
5#include "lupanels_GPU.cuh"
8
9template <typename Ftype>
11 int_t alvl,
12 sForest_t *sforest,
13 diagFactBufs_type<Ftype> **dFBufs, // size maxEtree level
14 gEtreeInfo_t *gEtreeInfo, // global etree info
15 int tag_ub)
16{
17 int_t nnodes = sforest->nNodes; // number of nodes in the tree
18 if (nnodes < 1)
19 {
20 return 1;
21 }
22
23#if (DEBUGlevel >= 1)
24 CHECK_MALLOC(grid3d->iam, "Enter dAncestorFactor_ASYNC()");
25#endif
26
27 int_t *perm_c_supno = sforest->nodeList; // list of nodes in the order of factorization
28 treeTopoInfo_t *treeTopoInfo = &sforest->topoInfo;
29 // int_t *myIperm = treeTopoInfo->myIperm;
30 int_t maxTopoLevel = treeTopoInfo->numLvl;
31 int_t *eTreeTopLims = treeTopoInfo->eTreeTopLims;
32
33 /*main loop over all the levels*/
34 int_t numLA = getNumLookAhead(options);
35
36 for (int_t topoLvl = 0; topoLvl < maxTopoLevel; ++topoLvl)
37 {
38 /* code */
39 int_t k_st = eTreeTopLims[topoLvl];
40 int_t k_end = eTreeTopLims[topoLvl + 1];
41 for (int_t k0 = k_st; k0 < k_end; ++k0)
42 {
43 int_t k = perm_c_supno[k0];
44 int kRoot = anc25d.rootRank(k0, alvl);
45 // reduce the l and u panels to the root with MPI_Comm = anc25d.getComm(alvl);
46 if (mycol == kcol(k))
47 {
48 void* sendBuf = (void*) lPanelVec[g2lCol(k)].gpuPanel.val;
49 if (anc25d.rankHasGrid(k0, alvl))
50 sendBuf = MPI_IN_PLACE;
51
52 MPI_Reduce(sendBuf, lPanelVec[g2lCol(k)].gpuPanel.val,
53 lPanelVec[g2lCol(k)].nzvalSize(), get_mpi_type<Ftype>(), MPI_SUM, kRoot, anc25d.getComm(alvl));
54
55 }
56
57 if (myrow == krow(k))
58 {
59 void* sendBuf = (void*) uPanelVec[g2lRow(k)].gpuPanel.val;
60 if (anc25d.rankHasGrid(k0, alvl))
61 sendBuf = MPI_IN_PLACE;
62 MPI_Reduce(sendBuf, uPanelVec[g2lRow(k)].gpuPanel.val,
63 uPanelVec[g2lRow(k)].nzvalSize(), get_mpi_type<Ftype>(), MPI_SUM, kRoot, anc25d.getComm(alvl));
64 }
65
66
67 if (anc25d.rankHasGrid(k0, alvl))
68 {
69
70 int_t offset = k0 - k_st;
71 // int_t ksupc = SuperSize(k);
72 dDFactPSolveGPU(k, offset, dFBufs);
73
74 #if 0
75 /*======= Diagonal Factorization ======*/
76 if (iam == procIJ(k, k))
77 {
78 lPanelVec[g2lCol(k)].diagFactor(k, dFBufs[offset]->BlockUFactor, ksupc,
79 thresh, xsup, options, stat, info);
80 lPanelVec[g2lCol(k)].packDiagBlock(dFBufs[offset]->BlockLFactor, ksupc);
81 }
82
83 /*======= Diagonal Broadcast ======*/
84 if (myrow == krow(k))
85 MPI_Bcast((void *)dFBufs[offset]->BlockLFactor, ksupc * ksupc,
86 get_mpi_type<Ftype>(), kcol(k), (grid->rscp).comm);
87 if (mycol == kcol(k))
88 MPI_Bcast((void *)dFBufs[offset]->BlockUFactor, ksupc * ksupc,
89 get_mpi_type<Ftype>(), krow(k), (grid->cscp).comm);
90
91 /*======= Panel Update ======*/
92 if (myrow == krow(k))
93 uPanelVec[g2lRow(k)].panelSolve(ksupc, dFBufs[offset]->BlockLFactor, ksupc);
94
95 if (mycol == kcol(k))
96 lPanelVec[g2lCol(k)].panelSolve(ksupc, dFBufs[offset]->BlockUFactor, ksupc);
97 #endif
98 /*======= Panel Broadcast ======*/
99 // upanel_t k_upanel(UidxRecvBufs[0], UvalRecvBufs[0]);
100 // lpanel_t k_lpanel(LidxRecvBufs[0], LvalRecvBufs[0]);
101 /*======= Panel Broadcast ======*/
102 xupanel_t<Ftype> k_upanel(UidxRecvBufs[0], UvalRecvBufs[0],
103 A_gpu.UidxRecvBufs[0], A_gpu.UvalRecvBufs[0]);
104 xlpanel_t<Ftype> k_lpanel(LidxRecvBufs[0], LvalRecvBufs[0],
105 A_gpu.LidxRecvBufs[0], A_gpu.LvalRecvBufs[0]);
106 if (myrow == krow(k))
107 {
108 k_upanel = uPanelVec[g2lRow(k)];
109 }
110 if (mycol == kcol(k))
111 k_lpanel = lPanelVec[g2lCol(k)];
112
113 if (UidxSendCounts[k] > 0)
114 {
115 MPI_Bcast(k_upanel.gpuPanel.index, UidxSendCounts[k], mpi_int_t, krow(k), grid3d->cscp.comm);
116 MPI_Bcast(k_upanel.gpuPanel.val, UvalSendCounts[k], get_mpi_type<Ftype>(), krow(k), grid3d->cscp.comm);
117 }
118
119 if (LidxSendCounts[k] > 0)
120 {
121 MPI_Bcast(k_lpanel.gpuPanel.index, LidxSendCounts[k], mpi_int_t, kcol(k), grid3d->rscp.comm);
122 MPI_Bcast(k_lpanel.gpuPanel.val, LvalSendCounts[k], get_mpi_type<Ftype>(), kcol(k), grid3d->rscp.comm);
123 }
124
125/*======= Schurcomplement Update ======*/
126#warning single node only
127 // dSchurComplementUpdate(k, lPanelVec[g2lCol(k)], uPanelVec[g2lRow(k)]);
128 // dSchurComplementUpdate(k, lPanelVec[g2lCol(k)], k_upanel);
129 if (UidxSendCounts[k] > 0 && LidxSendCounts[k] > 0)
130 {
131 k_upanel.checkCorrectness();
132 // dSchurComplementUpdate(k, k_lpanel, k_upanel);
133 int streamId = 0;
134 dSchurComplementUpdateGPU(
135 streamId,
136 k, k_lpanel, k_upanel);
137 }
138 }
140 // Brodcast the l and u panels to the root with MPI_Comm = anc25d.getComm(alvl);
141 if (mycol == kcol(k))
142 MPI_Bcast(lPanelVec[g2lCol(k)].gpuPanel.val,
143 lPanelVec[g2lCol(k)].nzvalSize(), get_mpi_type<Ftype>(), kRoot, anc25d.getComm(alvl));
144
145 if (myrow == krow(k))
146 MPI_Bcast(uPanelVec[g2lRow(k)].gpuPanel.val,
147 uPanelVec[g2lRow(k)].nzvalSize(), get_mpi_type<Ftype>(), kRoot, anc25d.getComm(alvl));
148 // MPI_Barrier(grid3d->comm);
149
150 } /*for k0= k_st:k_end */
151
152 } /*for topoLvl = 0:maxTopoLevel*/
153
154#if (DEBUGlevel >= 1)
155 CHECK_MALLOC(grid3d->iam, "Exit dAncestorFactor_ASYNC()");
156#endif
157
158 return 0;
159} /* dAncestorFactor_ASYNC */
Definition: xlupanels.hpp:22
Definition: xlupanels.hpp:176
typename std::conditional< std::is_same< Ftype, double >::value, ddiagFactBufs_t, typename std::conditional< std::is_same< Ftype, float >::value, sdiagFactBufs_t, typename std::conditional< std::is_same< Ftype, doublecomplex >::value, zdiagFactBufs_t, void >::type >::type >::type diagFactBufs_type
Definition: luAuxStructTemplated.hpp:147
Definition: superlu_defs.h:978
Definition: superlu_defs.h:989
treeTopoInfo_t topoInfo
Definition: superlu_defs.h:999
int_t * nodeList
Definition: superlu_defs.h:992
int_t nNodes
Definition: superlu_defs.h:991
Definition: superlu_defs.h:970
int_t numLvl
Definition: superlu_defs.h:971
int_t * eTreeTopLims
Definition: superlu_defs.h:972
Definition: xlupanels.hpp:335
Distributed SuperLU data types and function prototypes.
#define mpi_int_t
Definition: superlu_defs.h:120
int_t getNumLookAhead(superlu_dist_options_t *)
Definition: treeFactorization.c:186
int64_t int_t
Definition: superlu_defs.h:119
#define CHECK_MALLOC(pnum, where)
Definition: util_dist.h:56