#include <cassert>
#include <iostream>
#include <random>
#include <algorithm>
#include "Sampler.h"
#include "Constants.h"
using namespace std;

/**
 * This method samples amino acid dihedral angle conformations for the given amino acid
 * @param nsamples Number of samples to generate
 * @param mixture Associated amino acid dihedral angle mixture model
 * @param angleCount Total dihedral angles of the amino acid
 * @param gen The Pseudo Random Number Generator
 *
 * @return A vector of dihedral angle vectors
 *
*/

template<typename PRNG>
vector<vector<float>>
Sampler::sampleJointVonMisesMixture(const size_t &nsamples, const vector<JointVonMisesCircular> &mixture,
                                    const size_t &angleCount, PRNG &gen) {

    const size_t numComponents = mixture.size();
    vector<double> componentProbabilities(numComponents);
    for (int i = 0; i < numComponents; i++) {
        componentProbabilities[i] = mixture[i].getWeight();
    }
    vector<double> cdf = sumProbabilityList(componentProbabilities);


    vector<vector<float>> sampledData(nsamples, vector<float>(angleCount));

    for (int i = 0; i < nsamples; i++) {
        int componentIndex = rouletteWheelSelection(cdf, gen);
        auto &jointVonMisesCircular = mixture[componentIndex];
        for (int j = 0; j < angleCount; j++) {
            auto &voMisesCircular = jointVonMisesCircular.getVonMisesCircular()[j];
            sampledData[i][j] = (float) sampleVonMisesDistribution(voMisesCircular.getComponentMean(),
                                                                   voMisesCircular.getConcentration(), gen);
        }
    }
    return sampledData;
}

/**
 * This method samples amino acid sidechain conformations given backbone conformations
 * @param mixture Associated amino acid dihedral angle mixture model
 * @param angleCount Total dihedral angles of the amino acid
 * @param backboneDihedralAngles <phi,psi> backbone dihedral angles
 * @param gen The Pseudo Random Number Generator
 *
 * @return A vector of dihedral angle vectors
 *
*/

template<typename PRNG>
vector<vector<float>>
Sampler::sampleConditionalSideChain(const vector<JointVonMisesCircular> &mixture, const size_t &angleCount,
                                    const size_t &nsamples,
                                    const vector<float> &backbonePair, PRNG &gen) {

    size_t numComponents = mixture.size();

    vector<vector<long double>> normalizingReciprocalVector(numComponents);
    for (int k = 0; k < numComponents; k++) {
        normalizingReciprocalVector[k] = {
                getNormalizingReciprocal(mixture[k].getVonMisesCircular()[0].getConcentration()),
                getNormalizingReciprocal(mixture[k].getVonMisesCircular()[1].getConcentration())};
    }

    vector<long double> phiPsiDensities(numComponents);
    vector<vector<float>> sampledData(nsamples, vector<float>(angleCount));

    // recalculate the probabilities of models
    for (int i = 0; i < numComponents; i++) {
        phiPsiDensities[i] = calculateBackboneDensity(mixture[i], backbonePair, normalizingReciprocalVector[i]);
    }

    vector<long double> cdf = sumProbabilityList(phiPsiDensities);

    if (!cdf.empty()) {
        for (int i = 0; i < nsamples; i++) {
            int componentIndex = rouletteWheelSelection(cdf, gen);
            auto &jointVonMisesCircular = mixture[componentIndex];

            for (int j = 0; j < angleCount; j++) {
                auto &voMisesCircular = jointVonMisesCircular.getVonMisesCircular()[j + 2];
                auto value = (float) sampleVonMisesDistribution(voMisesCircular.getComponentMean(),
                                                                voMisesCircular.getConcentration(), gen);

                sampledData[i][j]= value;
            }
        }
    }
    return sampledData;
}

/**
 * This method performs a Roulette wheel selection for the provided cumulative probability distribution
 * @param cdf The cumulative probability vector
 * @param gen The Pseudo Random Number Generator
 *
 * @return bin index of a uniform random number [0,1) falls in the @param cdf
 *
*/
template<typename PRNG, typename T>
int Sampler::rouletteWheelSelection(const vector<T> &cdf, PRNG &gen) {
    uniform_real_distribution<> dis(0, 1.0);
    T randomNum = dis(gen);
    return binarySearch(randomNum, cdf, 0, cdf.size() - 1);;
}

/**
 * This method performs a binary search and returns the index the provided value falls into 
 * @param value The value to be located
 * @param probabilities The cumulative probability vector
 * @param startIndex The start index of the @param probabilities vector to search
 * @param endIndex The end index of the @param probabilities vector to search
 *
 * @return index @param value falls in to
 *
*/
template<typename T>
int Sampler::binarySearch(T value, const vector<T> &probabilities, int startIndex, int endIndex) {
    size_t size = endIndex - startIndex + 1;
    if (size == 1)
        return startIndex;
    div_t divresult = div(size, 2);
    int remainder = divresult.rem;
    int mid_point = divresult.quot + startIndex;
    if (remainder == 0) {
        mid_point = mid_point - 1;
    }

    if (value > probabilities[mid_point]) {
        return binarySearch(value, probabilities, mid_point + 1, endIndex);
    } else {
        return binarySearch(value, probabilities, startIndex, mid_point);
    }
}

/**
 * This method returns a cumulative probability vector
 * @param probabilities The probability vector
 *
 * @return cumulative probability vector
 *
*/
template<typename T>
vector<T> Sampler::sumProbabilityList(const vector<T> &probabilities) {
    T sum = 0;
    vector<T> sumProbabilities(probabilities.size());
    for (T probability: probabilities) {
        sum = probability + sum;
    }
    if (sum == 0) {
        return {}; // 0 probability across all the components
    }
    for (int i = 0; i < probabilities.size(); i++) {
        sumProbabilities[i] = probabilities[i] / sum; //normalize incoming probability vector
    }
    sum = 0;
    for (T &probability: sumProbabilities) {
        sum = sum + probability;
        probability = sum;
    }
    return sumProbabilities;
}


/**
 * This method calculate the associated probability density of the given backbone dihedral angles
 * @param component component of the mixture model
 * @param backbonePair <phi,psi> angle vector
 * @param normalizedReciprocal the reciprocal of the normalizing constant
 *
 * @return joint probability density
 *
*/
long double Sampler::calculateBackboneDensity(const JointVonMisesCircular &component, const vector<float> &backbonePair,
                                              const vector<long double> &normalizedReciprocal) {
    const VonMisesCircular &phiVonMises = component.getVonMisesCircular()[0];
    const VonMisesCircular &psiVonMises = component.getVonMisesCircular()[1];
    float phiRad = backbonePair[0];
    float psiRad = backbonePair[1];

    long double phiPdf = getPdfEvaluation(phiRad, phiVonMises.getComponentMean(), phiVonMises.getConcentration(),
                                          normalizedReciprocal[0]);
    long double psiPdf = getPdfEvaluation(psiRad, psiVonMises.getComponentMean(), psiVonMises.getConcentration(),
                                          normalizedReciprocal[1]);
    long double jointPdf = phiPdf * psiPdf * component.getWeight();

    assert(jointPdf >= 0);
    return jointPdf;
}

/**
 * This method calculate the associated probability density given the von Mises parameters
 * @param reciprocal the reciprocal of the normalizing constant
 *
 * @return von Mises probability density
 *
*/
long double Sampler::getPdfEvaluation(const float &angle, const float &mean, const float &concentration,
                                      const long double &reciprocal) {
    long double exponent = concentration * cos(angle - mean);
    long double pdf = exp(exponent) * reciprocal;
    assert(pdf >= 0);
    return pdf;
}

/**
 * This method calculate the reciprocal of the normalizing constant of a given von Mises function
 * @param concentration of the von Mises distribution
 *
 * @return reciprocal of the normalizing constant
 *
*/
long double Sampler::getNormalizingReciprocal(const float &concentration) {
    long double normalizingConstant = 1.0 / (2 * M_PI * cyl_bessel_i(0, (long double) concentration));
    assert(normalizingConstant > 0);
    return normalizingConstant;
}

/**
 * References:
 *  [1] D. J. Best and N. I. Fisher, Efficient Simulation of the von Mises Distribution, Applied Statistics, 28, 2, 152--157, (1979)
 *  [2]  N. I. Fisher, Statistical analysis of circular data, Cambridge University Press, (1993).
*/
template<typename PRNG>
double Sampler::sampleVonMisesDistribution(float mean, float kappa, PRNG &gen) {
    uniform_real_distribution<> dis(0, 1.0);
    bool accept;
    double theta;
    int iterations = 0;

    do {
        double u_1 = dis(gen);
        double u_2 = dis(gen);
        double u_3 = dis(gen);

        double a = 1 + sqrt(1 + 4 * kappa * kappa);
        double b = (a - sqrt(2 * a)) / (2 * kappa);
        double r = (1 + b * b) / (2 * b);

        double z = cos(M_PI * u_1), f = (1 + r * z) / (r + z), c = kappa * (r - f);

        accept = (c * (2 - c) - u_2 > 0.0) || (log(c / u_2) + 1 - c >= 0.0);

        if (accept) {
            double x = u_3 - 0.5;
            int sign = (x > 0) ? 1 : ((x < 0) ? -1 : 0);
            theta = sign * acos(f) + mean;
        }
        iterations++;
    } while (!accept and (iterations < Constants::MAX_ITERATIONS));

    if (accept) {
        if (theta > M_PI) {
            return -2 * M_PI + theta;
        } else if (theta < -M_PI) {
            return 2 * M_PI + theta;
        }
        return theta;
    } else {
        cerr << "Exceed max accept/reject sampling iterations";
        exit(1);
    }
}

template vector<vector<float>>
Sampler::sampleJointVonMisesMixture<mt19937>(const size_t &, const vector<JointVonMisesCircular> &,
                                             const size_t &, mt19937 &);

template vector<vector<float>>
Sampler::sampleJointVonMisesMixture<minstd_rand0>(const size_t &, const vector<JointVonMisesCircular> &,
                                                  const size_t &, minstd_rand0 &);

template vector<vector<float>>
Sampler::sampleConditionalSideChain(const vector<JointVonMisesCircular> &,
                                    const size_t &, const size_t &, const vector<float> &, mt19937 &);

template vector<vector<float>>
Sampler::sampleConditionalSideChain(const vector<JointVonMisesCircular> &,
                                    const size_t &, const size_t &, const vector<float> &, minstd_rand0 &);