#include <unittest/unittest.h>

#ifndef _MSC_VER
#define CUSP_USE_TEXTURE_MEMORY
#endif

#include <cusp/multiply.h>

#include <cusp/linear_operator.h>
#include <cusp/print.h>
#include <cusp/gallery/poisson.h>
#include <cusp/gallery/random.h>

#include <cusp/array2d.h>
#include <cusp/coo_matrix.h>
#include <cusp/csr_matrix.h>
#include <cusp/dia_matrix.h>
#include <cusp/ell_matrix.h>
#include <cusp/hyb_matrix.h>

/////////////////////////////////////////
// Sparse Matrix-Matrix Multiplication //
/////////////////////////////////////////

template <typename SparseMatrixType, typename DenseMatrixType>
void CompareSparseMatrixMatrixMultiply(DenseMatrixType A, DenseMatrixType B)
{
    DenseMatrixType C;
    cusp::multiply(A, B, C);

    SparseMatrixType _A(A), _B(B), _C;
    cusp::multiply(_A, _B, _C);

    ASSERT_EQUAL(C == DenseMatrixType(_C), true);

    typename SparseMatrixType::view _Aview(_A), _Bview(_B), _Cview(_C);
    cusp::multiply(_Aview, _Bview, _Cview);

    ASSERT_EQUAL(C == DenseMatrixType(_Cview), true);
}

template <typename TestMatrix>
void TestSparseMatrixMatrixMultiply(void)
{
    cusp::array2d<float,cusp::host_memory> A(3,2);
    A(0,0) = 1.0;
    A(0,1) = 2.0;
    A(1,0) = 3.0;
    A(1,1) = 0.0;
    A(2,0) = 5.0;
    A(2,1) = 6.0;

    cusp::array2d<float,cusp::host_memory> B(2,4);
    B(0,0) = 0.0;
    B(0,1) = 2.0;
    B(0,2) = 3.0;
    B(0,3) = 4.0;
    B(1,0) = 5.0;
    B(1,1) = 0.0;
    B(1,2) = 0.0;
    B(1,3) = 8.0;

    cusp::array2d<float,cusp::host_memory> C(2,2);
    C(0,0) = 0.0;
    C(0,1) = 0.0;
    C(1,0) = 3.0;
    C(1,1) = 5.0;

    cusp::array2d<float,cusp::host_memory> D(2,1);
    D(0,0) = 2.0;
    D(1,0) = 3.0;

    cusp::array2d<float,cusp::host_memory> E(2,2);
    E(0,0) = 0.0;
    E(0,1) = 0.0;
    E(1,0) = 0.0;
    E(1,1) = 0.0;

    cusp::array2d<float,cusp::host_memory> F(2,3);
    F(0,0) = 0.0;
    F(0,1) = 1.5;
    F(0,2) = 3.0;
    F(1,0) = 0.5;
    F(1,1) = 0.0;
    F(1,2) = 0.0;

    cusp::array2d<float,cusp::host_memory> G;
    cusp::gallery::poisson5pt(G, 4, 6);

    cusp::array2d<float,cusp::host_memory> H;
    cusp::gallery::poisson5pt(H, 8, 3);

    cusp::array2d<float,cusp::host_memory> I;
    cusp::gallery::random(24, 24, 150, I);

    cusp::array2d<float,cusp::host_memory> J;
    cusp::gallery::random(24, 24, 50, J);

    cusp::array2d<float,cusp::host_memory> K;
    cusp::gallery::random(24, 12, 20, K);

    //thrust::host_vector< cusp::array2d<float,cusp::host_memory> > matrices;
    std::vector< cusp::array2d<float,cusp::host_memory> > matrices;
    matrices.push_back(A);
    matrices.push_back(B);
    matrices.push_back(C);
    matrices.push_back(D);
    matrices.push_back(E);
    matrices.push_back(F);
    matrices.push_back(G);
    matrices.push_back(H);
    matrices.push_back(I);
    matrices.push_back(J);
    matrices.push_back(K);

    // test matrix multiply for every pair of compatible matrices
    for(size_t i = 0; i < matrices.size(); i++)
    {
        const cusp::array2d<float,cusp::host_memory>& left = matrices[i];
        for(size_t j = 0; j < matrices.size(); j++)
        {
            const cusp::array2d<float,cusp::host_memory>& right = matrices[j];

            if (left.num_cols == right.num_rows)
                CompareSparseMatrixMatrixMultiply<TestMatrix>(left, right);
        }
    }

}
DECLARE_SPARSE_MATRIX_UNITTEST(TestSparseMatrixMatrixMultiply);

///////////////////////////////////////////////
// Sparse Matrix-Dense Matrix Multiplication //
///////////////////////////////////////////////

template <typename SparseMatrixType, typename DenseMatrixType>
void CompareSparseMatrixDenseMatrixMultiply(DenseMatrixType A, DenseMatrixType B)
{
    typedef typename SparseMatrixType::value_type ValueType;
    typedef typename SparseMatrixType::memory_space MemorySpace;
    typedef cusp::array2d<ValueType,MemorySpace,cusp::column_major> DenseSpaceMatrixType;

    DenseMatrixType C(A.num_rows, B.num_cols);
    cusp::multiply(A, B, C);

    SparseMatrixType _A(A);

    // Copy B into the memory space
    DenseSpaceMatrixType B_space(B);
    // Allocate _B and ensure each column is properly aligned
    DenseSpaceMatrixType _B(B.num_rows, B.num_cols, ValueType(0), cusp::detail::round_up(B.num_rows, size_t(128)));
    // Copy columns of B into _B
    for(size_t i = 0; i < B.num_cols; i++ )
	cusp::blas::copy(B_space.column(i), _B.column(i));

    // test container
    {
        DenseSpaceMatrixType _C(C.num_rows, C.num_cols);
        cusp::multiply(_A, _B, _C);

        ASSERT_EQUAL(C == DenseMatrixType(_C), true);
    }

    {
        // test view
        DenseSpaceMatrixType _C(C.num_rows, C.num_cols);
        typename SparseMatrixType::view _Aview(_A);
        typename DenseSpaceMatrixType::view _Bview(_B), _Cview(_C);
        cusp::multiply(_Aview, _Bview, _Cview);

        ASSERT_EQUAL(C == DenseMatrixType(_C), true);
    }
}

template <typename TestMatrix>
void TestSparseMatrixDenseMatrixMultiply(void)
{
    cusp::array2d<float,cusp::host_memory> A(3,2);
    A(0,0) = 1.0;
    A(0,1) = 2.0;
    A(1,0) = 3.0;
    A(1,1) = 0.0;
    A(2,0) = 5.0;
    A(2,1) = 6.0;

    cusp::array2d<float,cusp::host_memory> B(2,4);
    B(0,0) = 0.0;
    B(0,1) = 2.0;
    B(0,2) = 3.0;
    B(0,3) = 4.0;
    B(1,0) = 5.0;
    B(1,1) = 0.0;
    B(1,2) = 0.0;
    B(1,3) = 8.0;

    cusp::array2d<float,cusp::host_memory> C(2,2);
    C(0,0) = 0.0;
    C(0,1) = 0.0;
    C(1,0) = 3.0;
    C(1,1) = 5.0;

    cusp::array2d<float,cusp::host_memory> D(2,1);
    D(0,0) = 2.0;
    D(1,0) = 3.0;

    cusp::array2d<float,cusp::host_memory> E(2,2);
    E(0,0) = 0.0;
    E(0,1) = 0.0;
    E(1,0) = 0.0;
    E(1,1) = 0.0;

    cusp::array2d<float,cusp::host_memory> F(2,3);
    F(0,0) = 0.0;
    F(0,1) = 1.5;
    F(0,2) = 3.0;
    F(1,0) = 0.5;
    F(1,1) = 0.0;
    F(1,2) = 0.0;

    cusp::array2d<float,cusp::host_memory> G;
    cusp::gallery::poisson5pt(G, 4, 6);

    cusp::array2d<float,cusp::host_memory> H;
    cusp::gallery::poisson5pt(H, 8, 3);

    cusp::array2d<float,cusp::host_memory> I;
    cusp::gallery::random(24, 24, 150, I);

    cusp::array2d<float,cusp::host_memory> J;
    cusp::gallery::random(24, 24, 50, J);

    cusp::array2d<float,cusp::host_memory> K;
    cusp::gallery::random(24, 12, 20, K);

    //thrust::host_vector< cusp::array2d<float,cusp::host_memory,cusp::column_major> > matrices;
    std::vector< cusp::array2d<float,cusp::host_memory,cusp::column_major> > matrices;
    matrices.push_back(A);
    matrices.push_back(B);
    matrices.push_back(C);
    matrices.push_back(D);
    matrices.push_back(E);
    matrices.push_back(F);
    matrices.push_back(G);
    matrices.push_back(H);
    matrices.push_back(I);
    matrices.push_back(J);
    matrices.push_back(K);

    // test matrix multiply for every pair of compatible matrices
    for(size_t i = 0; i < matrices.size(); i++)
    {
        const cusp::array2d<float,cusp::host_memory,cusp::column_major>& left = matrices[i];
        for(size_t j = 0; j < matrices.size(); j++)
        {
            const cusp::array2d<float,cusp::host_memory,cusp::column_major>& right = matrices[j];

            if (left.num_cols == right.num_rows)
                CompareSparseMatrixDenseMatrixMultiply<TestMatrix>(left, right);
        }
    }

}
DECLARE_SPARSE_MATRIX_UNITTEST(TestSparseMatrixDenseMatrixMultiply);


/////////////////////////////////////////
// Sparse Matrix-Vector Multiplication //
/////////////////////////////////////////

