Open3D (C++ API)  0.15.1
ContinuousConvTranspose.h
Go to the documentation of this file.
1// ----------------------------------------------------------------------------
2// - Open3D: www.open3d.org -
3// ----------------------------------------------------------------------------
4// The MIT License (MIT)
5//
6// Copyright (c) 2018-2021 www.open3d.org
7//
8// Permission is hereby granted, free of charge, to any person obtaining a copy
9// of this software and associated documentation files (the "Software"), to deal
10// in the Software without restriction, including without limitation the rights
11// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12// copies of the Software, and to permit persons to whom the Software is
13// furnished to do so, subject to the following conditions:
14//
15// The above copyright notice and this permission notice shall be included in
16// all copies or substantial portions of the Software.
17//
18// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
23// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
24// IN THE SOFTWARE.
25// ----------------------------------------------------------------------------
26
27#pragma once
28
29#include <tbb/parallel_for.h>
30
32
33namespace open3d {
34namespace ml {
35namespace impl {
36
39template <class TFeat,
40 class TOut,
41 class TReal,
42 class TIndex,
43 InterpolationMode INTERPOLATION,
44 CoordinateMapping MAPPING,
45 bool ALIGN_CORNERS,
46 bool INDIVIDUAL_EXTENT,
47 bool ISOTROPIC_EXTENT,
48 bool NORMALIZE>
50 TOut* out_features,
51 const std::vector<int>& filter_dims,
52 const TFeat* filter,
53 size_t num_out,
54 const TReal* out_positions,
55 const TFeat* out_importance,
56 size_t num_inp,
57 const TReal* inp_positions,
58 const TFeat* inp_features,
59 const TFeat* inp_neighbors_importance_sum,
60 const int64_t* inp_neighbors_row_splits,
61 size_t neighbors_index_size,
62 const TIndex* neighbors_index,
63 const TFeat* neighbors_importance,
64 const int64_t* neighbors_row_splits,
65 const TReal* extents,
66 const TReal* offsets) {
67 const bool NEIGHBORS_IMPORTANCE = inp_neighbors_importance_sum;
68 const int VECSIZE = 32;
69 typedef Eigen::Array<TReal, VECSIZE, 1> Vec_t;
70 typedef InterpolationVec<TReal, VECSIZE, INTERPOLATION> InterpolationVec_t;
71 InterpolationVec_t interpolation;
72
73 const int in_channels = filter_dims[filter_dims.size() - 2];
74 const int out_channels = filter_dims[filter_dims.size() - 1];
75
76 int spatial_filter_size = 1;
77 for (int i = 0; i < 3; ++i) spatial_filter_size *= filter_dims[i];
78 Eigen::Array<int, 3, 1> filter_size_xyz(filter_dims[2], filter_dims[1],
79 filter_dims[0]);
80
81 memset(out_features, 0, sizeof(TOut) * num_out * out_channels);
82
83 tbb::parallel_for(
84 tbb::blocked_range<size_t>(0, num_out, 32),
85 [&](const tbb::blocked_range<size_t>& r) {
86 int range_length = r.end() - r.begin();
87
88 Eigen::Matrix<TFeat, Eigen::Dynamic, Eigen::Dynamic> B(
89 in_channels * spatial_filter_size, range_length);
90 B.setZero();
91
92 typedef Eigen::Array<TFeat, VECSIZE, Eigen::Dynamic> Matrix;
93 Matrix infeat(VECSIZE, in_channels);
94
95 Eigen::Array<TReal, 3, 1> offsets_(offsets[0], offsets[1],
96 offsets[2]);
97
98 Eigen::Array<TReal, VECSIZE, 3> inv_extents;
99 if (INDIVIDUAL_EXTENT == false) {
100 if (ISOTROPIC_EXTENT) {
101 inv_extents = 1 / extents[0];
102 } else {
103 inv_extents.col(0) = 1 / extents[0];
104 inv_extents.col(1) = 1 / extents[1];
105 inv_extents.col(2) = 1 / extents[2];
106 }
107 }
108
109 for (size_t out_idx = r.begin(); out_idx != r.end();
110 ++out_idx) {
111 const int out_col = out_idx - r.begin();
112 const size_t neighbor_start = neighbors_row_splits[out_idx];
113 const size_t neighbor_end =
114 (out_idx + 1 < num_out
115 ? neighbors_row_splits[out_idx + 1]
116 : neighbors_index_size);
117
118 typename InterpolationVec_t::Weight_t interp_weights;
119 typename InterpolationVec_t::Idx_t interp_indices;
120
121 int vec_valid_count = 0;
122 Vec_t x, y, z;
123
124 // set to zero to avoid problems with vectors with less than
125 // VECSIZE valid entries
126 x.setZero();
127 y.setZero();
128 z.setZero();
129 for (size_t n = neighbor_start; n < neighbor_end; ++n) {
130 const size_t inp_idx = neighbors_index[n];
131
132 const int i = vec_valid_count;
133 x(i) = out_positions[out_idx * 3 + 0] -
134 inp_positions[inp_idx * 3 + 0];
135 y(i) = out_positions[out_idx * 3 + 1] -
136 inp_positions[inp_idx * 3 + 1];
137 z(i) = out_positions[out_idx * 3 + 2] -
138 inp_positions[inp_idx * 3 + 2];
139
140 if (INDIVIDUAL_EXTENT) {
141 if (ISOTROPIC_EXTENT) {
142 inv_extents.row(i) = 1 / extents[inp_idx];
143 } else {
144 inv_extents(i, 0) =
145 1 / extents[3 * inp_idx + 0];
146 inv_extents(i, 1) =
147 1 / extents[3 * inp_idx + 1];
148 inv_extents(i, 2) =
149 1 / extents[3 * inp_idx + 2];
150 }
151 }
152
153 TFeat n_importance = NEIGHBORS_IMPORTANCE
154 ? neighbors_importance[n]
155 : TFeat(1);
156 for (int ic = 0; ic < in_channels; ++ic)
157 infeat(i, ic) =
158 inp_features[inp_idx * in_channels + ic] *
159 n_importance;
160
161 if (NORMALIZE) {
162 TFeat normalizer(1);
163 if (NEIGHBORS_IMPORTANCE) {
164 if (inp_neighbors_importance_sum[inp_idx] !=
165 TFeat(0))
166 normalizer /= inp_neighbors_importance_sum
167 [inp_idx];
168 } else {
169 size_t num_inp_neighbors;
170 const size_t inp_neighbor_start =
171 inp_neighbors_row_splits[inp_idx];
172 const size_t inp_neighbor_end =
173 inp_neighbors_row_splits[inp_idx + 1];
174 num_inp_neighbors =
175 inp_neighbor_end - inp_neighbor_start;
176 if (num_inp_neighbors > 0)
177 normalizer /= TFeat(num_inp_neighbors);
178 }
179 for (int ic = 0; ic < in_channels; ++ic)
180 infeat(i, ic) *= normalizer;
181 }
182
183 ++vec_valid_count;
184 if (vec_valid_count == VECSIZE ||
185 n + 1 == neighbor_end) {
186 ComputeFilterCoordinates<ALIGN_CORNERS, MAPPING>(
187 x, y, z, filter_size_xyz, inv_extents,
188 offsets_);
189 interpolation.Interpolate(
190 interp_weights, interp_indices, x, y, z,
191 filter_size_xyz, in_channels);
192 for (int k = 0; k < vec_valid_count; ++k) {
193 for (int j = 0; j < InterpolationVec_t::Size();
194 ++j) {
195 for (int ic = 0; ic < in_channels; ++ic)
196 B(interp_indices(j, k) + ic, out_col) +=
197 TFeat(interp_weights(j, k)) *
198 infeat(k, ic);
199 }
200 }
201 vec_valid_count = 0;
202 }
203 }
204
205 } // out_idx
206
207 Eigen::Map<const Eigen::Matrix<TFeat, Eigen::Dynamic,
208 Eigen::Dynamic>>
209 A(filter, out_channels,
210 spatial_filter_size * in_channels);
211 Eigen::Map<Eigen::Matrix<TOut, Eigen::Dynamic, Eigen::Dynamic>>
212 C(out_features + (r.begin() * out_channels),
213 out_channels, range_length);
214
215 C = (A * B).template cast<TOut>();
216 if (out_importance) {
217 for (int i = 0; i < range_length; ++i)
218 C.col(i) *= TOut(out_importance[r.begin() + i]);
219 }
220 });
221}
222
304template <class TFeat, class TOut, class TReal, class TIndex>
305void CConvTransposeComputeFeaturesCPU(TOut* out_features,
306 const std::vector<int>& filter_dims,
307 const TFeat* filter,
308 size_t num_out,
309 const TReal* out_positions,
310 const TFeat* out_importance,
311 size_t num_inp,
312 const TReal* inp_positions,
313 const TFeat* inp_features,
314 const TFeat* inp_neighbors_importance_sum,
315 const int64_t* inp_neighbors_row_splits,
316 size_t neighbors_index_size,
317 const TIndex* neighbors_index,
318 const TFeat* neighbors_importance,
319 const int64_t* neighbors_row_splits,
320 const TReal* extents,
321 const TReal* offsets,
322 InterpolationMode interpolation,
323 CoordinateMapping coordinate_mapping,
324 bool align_corners,
325 bool individual_extent,
326 bool isotropic_extent,
327 bool normalize) {
328#define FN_PARAMETERS \
329 out_features, filter_dims, filter, num_out, out_positions, out_importance, \
330 num_inp, inp_positions, inp_features, \
331 inp_neighbors_importance_sum, inp_neighbors_row_splits, \
332 neighbors_index_size, neighbors_index, neighbors_importance, \
333 neighbors_row_splits, extents, offsets
334
335#define CALL_TEMPLATE(INTERPOLATION, MAPPING, ALIGN_CORNERS, \
336 INDIVIDUAL_EXTENT, ISOTROPIC_EXTENT, NORMALIZE) \
337 if (INTERPOLATION == interpolation && MAPPING == coordinate_mapping && \
338 ALIGN_CORNERS == align_corners && \
339 INDIVIDUAL_EXTENT == individual_extent && \
340 ISOTROPIC_EXTENT == isotropic_extent && NORMALIZE == normalize) \
341 _CConvTransposeComputeFeaturesCPU<TFeat, TOut, TReal, TIndex, \
342 INTERPOLATION, MAPPING, \
343 ALIGN_CORNERS, INDIVIDUAL_EXTENT, \
344 ISOTROPIC_EXTENT, NORMALIZE>( \
345 FN_PARAMETERS);
346
347#define CALL_TEMPLATE2(INTERPOLATION, MAPPING) \
348 CALL_TEMPLATE(INTERPOLATION, MAPPING, true, true, true, true) \
349 CALL_TEMPLATE(INTERPOLATION, MAPPING, true, true, true, false) \
350 CALL_TEMPLATE(INTERPOLATION, MAPPING, true, true, false, true) \
351 CALL_TEMPLATE(INTERPOLATION, MAPPING, true, true, false, false) \
352 CALL_TEMPLATE(INTERPOLATION, MAPPING, true, false, true, true) \
353 CALL_TEMPLATE(INTERPOLATION, MAPPING, true, false, true, false) \
354 CALL_TEMPLATE(INTERPOLATION, MAPPING, true, false, false, true) \
355 CALL_TEMPLATE(INTERPOLATION, MAPPING, true, false, false, false) \
356 CALL_TEMPLATE(INTERPOLATION, MAPPING, false, true, true, true) \
357 CALL_TEMPLATE(INTERPOLATION, MAPPING, false, true, true, false) \
358 CALL_TEMPLATE(INTERPOLATION, MAPPING, false, true, false, true) \
359 CALL_TEMPLATE(INTERPOLATION, MAPPING, false, true, false, false) \
360 CALL_TEMPLATE(INTERPOLATION, MAPPING, false, false, true, true) \
361 CALL_TEMPLATE(INTERPOLATION, MAPPING, false, false, true, false) \
362 CALL_TEMPLATE(INTERPOLATION, MAPPING, false, false, false, true) \
363 CALL_TEMPLATE(INTERPOLATION, MAPPING, false, false, false, false)
364
365#define CALL_TEMPLATE3(INTERPOLATION) \
366 CALL_TEMPLATE2(INTERPOLATION, CoordinateMapping::BALL_TO_CUBE_RADIAL) \
367 CALL_TEMPLATE2(INTERPOLATION, \
368 CoordinateMapping::BALL_TO_CUBE_VOLUME_PRESERVING) \
369 CALL_TEMPLATE2(INTERPOLATION, CoordinateMapping::IDENTITY)
370
371#define CALL_TEMPLATE4 \
372 CALL_TEMPLATE3(InterpolationMode::LINEAR) \
373 CALL_TEMPLATE3(InterpolationMode::LINEAR_BORDER) \
374 CALL_TEMPLATE3(InterpolationMode::NEAREST_NEIGHBOR)
375
377
378#undef CALL_TEMPLATE
379#undef CALL_TEMPLATE2
380#undef CALL_TEMPLATE3
381#undef CALL_TEMPLATE4
382
383#undef FN_PARAMETERS
384}
385
386} // namespace impl
387} // namespace ml
388} // namespace open3d
#define CALL_TEMPLATE4
#define VECSIZE
InterpolationMode
Definition: ContinuousConvTypes.h:37
void _CConvTransposeComputeFeaturesCPU(TOut *out_features, const std::vector< int > &filter_dims, const TFeat *filter, size_t num_out, const TReal *out_positions, const TFeat *out_importance, size_t num_inp, const TReal *inp_positions, const TFeat *inp_features, const TFeat *inp_neighbors_importance_sum, const int64_t *inp_neighbors_row_splits, size_t neighbors_index_size, const TIndex *neighbors_index, const TFeat *neighbors_importance, const int64_t *neighbors_row_splits, const TReal *extents, const TReal *offsets)
Definition: ContinuousConvTranspose.h:49
void CConvTransposeComputeFeaturesCPU(TOut *out_features, const std::vector< int > &filter_dims, const TFeat *filter, size_t num_out, const TReal *out_positions, const TFeat *out_importance, size_t num_inp, const TReal *inp_positions, const TFeat *inp_features, const TFeat *inp_neighbors_importance_sum, const int64_t *inp_neighbors_row_splits, size_t neighbors_index_size, const TIndex *neighbors_index, const TFeat *neighbors_importance, const int64_t *neighbors_row_splits, const TReal *extents, const TReal *offsets, InterpolationMode interpolation, CoordinateMapping coordinate_mapping, bool align_corners, bool individual_extent, bool isotropic_extent, bool normalize)
Definition: ContinuousConvTranspose.h:305
CoordinateMapping
Definition: ContinuousConvTypes.h:45
Definition: PinholeCameraIntrinsic.cpp:35
Class for computing interpolation weights.
Definition: CoordinateTransformation.h:204