Open3D (C++ API)  0.17.0
Loading...
Searching...
No Matches
InterpolateOpKernel.h
Go to the documentation of this file.
1// ----------------------------------------------------------------------------
2// - Open3D: www.open3d.org -
3// ----------------------------------------------------------------------------
4// Copyright (c) 2018-2023 www.open3d.org
5// SPDX-License-Identifier: MIT
6// ----------------------------------------------------------------------------
7
8#pragma once
9
10#include "../TensorFlowHelper.h"
11#include "tensorflow/core/framework/op.h"
12#include "tensorflow/core/framework/op_kernel.h"
13#include "tensorflow/core/lib/core/errors.h"
14
15class ThreeNNOpKernel : public tensorflow::OpKernel {
16public:
17 explicit ThreeNNOpKernel(tensorflow::OpKernelConstruction* construction)
18 : OpKernel(construction) {}
19
20 void Compute(tensorflow::OpKernelContext* context) override {
21 using namespace tensorflow;
22
23 const Tensor& inp_tensor = context->input(0);
24 OP_REQUIRES(
25 context,
26 inp_tensor.dims() == 3 && inp_tensor.shape().dim_size(2) == 3,
27 errors::InvalidArgument("ThreeNN expects "
28 "(batch_size,num_points,3) inp shape"));
29 int batch_size = inp_tensor.shape().dim_size(0);
30 int pts_num_out = inp_tensor.shape().dim_size(1);
31 auto inp_flat = inp_tensor.flat<float>();
32 const float* inp = &(inp_flat(0));
33
34 const Tensor& data_tensor = context->input(1);
35 OP_REQUIRES(
36 context,
37 data_tensor.dims() == 3 && data_tensor.shape().dim_size(2) == 3,
38 errors::InvalidArgument(
39 "ThreeNN expects "
40 "(batch_size,num_points,3) data shape"));
41 int pts_num_in = data_tensor.shape().dim_size(1);
42 auto data_flat = data_tensor.flat<float>();
43 const float* data = &(data_flat(0));
44
45 Tensor* out_dist;
46 OP_REQUIRES_OK(
47 context,
48 context->allocate_output(
49 0, TensorShape{batch_size, pts_num_out, 3}, &out_dist));
50 auto out_flat0 = out_dist->flat<float>();
51 float* out0 = &(out_flat0(0));
52
53 Tensor* out_idx;
54 OP_REQUIRES_OK(
55 context,
56 context->allocate_output(
57 1, TensorShape{batch_size, pts_num_out, 3}, &out_idx));
58 auto out_flat1 = out_idx->flat<int>();
59 int* out1 = &(out_flat1(0));
60
61 Kernel(context, batch_size, pts_num_out, pts_num_in, inp, data, out0,
62 out1);
63 }
64
65 virtual void Kernel(tensorflow::OpKernelContext* context,
66 int b,
67 int n,
68 int m,
69 const float* unknown,
70 const float* known,
71 float* dist2,
72 int* idx) = 0;
73};
74
75class ThreeInterpolateOpKernel : public tensorflow::OpKernel {
76public:
78 tensorflow::OpKernelConstruction* construction)
79 : OpKernel(construction) {}
80
81 void Compute(tensorflow::OpKernelContext* context) override {
82 using namespace tensorflow;
83
84 const Tensor& inp_tensor = context->input(0);
85 OP_REQUIRES(
86 context, inp_tensor.dims() == 3,
87 errors::InvalidArgument("ThreeInterpolate expects "
88 "(batch_size,num_points,3) inp shape"));
89 int batch_size = inp_tensor.shape().dim_size(0);
90 int C = inp_tensor.shape().dim_size(1);
91 int M = inp_tensor.shape().dim_size(2);
92 auto inp_flat = inp_tensor.flat<float>();
93 const float* inp = &(inp_flat(0));
94
95 const Tensor& idx_tensor = context->input(1);
96 OP_REQUIRES(
97 context, idx_tensor.dims() == 3,
98 errors::InvalidArgument("ThreeInterpolate expects "
99 "(batch_size,num_points,3) idx shape"));
100 int N = idx_tensor.shape().dim_size(1);
101 auto idx_flat = idx_tensor.flat<int>();
102 const int* idx = &(idx_flat(0));
103
104 const Tensor& weights_tensor = context->input(2);
105 OP_REQUIRES(context, weights_tensor.dims() == 3,
106 errors::InvalidArgument(
107 "ThreeInterpolate expects "
108 "(batch_size,num_points,3) weights shape"));
109 auto weights_flat = weights_tensor.flat<float>();
110 const float* weights = &(weights_flat(0));
111
112 Tensor* out_tensor;
113 OP_REQUIRES_OK(context,
114 context->allocate_output(
115 0, TensorShape{batch_size, C, N}, &out_tensor));
116 auto out_flat = out_tensor->flat<float>();
117 float* out = &(out_flat(0));
118
119 Kernel(context, batch_size, C, M, N, inp, idx, weights, out);
120 }
121
122 virtual void Kernel(tensorflow::OpKernelContext* context,
123 int b,
124 int c,
125 int m,
126 int n,
127 const float* points,
128 const int* idx,
129 const float* weight,
130 float* out) = 0;
131};
132
133class ThreeInterpolateGradOpKernel : public tensorflow::OpKernel {
134public:
136 tensorflow::OpKernelConstruction* construction)
137 : OpKernel(construction) {
138 OP_REQUIRES_OK(construction, construction->GetAttr("M", &M));
139 }
140
141 void Compute(tensorflow::OpKernelContext* context) override {
142 using namespace tensorflow;
143
144 const Tensor& inp_tensor = context->input(0);
145 OP_REQUIRES(
146 context, inp_tensor.dims() == 3,
147 errors::InvalidArgument("ThreeInterpolateGrad expects "
148 "(batch_size,num_points,3) inp shape"));
149 int batch_size = inp_tensor.shape().dim_size(0);
150 int C = inp_tensor.shape().dim_size(1);
151 int N = inp_tensor.shape().dim_size(2);
152 auto inp_flat = inp_tensor.flat<float>();
153 const float* inp = &(inp_flat(0));
154
155 const Tensor& idx_tensor = context->input(1);
156 OP_REQUIRES(
157 context, idx_tensor.dims() == 3,
158 errors::InvalidArgument("ThreeInterpolateGrad expects "
159 "(batch_size,num_points,3) idx shape"));
160 auto idx_flat = idx_tensor.flat<int>();
161 const int* idx = &(idx_flat(0));
162
163 const Tensor& weights_tensor = context->input(2);
164 OP_REQUIRES(context, weights_tensor.dims() == 3,
165 errors::InvalidArgument(
166 "ThreeInterpolateGrad expects "
167 "(batch_size,num_points,3) weights shape"));
168 auto weights_flat = weights_tensor.flat<float>();
169 const float* weights = &(weights_flat(0));
170
171 Tensor* out_tensor;
172 OP_REQUIRES_OK(context,
173 context->allocate_output(
174 0, TensorShape{batch_size, C, M}, &out_tensor));
175 auto out_flat = out_tensor->flat<float>();
176 float* out = &(out_flat(0));
177
178 Kernel(context, batch_size, C, N, M, inp, idx, weights, out);
179 }
180
181 virtual void Kernel(tensorflow::OpKernelContext* context,
182 int b,
183 int c,
184 int n,
185 int m,
186 const float* grad_out,
187 const int* idx,
188 const float* weight,
189 float* grad_points) = 0;
190
191protected:
192 int M;
193};
Eigen::Matrix3Xd M
Definition PointCloudPlanarPatchDetection.cpp:507
ImGuiContext * context
Definition Window.cpp:76
Definition InterpolateOpKernel.h:133
ThreeInterpolateGradOpKernel(tensorflow::OpKernelConstruction *construction)
Definition InterpolateOpKernel.h:135
int M
Definition InterpolateOpKernel.h:192
virtual void Kernel(tensorflow::OpKernelContext *context, int b, int c, int n, int m, const float *grad_out, const int *idx, const float *weight, float *grad_points)=0
void Compute(tensorflow::OpKernelContext *context) override
Definition InterpolateOpKernel.h:141
Definition InterpolateOpKernel.h:75
ThreeInterpolateOpKernel(tensorflow::OpKernelConstruction *construction)
Definition InterpolateOpKernel.h:77
void Compute(tensorflow::OpKernelContext *context) override
Definition InterpolateOpKernel.h:81
virtual void Kernel(tensorflow::OpKernelContext *context, int b, int c, int m, int n, const float *points, const int *idx, const float *weight, float *out)=0
Definition InterpolateOpKernel.h:15
void Compute(tensorflow::OpKernelContext *context) override
Definition InterpolateOpKernel.h:20
ThreeNNOpKernel(tensorflow::OpKernelConstruction *construction)
Definition InterpolateOpKernel.h:17
virtual void Kernel(tensorflow::OpKernelContext *context, int b, int n, int m, const float *unknown, const float *known, float *dist2, int *idx)=0
int points
Definition FilePCD.cpp:54