37 #ifndef VIGRA_RANDOM_FOREST_HXX
38 #define VIGRA_RANDOM_FOREST_HXX
46 #include "mathutil.hxx"
47 #include "array_vector.hxx"
48 #include "sized_int.hxx"
50 #include "metaprogramming.hxx"
52 #include "functorexpression.hxx"
53 #include "random_forest/rf_common.hxx"
54 #include "random_forest/rf_nodeproxy.hxx"
55 #include "random_forest/rf_split.hxx"
56 #include "random_forest/rf_decisionTree.hxx"
57 #include "random_forest/rf_visitors.hxx"
58 #include "random_forest/rf_region.hxx"
59 #include "sampling.hxx"
60 #include "random_forest/rf_preprocessing.hxx"
61 #include "random_forest/rf_online_prediction_set.hxx"
62 #include "random_forest/rf_earlystopping.hxx"
63 #include "random_forest/rf_ridge_split.hxx"
83 inline SamplerOptions make_sampler_opt ( RandomForestOptions & RF_opt)
85 SamplerOptions return_opt;
87 return_opt.
stratified(RF_opt.stratification_method_ == RF_EQUAL);
144 template <
class LabelType =
double ,
class PreprocessorTag = ClassificationTag >
151 typedef detail::DecisionTree DecisionTree_t;
158 typedef LabelType LabelT;
226 template<
class TopologyIterator,
class ParameterIterator>
228 TopologyIterator topology_begin,
229 ParameterIterator parameter_begin,
233 trees_(treeCount, DecisionTree_t(problem_spec)),
234 ext_param_(problem_spec),
237 for(
int k=0; k<treeCount; ++k, ++topology_begin, ++parameter_begin)
239 trees_[k].topology_ = *topology_begin;
240 trees_[k].parameters_ = *parameter_begin;
259 vigra_precondition(ext_param_.used() ==
true,
260 "RandomForest::ext_param(): "
261 "Random forest has not been trained yet.");
278 vigra_precondition(ext_param_.used() ==
false,
279 "RandomForest::set_ext_param():"
280 "Random forest has been trained! Call reset()"
281 "before specifying new extrinsic parameters.");
305 DecisionTree_t
const &
tree(
int index)
const
307 return trees_[index];
312 DecisionTree_t &
tree(
int index)
314 return trees_[index];
324 return ext_param_.column_count_;
335 return ext_param_.column_count_;
343 return ext_param_.class_count_;
350 return options_.tree_count_;
355 template<
class U,
class C1,
368 bool adjust_thresholds=
false);
370 template <
class U,
class C1,
class U2,
class C2>
375 onlineLearn(features,
385 template<
class U,
class C1,
391 void reLearnTree(MultiArrayView<2,U,C1>
const & features,
392 MultiArrayView<2,U2,C2>
const & response,
399 template<
class U,
class C1,
class U2,
class C2>
400 void reLearnTree(MultiArrayView<2, U, C1>
const & features,
401 MultiArrayView<2, U2, C2>
const & labels,
404 RandomNumberGenerator<> rnd = RandomNumberGenerator<>(RandomSeed);
449 template <
class U,
class C1,
455 void learn( MultiArrayView<2, U, C1>
const & features,
456 MultiArrayView<2, U2,C2>
const & response,
460 Random_t
const & random);
462 template <
class U,
class C1,
467 void learn( MultiArrayView<2, U, C1>
const & features,
468 MultiArrayView<2, U2,C2>
const & response,
474 RandomNumberGenerator<> rnd = RandomNumberGenerator<>(RandomSeed);
483 template <
class U,
class C1,
class U2,
class C2,
class Visitor_t>
484 void learn( MultiArrayView<2, U, C1>
const & features,
485 MultiArrayView<2, U2,C2>
const & labels,
495 template <
class U,
class C1,
class U2,
class C2,
496 class Visitor_t,
class Split_t>
497 void learn( MultiArrayView<2, U, C1>
const & features,
498 MultiArrayView<2, U2,C2>
const & labels,
527 template <
class U,
class C1,
class U2,
class C2>
555 template <
class U,
class C,
class Stop>
558 template <
class U,
class C>
569 template <
class U,
class C>
570 LabelType
predictLabel(MultiArrayView<2, U, C>
const & features,
571 ArrayVectorView<double> prior)
const;
583 template <
class U,
class C1,
class T,
class C2>
587 vigra_precondition(features.
shape(0) == labels.
shape(0),
588 "RandomForest::predictLabels(): Label array has wrong size.");
589 for(
int k=0; k<features.
shape(0); ++k)
591 vigra_precondition(!detail::contains_nan(
rowVector(features, k)),
592 "RandomForest::predictLabels(): NaN in feature matrix.");
607 template <
class U,
class C1,
class T,
class C2>
610 LabelType nanLabel)
const
612 vigra_precondition(features.
shape(0) == labels.
shape(0),
613 "RandomForest::predictLabels(): Label array has wrong size.");
614 for(
int k=0; k<features.
shape(0); ++k)
616 if(detail::contains_nan(
rowVector(features, k)))
617 labels(k,0) = nanLabel;
632 template <
class U,
class C1,
class T,
class C2,
class Stop>
637 vigra_precondition(features.
shape(0) == labels.
shape(0),
638 "RandomForest::predictLabels(): Label array has wrong size.");
639 for(
int k=0; k<features.
shape(0); ++k)
654 template <
class U,
class C1,
class T,
class C2,
class Stop>
658 template <
class T1,
class T2,
class C>
668 template <
class U,
class C1,
class T,
class C2>
675 template <
class U,
class C1,
class T,
class C2>
685 template <
class LabelType,
class PreprocessorTag>
686 template<
class U,
class C1,
692 void RandomForest<LabelType, PreprocessorTag>::onlineLearn(MultiArrayView<2,U,C1>
const & features,
693 MultiArrayView<2,U2,C2>
const & response,
699 bool adjust_thresholds)
701 online_visitor_.activate();
702 online_visitor_.adjust_thresholds=adjust_thresholds;
706 typedef Processor<PreprocessorTag,LabelType,U,C1,U2,C2> Preprocessor_t;
707 typedef UniformIntRandomFunctor<Random_t>
714 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
715 Default_Stop_t default_stop(options_);
716 typename RF_CHOOSER(Stop_t)::type stop
717 = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
718 Default_Split_t default_split;
719 typename RF_CHOOSER(Split_t)::type split
720 = RF_CHOOSER(Split_t)::choose(split_, default_split);
721 rf::visitors::StopVisiting stopvisiting;
722 typedef rf::visitors::detail::VisitorNode
723 <rf::visitors::OnlineLearnVisitor,
724 typename RF_CHOOSER(Visitor_t)::type>
727 visitor(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting));
729 vigra_precondition(options_.prepare_online_learning_,
"onlineLearn: online learning must be enabled on RandomForest construction");
735 ext_param_.class_count_=0;
736 Preprocessor_t preprocessor( features, response,
737 options_, ext_param_);
740 RandFunctor_t randint ( random);
743 split.set_external_parameters(ext_param_);
744 stop.set_external_parameters(ext_param_);
748 PoissonSampler<RandomTT800> poisson_sampler(1.0,
vigra::Int32(new_start_index),
vigra::Int32(ext_param().row_count_));
754 for(
int ii = 0; ii < static_cast<int>(trees_.size()); ++ii)
756 online_visitor_.tree_id=ii;
757 poisson_sampler.sample();
758 std::map<int,int> leaf_parents;
759 leaf_parents.clear();
761 for(
int s=0;s<poisson_sampler.numOfSamples();++s)
763 int sample=poisson_sampler[s];
764 online_visitor_.current_label=preprocessor.response()(sample,0);
765 online_visitor_.last_node_id=StackEntry_t::DecisionTreeNoParent;
766 int leaf=trees_[ii].getToLeaf(
rowVector(features,sample),online_visitor_);
770 online_visitor_.add_to_index_list(ii,leaf,sample);
773 if(Node<e_ConstProbNode>(trees_[ii].topology_,trees_[ii].parameters_,leaf).prob_begin()[preprocessor.response()(sample,0)]!=1.0)
775 leaf_parents[leaf]=online_visitor_.last_node_id;
780 std::map<int,int>::iterator leaf_iterator;
781 for(leaf_iterator=leaf_parents.begin();leaf_iterator!=leaf_parents.end();++leaf_iterator)
783 int leaf=leaf_iterator->first;
784 int parent=leaf_iterator->second;
785 int lin_index=online_visitor_.trees_online_information[ii].exterior_to_index[leaf];
786 ArrayVector<Int32> indeces;
788 indeces.swap(online_visitor_.trees_online_information[ii].index_lists[lin_index]);
789 StackEntry_t stack_entry(indeces.begin(),
791 ext_param_.class_count_);
796 if(NodeBase(trees_[ii].topology_,trees_[ii].parameters_,parent).child(0)==leaf)
798 stack_entry.leftParent=parent;
802 vigra_assert(NodeBase(trees_[ii].topology_,trees_[ii].parameters_,parent).child(1)==leaf,
"last_node_id seems to be wrong");
803 stack_entry.rightParent=parent;
807 trees_[ii].continueLearn(preprocessor.features(),preprocessor.response(),stack_entry,split,stop,visitor,randint,-1);
809 online_visitor_.move_exterior_node(ii,trees_[ii].topology_.size(),ii,leaf);
822 online_visitor_.deactivate();
825 template<
class LabelType,
class PreprocessorTag>
826 template<
class U,
class C1,
847 ext_param_.class_count_=0;
855 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
857 typename RF_CHOOSER(Stop_t)::type stop
858 = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
860 typename RF_CHOOSER(Split_t)::type split
861 = RF_CHOOSER(Split_t)::choose(split_, default_split);
865 typename RF_CHOOSER(Visitor_t)::type> IntermedVis;
867 visitor(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting));
869 vigra_precondition(options_.prepare_online_learning_,
"reLearnTree: Re learning trees only makes sense, if online learning is enabled");
870 online_visitor_.activate();
873 RandFunctor_t randint ( random);
879 Preprocessor_t preprocessor( features, response,
880 options_, ext_param_);
883 split.set_external_parameters(ext_param_);
884 stop.set_external_parameters(ext_param_);
891 preprocessor.strata().end(),
892 detail::make_sampler_opt(options_)
893 .sampleSize(ext_param().actual_msample_),
900 first_stack_entry( sampler.sampledIndices().begin(),
901 sampler.sampledIndices().end(),
902 ext_param_.class_count_);
904 .set_oob_range( sampler.oobIndices().begin(),
905 sampler.oobIndices().end());
906 online_visitor_.reset_tree(treeId);
907 online_visitor_.tree_id=treeId;
908 trees_[treeId].reset();
910 .learn( preprocessor.features(),
911 preprocessor.response(),
918 .visit_after_tree( *
this,
924 online_visitor_.deactivate();
927 template <
class LabelType,
class PreprocessorTag>
928 template <
class U,
class C1,
940 Random_t
const & random)
951 vigra_precondition(features.
shape(0) == response.
shape(0),
952 "RandomForest::learn(): shape mismatch between features and response.");
959 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
961 typename RF_CHOOSER(Stop_t)::type stop
962 = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
964 typename RF_CHOOSER(Split_t)::type split
965 = RF_CHOOSER(Split_t)::choose(split_, default_split);
969 typename RF_CHOOSER(Visitor_t)::type> IntermedVis;
971 visitor(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting));
973 if(options_.prepare_online_learning_)
974 online_visitor_.activate();
976 online_visitor_.deactivate();
980 RandFunctor_t randint ( random);
987 Preprocessor_t preprocessor( features, response,
988 options_, ext_param_);
991 split.set_external_parameters(ext_param_);
992 stop.set_external_parameters(ext_param_);
996 trees_.resize(options_.tree_count_ , DecisionTree_t(ext_param_));
999 preprocessor.strata().end(),
1000 detail::make_sampler_opt(options_)
1001 .sampleSize(ext_param().actual_msample_),
1004 visitor.visit_at_beginning(*
this, preprocessor);
1007 for(
int ii = 0; ii < static_cast<int>(trees_.size()); ++ii)
1013 first_stack_entry( sampler.sampledIndices().begin(),
1014 sampler.sampledIndices().end(),
1015 ext_param_.class_count_);
1017 .set_oob_range( sampler.oobIndices().begin(),
1018 sampler.oobIndices().end());
1020 .learn( preprocessor.features(),
1021 preprocessor.response(),
1028 .visit_after_tree( *
this,
1035 visitor.visit_at_end(*
this, preprocessor);
1037 online_visitor_.deactivate();
1043 template <
class LabelType,
class Tag>
1044 template <
class U,
class C,
class Stop>
1048 vigra_precondition(
columnCount(features) >= ext_param_.column_count_,
1049 "RandomForestn::predictLabel():"
1050 " Too few columns in feature matrix.");
1051 vigra_precondition(
rowCount(features) == 1,
1052 "RandomForestn::predictLabel():"
1053 " Feature matrix must have a singlerow.");
1056 predictProbabilities(features, probabilities, stop);
1057 ext_param_.to_classlabel(
argMax(probabilities), d);
1063 template <
class LabelType,
class PreprocessorTag>
1064 template <
class U,
class C>
1069 using namespace functor;
1070 vigra_precondition(
columnCount(features) >= ext_param_.column_count_,
1071 "RandomForestn::predictLabel(): Too few columns in feature matrix.");
1072 vigra_precondition(
rowCount(features) == 1,
1073 "RandomForestn::predictLabel():"
1074 " Feature matrix must have a single row.");
1075 Matrix<double> prob(1,ext_param_.class_count_);
1076 predictProbabilities(features, prob);
1077 std::transform( prob.begin(), prob.end(),
1078 priors.
begin(), prob.begin(),
1081 ext_param_.to_classlabel(
argMax(prob), d);
1085 template<
class LabelType,
class PreprocessorTag>
1086 template <
class T1,
class T2,
class C>
1095 "RandomFroest::predictProbabilities():"
1096 " Feature matrix and probability matrix size mismatch.");
1099 vigra_precondition(
columnCount(predictionSet.features) >= ext_param_.column_count_,
1100 "RandomForestn::predictProbabilities():"
1101 " Too few columns in feature matrix.");
1103 == static_cast<MultiArrayIndex>(ext_param_.class_count_),
1104 "RandomForestn::predictProbabilities():"
1105 " Probability matrix must have as many columns as there are classes.");
1108 std::vector<T1> totalWeights(predictionSet.indices[0].size(),0.0);
1111 for(
int k=0; k<options_.tree_count_; ++k)
1113 set_id=(set_id+1) % predictionSet.indices[0].size();
1114 typedef std::set<SampleRange<T1> > my_set;
1115 typedef typename my_set::iterator set_it;
1118 std::vector<std::pair<int,set_it> > stack;
1120 for(set_it i=predictionSet.ranges[set_id].begin();
1121 i!=predictionSet.ranges[set_id].end();++i)
1122 stack.push_back(std::pair<int,set_it>(2,i));
1124 int num_decisions=0;
1125 while(!stack.empty())
1127 set_it range=stack.back().second;
1128 int index=stack.back().first;
1132 if(trees_[k].isLeafNode(trees_[k].topology_[index]))
1135 trees_[k].parameters_,
1136 index).prob_begin();
1137 for(
int i=range->start;i!=range->end;++i)
1140 for(
int l=0; l<ext_param_.class_count_; ++l)
1142 prob(predictionSet.indices[set_id][i], l) +=
static_cast<T2
>(weights[l]);
1144 totalWeights[predictionSet.indices[set_id][i]] +=
static_cast<T1
>(weights[l]);
1151 if(trees_[k].topology_[index]!=i_ThresholdNode)
1153 throw std::runtime_error(
"predicting with online prediction sets is only supported for RFs with threshold nodes");
1155 Node<i_ThresholdNode> node(trees_[k].topology_,trees_[k].parameters_,index);
1156 if(range->min_boundaries[node.column()]>=node.threshold())
1159 stack.push_back(std::pair<int,set_it>(node.child(1),range));
1162 if(range->max_boundaries[node.column()]<node.threshold())
1165 stack.push_back(std::pair<int,set_it>(node.child(0),range));
1169 SampleRange<T1> new_range=*range;
1170 new_range.min_boundaries[node.column()]=FLT_MAX;
1171 range->max_boundaries[node.column()]=-FLT_MAX;
1172 new_range.start=new_range.end=range->end;
1174 while(i!=range->end)
1177 if(predictionSet.features(predictionSet.indices[set_id][i],node.column())>=node.threshold())
1179 new_range.min_boundaries[node.column()]=std::min(new_range.min_boundaries[node.column()],
1180 predictionSet.features(predictionSet.indices[set_id][i],node.column()));
1183 std::swap(predictionSet.indices[set_id][i],predictionSet.indices[set_id][range->end]);
1188 range->max_boundaries[node.column()]=std::max(range->max_boundaries[node.column()],
1189 predictionSet.features(predictionSet.indices[set_id][i],node.column()));
1194 if(range->start==range->end)
1196 predictionSet.ranges[set_id].erase(range);
1200 stack.push_back(std::pair<int,set_it>(node.child(0),range));
1203 if(new_range.start!=new_range.end)
1205 std::pair<set_it,bool> new_it=predictionSet.ranges[set_id].insert(new_range);
1206 stack.push_back(std::pair<int,set_it>(node.child(1),new_it.first));
1210 predictionSet.cumulativePredTime[k]=num_decisions;
1212 for(
unsigned int i=0;i<totalWeights.size();++i)
1216 for(
int l=0; l<ext_param_.class_count_; ++l)
1219 prob(i, l) /= totalWeights[i];
1221 assert(test==totalWeights[i]);
1222 assert(totalWeights[i]>0.0);
1226 template <
class LabelType,
class PreprocessorTag>
1227 template <
class U,
class C1,
class T,
class C2,
class Stop_t>
1230 MultiArrayView<2, T, C2> & prob,
1231 Stop_t & stop_)
const
1237 "RandomForestn::predictProbabilities():"
1238 " Feature matrix and probability matrix size mismatch.");
1242 vigra_precondition(
columnCount(features) >= ext_param_.column_count_,
1243 "RandomForestn::predictProbabilities():"
1244 " Too few columns in feature matrix.");
1246 == static_cast<MultiArrayIndex>(ext_param_.class_count_),
1247 "RandomForestn::predictProbabilities():"
1248 " Probability matrix must have as many columns as there are classes.");
1250 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
1251 Default_Stop_t default_stop(options_);
1252 typename RF_CHOOSER(Stop_t)::type & stop
1253 = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
1255 stop.set_external_parameters(ext_param_, tree_count());
1256 prob.init(NumericTraits<T>::zero());
1266 for(
int row=0; row <
rowCount(features); ++row)
1268 MultiArrayView<2, U, StridedArrayTag> currentRow(
rowVector(features, row));
1272 if(detail::contains_nan(currentRow))
1278 ArrayVector<double>::const_iterator weights;
1281 double totalWeight = 0.0;
1284 for(
int k=0; k<options_.tree_count_; ++k)
1287 weights = trees_[k ].predict(currentRow);
1290 int weighted = options_.predict_weighted_;
1291 for(
int l=0; l<ext_param_.class_count_; ++l)
1293 double cur_w = weights[l] * (weighted * (*(weights-1))
1295 prob(row, l) +=
static_cast<T
>(cur_w);
1297 totalWeight += cur_w;
1299 if(stop.after_prediction(weights,
1309 for(
int l=0; l< ext_param_.class_count_; ++l)
1311 prob(row, l) /= detail::RequiresExplicitCast<T>::cast(totalWeight);
1317 template <
class LabelType,
class PreprocessorTag>
1318 template <
class U,
class C1,
class T,
class C2>
1319 void RandomForest<LabelType, PreprocessorTag>
1320 ::predictRaw(MultiArrayView<2, U, C1>
const & features,
1321 MultiArrayView<2, T, C2> & prob)
const
1327 "RandomForestn::predictProbabilities():"
1328 " Feature matrix and probability matrix size mismatch.");
1332 vigra_precondition(
columnCount(features) >= ext_param_.column_count_,
1333 "RandomForestn::predictProbabilities():"
1334 " Too few columns in feature matrix.");
1336 == static_cast<MultiArrayIndex>(ext_param_.class_count_),
1337 "RandomForestn::predictProbabilities():"
1338 " Probability matrix must have as many columns as there are classes.");
1340 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
1341 prob.init(NumericTraits<T>::zero());
1351 for(
int row=0; row <
rowCount(features); ++row)
1353 ArrayVector<double>::const_iterator weights;
1356 double totalWeight = 0.0;
1359 for(
int k=0; k<options_.tree_count_; ++k)
1362 weights = trees_[k ].predict(
rowVector(features, row));
1365 int weighted = options_.predict_weighted_;
1366 for(
int l=0; l<ext_param_.class_count_; ++l)
1368 double cur_w = weights[l] * (weighted * (*(weights-1))
1370 prob(row, l) +=
static_cast<T
>(cur_w);
1372 totalWeight += cur_w;
1376 prob/= options_.tree_count_;
1384 #include "random_forest/rf_algorithm.hxx"
1385 #endif // VIGRA_RANDOM_FOREST_HXX
void predictLabels(MultiArrayView< 2, U, C1 >const &features, MultiArrayView< 2, T, C2 > &labels, LabelType nanLabel) const
predict multiple labels with given features
Definition: random_forest.hxx:608
Definition: rf_region.hxx:57
void set_ext_param(ProblemSpec_t const &in)
set external parameters
Definition: random_forest.hxx:275
int class_count() const
return number of classes used while training.
Definition: random_forest.hxx:341
Definition: rf_nodeproxy.hxx:626
detail::RF_DEFAULT & rf_default()
factory function to return a RF_DEFAULT tag
Definition: rf_common.hxx:131
MultiArrayIndex rowCount(const MultiArrayView< 2, T, C > &x)
Definition: matrix.hxx:669
Definition: rf_preprocessing.hxx:63
int feature_count() const
return number of features used while training.
Definition: random_forest.hxx:322
int column_count() const
return number of features used while training.
Definition: random_forest.hxx:333
Create random samples from a sequence of indices.
Definition: sampling.hxx:233
const difference_type & shape() const
Definition: multi_array.hxx:1594
Definition: rf_split.hxx:993
const_iterator begin() const
Definition: array_vector.hxx:223
problem specification class for the random forest.
Definition: rf_common.hxx:533
RandomForest(Options_t const &options=Options_t(), ProblemSpec_t const &ext_param=ProblemSpec_t())
default constructor
Definition: random_forest.hxx:193
void sample()
Definition: sampling.hxx:468
Standard early stopping criterion.
Definition: rf_common.hxx:880
ProblemSpec_t const & ext_param() const
return external parameters for viewing
Definition: random_forest.hxx:257
DecisionTree_t & tree(int index)
access trees
Definition: random_forest.hxx:312
DecisionTree_t const & tree(int index) const
access const trees
Definition: random_forest.hxx:305
Options_t & set_options()
access random forest options
Definition: random_forest.hxx:288
void learn(MultiArrayView< 2, U, C1 > const &features, MultiArrayView< 2, U2, C2 > const &response, Visitor_t visitor, Split_t split, Stop_t stop, Random_t const &random)
learn on data with custom config and random number generator
Definition: random_forest.hxx:935
void reLearnTree(MultiArrayView< 2, U, C1 > const &features, MultiArrayView< 2, U2, C2 > const &response, int treeId, Visitor_t visitor_, Split_t split_, Stop_t stop_, Random_t &random)
Definition: random_forest.hxx:832
Definition: random_forest.hxx:145
Options_t const & options() const
access const random forest options
Definition: random_forest.hxx:298
void predictProbabilities(MultiArrayView< 2, U, C1 >const &features, MultiArrayView< 2, T, C2 > &prob, Stop &stop) const
predict the class probabilities for multiple labels
Definition: rf_visitors.hxx:255
detail::SelectIntegerType< 32, detail::SignedIntTypes >::type Int32
32-bit signed int
Definition: sized_int.hxx:175
Iterator argMax(Iterator first, Iterator last)
Find the maximum element in a sequence.
Definition: algorithm.hxx:96
void predictLabels(MultiArrayView< 2, U, C1 >const &features, MultiArrayView< 2, T, C2 > &labels, Stop &stop) const
predict multiple labels with given features
Definition: random_forest.hxx:633
Definition: rf_visitors.hxx:584
void predictLabels(MultiArrayView< 2, U, C1 >const &features, MultiArrayView< 2, T, C2 > &labels) const
predict multiple labels with given features
Definition: random_forest.hxx:584
SamplerOptions & stratified(bool in=true)
Draw equally many samples from each "stratum". A stratum is a group of like entities, e.g. pixels belonging to the same object class. This is useful to create balanced samples when the class probabilities are very unbalanced (e.g. when there are many background and few foreground pixels). Stratified sampling thus avoids that a trained classifier is biased towards the majority class.
Definition: sampling.hxx:144
SamplerOptions & withReplacement(bool in=true)
Sample from training population with replacement.
Definition: sampling.hxx:86
void predictProbabilities(MultiArrayView< 2, U, C1 >const &features, MultiArrayView< 2, T, C2 > &prob) const
predict the class probabilities for multiple labels
Definition: random_forest.hxx:669
MultiArrayView< 2, T, C > rowVector(MultiArrayView< 2, T, C > const &m, MultiArrayIndex d)
Definition: matrix.hxx:695
int tree_count() const
return number of trees
Definition: random_forest.hxx:348
MultiArrayIndex columnCount(const MultiArrayView< 2, T, C > &x)
Definition: matrix.hxx:682
RandomForest(int treeCount, TopologyIterator topology_begin, ParameterIterator parameter_begin, ProblemSpec_t const &problem_spec, Options_t const &options=Options_t())
Create RF from external source.
Definition: random_forest.hxx:227
Definition: random.hxx:336
Base class for, and view to, vigra::MultiArray.
Definition: multi_array.hxx:650
Options object for the random forest.
Definition: rf_common.hxx:170
MultiArrayView & init(const U &init)
Definition: multi_array.hxx:1152
LabelType predictLabel(MultiArrayView< 2, U, C >const &features, Stop &stop) const
predict a label given a feature.
Definition: random_forest.hxx:1046
void learn(MultiArrayView< 2, U, C1 > const &features, MultiArrayView< 2, U2, C2 > const &labels)
learn on data with default configuration
Definition: random_forest.hxx:528
Definition: rf_visitors.hxx:235