/* Ergo, version 3.5, a program for linear scaling electronic structure
 * calculations.
 * Copyright (C) 2016 Elias Rudberg, Emanuel H. Rubensson, Pawel Salek,
 * and Anastasia Kruchinina.
 * 
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 * 
 * Primary academic reference:
 * Kohn−Sham Density Functional Theory Electronic Structure Calculations 
 * with Linearly Scaling Computational Time and Memory Usage,
 * Elias Rudberg, Emanuel H. Rubensson, and Pawel Salek,
 * J. Chem. Theory Comput. 7, 340 (2011),
 * <http://dx.doi.org/10.1021/ct100611z>
 * 
 * For further information about Ergo, see <http://www.ergoscf.org>.
 */



#ifndef ERGO_MATRIX_HEADER
#define ERGO_MATRIX_HEADER

#include "matrix_typedefs.h" // definitions of matrix types and interval type (source)
#include "realtype.h"   // definitions of types (utilities_basic)
#include "matrix_utilities.h"
#include "integral_matrix_wrappers.h"
#include "SizesAndBlocks.h"
#include "Matrix.h"
#include "Vector.h"
#include "MatrixSymmetric.h"
#include "MatrixTriangular.h"
#include "MatrixGeneral.h"
#include "VectorGeneral.h"
#include "output.h"

#include <iostream>
#include <fstream>
#include <string.h>

#include "LanczosSeveralLargestEig.h"

#include "files_sparse.h"

using namespace std;

typedef ergo_real real;



/** Wrapper for the matrix. Matrix operations are implemented using matrix class in Ergo.
 *
 * \tparam MatrixType Type of a matrix from the matrix library (ex. symmMatrix). */
template<typename MatrixType>
class ErgoMatrix
{
 public:

  typedef mat::normType NormType;
  typedef typename MatrixType::VectorType VectorType;
  typedef MatrixType MatrixTypeInner;

  /** Class containing parameters of MatrixType. */
  class Params
  {
  public:
    mat::SizesAndBlocks rows;
    mat::SizesAndBlocks cols;
    
    Params(){}
  Params(mat::SizesAndBlocks rows_, mat::SizesAndBlocks cols_) : rows(rows_), cols(cols_) {}
  };
  
  Params params;

 private:
  MatrixType MATRIX; /**< The matrix M itself. */

 public:

  /** Copy constructor. */
  ErgoMatrix(const ErgoMatrix<MatrixType> &M){ MATRIX = M.get_ref_to_matrix_const();}

  // constructors
 ErgoMatrix(MatrixType& X_, const Params &p = Params()) : params(p), MATRIX(X_) {}
 ErgoMatrix(const Params &p = Params()) : params(p) {}

  /** To avoid creation of an extra copy of a matrix, we use this
   *  function to get a reference to the matrix. Returns constant
   *  reference to the matrix. */
  const MatrixType& get_ref_to_matrix_const() const { return MATRIX;}
  /** To avoid creation of an extra copy of a matrix, we use this
   *  function to get a reference to the matrix. */
  MatrixType& get_ref_to_matrix() { return MATRIX;}
 
  /** Return the matrix which contains this wrapper. */
  void get_matrix(MatrixType &M /**< [out] Matrix which contains this wrapper. */
		  ) const { M = MATRIX;} // copy matrix

  /** Read matrix from .mtx file. */
  void read_from_mtx(const string &name,               /**< [in] Matrix which contains this wrapper. */
		     const mat::SizesAndBlocks &rows,  /**< [in] Parameter of MatrixType (info about rows). */
		     const mat::SizesAndBlocks &cols   /**< [in] Parameter of MatrixType (info about columns). */
		     );

  /** Write matrix to .bin file. The file has a default name and
   *  matrix is saved in particular format defined by the matrix
   *  library. */
  void writeToFile()
  {this->MATRIX.writeToFile();}

