30 #ifndef SACADO_FAD_EXP_ATOMIC_HPP 31 #define SACADO_FAD_EXP_ATOMIC_HPP 34 #if defined(HAVE_SACADO_KOKKOSCORE) 37 #include "Kokkos_Atomic.hpp" 38 #include "impl/Kokkos_Error.hpp" 46 template <
typename ValT,
unsigned sl,
unsigned ss,
typename U,
typename T>
48 void atomic_add(ViewFadPtr<ValT,sl,ss,U> dst,
const Expr<T>& xx) {
49 using Kokkos::atomic_add;
53 const int xsz =
x.size();
54 const int sz = dst->size();
60 "Sacado error: Fad resize within atomic_add() not supported!");
62 if (xsz != sz && sz > 0 && xsz > 0)
64 "Sacado error: Fad assignment of incompatiable sizes!");
67 if (sz > 0 && xsz > 0) {
69 atomic_add(&(dst->fastAccessDx(
i)),
x.fastAccessDx(
i));
72 atomic_add(&(dst->val()),
x.val());
78 template <
typename Oper,
typename DestPtrT,
typename ValT,
typename T>
81 atomic_oper_fetch_impl(
const Oper& op, DestPtrT dest, ValT* dest_val,
87 #if defined(KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_HOST) 88 while (!Kokkos::Impl::lock_address_host_space((
void*)dest_val))
90 Kokkos::memory_fence();
91 return_type return_val = op.apply(*dest,
val);
93 Kokkos::memory_fence();
94 Kokkos::Impl::unlock_address_host_space((
void*)dest_val);
96 #elif defined(KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_CUDA) 101 #if defined(SACADO_VIEW_CUDA_HIERARCHICAL) || defined(SACADO_VIEW_CUDA_HIERARCHICAL_DFAD) 102 const bool use_team = (blockDim.x > 1);
104 const bool use_team =
false;
109 if (threadIdx.x == 0)
110 go = !Kokkos::Impl::lock_address_cuda_space((
void*)dest_val);
111 go = Kokkos::shfl(go, 0, blockDim.x);
113 Kokkos::memory_fence();
114 return_type return_val = op.apply(*dest,
val);
116 Kokkos::memory_fence();
117 if (threadIdx.x == 0)
118 Kokkos::Impl::unlock_address_cuda_space((
void*)dest_val);
122 return_type return_val;
125 #ifdef KOKKOS_IMPL_CUDA_SYNCWARP_NEEDS_MASK 126 unsigned int mask = KOKKOS_IMPL_CUDA_ACTIVEMASK;
127 unsigned int active = KOKKOS_IMPL_CUDA_BALLOT_MASK(mask, 1);
129 unsigned int active = KOKKOS_IMPL_CUDA_BALLOT(1);
131 unsigned int done_active = 0;
132 while (active != done_active) {
134 if (Kokkos::Impl::lock_address_cuda_space((
void*)dest_val)) {
135 Kokkos::memory_fence();
136 return_val = op.apply(*dest,
val);
138 Kokkos::memory_fence();
139 Kokkos::Impl::unlock_address_cuda_space((
void*)dest_val);
143 #ifdef KOKKOS_IMPL_CUDA_SYNCWARP_NEEDS_MASK 144 done_active = KOKKOS_IMPL_CUDA_BALLOT_MASK(mask, done);
146 done_active = KOKKOS_IMPL_CUDA_BALLOT(done);
151 #elif defined(__HIP_DEVICE_COMPILE__) 153 Kokkos::abort(
"atomic_oper_fetch not implemented for large types.");
154 return_type return_val;
156 unsigned int active = __ballot(1);
157 unsigned int done_active = 0;
158 while (active != done_active) {
162 return_val = op.apply(*dest,
val);
168 done_active = __ballot(done);
174 template <
typename Oper,
typename DestPtrT,
typename ValT,
typename T>
177 atomic_fetch_oper_impl(
const Oper& op, DestPtrT dest, ValT* dest_val,
183 #ifdef KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_HOST 184 while (!Kokkos::Impl::lock_address_host_space((
void*)dest_val))
186 Kokkos::memory_fence();
187 return_type return_val = *dest;
188 *dest = op.apply(return_val,
val);
189 Kokkos::memory_fence();
190 Kokkos::Impl::unlock_address_host_space((
void*)dest_val);
192 #elif defined(KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_CUDA) 197 #if defined(SACADO_VIEW_CUDA_HIERARCHICAL) || defined(SACADO_VIEW_CUDA_HIERARCHICAL_DFAD) 198 const bool use_team = (blockDim.x > 1);
200 const bool use_team =
false;
205 if (threadIdx.x == 0)
206 go = !Kokkos::Impl::lock_address_cuda_space((
void*)dest_val);
207 go = Kokkos::shfl(go, 0, blockDim.x);
209 Kokkos::memory_fence();
210 return_type return_val = *dest;
211 *dest = op.apply(return_val,
val);
212 Kokkos::memory_fence();
213 if (threadIdx.x == 0)
214 Kokkos::Impl::unlock_address_cuda_space((
void*)dest_val);
218 return_type return_val;
221 #ifdef KOKKOS_IMPL_CUDA_SYNCWARP_NEEDS_MASK 222 unsigned int mask = KOKKOS_IMPL_CUDA_ACTIVEMASK;
223 unsigned int active = KOKKOS_IMPL_CUDA_BALLOT_MASK(mask, 1);
225 unsigned int active = KOKKOS_IMPL_CUDA_BALLOT(1);
227 unsigned int done_active = 0;
228 while (active != done_active) {
230 if (Kokkos::Impl::lock_address_cuda_space((
void*)dest_val)) {
231 Kokkos::memory_fence();
233 *dest = op.apply(return_val,
val);
234 Kokkos::memory_fence();
235 Kokkos::Impl::unlock_address_cuda_space((
void*)dest_val);
239 #ifdef KOKKOS_IMPL_CUDA_SYNCWARP_NEEDS_MASK 240 done_active = KOKKOS_IMPL_CUDA_BALLOT_MASK(mask, done);
242 done_active = KOKKOS_IMPL_CUDA_BALLOT(done);
247 #elif defined(__HIP_DEVICE_COMPILE__) 249 Kokkos::abort(
"atomic_oper_fetch not implemented for large types.");
250 return_type return_val;
252 unsigned int active = __ballot(1);
253 unsigned int done_active = 0;
254 while (active != done_active) {
259 *dest = op.apply(return_val,
val);
264 done_active = __ballot(done);
272 template <
typename Oper,
typename S>
274 atomic_oper_fetch(
const Oper& op, GeneralFad<S>* dest,
275 const GeneralFad<S>&
val)
277 return Impl::atomic_oper_fetch_impl(op, dest, &(dest->val()),
val);
279 template <
typename Oper,
typename ValT,
unsigned sl,
unsigned ss,
280 typename U,
typename T>
282 atomic_oper_fetch(
const Oper& op, ViewFadPtr<ValT,sl,ss,U> dest,
285 return Impl::atomic_oper_fetch_impl(op, dest, &dest.val(),
val);
288 template <
typename Oper,
typename S>
290 atomic_fetch_oper(
const Oper& op, GeneralFad<S>* dest,
291 const GeneralFad<S>&
val)
293 return Impl::atomic_fetch_oper_impl(op, dest, &(dest->val()),
val);
295 template <
typename Oper,
typename ValT,
unsigned sl,
unsigned ss,
296 typename U,
typename T>
298 atomic_fetch_oper(
const Oper& op, ViewFadPtr<ValT,sl,ss,U> dest,
301 return Impl::atomic_fetch_oper_impl(op, dest, &dest.val(),
val);
306 template <
class Scalar1,
class Scalar2>
307 KOKKOS_FORCEINLINE_FUNCTION
308 static auto apply(
const Scalar1& val1,
const Scalar2& val2)
309 -> decltype(
max(val1,val2))
311 return max(val1,val2);
315 template <
class Scalar1,
class Scalar2>
316 KOKKOS_FORCEINLINE_FUNCTION
317 static auto apply(
const Scalar1& val1,
const Scalar2& val2)
318 -> decltype(
min(val1,val2))
320 return min(val1,val2);
324 template <
class Scalar1,
class Scalar2>
325 KOKKOS_FORCEINLINE_FUNCTION
326 static auto apply(
const Scalar1& val1,
const Scalar2& val2)
327 -> decltype(val1+val2)
333 template <
class Scalar1,
class Scalar2>
334 KOKKOS_FORCEINLINE_FUNCTION
335 static auto apply(
const Scalar1& val1,
const Scalar2& val2)
336 -> decltype(val1-val2)
342 template <
class Scalar1,
class Scalar2>
343 KOKKOS_FORCEINLINE_FUNCTION
344 static auto apply(
const Scalar1& val1,
const Scalar2& val2)
345 -> decltype(val1*val2)
351 template <
class Scalar1,
class Scalar2>
352 KOKKOS_FORCEINLINE_FUNCTION
353 static auto apply(
const Scalar1& val1,
const Scalar2& val2)
354 -> decltype(val1/val2)
364 template <
typename S>
366 atomic_max_fetch(GeneralFad<S>* dest,
const GeneralFad<S>&
val) {
367 return Impl::atomic_oper_fetch(Impl::MaxOper(), dest,
val);
369 template <
typename ValT,
unsigned sl,
unsigned ss,
typename U,
typename T>
371 atomic_max_fetch(ViewFadPtr<ValT,sl,ss,U> dest,
const Expr<T>&
val) {
372 return Impl::atomic_oper_fetch(Impl::MaxOper(), dest,
val);
374 template <
typename S>
376 atomic_min_fetch(GeneralFad<S>* dest,
const GeneralFad<S>&
val) {
377 return Impl::atomic_oper_fetch(Impl::MinOper(), dest,
val);
379 template <
typename ValT,
unsigned sl,
unsigned ss,
typename U,
typename T>
381 atomic_min_fetch(ViewFadPtr<ValT,sl,ss,U> dest,
const Expr<T>&
val) {
382 return Impl::atomic_oper_fetch(Impl::MinOper(), dest,
val);
384 template <
typename S>
386 atomic_add_fetch(GeneralFad<S>* dest,
const GeneralFad<S>&
val) {
387 return Impl::atomic_oper_fetch(Impl::AddOper(), dest,
val);
389 template <
typename ValT,
unsigned sl,
unsigned ss,
typename U,
typename T>
391 atomic_add_fetch(ViewFadPtr<ValT,sl,ss,U> dest,
const Expr<T>&
val) {
392 return Impl::atomic_oper_fetch(Impl::AddOper(), dest,
val);
394 template <
typename S>
396 atomic_sub_fetch(GeneralFad<S>* dest,
const GeneralFad<S>&
val) {
397 return Impl::atomic_oper_fetch(Impl::SubOper(), dest,
val);
399 template <
typename ValT,
unsigned sl,
unsigned ss,
typename U,
typename T>
401 atomic_sub_fetch(ViewFadPtr<ValT,sl,ss,U> dest,
const Expr<T>&
val) {
402 return Impl::atomic_oper_fetch(Impl::SubOper(), dest,
val);
404 template <
typename S>
406 atomic_mul_fetch(GeneralFad<S>* dest,
const GeneralFad<S>&
val) {
407 return atomic_oper_fetch(Impl::MulOper(), dest,
val);
409 template <
typename ValT,
unsigned sl,
unsigned ss,
typename U,
typename T>
411 atomic_mul_fetch(ViewFadPtr<ValT,sl,ss,U> dest,
const Expr<T>&
val) {
412 return Impl::atomic_oper_fetch(Impl::MulOper(), dest,
val);
414 template <
typename S>
416 atomic_div_fetch(GeneralFad<S>* dest,
const GeneralFad<S>&
val) {
417 return Impl::atomic_oper_fetch(Impl::DivOper(), dest,
val);
419 template <
typename ValT,
unsigned sl,
unsigned ss,
typename U,
typename T>
421 atomic_div_fetch(ViewFadPtr<ValT,sl,ss,U> dest,
const Expr<T>&
val) {
422 return Impl::atomic_oper_fetch(Impl::DivOper(), dest,
val);
425 template <
typename S>
427 atomic_fetch_max(GeneralFad<S>* dest,
const GeneralFad<S>&
val) {
428 return Impl::atomic_fetch_oper(Impl::MaxOper(), dest,
val);
430 template <
typename ValT,
unsigned sl,
unsigned ss,
typename U,
typename T>
432 atomic_fetch_max(ViewFadPtr<ValT,sl,ss,U> dest,
const Expr<T>&
val) {
433 return Impl::atomic_fetch_oper(Impl::MaxOper(), dest,
val);
435 template <
typename S>
437 atomic_fetch_min(GeneralFad<S>* dest,
const GeneralFad<S>&
val) {
438 return Impl::atomic_fetch_oper(Impl::MinOper(), dest,
val);
440 template <
typename ValT,
unsigned sl,
unsigned ss,
typename U,
typename T>
442 atomic_fetch_min(ViewFadPtr<ValT,sl,ss,U> dest,
const Expr<T>&
val) {
443 return Impl::atomic_fetch_oper(Impl::MinOper(), dest,
val);
445 template <
typename S>
447 atomic_fetch_add(GeneralFad<S>* dest,
const GeneralFad<S>&
val) {
448 return Impl::atomic_fetch_oper(Impl::AddOper(), dest,
val);
450 template <
typename ValT,
unsigned sl,
unsigned ss,
typename U,
typename T>
452 atomic_fetch_add(ViewFadPtr<ValT,sl,ss,U> dest,
const Expr<T>&
val) {
453 return Impl::atomic_fetch_oper(Impl::AddOper(), dest,
val);
455 template <
typename S>
457 atomic_fetch_sub(GeneralFad<S>* dest,
const GeneralFad<S>&
val) {
458 return Impl::atomic_fetch_oper(Impl::SubOper(), dest,
val);
460 template <
typename ValT,
unsigned sl,
unsigned ss,
typename U,
typename T>
462 atomic_fetch_sub(ViewFadPtr<ValT,sl,ss,U> dest,
const Expr<T>&
val) {
463 return Impl::atomic_fetch_oper(Impl::SubOper(), dest,
val);
465 template <
typename S>
467 atomic_fetch_mul(GeneralFad<S>* dest,
const GeneralFad<S>&
val) {
468 return Impl::atomic_fetch_oper(Impl::MulOper(), dest,
val);
470 template <
typename ValT,
unsigned sl,
unsigned ss,
typename U,
typename T>
472 atomic_fetch_mul(ViewFadPtr<ValT,sl,ss,U> dest,
const Expr<T>&
val) {
473 return Impl::atomic_fetch_oper(Impl::MulOper(), dest,
val);
475 template <
typename S>
477 atomic_fetch_div(GeneralFad<S>* dest,
const GeneralFad<S>&
val) {
478 return Impl::atomic_fetch_oper(Impl::DivOper(), dest,
val);
480 template <
typename ValT,
unsigned sl,
unsigned ss,
typename U,
typename T>
482 atomic_fetch_div(ViewFadPtr<ValT,sl,ss,U> dest,
const Expr<T>&
val) {
483 return Impl::atomic_fetch_oper(Impl::DivOper(), dest,
val);
491 #endif // HAVE_SACADO_KOKKOSCORE 492 #endif // SACADO_FAD_EXP_VIEWFAD_HPP #define SACADO_FAD_THREAD_SINGLE
SimpleFad< ValueT > min(const SimpleFad< ValueT > &a, const SimpleFad< ValueT > &b)
#define SACADO_FAD_DERIV_LOOP(I, SZ)
Get the base Fad type from a view/expression.
T derived_type
Typename of derived object, returned by derived()
SimpleFad< ValueT > max(const SimpleFad< ValueT > &a, const SimpleFad< ValueT > &b)
#define SACADO_INLINE_FUNCTION