23 void ssyevd_(
char* jobz,
char* uplo,
int* n,
float* a,
int* lda,
float* w,
24 float* work,
int* lwork,
int* iwork,
int* liwork,
int* info);
25 void dsyevd_(
char* jobz,
char* uplo,
int* n,
double* a,
int* lda,
double* w,
26 double* work,
int* lwork,
int* iwork,
int* liwork,
int* info);
28 void sgesv_(
int* N,
int* NRHS,
float* A,
int* LDA,
int* IPIV,
float* B,
30 void dgesv_(
int* N,
int* NRHS,
double* A,
int* LDA,
int* IPIV,
double* B,
33 void sgemm_(
char* transa,
char* transb,
int* m,
int* n,
int* k,
float* alpha,
34 float* a,
int* lda,
float* b,
int* ldb,
float* beta,
float* c,
36 void dgemm_(
char* transa,
char* transb,
int* m,
int* n,
int* k,
double* alpha,
37 double* a,
int* lda,
double* b,
int* ldb,
double* beta,
double* c,
40 int sgetrf_(
const int* m,
const int* n,
float* a,
const int* lda,
int* lpiv,
42 int dgetrf_(
const int* m,
const int* n,
double* a,
const int* lda,
int* lpiv,
59template <std::
floating_po
int T>
60void dot_blas(std::span<const T>
A, std::array<std::size_t, 2>
Ashape,
61 std::span<const T>
B, std::array<std::size_t, 2>
Bshape,
64 static_assert(std::is_same_v<T, float>
or std::is_same_v<T, double>);
77 if constexpr (std::is_same_v<T, float>)
82 else if constexpr (std::is_same_v<T, double>)
95template <
typename U,
typename V>
96std::pair<std::vector<typename U::value_type>, std::array<std::size_t, 2>>
99 std::vector<typename U::value_type>
result(
u.size() *
v.size());
100 for (std::size_t
i = 0;
i <
u.size(); ++
i)
101 for (std::size_t
j = 0;
j <
v.size(); ++
j)
103 return {std::move(
result), {
u.size(),
v.size()}};
110template <
typename U,
typename V>
111std::array<typename U::value_type, 3>
cross(
const U&
u,
const V&
v)
115 return {
u[1] *
v[2] -
u[2] *
v[1],
u[2] *
v[0] -
u[0] *
v[2],
116 u[0] *
v[1] -
u[1] *
v[0]};
126template <std::
floating_po
int T>
127std::pair<std::vector<T>, std::vector<T>>
eigh(std::span<const T>
A,
131 std::vector<T> M(
A.begin(),
A.end());
134 std::vector<T>
w(
n, 0);
143 std::vector<T>
work(1);
144 std::vector<int>
iwork(1);
147 if constexpr (std::is_same_v<T, float>)
152 else if constexpr (std::is_same_v<T, double>)
159 throw std::runtime_error(
"Could not find workspace size for syevd.");
166 if constexpr (std::is_same_v<T, float>)
171 else if constexpr (std::is_same_v<T, double>)
177 throw std::runtime_error(
"Eigenvalue computation did not converge.");
179 return {std::move(
w), std::move(M)};
186template <std::
floating_po
int T>
187std::vector<T>
solve(md::mdspan<
const T, md::dextents<std::size_t, 2>>
A,
188 md::mdspan<
const T, md::dextents<std::size_t, 2>>
B)
191 mdex::mdarray<T, md::dextents<std::size_t, 2>, md::layout_left>
_A(
194 for (std::size_t
i = 0;
i <
A.extent(0); ++
i)
195 for (std::size_t
j = 0;
j <
A.extent(1); ++
j)
197 for (std::size_t
i = 0;
i <
B.extent(0); ++
i)
198 for (std::size_t
j = 0;
j <
B.extent(1); ++
j)
201 int N =
_A.extent(0);
203 int lda =
_A.extent(0);
204 int ldb =
_B.extent(0);
207 std::vector<int>
piv(
N);
209 if constexpr (std::is_same_v<T, float>)
211 else if constexpr (std::is_same_v<T, double>)
214 throw std::runtime_error(
"Call to dgesv failed: " + std::to_string(
info));
217 std::vector<T>
rb(
_B.extent(0) *
_B.extent(1));
218 md::mdspan<T, md::dextents<std::size_t, 2>>
r(
rb.data(),
_B.extents());
219 for (std::size_t
i = 0;
i <
_B.extent(0); ++
i)
220 for (std::size_t
j = 0;
j <
_B.extent(1); ++
j)
229template <std::
floating_po
int T>
233 mdex::mdarray<T, md::dextents<std::size_t, 2>, md::layout_left>
_A(
235 for (std::size_t
i = 0;
i <
A.extent(0); ++
i)
236 for (std::size_t
j = 0;
j <
A.extent(1); ++
j)
239 std::vector<T>
B(
A.extent(1), 1);
240 int N =
_A.extent(0);
242 int lda =
_A.extent(0);
246 std::vector<int>
piv(
N);
248 if constexpr (std::is_same_v<T, float>)
250 else if constexpr (std::is_same_v<T, double>)
255 throw std::runtime_error(
"dgesv failed due to invalid value: "
256 + std::to_string(
info));
269template <std::
floating_po
int T>
270std::vector<std::size_t>
273 std::size_t dim =
A.second[0];
280 if constexpr (std::is_same_v<T, float>)
282 else if constexpr (std::is_same_v<T, double>)
287 throw std::runtime_error(
"LU decomposition failed: "
288 + std::to_string(
info));
291 std::vector<std::size_t>
perm(dim);
292 for (std::size_t
i = 0;
i < dim; ++
i)
305template <
typename U,
typename V,
typename W>
307 typename std::decay_t<U>::value_type
alpha = 1,
308 typename std::decay_t<U>::value_type
beta = 0)
310 using T =
typename std::decay_t<U>::value_type;
315 if (
A.extent(0) *
B.extent(1) *
A.extent(1) < 256)
317 for (std::size_t
i = 0;
i <
A.extent(0); ++
i)
319 for (std::size_t
j = 0;
j <
B.extent(1); ++
j)
324 for (std::size_t
k = 0;
k <
A.extent(1); ++
k)
332 static_assert(std::is_same_v<typename std::decay_t<U>::layout_type,
334 static_assert(std::is_same_v<typename std::decay_t<V>::layout_type,
336 static_assert(std::is_same_v<typename std::decay_t<W>::layout_type,
338 static_assert(std::is_same_v<typename std::decay_t<V>::value_type,
T>);
339 static_assert(std::is_same_v<typename std::decay_t<W>::value_type,
T>);
341 std::span(
A.data_handle(),
A.size()), {A.extent(0), A.extent(1)},
342 std::span(
B.data_handle(),
B.size()), {B.extent(0), B.extent(1)},
343 std::span(
C.data_handle(),
C.size()),
alpha,
beta);
350template <std::
floating_po
int T>
351std::vector<T>
eye(std::size_t
n)
353 std::vector<T>
I(
n *
n, 0);
354 md::mdspan<T, md::dextents<std::size_t, 2>>
Iview(
I.data(),
n,
n);
355 for (std::size_t
i = 0;
i <
n; ++
i)
364template <std::
floating_po
int T>
366 std::size_t
start = 0)
368 for (std::size_t
i =
start;
i < wcoeffs.extent(0); ++
i)
371 for (std::size_t
k = 0;
k < wcoeffs.extent(1); ++
k)
372 norm += wcoeffs(
i,
k) * wcoeffs(
i,
k);
375 if (
norm < 2 * std::numeric_limits<T>::epsilon())
377 throw std::runtime_error(
"Cannot orthogonalise the rows of a matrix "
378 "with incomplete row rank");
381 for (std::size_t
k = 0;
k < wcoeffs.extent(1); ++
k)
384 for (std::size_t
j =
i + 1;
j < wcoeffs.extent(0); ++
j)
387 for (std::size_t
k = 0;
k < wcoeffs.extent(1); ++
k)
388 a += wcoeffs(
i,
k) * wcoeffs(
j,
k);
389 for (std::size_t
k = 0;
k < wcoeffs.extent(1); ++
k)
390 wcoeffs(
j,
k) -=
a * wcoeffs(
i,
k);
A finite element.
Definition finite-element.h:137
Mathematical functions.
Definition math.h:51
bool is_singular(md::mdspan< const T, md::dextents< std::size_t, 2 > > A)
Check if A is a singular matrix.
Definition math.h:230
std::array< typename U::value_type, 3 > cross(const U &u, const V &v)
Compute the cross product u x v.
Definition math.h:111
std::vector< std::size_t > transpose_lu(std::pair< std::vector< T >, std::array< std::size_t, 2 > > &A)
Compute the LU decomposition of the transpose of a square matrix A.
Definition math.h:271
void orthogonalise(md::mdspan< T, md::dextents< std::size_t, 2 > > wcoeffs, std::size_t start=0)
Orthogonalise the rows of a matrix (in place).
Definition math.h:365
std::vector< T > solve(md::mdspan< const T, md::dextents< std::size_t, 2 > > A, md::mdspan< const T, md::dextents< std::size_t, 2 > > B)
Solve A X = B.
Definition math.h:187
void dot(const U &A, const V &B, W &&C, typename std::decay_t< U >::value_type alpha=1, typename std::decay_t< U >::value_type beta=0)
Compute C = alpha A * B + beta C.
Definition math.h:306
std::pair< std::vector< T >, std::vector< T > > eigh(std::span< const T > A, std::size_t n)
Compute the eigenvalues and eigenvectors of a square Hermitian matrix A.
Definition math.h:127
std::vector< T > eye(std::size_t n)
Build an identity matrix.
Definition math.h:351
std::pair< std::vector< typename U::value_type >, std::array< std::size_t, 2 > > outer(const U &u, const V &v)
Compute the outer product of vectors u and v.
Definition math.h:97