Open3D (C++ API)  0.17.0
Loading...
Searching...
No Matches
SamplingOpKernel.h
Go to the documentation of this file.
1// ----------------------------------------------------------------------------
2// - Open3D: www.open3d.org -
3// ----------------------------------------------------------------------------
4// Copyright (c) 2018-2023 www.open3d.org
5// SPDX-License-Identifier: MIT
6// ----------------------------------------------------------------------------
7
8#pragma once
9
10#include "../TensorFlowHelper.h"
11#include "tensorflow/core/framework/op.h"
12#include "tensorflow/core/framework/op_kernel.h"
13#include "tensorflow/core/lib/core/errors.h"
14
15class FurthestPointSamplingOpKernel : public tensorflow::OpKernel {
16public:
18 tensorflow::OpKernelConstruction* construction)
19 : OpKernel(construction) {
20 using namespace tensorflow;
21
22 OP_REQUIRES_OK(construction,
23 construction->GetAttr("sample_size", &sample_size));
24 OP_REQUIRES(construction, sample_size > 0,
25 errors::InvalidArgument(
26 "FurthestPointSampling expects positive npoint"));
27 }
28
29 void Compute(tensorflow::OpKernelContext* context) override {
30 using namespace tensorflow;
31
32 const Tensor& inp_tensor = context->input(0);
33 OP_REQUIRES(
34 context,
35 inp_tensor.dims() == 3 && inp_tensor.shape().dim_size(2) == 3,
36 errors::InvalidArgument("FurthestPointSampling expects "
37 "(batch_size,num_points,3) inp shape"));
38 int batch_size = inp_tensor.shape().dim_size(0);
39 int pts_size = inp_tensor.shape().dim_size(1);
40 auto inp_flat = inp_tensor.flat<float>();
41 const float* inp = &(inp_flat(0));
42
43 Tensor* out_tensor;
44 OP_REQUIRES_OK(context, context->allocate_output(
45 0, TensorShape{batch_size, sample_size},
46 &out_tensor));
47 auto out_flat = out_tensor->flat<int>();
48 int* out = &(out_flat(0));
49
50 Tensor temp_tensor;
51 OP_REQUIRES_OK(context,
52 context->allocate_temp(DataTypeToEnum<float>::value,
53 TensorShape{batch_size, pts_size},
54 &temp_tensor));
55 auto temp_flat = temp_tensor.flat<float>();
56 float* temp = &(temp_flat(0));
57
58 Kernel(context, batch_size, pts_size, sample_size, inp, temp, out);
59 }
60
61 virtual void Kernel(tensorflow::OpKernelContext* context,
62 int b,
63 int n,
64 int m,
65 const float* dataset,
66 float* temp,
67 int* idxs) = 0;
68
69protected:
71};
ImGuiContext * context
Definition Window.cpp:76
Definition SamplingOpKernel.h:15
void Compute(tensorflow::OpKernelContext *context) override
Definition SamplingOpKernel.h:29
virtual void Kernel(tensorflow::OpKernelContext *context, int b, int n, int m, const float *dataset, float *temp, int *idxs)=0
int sample_size
Definition SamplingOpKernel.h:70
FurthestPointSamplingOpKernel(tensorflow::OpKernelConstruction *construction)
Definition SamplingOpKernel.h:17