#include <algorithm>
#include <iostream>
#include <iterator>
#include <iomanip>
#include <cstdlib>
#include <cula.hpp>
#include <cuda.h>
#include <cuda_runtime.h>
#include "gincsvd_config.h"

using std::cout; 
using std::endl; 
using std::cerr;

  namespace details {
    inline void cudaCheck(cudaError status, const char *file, const int line )
    {
      if (cudaSuccess != status)
      {
        cerr << "File " << file << ", line " << line << " returned CUDA error code " << status << ":\n* "
                  << cudaGetErrorString(status) << endl;
        exit(-1);        
      }
    }

    inline void culaCheck(culaStatus status, const char *file, const int line )
    {
      if (culaNoError != status)
      {
        culaInfo info = culaGetErrorInfo();
        const int errsz = 1024;
        char errstr[errsz];
        if( culaGetErrorInfoString(status, info, errstr, errsz) != culaNoError ) {
          cerr << "error calling culaGetErrorInfo(); this is likely an internal GINCSVD error." << endl;
        }
        cerr << "File " << file << ", line " << line << " returned CULA error code " << status << ", info " << info << ":\n* "
                  << errstr << endl;
        exit(-1);        
      }
    }
  }

#define CUDA_CHECK(status) \
    details::cudaCheck( status , __FILE__ , __LINE__ )
#define CULA_CHECK(status) \
    details::culaCheck( status , __FILE__ , __LINE__ )

void init(int device) {
  CUDA_CHECK( cudaSetDevice(device) );
  cudaDeviceProp deviceProp;
  CUDA_CHECK( cudaGetDeviceProperties(&deviceProp, device) );
  cout << "Using device " << device << ", \"" << deviceProp.name << "\""
       << ", of compute capability " << deviceProp.major << "." << deviceProp.minor 
       << endl;
  if (deviceProp.asyncEngineCount == 0) {
    cout << "*** WARNING: asyncEngineCount == 0; cannot overlap communication with computation." << endl;
  }
  CULA_CHECK( culaInitialize() );
}

template <class T>
void test_cula() {
    //  problem size
    //  don't change these without changing the expected output below
    const int M =  10000;
    const int L =      4;
    const int N =    2*L;
    //  host data
    const int LDU  = M;  
    const int LDVT = 2*L;
    T *hA  = new T[LDU*N];
    T *hS  = new T[N];
    std::fill(hA,hA+M*N,1.0);
    // device data
    T *dA, *dVT, *dS, *tmp;
    CUDA_CHECK( cudaMalloc(&dA,  sizeof(T)*LDU*N)  );
    CUDA_CHECK( cudaMalloc(&dVT, sizeof(T)*LDVT*N) );
    CUDA_CHECK( cudaMalloc(&tmp, sizeof(T)*LDVT*N) );
    CUDA_CHECK( cudaMalloc(&dS,  sizeof(T)*N)  );
    // copy all N columns to device
    CUDA_CHECK( cudaMemcpy(dA, hA, sizeof(T)*M*N, cudaMemcpyHostToDevice)     );
    // do first svd
    const int firstSVDRank = L;
    CULA_CHECK( culaDeviceGesvd('O','A',M,firstSVDRank,dA,LDU,dS,NULL,LDU,dVT,LDVT) );
    cudaMemcpy(hS,dS,sizeof(T)*firstSVDRank,cudaMemcpyDeviceToHost);
    cout << "Singular values of A (200 and then zeros):\n"; std::copy( hS, hS+firstSVDRank, std::ostream_iterator<T>(cout,"  ") ); cout << "\n\n";
    CULA_CHECK( culaDeviceGesvd('O','A',firstSVDRank,firstSVDRank,dVT,LDVT,dS,NULL,LDVT,tmp,LDVT) );
    cudaMemcpy(hS,dS,sizeof(T)*firstSVDRank,cudaMemcpyDeviceToHost);
    cout << "Singular values of VT (all ones):\n"; std::copy( hS, hS+firstSVDRank, std::ostream_iterator<T>(cout,"  ") ); cout << "\n\n";
    // do second svd
    const int secondSVDRank = N;
    CULA_CHECK( culaDeviceGesvd('O','A',M,secondSVDRank,dA,LDU,dS,NULL,LDU,dVT,LDVT) );
    cudaMemcpy(hS,dS,sizeof(T)*secondSVDRank,cudaMemcpyDeviceToHost);
    cout << "Singular values of A (200+epsilon, then ones, then zeros): \n"; std::copy( hS, hS+secondSVDRank, std::ostream_iterator<T>(cout,"  ") ); cout << "\n\n";
    CULA_CHECK( culaDeviceGesvd('O','A',secondSVDRank,secondSVDRank,dVT,LDVT,dS,NULL,LDVT,tmp,LDVT) );
    cudaMemcpy(hS,dS,sizeof(T)*secondSVDRank,cudaMemcpyDeviceToHost);
    cout << "Singular values of VT (all ones): \n"; std::copy( hS, hS+secondSVDRank, std::ostream_iterator<T>(cout,"  ") ); cout << "\n\n";
    //
    delete [] hA;
    delete [] hS;
    CUDA_CHECK( cudaFree(dA)  );
    CUDA_CHECK( cudaFree(dVT) );
    CUDA_CHECK( cudaFree(dS)  );
    CUDA_CHECK( cudaFree(tmp) );
}

int main() {
  // 
  cout << std::setprecision(10) << std::scientific;
  init(0);
#ifdef GINCSVD_FLOAT
  cout << "\n**** Testing float..." << endl;
  test_cula<float>();  
#endif
#ifdef GINCSVD_DOUBLE
  cout << "\n**** Testing double..." << endl;
  test_cula<double>();  
#endif
  return 0;
}
