16#ifndef dealii_differentiation_ad_ad_number_traits_h
17#define dealii_differentiation_ad_ad_number_traits_h
26#include <boost/type_traits.hpp>
67 template <
typename ADNumberType,
typename T =
void>
142 template <
typename ADNumberType,
typename T =
void>
175 template <
typename ADNumberType,
typename T =
void>
191 template <
typename ADNumberTrait,
typename T =
void>
209 template <
typename T>
217 template <
typename Number>
229 template <
typename NumberType>
239 template <
typename NumberType,
typename =
void>
249 template <
typename NumberType,
typename =
void>
259 template <
typename NumberType,
typename =
void>
269 template <
typename NumberType,
typename =
void>
288 template <
typename ADNumberTrait,
typename>
289 struct HasRequiredADInfo : std::false_type
304 template <
typename ADNumberTrait>
305 struct HasRequiredADInfo<
309 (void)std::declval<typename ADNumberTrait::real_type>(),
310 (void)std::declval<typename ADNumberTrait::derivative_type>(),
313 std::is_floating_point<typename ADNumberTrait::real_type>::value,
315 std::true_type>::type
323 template <
typename ScalarType>
324 struct ADNumberInfoFromEnum<
327 std::enable_if_t<std::is_floating_point<ScalarType>::value>>
341 template <
typename ScalarType>
353 template <
typename ADNumberType>
366 template <
typename ADNumberType>
373 "Floating point numbers cannot be marked as dependent variables."));
381 template <
typename ADNumberType>
388 template <
typename ScalarType>
398 "Marking for complex numbers has not yet been implemented."));
405 template <
typename ScalarType>
412 "Marking for complex numbers has not yet been implemented."));
419 template <
typename NumberType,
typename>
420 struct is_taped_ad_number : std::false_type
424 template <
typename NumberType,
typename>
425 struct is_tapeless_ad_number : std::false_type
429 template <
typename NumberType,
typename>
430 struct is_real_valued_ad_number : std::false_type
434 template <
typename NumberType,
typename>
435 struct is_complex_valued_ad_number : std::false_type
444 template <
typename NumberType>
446 : internal::HasRequiredADInfo<
447 ADNumberTraits<typename std::decay<NumberType>::type>>
455 template <
typename NumberType>
456 struct is_taped_ad_number<
459 ADNumberTraits<typename std::decay<NumberType>::type>::is_taped>>
468 template <
typename NumberType>
469 struct is_tapeless_ad_number<
472 ADNumberTraits<typename std::decay<NumberType>::type>::is_tapeless>>
482 template <
typename NumberType>
483 struct is_real_valued_ad_number<
486 ADNumberTraits<typename std::decay<NumberType>::type>::is_real_valued>>
496 template <
typename NumberType>
497 struct is_complex_valued_ad_number<
500 typename std::decay<NumberType>::type>::is_complex_valued>>
511 template <
typename Number>
512 struct RemoveComplexWrapper
523 template <
typename Number>
524 struct RemoveComplexWrapper<
std::complex<Number>>
526 using type =
typename RemoveComplexWrapper<Number>::type;
535 template <
typename NumberType>
543 static const NumberType &
544 value(
const NumberType &
x)
576 template <
typename ADNumberType>
577 struct ExtractData<
std::complex<ADNumberType>>
580 "Expected an auto-differentiable number.");
586 static std::complex<typename ADNumberTraits<ADNumberType>::scalar_type>
587 value(
const std::complex<ADNumberType> &
x)
590 typename ADNumberTraits<ADNumberType>::scalar_type>(
591 ExtractData<ADNumberType>::value(
x.real()),
592 ExtractData<ADNumberType>::value(
x.imag()));
602 return ExtractData<ADNumberType>::n_directional_derivatives(
x.real());
610 typename ADNumberTraits<ADNumberType>::derivative_type>
612 const unsigned int direction)
615 typename ADNumberTraits<ADNumberType>::derivative_type>(
616 ExtractData<ADNumberType>::directional_derivative(
x.real(),
618 ExtractData<ADNumberType>::directional_derivative(
x.imag(),
624 template <
typename T>
630 template <
typename F>
632 value(
const F &f, std::enable_if_t<!is_ad_number<F>::value> * =
nullptr)
637 return ::internal::NumberType<T>::value(f);
646 template <
typename F>
649 std::enable_if_t<is_ad_number<F>::value &&
650 std::is_floating_point<T>::value> * =
nullptr)
655 return NumberType<T>::value(ExtractData<F>::value(f));
664 template <
typename F>
667 std::enable_if_t<is_ad_number<F>::value && is_ad_number<T>::value>
674 template <
typename T>
675 struct NumberType<
std::complex<T>>
680 template <
typename F>
682 value(
const F &f, std::enable_if_t<!is_ad_number<F>::value> * =
nullptr)
687 return ::internal::NumberType<std::complex<T>>::value(f);
695 template <
typename F>
696 static std::complex<T>
698 std::enable_if_t<is_ad_number<F>::value &&
699 std::is_floating_point<T>::value> * =
nullptr)
704 return std::complex<T>(
705 NumberType<T>::value(ExtractData<F>::value(f)));
708 template <
typename F>
709 static std::complex<T>
710 value(
const std::complex<F> &f)
714 return std::complex<T>(NumberType<T>::value(f.real()),
715 NumberType<T>::value(f.imag()));
743 template <
typename ScalarType, enum NumberTypes ADNumberTypeCode>
748 std::is_floating_point<ScalarType>::value ||
749 (boost::is_complex<ScalarType>::value &&
750 std::is_floating_point<
751 typename internal::RemoveComplexWrapper<ScalarType>::type>::value)>>
791 static const bool is_complex_valued;
806 static constexpr bool is_taped = internal::ADNumberInfoFromEnum<
807 typename internal::RemoveComplexWrapper<ScalarType>::type,
816 !(NumberTraits<ScalarType, ADNumberTypeCode>::is_taped);
824 (!boost::is_complex<ScalarType>::value);
831 static constexpr bool is_complex_valued =
832 !(NumberTraits<ScalarType, ADNumberTypeCode>::is_real_valued);
840 internal::ADNumberInfoFromEnum<
841 typename internal::RemoveComplexWrapper<ScalarType>::type,
857 using real_type =
typename internal::ADNumberInfoFromEnum<
858 typename internal::RemoveComplexWrapper<ScalarType>::type,
871 using ad_type =
typename std::
872 conditional<is_real_valued, real_type, complex_type>::type;
879 typename internal::ADNumberInfoFromEnum<
880 typename internal::RemoveComplexWrapper<ScalarType>::type,
882 std::complex<
typename internal::ADNumberInfoFromEnum<
883 typename internal::RemoveComplexWrapper<ScalarType>::type,
899 internal::ExtractData<ad_type>::value(
x));
907 const ad_type &
x,
const unsigned int direction)
909 return internal::ExtractData<ad_type>::directional_derivative(
920 return internal::ExtractData<ad_type>::n_directional_derivatives(
x);
925 std::is_same<ad_type, real_type>::value :
926 std::is_same<ad_type, complex_type>::value),
927 "Incorrect template type selected for ad_type");
929 static_assert((is_complex_valued ==
true ?
930 boost::is_complex<scalar_type>::value :
932 "Expected a complex float_type");
934 static_assert((is_complex_valued ==
true ?
935 boost::is_complex<ad_type>::value :
937 "Expected a complex ad_type");
942 template <
typename ScalarType, enum NumberTypes ADNumberTypeCode>
943 const bool NumberTraits<
947 std::is_floating_point<ScalarType>::value ||
948 (boost::is_complex<ScalarType>::value &&
949 std::is_floating_point<
typename internal::RemoveComplexWrapper<
951 internal::ADNumberInfoFromEnum<
952 typename internal::RemoveComplexWrapper<ScalarType>::type,
956 template <
typename ScalarType, enum NumberTypes ADNumberTypeCode>
957 const bool NumberTraits<
961 std::is_floating_point<ScalarType>::value ||
962 (boost::is_complex<ScalarType>::value &&
963 std::is_floating_point<
typename internal::RemoveComplexWrapper<
965 !(NumberTraits<ScalarType, ADNumberTypeCode>::is_taped);
968 template <
typename ScalarType, enum NumberTypes ADNumberTypeCode>
969 const bool NumberTraits<
973 std::is_floating_point<ScalarType>::value ||
974 (boost::is_complex<ScalarType>::value &&
975 std::is_floating_point<
typename internal::RemoveComplexWrapper<
977 (!boost::is_complex<ScalarType>::value);
980 template <
typename ScalarType, enum NumberTypes ADNumberTypeCode>
981 const bool NumberTraits<
985 std::is_floating_point<ScalarType>::value ||
986 (boost::is_complex<ScalarType>::value &&
987 std::is_floating_point<
typename internal::RemoveComplexWrapper<
988 ScalarType>::type>::value)>>::is_complex_valued =
989 !(NumberTraits<ScalarType, ADNumberTypeCode>::is_real_valued);
992 template <
typename ScalarType, enum NumberTypes ADNumberTypeCode>
993 const unsigned int NumberTraits<
997 std::is_floating_point<ScalarType>::value ||
998 (boost::is_complex<ScalarType>::value &&
999 std::is_floating_point<
typename internal::RemoveComplexWrapper<
1001 internal::ADNumberInfoFromEnum<
1002 typename internal::RemoveComplexWrapper<ScalarType>::type,
1007 template <
typename ScalarType, enum NumberTypes ADNumberTypeCode>
1008 constexpr bool NumberTraits<
1012 std::is_floating_point<ScalarType>::value ||
1013 (boost::is_complex<ScalarType>::value &&
1014 std::is_floating_point<
typename internal::RemoveComplexWrapper<
1018 template <
typename ScalarType, enum NumberTypes ADNumberTypeCode>
1019 constexpr bool NumberTraits<
1023 std::is_floating_point<ScalarType>::value ||
1024 (boost::is_complex<ScalarType>::value &&
1025 std::is_floating_point<
typename internal::RemoveComplexWrapper<
1029 template <
typename ScalarType, enum NumberTypes ADNumberTypeCode>
1030 constexpr bool NumberTraits<
1034 std::is_floating_point<ScalarType>::value ||
1035 (boost::is_complex<ScalarType>::value &&
1036 std::is_floating_point<
typename internal::RemoveComplexWrapper<
1040 template <
typename ScalarType, enum NumberTypes ADNumberTypeCode>
1041 constexpr bool NumberTraits<
1045 std::is_floating_point<ScalarType>::value ||
1046 (boost::is_complex<ScalarType>::value &&
1047 std::is_floating_point<
typename internal::RemoveComplexWrapper<
1048 ScalarType>::type>::value)>>::is_complex_valued;
1051 template <
typename ScalarType, enum NumberTypes ADNumberTypeCode>
1052 constexpr unsigned int NumberTraits<
1056 std::is_floating_point<ScalarType>::value ||
1057 (boost::is_complex<ScalarType>::value &&
1058 std::is_floating_point<
typename internal::RemoveComplexWrapper<
1074 template <
typename ScalarType>
1075 struct NumberTraits<
1079 std::is_floating_point<ScalarType>::value ||
1080 (boost::is_complex<ScalarType>::value &&
1081 std::is_floating_point<
1082 typename internal::RemoveComplexWrapper<ScalarType>::type>::value)>>
1122 static const bool is_complex_valued;
1137 static constexpr bool is_taped =
false;
1152 (!boost::is_complex<ScalarType>::value);
1182 typename ::numbers::NumberTraits<scalar_type>::real_type;
1221 "Floating point/arithmetic numbers have no directional derivatives."));
1236 "Floating point/arithmetic numbers have no directional derivatives."));
1243 template <
typename ScalarType>
1244 const bool NumberTraits<
1248 std::is_floating_point<ScalarType>::value ||
1249 (boost::is_complex<ScalarType>::value &&
1250 std::is_floating_point<
typename internal::RemoveComplexWrapper<
1254 template <
typename ScalarType>
1255 const bool NumberTraits<
1259 std::is_floating_point<ScalarType>::value ||
1260 (boost::is_complex<ScalarType>::value &&
1261 std::is_floating_point<
typename internal::RemoveComplexWrapper<
1265 template <
typename ScalarType>
1266 const bool NumberTraits<
1270 std::is_floating_point<ScalarType>::value ||
1271 (boost::is_complex<ScalarType>::value &&
1272 std::is_floating_point<
typename internal::RemoveComplexWrapper<
1274 (!boost::is_complex<ScalarType>::value);
1277 template <
typename ScalarType>
1278 const bool NumberTraits<
1282 std::is_floating_point<ScalarType>::value ||
1283 (boost::is_complex<ScalarType>::value &&
1284 std::is_floating_point<
typename internal::RemoveComplexWrapper<
1285 ScalarType>::type>::value)>>::is_complex_valued =
1286 !(NumberTraits<ScalarType, NumberTypes::none>::is_real_valued);
1289 template <
typename ScalarType>
1290 const unsigned int NumberTraits<
1294 std::is_floating_point<ScalarType>::value ||
1295 (boost::is_complex<ScalarType>::value &&
1296 std::is_floating_point<
typename internal::RemoveComplexWrapper<
1301 template <
typename ScalarType>
1302 constexpr bool NumberTraits<
1306 std::is_floating_point<ScalarType>::value ||
1307 (boost::is_complex<ScalarType>::value &&
1308 std::is_floating_point<
typename internal::RemoveComplexWrapper<
1312 template <
typename ScalarType>
1313 constexpr bool NumberTraits<
1317 std::is_floating_point<ScalarType>::value ||
1318 (boost::is_complex<ScalarType>::value &&
1319 std::is_floating_point<
typename internal::RemoveComplexWrapper<
1323 template <
typename ScalarType>
1324 constexpr bool NumberTraits<
1328 std::is_floating_point<ScalarType>::value ||
1329 (boost::is_complex<ScalarType>::value &&
1330 std::is_floating_point<
typename internal::RemoveComplexWrapper<
1334 template <
typename ScalarType>
1335 constexpr bool NumberTraits<
1339 std::is_floating_point<ScalarType>::value ||
1340 (boost::is_complex<ScalarType>::value &&
1341 std::is_floating_point<
typename internal::RemoveComplexWrapper<
1342 ScalarType>::type>::value)>>::is_complex_valued;
1345 template <
typename ScalarType>
1346 constexpr unsigned int NumberTraits<
1350 std::is_floating_point<ScalarType>::value ||
1351 (boost::is_complex<ScalarType>::value &&
1352 std::is_floating_point<
typename internal::RemoveComplexWrapper<
1376 template <
typename ScalarType>
1377 struct ADNumberTraits<
1380 : NumberTraits<ScalarType, NumberTypes::none>
1386 template <
typename ComplexScalarType>
1387 struct ADNumberTraits<
1390 boost::is_complex<ComplexScalarType>::value &&
1391 std::is_floating_point<typename ComplexScalarType::value_type>::value>>
1392 : NumberTraits<ComplexScalarType, NumberTypes::none>
#define DEAL_II_NAMESPACE_OPEN
#define DEAL_II_NAMESPACE_CLOSE
#define Assert(cond, exc)
static ::ExceptionBase & ExcMessage(std::string arg1)
#define AssertThrow(cond, exc)
static constexpr DEAL_II_HOST_DEVICE_ALWAYS_INLINE const T & value(const T &t)