Open3D (C++ API)  0.15.1
ShapeChecking.h
Go to the documentation of this file.
1// ----------------------------------------------------------------------------
2// - Open3D: www.open3d.org -
3// ----------------------------------------------------------------------------
4// The MIT License (MIT)
5//
6// Copyright (c) 2018-2021 www.open3d.org
7//
8// Permission is hereby granted, free of charge, to any person obtaining a copy
9// of this software and associated documentation files (the "Software"), to deal
10// in the Software without restriction, including without limitation the rights
11// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12// copies of the Software, and to permit persons to whom the Software is
13// furnished to do so, subject to the following conditions:
14//
15// The above copyright notice and this permission notice shall be included in
16// all copies or substantial portions of the Software.
17//
18// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
23// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
24// IN THE SOFTWARE.
25// ----------------------------------------------------------------------------
26
27#pragma once
28#include <iostream>
29#include <string>
30#include <tuple>
31#include <vector>
32
33namespace open3d {
34namespace ml {
35namespace op_util {
36
38class DimValue {
39public:
40 DimValue() : value_(0), constant_(false) {}
41 DimValue(int64_t v) : value_(v), constant_(true) {}
43 if (constant_ && b.constant_)
44 value_ *= b.value_;
45 else
46 constant_ = false;
47 return *this;
48 }
49 std::string ToString() const {
50 if (constant_)
51 return std::to_string(value_);
52 else
53 return "?";
54 }
55 int64_t& value() {
56 if (!constant_) throw std::runtime_error("DimValue is not constant");
57 return value_;
58 }
59 bool& constant() { return constant_; }
60
61private:
62 int64_t value_;
63 bool constant_;
64};
65
66inline DimValue UnknownValue() { return DimValue(); }
67
69class Dim {
70public:
71 explicit Dim() : value_(0), constant_(false), origin_(this) {}
72
73 explicit Dim(const std::string& name)
74 : value_(0), constant_(false), origin_(this), name_(name) {}
75
76 Dim(int64_t value, const std::string& name = "")
77 : value_(value), constant_(true), origin_(nullptr), name_(name) {}
78
79 Dim(const Dim& other)
80 : value_(other.value_),
81 constant_(other.constant_),
82 origin_(other.origin_),
83 name_(other.name_) {}
84
85 ~Dim() {}
86
87 Dim& operator=(const Dim&) = delete;
88
89 int64_t& value() {
90 if (origin_)
91 return origin_->value_;
92 else
93 return value_;
94 }
95
96 bool& constant() {
97 if (origin_)
98 return origin_->constant_;
99 else
100 return constant_;
101 }
102
105 bool assign(int64_t a) {
106 if (!constant()) {
107 value() = a;
108 constant() = true;
109 }
110 return value() == a;
111 }
112
113 std::string ToString(bool show_value = true) {
114 if (name_.size()) {
115 if (show_value)
116 return name_ + "(" +
117 (constant() ? std::to_string(value()) : "?") + ")";
118 else
119 return name_;
120 }
121 if (constant())
122 return std::to_string(value());
123 else
124 return "?";
125 }
126
127private:
128 int64_t value_;
129 bool constant_;
130 Dim* origin_;
131 std::string name_;
132};
133
134//
135// Dim expression operator classes
136//
137
138struct DimXPlus {
139 static bool constant() { return true; };
140 static int64_t apply(int64_t a, int64_t b) { return a + b; }
141
142 template <class T1, class T2>
143 static bool backprop(int64_t ans, T1 a, T2 b) {
144 if (!a.constant() && a.constant() == b.constant()) {
145 std::string exstr =
146 GetString(a, false) + ToString() + GetString(b, false);
147 throw std::runtime_error("Illegal dim expression: " + exstr);
148 return false;
149 } else if (!a.constant()) {
150 return a.assign(ans - b.value());
151 } else {
152 return b.assign(ans - a.value());
153 }
154 }
155
156 static std::string ToString() { return "+"; }
157};
158
159struct DimXMinus {
160 static bool constant() { return true; };
161 static int64_t apply(int64_t a, int64_t b) { return a - b; }
162
163 template <class T1, class T2>
164 static bool backprop(int64_t ans, T1 a, T2 b) {
165 if (!a.constant() && a.constant() == b.constant()) {
166 std::string exstr =
167 GetString(a, false) + ToString() + GetString(b, false);
168 throw std::runtime_error("Illegal dim expression: " + exstr);
169 return false;
170 } else if (!a.constant()) {
171 return a.assign(ans + b.value());
172 } else {
173 return b.assign(a.value() - ans);
174 }
175 }
176
177 static std::string ToString() { return "-"; }
178};
179
181 static bool constant() { return true; };
182 static int64_t apply(int64_t a, int64_t b) { return a * b; }
183
184 template <class T1, class T2>
185 static bool backprop(int64_t ans, T1 a, T2 b) {
186 std::string exstr =
187 GetString(a, false) + ToString() + GetString(b, false);
188 throw std::runtime_error("Illegal dim expression: " + exstr);
189 return false;
190 }
191
192 static std::string ToString() { return "*"; }
193};
194
196 static bool constant() { return true; };
197 static int64_t apply(int64_t a, int64_t b) { return a / b; }
198
199 template <class T1, class T2>
200 static bool backprop(int64_t ans, T1 a, T2 b) {
201 std::string exstr =
202 GetString(a, false) + ToString() + GetString(b, false);
203 throw std::runtime_error("Illegal dim expression: " + exstr);
204 return false;
205 }
206
207 static std::string ToString() { return "/"; }
208};
209
210struct DimXOr {
211 static bool constant() { return false; };
212 static int64_t apply(int64_t a, int64_t b) {
213 throw std::runtime_error("Cannot evaluate OR expression");
214 return 0;
215 }
216 template <class T1, class T2>
217 static bool backprop(int64_t ans, T1 a, T2 b) {
218 return a.assign(ans) || b.assign(ans);
219 }
220
221 static std::string ToString() { return "||"; }
222};
223
225template <class TLeft, class TRight, class TOp>
226class DimX {
227public:
228 static DimX<TLeft, TRight, TOp> Create(TLeft left, TRight right) {
229 return DimX(left, right);
230 }
231
232 int64_t value() {
233 if (constant_) {
234 return TOp::apply(left_.value(), right_.value());
235 }
236 return 0;
237 }
238
239 bool& constant() { return constant_; }
240
242 bool assign(int64_t a) {
243 if (constant_) {
244 return value() == a;
245 } else {
246 return TOp::backprop(a, left_, right_);
247 }
248 }
249
250 std::string ToString(bool show_value = true) {
251 return left_.ToString(show_value) + TOp::ToString() +
252 right_.ToString(show_value);
253 }
254
255private:
256 DimX(TLeft left, TRight right) : left_(left), right_(right) {
257 constant_ = left.constant() && right.constant() && TOp::constant();
258 }
259 TLeft left_;
260 TRight right_;
261 bool constant_;
262};
263
264//
265// define operators for dim expressions
266//
267
268#define DEFINE_DIMX_OPERATOR(opclass, symbol) \
269 inline DimX<Dim, Dim, opclass> operator symbol(Dim a, Dim b) { \
270 return DimX<Dim, Dim, opclass>::Create(a, b); \
271 } \
272 \
273 template <class TL, class TR, class TOp> \
274 inline DimX<Dim, DimX<TL, TR, TOp>, opclass> operator symbol( \
275 Dim a, DimX<TL, TR, TOp>&& b) { \
276 return DimX<Dim, DimX<TL, TR, TOp>, opclass>::Create(a, b); \
277 } \
278 \
279 template <class TL, class TR, class TOp> \
280 inline DimX<DimX<TL, TR, TOp>, Dim, opclass> operator symbol( \
281 DimX<TL, TR, TOp>&& a, Dim b) { \
282 return DimX<DimX<TL, TR, TOp>, Dim, opclass>::Create(a, b); \
283 } \
284 \
285 template <class TL1, class TR1, class TOp1, class TL2, class TR2, \
286 class TOp2> \
287 inline DimX<DimX<TL1, TR1, TOp1>, DimX<TL2, TR2, TOp2>, opclass> \
288 operator symbol(DimX<TL1, TR1, TOp1>&& a, DimX<TL2, TR2, TOp2>&& b) { \
289 return DimX<DimX<TL1, TR1, TOp1>, DimX<TL2, TR2, TOp2>, \
290 opclass>::Create(a, b); \
291 }
292
293DEFINE_DIMX_OPERATOR(DimXPlus, +)
294DEFINE_DIMX_OPERATOR(DimXMinus, -)
295DEFINE_DIMX_OPERATOR(DimXMultiply, *)
296DEFINE_DIMX_OPERATOR(DimXDivide, /)
297DEFINE_DIMX_OPERATOR(DimXOr, ||)
298#undef DEFINE_DIMX_OPERATOR
299
300//
301// define operators for comparing DimValue to dim expressions.
302// Using these operators will try to assign the dim value to the expression.
303//
304
305template <class TLeft, class TRight, class TOp>
307 if (a.constant()) {
308 auto b_copy(b);
309 return b_copy.assign(a.value());
310 } else
311 return true;
312}
313
314inline bool operator==(DimValue a, Dim b) {
315 if (a.constant())
316 return b.assign(a.value());
317 else
318 return true;
319}
320
321//
322// some helper classes
323//
324
325template <class... args>
326struct CountArgs {
327 static const size_t value = sizeof...(args);
328};
329
330template <class TLeft, class TRight, class TOp>
331std::string GetString(DimX<TLeft, TRight, TOp> a, bool show_value = true) {
332 return a.ToString(show_value);
333}
334
335inline std::string GetString(Dim a, bool show_value = true) {
336 return a.ToString(show_value);
337}
338
339template <class TLeft, class TRight, class TOp>
341 return a.value();
342}
343
344template <class TLeft, class TRight, class TOp>
345int64_t GetValue(DimX<TLeft, TRight, TOp> a, int64_t unknown_dim_value) {
346 if (a.constant()) {
347 return a.value();
348 } else {
349 return unknown_dim_value;
350 }
351 return a.value();
352}
353
354inline int64_t GetValue(Dim a) { return a.value(); }
355
356inline int64_t GetValue(Dim a, int64_t unknown_dim_value) {
357 if (a.constant()) {
358 return a.value();
359 } else {
360 return unknown_dim_value;
361 }
362}
363
364inline std::string CreateDimXString() { return std::string(); }
365
366template <class TDimX>
367std::string CreateDimXString(TDimX dimex) {
368 return GetString(dimex);
369}
370
371template <class TDimX, class... TArgs>
372std::string CreateDimXString(TDimX dimex, TArgs... args) {
373 return GetString(dimex) + ", " + CreateDimXString(args...);
374}
375
376template <class TDimX>
377void CreateDimVector(std::vector<int64_t>& out,
378 int64_t unknown_dim_value,
379 TDimX dimex) {
380 out.push_back(GetValue(dimex, unknown_dim_value));
381}
382
383template <class TDimX, class... TArgs>
384void CreateDimVector(std::vector<int64_t>& out,
385 int64_t unknown_dim_value,
386 TDimX dimex,
387 TArgs... args) {
388 out.push_back(GetValue(dimex, unknown_dim_value));
389 CreateDimVector(out, unknown_dim_value, args...);
390}
391
392template <class TDimX>
393std::vector<int64_t> CreateDimVector(int64_t unknown_dim_value, TDimX dimex) {
394 std::vector<int64_t> out;
395 CreateDimVector(out, unknown_dim_value, dimex);
396 return out;
397}
398
399template <class TDimX, class... TArgs>
400std::vector<int64_t> CreateDimVector(int64_t unknown_dim_value,
401 TDimX dimex,
402 TArgs... args) {
403 std::vector<int64_t> out;
404 CreateDimVector(out, unknown_dim_value, dimex, args...);
405 return out;
406}
407
408//
409// classes which check if the dim value is compatible with the expression
410//
411
412template <class TLeft, class TRight, class TOp>
414 bool status = (lhs == std::forward<DimX<TLeft, TRight, TOp>>(rhs));
415 return status;
416}
417
418inline bool CheckDim(const DimValue& lhs, Dim d) {
419 bool status = lhs == d;
420 return status;
421}
422
424enum class CSOpt {
425 NONE = 0,
430};
431
432template <CSOpt Opt = CSOpt::NONE, class TDimX>
433bool _CheckShape(const std::vector<DimValue>& shape, TDimX&& dimex) {
434 // check rank
435 const int rank_diff = shape.size() - 1;
436 if (Opt != CSOpt::NONE) {
437 if (rank_diff < 0) {
438 return false;
439 }
440 } else {
441 if (rank_diff != 0) {
442 return false;
443 }
444 }
445
446 // check dim
447 bool status;
448 if (Opt == CSOpt::COMBINE_FIRST_DIMS) {
449 DimValue s(1);
450 for (int i = 0; i < rank_diff + 1; ++i) s *= shape[i];
451 status = CheckDim(s, std::forward<TDimX>(dimex));
452 } else if (Opt == CSOpt::IGNORE_FIRST_DIMS) {
453 status = CheckDim(shape[rank_diff], std::forward<TDimX>(dimex));
454 } else if (Opt == CSOpt::COMBINE_LAST_DIMS) {
455 DimValue s(1);
456 for (DimValue x : shape) s *= x;
457 status = CheckDim(s, std::forward<TDimX>(dimex));
458 } else {
459 status = CheckDim(shape[0], std::forward<TDimX>(dimex));
460 }
461 return status;
462}
463
464template <CSOpt Opt = CSOpt::NONE, class TDimX, class... TArgs>
465bool _CheckShape(const std::vector<DimValue>& shape,
466 TDimX&& dimex,
467 TArgs&&... args) {
468 // check rank
469 const int rank_diff = shape.size() - (CountArgs<TArgs...>::value + 1);
470 if (Opt != CSOpt::NONE) {
471 if (rank_diff < 0) {
472 return false;
473 }
474 } else {
475 if (rank_diff != 0) {
476 return false;
477 }
478 }
479
480 // check dim
481 bool status;
482 if (Opt == CSOpt::COMBINE_FIRST_DIMS) {
483 DimValue s(1);
484 for (int i = 0; i < rank_diff + 1; ++i) s *= shape[i];
485 status = CheckDim(s, std::forward<TDimX>(dimex));
486 } else if (Opt == CSOpt::IGNORE_FIRST_DIMS) {
487 status = CheckDim(shape[rank_diff], std::forward<TDimX>(dimex));
488 } else {
489 status = CheckDim(shape[0], std::forward<TDimX>(dimex));
490 }
491
492 const int offset = 1 + (Opt == CSOpt::COMBINE_FIRST_DIMS ||
494 ? rank_diff
495 : 0);
496 std::vector<DimValue> shape2(shape.begin() + offset, shape.end());
497 bool status2 = _CheckShape<Opt>(shape2, std::forward<TArgs>(args)...);
498
499 return status && status2;
500}
501
592template <CSOpt Opt = CSOpt::NONE, class TDimX, class... TArgs>
593std::tuple<bool, std::string> CheckShape(const std::vector<DimValue>& shape,
594 TDimX&& dimex,
595 TArgs&&... args) {
596 const bool status = _CheckShape<Opt>(shape, std::forward<TDimX>(dimex),
597 std::forward<TArgs>(args)...);
598 if (status) {
599 return std::make_tuple(status, std::string());
600 } else {
601 const int rank_diff = shape.size() - (CountArgs<TArgs...>::value + 1);
602
603 // generate string for the actual shape. This is a bit involved because
604 // of the many options.
605 std::string shape_str;
606 if (rank_diff <= 0) {
607 shape_str = "[";
608 for (int i = 0; i < int(shape.size()); ++i) {
609 shape_str += shape[i].ToString();
610 if (i + 1 < int(shape.size())) shape_str += ", ";
611 }
612 shape_str += "]";
613 } else {
614 if (Opt == CSOpt::COMBINE_FIRST_DIMS) {
615 shape_str += "[";
616 for (int i = 0; i < rank_diff; ++i) {
617 shape_str += shape[i].ToString();
618 if (i + 1 < int(shape.size())) shape_str += "*";
619 }
620 } else if (Opt == CSOpt::IGNORE_FIRST_DIMS) {
621 shape_str += "(";
622 for (int i = 0; i < rank_diff; ++i) {
623 shape_str += shape[i].ToString();
624 if (i + 1 < rank_diff) shape_str += ", ";
625 }
626 shape_str += ")[";
627 } else {
628 shape_str = "[";
629 }
630 int start = 0;
631 if (Opt == CSOpt::COMBINE_FIRST_DIMS ||
633 start = rank_diff;
634 }
635
636 int end = shape.size();
637 if (Opt == CSOpt::COMBINE_LAST_DIMS) {
638 end -= rank_diff + 1;
639 } else if (Opt == CSOpt::IGNORE_LAST_DIMS) {
640 end -= rank_diff;
641 }
642 for (int i = start; i < end; ++i) {
643 shape_str += shape[i].ToString();
644 if (i + 1 < end) shape_str += ", ";
645 }
646 if (Opt == CSOpt::COMBINE_LAST_DIMS) {
647 shape_str += ", ";
648 for (int i = std::max<int>(0, shape.size() - rank_diff - 1);
649 i < int(shape.size()); ++i) {
650 shape_str += shape[i].ToString();
651 if (i + 1 < int(shape.size())) shape_str += "*";
652 }
653 shape_str += "]";
654 } else if (Opt == CSOpt::IGNORE_LAST_DIMS) {
655 shape_str += "](";
656 for (int i = std::max<int>(0, shape.size() - rank_diff);
657 i < int(shape.size()); ++i) {
658 shape_str += shape[i].ToString();
659 if (i + 1 < int(shape.size())) shape_str += ", ";
660 }
661 shape_str += ")";
662 } else {
663 shape_str += "]";
664 }
665 }
666
667 // generate string for the expected shape with the dim expressions
668 std::string expected_shape;
669 if ((CountArgs<TArgs...>::value + 1) == 1) {
670 expected_shape = "[" + GetString(dimex) + "]";
671
672 } else {
673 expected_shape = "[" + GetString(dimex) + ", " +
674 CreateDimXString(args...) + "]";
675 }
676
677 std::string errstr;
678 // print rank information if there is a problem with the rank
679 if ((Opt != CSOpt::NONE && rank_diff < 0) ||
680 (Opt == CSOpt::NONE && rank_diff != 0)) {
681 errstr = "got rank " + std::to_string(shape.size()) + " " +
682 shape_str + ", expected rank " +
683 std::to_string(CountArgs<TArgs...>::value + 1) + " " +
684 expected_shape;
685 } else { // rank is OK print just the shapes
686 errstr = "got " + shape_str + ", expected " + expected_shape;
687 }
688 return std::make_tuple(status, errstr);
689 }
690}
691
692} // namespace op_util
693} // namespace ml
694} // namespace open3d
#define DEFINE_DIMX_OPERATOR(opclass, symbol)
Definition: ShapeChecking.h:268
Class for dimensions for which the value should be inferred.
Definition: ShapeChecking.h:69
int64_t & value()
Definition: ShapeChecking.h:89
bool assign(int64_t a)
Definition: ShapeChecking.h:105
Dim & operator=(const Dim &)=delete
Dim(const Dim &other)
Definition: ShapeChecking.h:79
Dim()
Definition: ShapeChecking.h:71
Dim(const std::string &name)
Definition: ShapeChecking.h:73
~Dim()
Definition: ShapeChecking.h:85
bool & constant()
Definition: ShapeChecking.h:96
std::string ToString(bool show_value=true)
Definition: ShapeChecking.h:113
Dim(int64_t value, const std::string &name="")
Definition: ShapeChecking.h:76
Class for representing a possibly unknown dimension value.
Definition: ShapeChecking.h:38
DimValue & operator*=(const DimValue &b)
Definition: ShapeChecking.h:42
DimValue(int64_t v)
Definition: ShapeChecking.h:41
bool & constant()
Definition: ShapeChecking.h:59
DimValue()
Definition: ShapeChecking.h:40
int64_t & value()
Definition: ShapeChecking.h:55
std::string ToString() const
Definition: ShapeChecking.h:49
Dim expression class.
Definition: ShapeChecking.h:226
std::string ToString(bool show_value=true)
Definition: ShapeChecking.h:250
bool assign(int64_t a)
assigns a value to the expression
Definition: ShapeChecking.h:242
static DimX< TLeft, TRight, TOp > Create(TLeft left, TRight right)
Definition: ShapeChecking.h:228
int64_t value()
Definition: ShapeChecking.h:232
bool & constant()
Definition: ShapeChecking.h:239
std::string name
Definition: FilePCD.cpp:58
int offset
Definition: FilePCD.cpp:64
const char const char value recording_handle imu_sample recording_handle uint8_t size_t data_size k4a_record_configuration_t config target_format k4a_capture_t capture_handle k4a_imu_sample_t imu_sample playback_handle k4a_logging_message_cb_t void min_level device_handle k4a_imu_sample_t timeout_in_ms capture_handle capture_handle capture_handle image_handle temperature_c int
Definition: K4aPlugin.cpp:493
bool operator==(DimValue a, DimX< TLeft, TRight, TOp > &&b)
Definition: ShapeChecking.h:306
std::string GetString(DimX< TLeft, TRight, TOp > a, bool show_value=true)
Definition: ShapeChecking.h:331
DimValue UnknownValue()
Definition: ShapeChecking.h:66
CSOpt
Check shape options.
Definition: ShapeChecking.h:424
std::tuple< bool, std::string > CheckShape(const std::vector< DimValue > &shape, TDimX &&dimex, TArgs &&... args)
Definition: ShapeChecking.h:593
void CreateDimVector(std::vector< int64_t > &out, int64_t unknown_dim_value, TDimX dimex)
Definition: ShapeChecking.h:377
bool CheckDim(const DimValue &lhs, DimX< TLeft, TRight, TOp > &&rhs)
Definition: ShapeChecking.h:413
bool _CheckShape(const std::vector< DimValue > &shape, TDimX &&dimex)
Definition: ShapeChecking.h:433
int64_t GetValue(DimX< TLeft, TRight, TOp > a)
Definition: ShapeChecking.h:340
std::string CreateDimXString()
Definition: ShapeChecking.h:364
Definition: PinholeCameraIntrinsic.cpp:35
Definition: ShapeChecking.h:326
static const size_t value
Definition: ShapeChecking.h:327
Definition: ShapeChecking.h:195
static bool backprop(int64_t ans, T1 a, T2 b)
Definition: ShapeChecking.h:200
static bool constant()
Definition: ShapeChecking.h:196
static int64_t apply(int64_t a, int64_t b)
Definition: ShapeChecking.h:197
static std::string ToString()
Definition: ShapeChecking.h:207
Definition: ShapeChecking.h:159
static std::string ToString()
Definition: ShapeChecking.h:177
static bool backprop(int64_t ans, T1 a, T2 b)
Definition: ShapeChecking.h:164
static int64_t apply(int64_t a, int64_t b)
Definition: ShapeChecking.h:161
static bool constant()
Definition: ShapeChecking.h:160
Definition: ShapeChecking.h:180
static bool constant()
Definition: ShapeChecking.h:181
static std::string ToString()
Definition: ShapeChecking.h:192
static int64_t apply(int64_t a, int64_t b)
Definition: ShapeChecking.h:182
static bool backprop(int64_t ans, T1 a, T2 b)
Definition: ShapeChecking.h:185
Definition: ShapeChecking.h:210
static bool constant()
Definition: ShapeChecking.h:211
static int64_t apply(int64_t a, int64_t b)
Definition: ShapeChecking.h:212
static bool backprop(int64_t ans, T1 a, T2 b)
Definition: ShapeChecking.h:217
static std::string ToString()
Definition: ShapeChecking.h:221
Definition: ShapeChecking.h:138
static bool backprop(int64_t ans, T1 a, T2 b)
Definition: ShapeChecking.h:143
static bool constant()
Definition: ShapeChecking.h:139
static std::string ToString()
Definition: ShapeChecking.h:156
static int64_t apply(int64_t a, int64_t b)
Definition: ShapeChecking.h:140