/**
 *
 * @file bcsc_cspmv.c
 *
 * Functions computing matrix-vector products for the BCSC
 *
 * @copyright 2004-2024 Bordeaux INP, CNRS (LaBRI UMR 5800), Inria,
 *                      Univ. Bordeaux. All rights reserved.
 *
 * @version 6.4.0
 * @author Mathieu Faverge
 * @author Vincent Bridonneau
 * @author Theophile Terraz
 * @author Tony Delarue
 * @date 2024-07-05
 *
 * @generated from /build/pastix/src/pastix-6.4.0/bcsc/bcsc_zspmv.c, normal z -> c, Tue Dec 16 21:22:40 2025
 *
 **/
#include "common.h"
#include <math.h>
#include "bcsc/bcsc.h"
#include "bcsc_c.h"
#include "blend/solver.h"
#include "pastix/datatypes.h"

#ifndef DOXYGEN_SHOULD_SKIP_THIS

typedef void ( *bcsc_cspmv_Ax_fct_t )( const pastix_bcsc_t *,
                                       const bcsc_cblk_t *,
                                       pastix_complex32_t,
                                       const pastix_complex32_t *,
                                       const pastix_complex32_t *,
                                       pastix_complex32_t,
                                       pastix_complex32_t * );

static inline void
__bcsc_cspmv_by( pastix_int_t        n,
                 pastix_complex32_t  beta,
                 pastix_complex32_t *y )
{
    if( beta != (pastix_complex32_t)0.0 )
    {
        pastix_int_t j;
        for( j=0; j<n; j++, y++ )
        {
            (*y) *= beta;
        }
    }
    else
    {
        memset( y, 0, n * sizeof(pastix_complex32_t) );
    }
}

static inline void
__bcsc_cspmv_Ax( const pastix_bcsc_t      *bcsc,
                 const bcsc_cblk_t        *cblk,
                 pastix_complex32_t        alpha,
                 const pastix_complex32_t *A,
                 const pastix_complex32_t *x,
                 pastix_complex32_t        beta,
                 pastix_complex32_t       *y )
{
    pastix_int_t i, j;

    __bcsc_cspmv_by( cblk->colnbr, beta, y );

    for( j=0; j<cblk->colnbr; j++, y++ )
    {
        for( i=cblk->coltab[j]; i< cblk->coltab[j+1]; i++ )
        {
            *y += alpha * A[i] * x[ bcsc->rowtab[i] ];
        }
    }
}

static inline void
__bcsc_cspmv_Ax_ind( const pastix_bcsc_t      *bcsc,
                     pastix_complex32_t        alpha,
                     const pastix_complex32_t *A,
                     const pastix_complex32_t *x,
                     pastix_complex32_t        beta,
                     pastix_complex32_t       *y )
{
    const pastix_complex32_t *xptr = x;
    pastix_int_t bloc, i, j;

    __bcsc_cspmv_by( bcsc->gN, beta, y );

    for( bloc=0; bloc<bcsc->cscfnbr; bloc++ )
    {
        for( j=0; j < bcsc->cscftab[bloc].colnbr; j++, xptr++ )
        {
            for( i = bcsc->cscftab[bloc].coltab[j]; i < bcsc->cscftab[bloc].coltab[j+1]; i++ )
            {
                y[ bcsc->rowtab[i] ] += alpha * A[i] * (*xptr);
            }
        }
    }
}

#if defined(PRECISION_z) || defined(PRECISION_c)
static inline void
__bcsc_cspmv_conjAx( const pastix_bcsc_t      *bcsc,
                     const bcsc_cblk_t        *cblk,
                     pastix_complex32_t        alpha,
                     const pastix_complex32_t *A,
                     const pastix_complex32_t *x,
                     pastix_complex32_t        beta,
                     pastix_complex32_t       *y )
{
    pastix_int_t i, j;

    __bcsc_cspmv_by( cblk->colnbr, beta, y );

    for( j=0; j<cblk->colnbr; j++, y++ )
    {
        for( i=cblk->coltab[j]; i< cblk->coltab[j+1]; i++ )
        {
            *y += alpha * conjf( A[i] ) * x[ bcsc->rowtab[i] ];
        }
    }
}
#endif

