35#ifndef VIGRA_RF3_VISITORS_HXX
36#define VIGRA_RF3_VISITORS_HXX
40#include "../multi_array.hxx"
41#include "../multi_shape.hxx"
89 template <
typename VISITORS,
typename RF,
typename FEATURES,
typename LABELS>
98 template <
typename TREE,
typename FEATURES,
typename LABELS,
typename WEIGHTS>
105 template <
typename RF,
typename FEATURES,
typename LABELS,
typename WEIGHTS>
115 template <
typename TREE,
179 template <
typename TREE,
typename FEATURES,
typename LABELS,
typename WEIGHTS>
186 double const EPS = 1e-20;
190 is_in_bag_.resize(weights.
size(),
true);
191 for (
size_t i = 0;
i < weights.
size(); ++
i)
193 if (std::abs(weights[
i]) <
EPS)
195 is_in_bag_[
i] =
false;
201 throw std::runtime_error(
"OOBError::visit_before_tree(): The tree has no out-of-bags.");
207 template <
typename VISITORS,
typename RF,
typename FEATURES,
typename LABELS>
215 vigra_precondition(rf.num_trees() > 0,
"OOBError::visit_after_training(): Number of trees must be greater than zero after training.");
216 vigra_precondition(visitors.
size() == rf.num_trees(),
"OOBError::visit_after_training(): Number of visitors must be equal to number of trees.");
217 size_t const num_instances = features.shape()[0];
218 auto const num_features = features.shape()[1];
219 for (
auto vptr : visitors)
220 vigra_precondition(
vptr->is_in_bag_.
size() == num_instances,
"OOBError::visit_after_training(): Some visitors have the wrong number of data points.");
223 typedef typename std::remove_const<LABELS>::type Labels;
226 for (
size_t i = 0;
i < (
size_t)num_instances; ++
i)
230 for (
size_t k = 0;
k < visitors.
size(); ++
k)
231 if (!visitors[
k]->is_in_bag_[
i])
237 if (
pred(0) != labels(
i))
249 std::vector<bool> is_in_bag_;
269 template <
typename TREE,
typename FEATURES,
typename LABELS,
typename WEIGHTS>
279 auto const num_features = features.shape()[1];
283 double const EPS = 1e-20;
285 is_in_bag_.resize(weights.
size(),
true);
286 for (
size_t i = 0;
i < weights.
size(); ++
i)
288 if (std::abs(weights[
i]) <
EPS)
290 is_in_bag_[
i] =
false;
295 throw std::runtime_error(
"VariableImportance::visit_before_tree(): The tree has no out-of-bags.");
301 template <
typename TREE,
317 typename SCORER::Functor functor;
318 auto const region_impurity = functor.region_score(labels, weights, begin, end);
326 template <
typename RF,
typename FEATURES,
typename LABELS,
typename WEIGHTS>
333 typedef typename std::remove_const<FEATURES>::type Features;
334 typedef typename std::remove_const<LABELS>::type Labels;
336 typedef typename Features::value_type FeatureType;
338 auto const num_features = features.shape()[1];
345 copy_out_of_bags(features, labels,
feats,
labs);
363 for (
size_t j = 0;
j < (
size_t)num_features; ++
j)
402 template <
typename VISITORS,
typename RF,
typename FEATURES,
typename LABELS>
409 vigra_precondition(rf.num_trees() > 0,
"VariableImportance::visit_after_training(): Number of trees must be greater than zero after training.");
410 vigra_precondition(visitors.
size() == rf.num_trees(),
"VariableImportance::visit_after_training(): Number of visitors must be equal to number of trees.");
413 auto const num_features = features.shape()[1];
415 for (
auto vptr : visitors)
418 "VariableImportance::visit_after_training(): Shape mismatch.");
464 template <
typename F0,
typename L0,
typename F1,
typename L1>
465 void copy_out_of_bags(
476 for (
auto x : is_in_bag_)
484 for (
size_t i = 0;
i < (
size_t)num_instances; ++
i)
497 std::vector<bool> is_in_bag_;
518template <
typename VISITOR,
typename NEXT = RFStopVisiting,
bool CPY = false>
526 typename std::conditional<CPY, Visitor, Visitor &>::type visitor_;
529 RFVisitorNode(Visitor &
visitor, Next next)
535 explicit RFVisitorNode(Visitor &
visitor)
538 next_(RFStopVisiting())
541 explicit RFVisitorNode(RFVisitorNode<Visitor, Next, !CPY> & other)
543 visitor_(other.visitor_),
547 explicit RFVisitorNode(RFVisitorNode<Visitor, Next, !CPY>
const & other)
549 visitor_(other.visitor_),
553 void visit_before_training()
555 if (visitor_.is_active())
556 visitor_.visit_before_training();
557 next_.visit_before_training();
560 template <
typename VISITORS,
typename RF,
typename FEATURES,
typename LABELS>
561 void visit_after_training(VISITORS & v, RF & rf,
const FEATURES & features,
const LABELS & labels)
563 typedef typename VISITORS::value_type VisitorNodeType;
564 typedef typename VisitorNodeType::Visitor VisitorType;
565 typedef typename VisitorNodeType::Next NextType;
571 if (visitor_.is_active())
573 std::vector<VisitorType*> visitors;
575 visitors.push_back(&x.visitor_);
576 visitor_.visit_after_training(visitors, rf, features, labels);
580 std::vector<NextType> nexts;
582 nexts.push_back(x.next_);
585 next_.visit_after_training(nexts, rf, features, labels);
588 template <
typename TREE,
typename FEATURES,
typename LABELS,
typename WEIGHTS>
589 void visit_before_tree(TREE & tree, FEATURES & features, LABELS & labels, WEIGHTS & weights)
591 if (visitor_.is_active())
592 visitor_.visit_before_tree(tree, features, labels, weights);
593 next_.visit_before_tree(tree, features, labels, weights);
596 template <
typename RF,
typename FEATURES,
typename LABELS,
typename WEIGHTS>
597 void visit_after_tree(RF & rf,
602 if (visitor_.is_active())
603 visitor_.visit_after_tree(rf, features, labels, weights);
604 next_.visit_after_tree(rf, features, labels, weights);
607 template <
typename TREE,
613 void visit_after_split(TREE & tree,
622 if (visitor_.is_active())
623 visitor_.visit_after_split(tree, features, labels, weights, scorer, begin, split, end);
624 next_.visit_after_split(tree, features, labels, weights, scorer, begin, split, end);
634template <
typename VISITOR>
637 typedef detail::RFVisitorNode<typename VISITOR::Visitor, typename VisitorCopy<typename VISITOR::Next>::type,
true> type;
654detail::RFVisitorNode<A>
657 typedef detail::RFVisitorNode<A>
_0_t;
662template<
typename A,
typename B>
663detail::RFVisitorNode<A, detail::RFVisitorNode<B> >
664create_visitor(A & a, B & b)
666 typedef detail::RFVisitorNode<B> _1_t;
668 typedef detail::RFVisitorNode<A, _1_t> _0_t;
673template<
typename A,
typename B,
typename C>
674detail::RFVisitorNode<A, detail::RFVisitorNode<B, detail::RFVisitorNode<C> > >
677 typedef detail::RFVisitorNode<C> _2_t;
679 typedef detail::RFVisitorNode<B, _2_t> _1_t;
681 typedef detail::RFVisitorNode<A, _1_t> _0_t;
686template<
typename A,
typename B,
typename C,
typename D>
687detail::RFVisitorNode<A, detail::RFVisitorNode<B, detail::RFVisitorNode<C,
688 detail::RFVisitorNode<D> > > >
691 typedef detail::RFVisitorNode<D> _3_t;
693 typedef detail::RFVisitorNode<C, _3_t> _2_t;
695 typedef detail::RFVisitorNode<B, _2_t> _1_t;
697 typedef detail::RFVisitorNode<A, _1_t> _0_t;
702template<
typename A,
typename B,
typename C,
typename D,
typename E>
703detail::RFVisitorNode<A, detail::RFVisitorNode<B, detail::RFVisitorNode<C,
704 detail::RFVisitorNode<D, detail::RFVisitorNode<E> > > > >
707 typedef detail::RFVisitorNode<E> _4_t;
709 typedef detail::RFVisitorNode<D, _4_t> _3_t;
711 typedef detail::RFVisitorNode<C, _3_t> _2_t;
713 typedef detail::RFVisitorNode<B, _2_t> _1_t;
715 typedef detail::RFVisitorNode<A, _1_t> _0_t;
720template<
typename A,
typename B,
typename C,
typename D,
typename E,
722detail::RFVisitorNode<A, detail::RFVisitorNode<B, detail::RFVisitorNode<C,
723 detail::RFVisitorNode<D, detail::RFVisitorNode<E, detail::RFVisitorNode<F> > > > > >
726 typedef detail::RFVisitorNode<F> _5_t;
728 typedef detail::RFVisitorNode<E, _5_t> _4_t;
730 typedef detail::RFVisitorNode<D, _4_t> _3_t;
732 typedef detail::RFVisitorNode<C, _3_t> _2_t;
734 typedef detail::RFVisitorNode<B, _2_t> _1_t;
736 typedef detail::RFVisitorNode<A, _1_t> _0_t;
741template<
typename A,
typename B,
typename C,
typename D,
typename E,
742 typename F,
typename G>
743detail::RFVisitorNode<A, detail::RFVisitorNode<B, detail::RFVisitorNode<C,
744 detail::RFVisitorNode<D, detail::RFVisitorNode<E, detail::RFVisitorNode<F,
745 detail::RFVisitorNode<G> > > > > > >
748 typedef detail::RFVisitorNode<G> _6_t;
750 typedef detail::RFVisitorNode<F, _6_t> _5_t;
752 typedef detail::RFVisitorNode<E, _5_t> _4_t;
754 typedef detail::RFVisitorNode<D, _4_t> _3_t;
756 typedef detail::RFVisitorNode<C, _3_t> _2_t;
758 typedef detail::RFVisitorNode<B, _2_t> _1_t;
760 typedef detail::RFVisitorNode<A, _1_t> _0_t;
765template<
typename A,
typename B,
typename C,
typename D,
typename E,
766 typename F,
typename G,
typename H>
767detail::RFVisitorNode<A, detail::RFVisitorNode<B, detail::RFVisitorNode<C,
768 detail::RFVisitorNode<D, detail::RFVisitorNode<E, detail::RFVisitorNode<F,
769 detail::RFVisitorNode<G, detail::RFVisitorNode<H> > > > > > > >
770create_visitor(A & a, B & b, C & c, D & d, E & e, F & f, G & g, H & h)
772 typedef detail::RFVisitorNode<H> _7_t;
774 typedef detail::RFVisitorNode<G, _7_t> _6_t;
776 typedef detail::RFVisitorNode<F, _6_t> _5_t;
778 typedef detail::RFVisitorNode<E, _5_t> _4_t;
780 typedef detail::RFVisitorNode<D, _4_t> _3_t;
782 typedef detail::RFVisitorNode<C, _3_t> _2_t;
784 typedef detail::RFVisitorNode<B, _2_t> _1_t;
786 typedef detail::RFVisitorNode<A, _1_t> _0_t;
791template<
typename A,
typename B,
typename C,
typename D,
typename E,
792 typename F,
typename G,
typename H,
typename I>
793detail::RFVisitorNode<A, detail::RFVisitorNode<B, detail::RFVisitorNode<C,
794 detail::RFVisitorNode<D, detail::RFVisitorNode<E, detail::RFVisitorNode<F,
795 detail::RFVisitorNode<G, detail::RFVisitorNode<H, detail::RFVisitorNode<I> > > > > > > > >
796create_visitor(A & a, B & b, C & c, D & d, E & e, F & f, G & g, H & h, I & i)
798 typedef detail::RFVisitorNode<I> _8_t;
800 typedef detail::RFVisitorNode<H, _8_t> _7_t;
802 typedef detail::RFVisitorNode<G, _7_t> _6_t;
804 typedef detail::RFVisitorNode<F, _6_t> _5_t;
806 typedef detail::RFVisitorNode<E, _5_t> _4_t;
808 typedef detail::RFVisitorNode<D, _4_t> _3_t;
810 typedef detail::RFVisitorNode<C, _3_t> _2_t;
812 typedef detail::RFVisitorNode<B, _2_t> _1_t;
814 typedef detail::RFVisitorNode<A, _1_t> _0_t;
819template<
typename A,
typename B,
typename C,
typename D,
typename E,
820 typename F,
typename G,
typename H,
typename I,
typename J>
821detail::RFVisitorNode<A, detail::RFVisitorNode<B, detail::RFVisitorNode<C,
822 detail::RFVisitorNode<D, detail::RFVisitorNode<E, detail::RFVisitorNode<F,
823 detail::RFVisitorNode<G, detail::RFVisitorNode<H, detail::RFVisitorNode<I,
824 detail::RFVisitorNode<J> > > > > > > > > >
825create_visitor(A & a, B & b, C & c, D & d, E & e, F & f, G & g, H & h, I & i,
828 typedef detail::RFVisitorNode<J> _9_t;
830 typedef detail::RFVisitorNode<I, _9_t> _8_t;
832 typedef detail::RFVisitorNode<H, _8_t> _7_t;
834 typedef detail::RFVisitorNode<G, _7_t> _6_t;
836 typedef detail::RFVisitorNode<F, _6_t> _5_t;
838 typedef detail::RFVisitorNode<E, _5_t> _4_t;
840 typedef detail::RFVisitorNode<D, _4_t> _3_t;
842 typedef detail::RFVisitorNode<C, _3_t> _2_t;
844 typedef detail::RFVisitorNode<B, _2_t> _1_t;
846 typedef detail::RFVisitorNode<A, _1_t> _0_t;
MultiArrayView subarray(difference_type p, difference_type q) const
Definition multi_array.hxx:1528
const difference_type & shape() const
Definition multi_array.hxx:1648
void reshape(const difference_type &shape)
Definition multi_array.hxx:2861
Class for a single RGB value.
Definition rgbvalue.hxx:128
size_type size() const
Definition tinyvector.hxx:913
TinyVectorView< VALUETYPE, TO-FROM > subarray() const
Definition tinyvector.hxx:887
Class for fixed size vectors.
Definition tinyvector.hxx:1008
Compute the out of bag error.
Definition random_forest_visitors.hxx:173
double oob_err_
Definition random_forest_visitors.hxx:246
void visit_before_tree(TREE &, FEATURES &, LABELS &, WEIGHTS &weights)
Definition random_forest_visitors.hxx:180
void visit_after_training(VISITORS &visitors, RF &rf, const FEATURES &features, const LABELS &labels)
Definition random_forest_visitors.hxx:208
The default visitor node (= "do nothing").
Definition random_forest_visitors.hxx:510
Base class from which all random forest visitors derive.
Definition random_forest_visitors.hxx:69
void visit_before_training()
Do something before training starts.
Definition random_forest_visitors.hxx:80
void visit_before_tree(TREE &, FEATURES &, LABELS &, WEIGHTS &)
Do something before a tree has been learned.
Definition random_forest_visitors.hxx:99
void activate()
Activate the visitor.
Definition random_forest_visitors.hxx:142
void deactivate()
Deactivate the visitor.
Definition random_forest_visitors.hxx:150
void visit_after_tree(RF &, FEATURES &, LABELS &, WEIGHTS &)
Do something after a tree has been learned.
Definition random_forest_visitors.hxx:106
void visit_after_split(TREE &, FEATURES &, LABELS &, WEIGHTS &, SCORER &, ITER, ITER, ITER)
Do something after the split was made.
Definition random_forest_visitors.hxx:121
bool is_active() const
Return whether the visitor is active or not.
Definition random_forest_visitors.hxx:134
void visit_after_training(VISITORS &, RF &, const FEATURES &, const LABELS &)
Do something after all trees have been learned.
Definition random_forest_visitors.hxx:90
Compute the variable importance.
Definition random_forest_visitors.hxx:258
void visit_after_split(TREE &tree, FEATURES &, LABELS &labels, WEIGHTS &weights, SCORER &scorer, ITER begin, ITER, ITER end)
Definition random_forest_visitors.hxx:307
void visit_after_tree(RF &rf, const FEATURES &features, const LABELS &labels, WEIGHTS &)
Definition random_forest_visitors.hxx:327
void visit_before_tree(TREE &tree, FEATURES &features, LABELS &, WEIGHTS &weights)
Definition random_forest_visitors.hxx:270
size_t repetition_count_
Definition random_forest_visitors.hxx:457
void visit_after_training(VISITORS &visitors, RF &rf, const FEATURES &features, const LABELS &)
Definition random_forest_visitors.hxx:403
MultiArray< 2, double > variable_importance_
Definition random_forest_visitors.hxx:452
detail::VisitorNode< A > create_visitor(A &a)
Definition rf_visitors.hxx:344
Definition random_forest_visitors.hxx:636