  /** Read matrix from .bin file.  The file has a default name and
   * matrix is saved in particular format defined by the matrix
   * library. */
  void readFromFile()
  {this->MATRIX.readFromFile();}

  /** Transfer matrix X to M, clearing previous content of M if
      any. Clear also matrix X.  */
  void transfer_from(ErgoMatrix<MatrixType> &X)
  {
    MatrixType &X_MATRIX = X.get_ref_to_matrix();
    X_MATRIX.transfer(MATRIX); // X.MATRIX -> MATRIX
  } 

  /**  Check if matrix is empty.  Notice: empty matrix != zero
   * matrix. Empty matrix has no structure. */
  bool is_empty() const
  {return MATRIX.is_empty();}

  /** Return number of non-zeros. */
  size_t nnz() const
  {return MATRIX.nnz();}
  
  /** Return number of rows. */
  size_t get_nrows() const
  {return MATRIX.get_nrows();}

  /** Return number of columns. */
  size_t get_ncols() const
  {return MATRIX.get_ncols();}

  /** Compute trace. */
  real trace() const
  {return MATRIX.trace();}

  /** Get Gershgorin spectrum bounds. */
  void gershgorin(real &eigmin,   /**< [out] Lower bound. */
		  real &eigmax    /**< [out] Upper bound. */
		  ) const
  {MATRIX.gershgorin(eigmin, eigmax);}

  /** Compute spectral norm. */
  real eucl(real requested_accuracy,  /**< [in] Tolerance in eigensolver. */
	    int maxIter = -1          /**< [in] Maximum number of
					 iterations in eigensolver. */
	    ) const
  {return MATRIX.eucl(requested_accuracy, maxIter);}

  /** Compute mixed norm. See article J. Comput. Chem. 30.6 (2009): 974-977. */
  real mixed(real requested_accuracy,  /**< [in] Tolerance in eigensolver. */
	     int maxIter = -1          /**< [in] Maximum number of
					  iterations in eigensolver. */
	     ) const
  {return MATRIX.mixed_norm(requested_accuracy, maxIter);}

  /** Compute Frobenius norm. */
  real frob() const
  {return MATRIX.frob();}

  /** Remove small matrix elements (truncate). */
  real thresh(const real tau, /**< [in] Threshold value, or maximum
				 allowed error in the matrix measured
				 by a requested norm. */ 
	      NormType norm   /**< [in] Norm used for truncation. */ 
	      ) 
  {return MATRIX.thresh(tau, norm);}

  /** Remove small matrix elements (truncate) using spectral norm in
      non-orthogonal basis set. */
  real eucl_thresh(const real tau,  /**< [in] Threshold value, or maximum
				       allowed error in the matrix measured
				       by a requested norm. */ 
		   triangMatrix * chol_factor /**< [in] Inverse Cholesky factor. 
						 Needed for non-orthogonal basis set. */ 
		   )
  {return MATRIX.eucl_thresh(tau, chol_factor);}

  /** Remove small matrix elements (truncate) using spectral norm in
      orthogonal basis set.  */
  real eucl_thresh(const real tau  /**< [in] Threshold value, or
				      maximum allowed error in the matrix
				      measured by a requested norm. */ 
		   )
  {return MATRIX.eucl_thresh(tau);}


  /** Clear matrix structure.  */
  void clear()
  {MATRIX.clear();}
  
  /** Compute matrix square and save it in X2: X2 = M^2. */
  void square(ErgoMatrix<MatrixType> &X2  /**< [out] Matrix square. */
	      ) const
    {
      X2.get_ref_to_matrix() =  (real)1.0*MATRIX*MATRIX;
      /* 
	 The code was changed by removing the return 
	 statement and sending output matrix as a parameter 
	 due to the memory problems noticed by Elias.
      */
      //MatrixType Y;
      //Y = 1.0*MATRIX*MATRIX;
      //return ErgoMatrix(Y);
    }

