SHOGUN  v3.2.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines
VwConditionalProbabilityTree.h
Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2012 Chiyuan Zhang
00008  * Copyright (C) 2012 Chiyuan Zhang
00009  */
00010 
00011 #ifndef CONDITIONALPROBABILITYTREE_H__
00012 #define CONDITIONALPROBABILITYTREE_H__
00013 
00014 #include <map>
00015 
00016 #include <shogun/multiclass/tree/TreeMachine.h>
00017 #include <shogun/classifier/vw/VowpalWabbit.h>
00018 
00019 namespace shogun
00020 {
00021 
00023 struct VwConditionalProbabilityTreeNodeData
00024 {
00026     int32_t label;
00028     float64_t p_right;
00029 
00031     VwConditionalProbabilityTreeNodeData():label(-1), p_right(0) {}
00032 };
00033 
00035 typedef CTreeMachineNode<VwConditionalProbabilityTreeNodeData> node_t;
00036 
00038 class CVwConditionalProbabilityTree: public CTreeMachine<VwConditionalProbabilityTreeNodeData>
00039 {
00040 public:
00041 
00043     CVwConditionalProbabilityTree(int32_t num_passes=1)
00044         :m_num_passes(num_passes), m_feats(NULL)
00045     {
00046     }
00047 
00049     virtual ~CVwConditionalProbabilityTree() {}
00050 
00052     virtual const char* get_name() const { return "VwConditionalProbabilityTree"; }
00053 
00055     void set_num_passes(int32_t num_passes)
00056     {
00057         m_num_passes = num_passes;
00058     }
00059 
00061     int32_t get_num_passes() const
00062     {
00063         return m_num_passes;
00064     }
00065 
00069     void set_features(CStreamingVwFeatures *feats)
00070     {
00071         SG_REF(feats);
00072         SG_UNREF(m_feats);
00073         m_feats = feats;
00074     }
00075 
00077     virtual CMulticlassLabels* apply_multiclass(CFeatures* data=NULL);
00078 
00080     virtual int32_t apply_multiclass_example(VwExample* ex);
00081 protected:
00083     virtual bool train_require_labels() const { return false; }
00084 
00091     virtual bool train_machine(CFeatures* data);
00092 
00096     void train_example(VwExample *ex);
00097 
00102     void train_path(VwExample *ex, node_t *node);
00103 
00109     float64_t train_node(VwExample *ex, node_t *node);
00110 
00114     int32_t create_machine(VwExample *ex);
00115 
00121     virtual bool which_subtree(node_t *node, VwExample *ex)=0;
00122 
00124     void compute_conditional_probabilities(VwExample *ex);
00125 
00129     float64_t accumulate_conditional_probability(node_t *leaf);
00130 
00131     int32_t m_num_passes; 
00132     std::map<int32_t, node_t*> m_leaves; 
00133     CStreamingVwFeatures *m_feats; 
00134 };
00135 
00136 } /* shogun */
00137 
00138 #endif /* end of include guard: CONDITIONALPROBABILITYTREE_H__ */
00139 
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation