Open3D (C++ API)  0.16.0
InterpolateOpKernel.h
Go to the documentation of this file.
1// ----------------------------------------------------------------------------
2// - Open3D: www.open3d.org -
3// ----------------------------------------------------------------------------
4// The MIT License (MIT)
5//
6// Copyright (c) 2018-2021 www.open3d.org
7//
8// Permission is hereby granted, free of charge, to any person obtaining a copy
9// of this software and associated documentation files (the "Software"), to deal
10// in the Software without restriction, including without limitation the rights
11// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12// copies of the Software, and to permit persons to whom the Software is
13// furnished to do so, subject to the following conditions:
14//
15// The above copyright notice and this permission notice shall be included in
16// all copies or substantial portions of the Software.
17//
18// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
23// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
24// IN THE SOFTWARE.
25// ----------------------------------------------------------------------------
26
27#pragma once
28
29#include "../TensorFlowHelper.h"
30#include "tensorflow/core/framework/op.h"
31#include "tensorflow/core/framework/op_kernel.h"
32#include "tensorflow/core/lib/core/errors.h"
33
34class ThreeNNOpKernel : public tensorflow::OpKernel {
35public:
36 explicit ThreeNNOpKernel(tensorflow::OpKernelConstruction* construction)
37 : OpKernel(construction) {}
38
39 void Compute(tensorflow::OpKernelContext* context) override {
40 using namespace tensorflow;
41
42 const Tensor& inp_tensor = context->input(0);
43 OP_REQUIRES(
44 context,
45 inp_tensor.dims() == 3 && inp_tensor.shape().dim_size(2) == 3,
46 errors::InvalidArgument("ThreeNN expects "
47 "(batch_size,num_points,3) inp shape"));
48 int batch_size = inp_tensor.shape().dim_size(0);
49 int pts_num_out = inp_tensor.shape().dim_size(1);
50 auto inp_flat = inp_tensor.flat<float>();
51 const float* inp = &(inp_flat(0));
52
53 const Tensor& data_tensor = context->input(1);
54 OP_REQUIRES(
55 context,
56 data_tensor.dims() == 3 && data_tensor.shape().dim_size(2) == 3,
57 errors::InvalidArgument(
58 "ThreeNN expects "
59 "(batch_size,num_points,3) data shape"));
60 int pts_num_in = data_tensor.shape().dim_size(1);
61 auto data_flat = data_tensor.flat<float>();
62 const float* data = &(data_flat(0));
63
64 Tensor* out_dist;
65 OP_REQUIRES_OK(
66 context,
67 context->allocate_output(
68 0, TensorShape{batch_size, pts_num_out, 3}, &out_dist));
69 auto out_flat0 = out_dist->flat<float>();
70 float* out0 = &(out_flat0(0));
71
72 Tensor* out_idx;
73 OP_REQUIRES_OK(
74 context,
75 context->allocate_output(
76 1, TensorShape{batch_size, pts_num_out, 3}, &out_idx));
77 auto out_flat1 = out_idx->flat<int>();
78 int* out1 = &(out_flat1(0));
79
80 Kernel(context, batch_size, pts_num_out, pts_num_in, inp, data, out0,
81 out1);
82 }
83
84 virtual void Kernel(tensorflow::OpKernelContext* context,
85 int b,
86 int n,
87 int m,
88 const float* unknown,
89 const float* known,
90 float* dist2,
91 int* idx) = 0;
92};
93
94class ThreeInterpolateOpKernel : public tensorflow::OpKernel {
95public:
97 tensorflow::OpKernelConstruction* construction)
98 : OpKernel(construction) {}
99
100 void Compute(tensorflow::OpKernelContext* context) override {
101 using namespace tensorflow;
102
103 const Tensor& inp_tensor = context->input(0);
104 OP_REQUIRES(
105 context, inp_tensor.dims() == 3,
106 errors::InvalidArgument("ThreeInterpolate expects "
107 "(batch_size,num_points,3) inp shape"));
108 int batch_size = inp_tensor.shape().dim_size(0);
109 int C = inp_tensor.shape().dim_size(1);
110 int M = inp_tensor.shape().dim_size(2);
111 auto inp_flat = inp_tensor.flat<float>();
112 const float* inp = &(inp_flat(0));
113
114 const Tensor& idx_tensor = context->input(1);
115 OP_REQUIRES(
116 context, idx_tensor.dims() == 3,
117 errors::InvalidArgument("ThreeInterpolate expects "
118 "(batch_size,num_points,3) idx shape"));
119 int N = idx_tensor.shape().dim_size(1);
120 auto idx_flat = idx_tensor.flat<int>();
121 const int* idx = &(idx_flat(0));
122
123 const Tensor& weights_tensor = context->input(2);
124 OP_REQUIRES(context, weights_tensor.dims() == 3,
125 errors::InvalidArgument(
126 "ThreeInterpolate expects "
127 "(batch_size,num_points,3) weights shape"));
128 auto weights_flat = weights_tensor.flat<float>();
129 const float* weights = &(weights_flat(0));
130
131 Tensor* out_tensor;
132 OP_REQUIRES_OK(context,
133 context->allocate_output(
134 0, TensorShape{batch_size, C, N}, &out_tensor));
135 auto out_flat = out_tensor->flat<float>();
136 float* out = &(out_flat(0));
137
138 Kernel(context, batch_size, C, M, N, inp, idx, weights, out);
139 }
140
141 virtual void Kernel(tensorflow::OpKernelContext* context,
142 int b,
143 int c,
144 int m,
145 int n,
146 const float* points,
147 const int* idx,
148 const float* weight,
149 float* out) = 0;
150};
151
152class ThreeInterpolateGradOpKernel : public tensorflow::OpKernel {
153public:
155 tensorflow::OpKernelConstruction* construction)
156 : OpKernel(construction) {
157 OP_REQUIRES_OK(construction, construction->GetAttr("M", &M));
158 }
159
160 void Compute(tensorflow::OpKernelContext* context) override {
161 using namespace tensorflow;
162
163 const Tensor& inp_tensor = context->input(0);
164 OP_REQUIRES(
165 context, inp_tensor.dims() == 3,
166 errors::InvalidArgument("ThreeInterpolateGrad expects "
167 "(batch_size,num_points,3) inp shape"));
168 int batch_size = inp_tensor.shape().dim_size(0);
169 int C = inp_tensor.shape().dim_size(1);
170 int N = inp_tensor.shape().dim_size(2);
171 auto inp_flat = inp_tensor.flat<float>();
172 const float* inp = &(inp_flat(0));
173
174 const Tensor& idx_tensor = context->input(1);
175 OP_REQUIRES(
176 context, idx_tensor.dims() == 3,
177 errors::InvalidArgument("ThreeInterpolateGrad expects "
178 "(batch_size,num_points,3) idx shape"));
179 auto idx_flat = idx_tensor.flat<int>();
180 const int* idx = &(idx_flat(0));
181
182 const Tensor& weights_tensor = context->input(2);
183 OP_REQUIRES(context, weights_tensor.dims() == 3,
184 errors::InvalidArgument(
185 "ThreeInterpolateGrad expects "
186 "(batch_size,num_points,3) weights shape"));
187 auto weights_flat = weights_tensor.flat<float>();
188 const float* weights = &(weights_flat(0));
189
190 Tensor* out_tensor;
191 OP_REQUIRES_OK(context,
192 context->allocate_output(
193 0, TensorShape{batch_size, C, M}, &out_tensor));
194 auto out_flat = out_tensor->flat<float>();
195 float* out = &(out_flat(0));
196
197 Kernel(context, batch_size, C, N, M, inp, idx, weights, out);
198 }
199
200 virtual void Kernel(tensorflow::OpKernelContext* context,
201 int b,
202 int c,
203 int n,
204 int m,
205 const float* grad_out,
206 const int* idx,
207 const float* weight,
208 float* grad_points) = 0;
209
210protected:
211 int M;
212};
ImGuiContext * context
Definition: Window.cpp:95
Definition: InterpolateOpKernel.h:152
ThreeInterpolateGradOpKernel(tensorflow::OpKernelConstruction *construction)
Definition: InterpolateOpKernel.h:154
int M
Definition: InterpolateOpKernel.h:211
virtual void Kernel(tensorflow::OpKernelContext *context, int b, int c, int n, int m, const float *grad_out, const int *idx, const float *weight, float *grad_points)=0
void Compute(tensorflow::OpKernelContext *context) override
Definition: InterpolateOpKernel.h:160
Definition: InterpolateOpKernel.h:94
ThreeInterpolateOpKernel(tensorflow::OpKernelConstruction *construction)
Definition: InterpolateOpKernel.h:96
void Compute(tensorflow::OpKernelContext *context) override
Definition: InterpolateOpKernel.h:100
virtual void Kernel(tensorflow::OpKernelContext *context, int b, int c, int m, int n, const float *points, const int *idx, const float *weight, float *out)=0
Definition: InterpolateOpKernel.h:34
void Compute(tensorflow::OpKernelContext *context) override
Definition: InterpolateOpKernel.h:39
ThreeNNOpKernel(tensorflow::OpKernelConstruction *construction)
Definition: InterpolateOpKernel.h:36
virtual void Kernel(tensorflow::OpKernelContext *context, int b, int n, int m, const float *unknown, const float *known, float *dist2, int *idx)=0
int points
Definition: FilePCD.cpp:73
const char const char value recording_handle imu_sample recording_handle uint8_t data
Definition: K4aPlugin.cpp:288