#cython: language_level=3

"""
Parallel version of the cqobjevo's
See ../cqobjevo.pyx for more details
"""
from qutip.cy.cqobjevo cimport CQobjCte, CQobjEvoTd, CQobjEvoTdMatched
from qutip.cy.openmp.parfuncs cimport spmvpy_openmp
import numpy as np
import scipy.sparse as sp
cimport numpy as np
import cython
cimport cython
from cython.parallel import prange

include "../complex_math.pxi"


@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
cdef void _spmmcpy_par(complex* data, int* ind, int* ptr, complex* mat,
                      complex a, complex* out, int sp_rows,
                      unsigned int nrows, unsigned int ncols, int nthr):
    """
    sparse*dense "C" ordered.
    """
    cdef int row, col, ii, jj, row_start, row_end
    for row in prange(sp_rows, nogil=True, num_threads=nthr):
        row_start = ptr[row]
        row_end = ptr[row+1]
        for jj from row_start <= jj < row_end:
            for col in range(ncols):
                out[row * ncols + col] += a*data[jj]*mat[ind[jj] * ncols + col]


@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
cdef void _spmmfpy_omp(complex* data, int* ind, int* ptr, complex* mat,
                  complex a, complex* out, unsigned int sp_rows,
                  unsigned int nrows, unsigned int ncols, int nthr):
    """
    sparse*dense "F" ordered.
    """
    cdef int col
    for col in range(ncols):
        spmvpy_openmp(data, ind, ptr, mat+nrows*col, a,
                      out+sp_rows*col, sp_rows, nthr)


cdef class CQobjCteOmp(CQobjCte):
    cdef int nthr

    def set_threads(self, nthr):
        self.nthr = nthr

    @cython.boundscheck(False)
    @cython.wraparound(False)
    @cython.cdivision(True)
    cdef int _mul_vec(self, double t, complex* vec, complex* out) except -1:
        spmvpy_openmp(self.cte.data, self.cte.indices, self.cte.indptr, vec, 1.,
               out, self.shape0, self.nthr)
        return 0

    @cython.boundscheck(False)
    @cython.wraparound(False)
    @cython.cdivision(True)
    cdef complex _expect(self, double t, complex* vec) except *:
        cdef complex[::1] y = np.zeros(self.shape0, dtype=complex)
        spmvpy_openmp(self.cte.data, self.cte.indices, self.cte.indptr, vec, 1.,
               &y[0], self.shape0, self.nthr)
        cdef int row
        cdef complex dot = 0
        for row from 0 <= row < self.shape0:
            dot += conj(vec[row])*y[row]
        return dot

    @cython.boundscheck(False)
    @cython.wraparound(False)
    @cython.cdivision(True)
    cdef int _mul_matf(self, double t, complex* mat, complex* out,
                        int nrow, int ncol) except -1:
        _spmmfpy_omp(self.cte.data, self.cte.indices, self.cte.indptr, mat, 1.,
               out, self.shape0, nrow, ncol, self.nthr)
        return 0

    @cython.boundscheck(False)
    @cython.wraparound(False)
    @cython.cdivision(True)
    cdef int _mul_matc(self, double t, complex* mat, complex* out,
                        int nrow, int ncol) except -1:
        _spmmcpy_par(self.cte.data, self.cte.indices, self.cte.indptr, mat, 1.,
               out, self.shape0, nrow, ncol, self.nthr)
        return 0


cdef class CQobjEvoTdOmp(CQobjEvoTd):
    cdef int nthr

    def set_threads(self, nthr):
        self.nthr = nthr

    @cython.boundscheck(False)
    @cython.wraparound(False)
    @cython.cdivision(True)
    cdef int _mul_vec(self, double t, complex* vec, complex* out) except -1:
        cdef int[2] shape
        shape[0] = self.shape1
        shape[1] = 1
        self._factor_dyn(t, vec, shape)
        cdef int i
        spmvpy_openmp(self.cte.data, self.cte.indices, self.cte.indptr, vec,
               1., out, self.shape0, self.nthr)
        for i in range(self.num_ops):
            spmvpy_openmp(self.ops[i].data, self.ops[i].indices, self.ops[i].indptr,
                   vec, self.coeff_ptr[i], out, self.shape0, self.nthr)
        return 0

    @cython.boundscheck(False)
    @cython.wraparound(False)
    @cython.cdivision(True)
    cdef int _mul_matf(self, double t, complex* mat, complex* out,
                        int nrow, int ncol) except -1:
        cdef int[2] shape
        shape[0] = nrow
        shape[1] = ncol
        self._factor_dyn(t, mat, shape)
        cdef int i
        _spmmfpy_omp(self.cte.data, self.cte.indices, self.cte.indptr, mat, 1.,
               out, self.shape0, nrow, ncol, self.nthr)
        for i in range(self.num_ops):
             _spmmfpy_omp(self.ops[i].data, self.ops[i].indices, self.ops[i].indptr,
                 mat, self.coeff_ptr[i], out, self.shape0, nrow, ncol, self.nthr)
        return 0

    @cython.boundscheck(False)
    @cython.wraparound(False)
    @cython.cdivision(True)
    cdef int _mul_matc(self, double t, complex* mat, complex* out,
                        int nrow, int ncol) except -1:
        cdef int[2] shape
        shape[0] = nrow
        shape[1] = ncol
        self._factor_dyn(t, mat, shape)
        cdef int i
        _spmmcpy_par(self.cte.data, self.cte.indices, self.cte.indptr, mat, 1.,
               out, self.shape0, nrow, ncol, self.nthr)
        for i in range(self.num_ops):
             _spmmcpy_par(self.ops[i].data, self.ops[i].indices, self.ops[i].indptr,
                 mat, self.coeff_ptr[i], out, self.shape0, nrow, ncol, self.nthr)
        return 0


cdef class CQobjEvoTdMatchedOmp(CQobjEvoTdMatched):
    cdef int nthr

    def set_threads(self, nthr):
        self.nthr = nthr

    @cython.boundscheck(False)
    @cython.wraparound(False)
    @cython.cdivision(True)
    cdef int _mul_vec(self, double t, complex* vec, complex* out) except -1:
        cdef int[2] shape
        shape[0] = self.shape1
        shape[1] = 1
        self._factor_dyn(t, vec, shape)
        self._call_core(self.data_t, self.coeff_ptr)
        spmvpy_openmp(self.data_ptr, &self.indices[0], &self.indptr[0], vec,
               1., out, self.shape0, self.nthr)
        return 0

    @cython.boundscheck(False)
    @cython.wraparound(False)
    @cython.cdivision(True)
    cdef int _mul_matf(self, double t, complex* mat, complex* out,
                        int nrow, int ncol) except -1:
        cdef int[2] shape
        shape[0] = nrow
        shape[1] = ncol
        self._factor_dyn(t, mat, shape)
        self._call_core(self.data_t, self.coeff_ptr)
        _spmmfpy_omp(self.data_ptr, &self.indices[0], &self.indptr[0], mat, 1.,
               out, self.shape0, nrow, ncol, self.nthr)
        return 0

    @cython.boundscheck(False)
    @cython.wraparound(False)
    @cython.cdivision(True)
    cdef int _mul_matc(self, double t, complex* mat, complex* out,
                        int nrow, int ncol) except -1:
        cdef int[2] shape
        shape[0] = nrow
        shape[1] = ncol
        self._factor_dyn(t, mat, shape)
        self._call_core(self.data_t, self.coeff_ptr)
        _spmmcpy_par(self.data_ptr, &self.indices[0], &self.indptr[0], mat, 1.,
               out, self.shape0, nrow, ncol, self.nthr)
        return 0
