SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2013 Shell Hu 00008 * Copyright (C) 2013 Shell Hu 00009 */ 00010 00011 #include <shogun/structure/DisjointSet.h> 00012 #include <shogun/base/Parameter.h> 00013 00014 using namespace shogun; 00015 00016 CDisjointSet::CDisjointSet() 00017 : CSGObject() 00018 { 00019 SG_UNSTABLE("CDisjointSet::CDisjointSet()", "\n"); 00020 00021 init(); 00022 } 00023 00024 CDisjointSet::CDisjointSet(int32_t num_elements) 00025 : CSGObject() 00026 { 00027 init(); 00028 m_num_elements = num_elements; 00029 m_parent = SGVector<int32_t>(num_elements); 00030 m_rank = SGVector<int32_t>(num_elements); 00031 } 00032 00033 void CDisjointSet::init() 00034 { 00035 SG_ADD(&m_num_elements, "num_elements", "Number of elements", MS_NOT_AVAILABLE); 00036 SG_ADD(&m_parent, "parent", "Parent pointers", MS_NOT_AVAILABLE); 00037 SG_ADD(&m_rank, "rank", "Rank of each element", MS_NOT_AVAILABLE); 00038 SG_ADD(&m_is_connected, "is_connected", "Whether disjoint sets have been linked", MS_NOT_AVAILABLE); 00039 00040 m_is_connected = false; 00041 m_num_elements = -1; 00042 } 00043 00044 void CDisjointSet::make_sets() 00045 { 00046 REQUIRE(m_num_elements > 0, "%s::make_sets(): m_num_elements <= 0.\n", get_name()); 00047 00048 m_parent.range_fill(); 00049 m_rank.zero(); 00050 } 00051 00052 int32_t CDisjointSet::find_set(int32_t x) 00053 { 00054 ASSERT(x >= 0 && x < m_num_elements); 00055 00056 // path compression 00057 if (x != m_parent[x]) 00058 m_parent[x] = find_set(m_parent[x]); 00059 00060 return m_parent[x]; 00061 } 00062 00063 int32_t CDisjointSet::link_set(int32_t xroot, int32_t yroot) 00064 { 00065 ASSERT(xroot >= 0 && xroot < m_num_elements); 00066 ASSERT(yroot >= 0 && yroot < m_num_elements); 00067 ASSERT(m_parent[xroot] == xroot && m_parent[yroot] == yroot); 00068 ASSERT(xroot != yroot); 00069 00070 // union by rank 00071 if (m_rank[xroot] > m_rank[yroot]) 00072 { 00073 m_parent[yroot] = xroot; 00074 return xroot; 00075 } 00076 else 00077 { 00078 m_parent[xroot] = yroot; 00079 if (m_rank[xroot] == m_rank[yroot]) 00080 m_rank[yroot] += 1; 00081 00082 return yroot; 00083 } 00084 } 00085 00086 bool CDisjointSet::union_set(int32_t x, int32_t y) 00087 { 00088 ASSERT(x >= 0 && x < m_num_elements); 00089 ASSERT(y >= 0 && y < m_num_elements); 00090 00091 int32_t xroot = find_set(x); 00092 int32_t yroot = find_set(y); 00093 00094 if (xroot == yroot) 00095 return true; 00096 00097 link_set(xroot, yroot); 00098 return false; 00099 } 00100 00101 bool CDisjointSet::is_same_set(int32_t x, int32_t y) 00102 { 00103 ASSERT(x >= 0 && x < m_num_elements); 00104 ASSERT(y >= 0 && y < m_num_elements); 00105 00106 if (find_set(x) == find_set(y)) 00107 return true; 00108 00109 return false; 00110 } 00111 00112 int32_t CDisjointSet::get_unique_labeling(SGVector<int32_t> out_labels) 00113 { 00114 REQUIRE(m_num_elements > 0, "%s::get_unique_labeling(): m_num_elements <= 0.\n", get_name()); 00115 00116 if (out_labels.size() != m_num_elements) 00117 out_labels.resize_vector(m_num_elements); 00118 00119 SGVector<int32_t> roots(m_num_elements); 00120 SGVector<int32_t> flags(m_num_elements); 00121 SGVector<int32_t>::fill_vector(flags.vector, flags.vlen, -1); 00122 int32_t unilabel = 0; 00123 00124 for (int32_t i = 0; i < m_num_elements; i++) 00125 { 00126 roots[i] = find_set(i); 00127 // if roots[i] never be found 00128 if (flags[roots[i]] < 0) 00129 flags[roots[i]] = unilabel++; 00130 } 00131 00132 for (int32_t i = 0; i < m_num_elements; i++) 00133 out_labels[i] = flags[roots[i]]; 00134 00135 return unilabel; 00136 } 00137 00138 int32_t CDisjointSet::get_num_sets() 00139 { 00140 REQUIRE(m_num_elements > 0, "%s::get_num_sets(): m_num_elements <= 0.\n", get_name()); 00141 00142 return get_unique_labeling(SGVector<int32_t>(m_num_elements)); 00143 } 00144 00145 bool CDisjointSet::get_connected() 00146 { 00147 return m_is_connected; 00148 } 00149 00150 void CDisjointSet::set_connected(bool is_connected) 00151 { 00152 m_is_connected = is_connected; 00153 } 00154