SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2012 Chiyuan Zhang 00008 * Copyright (C) 2012 Chiyuan Zhang 00009 */ 00010 00011 #ifndef CONDITIONALPROBABILITYTREE_H__ 00012 #define CONDITIONALPROBABILITYTREE_H__ 00013 00014 #include <map> 00015 00016 #include <shogun/multiclass/tree/TreeMachine.h> 00017 #include <shogun/classifier/vw/VowpalWabbit.h> 00018 00019 namespace shogun 00020 { 00021 00023 struct VwConditionalProbabilityTreeNodeData 00024 { 00026 int32_t label; 00028 float64_t p_right; 00029 00031 VwConditionalProbabilityTreeNodeData():label(-1), p_right(0) {} 00032 }; 00033 00035 typedef CTreeMachineNode<VwConditionalProbabilityTreeNodeData> node_t; 00036 00038 class CVwConditionalProbabilityTree: public CTreeMachine<VwConditionalProbabilityTreeNodeData> 00039 { 00040 public: 00041 00043 CVwConditionalProbabilityTree(int32_t num_passes=1) 00044 :m_num_passes(num_passes), m_feats(NULL) 00045 { 00046 } 00047 00049 virtual ~CVwConditionalProbabilityTree() {} 00050 00052 virtual const char* get_name() const { return "VwConditionalProbabilityTree"; } 00053 00055 void set_num_passes(int32_t num_passes) 00056 { 00057 m_num_passes = num_passes; 00058 } 00059 00061 int32_t get_num_passes() const 00062 { 00063 return m_num_passes; 00064 } 00065 00069 void set_features(CStreamingVwFeatures *feats) 00070 { 00071 SG_REF(feats); 00072 SG_UNREF(m_feats); 00073 m_feats = feats; 00074 } 00075 00077 virtual CMulticlassLabels* apply_multiclass(CFeatures* data=NULL); 00078 00080 virtual int32_t apply_multiclass_example(VwExample* ex); 00081 protected: 00083 virtual bool train_require_labels() const { return false; } 00084 00091 virtual bool train_machine(CFeatures* data); 00092 00096 void train_example(VwExample *ex); 00097 00102 void train_path(VwExample *ex, node_t *node); 00103 00109 float64_t train_node(VwExample *ex, node_t *node); 00110 00114 int32_t create_machine(VwExample *ex); 00115 00121 virtual bool which_subtree(node_t *node, VwExample *ex)=0; 00122 00124 void compute_conditional_probabilities(VwExample *ex); 00125 00129 float64_t accumulate_conditional_probability(node_t *leaf); 00130 00131 int32_t m_num_passes; 00132 std::map<int32_t, node_t*> m_leaves; 00133 CStreamingVwFeatures *m_feats; 00134 }; 00135 00136 } /* shogun */ 00137 00138 #endif /* end of include guard: CONDITIONALPROBABILITYTREE_H__ */ 00139