static inline void
__bcsc_cspmv_loop( const SolverMatrix       *solvmtx,
                   pastix_trans_t            trans,
                   pastix_complex32_t        alpha,
                   const pastix_bcsc_t      *bcsc,
                   const pastix_complex32_t *x,
                   pastix_complex32_t        beta,
                   pastix_complex32_t       *y,
                   pastix_int_t              rank,
                   pastix_int_t              begin,
                   pastix_int_t              end )
{
    bcsc_cspmv_Ax_fct_t  zspmv_Ax = __bcsc_cspmv_Ax;
    pastix_complex32_t  *valptr = NULL;
    pastix_int_t         bloc;
    bcsc_cblk_t         *cblk;

    /*
     * There are three cases:
     *    We can use the Lvalues pointer directly:
     *          - The matrix is general and we use A^t
     *          - the matrix is symmetric or hermitian
     *    We can use the Uvalues pointer directly
     *          - The matrix is general and we use A
     *    We have to use Lvalues per row (instead of column)
     *          - The matrix A is general and Uvalues is unavailable
     *
     * To this, we have to add the conjf call if ConjTrans or Hermitian
     *
     *     Mtxtype   | trans asked | algo applied
     *     ++++++++++++++++++++++++++++++++++++
     +     General   | NoTrans     | U if possible, otherwise indirect L
     +     General   | Trans       | L
     +     General   | ConjTrans   | conjf(L)
     +     Symmetric | NoTrans     | L
     +     Symmetric | Trans       | L
     +     Symmetric | ConjTrans   | conjf(L)
     +     Hermitian | NoTrans     | conjf(L)
     +     Hermitian | Trans       | L
     +     Hermitian | ConjTrans   | conjf(L)
     */
    cblk   = bcsc->cscftab + begin;
    valptr = (pastix_complex32_t*)bcsc->Lvalues;

    if ( (bcsc->mtxtype == PastixGeneral) && (trans == PastixNoTrans) )
    {
        /* U */
        if ( bcsc->Uvalues != NULL ) {
            valptr = (pastix_complex32_t*)bcsc->Uvalues;
        }
        /* Indirect L */
        else {
            /* Execute in sequential */
            if ( rank != 0 ) {
                return;
            }
            __bcsc_cspmv_Ax_ind( bcsc, alpha, valptr, x, beta, y );
        }
    }
#if defined(PRECISION_z) || defined(PRECISION_c)
    /* Conj(L) */
    else if ( ( (bcsc->mtxtype == PastixGeneral  ) && (trans == PastixConjTrans) ) ||
              ( (bcsc->mtxtype == PastixSymmetric) && (trans == PastixConjTrans) ) ||
              ( (bcsc->mtxtype == PastixHermitian) && (trans != PastixTrans    ) ) )
    {
        zspmv_Ax = __bcsc_cspmv_conjAx;
    }
#endif /* defined(PRECISION_z) || defined(PRECISION_c) */

    for( bloc=begin; bloc<end; bloc++, cblk++ )
    {
        const SolverCblk   *solv_cblk = solvmtx->cblktab + cblk->cblknum;
        pastix_complex32_t *yptr      = y + solv_cblk->lcolidx;

        assert( !(solv_cblk->cblktype & (CBLK_FANIN|CBLK_RECV)) );

        zspmv_Ax( bcsc, cblk, alpha, valptr, x, beta, yptr );
    }
}

#endif /* DOXYGEN_SHOULD_SKIP_THIS */

