SHOGUN  v3.2.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines
CGMShiftedFamilySolver.cpp
Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2013 Soumyajit De
00008  */
00009 
00010 #include <shogun/lib/common.h>
00011 
00012 #ifdef HAVE_EIGEN3
00013 
00014 #include <shogun/lib/SGVector.h>
00015 #include <shogun/lib/Time.h>
00016 #include <shogun/mathematics/eigen3.h>
00017 #include <shogun/mathematics/Math.h>
00018 #include <shogun/mathematics/linalg/linop/LinearOperator.h>
00019 #include <shogun/mathematics/linalg/linsolver/CGMShiftedFamilySolver.h>
00020 #include <shogun/mathematics/linalg/linsolver/IterativeSolverIterator.h>
00021 
00022 using namespace Eigen;
00023 
00024 namespace shogun
00025 {
00026 
00027 CCGMShiftedFamilySolver::CCGMShiftedFamilySolver()
00028     : CIterativeShiftedLinearFamilySolver<float64_t, complex128_t>()
00029 {
00030 }
00031 
00032 CCGMShiftedFamilySolver::CCGMShiftedFamilySolver(bool store_residuals)
00033     : CIterativeShiftedLinearFamilySolver<float64_t, complex128_t>(store_residuals)
00034 {
00035 }
00036 
00037 CCGMShiftedFamilySolver::~CCGMShiftedFamilySolver()
00038 {
00039 }
00040 
00041 SGVector<float64_t> CCGMShiftedFamilySolver::solve(
00042     CLinearOperator<float64_t>* A, SGVector<float64_t> b)
00043 {
00044     SGVector<complex128_t> shifts(1);
00045     shifts[0]=0.0;
00046     SGVector<complex128_t> weights(1);
00047     weights[0]=1.0;
00048 
00049     return solve_shifted_weighted(A, b, shifts, weights).get_real();
00050 }
00051 
00052 SGVector<complex128_t> CCGMShiftedFamilySolver::solve_shifted_weighted(
00053     CLinearOperator<float64_t>* A, SGVector<float64_t> b,
00054     SGVector<complex128_t> shifts, SGVector<complex128_t> weights)
00055 {
00056     SG_DEBUG("Entering\n");
00057 
00058     // sanity check
00059     REQUIRE(A, "Operator is NULL!\n");
00060     REQUIRE(A->get_dimension()==b.vlen, "Dimension mismatch! [%d vs %d]\n",
00061         A->get_dimension(), b.vlen);
00062     REQUIRE(shifts.vector,"Shifts are not initialized!\n");
00063     REQUIRE(weights.vector,"Weights are not initialized!\n");
00064     REQUIRE(shifts.vlen==weights.vlen, "Number of shifts and number of "
00065         "weights are not equal! [%d vs %d]\n", shifts.vlen, weights.vlen);
00066 
00067     // the solution matrix, one column per shift, initial guess 0 for all
00068     MatrixXcd x_sh=MatrixXcd::Zero(b.vlen, shifts.vlen);
00069     MatrixXcd p_sh=MatrixXcd::Zero(b.vlen, shifts.vlen);
00070 
00071     // non-shifted direction
00072     SGVector<float64_t> p_(b.vlen);
00073 
00074     // the rest of the part hinges on eigen3 for computing norms
00075     Map<VectorXd> b_map(b.vector, b.vlen);
00076     Map<VectorXd> p(p_.vector, p_.vlen);
00077 
00078     // residual r_i=b-Ax_i, here x_0=[0], so r_0=b
00079     VectorXd r=b_map;
00080 
00081     // initial direction is same as residual
00082     p=r;
00083     p_sh=r.replicate(1, shifts.vlen).cast<complex128_t>();
00084 
00085     // non shifted initializers
00086     float64_t r_norm2=r.dot(r);
00087     float64_t beta_old=1.0;
00088     float64_t alpha=1.0;
00089 
00090     // shifted quantities
00091     SGVector<complex128_t> alpha_sh(shifts.vlen);
00092     SGVector<complex128_t> beta_sh(shifts.vlen);
00093     SGVector<complex128_t> zeta_sh_old(shifts.vlen);
00094     SGVector<complex128_t> zeta_sh_cur(shifts.vlen);
00095     SGVector<complex128_t> zeta_sh_new(shifts.vlen);
00096 
00097     // shifted initializers
00098     zeta_sh_old.set_const(1.0);
00099     zeta_sh_cur.set_const(1.0);
00100 
00101     // the iterator for this iterative solver
00102     IterativeSolverIterator<float64_t> it(r, m_max_iteration_limit,
00103         m_relative_tolerence, m_absolute_tolerence);
00104 
00105     // start the timer
00106     CTime time;
00107     time.start();
00108 
00109     // set the residuals to zero
00110     if (m_store_residuals)
00111         m_residuals.set_const(0.0);
00112 
00113     // CG iteration begins
00114     for (it.begin(r); !it.end(r); ++it)
00115     {
00116 
00117         SG_DEBUG("CG iteration %d, residual norm %f\n",
00118                 it.get_iter_info().iteration_count,
00119                 it.get_iter_info().residual_norm);
00120 
00121         if (m_store_residuals)
00122         {
00123             m_residuals[it.get_iter_info().iteration_count]
00124                 =it.get_iter_info().residual_norm;
00125         }
00126 
00127         // apply linear operator to the direction vector
00128         SGVector<float64_t> Ap_=A->apply(p_);
00129         Map<VectorXd> Ap(Ap_.vector, Ap_.vlen);
00130 
00131         // compute p^{T}Ap, if zero, failure
00132         float64_t p_dot_Ap=p.dot(Ap);
00133         if (p_dot_Ap==0.0)
00134             break;
00135 
00136         // compute the beta parameter of CG_M
00137         float64_t beta=-r_norm2/p_dot_Ap;
00138 
00139         // compute the zeta-shifted parameter of CG_M
00140         compute_zeta_sh_new(zeta_sh_old, zeta_sh_cur, shifts, beta_old, beta,
00141             alpha, zeta_sh_new);
00142 
00143         // compute beta-shifted parameter of CG_M
00144         compute_beta_sh(zeta_sh_new, zeta_sh_cur, beta, beta_sh);
00145 
00146         // update the solution vector and residual
00147         for (index_t i=0; i<shifts.vlen; ++i)
00148             x_sh.col(i)-=beta_sh[i]*p_sh.col(i);
00149 
00150         // r_{i}=r_{i-1}+\beta_{i}Ap
00151         r+=beta*Ap;
00152 
00153         // compute new ||r||_{2}, if zero, converged
00154         float64_t r_norm2_i=r.dot(r);
00155         if (r_norm2_i==0.0)
00156             break;
00157 
00158         // compute the alpha parameter of CG_M
00159         alpha=r_norm2_i/r_norm2;
00160 
00161         // update ||r||_{2}
00162         r_norm2=r_norm2_i;
00163 
00164         // update direction
00165         p=r+alpha*p;
00166 
00167         compute_alpha_sh(zeta_sh_new, zeta_sh_cur, beta_sh, beta, alpha, alpha_sh);
00168 
00169         for (index_t i=0; i<shifts.vlen; ++i)
00170         {
00171             p_sh.col(i)*=alpha_sh[i];
00172             p_sh.col(i)+=zeta_sh_new[i]*r;
00173         }
00174 
00175         // update parameters
00176         for (index_t i=0; i<shifts.vlen; ++i)
00177         {
00178             zeta_sh_old[i]=zeta_sh_cur[i];
00179             zeta_sh_cur[i]=zeta_sh_new[i];
00180         }
00181         beta_old=beta;
00182     }
00183 
00184     float64_t elapsed=time.cur_time_diff();
00185 
00186     if (!it.succeeded(r))
00187         SG_WARNING("Did not converge!\n");
00188 
00189     SG_INFO("Iteration took %ld times, residual norm=%.20lf, time elapsed=%lf\n",
00190         it.get_iter_info().iteration_count, it.get_iter_info().residual_norm, elapsed);
00191 
00192     // compute the final result vector multiplied by weights
00193     SGVector<complex128_t> result(b.vlen);
00194     result.set_const(0.0);
00195     Map<VectorXcd> x(result.vector, result.vlen);
00196 
00197     for (index_t i=0; i<x_sh.cols(); ++i)
00198         x+=x_sh.col(i)*weights[i];
00199 
00200     SG_DEBUG("Leaving\n");
00201     return result;
00202 }
00203 
00204 }
00205 #endif // HAVE_EIGEN3
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation