Browse code

MH sampler with improved priors should lead to better convergence

Greg Finak authored on 27/01/2021 19:11:59
Showing 5 changed files

... ...
@@ -48,29 +48,30 @@
48 48
   alpha_u = array(0, dim = c(N, K));
49 49
   alpha_s = array(0, dim = c(N, K));
50 50
 
51
-  varp_s1 = array(sqrt(3), dim = c(K, 1));
51
+  varp_s1 = array(sqrt(5), dim = c(K, 1)); # variance of the proposal distribution 1 for alpha_s
52 52
   # sqrt(var)
53
-  varp_s2 = array(sqrt(10), dim = c(K, 1));
53
+  varp_s2 = array(sqrt(15), dim = c(K, 1)); #variance of the proposal distribution 2 for alpha_s
54 54
   # sqrt(var)
55
-  varp_s2[K] = sqrt(20);
56
-  varp_s1[K] = sqrt(10);
55
+  varp_s2[K] = sqrt(25);
56
+  varp_s1[K] = sqrt(15);
57
+  # alpha_s is proposed using a mixture distribution
57 58
 
58
-  pvar_s = array(0.8, dim = c(K, 1));
59
+  pvar_s = array(0.8, dim = c(K, 1)); # mixing for the mixture proposal distribution
59 60
   pvar_s[K] = 0.6;
60
-  varp_u = array(sqrt(8), dim = c(K, 1));
61
+  varp_u = array(sqrt(10), dim = c(K, 1)); #variance of the proposal distribution for alpha_u
61 62
 
62 63
   pp = array(0.65, dim = c(I, 1))
63 64
   pb1 <- clamp(1.5 / median(indi[, K]), 0, 0.9)
64 65
   pb2 <- clamp(5.0 / median(indi[, K]), 0, 0.9)
65
-  lambda_s = rep(0, K);
66
+  lambda_s = rep(0, K);  
66 67
   lambda_s[1:K1] = (10 ^ -2) * max(N_s, N_u)
67 68
   lambda_s[K] = max(N_s, N_u) - sum(lambda_s[1:K1])
68 69
   lambda_u = lambda_s
69 70
 
70
-  alpha_u[1, 1:(K - 1)] = 10
71
+  alpha_u[1, 1:(K - 1)] = 10 #initializaion 
71 72
   alpha_u[1, K] = 150
72 73
 
73
-  alpha_s[1, 1:(K - 1)] = 10
74
+  alpha_s[1, 1:(K - 1)] = 10 #initialization 
74 75
   alpha_s[1, K] = 100
75 76
 
76 77
   #################### acceptance rate ###########################
... ...
@@ -84,7 +85,8 @@
84 85
     if (tt %% 1000 == 0) vmessage("Iteration ", tt, " of ", N, ".")
85 86
 
86 87
     # update alphau