  /* // M *= B; */
  /* void mult(ErgoMatrix<MatrixType> &B) */
  /*   { */
  /*     //C = alpha * A * B + beta * C */
  /*     MatrixType Y(MATRIX);  */
  /*     const MatrixType& B_MATRIX = B.get_ref_to_matrix(); */
  /*     MATRIX = 1.0*Y*B_MATRIX + 0.0 * MATRIX; */
  /*   } */


  /** Multiply matrix with matrix B: M *= B and enforce symmetry. */
  void mult_force_symm(const ErgoMatrix<MatrixType> &B)
  {
    const MatrixType& B_MATRIX = B.get_ref_to_matrix_const();
    //C = alpha * A * B + beta * C
    MatrixType Y(MATRIX);  // must be!
    symmMatrix::ssmmUpperTriangleOnly((real)1.0, Y, B_MATRIX, 0, MATRIX);
  }


  /** M *= value. */
  int mult_scalar(const real value)
  {MATRIX *= value; return 1;}
  
  /** M += value*I, I is identity matrix. */
  int add_identity(const real value)
  {MATRIX.add_identity(value); return 1;}

  /** M = M - B is minus = false, M = B - M if minus  = true. */
  int subtract(const ErgoMatrix<MatrixType> & B, bool minus = false)
  {
    const MatrixType& B_MATRIX = B.get_ref_to_matrix_const();
    if(minus)
      { // M = B - M
	this->mult_scalar(-1); // compute -M
	MATRIX += B_MATRIX; //  compute -M + B
      }
    else
      { // M = M - B
	MATRIX -= B_MATRIX;
      } 
    return 1;
  }


  /** M += alpha*B. */
  int add(const ErgoMatrix<MatrixType> & B, const real alpha = 1.0)
  {
    const MatrixType& B_MATRIX = B.get_ref_to_matrix_const(); 
    MATRIX += alpha*B_MATRIX; 
    return 1;
  }

  /** Transform matrix to orthonormal basis.  M = Z' * M * Z, Z = chol_factor. */
  template<typename Matrix>
    int to_norm_basis(const ErgoMatrix<Matrix>  &chol_factor)
    {
      const Matrix& chol_factor_MATRIX = chol_factor.get_ref_to_matrix_const();
      MATRIX = transpose(chol_factor_MATRIX) * MATRIX * chol_factor_MATRIX;
      return 0;
    }
  
  /** Transform matrix to non-orthonormal basis.  M= Z * M * Z', Z = chol_factor. */
  template<typename Matrix>
    int to_nonnorm_basis(const ErgoMatrix<Matrix>  &chol_factor)
    {
      const Matrix& chol_factor_MATRIX = chol_factor.get_ref_to_matrix_const();
      MATRIX = chol_factor_MATRIX * MATRIX * transpose(chol_factor_MATRIX);
      return 0;
    }
  
  /** Compute inverse Cholesky factor.  
   *
   * \tparam Matrix Type of chol_factor matrix, in general it is a
   * triangular matrix.
   */
  template<typename Matrix>
    static int inv_chol(const ErgoMatrix<MatrixType> & S,   /**< [in] Overlap matrix. */
			ErgoMatrix<Matrix>  &chol_factor    /**< [out] Inverse Cholesky factor. */)
    {
      Matrix& chol_factor_MATRIX = chol_factor.get_ref_to_matrix();
      const MatrixType& S_MATRIX = S.get_ref_to_matrix_const();
      chol_factor_MATRIX.inch(S_MATRIX, 1e-10, mat::right);
      return 1;
    }
  
  
  /** Return trace(A-B). */
  static real trace_diff(const ErgoMatrix<MatrixType> &A, const ErgoMatrix<MatrixType> &B)
  {
    return A.trace() - B.trace();
  }
  
  /** Return Frobenius norm of A-B. */
  static real frob_diff(const ErgoMatrix<MatrixType> &A, const ErgoMatrix<MatrixType> &B) 
  {  
    const MatrixType& A_MATRIX = A.get_ref_to_matrix_const();
    const MatrixType& B_MATRIX = B.get_ref_to_matrix_const();
    return MatrixType::frob_diff(A_MATRIX, B_MATRIX);
  }

  /** Return spectral norm of A-B. */
  static real eucl_diff(const ErgoMatrix<MatrixType> &A, const ErgoMatrix<MatrixType> &B, real acc) 
  { 
    const MatrixType& A_MATRIX = A.get_ref_to_matrix_const();
    const MatrixType& B_MATRIX = B.get_ref_to_matrix_const();
    return MatrixType::eucl_diff(A_MATRIX, B_MATRIX, acc);
  }

  /** Return mixed norm of A-B. */
  static real mixed_diff(const ErgoMatrix<MatrixType> &A, const ErgoMatrix<MatrixType> &B, real acc) 
  {
    const MatrixType& A_MATRIX = A.get_ref_to_matrix_const();
    const MatrixType& B_MATRIX = B.get_ref_to_matrix_const();
    return MatrixType::mixed_diff(A_MATRIX, B_MATRIX, acc);
  }


  /*****************************************/
  /*         EIGENVECTORS STAFF            */
  /*****************************************/

  /** Get Rayleigh quotient: A = (y'Ay)/(y'y), y = eigVecPtr. */
  static real compute_rayleigh_quotient(const ErgoMatrix<MatrixType> &A, VectorType * eigVecPtr)
  {
    mat::SizesAndBlocks cols;
    const MatrixType& A_MATRIX = A.get_ref_to_matrix_const();
    A_MATRIX.getCols(cols);
    VectorType y, Ay;
    y.resetSizesAndBlocks(cols);
    y = *eigVecPtr;
    Ay.resetSizesAndBlocks(cols);
    real ONE = 1.0;
    y *= ONE/y.eucl(); // y = y/norm(y)
    Ay = ONE * A_MATRIX * y; // Ay = A*y
    real lambda = transpose(y) * Ay; // lambda = y'*Ay
    return lambda;
  }


  /** Function for choosing method for computing eigenvectors. */
  static int computeEigenvectors(const ErgoMatrix<MatrixType> &A,      /**< [in] Matrix for which to compute eigenvectors. */
				 real tol,                             /**< [in] Eigensolver tolerance. */
				 vector<real> &eigVal,                 /**< [out] Eigenvalue(s). */
				 vector<VectorType> &eigVec,           /**< [in/out] Eigenvector(s). */
				 int number_of_eigenvalues_to_compute, /**< [in] Number of eigenvalues which Lanczos should compute. */
				 bool use_vector_as_guess,             /**< [in] Use vector in eigVec as an initial guess. */
				 string method,                        /**< [in] Chosen eigensolver (power or Lanczos). */
 				 vector<int> & num_iter,               /**< [out] Actual number of iterations (now just num_iter[0] is used). */
				 int maxit = 200,                      /**< [in] Maximum number of iterations. */
				 bool do_deflation = false             /**< [in] Use deflation. */
				 )  
  {
    assert(number_of_eigenvalues_to_compute >= 1);
    assert(eigVal.size() >= 1); // note: number_of_eigenvalues may not be equal to eigVal.size()
    assert(eigVec.size() == eigVal.size());
    assert(eigVec.size() == num_iter.size());
    
    if(method == "power")
      {
	if(eigVal.size() > 1) 
	  throw "Error in computeEigenvectors: computation of more " 
	    "than 1 eigenpair is not implemented for the power method.";
	if(do_deflation)
	  throw "Error in computeEigenvectors: deflation is not implemented for the power method.";
	power_method(A, eigVal[0], eigVec[0], tol, use_vector_as_guess, num_iter[0], maxit);
      }
    if(method == "lanczos")
      {
	lanczos_method(A, eigVal, eigVec, number_of_eigenvalues_to_compute, tol, use_vector_as_guess, num_iter, maxit, do_deflation);
      }
    return 0;
  }
  
  /** Use Lanzcos method for computing eigenvectors. See function
      computeEigenvectors above for the meaning of parameters. */
  static void lanczos_method(const ErgoMatrix<MatrixType> &A, 
			     vector<real> &eigVal, 
			     vector<VectorType> &eigVec, 
			     int number_of_eigenvalues,
			     const real TOL, 
			     bool use_vector_as_guess, 
			     vector<int> &num_iter, 
			     int maxit = 200, 
			     bool do_deflation = false);

  /** Use power method for computing eigenvectors. See function
      computeEigenvectors above for the meaning of parameters. */
  static void power_method(const ErgoMatrix<MatrixType> &A, 
			   real &eigVal, 
			   VectorType &eigVec, 
			   const real TOL, 
			   bool use_vector_as_guess, 
			   int &num_iter, 
			   int maxit = 200);


}; // end of ErgoMatrix









template<typename MatrixType>
void ErgoMatrix<MatrixType>::read_from_mtx(const string &name, 
			       const mat::SizesAndBlocks &rows, 
			       const mat::SizesAndBlocks &cols)
{
  vector<int> I, J;
  vector<real> val;
  int N, M;
  if(read_matrix_from_mtx(name.c_str(), I, J, val, N, M) == -1) 
    throw "read_matrix_from_mtx: error while reading matrix from the file.";
  assert(N==M);
  this->MATRIX.resetSizesAndBlocks(rows,cols);
  assert(this->get_nrows()*this->get_ncols() == (size_t)N*(size_t)N);
  this->MATRIX.assign_from_sparse(I, J, val);
}


template<typename MatrixType>
void ErgoMatrix<MatrixType>::lanczos_method(const ErgoMatrix<MatrixType> &A, 
					    vector<real> &eigVal, 
					    vector<VectorType> &eigVec,
					    int number_of_eigenvalues,
					    const real TOL, 
					    bool use_vector_as_guess, 
					    vector<int> &num_iter, 
					    int maxit, 
					    bool do_deflation)  
  {
    assert(eigVal.size() == eigVec.size());
    assert(eigVal.size() == num_iter.size());
    assert( number_of_eigenvalues >= 1 );

    const MatrixType& A_MATRIX = A.get_ref_to_matrix_const();

    if(!do_deflation)
      {
	try
	  {

	    VectorType y;
	    mat::SizesAndBlocks cols;
	    A_MATRIX.getCols(cols);
	    y.resetSizesAndBlocks(cols);
	    if(use_vector_as_guess) y = eigVec[0];
	    else y.rand(); // generate random vector
	    const real ONE = 1.0;
	    y *= (ONE/y.eucl()); // normalization


	    mat::arn::LanczosSeveralLargestEig<typename MatrixType::real, MatrixType, VectorType> lan(A_MATRIX, y, number_of_eigenvalues, maxit);
	    lan.setAbsTol(TOL);    
	    lan.setRelTol(TOL);
	    lan.run();
	    real acc = 0;
	    lan.get_ith_eigenpair(1, eigVal[0], eigVec[0], acc);
	    // lan.getLargestMagnitudeEigPair(eigVal,
	    // 			       eigVec,
	    // 			       acc);

	    VectorType resVec(eigVec[0]); // residual
	    resVec *= eigVal[0];
	    resVec += -ONE * A_MATRIX * eigVec[0];

	    /* if(number_of_eigenvalues >= 2) */
	    /*   { */
	    /*     lan.get_ith_eigenpair(2, eigVal[1], eigVec[1], acc); */
	    /*     VectorType resVec2(eigVec[1]); // residual */
	    /*     resVec2 *= eigVal[1]; */
	    /*     resVec2 += -ONE * A_MATRIX * eigVec[1]; */
	    /*   } */
	    num_iter[0] = lan.get_num_iter();
	
	  }
	catch(std::exception& e)
	  {
	    num_iter[0] = maxit; // lanczos did not converge in maxIter iterations
	  }
      }
    else // do_deflation
      {
	if(eigVec[0].is_empty())
	  throw "Error in ErgoMatrix<MatrixType>::lanczos_method : eigVec[0].is_empty()";
	
	// use the vector stored in eigVec[0]
	if(number_of_eigenvalues > 1)
	  {
	    VectorType y, v1;
	    v1 = eigVec[0];

	    mat::SizesAndBlocks cols;
	    A_MATRIX.getCols(cols);
	    y.resetSizesAndBlocks(cols);
	    if(use_vector_as_guess) y = eigVec[1];
	    else y.rand(); // generate random vector
	    const real ONE = 1.0;

	    y *= (ONE/y.eucl()); // normalization

	    try
	      {
		real num_eig = 1; // just one eigenpair should be computed

		// find bounds of the spectrum
		real eigmin, eigmax;
		A_MATRIX.gershgorin(eigmin, eigmax);

		real sigma = eigVal[0] - eigmin; // out eigenvalue to the uninteresting end of the spectrum

		mat::arn::LanczosSeveralLargestEig<typename MatrixType::real, MatrixType, VectorType> lan(A_MATRIX, y, num_eig, maxit, 100, &v1, sigma);
		lan.setAbsTol(TOL);    
		lan.setRelTol(TOL);
		lan.run();
		real acc = 0;
		lan.get_ith_eigenpair(1, eigVal[1], eigVec[1], acc);

		VectorType resVec(eigVec[1]); // residual
		resVec *= eigVal[1];
		resVec += -ONE * A_MATRIX * eigVec[1];

		num_iter[1] = lan.get_num_iter();
	      }
	    catch(std::exception& e)
	      {
		num_iter[1] = maxit; // lanczos did not converge in maxIter iterations
	      }
	  }
	else throw "Error in ErgoMatrix<MatrixType>::lanczos_method :  number_of_eigenvalues <= 1";
      }

  }


