Visual Servoing Platform version 3.6.0
Loading...
Searching...
No Matches
vpGEMM.h
1/*
2 * ViSP, open source Visual Servoing Platform software.
3 * Copyright (C) 2005 - 2023 by Inria. All rights reserved.
4 *
5 * This software is free software; you can redistribute it and/or modify
6 * it under the terms of the GNU General Public License as published by
7 * the Free Software Foundation; either version 2 of the License, or
8 * (at your option) any later version.
9 * See the file LICENSE.txt at the root directory of this source
10 * distribution for additional information about the GNU GPL.
11 *
12 * For using ViSP with software that can not be combined with the GNU
13 * GPL, please contact Inria about acquiring a ViSP Professional
14 * Edition License.
15 *
16 * See https://visp.inria.fr for more information.
17 *
18 * This software was developed at:
19 * Inria Rennes - Bretagne Atlantique
20 * Campus Universitaire de Beaulieu
21 * 35042 Rennes Cedex
22 * France
23 *
24 * If you have questions regarding the use of this file, please contact
25 * Inria at visp@inria.fr
26 *
27 * This file is provided AS IS with NO WARRANTY OF ANY KIND, INCLUDING THE
28 * WARRANTY OF DESIGN, MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE.
29 *
30 * Description:
31 * Matrix generalized multiplication.
32 */
33
34#ifndef _vpGEMM_h_
35#define _vpGEMM_h_
36
37#include <visp3/core/vpArray2D.h>
38#include <visp3/core/vpException.h>
39
40const vpArray2D<double> null(0, 0);
41
52typedef enum {
53 VP_GEMM_A_T = 1,
54 VP_GEMM_B_T = 2,
55 VP_GEMM_C_T = 4,
56} vpGEMMmethod;
57
58template <unsigned int>
59inline void GEMMsize(const vpArray2D<double> & /*A*/, const vpArray2D<double> & /*B*/, unsigned int & /*Arows*/,
60 unsigned int & /*Acols*/, unsigned int & /*Brows*/, unsigned int & /*Bcols*/)
61{
62}
63
64template <>
65void inline GEMMsize<0>(const vpArray2D<double> &A, const vpArray2D<double> &B, unsigned int &Arows,
66 unsigned int &Acols, unsigned int &Brows, unsigned int &Bcols)
67{
68 Arows = A.getRows();
69 Acols = A.getCols();
70 Brows = B.getRows();
71 Bcols = B.getCols();
72}
73
74template <>
75inline void GEMMsize<1>(const vpArray2D<double> &A, const vpArray2D<double> &B, unsigned int &Arows,
76 unsigned int &Acols, unsigned int &Brows, unsigned int &Bcols)
77{
78 Arows = A.getCols();
79 Acols = A.getRows();
80 Brows = B.getRows();
81 Bcols = B.getCols();
82}
83template <>
84inline void GEMMsize<2>(const vpArray2D<double> &A, const vpArray2D<double> &B, unsigned int &Arows,
85 unsigned int &Acols, unsigned int &Brows, unsigned int &Bcols)
86{
87 Arows = A.getRows();
88 Acols = A.getCols();
89 Brows = B.getCols();
90 Bcols = B.getRows();
91}
92template <>
93inline void GEMMsize<3>(const vpArray2D<double> &A, const vpArray2D<double> &B, unsigned int &Arows,
94 unsigned int &Acols, unsigned int &Brows, unsigned int &Bcols)
95{
96 Arows = A.getCols();
97 Acols = A.getRows();
98 Brows = B.getCols();
99 Bcols = B.getRows();
100}
101
102template <>
103inline void GEMMsize<4>(const vpArray2D<double> &A, const vpArray2D<double> &B, unsigned int &Arows,
104 unsigned int &Acols, unsigned int &Brows, unsigned int &Bcols)
105{
106 Arows = A.getRows();
107 Acols = A.getCols();
108 Brows = B.getRows();
109 Bcols = B.getCols();
110}
111
112template <>
113inline void GEMMsize<5>(const vpArray2D<double> &A, const vpArray2D<double> &B, unsigned int &Arows,
114 unsigned int &Acols, unsigned int &Brows, unsigned int &Bcols)
115{
116 Arows = A.getCols();
117 Acols = A.getRows();
118 Brows = B.getRows();
119 Bcols = B.getCols();
120}
121
122template <>
123inline void GEMMsize<6>(const vpArray2D<double> &A, const vpArray2D<double> &B, unsigned int &Arows,
124 unsigned int &Acols, unsigned int &Brows, unsigned int &Bcols)
125{
126 Arows = A.getRows();
127 Acols = A.getCols();
128 Brows = B.getCols();
129 Bcols = B.getRows();
130}
131
132template <>
133inline void GEMMsize<7>(const vpArray2D<double> &A, const vpArray2D<double> &B, unsigned int &Arows,
134 unsigned int &Acols, unsigned int &Brows, unsigned int &Bcols)
135{
136 Arows = A.getCols();
137 Acols = A.getRows();
138 Brows = B.getCols();
139 Bcols = B.getRows();
140}
141
142template <unsigned int>
143inline void GEMM1(const unsigned int & /*Arows*/, const unsigned int & /*Brows*/, const unsigned int & /*Bcols*/,
144 const vpArray2D<double> & /*A*/, const vpArray2D<double> & /*B*/, const double & /*alpha*/,
145 vpArray2D<double> & /*D*/)
146{
147}
148
149template <>
150inline void GEMM1<0>(const unsigned int &Arows, const unsigned int &Brows, const unsigned int &Bcols,
151 const vpArray2D<double> &A, const vpArray2D<double> &B, const double &alpha, vpArray2D<double> &D)
152{
153 for (unsigned int r = 0; r < Arows; r++)
154 for (unsigned int c = 0; c < Bcols; c++) {
155 double sum = 0;
156 for (unsigned int n = 0; n < Brows; n++)
157 sum += A[r][n] * B[n][c] * alpha;
158 D[r][c] = sum;
159 }
160}
161
162template <>
163inline void GEMM1<1>(const unsigned int &Arows, const unsigned int &Brows, const unsigned int &Bcols,
164 const vpArray2D<double> &A, const vpArray2D<double> &B, const double &alpha, vpArray2D<double> &D)
165{
166 for (unsigned int r = 0; r < Arows; r++)
167 for (unsigned int c = 0; c < Bcols; c++) {
168 double sum = 0;
169 for (unsigned int n = 0; n < Brows; n++)
170 sum += A[n][r] * B[n][c] * alpha;
171 D[r][c] = sum;
172 }
173}
174
175template <>
176inline void GEMM1<2>(const unsigned int &Arows, const unsigned int &Brows, const unsigned int &Bcols,
177 const vpArray2D<double> &A, const vpArray2D<double> &B, const double &alpha, vpArray2D<double> &D)
178{
179 for (unsigned int r = 0; r < Arows; r++)
180 for (unsigned int c = 0; c < Bcols; c++) {
181 double sum = 0;
182 for (unsigned int n = 0; n < Brows; n++)
183 sum += A[r][n] * B[c][n] * alpha;
184 D[r][c] = sum;
185 }
186}
187
188template <>
189inline void GEMM1<3>(const unsigned int &Arows, const unsigned int &Brows, const unsigned int &Bcols,
190 const vpArray2D<double> &A, const vpArray2D<double> &B, const double &alpha, vpArray2D<double> &D)
191{
192 for (unsigned int r = 0; r < Arows; r++)
193 for (unsigned int c = 0; c < Bcols; c++) {
194 double sum = 0;
195 for (unsigned int n = 0; n < Brows; n++)
196 sum += A[n][r] * B[c][n] * alpha;
197 D[r][c] = sum;
198 }
199}
200
201template <unsigned int>
202inline void GEMM2(const unsigned int & /*Arows*/, const unsigned int & /*Brows*/, const unsigned int & /*Bcols*/,
203 const vpArray2D<double> & /*A*/, const vpArray2D<double> & /*B*/, const double & /*alpha*/,
204 const vpArray2D<double> & /*C*/, const double & /*beta*/, vpArray2D<double> & /*D*/)
205{
206}
207
208template <>
209inline void GEMM2<0>(const unsigned int &Arows, const unsigned int &Brows, const unsigned int &Bcols,
210 const vpArray2D<double> &A, const vpArray2D<double> &B, const double &alpha,
211 const vpArray2D<double> &C, const double &beta, vpArray2D<double> &D)
212{
213 for (unsigned int r = 0; r < Arows; r++)
214 for (unsigned int c = 0; c < Bcols; c++) {
215 double sum = 0;
216 for (unsigned int n = 0; n < Brows; n++)
217 sum += A[r][n] * B[n][c] * alpha;
218 D[r][c] = sum + C[r][c] * beta;
219 }
220}
221
222template <>
223inline void GEMM2<1>(const unsigned int &Arows, const unsigned int &Brows, const unsigned int &Bcols,
224 const vpArray2D<double> &A, const vpArray2D<double> &B, const double &alpha,
225 const vpArray2D<double> &C, const double &beta, vpArray2D<double> &D)
226{
227 for (unsigned int r = 0; r < Arows; r++)
228 for (unsigned int c = 0; c < Bcols; c++) {
229 double sum = 0;
230 for (unsigned int n = 0; n < Brows; n++)
231 sum += A[n][r] * B[n][c] * alpha;
232 D[r][c] = sum + C[r][c] * beta;
233 }
234}
235
236template <>
237inline void GEMM2<2>(const unsigned int &Arows, const unsigned int &Brows, const unsigned int &Bcols,
238 const vpArray2D<double> &A, const vpArray2D<double> &B, const double &alpha,
239 const vpArray2D<double> &C, const double &beta, vpArray2D<double> &D)
240{
241 for (unsigned int r = 0; r < Arows; r++)
242 for (unsigned int c = 0; c < Bcols; c++) {
243 double sum = 0;
244 for (unsigned int n = 0; n < Brows; n++)
245 sum += A[r][n] * B[c][n] * alpha;
246 D[r][c] = sum + C[r][c] * beta;
247 }
248}
249
250template <>
251inline void GEMM2<3>(const unsigned int &Arows, const unsigned int &Brows, const unsigned int &Bcols,
252 const vpArray2D<double> &A, const vpArray2D<double> &B, const double &alpha,
253 const vpArray2D<double> &C, const double &beta, vpArray2D<double> &D)
254{
255 for (unsigned int r = 0; r < Arows; r++)
256 for (unsigned int c = 0; c < Bcols; c++) {
257 double sum = 0;
258 for (unsigned int n = 0; n < Brows; n++)
259 sum += A[n][r] * B[c][n] * alpha;
260 D[r][c] = sum + C[r][c] * beta;
261 }
262}
263
264template <>
265inline void GEMM2<4>(const unsigned int &Arows, const unsigned int &Brows, const unsigned int &Bcols,
266 const vpArray2D<double> &A, const vpArray2D<double> &B, const double &alpha,
267 const vpArray2D<double> &C, const double &beta, vpArray2D<double> &D)
268{
269 for (unsigned int r = 0; r < Arows; r++)
270 for (unsigned int c = 0; c < Bcols; c++) {
271 double sum = 0;
272 for (unsigned int n = 0; n < Brows; n++)
273 sum += A[r][n] * B[n][c] * alpha;
274 D[r][c] = sum + C[c][r] * beta;
275 }
276}
277
278template <>
279inline void GEMM2<5>(const unsigned int &Arows, const unsigned int &Brows, const unsigned int &Bcols,
280 const vpArray2D<double> &A, const vpArray2D<double> &B, const double &alpha,
281 const vpArray2D<double> &C, const double &beta, vpArray2D<double> &D)
282{
283 for (unsigned int r = 0; r < Arows; r++)
284 for (unsigned int c = 0; c < Bcols; c++) {
285 double sum = 0;
286 for (unsigned int n = 0; n < Brows; n++)
287 sum += A[n][r] * B[n][c] * alpha;
288 D[r][c] = sum + C[c][r] * beta;
289 }
290}
291
292template <>
293inline void GEMM2<6>(const unsigned int &Arows, const unsigned int &Brows, const unsigned int &Bcols,
294 const vpArray2D<double> &A, const vpArray2D<double> &B, const double &alpha,
295 const vpArray2D<double> &C, const double &beta, vpArray2D<double> &D)
296{
297 for (unsigned int r = 0; r < Arows; r++)
298 for (unsigned int c = 0; c < Bcols; c++) {
299 double sum = 0;
300 for (unsigned int n = 0; n < Brows; n++)
301 sum += A[r][n] * B[c][n] * alpha;
302 D[r][c] = sum + C[c][r] * beta;
303 }
304}
305
306template <>
307inline void GEMM2<7>(const unsigned int &Arows, const unsigned int &Brows, const unsigned int &Bcols,
308 const vpArray2D<double> &A, const vpArray2D<double> &B, const double &alpha,
309 const vpArray2D<double> &C, const double &beta, vpArray2D<double> &D)
310{
311 for (unsigned int r = 0; r < Arows; r++)
312 for (unsigned int c = 0; c < Bcols; c++) {
313 double sum = 0;
314 for (unsigned int n = 0; n < Brows; n++)
315 sum += A[n][r] * B[c][n] * alpha;
316 D[r][c] = sum + C[c][r] * beta;
317 }
318}
319
320template <unsigned int T>
321inline void vpTGEMM(const vpArray2D<double> &A, const vpArray2D<double> &B, const double &alpha,
322 const vpArray2D<double> &C, const double &beta, vpArray2D<double> &D)
323{
324 unsigned int Arows;
325 unsigned int Acols;
326 unsigned int Brows;
327 unsigned int Bcols;
328
329 GEMMsize<T>(A, B, Arows, Acols, Brows, Bcols);
330
331 try {
332 if ((Arows != D.getRows()) || (Bcols != D.getCols()))
333 D.resize(Arows, Bcols);
334 } catch (...) {
335 throw;
336 }
337
338 if (Acols != Brows) {
339 throw(vpException(vpException::dimensionError, "In vpGEMM, cannot multiply (%dx%d) matrix by (%dx%d) matrix", Arows,
340 Acols, Brows, Bcols));
341 }
342
343 if (C.getRows() != 0 && C.getCols() != 0) {
344 if ((Arows != C.getRows()) || (Bcols != C.getCols())) {
345 throw(vpException(vpException::dimensionError, "In vpGEMM, cannot add resulting (%dx%d) matrix to (%dx%d) matrix",
346 Arows, Bcols, C.getRows(), C.getCols()));
347 }
348
349 GEMM2<T>(Arows, Brows, Bcols, A, B, alpha, C, beta, D);
350 } else {
351 GEMM1<T>(Arows, Brows, Bcols, A, B, alpha, D);
352 }
353}
354
388inline void vpGEMM(const vpArray2D<double> &A, const vpArray2D<double> &B, const double &alpha,
389 const vpArray2D<double> &C, const double &beta, vpArray2D<double> &D, const unsigned int &ops = 0)
390{
391 switch (ops) {
392 case 0:
393 vpTGEMM<0>(A, B, alpha, C, beta, D);
394 break;
395 case 1:
396 vpTGEMM<1>(A, B, alpha, C, beta, D);
397 break;
398 case 2:
399 vpTGEMM<2>(A, B, alpha, C, beta, D);
400 break;
401 case 3:
402 vpTGEMM<3>(A, B, alpha, C, beta, D);
403 break;
404 case 4:
405 vpTGEMM<4>(A, B, alpha, C, beta, D);
406 break;
407 case 5:
408 vpTGEMM<5>(A, B, alpha, C, beta, D);
409 break;
410 case 6:
411 vpTGEMM<6>(A, B, alpha, C, beta, D);
412 break;
413 case 7:
414 vpTGEMM<7>(A, B, alpha, C, beta, D);
415 break;
416 default:
417 throw(vpException(vpException::functionNotImplementedError, "Operation on vpGEMM not implemented"));
418 break;
419 }
420}
421
422#endif
Implementation of a generic 2D array used as base class for matrices and vectors.
Definition vpArray2D.h:131
unsigned int getCols() const
Definition vpArray2D.h:280
void vpGEMM(const vpArray2D< double > &A, const vpArray2D< double > &B, const double &alpha, const vpArray2D< double > &C, const double &beta, vpArray2D< double > &D, const unsigned int &ops=0)
Definition vpGEMM.h:388
unsigned int getRows() const
Definition vpArray2D.h:290
error that can be emitted by ViSP classes.
Definition vpException.h:59
@ functionNotImplementedError
Function not implemented.
Definition vpException.h:78
@ dimensionError
Bad dimension.
Definition vpException.h:83