gtsam 4.2.0
gtsam
DecisionTree.h
Go to the documentation of this file.
1/* ----------------------------------------------------------------------------
2
3 * GTSAM Copyright 2010, Georgia Tech Research Corporation,
4 * Atlanta, Georgia 30332-0415
5 * All Rights Reserved
6 * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
7
8 * See LICENSE for the license information
9
10 * -------------------------------------------------------------------------- */
11
20#pragma once
21
22#include <gtsam/base/Testable.h>
23#include <gtsam/base/types.h>
25
26#include <boost/serialization/nvp.hpp>
27#include <boost/shared_ptr.hpp>
28#include <functional>
29#include <iostream>
30#include <map>
31#include <set>
32#include <sstream>
33#include <string>
34#include <utility>
35#include <vector>
36
37namespace gtsam {
38
46 template<typename L, typename Y>
48 protected:
50 static bool DefaultCompare(const Y& a, const Y& b) {
51 return a == b;
52 }
53
54 public:
55 using LabelFormatter = std::function<std::string(L)>;
56 using ValueFormatter = std::function<std::string(Y)>;
57 using CompareFunc = std::function<bool(const Y&, const Y&)>;
58
60 using Unary = std::function<Y(const Y&)>;
61 using UnaryAssignment = std::function<Y(const Assignment<L>&, const Y&)>;
62 using Binary = std::function<Y(const Y&, const Y&)>;
63
65 using LabelC = std::pair<L, size_t>;
66
68 struct Leaf;
69 struct Choice;
70
72 struct Node {
73 using Ptr = boost::shared_ptr<const Node>;
74
75#ifdef DT_DEBUG_MEMORY
76 static int nrNodes;
77#endif
78
79 // Constructor
80 Node() {
81#ifdef DT_DEBUG_MEMORY
82 std::cout << ++nrNodes << " constructed " << id() << std::endl;
83 std::cout.flush();
84#endif
85 }
86
87 // Destructor
88 virtual ~Node() {
89#ifdef DT_DEBUG_MEMORY
90 std::cout << --nrNodes << " destructed " << id() << std::endl;
91 std::cout.flush();
92#endif
93 }
94
95 // Unique ID for dot files
96 const void* id() const { return this; }
97
98 // everything else is virtual, no documentation here as internal
99 virtual void print(const std::string& s,
100 const LabelFormatter& labelFormatter,
101 const ValueFormatter& valueFormatter) const = 0;
102 virtual void dot(std::ostream& os, const LabelFormatter& labelFormatter,
103 const ValueFormatter& valueFormatter,
104 bool showZero) const = 0;
105 virtual bool sameLeaf(const Leaf& q) const = 0;
106 virtual bool sameLeaf(const Node& q) const = 0;
107 virtual bool equals(const Node& other, const CompareFunc& compare =
108 &DefaultCompare) const = 0;
109 virtual const Y& operator()(const Assignment<L>& x) const = 0;
110 virtual Ptr apply(const Unary& op) const = 0;
111 virtual Ptr apply(const UnaryAssignment& op,
112 const Assignment<L>& assignment) const = 0;
113 virtual Ptr apply_f_op_g(const Node&, const Binary&) const = 0;
114 virtual Ptr apply_g_op_fL(const Leaf&, const Binary&) const = 0;
115 virtual Ptr apply_g_op_fC(const Choice&, const Binary&) const = 0;
116 virtual Ptr choose(const L& label, size_t index) const = 0;
117 virtual bool isLeaf() const = 0;
118
119 private:
122 template <class ARCHIVE>
123 void serialize(ARCHIVE& ar, const unsigned int /*version*/) {}
124 };
127 public:
129 using NodePtr = typename Node::Ptr;
130
133
134 protected:
138 template<typename It, typename ValueIt>
139 NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const;
140
151 template <typename M, typename X>
153 std::function<L(const M&)> L_of_M,
154 std::function<Y(const X&)> Y_of_X) const;
155
156 public:
159
161 DecisionTree();
162
164 explicit DecisionTree(const Y& y);
165
167 DecisionTree(const L& label, const Y& y1, const Y& y2);
168
170 DecisionTree(const LabelC& label, const Y& y1, const Y& y2);
171
173 DecisionTree(const std::vector<LabelC>& labelCs, const std::vector<Y>& ys);
174
176 DecisionTree(const std::vector<LabelC>& labelCs, const std::string& table);
177
179 template<typename Iterator>
180 DecisionTree(Iterator begin, Iterator end, const L& label);
181
183 DecisionTree(const L& label, const DecisionTree& f0,
184 const DecisionTree& f1);
185
193 template <typename X, typename Func>
194 DecisionTree(const DecisionTree<L, X>& other, Func Y_of_X);
195
206 template <typename M, typename X, typename Func>
207 DecisionTree(const DecisionTree<M, X>& other, const std::map<M, L>& map,
208 Func Y_of_X);
209
213
221 void print(const std::string& s, const LabelFormatter& labelFormatter,
222 const ValueFormatter& valueFormatter) const;
223
224 // Testable
225 bool equals(const DecisionTree& other,
226 const CompareFunc& compare = &DefaultCompare) const;
227
231
233 virtual ~DecisionTree() = default;
234
236 bool empty() const { return !root_; }
237
239 bool operator==(const DecisionTree& q) const;
240
242 const Y& operator()(const Assignment<L>& x) const;
243
258 template <typename Func>
259 void visit(Func f) const;
260
275 template <typename Func>
276 void visitLeaf(Func f) const;
277
292 template <typename Func>
293 void visitWith(Func f) const;
294
296 size_t nrLeaves() const;
297
313 template <typename Func, typename X>
314 X fold(Func f, X x0) const;
315
317 std::set<L> labels() const;
318
320 DecisionTree apply(const Unary& op) const;
321
330 DecisionTree apply(const UnaryAssignment& op) const;
331
333 DecisionTree apply(const DecisionTree& g, const Binary& op) const;
334
337 DecisionTree choose(const L& label, size_t index) const {
338 NodePtr newRoot = root_->choose(label, index);
339 return DecisionTree(newRoot);
340 }
341
343 DecisionTree combine(const L& label, size_t cardinality,
344 const Binary& op) const;
345
347 DecisionTree combine(const LabelC& labelC, const Binary& op) const {
348 return combine(labelC.first, labelC.second, op);
349 }
350
352 void dot(std::ostream& os, const LabelFormatter& labelFormatter,
353 const ValueFormatter& valueFormatter, bool showZero = true) const;
354
356 void dot(const std::string& name, const LabelFormatter& labelFormatter,
357 const ValueFormatter& valueFormatter, bool showZero = true) const;
358
360 std::string dot(const LabelFormatter& labelFormatter,
361 const ValueFormatter& valueFormatter,
362 bool showZero = true) const;
363
366
367 // internal use only
368 explicit DecisionTree(const NodePtr& root);
369
370 // internal use only
371 template<typename Iterator> NodePtr
372 compose(Iterator begin, Iterator end, const L& label) const;
373
375
376 private:
379 template <class ARCHIVE>
380 void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
381 ar& BOOST_SERIALIZATION_NVP(root_);
382 }
383 }; // DecisionTree
384
385 template <class L, class Y>
386 struct traits<DecisionTree<L, Y>> : public Testable<DecisionTree<L, Y>> {};
387
391 template<typename L, typename Y>
393 const typename DecisionTree<L, Y>::Unary& op) {
394 return f.apply(op);
395 }
396
398 template<typename L, typename Y>
400 const typename DecisionTree<L, Y>::UnaryAssignment& op) {
401 return f.apply(op);
402 }
403
405 template<typename L, typename Y>
407 const DecisionTree<L, Y>& g,
408 const typename DecisionTree<L, Y>::Binary& op) {
409 return f.apply(g, op);
410 }
411
418 template <typename L, typename T1, typename T2>
419 std::pair<DecisionTree<L, T1>, DecisionTree<L, T2> > unzip(
420 const DecisionTree<L, std::pair<T1, T2> >& input) {
421 return std::make_pair(
422 DecisionTree<L, T1>(input, [](std::pair<T1, T2> i) { return i.first; }),
424 [](std::pair<T1, T2> i) { return i.second; }));
425 }
426
427} // namespace gtsam
Concept check for values that can be used in unit tests.
Typedefs for easier changing of types.
An assignment from labels to a discrete value index (size_t)
Global functions in a separate testing namespace.
Definition: chartTesting.h:28
std::pair< DecisionTree< L, T1 >, DecisionTree< L, T2 > > unzip(const DecisionTree< L, std::pair< T1, T2 > > &input)
unzip a DecisionTree with std::pair values.
Definition: DecisionTree.h:419
DecisionTree< L, Y > apply(const DecisionTree< L, Y > &f, const typename DecisionTree< L, Y >::Unary &op)
free versions of apply
Definition: DecisionTree.h:392
A manifold defines a space in which there is a notion of a linear tangent space that can be centered ...
Definition: concepts.h:30
Template to create a binary predicate.
Definition: Testable.h:111
A helper that implements the traits interface for GTSAM types.
Definition: Testable.h:151
An assignment from labels to value index (size_t).
Definition: Assignment.h:37
Definition: DecisionTree-inl.h:52
Definition: DecisionTree-inl.h:172
Decision Tree L = label for variables Y = function range (any algebra), e.g., bool,...
Definition: DecisionTree.h:47
DecisionTree apply(const Unary &op) const
apply Unary operation "op" to f
Definition: DecisionTree-inl.h:889
DecisionTree choose(const L &label, size_t index) const
create a new function where value(label)==index It's like "restrict" in Darwiche09book pg329,...
Definition: DecisionTree.h:337
NodePtr convertFrom(const typename DecisionTree< M, X >::NodePtr &f, std::function< L(const M &)> L_of_M, std::function< Y(const X &)> Y_of_X) const
Convert from a DecisionTree<M, X> to DecisionTree<L, Y>.
Definition: DecisionTree-inl.h:671
DecisionTree combine(const LabelC &labelC, const Binary &op) const
combine with LabelC for convenience
Definition: DecisionTree.h:347
NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const
Internal recursive function to create from keys, cardinalities, and Y values.
Definition: DecisionTree-inl.h:630
virtual ~DecisionTree()=default
Make virtual.
static bool DefaultCompare(const Y &a, const Y &b)
Default method for comparison of two objects of type Y.
Definition: DecisionTree.h:50
typename Node::Ptr NodePtr
---------------------— Node base class ------------------------—
Definition: DecisionTree.h:129
std::set< L > labels() const
Retrieve all unique labels as a set.
Definition: DecisionTree-inl.h:853
bool empty() const
Check if tree is empty.
Definition: DecisionTree.h:236
void visit(Func f) const
Visit all leaves in depth-first fashion.
Definition: DecisionTree-inl.h:736
void visitLeaf(Func f) const
Visit all leaves in depth-first fashion.
Definition: DecisionTree-inl.h:773
std::function< Y(const Y &)> Unary
Handy typedefs for unary and binary function types.
Definition: DecisionTree.h:60
X fold(Func f, X x0) const
Fold a binary function over the tree, returning accumulator.
Definition: DecisionTree-inl.h:833
NodePtr root_
A DecisionTree just contains the root. TODO(dellaert): make protected.
Definition: DecisionTree.h:132
void print(const std::string &s, const LabelFormatter &labelFormatter, const ValueFormatter &valueFormatter) const
GTSAM-style print.
Definition: DecisionTree-inl.h:872
DecisionTree combine(const L &label, size_t cardinality, const Binary &op) const
combine subtrees on key with binary operation "op"
Definition: DecisionTree-inl.h:937
void visitWith(Func f) const
Visit all leaves in depth-first fashion.
Definition: DecisionTree-inl.h:816
const Y & operator()(const Assignment< L > &x) const
evaluate
Definition: DecisionTree-inl.h:884
void dot(std::ostream &os, const LabelFormatter &labelFormatter, const ValueFormatter &valueFormatter, bool showZero=true) const
output to graphviz format, stream version
Definition: DecisionTree-inl.h:949
friend class boost::serialization::access
Serialization function.
Definition: DecisionTree.h:378
bool operator==(const DecisionTree &q) const
equality
Definition: DecisionTree-inl.h:879
std::pair< L, size_t > LabelC
A label annotated with cardinality.
Definition: DecisionTree.h:65
size_t nrLeaves() const
Return the number of leaves in the tree.
Definition: DecisionTree-inl.h:823
DecisionTree()
Default constructor (for serialization)
Definition: DecisionTree-inl.h:462
---------------------— Node base class ------------------------—
Definition: DecisionTree.h:72
friend class boost::serialization::access
Serialization function.
Definition: DecisionTree.h:121