68template <
typename FEATURES,
typename LABELS>
85template <
typename ACC>
88 template <
typename A,
typename B>
89 void operator()(A & a,
B const & b)
const
98struct RFMapUpdater<ArgMaxAcc>
100 template <
typename A,
typename B>
101 void operator()(A & a,
B const & b)
const
103 auto it = std::max_element(b.
begin(), b.
end());
104 a = std::distance(b.
begin(), it);
111template <
typename FEATURES,
typename LABELS,
typename SAMPLER,
typename SCORER>
120 typedef typename FEATURES::value_type FeatureType;
149template <
typename RF,
typename SCORER,
typename VISITOR,
typename STOP,
typename RANDENGINE>
150void random_forest_single_tree(
151 typename RF::Features
const & features,
159 typedef typename RF::Features Features;
160 typedef typename Features::value_type FeatureType;
162 typedef typename RF::Node Node;
163 typedef typename RF::ACC ACC;
166 static_assert(std::is_same<SplitTests, typename RF::SplitTests>::value,
167 "random_forest_single_tree(): Wrong Random Forest class.");
170 int const num_instances = features.shape()[0];
171 size_t const num_features = features.shape()[1];
172 auto const &
spec = tree.problem_spec_;
174 vigra_precondition(num_instances == labels.
size(),
175 "random_forest_single_tree(): Shape mismatch between features and labels.");
176 vigra_precondition(num_features ==
spec.num_features_,
177 "random_forest_single_tree(): Wrong number of features.");
186 if (options.bootstrap_sampling_)
190 SamplerOptions().withReplacement().stratified(options.use_stratification_),
193 for (
int i = 0;
i <
sampler.sampleSize(); ++
i)
201 if (options.class_weights_.size() > 0)
208 auto const mtry =
spec.actual_mtry_;
213 typedef std::pair<InstanceIter, InstanceIter>
IterPair;
218 auto const rootnode = tree.graph_.addNode();
223 std::vector<double>
priors(
spec.num_classes_, 0.0);
248 for (
auto it = begin; it != end; ++it)
272 auto indices = std::vector<size_t>(options.resample_count_);
273 for (
size_t i = 0;
i < options.resample_count_; ++
i)
288 if (!
score.split_found_)
296 auto const n_left = tree.graph_.addNode();
297 auto const n_right = tree.graph_.addNode();
298 tree.graph_.addArc(node,
n_left);
299 tree.graph_.addArc(node,
n_right);
302 auto const split_iter = std::partition(begin, end,
378 typedef typename Labels::value_type LabelType;
382 pspec.num_instances(features.shape()[0])
383 .num_features(features.shape()[1])
385 .actual_msample(labels.
size());
388 size_t const tree_count = options.tree_count_;
389 vigra_precondition(tree_count > 0,
"random_forest_impl(): tree_count must not be zero.");
390 std::vector<RF>
trees(tree_count);
409 vigra_precondition(options.class_weights_.size() == 0 || options.class_weights_.size() ==
distinct_labels.
size(),
410 "random_forest_impl(): The number of class weights must be 0 or equal to the number of classes.");
413 for (
auto & t :
trees)
414 t.problem_spec_ =
pspec;
417 size_t n_threads = 1;
418 if (options.n_threads_ >= 1)
419 n_threads = options.n_threads_;
420 else if (options.n_threads_ == -1)
421 n_threads = std::thread::hardware_concurrency();
425 std::set<UInt32>
seeds;
430 vigra_assert(
seeds.
size() == n_threads,
"random_forest_impl(): Could not create random seeds.");
434 for (
auto seed :
seeds)
440 visitor.visit_before_training();
446 for (
size_t i = 0;
i < tree_count; ++
i)
453 std::vector<threading::future<void> >
futures;
454 for (
size_t i = 0;
i < tree_count; ++
i)
459 random_forest_single_tree<RF, SCORER, VisitorCopyType, STOP>(features, transformed_labels, options, tree_visitors[i], stop, trees[i], rand_engines[thread_id]);
469 rf.options_ = options;
484template <
typename FEATURES,
typename LABELS,
typename VISITOR,
typename SCORER,
typename RANDENGINE>
494 if (options.max_depth_ > 0)
495 return random_forest_impl<FEATURES, LABELS, VISITOR, SCORER, DepthStop, RANDENGINE>(features, labels, options,
visitor,
DepthStop(options.max_depth_),
randengine);
496 else if (options.min_num_instances_ > 1)
497 return random_forest_impl<FEATURES, LABELS, VISITOR, SCORER, NumInstancesStop, RANDENGINE>(features, labels, options,
visitor,
NumInstancesStop(options.min_num_instances_),
randengine);
498 else if (options.node_complexity_tau_ > 0)
499 return random_forest_impl<FEATURES, LABELS, VISITOR, SCORER, NodeComplexityStop, RANDENGINE>(features, labels, options,
visitor,
NodeComplexityStop(options.node_complexity_tau_),
randengine);
501 return random_forest_impl<FEATURES, LABELS, VISITOR, SCORER, PurityStop, RANDENGINE>(features, labels, options,
visitor,
PurityStop(),
randengine);
579template <
typename FEATURES,
typename LABELS,
typename VISITOR,
typename RANDENGINE>
589 typedef detail::GeneralScorer<GiniScore>
GiniScorer;
591 typedef detail::GeneralScorer<KolmogorovSmirnovScore>
KSDScorer;
592 if (options.split_ == RF_GINI)
593 return detail::random_forest_impl0<FEATURES, LABELS, VISITOR, GiniScorer, RANDENGINE>(features, labels, options,
visitor,
randengine);
594 else if (options.split_ == RF_ENTROPY)
595 return detail::random_forest_impl0<FEATURES, LABELS, VISITOR, EntropyScorer, RANDENGINE>(features, labels, options,
visitor,
randengine);
596 else if (options.split_ == RF_KSD)
597 return detail::random_forest_impl0<FEATURES, LABELS, VISITOR, KSDScorer, RANDENGINE>(features, labels, options,
visitor,
randengine);
599 throw std::runtime_error(
"random_forest(): Unknown split criterion.");
602template <
typename FEATURES,
typename LABELS,
typename VISITOR>
615template <
typename FEATURES,
typename LABELS>
617RandomForest<FEATURES, LABELS>
619 FEATURES
const & features,
620 LABELS
const & labels,
621 RandomForestOptions
const & options
627template <
typename FEATURES,
typename LABELS>
629RandomForest<FEATURES, LABELS>
631 FEATURES
const & features,
632 LABELS
const & labels
634 return random_forest(features, labels, RandomForestOptions());