/**
 *******************************************************************************
 *
 * @ingroup bcsc
 *
 * @brief Compute the matrix-vector product  y = alpha * A * x + beta * y
 * (Sequential version)
 *
 * Where A is given in the bcsc format, x and y are two vectors of size n, and
 * alpha and beta are two scalars.
 *
 *******************************************************************************
 *
 * @param[in] pastix_data
 *          Provide information about bcsc
 *
 * @param[in] trans
 *          Specifies whether the matrix A from the bcsc is transposed, not
 *          transposed or conjugate transposed:
 *            = PastixNoTrans:   A is not transposed;
 *            = PastixTrans:     A is transposed;
 *            = PastixConjTrans: A is conjugate transposed.
 *
 * @param[in] alpha
 *          alpha specifies the scalar alpha
 *
 * @param[in] x
 *          The vector x.
 *
 * @param[in] beta
 *          beta specifies the scalar beta
 *
 * @param[inout] y
 *          The vector y.
 *
 *******************************************************************************/
void
bcsc_cspmv_seq( const pastix_data_t      *pastix_data,
                pastix_trans_t            trans,
                pastix_complex32_t        alpha,
                const pastix_complex32_t *x,
                pastix_complex32_t        beta,
                pastix_complex32_t       *y )
{
    pastix_bcsc_t *bcsc    = pastix_data->bcsc;
    SolverMatrix  *solvmtx = pastix_data->solvmatr;

    if( (bcsc == NULL) || (y == NULL) || (x == NULL) ) {
        return;
    }

    __bcsc_cspmv_loop( solvmtx,
                       trans, alpha, bcsc, x, beta, y,
                       0, 0, bcsc->cscfnbr );
}

/**
 * @brief Data structure for parallel arguments of spmv functions
 */
struct c_argument_spmv_s {
    pastix_trans_t            trans;
    pastix_complex32_t        alpha;
    const pastix_bcsc_t      *bcsc;
    const pastix_complex32_t *x;
    pastix_complex32_t        beta;
    pastix_complex32_t       *y;
    SolverMatrix             *mtx;
    pastix_int_t             *start_indexes; /* starting position for each thread*/
    pastix_int_t             *start_bloc;
};

/**
 *******************************************************************************
 *
 * @ingroup bcsc_internal
 *
 * @brief Compute the matrix-vector product  y = alpha * op(A) * x + beta * y
 *
 * Where A is given in the bcsc format, x and y are two vectors of size n, and
 * alpha and beta are two scalars.
 * The op function is specified by the trans parameter and performs the
 * operation as follows:
 *              trans = PastixNoTrans   y := alpha*A       *x + beta*y
 *              trans = PastixTrans     y := alpha*A'      *x + beta*y
 *              trans = PastixConjTrans y := alpha*conjf(A')*x + beta*y
 *
 *******************************************************************************
 *
 * @param[in] ctx
 *          the context of the current thread
 *
 * @param[inout] args
 *          The parameter as specified in bcsc_cspmv.
 *
 *******************************************************************************/
void
pthread_bcsc_cspmv( isched_thread_t *ctx,
                    void            *args )
{
    struct c_argument_spmv_s *arg    = (struct c_argument_spmv_s*)args;
    const pastix_bcsc_t      *bcsc   = arg->bcsc;
    pastix_int_t              begin, end, size, rank;
    pastix_int_t             *start_indexes = arg->start_indexes;
    pastix_int_t             *start_bloc    = arg->start_bloc;

    rank = (pastix_int_t)ctx->rank;
    size = (pastix_int_t)ctx->global_ctx->world_size;

    begin = start_bloc[rank];
    if ( rank == (size - 1) )
    {
        end = bcsc->cscfnbr;
    }
    else {
        end = start_bloc[rank + 1];
    }

    __bcsc_cspmv_loop( arg->mtx,
                       arg->trans, arg->alpha, bcsc, arg->x,
                       arg->beta, arg->y + start_indexes[rank],
                       rank, begin, end );
}

/**
 *******************************************************************************
 *
 * @ingroup bcsc_internal
 *
 * @brief Compute the matrix-vector product  y = alpha * op(A) * x + beta * y
 *
 * Where A is given in the bcsc format, x and y are two vectors of size n, and
 * alpha and beta are two scalars.
 * The op function is specified by the trans parameter and performs the
 * operation as follows:
 *              trans = PastixNoTrans   y := alpha*A       *x + beta*y
 *              trans = PastixTrans     y := alpha*A'      *x + beta*y
 *              trans = PastixConjTrans y := alpha*conjf(A')*x + beta*y
 *
 *******************************************************************************
 *
 * @param[in] ctx
 *          the context of the current thread
 *
 * @param[inout] args
 *          The parameter as specified in bcsc_cspmv.
 *
 *******************************************************************************/
void
pthread_bcsc_cspmv_tasktab( isched_thread_t *ctx,
                            void            *args )
{
    bcsc_cspmv_Ax_fct_t       zspmv_Ax = __bcsc_cspmv_Ax;
    struct c_argument_spmv_s *arg    = (struct c_argument_spmv_s*)args;
    pastix_trans_t            trans  = arg->trans;
    pastix_complex32_t        alpha  = arg->alpha;
    const pastix_bcsc_t      *bcsc   = arg->bcsc;
    const pastix_complex32_t *x      = arg->x;
    pastix_complex32_t        beta   = arg->beta;
    pastix_complex32_t       *y      = arg->y;
    pastix_complex32_t       *valptr = NULL;
    pastix_complex32_t       *yptr;
    pastix_int_t              rank;
    SolverMatrix             *mtx = arg->mtx;
    pastix_int_t              tasknbr, *tasktab;
    pastix_int_t              ii, task_id;
    SolverCblk               *solv_cblk;
    bcsc_cblk_t              *bcsc_cblk;
    Task                     *t;

    rank = (pastix_int_t)ctx->rank;

    tasknbr = mtx->ttsknbr[rank];
    tasktab = mtx->ttsktab[rank];

    /*
     * There are three cases:
     *    We can use the Lvalues pointer directly:
     *          - The matrix is general and we use A^t
     *          - The matrix is symmetric or hermitian
     *    We can use the Uvalues pointer directly
     *          - The matrix is general and we use A
     *    We have to use Lvalues per row (instead of column)
     *          - The matrix A is general and Uvalues is unavailable
     *
     * To this, we have to add the conjf call if ConjTrans or Hermitian
     *
     *     Mtxtype   | trans asked | algo applied
     *     ++++++++++++++++++++++++++++++++++++
     +     General   | NoTrans     | U if possible, otherwise indirect L
     +     General   | Trans       | L
     +     General   | ConjTrans   | conjf(L)
     +     Symmetric | NoTrans     | L
     +     Symmetric | Trans       | L
     +     Symmetric | ConjTrans   | conjf(L)
     +     Hermitian | NoTrans     | conjf(L)
     +     Hermitian | Trans       | L
     +     Hermitian | ConjTrans   | conjf(L)
     */
    valptr = (pastix_complex32_t*)bcsc->Lvalues;

    if ( (bcsc->mtxtype == PastixGeneral) && (trans == PastixNoTrans) )
    {
        /* U */
        if ( bcsc->Uvalues != NULL ) {
            valptr = (pastix_complex32_t*)bcsc->Uvalues;
        }
        /* Indirect L */
        else {
            /* Execute in sequential */
            if ( rank != 0 ) {
                return;
            }
            __bcsc_cspmv_Ax_ind( bcsc, alpha, valptr, x, beta, y );
            return;
        }
    }
#if defined(PRECISION_z) || defined(PRECISION_c)
    /* Conj(L) */
    else if ( ( (bcsc->mtxtype == PastixGeneral  ) && (trans == PastixConjTrans) ) ||
              ( (bcsc->mtxtype == PastixSymmetric) && (trans == PastixConjTrans) ) ||
              ( (bcsc->mtxtype == PastixHermitian) && (trans != PastixTrans    ) ) )
    {
        zspmv_Ax = __bcsc_cspmv_conjAx;
    }
#endif /* defined(PRECISION_z) || defined(PRECISION_c) */

    for (ii=0; ii<tasknbr; ii++)
    {
        task_id = tasktab[ii];
        t = mtx->tasktab + task_id;

        solv_cblk = mtx->cblktab + t->cblknum;
        bcsc_cblk = bcsc->cscftab + solv_cblk->bcscnum;
        yptr = y + solv_cblk->lcolidx;

        zspmv_Ax( bcsc, bcsc_cblk, alpha, valptr, x, beta, yptr );
    }
}

