1// ----------------------------------------------------------------------------
2// - Open3D: www.open3d.org -
3// ----------------------------------------------------------------------------
4// Copyright (c) 2018-2023 www.open3d.org
5// SPDX-License-Identifier: MIT
6// ----------------------------------------------------------------------------
10#include "open3d/ml/impl/continuous_conv/ContinuousConvTypes.h"
16/// Maps coordinates in a sphere with radius 1 to a cylinder. The input and
17/// output range of the coordinates is [-1,1]. The cylinder axis is along z.
19inline __device__ void MapSphereToCylinder(T& x, T& y, T& z) {
20 T sq_norm = x * x + y * y + z * z;
22 if (sq_norm < T(1e-8)) {
27 T norm = sqrt(sq_norm);
28 if (T(5.0 / 4) * z * z > (x * x + y * y)) {
29 T s = sqrt(3 * norm / (norm + abs(z)));
32 z = copysign(norm, z);
34 T s = norm / sqrt(x * x + y * y);
41/// Maps coordinates in a cylinder with radius 1 to a cube. The input and
42/// output range of the coordinates is [-1,1]. The cylinder axis is along z.
44inline __device__ void MapCylinderToCube(T& x, T& y, T& z) {
45 T sq_norm_xy = x * x + y * y;
47 if (sq_norm_xy < T(1e-8)) {
52 T norm_xy = sqrt(sq_norm_xy);
54 if (abs(y) <= abs(x)) {
55 T tmp = copysign(norm_xy, x);
56 y = tmp * T(4 / M_PI) * atan(y / x);
58 } else if (abs(x) <= abs(y)) {
59 T tmp = copysign(norm_xy, y);
60 x = tmp * T(4 / M_PI) * atan(x / y);
65/// Computes the filter coordinates.
66/// The input to this function are coordinates relative to the point where the
67/// convolution is evaluated. Coordinates are usually in the range
68/// [-extent/2,extent/2] with extent as the edge length of the bounding box of
69/// the filter shape. The output is a coordinate within the filter array, i.e.
70/// the range is [0, filter_size.xyz], if the point was inside the filter shape.
72/// The simplest filter shape is a cuboid (MAPPING=IDENTITY) and the
73/// transformation is simply [-extent/2,extent/2] -> [0, filter_size.xyz].
74/// The other type of shape that is implemented is a sphere with
75/// MAPPING=BALL_TO_CUBE_RADIAL or MAPPING=BALL_TO_CUBE_VOLUME_PRESERVING.
77/// \tparam ALIGN_CORNERS If true then the voxel centers of the outer voxels
78/// of the filter array are mapped to the boundary of the filter shape.
79/// If false then the boundary of the filter array is mapped to the
80/// boundary of the filter shape.
82/// \tparam MAPPING The mapping that is applied to the input coordinates.
83/// - BALL_TO_CUBE_RADIAL uses radial stretching to map a sphere to
85/// - BALL_TO_CUBE_VOLUME_PRESERVING is using a more expensive volume
86/// preserving mapping to map a sphere to a cube.
87/// - IDENTITY no mapping is applied to the coordinates.
89/// \param x x coordinates. Input and output variable.
90/// \param y y coordinates. Input and output variable.
91/// \param z z coordinates. Input and output variable.
93/// \param filter_size_x The spatial size of the filter array in voxels for
95/// \param filter_size_y Like \p filter_size_x
96/// \param filter_size_z Like \p filter_size_x
98/// \param inv_extents_x The reciproval of the spatial extent of the filter
99/// in coordinate units for the x direction.
100/// \param inv_extents_y Like \p inv_extents_x
101/// \param inv_extents_z Like \p inv_extents_x
103/// \param offset_x An offset for shifting the center. Can be used to
104/// implement discrete filters with even filter size.
105/// \param offset_y Like \p offset_x
106/// \param offset_z Like \p offset_x
108template <bool ALIGN_CORNERS, CoordinateMapping MAPPING, class T>
109inline __device__ void ComputeFilterCoordinates(T& x,
112 const int& filter_size_x,
113 const int& filter_size_y,
114 const int& filter_size_z,
115 const T& inv_extent_x,
116 const T& inv_extent_y,
117 const T& inv_extent_z,
121 if (MAPPING == CoordinateMapping::BALL_TO_CUBE_RADIAL) {
122 // x,y,z is now in the range [-1,1]
123 x *= 2 * inv_extent_x;
124 y *= 2 * inv_extent_y;
125 z *= 2 * inv_extent_z;
127 T radius = sqrt(x * x + y * y + z * z);
128 T abs_max = max(abs(x), max(abs(y), abs(z)));
129 if (abs_max < T(1e-8)) {
134 // map to the unit cube with edge length 1 and range [-0.5,0.5]
135 x *= T(0.5) * radius / abs_max;
136 y *= T(0.5) * radius / abs_max;
137 z *= T(0.5) * radius / abs_max;
139 } else if (MAPPING == CoordinateMapping::BALL_TO_CUBE_VOLUME_PRESERVING) {
140 // x,y,z is now in the range [-1,1]
141 x *= 2 * inv_extent_x;
142 y *= 2 * inv_extent_y;
143 z *= 2 * inv_extent_z;
144 MapSphereToCylinder(x, y, z);
145 MapCylinderToCube(x, y, z);
150 // map to the unit cube with edge length 1 and range [-0.5,0.5]
161 x *= filter_size_x - 1;
162 y *= filter_size_y - 1;
163 z *= filter_size_z - 1;
174 x += filter_size_x / 2;
175 y += filter_size_y / 2;
176 z += filter_size_z / 2;
178 // shift if the filter size is even
179 if (filter_size_x % 2 == 0) x -= T(0.5);
180 if (filter_size_y % 2 == 0) y -= T(0.5);
181 if (filter_size_z % 2 == 0) z -= T(0.5);
185/// Computes interpolation weights and indices
187/// \tparam INTERPOLATION One of LINEAR, LINEAR_BORDER, NEAREST_NEIGHBOR.
188/// LINEAR is trilinear interpolation with coordinate clamping.
189/// LINEAR_BORDER uses a zero border if outside the range.
190/// NEAREST_NEIGHBOR uses the nearest neighbor instead of interpolation.
192/// \param w The interpolation weights with range [0,1].
194/// \param idx The linear index addressing a value in the filter. The
195/// linear index accounts for the number of channels given passed in
198/// \param x x coordinate with range [0, filter_size.x-1]. Values outside
199/// the range are handled.
201/// \param y Like \p x
202/// \param z Like \p x
204/// \param filter_size_x The spatial size of the filter array in voxels.
205/// \param filter_size_y Like \p filter_size_x
206/// \param filter_size_z Like \p filter_size_x
208/// \param num_channels The number of channels of the filter.
209template <InterpolationMode INTERPOLATION, class T>
210inline __device__ void Interpolate(T* w,
215 const int& filter_size_x,
216 const int& filter_size_y,
217 const int& filter_size_z,
218 int num_channels = 1) {
219 if (INTERPOLATION == InterpolationMode::NEAREST_NEIGHBOR) {
224 // clamp to the valid range
225 xi = max(0, min(xi, filter_size_x - 1));
226 yi = max(0, min(yi, filter_size_y - 1));
227 zi = max(0, min(zi, filter_size_z - 1));
228 idx[0] = num_channels *
229 (zi * filter_size_y * filter_size_x + yi * filter_size_x + xi);
231 } else if (INTERPOLATION == InterpolationMode::LINEAR_BORDER) {
232 int xi0 = int(floor(x));
235 int yi0 = int(floor(y));
238 int zi0 = int(floor(z));
245 if (zi0 < 0 || yi0 < 0 || xi0 < 0 || zi0 >= filter_size_z ||
246 yi0 >= filter_size_y || xi0 >= filter_size_x) {
250 idx[0] = zi0 * filter_size_y * filter_size_x + yi0 * filter_size_x +
252 w[0] = (1 - a) * (1 - b) * (1 - c);
255 if (zi0 < 0 || yi0 < 0 || xi1 < 0 || zi0 >= filter_size_z ||
256 yi0 >= filter_size_y || xi1 >= filter_size_x) {
260 idx[1] = zi0 * filter_size_y * filter_size_x + yi0 * filter_size_x +
262 w[1] = (a) * (1 - b) * (1 - c);
265 if (zi0 < 0 || yi1 < 0 || xi0 < 0 || zi0 >= filter_size_z ||
266 yi1 >= filter_size_y || xi0 >= filter_size_x) {
270 idx[2] = zi0 * filter_size_y * filter_size_x + yi1 * filter_size_x +
272 w[2] = (1 - a) * (b) * (1 - c);
275 if (zi0 < 0 || yi1 < 0 || xi1 < 0 || zi0 >= filter_size_z ||
276 yi1 >= filter_size_y || xi1 >= filter_size_x) {
280 idx[3] = zi0 * filter_size_y * filter_size_x + yi1 * filter_size_x +
282 w[3] = (a) * (b) * (1 - c);
285 if (zi1 < 0 || yi0 < 0 || xi0 < 0 || zi1 >= filter_size_z ||
286 yi0 >= filter_size_y || xi0 >= filter_size_x) {
290 idx[4] = zi1 * filter_size_y * filter_size_x + yi0 * filter_size_x +
292 w[4] = (1 - a) * (1 - b) * (c);
295 if (zi1 < 0 || yi0 < 0 || xi1 < 0 || zi1 >= filter_size_z ||
296 yi0 >= filter_size_y || xi1 >= filter_size_x) {
300 idx[5] = zi1 * filter_size_y * filter_size_x + yi0 * filter_size_x +
302 w[5] = (a) * (1 - b) * (c);
305 if (zi1 < 0 || yi1 < 0 || xi0 < 0 || zi1 >= filter_size_z ||
306 yi1 >= filter_size_y || xi0 >= filter_size_x) {
310 idx[6] = zi1 * filter_size_y * filter_size_x + yi1 * filter_size_x +
312 w[6] = (1 - a) * (b) * (c);
315 if (zi1 < 0 || yi1 < 0 || xi1 < 0 || zi1 >= filter_size_z ||
316 yi1 >= filter_size_y || xi1 >= filter_size_x) {
320 idx[7] = zi1 * filter_size_y * filter_size_x + yi1 * filter_size_x +
322 w[7] = (a) * (b) * (c);
327 int xi0 = max(0, min(int(x), filter_size_x - 1));
328 int xi1 = max(0, min(xi0 + 1, filter_size_x - 1));
330 int yi0 = max(0, min(int(y), filter_size_y - 1));
331 int yi1 = max(0, min(yi0 + 1, filter_size_y - 1));
333 int zi0 = max(0, min(int(z), filter_size_z - 1));
334 int zi1 = max(0, min(zi0 + 1, filter_size_z - 1));
336 T a = max(T(0), min(x - xi0, T(1)));
337 T b = max(T(0), min(y - yi0, T(1)));
338 T c = max(T(0), min(z - zi0, T(1)));
340 w[0] = (1 - a) * (1 - b) * (1 - c);
341 w[1] = (a) * (1 - b) * (1 - c);
342 w[2] = (1 - a) * (b) * (1 - c);
343 w[3] = (a) * (b) * (1 - c);
344 w[4] = (1 - a) * (1 - b) * (c);
345 w[5] = (a) * (1 - b) * (c);
346 w[6] = (1 - a) * (b) * (c);
347 w[7] = (a) * (b) * (c);
349 idx[0] = (zi0 * filter_size_y * filter_size_x + yi0 * filter_size_x +
352 idx[1] = (zi0 * filter_size_y * filter_size_x + yi0 * filter_size_x +
355 idx[2] = (zi0 * filter_size_y * filter_size_x + yi1 * filter_size_x +
358 idx[3] = (zi0 * filter_size_y * filter_size_x + yi1 * filter_size_x +
361 idx[4] = (zi1 * filter_size_y * filter_size_x + yi0 * filter_size_x +
364 idx[5] = (zi1 * filter_size_y * filter_size_x + yi0 * filter_size_x +
367 idx[6] = (zi1 * filter_size_y * filter_size_x + yi1 * filter_size_x +
370 idx[7] = (zi1 * filter_size_y * filter_size_x + yi1 * filter_size_x +