87
-    res2 <- .Call(C_updatealphau_noPu_Exp, alphaut = alpha_u[tt - 1,], n_s = n_s, n_u = n_u, I = I, K = K, lambda_u = lambda_u, var_p = varp_u, ttt = ttt, gammat = gamma[,, tt - 1])
88
+    #res2 <- .Call(C_updatealphau_noPu_Exp, alphaut = alpha_u[tt - 1,], n_s = n_s, n_u = n_u, I = I, K = K, lambda_u = lambda_u, var_p = varp_u, ttt = ttt, gammat = gamma[,, tt - 1])
89
+    res2 <- .Call(C_updatealphau_noPu_Exp_MH, alphaut = alpha_u[tt - 1,], n_s = n_s, n_u = n_u, I = I, K = K, lambda_u = lambda_u, var_p = varp_u, gammat = gamma[,, tt - 1])
88 90
     if (length(alpha_u[tt, ]) != length(res2$alphau_tt)) {
89 91
       vmessage("res2 alphau_tt length:", length(res2$alphau_tt), "\n")
90 92
       vmessage("alpha_u[tt,] length: ", length(alpha_u[tt,]), "\n")
... ...
@@ -95,6 +97,7 @@
95 97
     #update gamma
96 98
     res1 <- .Call(C_updategammak_noPu, n_s = n_s, n_u = n_u, gammat = gamma[,, tt - 1], I = I, K = K, SS = SS, alphau = alpha_u[tt,], alphas = alpha_s[tt - 1,], alpha = 1, mk = mk, Istar = Istar,
97 99
       mKstar = mKstar, pp = pp, pb1 = pb1, pb2 = pb2, indi = indi)
100
+    
98 101
     gamma[,, tt] = res1$gamma_tt;
99 102
     if (length(A_gm[, tt]) != length(res1$Ag)) {
100 103
       vmessage("res1 Ag length: ", length(res1$Ag), "\n")
... ...
@@ -107,7 +110,8 @@
107 110
     mKstar = res1$mKstar;
108 111
 
109 112
     # update alphas
110
-    res3 <- .Call(C_updatealphas_Exp, alphast = alpha_s[tt - 1,], n_s = n_s, K = K, I = I, lambda_s = lambda_s, gammat = gamma[,, tt], var_1 = varp_s1, var_2 = varp_s2, p_var = pvar_s, ttt = ttt)
113
+    # res3 <- .Call(C_updatealphas_Exp, alphast = alpha_s[tt - 1,], n_s = n_s, K = K, I = I, lambda_s = lambda_s, gammat = gamma[,, tt], var_1 = varp_s1, var_2 = varp_s2, p_var = pvar_s, ttt = ttt)
114
+    res3 <- .Call(C_updatealphas_Exp_MH, alphast = alpha_s[tt - 1,], n_s = n_s, K = K, I = I, lambda_s = lambda_s, gammat = gamma[,, tt], var_1 = varp_s1, var_2 = varp_s2, p_var = pvar_s)
111 115
     if (length(alpha_s[tt, ]) != length(res3$alphas_tt)) {
112 116
       vmessage("res3 alphas_tt length:", length(res3$alphas_tt), "\n")
113 117
       vmessage("alpha_s[tt,] length: ", length(alpha_s[tt,]), "\n")
... ...
@@ -1,4 +1,3 @@
1
-// Copyright [2014] <Fred Hutchinson Cancer Research Center>
2 1
 // This file was automatically generated by Kmisc::registerFunctions()
3 2
 
4 3
 #include <R.h>
... ...
@@ -34,6 +33,15 @@ SEXP updatealphas_Exp(SEXP alphast,
34 33
                       SEXP var_2,
35 34
                       SEXP p_var,
36 35
                       SEXP ttt);
36
+SEXP updatealphas_Exp_MH(SEXP alphast,
37
+                         SEXP n_s,
38
+                         SEXP K,
39
+                         SEXP I,
40
+                         SEXP lambda_s,
41
+                         SEXP gammat,
42
+                         SEXP var_1,
43
+                         SEXP var_2,
44
+                         SEXP p_var);
37 45
 SEXP updatealphau_noPu_Exp(SEXP alphaut,
38 46
                            SEXP n_s,
39 47
                            SEXP n_u,
... ...
@@ -43,6 +51,14 @@ SEXP updatealphau_noPu_Exp(SEXP alphaut,
43 51
                            SEXP var_p,
44 52
                            SEXP ttt,
45 53
                            SEXP gammat);
54
+SEXP updatealphau_noPu_Exp_MH(SEXP alphaut,
55
+                              SEXP n_s,
56
+                              SEXP n_u,
57
+                              SEXP I,
58
+                              SEXP K,
59
+                              SEXP lambda_u,
60
+                              SEXP var_p,
61
+                              SEXP gammat);
46 62
 SEXP updategammak_noPu(SEXP n_s,
47 63
                        SEXP n_u,
48 64
                        SEXP gammat,
... ...
@@ -69,7 +85,9 @@ R_CallMethodDef callMethods[] = {
69 85
     {"C_COMPASS_CellCounts", (DL_FUNC) & _COMPASS_CellCounts, 2},
70 86
     {"C_samplePuPs", (DL_FUNC) & samplePuPs, 9},
71 87
     {"C_updatealphas_Exp", (DL_FUNC) & updatealphas_Exp, 10},
88
+    {"C_updatealphas_Exp_MH", (DL_FUNC) & updatealphas_Exp_MH, 9},
72 89
     {"C_updatealphau_noPu_Exp", (DL_FUNC) & updatealphau_noPu_Exp, 9},
90
+    {"C_updatealphau_noPu_Exp_MH", (DL_FUNC) & updatealphau_noPu_Exp_MH, 8},
73 91
     {"C_updategammak_noPu", (DL_FUNC) & updategammak_noPu, 16},
74 92
     {NULL, NULL, 0}};
75 93
 
... ...
@@ -1,5 +1,3 @@
1
-// Copyright [2014] <Fred Hutchinson Cancer Research Center>
2
-
3 1
 // Generated by using Rcpp::compileAttributes() -> do not edit by hand
4 2
 // Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393
5 3
 
... ...
@@ -10,25 +8,24 @@ using namespace Rcpp;
10 8
 // CellCounts
11 9
 IntegerMatrix CellCounts(List x, List combos);
12 10
 RcppExport SEXP _COMPASS_CellCounts(SEXP xSEXP, SEXP combosSEXP) {
13
-  BEGIN_RCPP
14
-  Rcpp::RObject rcpp_result_gen;
15
-  Rcpp::RNGScope rcpp_rngScope_gen;
16
-  Rcpp::traits::input_parameter<List>::type x(xSEXP);
17
-  Rcpp::traits::input_parameter<List>::type combos(combosSEXP);
18
-  rcpp_result_gen = Rcpp::wrap(CellCounts(x, combos));
19
-  return rcpp_result_gen;
20
-  END_RCPP
11
+BEGIN_RCPP
12
+    Rcpp::RObject rcpp_result_gen;
13
+    Rcpp::RNGScope rcpp_rngScope_gen;
14
+    Rcpp::traits::input_parameter< List >::type x(xSEXP);
15
+    Rcpp::traits::input_parameter< List >::type combos(combosSEXP);
16
+    rcpp_result_gen = Rcpp::wrap(CellCounts(x, combos));
17
+    return rcpp_result_gen;
18
+END_RCPP
21 19
 }
22 20
 // CellCounts_character
23 21
 IntegerMatrix CellCounts_character(List data, List combinations);
24
-RcppExport SEXP _COMPASS_CellCounts_character(SEXP dataSEXP,
25
-                                              SEXP combinationsSEXP) {
26
-  BEGIN_RCPP
27
-  Rcpp::RObject rcpp_result_gen;
28
-  Rcpp::RNGScope rcpp_rngScope_gen;
29
-  Rcpp::traits::input_parameter<List>::type data(dataSEXP);
30
-  Rcpp::traits::input_parameter<List>::type combinations(combinationsSEXP);
31
-  rcpp_result_gen = Rcpp::wrap(CellCounts_character(data, combinations));
32
-  return rcpp_result_gen;
33
-  END_RCPP
22
+RcppExport SEXP _COMPASS_CellCounts_character(SEXP dataSEXP, SEXP combinationsSEXP) {
23
+BEGIN_RCPP
24
+    Rcpp::RObject rcpp_result_gen;
25
+    Rcpp::RNGScope rcpp_rngScope_gen;
26
+    Rcpp::traits::input_parameter< List >::type data(dataSEXP);
27
+    Rcpp::traits::input_parameter< List >::type combinations(combinationsSEXP);
28
+    rcpp_result_gen = Rcpp::wrap(CellCounts_character(data, combinations));
29
+    return rcpp_result_gen;
30
+END_RCPP
34 31
 }
35 32
new file mode 100644
... ...
@@ -0,0 +1,183 @@
1
+#include <Rcpp.h>
2
+
3
+extern "C" double digamma(double);
4
+
5
+// [[register]]
6
+RcppExport SEXP updatealphas_Exp_MH(SEXP alphast, SEXP n_s, SEXP K, SEXP I,
7
+                                    SEXP lambda_s, SEXP gammat, SEXP var_1,
8
+                                    SEXP var_2, SEXP p_var) {
9
+  BEGIN_RCPP
10
+  Rcpp::NumericVector xalphast(alphast);
11
+
12
+  Rcpp::IntegerMatrix xn_s(n_s);
13
+  Rcpp::IntegerMatrix xgammat(gammat);
14
+  int xI = Rcpp::as<int>(I);
15
+  int xK = Rcpp::as<int>(K);
16
+  Rcpp::NumericVector sqrt_var1(var_1);
17
+  Rcpp::NumericVector sqrt_var2(var_2);
18
+  // int xtt = Rcpp::as<int>(ttt);
19
+  Rcpp::NumericVector xlambda_s(lambda_s);
20
+  Rcpp::IntegerVector xAalphas(xK);
21
+
22
+  Rcpp::RNGScope scope;
23
+  Rcpp::NumericVector xp_var(p_var);  // proposal mixture
24
+
25
+  // double delF = 0.0;
26
+  double psik = 0.;
27
+  double log1 = 0.0;
28
+  double log2 = 0.0;
29
+  double sums = 0.;
30
+  double sum_alp_ns = 0.0;
31
+  double sum_alp = 0.0;
32
+  double sum_gl_alp = 0.0;
33
+  double sum_gl_alp_ns = 0.0;
34
+  int flag1 = 0;
35
+  int flagkk = 0;
36
+  int lp1 = 0;
37
+  for (int kk = 0; kk < xK; kk++) {
38
+    // delF = 0.0;
39
+    psik = digamma(xalphast[kk]);
40
+    log1 = 0.0;
41
+    log2 = 0.0;
42
+    for (int i = 0; i < xI; i++) {
43
+      lp1 = 0;
44
+      for (int k = 0; k < xK; k++) {
45
+        if (xgammat(i, k) == 1) {
46
+          lp1 += 1;
47
+        }
48
+      }
49
+
50
+      std::vector<int> p1(lp1);
51
+      flag1 = 0;
52
+      flagkk = 0;
53
+      for (int k = 0; k < xK; k++) {
54
+        if (xgammat(i, k) == 1) {
55
+          p1[flag1] = k;
56
+          flag1 += 1;
57
+          if (k == kk) {
58
+            flagkk = 1;
59
+          }
60
+        }
61
+      }
62
+      sum_alp_ns = 0.0;
63
+      sum_alp = 0.0;
64
+      sum_gl_alp = 0.0;
65
+      sum_gl_alp_ns = 0.0;
66
+      for (int k = 0; k < lp1; k++) {
67
+        sums = xalphast[p1[k]] + xn_s(i, p1[k]);
68
+        sum_alp_ns += sums;
69
+        sum_alp += xalphast[p1[k]];
70
+        sum_gl_alp += lgamma(xalphast[p1[k]]);
71
+        sum_gl_alp_ns += lgamma(sums);
72
+      }
73
+      // if (flagkk > 0) {
74
+      // delF += digamma(xn_s(i, kk) + xalphast[kk]) - psik -
75
+      //  digamma(sum_alp_ns) + digamma(sum_alp);
76
+      //}
77
+      if (lp1 > 0) {
78
+        log2 += -(sum_gl_alp - lgamma(sum_alp)) +
79
+                (sum_gl_alp_ns - lgamma(sum_alp_ns));
80
+      }
81
+    }
82
+    // double mean_p = std::max(0.01, xalphast[kk] + delF / xtt);
83
+    Rcpp::NumericVector alpha_s_p = Rcpp::rnorm(1, xalphast[kk], sqrt_var1[kk]);
84
+
85
+    if (Rcpp::as<double>(Rcpp::rbinom(1, 1, xp_var[kk])) == 1) {
86
+      alpha_s_p = Rcpp::rnorm(1, xalphast[kk], sqrt_var1[kk]);
87
+    } else {
88
+      alpha_s_p = Rcpp::rnorm(1, xalphast[kk], sqrt_var2[kk]);
89
+    }
90
+
91
+    if (alpha_s_p[0] > 0.0) {
92
+      std::vector<double> alp(xK);
93
+
94
+      for (int i = 0; i < xK; i++) {
95
+        alp[i] = xalphast[i];
96
+      }
97
+      alp[kk] = alpha_s_p[0];
98
+      // log2 += log(xp_var[kk]*gsl_ran_gaussian_pdf(alp[kk]-mean_p,
99
+      // sqrt_var1[kk])+(1-xp_var[kk])*gsl_ran_gaussian_pdf(alp[kk]-mean_p,
100
+      // sqrt_var2[kk]));
101
+      log2 +=
102
+          log(xp_var[kk] * Rf_dnorm4(alp[kk], xalphast[kk], sqrt_var1[kk], 0) +
103
+              (1 - xp_var[kk]) *
104
+                  Rf_dnorm4(alp[kk], xalphast[kk], sqrt_var2[kk], 0));
105
+      // delF = 0.0;
106
+      psik = digamma(alp[kk]);
107
+      for (int i = 0; i < xI; i++) {
108
+        lp1 = 0;
109
+        for (int k = 0; k < xK; k++) {
110
+          if (xgammat(i, k) == 1) {
111
+            lp1 += 1;
112
+          }
113
+        }
114
+
115
+        std::vector<int> p1(lp1);
116
+        flag1 = 0;
117
+        flagkk = 0;
118
+        for (int k = 0; k < xK; k++) {
119
+          if (xgammat(i, k) == 1) {
120
+            p1[flag1] = k;
121
+            flag1 += 1;
122
+            if (k == kk) {
123
+              flagkk = 1;
124
+            }
125
+          }
126
+        }
127
+
128
+        sum_alp_ns = 0.0;
129
+        sum_alp = 0.0;
130
+        sum_gl_alp = 0.0;
131
+        sum_gl_alp_ns = 0.0;
132
+        for (int k = 0; k < lp1; k++) {
133
+          sums = alp[p1[k]] + xn_s(i, p1[k]);
134
+          sum_alp_ns += sums;
135
+          sum_alp += alp[p1[k]];
136
+          sum_gl_alp += lgamma(alp[p1[k]]);
137
+          sum_gl_alp_ns += lgamma(sums);
138
+        }
139
+        // if (flagkk > 0) {
140
+        //   delF += digamma(xn_s(i, kk) + xalphast[kk]) - psik -
141
+        //     digamma(sum_alp_ns) + digamma(sum_alp);
142
+        // }
143
+        if (lp1 > 0) {
144
+          log1 += -(sum_gl_alp - lgamma(sum_alp)) +
145
+                  (sum_gl_alp_ns - lgamma(sum_alp_ns));
146
+        }
147
+      }
148
+      // mean_p = std::max(0.01, alp[kk] + delF / xtt);
149
+      // log1 +=log(xp_var[kk]*gsl_ran_gaussian_pdf(xalphast[kk]-mean_p,
150
+      // sqrt_var1[kk])+(1-xp_var[kk])*gsl_ran_gaussian_pdf(xalphast[kk]-mean_p,
151
+      // sqrt_var2[kk]));
152
+      log1 +=
153
+          log(xp_var[kk] * Rf_dnorm4(xalphast[kk], alp[kk], sqrt_var1[kk], 0) +
154
+              (1 - xp_var[kk]) *
155
+                  Rf_dnorm4(xalphast[kk], alp[kk], sqrt_var2[kk], 0));
156
+
157
+      // log1 += log(gsl_ran_exponential_pdf(alp[kk],xlambda_s[kk]));
158
+      // //exponential prior
159
+      log1 += Rf_dexp(alp[kk], xlambda_s[kk], 1);
160
+
161
+      // log2 +=
162
+      // log(gsl_ran_exponential_pdf(xalphast[kk],xlambda_s[kk]));//exponential
163
+      // prior
164
+      log2 += Rf_dexp(xalphast[kk], xlambda_s[kk], 1);
165
+
166
+      // if (alp[kk]<0 || alp[kk]>xlambda_s[kk]) {log1+=log(0);} //Uniform prior
167
+      // if (xalphast[kk]<0 || xalphast[kk]>xlambda_s[kk]) {log2+=log(0);}
168
+      // //Uniform prior
169
+
170
+      if (log(Rcpp::as<double>(Rcpp::runif(1))) <= (log1 - log2)) {
171
+        xalphast[kk] = alp[kk];
172
+        xAalphas[kk] = 1;
173
+      } else {
174
+        xAalphas[kk] = 0;
175
+      }
176
+    }
177
+  }
178
+
179
+  return Rcpp::List::create(Rcpp::Named("alphas_tt") = xalphast,
180
+                            Rcpp::Named("Aalphas") = xAalphas);
181
+
182
+  END_RCPP
183
+}
0 184
new file mode 100644
... ...
@@ -0,0 +1,209 @@
1
+#include <Rcpp.h>
2
+extern "C" double digamma(double);
3
+
4
+// Standard MH algorithm for updating alpha_u
5
+
6
+// [[register]]
7
+RcppExport SEXP updatealphau_noPu_Exp_MH(SEXP alphaut, SEXP n_s, SEXP n_u,
8
+                                         SEXP I, SEXP K, SEXP lambda_u,
9
+                                         SEXP var_p, SEXP gammat) {
10
+  BEGIN_RCPP
11
+  Rcpp::IntegerMatrix xgammat(gammat);
12
+  Rcpp::NumericVector xalphaut(alphaut);
13
+  Rcpp::IntegerMatrix xn_s(n_s);
14
+  Rcpp::IntegerMatrix xn_u(n_u);
15
+  int xI = Rcpp::as<int>(I);
16
+  int xK = Rcpp::as<int>(K);
17
+  Rcpp::NumericVector sqrt_var(var_p);
18
+  // int xtt = Rcpp::as<int>(ttt);
19
+  Rcpp::NumericVector xlambda_u(lambda_u);
20
+  Rcpp::IntegerVector xAalphau(xK);
21
+  Rcpp::RNGScope scope;
22
+
23
+  // double delF = 0.0;
24
+  double log1 = 0.0;
25
+  double log2 = 0.0;
26
+  double sum_alphau = 0.0;
27
+  int flag1 = 0;
28
+  int flag0 = 0;
29
+  int flagkk = 0;
30
+  int lp0 = 0;
31
+  int lp1 = 0;
32
+  double sum_nusalphau = 0.0;
33
+  double sum_nualphau = 0.0;
34
+  double sums = 0.;
35
+  for (int kk = 0; kk < xK; kk++) {
36
+    // delF = 0.0;
37
+    log1 = 0.0;
38
+    log2 = 0.0;
39
+    sum_alphau = 0.0;
40
+    for (int s = 0; s < xK; s++) {
41
+      sum_alphau += xalphaut[s];
42
+    }
43
+    log2 -= xI * lgamma(xalphaut[kk]);
44
+    // delF += xI * (digamma(sum_alphau) - digamma(xalphaut[kk]));
45
+    log2 += xI * lgamma(sum_alphau);
46
+    for (int i = 0; i < xI; i++) {
47
+      lp1 = 0;
48
+      for (int k = 0; k < xK; k++) {
49
+        if (xgammat(i, k) == 1) {
50
+          lp1 += 1;
51
+        }
52
+      }
53
+      lp0 = xK - lp1;
54
+      std::vector<int> p1(lp1);
55
+      flag1 = 0;
56
+      std::vector<int> p0(lp0);
57
+      flag0 = 0;
58
+      flagkk = 0;  // whether gamma_k = 1
59
+
60
+      for (int k = 0; k < xK; k++) {
61
+        if (xgammat(i, k) == 1) {
62
+          p1[flag1] = k;
63
+          flag1 += 1;
64
+          if (k == kk) {
65
+            flagkk = 1;
66
+          }
67
+        } else {
68
+          p0[flag0] = k;
69
+          flag0 += 1;
70
+        }
71
+      }
72
+      if (flagkk == 1) {
73
+        log2 += lgamma(xn_u(i, kk) + xalphaut[kk]);
74
+        // delF += digamma(xn_u(i, kk) + xalphaut[kk]);
75
+        sum_nualphau = 0.0;
76
+        sum_nusalphau = 0.0;
77
+        for (int k = 0; k < lp1; k++) {
78
+          sums = xn_u(i, p1[k]) + xalphaut[p1[k]];
79
+          sum_nualphau += sums;
80
+          sum_nusalphau += (sums + xn_s(i, p1[k]));
81
+        }
82
+        log2 -= lgamma(sum_nualphau);
83
+        log2 += lgamma(sum_nusalphau + 1);
84
+        // delF -= digamma(sum_nualphau);
85
+        // delF += digamma(sum_nusalphau + 1);
86
+
87
+        for (int k = 0; k < lp0; k++) {
88
+          sum_nusalphau += (xn_u(i, p0[k]) + xalphaut[p0[k]] + xn_s(i, p0[k]));
89
+        }
90
+        // delF -= digamma(sum_nusalphau + 1);
91
+        log2 -= lgamma(sum_nusalphau + 1);
92
+      } else {
93
+        log2 += lgamma(xn_u(i, kk) + xalphaut[kk] + xn_s(i, kk));
94
+        // delF += digamma(xn_u(i, kk) + xalphaut[kk] + xn_s(i, kk));
95
+        sum_nusalphau = 0.0;
96
+        for (int k = 0; k < xK; k++) {
97
+          sum_nusalphau += xn_u(i, k) + xalphaut[k] + xn_s(i, k);
98
+        }
99
+        log2 -= lgamma(sum_nusalphau + 1);
100
+        // delF -= digamma(sum_nusalphau + 1);
101
+      }
102
+    }
103
+    // double mean_p = std::max(0.01, xalphaut[kk] + delF / xtt);
104
+    Rcpp::NumericVector alpha_u_p = Rcpp::rnorm(1, xalphaut[kk], sqrt_var[kk]);
105
+    if (alpha_u_p[0] > 0.0) {
106
+      std::vector<double> alp(xK);
107
+      for (int i = 0; i < xK; i++) {
108
+        alp[i] = xalphaut[i];
109
+      }
110
+      alp[kk] = alpha_u_p[0];
111
+
112
+      // log2 += log(gsl_ran_gaussian_pdf(alp[kk]-mean_p, sqrt_var[kk]));
113
+      log2 += Rf_dnorm4(alp[kk], xalphaut[kk], sqrt_var[kk], 1);
114
+
115
+      // delF = 0.0;
116
+      sum_alphau = 0.0;
117
+      for (int s = 0; s < xK; s++) {
118
+        sum_alphau += alp[s];
119
+      }
120
+      log1 -= xI * lgamma(alp[kk]);
121
+      // delF += xI * (digamma(sum_alphau) - digamma(alp[kk]));
122
+      log1 += xI * lgamma(sum_alphau);
123
+      for (int i = 0; i < xI; i++) {
124
+        lp1 = 0;
125
+        for (int k = 0; k < xK; k++) {
126
+          if (xgammat(i, k) == 1) {
127
+            lp1 += 1;
128
+          }
129
+        }
130
+        lp0 = xK - lp1;
131
+        std::vector<int> p1(lp1);
132
+        flag1 = 0;
133
+        std::vector<int> p0(lp0);
134
+        flag0 = 0;
135
+        flagkk = 0;  // whether gamma_k = 1
136
+
137
+        for (int k = 0; k < xK; k++) {
138
+          if (xgammat(i, k) == 1) {
139
+            p1[flag1] = k;
140
+            flag1 += 1;
141
+            if (k == kk) {
142
+              flagkk = 1;
143
+            }
144
+          } else {
145
+            p0[flag0] = k;
146
+            flag0 += 1;
147
+          }
148
+        }
149
+        if (flagkk == 1) {
150
+          log1 += lgamma(xn_u(i, kk) + alp[kk]);
151
+          // delF += digamma(xn_u(i, kk) + alp[kk]);
152
+          sum_nualphau = 0.0;
153
+          sum_nusalphau = 0.0;
154
+          for (int k = 0; k < lp1; k++) {
155
+            sums = xn_u(i, p1[k]) + alp[p1[k]];
156
+            sum_nualphau += sums;
157
+            sum_nusalphau += (sums + xn_s(i, p1[k]));
158
+          }
159
+          log1 -= lgamma(sum_nualphau);
160
+          log1 += lgamma(sum_nusalphau + 1);
161
+          // delF -= digamma(sum_nualphau);
162
+          // delF += digamma(sum_nusalphau + 1);
163
+
164
+          for (int k = 0; k < lp0; k++) {
165
+            sum_nusalphau += (xn_u(i, p0[k]) + alp[p0[k]] + xn_s(i, p0[k]));
166
+          }
167
+          // delF -= digamma(sum_nusalphau + 1);
168
+          log1 -= lgamma(sum_nusalphau + 1);
169
+        } else {
170
+          log1 += lgamma(xn_u(i, kk) + alp[kk] + xn_s(i, kk));
171
+          // delF += digamma(xn_u(i, kk) + alp[kk] + xn_s(i, kk));
172
+          sum_nusalphau = 0.0;
173
+          for (int k = 0; k < xK; k++) {
174
+            sum_nusalphau += xn_u(i, k) + alp[k] + xn_s(i, k);
175
+          }
176
+          log1 -= lgamma(sum_nusalphau + 1);
177
+          // delF -= digamma(sum_nusalphau + 1);
178
+        }
179
+      }
180
+      // mean_p = std::max(0.01, alp[kk] + delF / xtt);
181
+
182
+      // log1 +=log(gsl_ran_gaussian_pdf(xalphaut[kk]-mean_p, sqrt_var[kk]));
183
+      log1 += Rf_dnorm4(xalphaut[kk], alp[kk], sqrt_var[kk], 1);
184
+
185
+      // log1 += log(gsl_ran_exponential_pdf(alp[kk],xlambda_u[kk]));
186
+      // //exponential prior
187
+      log1 += Rf_dexp(alp[kk], xlambda_u[kk], 1);
188
+
189
+      // log2 += log(gsl_ran_exponential_pdf(xalphaut[kk],xlambda_u[kk]));
190
+      // //exponential prior
191
+      log2 += Rf_dexp(xalphaut[kk], xlambda_u[kk], 1);
192
+
193
+      // if (alp[kk]<0 || alp[kk]>xlambda_u[kk]) {log1+=log(0);} //Uniform prior
194
+      // if (xalphaut[kk]<0 || xalphaut[kk]>xlambda_u[kk]) {log2+=log(0);}
195
+      // //Uniform prior
196
+
197
+      if (log(Rcpp::as<double>(Rcpp::runif(1))) <= (log1 - log2)) {
198
+        xalphaut[kk] = alp[kk];
199
+        xAalphau[kk] = 1;
200
+      } else {
201
+        xAalphau[kk] = 0;
202
+      }
203
+    }
204
+  }
205
+  return Rcpp::List::create(Rcpp::Named("alphau_tt") = xalphaut,
206
+                            Rcpp::Named("Aalphau") = xAalphau);
207
+
208
+  END_RCPP
209
+}