/**
 *******************************************************************************
 *
 * @ingroup bcsc_internal
 *
 * @brief Initialize indexes for vector pointer and bloc indexes
 *   for parallel version of spmv.
 *
 *   This function Initial indexes for each thread
 *   in order to computes it once instead of once per thread. This is a more
 *   sophisticated version trying to balance the load for each thread in terms
 *   of bloc size.
 *
 *******************************************************************************
 *
 * @param[in] pastix_data
 *          The pastix_data structure providing number of threads and holding
 *          the A matrix.
 *
 * @param[out] args
 *          The argument containing arrays to initialise (blocs and indexes).
 *
 *******************************************************************************/
void
bcsc_cspmv_get_balanced_indexes( const pastix_data_t      *pastix_data,
                                 struct c_argument_spmv_s *args )
{
    pastix_int_t rank, bloc, size;
    pastix_int_t ratio, total, load;
    pastix_bcsc_t *bcsc = pastix_data->bcsc;
    bcsc_cblk_t *cblk = bcsc->cscftab;

    if ( bcsc->mtxtype != PastixGeneral ) {
        total = 2 * pastix_data->csc->nnzexp - bcsc->gN;
    } else {
        total = pastix_data->csc->nnzexp;
    }
    size  = pastix_data->isched->world_size;
    ratio = pastix_iceil( total, size );
    load  = 0;

    args->start_bloc[0]    = 0;
    args->start_indexes[0] = 0;

    for ( bloc = 0, rank = 1; bloc < bcsc->cscfnbr; ++bloc, ++cblk )
    {
        if ( load >= ratio ) {
            assert( rank < size );

            args->start_bloc[rank]    = bloc;
            args->start_indexes[rank] = pastix_data->solvmatr->cblktab[bloc].fcolnum;

            rank ++;
            total -= load;
            load = 0;
        }
        load += cblk->coltab[cblk->colnbr] - cblk->coltab[0];
    }

    total -= load;
    assert( total == 0 );

    for ( ; rank < size; rank ++ ) {
        args->start_bloc[rank]    = bcsc->cscfnbr;
        args->start_indexes[rank] = bcsc->gN;
    }
}

/**
 *******************************************************************************
 *
 * @ingroup bcsc
 *
 * @brief Perform y = alpha A x + beta y (Parallel version)
 *
 * This functions is parallelized through the internal static scheduler.
 *
 *******************************************************************************
 *
 * @param[in] pastix_data
 *          The pastix_data structure that holds the A matrix.
 *
 * @param[in] trans
 *          Specifies whether the matrix A from the bcsc is transposed, not
 *          transposed or conjugate transposed:
 *            = PastixNoTrans:   A is not transposed;
 *            = PastixTrans:     A is transposed;
 *            = PastixConjTrans: A is conjugate transposed.
 *
 * @param[in] alpha
 *          The scalar alpha.
 *
 * @param[in] x
 *          The vector x
 *
 * @param[in] beta
 *          The scalar beta.
 *
 * @param[inout] y
 *          On entry, the vector y
 *          On exit, alpha A x + y
 *
 *******************************************************************************/
void
bcsc_cspmv_smp( const pastix_data_t      *pastix_data,
                pastix_trans_t            trans,
                pastix_complex32_t        alpha,
                const pastix_complex32_t *x,
                pastix_complex32_t        beta,
                pastix_complex32_t       *y )
{
    pastix_bcsc_t *bcsc = pastix_data->bcsc;
    struct c_argument_spmv_s arg = { trans, alpha, bcsc, x, beta, y,
                                     pastix_data->solvmatr, NULL, NULL };

    if( (bcsc == NULL) || (y == NULL) || (x == NULL) ) {
        return;
    }

    isched_parallel_call( pastix_data->isched, pthread_bcsc_cspmv_tasktab, &arg );

#if 0
    /*
     * Version that balances the number of nnz per thread, instead of exploiting
     * the tasktab array.
     */
    {
        MALLOC_INTERN( arg.start_indexes, 2 * pastix_data->isched->world_size, pastix_int_t );
        arg.start_bloc = arg.start_indexes + pastix_data->isched->world_size;

        bcsc_cspmv_get_balanced_indexes( pastix_data, &arg );

        isched_parallel_call( pastix_data->isched, pthread_bcsc_cspmv, &arg );

        memFree_null( arg.start_indexes );
    }
#endif
}

