Open3D (C++ API)  0.15.1
TensorFlowHelper.h
Go to the documentation of this file.
1// ----------------------------------------------------------------------------
2// - Open3D: www.open3d.org -
3// ----------------------------------------------------------------------------
4// The MIT License (MIT)
5//
6// Copyright (c) 2018-2021 www.open3d.org
7//
8// Permission is hereby granted, free of charge, to any person obtaining a copy
9// of this software and associated documentation files (the "Software"), to deal
10// in the Software without restriction, including without limitation the rights
11// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12// copies of the Software, and to permit persons to whom the Software is
13// furnished to do so, subject to the following conditions:
14//
15// The above copyright notice and this permission notice shall be included in
16// all copies or substantial portions of the Software.
17//
18// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
23// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
24// IN THE SOFTWARE.
25// ----------------------------------------------------------------------------
26
27#pragma once
28#include <tensorflow/core/framework/op_kernel.h>
29#include <tensorflow/core/framework/shape_inference.h>
30#include <tensorflow/core/framework/tensor.h>
31#include <tensorflow/core/lib/core/errors.h>
32
34
35inline std::vector<open3d::ml::op_util::DimValue> GetShapeVector(
36 ::tensorflow::shape_inference::InferenceContext* c,
37 ::tensorflow::shape_inference::ShapeHandle shape_handle) {
38 using namespace open3d::ml::op_util;
39 if (!c->RankKnown(shape_handle)) {
40 return std::vector<DimValue>();
41 }
42
43 std::vector<DimValue> shape;
44 const int rank = c->Rank(shape_handle);
45 for (int i = 0; i < rank; ++i) {
46 auto d = c->DimKnownRank(shape_handle, i);
47 if (c->ValueKnown(d)) {
48 shape.push_back(c->Value(d));
49 } else {
50 shape.push_back(DimValue());
51 }
52 }
53 return shape;
54}
55
57 class TDimX,
58 class... TArgs>
59std::tuple<bool, std::string> CheckShape(
60 ::tensorflow::shape_inference::InferenceContext* c,
61 ::tensorflow::shape_inference::ShapeHandle shape_handle,
62 TDimX&& dimex,
63 TArgs&&... args) {
64 if (!c->RankKnown(shape_handle)) {
65 // without rank we cannot check
66 return std::make_tuple(true, std::string());
67 }
68 return open3d::ml::op_util::CheckShape<Opt>(GetShapeVector(c, shape_handle),
69 std::forward<TDimX>(dimex),
70 std::forward<TArgs>(args)...);
71}
72
73inline std::vector<open3d::ml::op_util::DimValue> GetShapeVector(
74 const tensorflow::Tensor& tensor) {
75 using namespace open3d::ml::op_util;
76
77 std::vector<DimValue> shape;
78 for (int i = 0; i < tensor.dims(); ++i) {
79 shape.push_back(tensor.dim_size(i));
80 }
81 return shape;
82}
83
85 class TDimX,
86 class... TArgs>
87std::tuple<bool, std::string> CheckShape(const tensorflow::Tensor& tensor,
88 TDimX&& dimex,
89 TArgs&&... args) {
90 return open3d::ml::op_util::CheckShape<Opt>(GetShapeVector(tensor),
91 std::forward<TDimX>(dimex),
92 std::forward<TArgs>(args)...);
93}
94
95//
96// Helper function for creating a ShapeHandle from dim expressions.
97// Dim expressions which are not constant will translate to unknown dims in
98// the returned shape handle.
99//
100// Usage:
101// // ctx is of type tensorflow::shape_inference::InferenceContext*
102// {
103// using namespace open3d::ml::op_util;
104// Dim w("w");
105// Dim h("h");
106// CHECK_SHAPE_HANDLE(ctx, handle1, 10, w, h); // checks if the first dim is
107// // 10 and assigns w and h
108// // based on the shape of
109// // handle1
110//
111// CHECK_SHAPE_HANDLE(ctx, handle2, 10, 20, h); // this checks if the the
112// // last dim of handle2 matches the
113// // last dim of handle1. The first
114// // two dims must match 10, 20.
115//
116// ShapeHandle out_shape = MakeShapeHandle(ctx, Dim(), h, w);
117// ctx->set_output(0, out_shape);
118// }
119//
120//
121// See "../ShapeChecking.h" for more info and limitations.
122//
123template <class TDimX, class... TArgs>
124::tensorflow::shape_inference::ShapeHandle MakeShapeHandle(
125 ::tensorflow::shape_inference::InferenceContext* ctx,
126 TDimX&& dimex,
127 TArgs&&... args) {
128 using namespace tensorflow::shape_inference;
129 using namespace open3d::ml::op_util;
130 std::vector<int64_t> shape = CreateDimVector(
131 int64_t(InferenceContext::kUnknownDim), dimex, args...);
132 std::vector<DimensionHandle> dims;
133 for (int64_t d : shape) {
134 dims.push_back(ctx->MakeDim(d));
135 }
136 return ctx->MakeShape(dims);
137}
138
139//
140// Macros for checking the shape of ShapeHandle during shape inference.
141//
142// Usage:
143// // ctx is of type tensorflow::shape_inference::InferenceContext*
144// {
145// using namespace open3d::ml::op_util;
146// Dim w("w");
147// Dim h("h");
148// CHECK_SHAPE_HANDLE(ctx, handle1, 10, w, h); // checks if the first dim is
149// // 10 and assigns w and h
150// // based on the shape of
151// // handle1
152//
153// CHECK_SHAPE_HANDLE(ctx, handle2, 10, 20, h); // this checks if the the
154// // last dim of handle2 matches the
155// // last dim of handle1. The first
156// // two dims must match 10, 20.
157// }
158//
159//
160// See "../ShapeChecking.h" for more info and limitations.
161//
162#define CHECK_SHAPE_HANDLE(ctx, shape_handle, ...) \
163 do { \
164 bool cs_success_; \
165 std::string cs_errstr_; \
166 std::tie(cs_success_, cs_errstr_) = \
167 CheckShape(ctx, shape_handle, __VA_ARGS__); \
168 if (TF_PREDICT_FALSE(!cs_success_)) { \
169 return tensorflow::errors::InvalidArgument( \
170 "invalid shape for '" #shape_handle "', " + cs_errstr_); \
171 } \
172 } while (0)
173
174#define CHECK_SHAPE_HANDLE_COMBINE_FIRST_DIMS(ctx, shape_handle, ...) \
175 do { \
176 bool cs_success_; \
177 std::string cs_errstr_; \
178 std::tie(cs_success_, cs_errstr_) = \
179 CheckShape<CSOpt::COMBINE_FIRST_DIMS>(ctx, shape_handle, \
180 __VA_ARGS__); \
181 if (TF_PREDICT_FALSE(!cs_success_)) { \
182 return tensorflow::errors::InvalidArgument( \
183 "invalid shape for '" #shape_handle "', " + cs_errstr_); \
184 } \
185 } while (0)
186
187#define CHECK_SHAPE_HANDLE_IGNORE_FIRST_DIMS(ctx, shape_handle, ...) \
188 do { \
189 bool cs_success_; \
190 std::string cs_errstr_; \
191 std::tie(cs_success_, cs_errstr_) = \
192 CheckShape<CSOpt::IGNORE_FIRST_DIMS>(ctx, shape_handle, \
193 __VA_ARGS__); \
194 if (TF_PREDICT_FALSE(!cs_success_)) { \
195 return tensorflow::errors::InvalidArgument( \
196 "invalid shape for '" #shape_handle "', " + cs_errstr_); \
197 } \
198 } while (0)
199
200#define CHECK_SHAPE_HANDLE_COMBINE_LAST_DIMS(ctx, shape_handle, ...) \
201 do { \
202 bool cs_success_; \
203 std::string cs_errstr_; \
204 std::tie(cs_success_, cs_errstr_) = \
205 CheckShape<CSOpt::COMBINE_LAST_DIMS>(ctx, shape_handle, \
206 __VA_ARGS__); \
207 if (TF_PREDICT_FALSE(!cs_success_)) { \
208 return tensorflow::errors::InvalidArgument( \
209 "invalid shape for '" #shape_handle "', " + cs_errstr_); \
210 } \
211 } while (0)
212
213#define CHECK_SHAPE_HANDLE_IGNORE_LAST_DIMS(ctx, shape_handle, ...) \
214 do { \
215 bool cs_success_; \
216 std::string cs_errstr_; \
217 std::tie(cs_success_, cs_errstr_) = \
218 CheckShape<CSOpt::IGNORE_LAST_DIMS>(ctx, shape_handle, \
219 __VA_ARGS__); \
220 if (TF_PREDICT_FALSE(!cs_success_)) { \
221 return tensorflow::errors::InvalidArgument( \
222 "invalid shape for '" #shape_handle "', " + cs_errstr_); \
223 } \
224 } while (0)
225
226//
227// Macros for checking the shape of Tensors.
228// Usage:
229// // ctx is of type tensorflow::OpKernelContext*
230// {
231// using namespace open3d::ml::op_util;
232// Dim w("w");
233// Dim h("h");
234// CHECK_SHAPE(ctx, tensor1, 10, w, h); // checks if the first dim is 10
235// // and assigns w and h based on
236// // the shape of tensor1
237//
238// CHECK_SHAPE(ctx, tensor2, 10, 20, h); // this checks if the the last dim
239// // of tensor2 matches the last dim
240// // of tensor1. The first two dims
241// // must match 10, 20.
242// }
243//
244//
245// See "../ShapeChecking.h" for more info and limitations.
246//
247#define CHECK_SHAPE(ctx, tensor, ...) \
248 do { \
249 bool cs_success_; \
250 std::string cs_errstr_; \
251 std::tie(cs_success_, cs_errstr_) = CheckShape(tensor, __VA_ARGS__); \
252 OP_REQUIRES( \
253 ctx, cs_success_, \
254 tensorflow::errors::InvalidArgument( \
255 "invalid shape for '" #tensor "', " + cs_errstr_)); \
256 } while (0)
257
258#define CHECK_SHAPE_COMBINE_FIRST_DIMS(ctx, tensor, ...) \
259 do { \
260 bool cs_success_; \
261 std::string cs_errstr_; \
262 std::tie(cs_success_, cs_errstr_) = \
263 CheckShape<CSOpt::COMBINE_FIRST_DIMS>(tensor, __VA_ARGS__); \
264 OP_REQUIRES( \
265 ctx, cs_success_, \
266 tensorflow::errors::InvalidArgument( \
267 "invalid shape for '" #tensor "', " + cs_errstr_)); \
268 } while (0)
269
270#define CHECK_SHAPE_IGNORE_FIRST_DIMS(ctx, tensor, ...) \
271 do { \
272 bool cs_success_; \
273 std::string cs_errstr_; \
274 std::tie(cs_success_, cs_errstr_) = \
275 CheckShape<CSOpt::IGNORE_FIRST_DIMS>(tensor, __VA_ARGS__); \
276 OP_REQUIRES( \
277 ctx, cs_success_, \
278 tensorflow::errors::InvalidArgument( \
279 "invalid shape for '" #tensor "', " + cs_errstr_)); \
280 } while (0)
281
282#define CHECK_SHAPE_COMBINE_LAST_DIMS(ctx, tensor, ...) \
283 do { \
284 bool cs_success_; \
285 std::string cs_errstr_; \
286 std::tie(cs_success_, cs_errstr_) = \
287 CheckShape<CSOpt::COMBINE_LAST_DIMS>(tensor, __VA_ARGS__); \
288 OP_REQUIRES( \
289 ctx, cs_success_, \
290 tensorflow::errors::InvalidArgument( \
291 "invalid shape for '" #tensor "', " + cs_errstr_)); \
292 } while (0)
293
294#define CHECK_SHAPE_IGNORE_LAST_DIMS(ctx, tensor, ...) \
295 do { \
296 bool cs_success_; \
297 std::string cs_errstr_; \
298 std::tie(cs_success_, cs_errstr_) = \
299 CheckShape<CSOpt::IGNORE_LAST_DIMS>(tensor, __VA_ARGS__); \
300 OP_REQUIRES( \
301 ctx, cs_success_, \
302 tensorflow::errors::InvalidArgument( \
303 "invalid shape for '" #tensor "', " + cs_errstr_)); \
304 } while (0)
std::vector< open3d::ml::op_util::DimValue > GetShapeVector(::tensorflow::shape_inference::InferenceContext *c, ::tensorflow::shape_inference::ShapeHandle shape_handle)
Definition: TensorFlowHelper.h:35
std::tuple< bool, std::string > CheckShape(::tensorflow::shape_inference::InferenceContext *c, ::tensorflow::shape_inference::ShapeHandle shape_handle, TDimX &&dimex, TArgs &&... args)
Definition: TensorFlowHelper.h:59
::tensorflow::shape_inference::ShapeHandle MakeShapeHandle(::tensorflow::shape_inference::InferenceContext *ctx, TDimX &&dimex, TArgs &&... args)
Definition: TensorFlowHelper.h:124
Class for representing a possibly unknown dimension value.
Definition: ShapeChecking.h:38
Definition: ShapeChecking.h:35
CSOpt
Check shape options.
Definition: ShapeChecking.h:424
void CreateDimVector(std::vector< int64_t > &out, int64_t unknown_dim_value, TDimX dimex)
Definition: ShapeChecking.h:377