SHOGUN
v3.2.0
|
00001 /* This program is free software: you can redistribute it and/or modify 00002 * it under the terms of the GNU General Public License as published by 00003 * the Free Software Foundation, either version 3 of the License, or 00004 * (at your option) any later version. 00005 * 00006 * This program is distributed in the hope that it will be useful, 00007 * but WITHOUT ANY WARRANTY; without even the implied warranty of 00008 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 00009 * GNU General Public License for more details. 00010 * 00011 * You should have received a copy of the GNU General Public License 00012 * along with this program. If not, see <http://www.gnu.org/licenses/>. 00013 * 00014 * Copyright (C) 2009 - 2012 Jun Liu and Jieping Ye 00015 */ 00016 00017 #include <shogun/lib/slep/tree/altra.h> 00018 00019 void altra(double *x, double *v, int n, double *ind, int nodes, double mult) 00020 { 00021 int i, j; 00022 double lambda,twoNorm, ratio; 00023 00024 /* 00025 * test whether the first node is special 00026 */ 00027 if ((int) ind[0]==-1){ 00028 00029 /* 00030 *Recheck whether ind[1] equals to zero 00031 */ 00032 if ((int) ind[1]!=-1){ 00033 printf("\n Error! \n Check ind"); 00034 exit(1); 00035 } 00036 00037 lambda=mult*ind[2]; 00038 00039 for(j=0;j<n;j++){ 00040 if (v[j]>lambda) 00041 x[j]=v[j]-lambda; 00042 else 00043 if (v[j]<-lambda) 00044 x[j]=v[j]+lambda; 00045 else 00046 x[j]=0; 00047 } 00048 00049 i=1; 00050 } 00051 else{ 00052 memcpy(x, v, sizeof(double) * n); 00053 i=0; 00054 } 00055 00056 /* 00057 * sequentially process each node 00058 * 00059 */ 00060 for(;i < nodes; i++){ 00061 /* 00062 * compute the L2 norm of this group 00063 */ 00064 twoNorm=0; 00065 for(j=(int) ind[3*i]-1;j< (int) ind[3*i+1];j++) 00066 twoNorm += x[j] * x[j]; 00067 twoNorm=sqrt(twoNorm); 00068 00069 lambda=mult*ind[3*i+2]; 00070 if (twoNorm>lambda){ 00071 ratio=(twoNorm-lambda)/twoNorm; 00072 00073 /* 00074 * shrinkage this group by ratio 00075 */ 00076 for(j=(int) ind[3*i]-1;j<(int) ind[3*i+1];j++) 00077 x[j]*=ratio; 00078 } 00079 else{ 00080 /* 00081 * threshold this group to zero 00082 */ 00083 for(j=(int) ind[3*i]-1;j<(int) ind[3*i+1];j++) 00084 x[j]=0; 00085 } 00086 } 00087 } 00088 00089 void altra_mt(double *X, double *V, int n, int k, double *ind, int nodes, double mult) 00090 { 00091 int i, j; 00092 00093 double *x=(double *)malloc(sizeof(double)*k); 00094 double *v=(double *)malloc(sizeof(double)*k); 00095 00096 for (i=0;i<n;i++){ 00097 /* 00098 * copy a row of V to v 00099 * 00100 */ 00101 for(j=0;j<k;j++) 00102 v[j]=V[j*n + i]; 00103 00104 altra(x, v, k, ind, nodes, mult); 00105 00106 /* 00107 * copy the solution to X 00108 */ 00109 for(j=0;j<k;j++) 00110 X[j*n+i]=x[j]; 00111 } 00112 00113 free(x); 00114 free(v); 00115 } 00116 00117 void computeLambda2Max(double *lambda2_max, double *x, int n, double *ind, int nodes) 00118 { 00119 int i, j; 00120 double twoNorm; 00121 00122 *lambda2_max=0; 00123 00124 for(i=0;i < nodes; i++){ 00125 /* 00126 * compute the L2 norm of this group 00127 */ 00128 twoNorm=0; 00129 for(j=(int) ind[3*i]-1;j< (int) ind[3*i+1];j++) 00130 twoNorm += x[j] * x[j]; 00131 twoNorm=sqrt(twoNorm); 00132 00133 twoNorm=twoNorm/ind[3*i+2]; 00134 00135 if (twoNorm >*lambda2_max ) 00136 *lambda2_max=twoNorm; 00137 } 00138 } 00139 00140 double treeNorm(double *x, int ldx, int n, double *ind, int nodes){ 00141 00142 int i, j; 00143 double twoNorm, lambda; 00144 00145 double tree_norm = 0; 00146 00147 /* 00148 * test whether the first node is special 00149 */ 00150 if ((int) ind[0]==-1){ 00151 00152 /* 00153 *Recheck whether ind[1] equals to zero 00154 */ 00155 if ((int) ind[1]!=-1){ 00156 printf("\n Error! \n Check ind"); 00157 exit(1); 00158 } 00159 00160 lambda=ind[2]; 00161 00162 for(j=0;j<n*ldx;j+=ldx){ 00163 tree_norm+=fabs(x[j]); 00164 } 00165 00166 tree_norm = tree_norm * lambda; 00167 00168 i=1; 00169 } 00170 else{ 00171 i=0; 00172 } 00173 00174 /* 00175 * sequentially process each node 00176 * 00177 */ 00178 for(;i < nodes; i++){ 00179 /* 00180 * compute the L2 norm of this group 00181 */ 00182 twoNorm=0; 00183 00184 int n_in_node = (int) ind[3*i+1] - (int) ind[3*i]-1; 00185 for(j=(int) ind[3*i]-1;j< (int) ind[3*i]-1 + n_in_node*ldx;j+=ldx) 00186 twoNorm += x[j] * x[j]; 00187 twoNorm=sqrt(twoNorm); 00188 00189 lambda=ind[3*i+2]; 00190 00191 tree_norm = tree_norm + lambda*twoNorm; 00192 } 00193 00194 return tree_norm; 00195 } 00196 00197 double findLambdaMax(double *v, int n, double *ind, int nodes){ 00198 00199 int i; 00200 double lambda=0,squaredWeight=0, lambda1,lambda2; 00201 double *x=(double *)malloc(sizeof(double)*n); 00202 double *ind2=(double *)malloc(sizeof(double)*nodes*3); 00203 int num=0; 00204 00205 for(i=0;i<n;i++){ 00206 lambda+=v[i]*v[i]; 00207 } 00208 00209 if ( (int)ind[0]==-1 ) 00210 squaredWeight=n*ind[2]*ind[2]; 00211 else 00212 squaredWeight=ind[2]*ind[2]; 00213 00214 for (i=1;i<nodes;i++){ 00215 squaredWeight+=ind[3*i+2]*ind[3*i+2]; 00216 } 00217 00218 /* set lambda to an initial guess 00219 */ 00220 lambda=sqrt(lambda/squaredWeight); 00221 00222 /* 00223 printf("\n\n lambda=%2.5f",lambda); 00224 */ 00225 00226 /* 00227 *copy ind to ind2, 00228 *and scale the weight 3*i+2 00229 */ 00230 for(i=0;i<nodes;i++){ 00231 ind2[3*i]=ind[3*i]; 00232 ind2[3*i+1]=ind[3*i+1]; 00233 ind2[3*i+2]=ind[3*i+2]*lambda; 00234 } 00235 00236 /* test whether the solution is zero or not 00237 */ 00238 altra(x, v, n, ind2, nodes); 00239 for(i=0;i<n;i++){ 00240 if (x[i]!=0) 00241 break; 00242 } 00243 00244 if (i>=n) { 00245 /*x is a zero vector*/ 00246 lambda2=lambda; 00247 lambda1=lambda; 00248 00249 num=0; 00250 00251 while(1){ 00252 num++; 00253 00254 lambda2=lambda; 00255 lambda1=lambda1/2; 00256 /* update ind2 00257 */ 00258 for(i=0;i<nodes;i++){ 00259 ind2[3*i+2]=ind[3*i+2]*lambda1; 00260 } 00261 00262 /* compute and test whether x is zero 00263 */ 00264 altra(x, v, n, ind2, nodes); 00265 for(i=0;i<n;i++){ 00266 if (x[i]!=0) 00267 break; 00268 } 00269 00270 if (i<n){ 00271 break; 00272 /*x is not zero 00273 *we have found lambda1 00274 */ 00275 } 00276 } 00277 00278 } 00279 else{ 00280 /*x is a non-zero vector*/ 00281 lambda2=lambda; 00282 lambda1=lambda; 00283 00284 num=0; 00285 while(1){ 00286 num++; 00287 00288 lambda1=lambda2; 00289 lambda2=lambda2*2; 00290 /* update ind2 00291 */ 00292 for(i=0;i<nodes;i++){ 00293 ind2[3*i+2]=ind[3*i+2]*lambda2; 00294 } 00295 00296 /* compute and test whether x is zero 00297 */ 00298 altra(x, v, n, ind2, nodes); 00299 for(i=0;i<n;i++){ 00300 if (x[i]!=0) 00301 break; 00302 } 00303 00304 if (i>=n){ 00305 break; 00306 /*x is a zero vector 00307 *we have found lambda2 00308 */ 00309 } 00310 } 00311 } 00312 00313 /* 00314 printf("\n num=%d, lambda1=%2.5f, lambda2=%2.5f",num, lambda1,lambda2); 00315 */ 00316 00317 while ( fabs(lambda2-lambda1) > lambda2 * 1e-10 ){ 00318 00319 num++; 00320 00321 lambda=(lambda1+lambda2)/2; 00322 00323 /* update ind2 00324 */ 00325 for(i=0;i<nodes;i++){ 00326 ind2[3*i+2]=ind[3*i+2]*lambda; 00327 } 00328 00329 /* compute and test whether x is zero 00330 */ 00331 altra(x, v, n, ind2, nodes); 00332 for(i=0;i<n;i++){ 00333 if (x[i]!=0) 00334 break; 00335 } 00336 00337 if (i>=n){ 00338 lambda2=lambda; 00339 } 00340 else{ 00341 lambda1=lambda; 00342 } 00343 00344 /* 00345 printf("\n lambda1=%2.5f, lambda2=%2.5f",lambda1,lambda2); 00346 */ 00347 } 00348 00349 /* 00350 printf("\n num=%d",num); 00351 00352 printf(" lambda1=%2.5f, lambda2=%2.5f",lambda1,lambda2); 00353 00354 */ 00355 00356 free(x); 00357 free(ind2); 00358 00359 return lambda2; 00360 } 00361 00362 double findLambdaMax_mt(double *V, int n, int k, double *ind, int nodes) 00363 { 00364 int i, j; 00365 00366 double *v=(double *)malloc(sizeof(double)*k); 00367 double lambda; 00368 00369 double lambdaMax=0; 00370 00371 for (i=0;i<n;i++){ 00372 /* 00373 * copy a row of V to v 00374 * 00375 */ 00376 for(j=0;j<k;j++) 00377 v[j]=V[j*n + i]; 00378 00379 lambda = findLambdaMax(v, k, ind, nodes); 00380 00381 /* 00382 printf("\n lambda=%5.2f",lambda); 00383 */ 00384 00385 if (lambda>lambdaMax) 00386 lambdaMax=lambda; 00387 } 00388 00389 /* 00390 printf("\n *lambdaMax=%5.2f",*lambdaMax); 00391 */ 00392 00393 free(v); 00394 return lambdaMax; 00395 }