#include <Rcpp.h>

extern "C" double digamma(double);

using namespace std;
using namespace Rcpp;

void updatealphau(vector<double>& xalphaut,
                  vector<int>& xn_s,
                  vector<int>& xn_u,
                  int xI,
                  int xK,
                  vector<double>& xlambda_u,
                  vector<double>& sqrt_var,
                  int xtt,
                  vector<int>& xgammat,
                  vector<int>& xAalphau) {

  double delF = 0.0;
  double log1 = 0.0;
  double log2 = 0.0;
  double sum_alphau = 0.0;
  int flag1 = 0;
  int flag0 = 0;
  int flagkk = 0;
  int temp = 0;
  for (int kk = 0; kk < xK; kk++) {
    delF = 0.0;
    log1 = 0.0;
    log2 = 0.0;
    sum_alphau = 0.0;
    for (int s = 0; s < xK; s++) {
      sum_alphau += xalphaut[s];
    }
    log2 -= xI * lgamma(xalphaut[kk]);
    delF += xI * (digamma(sum_alphau) - digamma(xalphaut[kk]));
    log2 += xI * lgamma(sum_alphau);
    for (int i = 0; i < xI; i++) {
      int lp1 = 0;
      for (int k = 0; k < xK; k++) {
        if (xgammat[xI * k + i] == 1) {
          lp1 += 1;
        }
      }
      int lp0 = xK - lp1;
      vector<int> p1(lp1);
      flag1 = 0;
      vector<int> p0(lp0);
      flag0 = 0;
      flagkk = 0;  // whether gamma_k = 1

      for (int k = 0; k < xK; k++) {
        if (xgammat[xI * k + i] == 1) {
          p1[flag1] = k;
          flag1 += 1;
          if (k == kk) {
            flagkk = 1;
          }
        } else {
          p0[flag0] = k;
          flag0 += 1;
        }
      }
      if (flagkk == 1) {
        log2 += lgamma(xn_u[i + xI * kk] + xalphaut[kk]);
        delF += digamma(xn_u[i + xI * kk] + xalphaut[kk]);
        double sum_nualphau = 0.0;
        double sum_nusalphau = 0.0;
        for (int k = 0; k < lp1; k++) {
          temp = i + xI * p1[k];
          double sum = xn_u[temp] + xalphaut[p1[k]];
          sum_nualphau += sum;
          sum_nusalphau += (sum + xn_s[temp]);
        }
        log2 -= lgamma(sum_nualphau);
        log2 += lgamma(sum_nusalphau + 1);
        delF -= digamma(sum_nualphau);
        delF += digamma(sum_nusalphau + 1);

        for (int k = 0; k < lp0; k++) {
          temp = i + xI * p0[k];
          sum_nusalphau += (xn_u[temp] + xalphaut[p0[k]] + xn_s[temp]);
        }
        delF -= digamma(sum_nusalphau + 1);
        log2 -= lgamma(sum_nusalphau + 1);
      } else {
        log2 += lgamma(xn_u[i + xI * kk] + xalphaut[kk] + xn_s[i + xI * kk]);
        delF += digamma(xn_u[i + xI * kk] + xalphaut[kk] + xn_s[i + kk * xI]);
        double sum_nusalphau = 0.0;
        for (int k = 0; k < xK; k++) {
          sum_nusalphau += xn_u[i + xI * k] + xalphaut[k] + xn_s[i + xI * k];
        }
        log2 -= lgamma(sum_nusalphau + 1);
        delF -= digamma(sum_nusalphau + 1);
      }
    }
    double mean_p = std::max(0.01, xalphaut[kk] + delF / xtt);
    Rcpp::NumericVector alpha_u_p = Rcpp::rnorm(1, mean_p, sqrt_var[kk]);

    if (alpha_u_p[0] > 0.0 && alpha_u_p[0] <= xlambda_u[kk]) {
      vector<double> alp(xK);
      for (int i = 0; i < xK; i++) {
        alp[i] = xalphaut[i];
      }
      alp[kk] = alpha_u_p[0];
      // log2 += log(gsl_ran_gaussian_pdf(alp[kk]-mean_p, sqrt_var[kk]));
      log2 += Rf_dnorm4(alp[kk], mean_p, sqrt_var[kk], 1);

      delF = 0.0;
      sum_alphau = 0.0;
      for (int s = 0; s < xK; s++) {
        sum_alphau += alp[s];
      }
      log1 -= xI * lgamma(alp[kk]);
      delF += xI * (digamma(sum_alphau) - digamma(alp[kk]));
      log1 += xI * lgamma(sum_alphau);
      for (int i = 0; i < xI; i++) {
        int lp1 = 0;
        for (int k = 0; k < xK; k++) {
          if (xgammat[xI * k + i] == 1) {
            lp1 += 1;
          }
        }
        int lp0 = xK - lp1;
        vector<int> p1(lp1);
        flag1 = 0;
        vector<int> p0(lp0);
        flag0 = 0;
        flagkk = 0;  // whether gamma_k = 1

        for (int k = 0; k < xK; k++) {
          if (xgammat[xI * k + i] == 1) {
            p1[flag1] = k;
            flag1 += 1;
            if (k == kk) {
              flagkk = 1;
            }
          } else {
            p0[flag0] = k;
            flag0 += 1;
          }
        }
        if (flagkk == 1) {
          log1 += lgamma(xn_u[i + xI * kk] + alp[kk]);
          delF += digamma(xn_u[i + xI * kk] + alp[kk]);
          double sum_nualphau = 0.0;
          double sum_nusalphau = 0.0;
          for (int k = 0; k < lp1; k++) {
            temp = i + xI * p1[k];
            double sum = xn_u[temp] + alp[p1[k]];
            sum_nualphau += sum;
            sum_nusalphau += (sum + xn_s[temp]);
          }
          log1 -= lgamma(sum_nualphau);
          log1 += lgamma(sum_nusalphau + 1);
          delF -= digamma(sum_nualphau);
          delF += digamma(sum_nusalphau + 1);

          for (int k = 0; k < lp0; k++) {
            sum_nusalphau +=
                (xn_u[i + xI * p0[k]] + alp[p0[k]] + xn_s[i + xI * p0[k]]);
          }
          delF -= digamma(sum_nusalphau + 1);
          log1 -= lgamma(sum_nusalphau + 1);
        } else {
          log1 += lgamma(xn_u[i + xI * kk] + alp[kk] + xn_s[i + xI * kk]);
          delF += digamma(xn_u[i + xI * kk] + alp[kk] + xn_s[i + xI * kk]);
          double sum_nusalphau = 0.0;
          for (int k = 0; k < xK; k++) {
            temp = i + xI * k;
            sum_nusalphau += xn_u[temp] + alp[k] + xn_s[temp];
          }
          log1 -= lgamma(sum_nusalphau + 1);
          delF -= digamma(sum_nusalphau + 1);
        }
      }
      mean_p = std::max(0.01, alp[kk] + delF / xtt);
      // log1 +=log(gsl_ran_gaussian_pdf(xalphaut[kk]-mean_p, sqrt_var[kk]));
      log1 += Rf_dnorm4(xalphaut[kk], mean_p, sqrt_var[kk], 1);

      if (log(Rcpp::as<double>(Rcpp::runif(1))) <= (log1 - log2)) {
        xalphaut[kk] = alp[kk];
        xAalphau[kk] = 1;
      }
    }
  }
}