[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

rf_visitors.hxx VIGRA

1 /************************************************************************/
2 /* */
3 /* Copyright 2008-2009 by Ullrich Koethe and Rahul Nair */
4 /* */
5 /* This file is part of the VIGRA computer vision library. */
6 /* The VIGRA Website is */
7 /* http://hci.iwr.uni-heidelberg.de/vigra/ */
8 /* Please direct questions, bug reports, and contributions to */
9 /* ullrich.koethe@iwr.uni-heidelberg.de or */
10 /* vigra@informatik.uni-hamburg.de */
11 /* */
12 /* Permission is hereby granted, free of charge, to any person */
13 /* obtaining a copy of this software and associated documentation */
14 /* files (the "Software"), to deal in the Software without */
15 /* restriction, including without limitation the rights to use, */
16 /* copy, modify, merge, publish, distribute, sublicense, and/or */
17 /* sell copies of the Software, and to permit persons to whom the */
18 /* Software is furnished to do so, subject to the following */
19 /* conditions: */
20 /* */
21 /* The above copyright notice and this permission notice shall be */
22 /* included in all copies or substantial portions of the */
23 /* Software. */
24 /* */
25 /* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND */
26 /* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES */
27 /* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND */
28 /* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT */
29 /* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, */
30 /* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING */
31 /* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR */
32 /* OTHER DEALINGS IN THE SOFTWARE. */
33 /* */
34 /************************************************************************/
35 #ifndef RF_VISITORS_HXX
36 #define RF_VISITORS_HXX
37 
38 #ifdef HasHDF5
39 # include "vigra/hdf5impex.hxx"
40 #endif // HasHDF5
41 #include <vigra/windows.h>
42 #include <iostream>
43 #include <iomanip>
44 
45 #include <vigra/metaprogramming.hxx>
46 #include <vigra/multi_pointoperators.hxx>
47 #include <vigra/timing.hxx>
48 
49 namespace vigra
50 {
51 namespace rf
52 {
53 /** \addtogroup MachineLearning Machine Learning
54 **/
55 //@{
56 
57 /**
58  This namespace contains all classes and methods related to extracting information during
59  learning of the random forest. All Visitors share the same interface defined in
60  visitors::VisitorBase. The member methods are invoked at certain points of the main code in
61  the order they were supplied.
62 
63  For the Random Forest the Visitor concept is implemented as a statically linked list
64  (Using templates). Each Visitor object is encapsulated in a detail::VisitorNode object. The
65  VisitorNode object calls the Next Visitor after one of its visit() methods have terminated.
66 
67  To simplify usage create_visitor() factory methods are supplied.
68  Use the create_visitor() method to supply visitor objects to the RandomForest::learn() method.
69  It is possible to supply more than one visitor. They will then be invoked in serial order.
70 
71  The calculated information are stored as public data members of the class. - see documentation
72  of the individual visitors
73 
74  While creating a new visitor the new class should therefore publicly inherit from this class
75  (i.e.: see visitors::OOB_Error).
76 
77  \code
78 
79  typedef xxx feature_t \\ replace xxx with whichever type
80  typedef yyy label_t \\ meme chose.
81  MultiArrayView<2, feature_t> f = get_some_features();
82  MultiArrayView<2, label_t> l = get_some_labels();
83  RandomForest<> rf()
84 
85  //calculate OOB Error
86  visitors::OOB_Error oob_v;
87  //calculate Variable Importance
88  visitors::VariableImportanceVisitor varimp_v;
89 
90  double oob_error = rf.learn(f, l, visitors::create_visitor(oob_v, varimp_v);
91  //the data can be found in the attributes of oob_v and varimp_v now
92 
93  \endcode
94 */
95 namespace visitors
96 {
97 
98 
99 /** Base Class from which all Visitors derive. Can be used as a template to create new
100  * Visitors.
101  */
103 {
104  public:
105  bool active_;
106  bool is_active()
107  {
108  return active_;
109  }
110 
111  bool has_value()
112  {
113  return false;
114  }
115 
116  VisitorBase()
117  : active_(true)
118  {}
119 
120  void deactivate()
121  {
122  active_ = false;
123  }
124  void activate()
125  {
126  active_ = true;
127  }
128 
129  /** do something after the the Split has decided how to process the Region
130  * (Stack entry)
131  *
132  * \param tree reference to the tree that is currently being learned
133  * \param split reference to the split object
134  * \param parent current stack entry which was used to decide the split
135  * \param leftChild left stack entry that will be pushed
136  * \param rightChild
137  * right stack entry that will be pushed.
138  * \param features features matrix
139  * \param labels label matrix
140  * \sa RF_Traits::StackEntry_t
141  */
142  template<class Tree, class Split, class Region, class Feature_t, class Label_t>
143  void visit_after_split( Tree & tree,
144  Split & split,
145  Region & parent,
146  Region & leftChild,
147  Region & rightChild,
148  Feature_t & features,
149  Label_t & labels)
150  {
151  ignore_argument(tree,split,parent,leftChild,rightChild,features,labels);
152  }
153 
154  /** do something after each tree has been learned
155  *
156  * \param rf reference to the random forest object that called this
157  * visitor
158  * \param pr reference to the preprocessor that processed the input
159  * \param sm reference to the sampler object
160  * \param st reference to the first stack entry
161  * \param index index of current tree
162  */
163  template<class RF, class PR, class SM, class ST>
164  void visit_after_tree(RF & rf, PR & pr, SM & sm, ST & st, int index)
165  {
166  ignore_argument(rf,pr,sm,st,index);
167  }
168 
169  /** do something after all trees have been learned
170  *
171  * \param rf reference to the random forest object that called this
172  * visitor
173  * \param pr reference to the preprocessor that processed the input
174  */
175  template<class RF, class PR>
176  void visit_at_end(RF const & rf, PR const & pr)
177  {
178  ignore_argument(rf,pr);
179  }
180 
181  /** do something before learning starts
182  *
183  * \param rf reference to the random forest object that called this
184  * visitor
185  * \param pr reference to the Processor class used.
186  */
187  template<class RF, class PR>
188  void visit_at_beginning(RF const & rf, PR const & pr)
189  {
190  ignore_argument(rf,pr);
191  }
192  /** do some thing while traversing tree after it has been learned
193  * (external nodes)
194  *
195  * \param tr reference to the tree object that called this visitor
196  * \param index index in the topology_ array we currently are at
197  * \param node_t type of node we have (will be e_.... - )
198  * \param features feature matrix
199  * \sa NodeTags;
200  *
201  * you can create the node by using a switch on node_tag and using the
202  * corresponding Node objects. Or - if you do not care about the type
203  * use the NodeBase class.
204  */
205  template<class TR, class IntT, class TopT,class Feat>
206  void visit_external_node(TR & tr, IntT index, TopT node_t, Feat & features)
207  {
208  ignore_argument(tr,index,node_t,features);
209  }
210 
211  /** do something when visiting a internal node after it has been learned
212  *
213  * \sa visit_external_node
214  */
215  template<class TR, class IntT, class TopT,class Feat>
216  void visit_internal_node(TR & /* tr */, IntT /* index */, TopT /* node_t */, Feat & /* features */)
217  {}
218 
219  /** return a double value. The value of the first
220  * visitor encountered that has a return value is returned with the
221  * RandomForest::learn() method - or -1.0 if no return value visitor
222  * existed. This functionality basically only exists so that the
223  * OOB - visitor can return the oob error rate like in the old version
224  * of the random forest.
225  */
226  double return_val()
227  {
228  return -1.0;
229  }
230 };
231 
232 
233 /** Last Visitor that should be called to stop the recursion.
234  */
236 {
237  public:
238  bool has_value()
239  {
240  return true;
241  }
242  double return_val()
243  {
244  return -1.0;
245  }
246 };
247 namespace detail
248 {
249 /** Container elements of the statically linked Visitor list.
250  *
251  * use the create_visitor() factory functions to create visitors up to size 10;
252  *
253  */
254 template <class Visitor, class Next = StopVisiting>
256 {
257  public:
258 
259  StopVisiting stop_;
260  Next next_;
261  Visitor & visitor_;
262  VisitorNode(Visitor & visitor, Next & next)
263  :
264  next_(next), visitor_(visitor)
265  {}
266 
267  VisitorNode(Visitor & visitor)
268  :
269  next_(stop_), visitor_(visitor)
270  {}
271 
272  template<class Tree, class Split, class Region, class Feature_t, class Label_t>
273  void visit_after_split( Tree & tree,
274  Split & split,
275  Region & parent,
276  Region & leftChild,
277  Region & rightChild,
278  Feature_t & features,
279  Label_t & labels)
280  {
281  if(visitor_.is_active())
282  visitor_.visit_after_split(tree, split,
283  parent, leftChild, rightChild,
284  features, labels);
285  next_.visit_after_split(tree, split, parent, leftChild, rightChild,
286  features, labels);
287  }
288 
289  template<class RF, class PR, class SM, class ST>
290  void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st, int index)
291  {
292  if(visitor_.is_active())
293  visitor_.visit_after_tree(rf, pr, sm, st, index);
294  next_.visit_after_tree(rf, pr, sm, st, index);
295  }
296 
297  template<class RF, class PR>
298  void visit_at_beginning(RF & rf, PR & pr)
299  {
300  if(visitor_.is_active())
301  visitor_.visit_at_beginning(rf, pr);
302  next_.visit_at_beginning(rf, pr);
303  }
304  template<class RF, class PR>
305  void visit_at_end(RF & rf, PR & pr)
306  {
307  if(visitor_.is_active())
308  visitor_.visit_at_end(rf, pr);
309  next_.visit_at_end(rf, pr);
310  }
311 
312  template<class TR, class IntT, class TopT,class Feat>
313  void visit_external_node(TR & tr, IntT & index, TopT & node_t,Feat & features)
314  {
315  if(visitor_.is_active())
316  visitor_.visit_external_node(tr, index, node_t,features);
317  next_.visit_external_node(tr, index, node_t,features);
318  }
319  template<class TR, class IntT, class TopT,class Feat>
320  void visit_internal_node(TR & tr, IntT & index, TopT & node_t,Feat & features)
321  {
322  if(visitor_.is_active())
323  visitor_.visit_internal_node(tr, index, node_t,features);
324  next_.visit_internal_node(tr, index, node_t,features);
325  }
326 
327  double return_val()
328  {
329  if(visitor_.is_active() && visitor_.has_value())
330  return visitor_.return_val();
331  return next_.return_val();
332  }
333 };
334 
335 } //namespace detail
336 
337 //////////////////////////////////////////////////////////////////////////////
338 // Visitor Factory function up to 10 visitors //
339 //////////////////////////////////////////////////////////////////////////////
340 
341 /** factory method to to be used with RandomForest::learn()
342  */
343 template<class A>
346 {
347  typedef detail::VisitorNode<A> _0_t;
348  _0_t _0(a);
349  return _0;
350 }
351 
352 
353 /** factory method to to be used with RandomForest::learn()
354  */
355 template<class A, class B>
356 detail::VisitorNode<A, detail::VisitorNode<B> >
357 create_visitor(A & a, B & b)
358 {
359  typedef detail::VisitorNode<B> _1_t;
360  _1_t _1(b);
361  typedef detail::VisitorNode<A, _1_t> _0_t;
362  _0_t _0(a, _1);
363  return _0;
364 }
365 
366 
367 /** factory method to to be used with RandomForest::learn()
368  */
369 template<class A, class B, class C>
370 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C> > >
371 create_visitor(A & a, B & b, C & c)
372 {
373  typedef detail::VisitorNode<C> _2_t;
374  _2_t _2(c);
375  typedef detail::VisitorNode<B, _2_t> _1_t;
376  _1_t _1(b, _2);
377  typedef detail::VisitorNode<A, _1_t> _0_t;
378  _0_t _0(a, _1);
379  return _0;
380 }
381 
382 
383 /** factory method to to be used with RandomForest::learn()
384  */
385 template<class A, class B, class C, class D>
386 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
387  detail::VisitorNode<D> > > >
388 create_visitor(A & a, B & b, C & c, D & d)
389 {
390  typedef detail::VisitorNode<D> _3_t;
391  _3_t _3(d);
392  typedef detail::VisitorNode<C, _3_t> _2_t;
393  _2_t _2(c, _3);
394  typedef detail::VisitorNode<B, _2_t> _1_t;
395  _1_t _1(b, _2);
396  typedef detail::VisitorNode<A, _1_t> _0_t;
397  _0_t _0(a, _1);
398  return _0;
399 }
400 
401 
402 /** factory method to to be used with RandomForest::learn()
403  */
404 template<class A, class B, class C, class D, class E>
405 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
406  detail::VisitorNode<D, detail::VisitorNode<E> > > > >
407 create_visitor(A & a, B & b, C & c,
408  D & d, E & e)
409 {
410  typedef detail::VisitorNode<E> _4_t;
411  _4_t _4(e);
412  typedef detail::VisitorNode<D, _4_t> _3_t;
413  _3_t _3(d, _4);
414  typedef detail::VisitorNode<C, _3_t> _2_t;
415  _2_t _2(c, _3);
416  typedef detail::VisitorNode<B, _2_t> _1_t;
417  _1_t _1(b, _2);
418  typedef detail::VisitorNode<A, _1_t> _0_t;
419  _0_t _0(a, _1);
420  return _0;
421 }
422 
423 
424 /** factory method to to be used with RandomForest::learn()
425  */
426 template<class A, class B, class C, class D, class E,
427  class F>
428 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
429  detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F> > > > > >
430 create_visitor(A & a, B & b, C & c,
431  D & d, E & e, F & f)
432 {
433  typedef detail::VisitorNode<F> _5_t;
434  _5_t _5(f);
435  typedef detail::VisitorNode<E, _5_t> _4_t;
436  _4_t _4(e, _5);
437  typedef detail::VisitorNode<D, _4_t> _3_t;
438  _3_t _3(d, _4);
439  typedef detail::VisitorNode<C, _3_t> _2_t;
440  _2_t _2(c, _3);
441  typedef detail::VisitorNode<B, _2_t> _1_t;
442  _1_t _1(b, _2);
443  typedef detail::VisitorNode<A, _1_t> _0_t;
444  _0_t _0(a, _1);
445  return _0;
446 }
447 
448 
449 /** factory method to to be used with RandomForest::learn()
450  */
451 template<class A, class B, class C, class D, class E,
452  class F, class G>
453 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
454  detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F,
455  detail::VisitorNode<G> > > > > > >
456 create_visitor(A & a, B & b, C & c,
457  D & d, E & e, F & f, G & g)
458 {
459  typedef detail::VisitorNode<G> _6_t;
460  _6_t _6(g);
461  typedef detail::VisitorNode<F, _6_t> _5_t;
462  _5_t _5(f, _6);
463  typedef detail::VisitorNode<E, _5_t> _4_t;
464  _4_t _4(e, _5);
465  typedef detail::VisitorNode<D, _4_t> _3_t;
466  _3_t _3(d, _4);
467  typedef detail::VisitorNode<C, _3_t> _2_t;
468  _2_t _2(c, _3);
469  typedef detail::VisitorNode<B, _2_t> _1_t;
470  _1_t _1(b, _2);
471  typedef detail::VisitorNode<A, _1_t> _0_t;
472  _0_t _0(a, _1);
473  return _0;
474 }
475 
476 
477 /** factory method to to be used with RandomForest::learn()
478  */
479 template<class A, class B, class C, class D, class E,
480  class F, class G, class H>
481 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
482  detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F,
483  detail::VisitorNode<G, detail::VisitorNode<H> > > > > > > >
484 create_visitor(A & a, B & b, C & c,
485  D & d, E & e, F & f,
486  G & g, H & h)
487 {
488  typedef detail::VisitorNode<H> _7_t;
489  _7_t _7(h);
490  typedef detail::VisitorNode<G, _7_t> _6_t;
491  _6_t _6(g, _7);
492  typedef detail::VisitorNode<F, _6_t> _5_t;
493  _5_t _5(f, _6);
494  typedef detail::VisitorNode<E, _5_t> _4_t;
495  _4_t _4(e, _5);
496  typedef detail::VisitorNode<D, _4_t> _3_t;
497  _3_t _3(d, _4);
498  typedef detail::VisitorNode<C, _3_t> _2_t;
499  _2_t _2(c, _3);
500  typedef detail::VisitorNode<B, _2_t> _1_t;
501  _1_t _1(b, _2);
502  typedef detail::VisitorNode<A, _1_t> _0_t;
503  _0_t _0(a, _1);
504  return _0;
505 }
506 
507 
508 /** factory method to to be used with RandomForest::learn()
509  */
510 template<class A, class B, class C, class D, class E,
511  class F, class G, class H, class I>
512 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
513  detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F,
514  detail::VisitorNode<G, detail::VisitorNode<H, detail::VisitorNode<I> > > > > > > > >
515 create_visitor(A & a, B & b, C & c,
516  D & d, E & e, F & f,
517  G & g, H & h, I & i)
518 {
519  typedef detail::VisitorNode<I> _8_t;
520  _8_t _8(i);
521  typedef detail::VisitorNode<H, _8_t> _7_t;
522  _7_t _7(h, _8);
523  typedef detail::VisitorNode<G, _7_t> _6_t;
524  _6_t _6(g, _7);
525  typedef detail::VisitorNode<F, _6_t> _5_t;
526  _5_t _5(f, _6);
527  typedef detail::VisitorNode<E, _5_t> _4_t;
528  _4_t _4(e, _5);
529  typedef detail::VisitorNode<D, _4_t> _3_t;
530  _3_t _3(d, _4);
531  typedef detail::VisitorNode<C, _3_t> _2_t;
532  _2_t _2(c, _3);
533  typedef detail::VisitorNode<B, _2_t> _1_t;
534  _1_t _1(b, _2);
535  typedef detail::VisitorNode<A, _1_t> _0_t;
536  _0_t _0(a, _1);
537  return _0;
538 }
539 
540 /** factory method to to be used with RandomForest::learn()
541  */
542 template<class A, class B, class C, class D, class E,
543  class F, class G, class H, class I, class J>
544 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
545  detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F,
546  detail::VisitorNode<G, detail::VisitorNode<H, detail::VisitorNode<I,
547  detail::VisitorNode<J> > > > > > > > > >
548 create_visitor(A & a, B & b, C & c,
549  D & d, E & e, F & f,
550  G & g, H & h, I & i,
551  J & j)
552 {
553  typedef detail::VisitorNode<J> _9_t;
554  _9_t _9(j);
555  typedef detail::VisitorNode<I, _9_t> _8_t;
556  _8_t _8(i, _9);
557  typedef detail::VisitorNode<H, _8_t> _7_t;
558  _7_t _7(h, _8);
559  typedef detail::VisitorNode<G, _7_t> _6_t;
560  _6_t _6(g, _7);
561  typedef detail::VisitorNode<F, _6_t> _5_t;
562  _5_t _5(f, _6);
563  typedef detail::VisitorNode<E, _5_t> _4_t;
564  _4_t _4(e, _5);
565  typedef detail::VisitorNode<D, _4_t> _3_t;
566  _3_t _3(d, _4);
567  typedef detail::VisitorNode<C, _3_t> _2_t;
568  _2_t _2(c, _3);
569  typedef detail::VisitorNode<B, _2_t> _1_t;
570  _1_t _1(b, _2);
571  typedef detail::VisitorNode<A, _1_t> _0_t;
572  _0_t _0(a, _1);
573  return _0;
574 }
575 
576 //////////////////////////////////////////////////////////////////////////////
577 // Visitors of communal interest. //
578 //////////////////////////////////////////////////////////////////////////////
579 
580 
581 /** Visitor to gain information, later needed for online learning.
582  */
583 
585 {
586 public:
587  //Set if we adjust thresholds
588  bool adjust_thresholds;
589  //Current tree id
590  int tree_id;
591  //Last node id for finding parent
592  int last_node_id;
593  //Need to now the label for interior node visiting
594  vigra::Int32 current_label;
595  //marginal distribution for interior nodes
596  //
598  adjust_thresholds(false), tree_id(0), last_node_id(0), current_label(0)
599  {}
600  struct MarginalDistribution
601  {
602  ArrayVector<Int32> leftCounts;
603  Int32 leftTotalCounts;
604  ArrayVector<Int32> rightCounts;
605  Int32 rightTotalCounts;
606  double gap_left;
607  double gap_right;
608  };
610 
611  //All information for one tree
612  struct TreeOnlineInformation
613  {
614  std::vector<MarginalDistribution> mag_distributions;
615  std::vector<IndexList> index_lists;
616  //map for linear index of mag_distributions
617  std::map<int,int> interior_to_index;
618  //map for linear index of index_lists
619  std::map<int,int> exterior_to_index;
620  };
621 
622  //All trees
623  std::vector<TreeOnlineInformation> trees_online_information;
624 
625  /** Initialize, set the number of trees
626  */
627  template<class RF,class PR>
628  void visit_at_beginning(RF & rf,const PR & /* pr */)
629  {
630  tree_id=0;
631  trees_online_information.resize(rf.options_.tree_count_);
632  }
633 
634  /** Reset a tree
635  */
636  void reset_tree(int tree_id)
637  {
638  trees_online_information[tree_id].mag_distributions.clear();
639  trees_online_information[tree_id].index_lists.clear();
640  trees_online_information[tree_id].interior_to_index.clear();
641  trees_online_information[tree_id].exterior_to_index.clear();
642  }
643 
644  /** simply increase the tree count
645  */
646  template<class RF, class PR, class SM, class ST>
647  void visit_after_tree(RF & /* rf */, PR & /* pr */, SM & /* sm */, ST & /* st */, int /* index */)
648  {
649  tree_id++;
650  }
651 
652  template<class Tree, class Split, class Region, class Feature_t, class Label_t>
653  void visit_after_split( Tree & tree,
654  Split & split,
655  Region & parent,
656  Region & leftChild,
657  Region & rightChild,
658  Feature_t & features,
659  Label_t & /* labels */)
660  {
661  int linear_index;
662  int addr=tree.topology_.size();
663  if(split.createNode().typeID() == i_ThresholdNode)
664  {
665  if(adjust_thresholds)
666  {
667  //Store marginal distribution
668  linear_index=trees_online_information[tree_id].mag_distributions.size();
669  trees_online_information[tree_id].interior_to_index[addr]=linear_index;
670  trees_online_information[tree_id].mag_distributions.push_back(MarginalDistribution());
671 
672  trees_online_information[tree_id].mag_distributions.back().leftCounts=leftChild.classCounts_;
673  trees_online_information[tree_id].mag_distributions.back().rightCounts=rightChild.classCounts_;
674 
675  trees_online_information[tree_id].mag_distributions.back().leftTotalCounts=leftChild.size_;
676  trees_online_information[tree_id].mag_distributions.back().rightTotalCounts=rightChild.size_;
677  //Store the gap
678  double gap_left,gap_right;
679  int i;
680  gap_left=features(leftChild[0],split.bestSplitColumn());
681  for(i=1;i<leftChild.size();++i)
682  if(features(leftChild[i],split.bestSplitColumn())>gap_left)
683  gap_left=features(leftChild[i],split.bestSplitColumn());
684  gap_right=features(rightChild[0],split.bestSplitColumn());
685  for(i=1;i<rightChild.size();++i)
686  if(features(rightChild[i],split.bestSplitColumn())<gap_right)
687  gap_right=features(rightChild[i],split.bestSplitColumn());
688  trees_online_information[tree_id].mag_distributions.back().gap_left=gap_left;
689  trees_online_information[tree_id].mag_distributions.back().gap_right=gap_right;
690  }
691  }
692  else
693  {
694  //Store index list
695  linear_index=trees_online_information[tree_id].index_lists.size();
696  trees_online_information[tree_id].exterior_to_index[addr]=linear_index;
697 
698  trees_online_information[tree_id].index_lists.push_back(IndexList());
699 
700  trees_online_information[tree_id].index_lists.back().resize(parent.size_,0);
701  std::copy(parent.begin_,parent.end_,trees_online_information[tree_id].index_lists.back().begin());
702  }
703  }
704  void add_to_index_list(int tree,int node,int index)
705  {
706  if(!this->active_)
707  return;
708  TreeOnlineInformation &ti=trees_online_information[tree];
709  ti.index_lists[ti.exterior_to_index[node]].push_back(index);
710  }
711  void move_exterior_node(int src_tree,int src_index,int dst_tree,int dst_index)
712  {
713  if(!this->active_)
714  return;
715  trees_online_information[dst_tree].exterior_to_index[dst_index]=trees_online_information[src_tree].exterior_to_index[src_index];
716  trees_online_information[src_tree].exterior_to_index.erase(src_index);
717  }
718  /** do something when visiting a internal node during getToLeaf
719  *
720  * remember as last node id, for finding the parent of the last external node
721  * also: adjust class counts and borders
722  */
723  template<class TR, class IntT, class TopT,class Feat>
724  void visit_internal_node(TR & tr, IntT index, TopT node_t,Feat & features)
725  {
726  last_node_id=index;
727  if(adjust_thresholds)
728  {
729  vigra_assert(node_t==i_ThresholdNode,"We can only visit threshold nodes");
730  //Check if we are in the gap
731  double value=features(0, Node<i_ThresholdNode>(tr.topology_,tr.parameters_,index).column());
732  TreeOnlineInformation &ti=trees_online_information[tree_id];
733  MarginalDistribution &m=ti.mag_distributions[ti.interior_to_index[index]];
734  if(value>m.gap_left && value<m.gap_right)
735  {
736  //Check which site we want to go
737  if(m.leftCounts[current_label]/double(m.leftTotalCounts)>m.rightCounts[current_label]/double(m.rightTotalCounts))
738  {
739  //We want to go left
740  m.gap_left=value;
741  }
742  else
743  {
744  //We want to go right
745  m.gap_right=value;
746  }
747  Node<i_ThresholdNode>(tr.topology_,tr.parameters_,index).threshold()=(m.gap_right+m.gap_left)/2.0;
748  }
749  //Adjust class counts
750  if(value>Node<i_ThresholdNode>(tr.topology_,tr.parameters_,index).threshold())
751  {
752  ++m.rightTotalCounts;
753  ++m.rightCounts[current_label];
754  }
755  else
756  {
757  ++m.leftTotalCounts;
758  ++m.rightCounts[current_label];
759  }
760  }
761  }
762  /** do something when visiting a extern node during getToLeaf
763  *
764  * Store the new index!
765  */
766 };
767 
768 //////////////////////////////////////////////////////////////////////////////
769 // Out of Bag Error estimates //
770 //////////////////////////////////////////////////////////////////////////////
771 
772 
773 /** Visitor that calculates the oob error of each individual randomized
774  * decision tree.
775  *
776  * After training a tree, all those samples that are OOB for this particular tree
777  * are put down the tree and the error estimated.
778  * the per tree oob error is the average of the individual error estimates.
779  * (oobError = average error of one randomized tree)
780  * Note: This is Not the OOB - Error estimate suggested by Breiman (See OOB_Error
781  * visitor)
782  */
784 {
785 public:
786  /** Average error of one randomized decision tree
787  */
788  double oobError;
789 
790  int totalOobCount;
791  ArrayVector<int> oobCount,oobErrorCount;
792 
794  : oobError(0.0),
795  totalOobCount(0)
796  {}
797 
798 
799  bool has_value()
800  {
801  return true;
802  }
803 
804 
805  /** does the basic calculation per tree*/
806  template<class RF, class PR, class SM, class ST>
807  void visit_after_tree(RF & rf, PR & pr, SM & sm, ST &, int index)
808  {
809  //do the first time called.
810  if(int(oobCount.size()) != rf.ext_param_.row_count_)
811  {
812  oobCount.resize(rf.ext_param_.row_count_, 0);
813  oobErrorCount.resize(rf.ext_param_.row_count_, 0);
814  }
815  // go through the samples
816  for(int l = 0; l < rf.ext_param_.row_count_; ++l)
817  {
818  // if the lth sample is oob...
819  if(!sm.is_used()[l])
820  {
821  ++oobCount[l];
822  if( rf.tree(index)
823  .predictLabel(rowVector(pr.features(), l))
824  != pr.response()(l,0))
825  {
826  ++oobErrorCount[l];
827  }
828  }
829 
830  }
831  }
832 
833  /** Does the normalisation
834  */
835  template<class RF, class PR>
836  void visit_at_end(RF & rf, PR &)
837  {
838  // do some normalisation
839  for(int l=0; l < static_cast<int>(rf.ext_param_.row_count_); ++l)
840  {
841  if(oobCount[l])
842  {
843  oobError += double(oobErrorCount[l]) / oobCount[l];
844  ++totalOobCount;
845  }
846  }
847  oobError/=totalOobCount;
848  }
849 
850 };
851 
852 /** Visitor that calculates the oob error of the ensemble
853  * This rate should be used to estimate the crossvalidation
854  * error rate.
855  * Here each sample is put down those trees, for which this sample
856  * is OOB i.e. if sample #1 is OOB for trees 1, 3 and 5 we calculate
857  * the output using the ensemble consisting only of trees 1 3 and 5.
858  *
859  * Using normal bagged sampling each sample is OOB for approx. 33% of trees
860  * The error rate obtained as such therefore corresponds to crossvalidation
861  * rate obtained using a ensemble containing 33% of the trees.
862  */
863 class OOB_Error : public VisitorBase
864 {
865  typedef MultiArrayShape<2>::type Shp;
866  int class_count;
867  bool is_weighted;
868  MultiArray<2,double> tmp_prob;
869  public:
870 
871  MultiArray<2, double> prob_oob;
872  /** Ensemble oob error rate
873  */
874  double oob_breiman;
875 
876  MultiArray<2, double> oobCount;
877  ArrayVector< int> indices;
878  OOB_Error() : VisitorBase(), oob_breiman(0.0) {}
879 #ifdef HasHDF5
880  void save(std::string filen, std::string pathn)
881  {
882  if(*(pathn.end()-1) != '/')
883  pathn += "/";
884  const char* filename = filen.c_str();
885  MultiArray<2, double> temp(Shp(1,1), 0.0);
886  temp[0] = oob_breiman;
887  writeHDF5(filename, (pathn + "breiman_error").c_str(), temp);
888  }
889 #endif
890  // negative value if sample was ib, number indicates how often.
891  // value >=0 if sample was oob, 0 means fail 1, correct
892 
893  template<class RF, class PR>
894  void visit_at_beginning(RF & rf, PR &)
895  {
896  class_count = rf.class_count();
897  tmp_prob.reshape(Shp(1, class_count), 0);
898  prob_oob.reshape(Shp(rf.ext_param().row_count_,class_count), 0);
899  is_weighted = rf.options().predict_weighted_;
900  indices.resize(rf.ext_param().row_count_);
901  if(int(oobCount.size()) != rf.ext_param_.row_count_)
902  {
903  oobCount.reshape(Shp(rf.ext_param_.row_count_, 1), 0);
904  }
905  for(int ii = 0; ii < rf.ext_param().row_count_; ++ii)
906  {
907  indices[ii] = ii;
908  }
909  }
910 
911  template<class RF, class PR, class SM, class ST>
912  void visit_after_tree(RF& rf, PR & pr, SM & sm, ST &, int index)
913  {
914  // go through the samples
915  int total_oob =0;
916  // FIXME: magic number 10000: invoke special treatment when when msample << sample_count
917  // (i.e. the OOB sample ist very large)
918  // 40000: use at most 40000 OOB samples per class for OOB error estimate
919  if(rf.ext_param_.actual_msample_ < pr.features().shape(0) - 10000)
920  {
921  ArrayVector<int> oob_indices;
922  ArrayVector<int> cts(class_count, 0);
923  std::random_shuffle(indices.begin(), indices.end());
924  for(int ii = 0; ii < rf.ext_param_.row_count_; ++ii)
925  {
926  if(!sm.is_used()[indices[ii]] && cts[pr.response()(indices[ii], 0)] < 40000)
927  {
928  oob_indices.push_back(indices[ii]);
929  ++cts[pr.response()(indices[ii], 0)];
930  }
931  }
932  for(unsigned int ll = 0; ll < oob_indices.size(); ++ll)
933  {
934  // update number of trees in which current sample is oob
935  ++oobCount[oob_indices[ll]];
936 
937  // update number of oob samples in this tree.
938  ++total_oob;
939  // get the predicted votes ---> tmp_prob;
940  int pos = rf.tree(index).getToLeaf(rowVector(pr.features(),oob_indices[ll]));
941  Node<e_ConstProbNode> node ( rf.tree(index).topology_,
942  rf.tree(index).parameters_,
943  pos);
944  tmp_prob.init(0);
945  for(int ii = 0; ii < class_count; ++ii)
946  {
947  tmp_prob[ii] = node.prob_begin()[ii];
948  }
949  if(is_weighted)
950  {
951  for(int ii = 0; ii < class_count; ++ii)
952  tmp_prob[ii] = tmp_prob[ii] * (*(node.prob_begin()-1));
953  }
954  rowVector(prob_oob, oob_indices[ll]) += tmp_prob;
955 
956  }
957  }else
958  {
959  for(int ll = 0; ll < rf.ext_param_.row_count_; ++ll)
960  {
961  // if the lth sample is oob...
962  if(!sm.is_used()[ll])
963  {
964  // update number of trees in which current sample is oob
965  ++oobCount[ll];
966 
967  // update number of oob samples in this tree.
968  ++total_oob;
969  // get the predicted votes ---> tmp_prob;
970  int pos = rf.tree(index).getToLeaf(rowVector(pr.features(),ll));
971  Node<e_ConstProbNode> node ( rf.tree(index).topology_,
972  rf.tree(index).parameters_,
973  pos);
974  tmp_prob.init(0);
975  for(int ii = 0; ii < class_count; ++ii)
976  {
977  tmp_prob[ii] = node.prob_begin()[ii];
978  }
979  if(is_weighted)
980  {
981  for(int ii = 0; ii < class_count; ++ii)
982  tmp_prob[ii] = tmp_prob[ii] * (*(node.prob_begin()-1));
983  }
984  rowVector(prob_oob, ll) += tmp_prob;
985  }
986  }
987  }
988  // go through the ib samples;
989  }
990 
991  /** Normalise variable importance after the number of trees is known.
992  */
993  template<class RF, class PR>
994  void visit_at_end(RF & rf, PR & pr)
995  {
996  // ullis original metric and breiman style stuff
997  int totalOobCount =0;
998  int breimanstyle = 0;
999  for(int ll=0; ll < static_cast<int>(rf.ext_param_.row_count_); ++ll)
1000  {
1001  if(oobCount[ll])
1002  {
1003  if(argMax(rowVector(prob_oob, ll)) != pr.response()(ll, 0))
1004  ++breimanstyle;
1005  ++totalOobCount;
1006  }
1007  }
1008  oob_breiman = double(breimanstyle)/totalOobCount;
1009  }
1010 };
1011 
1012 
1013 /** Visitor that calculates different OOB error statistics
1014  */
1016 {
1017  typedef MultiArrayShape<2>::type Shp;
1018  int class_count;
1019  bool is_weighted;
1020  MultiArray<2,double> tmp_prob;
1021  public:
1022 
1023  /** OOB Error rate of each individual tree
1024  */
1026  /** Mean of oob_per_tree
1027  */
1028  double oob_mean;
1029  /**Standard deviation of oob_per_tree
1030  */
1031  double oob_std;
1032 
1033  MultiArray<2, double> prob_oob;
1034  /** Ensemble OOB error
1035  *
1036  * \sa OOB_Error
1037  */
1038  double oob_breiman;
1039 
1040  MultiArray<2, double> oobCount;
1041  MultiArray<2, double> oobErrorCount;
1042  /** Per Tree OOB error calculated as in OOB_PerTreeError
1043  * (Ulli's version)
1044  */
1046 
1047  /**Column containing the development of the Ensemble
1048  * error rate with increasing number of trees
1049  */
1051  /** 4 dimensional array containing the development of confusion matrices
1052  * with number of trees - can be used to estimate ROC curves etc.
1053  *
1054  * oobroc_per_tree(ii,jj,kk,ll)
1055  * corresponds true label = ii
1056  * predicted label = jj
1057  * confusion matrix after ll trees
1058  *
1059  * explanation of third index:
1060  *
1061  * Two class case:
1062  * kk = 0 - (treeCount-1)
1063  * Threshold is on Probability for class 0 is kk/(treeCount-1);
1064  * More classes:
1065  * kk = 0. Threshold on probability set by argMax of the probability array.
1066  */
1068 
1070 
1071 #ifdef HasHDF5
1072  /** save to HDF5 file
1073  */
1074  void save(std::string filen, std::string pathn)
1075  {
1076  if(*(pathn.end()-1) != '/')
1077  pathn += "/";
1078  const char* filename = filen.c_str();
1079  MultiArray<2, double> temp(Shp(1,1), 0.0);
1080  writeHDF5(filename, (pathn + "oob_per_tree").c_str(), oob_per_tree);
1081  writeHDF5(filename, (pathn + "oobroc_per_tree").c_str(), oobroc_per_tree);
1082  writeHDF5(filename, (pathn + "breiman_per_tree").c_str(), breiman_per_tree);
1083  temp[0] = oob_mean;
1084  writeHDF5(filename, (pathn + "per_tree_error").c_str(), temp);
1085  temp[0] = oob_std;
1086  writeHDF5(filename, (pathn + "per_tree_error_std").c_str(), temp);
1087  temp[0] = oob_breiman;
1088  writeHDF5(filename, (pathn + "breiman_error").c_str(), temp);
1089  temp[0] = oob_per_tree2;
1090  writeHDF5(filename, (pathn + "ulli_error").c_str(), temp);
1091  }
1092 #endif
1093  // negative value if sample was ib, number indicates how often.
1094  // value >=0 if sample was oob, 0 means fail 1, correct
1095 
1096  template<class RF, class PR>
1097  void visit_at_beginning(RF & rf, PR &)
1098  {
1099  class_count = rf.class_count();
1100  if(class_count == 2)
1101  oobroc_per_tree.reshape(MultiArrayShape<4>::type(2,2,rf.tree_count(), rf.tree_count()));
1102  else
1103  oobroc_per_tree.reshape(MultiArrayShape<4>::type(rf.class_count(),rf.class_count(),1, rf.tree_count()));
1104  tmp_prob.reshape(Shp(1, class_count), 0);
1105  prob_oob.reshape(Shp(rf.ext_param().row_count_,class_count), 0);
1106  is_weighted = rf.options().predict_weighted_;
1107  oob_per_tree.reshape(Shp(1, rf.tree_count()), 0);
1108  breiman_per_tree.reshape(Shp(1, rf.tree_count()), 0);
1109  //do the first time called.
1110  if(int(oobCount.size()) != rf.ext_param_.row_count_)
1111  {
1112  oobCount.reshape(Shp(rf.ext_param_.row_count_, 1), 0);
1113  oobErrorCount.reshape(Shp(rf.ext_param_.row_count_,1), 0);
1114  }
1115  }
1116 
1117  template<class RF, class PR, class SM, class ST>
1118  void visit_after_tree(RF& rf, PR & pr, SM & sm, ST &, int index)
1119  {
1120  // go through the samples
1121  int total_oob =0;
1122  int wrong_oob =0;
1123  for(int ll = 0; ll < rf.ext_param_.row_count_; ++ll)
1124  {
1125  // if the lth sample is oob...
1126  if(!sm.is_used()[ll])
1127  {
1128  // update number of trees in which current sample is oob
1129  ++oobCount[ll];
1130 
1131  // update number of oob samples in this tree.
1132  ++total_oob;
1133  // get the predicted votes ---> tmp_prob;
1134  int pos = rf.tree(index).getToLeaf(rowVector(pr.features(),ll));
1135  Node<e_ConstProbNode> node ( rf.tree(index).topology_,
1136  rf.tree(index).parameters_,
1137  pos);
1138  tmp_prob.init(0);
1139  for(int ii = 0; ii < class_count; ++ii)
1140  {
1141  tmp_prob[ii] = node.prob_begin()[ii];
1142  }
1143  if(is_weighted)
1144  {
1145  for(int ii = 0; ii < class_count; ++ii)
1146  tmp_prob[ii] = tmp_prob[ii] * (*(node.prob_begin()-1));
1147  }
1148  rowVector(prob_oob, ll) += tmp_prob;
1149  int label = argMax(tmp_prob);
1150 
1151  if(label != pr.response()(ll, 0))
1152  {
1153  // update number of wrong oob samples in this tree.
1154  ++wrong_oob;
1155  // update number of trees in which current sample is wrong oob
1156  ++oobErrorCount[ll];
1157  }
1158  }
1159  }
1160  int breimanstyle = 0;
1161  int totalOobCount = 0;
1162  for(int ll=0; ll < static_cast<int>(rf.ext_param_.row_count_); ++ll)
1163  {
1164  if(oobCount[ll])
1165  {
1166  if(argMax(rowVector(prob_oob, ll)) != pr.response()(ll, 0))
1167  ++breimanstyle;
1168  ++totalOobCount;
1169  if(oobroc_per_tree.shape(2) == 1)
1170  {
1171  oobroc_per_tree(pr.response()(ll,0), argMax(rowVector(prob_oob, ll)),0 ,index)++;
1172  }
1173  }
1174  }
1175  if(oobroc_per_tree.shape(2) == 1)
1176  oobroc_per_tree.bindOuter(index)/=totalOobCount;
1177  if(oobroc_per_tree.shape(2) > 1)
1178  {
1179  MultiArrayView<3, double> current_roc
1180  = oobroc_per_tree.bindOuter(index);
1181  for(int gg = 0; gg < current_roc.shape(2); ++gg)
1182  {
1183  for(int ll=0; ll < static_cast<int>(rf.ext_param_.row_count_); ++ll)
1184  {
1185  if(oobCount[ll])
1186  {
1187  int pred = prob_oob(ll, 1) > (double(gg)/double(current_roc.shape(2)))?
1188  1 : 0;
1189  current_roc(pr.response()(ll, 0), pred, gg)+= 1;
1190  }
1191  }
1192  current_roc.bindOuter(gg)/= totalOobCount;
1193  }
1194  }
1195  breiman_per_tree[index] = double(breimanstyle)/double(totalOobCount);
1196  oob_per_tree[index] = double(wrong_oob)/double(total_oob);
1197  // go through the ib samples;
1198  }
1199 
1200  /** Normalise variable importance after the number of trees is known.
1201  */
1202  template<class RF, class PR>
1203  void visit_at_end(RF & rf, PR & pr)
1204  {
1205  // ullis original metric and breiman style stuff
1206  oob_per_tree2 = 0;
1207  int totalOobCount =0;
1208  int breimanstyle = 0;
1209  for(int ll=0; ll < static_cast<int>(rf.ext_param_.row_count_); ++ll)
1210  {
1211  if(oobCount[ll])
1212  {
1213  if(argMax(rowVector(prob_oob, ll)) != pr.response()(ll, 0))
1214  ++breimanstyle;
1215  oob_per_tree2 += double(oobErrorCount[ll]) / oobCount[ll];
1216  ++totalOobCount;
1217  }
1218  }
1219  oob_per_tree2 /= totalOobCount;
1220  oob_breiman = double(breimanstyle)/totalOobCount;
1221  // mean error of each tree
1222  MultiArrayView<2, double> mean(Shp(1,1), &oob_mean);
1223  MultiArrayView<2, double> stdDev(Shp(1,1), &oob_std);
1224  rowStatistics(oob_per_tree, mean, stdDev);
1225  }
1226 };
1227 
1228 /** calculate variable importance while learning.
1229  */
1231 {
1232  public:
1233 
1234  /** This Array has the same entries as the R - random forest variable
1235  * importance.
1236  * Matrix is featureCount by (classCount +2)
1237  * variable_importance_(ii,jj) is the variable importance measure of
1238  * the ii-th variable according to:
1239  * jj = 0 - (classCount-1)
1240  * classwise permutation importance
1241  * jj = rowCount(variable_importance_) -2
1242  * permutation importance
1243  * jj = rowCount(variable_importance_) -1
1244  * gini decrease importance.
1245  *
1246  * permutation importance:
1247  * The difference between the fraction of OOB samples classified correctly
1248  * before and after permuting (randomizing) the ii-th column is calculated.
1249  * The ii-th column is permuted rep_cnt times.
1250  *
1251  * class wise permutation importance:
1252  * same as permutation importance. We only look at those OOB samples whose
1253  * response corresponds to class jj.
1254  *
1255  * gini decrease importance:
1256  * row ii corresponds to the sum of all gini decreases induced by variable ii
1257  * in each node of the random forest.
1258  */
1260  int repetition_count_;
1261  bool in_place_;
1262 
1263 #ifdef HasHDF5
1264  void save(std::string filename, std::string prefix)
1265  {
1266  prefix = "variable_importance_" + prefix;
1267  writeHDF5(filename.c_str(),
1268  prefix.c_str(),
1270  }
1271 #endif
1272 
1273  /* Constructor
1274  * \param rep_cnt (defautl: 10) how often should
1275  * the permutation take place. Set to 1 to make calculation faster (but
1276  * possibly more instable)
1277  */
1278  VariableImportanceVisitor(int rep_cnt = 10)
1279  : repetition_count_(rep_cnt)
1280 
1281  {}
1282 
1283  /** calculates impurity decrease based variable importance after every
1284  * split.
1285  */
1286  template<class Tree, class Split, class Region, class Feature_t, class Label_t>
1287  void visit_after_split( Tree & tree,
1288  Split & split,
1289  Region & /* parent */,
1290  Region & /* leftChild */,
1291  Region & /* rightChild */,
1292  Feature_t & /* features */,
1293  Label_t & /* labels */)
1294  {
1295  //resize to right size when called the first time
1296 
1297  Int32 const class_count = tree.ext_param_.class_count_;
1298  Int32 const column_count = tree.ext_param_.column_count_;
1299  if(variable_importance_.size() == 0)
1300  {
1301 
1303  .reshape(MultiArrayShape<2>::type(column_count,
1304  class_count+2));
1305  }
1306 
1307  if(split.createNode().typeID() == i_ThresholdNode)
1308  {
1309  Node<i_ThresholdNode> node(split.createNode());
1310  variable_importance_(node.column(),class_count+1)
1311  += split.region_gini_ - split.minGini();
1312  }
1313  }
1314 
1315  /**compute permutation based var imp.
1316  * (Only an Array of size oob_sample_count x 1 is created.
1317  * - apposed to oob_sample_count x feature_count in the other method.
1318  *
1319  * \sa FieldProxy
1320  */
1321  template<class RF, class PR, class SM, class ST>
1322  void after_tree_ip_impl(RF& rf, PR & pr, SM & sm, ST & /* st */, int index)
1323  {
1324  typedef MultiArrayShape<2>::type Shp_t;
1325  Int32 column_count = rf.ext_param_.column_count_;
1326  Int32 class_count = rf.ext_param_.class_count_;
1327 
1328  /* This solution saves memory uptake but not multithreading
1329  * compatible
1330  */
1331  // remove the const cast on the features (yep , I know what I am
1332  // doing here.) data is not destroyed.
1333  //typename PR::Feature_t & features
1334  // = const_cast<typename PR::Feature_t &>(pr.features());
1335 
1336  typedef typename PR::FeatureWithMemory_t FeatureArray;
1337  typedef typename FeatureArray::value_type FeatureValue;
1338 
1339  FeatureArray features = pr.features();
1340 
1341  //find the oob indices of current tree.
1342  ArrayVector<Int32> oob_indices;
1344  iter;
1345  for(int ii = 0; ii < rf.ext_param_.row_count_; ++ii)
1346  if(!sm.is_used()[ii])
1347  oob_indices.push_back(ii);
1348 
1349  //create space to back up a column
1350  ArrayVector<FeatureValue> backup_column;
1351 
1352  // Random foo
1353 #ifdef CLASSIFIER_TEST
1354  RandomMT19937 random(1);
1355 #else
1356  RandomMT19937 random(RandomSeed);
1357 #endif
1359  randint(random);
1360 
1361 
1362  //make some space for the results
1364  oob_right(Shp_t(1, class_count + 1));
1366  perm_oob_right (Shp_t(1, class_count + 1));
1367 
1368 
1369  // get the oob success rate with the original samples
1370  for(iter = oob_indices.begin();
1371  iter != oob_indices.end();
1372  ++iter)
1373  {
1374  if(rf.tree(index)
1375  .predictLabel(rowVector(features, *iter))
1376  == pr.response()(*iter, 0))
1377  {
1378  //per class
1379  ++oob_right[pr.response()(*iter,0)];
1380  //total
1381  ++oob_right[class_count];
1382  }
1383  }
1384  //get the oob rate after permuting the ii'th dimension.
1385  for(int ii = 0; ii < column_count; ++ii)
1386  {
1387  perm_oob_right.init(0.0);
1388  //make backup of original column
1389  backup_column.clear();
1390  for(iter = oob_indices.begin();
1391  iter != oob_indices.end();
1392  ++iter)
1393  {
1394  backup_column.push_back(features(*iter,ii));
1395  }
1396 
1397  //get the oob rate after permuting the ii'th dimension.
1398  for(int rr = 0; rr < repetition_count_; ++rr)
1399  {
1400  //permute dimension.
1401  int n = oob_indices.size();
1402  for(int jj = n-1; jj >= 1; --jj)
1403  std::swap(features(oob_indices[jj], ii),
1404  features(oob_indices[randint(jj+1)], ii));
1405 
1406  //get the oob success rate after permuting
1407  for(iter = oob_indices.begin();
1408  iter != oob_indices.end();
1409  ++iter)
1410  {
1411  if(rf.tree(index)
1412  .predictLabel(rowVector(features, *iter))
1413  == pr.response()(*iter, 0))
1414  {
1415  //per class
1416  ++perm_oob_right[pr.response()(*iter, 0)];
1417  //total
1418  ++perm_oob_right[class_count];
1419  }
1420  }
1421  }
1422 
1423 
1424  //normalise and add to the variable_importance array.
1425  perm_oob_right /= repetition_count_;
1426  perm_oob_right -=oob_right;
1427  perm_oob_right *= -1;
1428  perm_oob_right /= oob_indices.size();
1430  .subarray(Shp_t(ii,0),
1431  Shp_t(ii+1,class_count+1)) += perm_oob_right;
1432  //copy back permuted dimension
1433  for(int jj = 0; jj < int(oob_indices.size()); ++jj)
1434  features(oob_indices[jj], ii) = backup_column[jj];
1435  }
1436  }
1437 
1438  /** calculate permutation based impurity after every tree has been
1439  * learned default behaviour is that this happens out of place.
1440  * If you have very big data sets and want to avoid copying of data
1441  * set the in_place_ flag to true.
1442  */
1443  template<class RF, class PR, class SM, class ST>
1444  void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st, int index)
1445  {
1446  after_tree_ip_impl(rf, pr, sm, st, index);
1447  }
1448 
1449  /** Normalise variable importance after the number of trees is known.
1450  */
1451  template<class RF, class PR>
1452  void visit_at_end(RF & rf, PR & /* pr */)
1453  {
1454  variable_importance_ /= rf.trees_.size();
1455  }
1456 };
1457 
1458 /** Verbose output
1459  */
1461  public:
1463 
1464  template<class RF, class PR, class SM, class ST>
1465  void visit_after_tree(RF& rf, PR &, SM &, ST &, int index){
1466  if(index != rf.options().tree_count_-1) {
1467  std::cout << "\r[" << std::setw(10) << (index+1)/static_cast<double>(rf.options().tree_count_)*100 << "%]"
1468  << " (" << index+1 << " of " << rf.options().tree_count_ << ") done" << std::flush;
1469  }
1470  else {
1471  std::cout << "\r[" << std::setw(10) << 100.0 << "%]" << std::endl;
1472  }
1473  }
1474 
1475  template<class RF, class PR>
1476  void visit_at_end(RF const & rf, PR const &) {
1477  std::string a = TOCS;
1478  std::cout << "all " << rf.options().tree_count_ << " trees have been learned in " << a << std::endl;
1479  }
1480 
1481  template<class RF, class PR>
1482  void visit_at_beginning(RF const & rf, PR const &) {
1483  TIC;
1484  std::cout << "growing random forest, which will have " << rf.options().tree_count_ << " trees" << std::endl;
1485  }
1486 
1487  private:
1488  USETICTOC;
1489 };
1490 
1491 
1492 /** Computes Correlation/Similarity Matrix of features while learning
1493  * random forest.
1494  */
1496 {
1497  public:
1498  /** gini_missc(ii, jj) describes how well variable jj can describe a partition
1499  * created on variable ii(when variable ii was chosen)
1500  */
1502  MultiArray<2, int> tmp_labels;
1503  /** additional noise features.
1504  */
1506  MultiArray<2, double> noise_l;
1507  /** how well can a noise column describe a partition created on variable ii.
1508  */
1510  MultiArray<2, double> corr_l;
1511 
1512  /** Similarity Matrix
1513  *
1514  * (numberOfFeatures + 1) by (number Of Features + 1) Matrix
1515  * gini_missc
1516  * - row normalized by the number of times the column was chosen
1517  * - mean of corr_noise subtracted
1518  * - and symmetrised.
1519  *
1520  */
1522  /** Distance Matrix 1-similarity
1523  */
1525  ArrayVector<int> tmp_cc;
1526 
1527  /** How often was variable ii chosen
1528  */
1532  void save(std::string, std::string)
1533  {
1534  /*
1535  std::string tmp;
1536 #define VAR_WRITE(NAME) \
1537  tmp = #NAME;\
1538  tmp += "_";\
1539  tmp += prefix;\
1540  vigra::writeToHDF5File(file.c_str(), tmp.c_str(), NAME);
1541  VAR_WRITE(gini_missc);
1542  VAR_WRITE(corr_noise);
1543  VAR_WRITE(distance);
1544  VAR_WRITE(similarity);
1545  vigra::writeToHDF5File(file.c_str(), "nChoices", MultiArrayView<2, int>(MultiArrayShape<2>::type(numChoices.size(),1), numChoices.data()));
1546 #undef VAR_WRITE
1547 */
1548  }
1549 
1550  template<class RF, class PR>
1551  void visit_at_beginning(RF const & rf, PR & pr)
1552  {
1553  typedef MultiArrayShape<2>::type Shp;
1554  int n = rf.ext_param_.column_count_;
1555  gini_missc.reshape(Shp(n +1,n+ 1));
1556  corr_noise.reshape(Shp(n + 1, 10));
1557  corr_l.reshape(Shp(n +1, 10));
1558 
1559  noise.reshape(Shp(pr.features().shape(0), 10));
1560  noise_l.reshape(Shp(pr.features().shape(0), 10));
1561  RandomMT19937 random(RandomSeed);
1562  for(int ii = 0; ii < noise.size(); ++ii)
1563  {
1564  noise[ii] = random.uniform53();
1565  noise_l[ii] = random.uniform53() > 0.5;
1566  }
1567  bgfunc = ColumnDecisionFunctor( rf.ext_param_);
1568  tmp_labels.reshape(pr.response().shape());
1569  tmp_cc.resize(2);
1570  numChoices.resize(n+1);
1571  // look at all axes
1572  }
1573  template<class RF, class PR>
1574  void visit_at_end(RF const &, PR const &)
1575  {
1576  typedef MultiArrayShape<2>::type Shp;
1579  MultiArray<2, double> mean_noise(Shp(corr_noise.shape(0), 1));
1580  rowStatistics(corr_noise, mean_noise);
1581  mean_noise/= MultiArrayView<2, int>(mean_noise.shape(), numChoices.data());
1582  int rC = similarity.shape(0);
1583  for(int jj = 0; jj < rC-1; ++jj)
1584  {
1585  rowVector(similarity, jj) /= numChoices[jj];
1586  rowVector(similarity, jj) -= mean_noise(jj, 0);
1587  }
1588  for(int jj = 0; jj < rC; ++jj)
1589  {
1590  similarity(rC -1, jj) /= numChoices[jj];
1591  }
1592  rowVector(similarity, rC - 1) -= mean_noise(rC-1, 0);
1594  FindMinMax<double> minmax;
1595  inspectMultiArray(srcMultiArrayRange(similarity), minmax);
1596 
1597  for(int jj = 0; jj < rC; ++jj)
1598  similarity(jj, jj) = minmax.max;
1599 
1600  similarity.subarray(Shp(0,0), Shp(rC-1, rC-1))
1601  += similarity.subarray(Shp(0,0), Shp(rC-1, rC-1)).transpose();
1602  similarity.subarray(Shp(0,0), Shp(rC-1, rC-1))/= 2;
1603  columnVector(similarity, rC-1) = rowVector(similarity, rC-1).transpose();
1604  for(int jj = 0; jj < rC; ++jj)
1605  similarity(jj, jj) = 0;
1606 
1607  FindMinMax<double> minmax2;
1608  inspectMultiArray(srcMultiArrayRange(similarity), minmax2);
1609  for(int jj = 0; jj < rC; ++jj)
1610  similarity(jj, jj) = minmax2.max;
1611  distance.reshape(gini_missc.shape(), minmax2.max);
1612  distance -= similarity;
1613  }
1614 
1615  template<class Tree, class Split, class Region, class Feature_t, class Label_t>
1616  void visit_after_split( Tree &,
1617  Split & split,
1618  Region & parent,
1619  Region &,
1620  Region &,
1621  Feature_t & features,
1622  Label_t & labels)
1623  {
1624  if(split.createNode().typeID() == i_ThresholdNode)
1625  {
1626  double wgini;
1627  tmp_cc.init(0);
1628  for(int ii = 0; ii < parent.size(); ++ii)
1629  {
1630  tmp_labels[parent[ii]]
1631  = (features(parent[ii], split.bestSplitColumn()) < split.bestSplitThreshold());
1632  ++tmp_cc[tmp_labels[parent[ii]]];
1633  }
1634  double region_gini = bgfunc.loss_of_region(tmp_labels,
1635  parent.begin(),
1636  parent.end(),
1637  tmp_cc);
1638 
1639  int n = split.bestSplitColumn();
1640  ++numChoices[n];
1641  ++(*(numChoices.end()-1));
1642  //this functor does all the work
1643  for(int k = 0; k < features.shape(1); ++k)
1644  {
1645  bgfunc(columnVector(features, k),
1646  tmp_labels,
1647  parent.begin(), parent.end(),
1648  tmp_cc);
1649  wgini = (region_gini - bgfunc.min_gini_);
1650  gini_missc(n, k)
1651  += wgini;
1652  }
1653  for(int k = 0; k < 10; ++k)
1654  {
1655  bgfunc(columnVector(noise, k),
1656  tmp_labels,
1657  parent.begin(), parent.end(),
1658  tmp_cc);
1659  wgini = (region_gini - bgfunc.min_gini_);
1660  corr_noise(n, k)
1661  += wgini;
1662  }
1663 
1664  for(int k = 0; k < 10; ++k)
1665  {
1666  bgfunc(columnVector(noise_l, k),
1667  tmp_labels,
1668  parent.begin(), parent.end(),
1669  tmp_cc);
1670  wgini = (region_gini - bgfunc.min_gini_);
1671  corr_l(n, k)
1672  += wgini;
1673  }
1674  bgfunc(labels, tmp_labels, parent.begin(), parent.end(),tmp_cc);
1675  wgini = (region_gini - bgfunc.min_gini_);
1677  += wgini;
1678 
1679  region_gini = split.region_gini_;
1680 #if 1
1681  Node<i_ThresholdNode> node(split.createNode());
1683  node.column())
1684  +=split.region_gini_ - split.minGini();
1685 #endif
1686  for(int k = 0; k < 10; ++k)
1687  {
1688  split.bgfunc(columnVector(noise, k),
1689  labels,
1690  parent.begin(), parent.end(),
1691  parent.classCounts());
1693  k)
1694  += wgini;
1695  }
1696 #if 0
1697  for(int k = 0; k < tree.ext_param_.actual_mtry_; ++k)
1698  {
1699  wgini = region_gini - split.min_gini_[k];
1700 
1702  split.splitColumns[k])
1703  += wgini;
1704  }
1705 
1706  for(int k=tree.ext_param_.actual_mtry_; k<features.shape(1); ++k)
1707  {
1708  split.bgfunc(columnVector(features, split.splitColumns[k]),
1709  labels,
1710  parent.begin(), parent.end(),
1711  parent.classCounts());
1712  wgini = region_gini - split.bgfunc.min_gini_;
1714  split.splitColumns[k]) += wgini;
1715  }
1716 #endif
1717  // remember to partition the data according to the best.
1719  columnCount(gini_missc)-1)
1720  += region_gini;
1721  SortSamplesByDimensions<Feature_t>
1722  sorter(features, split.bestSplitColumn(), split.bestSplitThreshold());
1723  std::partition(parent.begin(), parent.end(), sorter);
1724  }
1725  }
1726 };
1727 
1728 
1729 } // namespace visitors
1730 } // namespace rf
1731 } // namespace vigra
1732 
1733 //@}
1734 #endif // RF_VISITORS_HXX
#define TIC
Definition: timing.hxx:322
void visit_internal_node(TR &tr, IntT index, TopT node_t, Feat &features)
Definition: rf_visitors.hxx:724
MultiArrayView< 2, T, C > columnVector(MultiArrayView< 2, T, C > const &m, MultiArrayIndex d)
Definition: matrix.hxx:725
MultiArray< 2, double > breiman_per_tree
Definition: rf_visitors.hxx:1050
MultiArray< 2, double > gini_missc
Definition: rf_visitors.hxx:1501
MultiArrayIndex rowCount(const MultiArrayView< 2, T, C > &x)
Definition: matrix.hxx:669
const difference_type & shape() const
Definition: multi_array.hxx:1594
void visit_at_end(RF const &rf, PR const &pr)
Definition: rf_visitors.hxx:176
void visit_after_tree(RF &rf, PR &pr, SM &sm, ST &st, int index)
Definition: rf_visitors.hxx:164
void visit_after_split(Tree &tree, Split &split, Region &, Region &, Region &, Feature_t &, Label_t &)
Definition: rf_visitors.hxx:1287
const_iterator begin() const
Definition: array_vector.hxx:223
double oobError
Definition: rf_visitors.hxx:788
MultiArray< 2, double > similarity
Definition: rf_visitors.hxx:1521
void reshape(const difference_type &shape)
Definition: multi_array.hxx:2807
Definition: rf_visitors.hxx:863
void visit_at_end(RF &rf, PR &)
Definition: rf_visitors.hxx:836
ArrayVector< int > numChoices
Definition: rf_visitors.hxx:1529
Definition: rf_visitors.hxx:1495
Definition: rf_visitors.hxx:1230
MultiArrayView< N, T, StridedArrayTag > transpose() const
Definition: multi_array.hxx:1513
double oob_per_tree2
Definition: rf_visitors.hxx:1045
detail::VisitorNode< A > create_visitor(A &a)
Definition: rf_visitors.hxx:345
void reset_tree(int tree_id)
Definition: rf_visitors.hxx:636
Definition: random.hxx:669
difference_type_1 size() const
Definition: multi_array.hxx:1587
MultiArray< 4, double > oobroc_per_tree
Definition: rf_visitors.hxx:1067
Definition: multi_fwd.hxx:63
double return_val()
Definition: rf_visitors.hxx:226
void visit_at_end(RF &rf, PR &pr)
Definition: rf_visitors.hxx:994
Definition: rf_visitors.hxx:255
detail::SelectIntegerType< 32, detail::SignedIntTypes >::type Int32
32-bit signed int
Definition: sized_int.hxx:175
MultiArray< 2, double > noise
Definition: rf_visitors.hxx:1505
void init(U const &initial)
Definition: array_vector.hxx:146
Definition: rf_split.hxx:831
MultiArray & init(const U &init)
Definition: multi_array.hxx:2797
Iterator argMax(Iterator first, Iterator last)
Find the maximum element in a sequence.
Definition: algorithm.hxx:96
void visit_after_tree(RF &, PR &, SM &, ST &, int)
Definition: rf_visitors.hxx:647
Definition: rf_visitors.hxx:1015
Definition: rf_visitors.hxx:584
void visit_internal_node(TR &, IntT, TopT, Feat &)
Definition: rf_visitors.hxx:216
MultiArray< 2, double > oob_per_tree
Definition: rf_visitors.hxx:1025
#define TOCS
Definition: timing.hxx:325
Class for fixed size vectors.This class contains an array of size SIZE of the specified VALUETYPE...
Definition: accessor.hxx:940
void writeHDF5(...)
Store array data in an HDF5 file.
Definition: rf_visitors.hxx:1460
MultiArray< 2, double > distance
Definition: rf_visitors.hxx:1524
MultiArrayView< 2, T, C > rowVector(MultiArrayView< 2, T, C > const &m, MultiArrayIndex d)
Definition: matrix.hxx:695
double oob_std
Definition: rf_visitors.hxx:1031
Definition: rf_visitors.hxx:102
void visit_at_beginning(RF &rf, const PR &)
Definition: rf_visitors.hxx:628
MultiArrayIndex columnCount(const MultiArrayView< 2, T, C > &x)
Definition: matrix.hxx:682
FFTWComplex< R >::NormType abs(const FFTWComplex< R > &a)
absolute value (= magnitude)
Definition: fftw3.hxx:1002
Definition: random.hxx:336
MultiArray< 2, double > variable_importance_
Definition: rf_visitors.hxx:1259
double oob_breiman
Definition: rf_visitors.hxx:874
const_iterator end() const
Definition: array_vector.hxx:237
const_pointer data() const
Definition: array_vector.hxx:209
size_type size() const
Definition: array_vector.hxx:358
void visit_at_beginning(RF const &rf, PR const &pr)
Definition: rf_visitors.hxx:188
MultiArrayView subarray(difference_type p, difference_type q) const
Definition: multi_array.hxx:1474
void after_tree_ip_impl(RF &rf, PR &pr, SM &sm, ST &, int index)
Definition: rf_visitors.hxx:1322
void inspectMultiArray(...)
Call an analyzing functor at every element of a multi-dimensional array.
void visit_after_tree(RF &rf, PR &pr, SM &sm, ST &, int index)
Definition: rf_visitors.hxx:807
void visit_external_node(TR &tr, IntT index, TopT node_t, Feat &features)
Definition: rf_visitors.hxx:206
Definition: rf_visitors.hxx:783
void rowStatistics(...)
MultiArrayView< N-M, T, StrideTag > bindOuter(const TinyVector< Index, M > &d) const
Definition: multi_array.hxx:2130
Definition: rf_visitors.hxx:235
void visit_after_split(Tree &tree, Split &split, Region &parent, Region &leftChild, Region &rightChild, Feature_t &features, Label_t &labels)
Definition: rf_visitors.hxx:143
double oob_mean
Definition: rf_visitors.hxx:1028
MultiArray< 2, double > corr_noise
Definition: rf_visitors.hxx:1509

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.11.0 (Thu Mar 17 2016)