Skip to content

Commit 0699292

Browse files
authored
Merge pull request #2083 from su2code/fix_mz_adjoint_wall_time
Fix wall time for discrete adjoint MZ driver + OpenMP the GMRES orthogonalization when used by MZ driver
2 parents 02da4b6 + 19e200e commit 0699292

4 files changed

Lines changed: 61 additions & 31 deletions

File tree

Common/include/linear_algebra/CSysSolve.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ class CSysSolve {
168168
* \brief Modified Gram-Schmidt orthogonalization
169169
* \author Based on Kesheng John Wu's mgsro subroutine in Saad's SPARSKIT
170170
*
171+
* \param[in] shared_hsbg - if the Hessenberg matrix is shared by multiple threads
171172
* \param[in] i - index indicating which vector in w is being orthogonalized
172173
* \param[in,out] Hsbg - the upper Hessenberg begin updated
173174
* \param[in,out] w - the (i+1)th vector of w is orthogonalized against the
@@ -181,7 +182,7 @@ class CSysSolve {
181182
* vector is kept in nrm0 and updated after operating with each vector
182183
*
183184
*/
184-
void ModGramSchmidt(int i, su2matrix<ScalarType>& Hsbg, std::vector<VectorType>& w) const;
185+
void ModGramSchmidt(bool shared_hsbg, int i, su2matrix<ScalarType>& Hsbg, std::vector<VectorType>& w) const;
185186

186187
/*!
187188
* \brief writes header information for a CSysSolve residual history

Common/include/linear_algebra/vector_expressions.hpp

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -146,28 +146,30 @@ MAKE_UNARY_FUN(sign, sign_, sign_impl)
146146

147147
/*--- Macro to create expressions and overloads for binary functions. ---*/
148148

149-
#define MAKE_BINARY_FUN(FUN, EXPR, IMPL) \
150-
/*!--- Expression class. ---*/ \
151-
template <class U, class V, class Scalar> \
152-
class EXPR : public CVecExpr<EXPR<U, V, Scalar>, Scalar> { \
153-
store_t<const U> u; \
154-
store_t<const V> v; \
155-
\
156-
public: \
157-
static constexpr bool StoreAsRef = false; \
158-
FORCEINLINE EXPR(const U& u_, const V& v_) : u(u_), v(v_) {} \
159-
FORCEINLINE auto operator[](size_t i) const RETURNS(IMPL(u[i], v[i])) \
160-
}; \
161-
/*!--- Vector with vector function overload. ---*/ \
162-
template <class U, class V, class S> \
163-
FORCEINLINE auto FUN(const CVecExpr<U, S>& u, const CVecExpr<V, S>& v) \
164-
RETURNS(EXPR<U, V, S>(u.derived(), v.derived())) /*!--- Vector with scalar function overload. ---*/ \
165-
template <class U, class S> \
166-
FORCEINLINE auto FUN(const CVecExpr<U, S>& u, decay_t<S> v) \
167-
RETURNS(EXPR<U, Bcast<S>, S>(u.derived(), Bcast<S>(v))) /*!--- Scalar with vector function overload. ---*/ \
168-
template <class S, class V> \
169-
FORCEINLINE auto FUN(decay_t<S> u, const CVecExpr<V, S>& v) \
170-
RETURNS(EXPR<Bcast<S>, V, S>(Bcast<S>(u), v.derived()))
149+
// clang-format off
150+
#define MAKE_BINARY_FUN(FUN, EXPR, IMPL) \
151+
/*!--- Expression class. ---*/ \
152+
template <class U, class V, class Scalar> \
153+
class EXPR : public CVecExpr<EXPR<U, V, Scalar>, Scalar> { \
154+
store_t<const U> u; \
155+
store_t<const V> v; \
156+
\
157+
public: \
158+
static constexpr bool StoreAsRef = false; \
159+
FORCEINLINE EXPR(const U& u_, const V& v_) : u(u_), v(v_) {} \
160+
FORCEINLINE auto operator[](size_t i) const RETURNS(IMPL(u[i], v[i])) \
161+
}; \
162+
/*!--- Vector with vector function overload. ---*/ \
163+
template <class U, class V, class S> \
164+
FORCEINLINE auto FUN(const CVecExpr<U, S>& u, const CVecExpr<V, S>& v) \
165+
RETURNS(EXPR<U, V, S>(u.derived(), v.derived())) \
166+
/*!--- Vector with scalar function overload. ---*/ \
167+
template <class U, class S> \
168+
FORCEINLINE auto FUN(const CVecExpr<U, S>& u, decay_t<S> v) RETURNS(EXPR<U, Bcast<S>, S>(u.derived(), Bcast<S>(v))) \
169+
/*!--- Scalar with vector function overload. ---*/ \
170+
template <class S, class V> \
171+
FORCEINLINE auto FUN(decay_t<S> u, const CVecExpr<V, S>& v) RETURNS(EXPR<Bcast<S>, V, S>(Bcast<S>(u), v.derived()))
172+
// clang-format on
171173

172174
/*--- std::max/min have issues (because they return by reference).
173175
* fmin and fmax return by value and thus are fine, but they would force

Common/src/linear_algebra/CSysSolve.cpp

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,19 @@ void CSysSolve<ScalarType>::SolveReduced(int n, const su2matrix<ScalarType>& Hsb
109109
}
110110

111111
template <class ScalarType>
112-
void CSysSolve<ScalarType>::ModGramSchmidt(int i, su2matrix<ScalarType>& Hsbg,
112+
void CSysSolve<ScalarType>::ModGramSchmidt(bool shared_hsbg, int i, su2matrix<ScalarType>& Hsbg,
113113
vector<CSysVector<ScalarType> >& w) const {
114+
const auto thread = omp_get_thread_num();
115+
116+
/*--- If Hsbg is shared by multiple threads calling this function, only one
117+
* thread can write into it. If Hsbg is private, all threads need to write. ---*/
118+
119+
auto SetHsbg = [&](int row, int col, const ScalarType& value) {
120+
if (!shared_hsbg || thread == 0) {
121+
Hsbg(row, col) = value;
122+
}
123+
};
124+
114125
/*--- Parameter for reorthonormalization ---*/
115126

116127
const ScalarType reorth = 0.98;
@@ -132,28 +143,29 @@ void CSysSolve<ScalarType>::ModGramSchmidt(int i, su2matrix<ScalarType>& Hsbg,
132143

133144
for (int k = 0; k < i + 1; k++) {
134145
ScalarType prod = w[i + 1].dot(w[k]);
135-
Hsbg(k, i) = prod;
146+
ScalarType h_ki = prod;
136147
w[i + 1] -= prod * w[k];
137148

138149
/*--- Check if reorthogonalization is necessary ---*/
139150

140151
if (prod * prod > thr) {
141152
prod = w[i + 1].dot(w[k]);
142-
Hsbg(k, i) += prod;
153+
h_ki += prod;
143154
w[i + 1] -= prod * w[k];
144155
}
156+
SetHsbg(k, i, h_ki);
145157

146158
/*--- Update the norm and check its size ---*/
147159

148-
nrm -= pow(Hsbg(k, i), 2);
160+
nrm -= pow(h_ki, 2);
149161
nrm = max<ScalarType>(nrm, 0.0);
150162
thr = nrm * reorth;
151163
}
152164

153165
/*--- Test the resulting vector ---*/
154166

155167
nrm = w[i + 1].norm();
156-
Hsbg(i + 1, i) = nrm;
168+
SetHsbg(i + 1, i, nrm);
157169

158170
/*--- Scale the resulting vector ---*/
159171

@@ -343,6 +355,9 @@ unsigned long CSysSolve<ScalarType>::FGMRES_LinSolver(const CSysVector<ScalarTyp
343355
const CConfig* config) const {
344356
const bool masterRank = (SU2_MPI::GetRank() == MASTER_NODE);
345357
const bool flexible = !precond.IsIdentity();
358+
/*--- If we call the solver outside of a parallel region, but the number of threads allows,
359+
* we still want to parallelize some of the expensive operations. ---*/
360+
const bool nestedParallel = !omp_in_parallel() && omp_get_max_threads() > 1;
346361

347362
/*--- Check the subspace size ---*/
348363

@@ -452,7 +467,14 @@ unsigned long CSysSolve<ScalarType>::FGMRES_LinSolver(const CSysVector<ScalarTyp
452467

453468
/*--- Modified Gram-Schmidt orthogonalization ---*/
454469

455-
ModGramSchmidt(i, H, W);
470+
if (nestedParallel) {
471+
/*--- "omp parallel if" does not work well here ---*/
472+
SU2_OMP_PARALLEL
473+
ModGramSchmidt(true, i, H, W);
474+
END_SU2_OMP_PARALLEL
475+
} else {
476+
ModGramSchmidt(false, i, H, W);
477+
}
456478

457479
/*--- Apply old Givens rotations to new column of the Hessenberg matrix then generate the
458480
new Givens rotation matrix and apply it to the last two elements of H[:][i] and g ---*/
@@ -480,8 +502,12 @@ unsigned long CSysSolve<ScalarType>::FGMRES_LinSolver(const CSysVector<ScalarTyp
480502

481503
const auto& basis = flexible ? Z : W;
482504

483-
for (unsigned long k = 0; k < i; k++) {
484-
x += y[k] * basis[k];
505+
if (nestedParallel) {
506+
SU2_OMP_PARALLEL
507+
for (unsigned long k = 0; k < i; k++) x += y[k] * basis[k];
508+
END_SU2_OMP_PARALLEL
509+
} else {
510+
for (unsigned long k = 0; k < i; k++) x += y[k] * basis[k];
485511
}
486512

487513
/*--- Recalculate final (neg.) residual (this should be optional) ---*/

SU2_CFD/src/drivers/CDiscAdjMultizoneDriver.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,7 @@ void CDiscAdjMultizoneDriver::Run() {
308308

309309
const unsigned long nOuterIter = driver_config->GetnOuter_Iter();
310310
const bool time_domain = driver_config->GetTime_Domain();
311+
driver_config->Set_StartTime(SU2_MPI::Wtime());
311312

312313
/*--- If the gradient of the objective function is 0 so are the adjoint variables.
313314
* Unless in unsteady problems where there are other contributions to the RHS. ---*/

0 commit comments

Comments
 (0)