template<typename MatrixType>
void ErgoMatrix<MatrixType>::power_method(const ErgoMatrix<MatrixType> &A, real &eigVal, VectorType &eigVec, const real TOL, bool use_vector_as_guess, int &num_iter, int maxit)
  {
    VectorType y;
    VectorType Ay;
    VectorType residual;
    VectorType temp;
    real lambda;
    const real ONE = 1.0;
    const real MONE = -1.0;
    mat::SizesAndBlocks cols;
    const MatrixType& A_MATRIX = A.get_ref_to_matrix_const();
    A_MATRIX.getCols(cols);
    y.resetSizesAndBlocks(cols);
    if(use_vector_as_guess) y = eigVec;
    else y.rand(); // generate random vector
    y *= (ONE/y.eucl()); // normalization
    
    Ay.resetSizesAndBlocks(cols);
    residual.resetSizesAndBlocks(cols);

    temp.resetSizesAndBlocks(cols);
        
    int it = 1;

    Ay = ONE * A_MATRIX * y; // Ay = A*y	
    
    while (it == 1 || (residual.eucl()/abs(lambda) > TOL && it <= maxit))
      {
	y = Ay;
	y *= ONE/Ay.eucl(); // y = Ay/norm(Ay)
	Ay = ONE * A_MATRIX * y; // Ay = A*y
	lambda = transpose(y) * Ay; // lambda = y'*Ay
	
        // r = A*y - lambda*y
        residual = Ay;
	residual += (MONE*lambda)*y;
	//printf("residual.eucl() = %e\n", residual.eucl());

	++it;
      }
    
    printf("Power method required %d iterations.\n", it-1);

    eigVal = lambda;
    eigVec = y;
    num_iter = it-1;
  }



#endif // ERGO_MATRIX_HEADER