/**
 *******************************************************************************
 *
 * @brief Compute the matrix-vector product  y = alpha * op(A) * x + beta * y
 *
 * Where A is given in the bcsc format, x and y are two vectors of size n, and
 * alpha and beta are two scalars.
 * The op function is specified by the trans parameter and performs the
 * operation as follows:
 *              trans = PastixNoTrans   y := alpha*A       *x + beta*y
 *              trans = PastixTrans     y := alpha*A'      *x + beta*y
 *              trans = PastixConjTrans y := alpha*conjf(A')*x + beta*y
 *
 * This function is used only in testings.
 *
 *******************************************************************************
 *
 * @param[in] pastix_data
 *          Provide information about bcsc, and select the scheduling version
 *          based on iparm[IPARM_SCHEDULER].
 *
 * @param[in] trans
 *          Specifies whether the matrix A from the bcsc is transposed, not
 *          transposed or conjugate transposed:
 *            = PastixNoTrans:   A is not transposed;
 *            = PastixTrans:     A is transposed;
 *            = PastixConjTrans: A is conjugate transposed.
 *
 * @param[in] alpha
 *          alpha specifies the scalar alpha
 *
 * @param[in] x
 *          The vector x.
 *
 * @param[in] beta
 *          beta specifies the scalar beta
 *
 * @param[inout] y
 *          The vector y.
 *
 *******************************************************************************/
void
bcsc_cspmv( const pastix_data_t      *pastix_data,
            pastix_trans_t            trans,
            pastix_complex32_t        alpha,
            const pastix_complex32_t *x,
            pastix_complex32_t        beta,
            pastix_complex32_t       *y )
{
    const pastix_complex32_t *xglobal;
    const pastix_int_t       *iparm  = pastix_data->iparm;
    pastix_trans_t            transA = iparm[IPARM_TRANSPOSE_SOLVE];

    /*
     * trans           | transA          | Final
     * ----------------+-----------------+-----------------
     * PastixNoTrans   | PastixNoTrans   | PastixNoTrans
     * PastixNoTrans   | PastixTrans     | PastixTrans
     * PastixNoTrans   | PastixConjTrans | PastixConjTrans
     * PastixTrans     | PastixNoTrans   | PastixTrans
     * PastixTrans     | PastixTrans     | PastixNoTrans
     * PastixTrans     | PastixConjTrans | INCORRECT
     * PastixConjTrans | PastixNoTrans   | PastixConjTrans
     * PastixConjTrans | PastixTrans     | INCORRECT
     * PastixConjTrans | PastixConjTrans | PastixNoTrans
     */
    if ( trans == PastixNoTrans ) {
        trans = transA;
    }
    else if ( trans == transA ) {
        trans = PastixNoTrans;
    }
    else if ( transA != PastixNoTrans ) {
        pastix_print_error( "bcsc_cspmv: incompatible trans and transA" );
        return;
    }

    /* y is duplicated on all nodes. Set to 0 non local data */
    xglobal = bvec_cgather_remote( pastix_data, x );

    if ( (iparm[IPARM_SCHEDULER] == PastixSchedStatic) ||
         (iparm[IPARM_SCHEDULER] == PastixSchedDynamic) ) {
        bcsc_cspmv_smp( pastix_data, trans, alpha, xglobal, beta, y );
    }
    else {
        bcsc_cspmv_seq( pastix_data, trans, alpha, xglobal, beta, y );
    }

    if ( x != xglobal ) {
        free( (void*)xglobal );
    }
}