template <typename SparseMatrixType, typename DenseMatrixType>
void CompareSparseMatrixVectorMultiply(DenseMatrixType A)
{
    typedef typename SparseMatrixType::memory_space MemorySpace;

    // setup reference input
    cusp::array1d<float, cusp::host_memory> x(A.num_cols);
    cusp::array1d<float, cusp::host_memory> y(A.num_rows, 10);
    for(size_t i = 0; i < x.size(); i++)
        x[i] = i % 10;

    // compute reference output
    cusp::multiply(A, x, y);

    // test container
    {
        SparseMatrixType _A(A);
        cusp::array1d<float, MemorySpace> _x(x);
        cusp::array1d<float, MemorySpace> _y(A.num_rows, 10);

        cusp::multiply(_A, _x, _y);

        ASSERT_EQUAL(_y, y);
    }

    // test matrix view
    {
        SparseMatrixType _A(A);
        cusp::array1d<float, MemorySpace> _x(x);
        cusp::array1d<float, MemorySpace> _y(A.num_rows, 10);

        typename SparseMatrixType::view _V(_A);
        cusp::multiply(_V, _x, _y);

        ASSERT_EQUAL(_y, y);
    }

    // test array view
    {
        SparseMatrixType _A(A);
        cusp::array1d<float, MemorySpace> _x(x);
        cusp::array1d<float, MemorySpace> _y(A.num_rows, 10);

        typename cusp::array1d<float, MemorySpace> _Vx(_x), _Vy(_y);
        cusp::multiply(_A, _Vx, _Vy);

        ASSERT_EQUAL(_Vy, y);
    }
}


// TODO use COO reference format and test larger problem sizes
template <class TestMatrix>
void TestSparseMatrixVectorMultiply()
{
    typedef typename TestMatrix::memory_space MemorySpace;

    cusp::array2d<float, cusp::host_memory> A(5,4);
    A(0,0) = 13;
    A(0,1) = 80;
    A(0,2) =  0;
    A(0,3) =  0;
    A(1,0) =  0;
    A(1,1) = 27;
    A(1,2) =  0;
    A(1,3) =  0;
    A(2,0) = 55;
    A(2,1) =  0;
    A(2,2) = 24;
    A(2,3) = 42;
    A(3,0) =  0;
    A(3,1) = 69;
    A(3,2) =  0;
    A(3,3) = 83;
    A(4,0) =  0;
    A(4,1) =  0;
    A(4,2) = 27;
    A(4,3) =  0;

    cusp::array2d<float,cusp::host_memory> B(2,4);
    B(0,0) = 0.0;
    B(0,1) = 2.0;
    B(0,2) = 3.0;
    B(0,3) = 4.0;
    B(1,0) = 5.0;
    B(1,1) = 0.0;
    B(1,2) = 0.0;
    B(1,3) = 8.0;

    cusp::array2d<float,cusp::host_memory> C(2,2);
    C(0,0) = 0.0;
    C(0,1) = 0.0;
    C(1,0) = 3.0;
    C(1,1) = 5.0;

    cusp::array2d<float,cusp::host_memory> D(2,1);
    D(0,0) = 2.0;
    D(1,0) = 3.0;

    cusp::array2d<float,cusp::host_memory> E(2,2);
    E(0,0) = 0.0;
    E(0,1) = 0.0;
    E(1,0) = 0.0;
    E(1,1) = 0.0;

    cusp::array2d<float,cusp::host_memory> F(2,3);
    F(0,0) = 0.0;
    F(0,1) = 1.5;
    F(0,2) = 3.0;
    F(1,0) = 0.5;
    F(1,1) = 0.0;
    F(1,2) = 0.0;

    cusp::array2d<float,cusp::host_memory> G;
    cusp::gallery::poisson5pt(G, 4, 6);

    cusp::array2d<float,cusp::host_memory> H;
    cusp::gallery::poisson5pt(H, 8, 3);

    CompareSparseMatrixVectorMultiply<TestMatrix>(A);
    CompareSparseMatrixVectorMultiply<TestMatrix>(B);
    CompareSparseMatrixVectorMultiply<TestMatrix>(C);
    CompareSparseMatrixVectorMultiply<TestMatrix>(D);
    CompareSparseMatrixVectorMultiply<TestMatrix>(E);
    CompareSparseMatrixVectorMultiply<TestMatrix>(F);
    CompareSparseMatrixVectorMultiply<TestMatrix>(G);
    CompareSparseMatrixVectorMultiply<TestMatrix>(H);
}
DECLARE_SPARSE_MATRIX_UNITTEST(TestSparseMatrixVectorMultiply);


//////////////////////////////
// General Linear Operators //
//////////////////////////////

template <class MemorySpace>
void TestMultiplyIdentityOperator(void)
{
    cusp::array1d<float, MemorySpace> x(4);
    cusp::array1d<float, MemorySpace> y(4);

    x[0] =  7.0f;
    y[0] =  0.0f;
    x[1] =  5.0f;
    y[1] = -2.0f;
    x[2] =  4.0f;
    y[2] =  0.0f;
    x[3] = -3.0f;
    y[3] =  5.0f;

    cusp::identity_operator<float, MemorySpace> A(4,4);

    cusp::multiply(A, x, y);

    ASSERT_EQUAL(y[0],  7.0f);
    ASSERT_EQUAL(y[1],  5.0f);
    ASSERT_EQUAL(y[2],  4.0f);
    ASSERT_EQUAL(y[3], -3.0f);
}
DECLARE_HOST_DEVICE_UNITTEST(TestMultiplyIdentityOperator);

