Skip to content

Commit

Permalink
improved progress.hpp
Browse files Browse the repository at this point in the history
  • Loading branch information
qddyy committed Dec 12, 2024
1 parent 7b97cf6 commit 7ade19c
Show file tree
Hide file tree
Showing 10 changed files with 141 additions and 142 deletions.
6 changes: 2 additions & 4 deletions R/pmt.R
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,6 @@ define_pmt <- function(
} else {
impl <- paste0("impl_", inherit, "_pmt")
cppFunction(
env = environment(super$.calculate_statistic),
depends = c(depends, "LearnNonparam"),
plugins = {
cpp_standard_ver <- evalCpp("__cplusplus")
Expand All @@ -226,7 +225,7 @@ define_pmt <- function(
hpps <- c("progress", "reorder", impl)
c(includes, paste0("#include<pmt/", hpps, ".hpp>"))
},
code = {
env = environment(super$.calculate_statistic), code = {
args <- paste0(
"arg", 1:(n <- if (inherit == "rcbd") 2 else 3)
)
Expand All @@ -236,8 +235,7 @@ define_pmt <- function(
", double n_permu, bool progress){",
"auto statistic = ", statistic, ";",
"return progress ?", paste0(
impl, "<PermuBar", c("Show", "Hide"), ">(",
paste0(
impl, "<", c("true", "false"), ">(", paste(
"clone(", args[-n], ")", collapse = ","
), ", statistic, n_permu )", collapse = ":"
), ";}"
Expand Down
20 changes: 10 additions & 10 deletions inst/include/pmt/impl_association_pmt.hpp
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
template <typename T, typename U>
NumericVector impl_association_pmt(
template <bool progress, typename T>
RObject impl_association_pmt(
NumericVector x,
NumericVector y,
const U& statistic_func,
const T& statistic_func,
const double n_permu)
{
T bar;
Stat<progress> statistic_container;

auto statistic_closure = statistic_func(x, y);
auto association_update = [x, y, &statistic_closure, &bar]() {
return bar << statistic_closure(x, y);
auto association_update = [x, y, &statistic_closure, &statistic_container]() {
return statistic_container << statistic_closure(x, y);
};

bar.init_statistic(association_update);
statistic_container.init_statistic(association_update);

if (!std::isnan(n_permu)) {
if (n_permu == 0) {
Expand All @@ -21,19 +21,19 @@ NumericVector impl_association_pmt(

NumericVector y_ = (n_permutation(x) < n_permutation(y)) ? x : y;

bar.init_statistic_permu(n_permutation(y_));
statistic_container.init_statistic_permu(n_permutation(y_));

do {
association_update();
} while (next_permutation(y_));
} else {
bar.init_statistic_permu(n_permu);
statistic_container.init_statistic_permu(n_permu);

do {
random_shuffle(y);
} while (association_update());
}
}

return bar.close();
return statistic_container.close();
}
20 changes: 10 additions & 10 deletions inst/include/pmt/impl_ksample_pmt.hpp
Original file line number Diff line number Diff line change
@@ -1,34 +1,34 @@
template <typename T, typename U>
NumericVector impl_ksample_pmt(
template <bool progress, typename T>
RObject impl_ksample_pmt(
const NumericVector data,
IntegerVector group,
const U& statistic_func,
const T& statistic_func,
const double n_permu)
{
T bar;
Stat<progress> statistic_container;

auto statistic_closure = statistic_func(data, group);
auto ksample_update = [data, group, &statistic_closure, &bar]() {
return bar << statistic_closure(data, group);
auto ksample_update = [data, group, &statistic_closure, &statistic_container]() {
return statistic_container << statistic_closure(data, group);
};

bar.init_statistic(ksample_update);
statistic_container.init_statistic(ksample_update);

if (!std::isnan(n_permu)) {
if (n_permu == 0) {
bar.init_statistic_permu(n_permutation(group));
statistic_container.init_statistic_permu(n_permutation(group));

do {
ksample_update();
} while (next_permutation(group));
} else {
bar.init_statistic_permu(n_permu);
statistic_container.init_statistic_permu(n_permu);

do {
random_shuffle(group);
} while (ksample_update());
}
}

return bar.close();
return statistic_container.close();
}
20 changes: 10 additions & 10 deletions inst/include/pmt/impl_multcomp_pmt.hpp
Original file line number Diff line number Diff line change
@@ -1,45 +1,45 @@
template <typename T, typename U>
NumericVector impl_multcomp_pmt(
template <bool progress, typename T>
RObject impl_multcomp_pmt(
const IntegerVector group_i,
const IntegerVector group_j,
const NumericVector data,
IntegerVector group,
const U& statistic_func,
const T& statistic_func,
const double n_permu)
{
R_len_t n_group = group[group.size() - 1];
R_len_t n_pair = n_group * (n_group - 1) / 2;

T bar(n_pair);
Stat<progress> statistic_container(n_pair);

auto multcomp_update = [group_i, group_j, data, group, n_pair, &statistic_func, &bar]() {
auto multcomp_update = [group_i, group_j, data, group, n_pair, &statistic_func, &statistic_container]() {
auto statistic_closure = statistic_func(data, group);

bool flag = false;
for (R_len_t k = 0; k < n_pair; k++) {
flag = bar << statistic_closure(group_i[k], group_j[k]);
flag = statistic_container << statistic_closure(group_i[k], group_j[k]);
};

return flag;
};

bar.init_statistic(multcomp_update);
statistic_container.init_statistic(multcomp_update);

if (!std::isnan(n_permu)) {
if (n_permu == 0) {
bar.init_statistic_permu(n_permutation(group));
statistic_container.init_statistic_permu(n_permutation(group));

do {
multcomp_update();
} while (next_permutation(group));
} else {
bar.init_statistic_permu(n_permu);
statistic_container.init_statistic_permu(n_permu);

do {
random_shuffle(group);
} while (multcomp_update());
}
}

return bar.close();
return statistic_container.close();
}
20 changes: 10 additions & 10 deletions inst/include/pmt/impl_paired_pmt.hpp
Original file line number Diff line number Diff line change
@@ -1,25 +1,25 @@
template <typename T, typename U>
NumericVector impl_paired_pmt(
template <bool progress, typename T>
RObject impl_paired_pmt(
NumericVector x,
NumericVector y,
const U& statistic_func,
const T& statistic_func,
const double n_permu)
{
T bar;
Stat<progress> statistic_container;

auto statistic_closure = statistic_func(x, y);
auto paired_update = [x, y, &statistic_closure, &bar]() {
return bar << statistic_closure(x, y);
auto paired_update = [x, y, &statistic_closure, &statistic_container]() {
return statistic_container << statistic_closure(x, y);
};

bar.init_statistic(paired_update);
statistic_container.init_statistic(paired_update);

if (!std::isnan(n_permu)) {
R_len_t i = 0;
R_len_t n = x.size();

if (n_permu == 0) {
bar.init_statistic_permu(1 << n);
statistic_container.init_statistic_permu(1 << n);

IntegerVector swapped(n, 0);
while (i < n) {
Expand All @@ -37,7 +37,7 @@ NumericVector impl_paired_pmt(
}
}
} else {
bar.init_statistic_permu(n_permu);
statistic_container.init_statistic_permu(n_permu);

do {
for (i = 0; i < n; i++) {
Expand All @@ -49,5 +49,5 @@ NumericVector impl_paired_pmt(
}
}

return bar.close();
return statistic_container.close();
}
20 changes: 10 additions & 10 deletions inst/include/pmt/impl_rcbd_pmt.hpp
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
template <typename T, typename U>
NumericVector impl_rcbd_pmt(
template <bool progress, typename T>
RObject impl_rcbd_pmt(
NumericMatrix data,
const U& statistic_func,
const T& statistic_func,
const double n_permu)
{
T bar;
Stat<progress> statistic_container;

auto statistic_closure = statistic_func(data);
auto rcbd_update = [data, &statistic_closure, &bar]() {
return bar << statistic_closure(data);
auto rcbd_update = [data, &statistic_closure, &statistic_container]() {
return statistic_container << statistic_closure(data);
};

bar.init_statistic(rcbd_update);
statistic_container.init_statistic(rcbd_update);

if (!std::isnan(n_permu)) {
R_len_t i;
Expand All @@ -24,7 +24,7 @@ NumericVector impl_rcbd_pmt(
total *= n_permutation(data.column(i));
}

bar.init_statistic_permu(total);
statistic_container.init_statistic_permu(total);

i = 0;
while (i < b) {
Expand All @@ -35,7 +35,7 @@ NumericVector impl_rcbd_pmt(
i = next_permutation(data.column(i)) ? 0 : i + 1;
}
} else {
bar.init_statistic_permu(n_permu);
statistic_container.init_statistic_permu(n_permu);

do {
for (i = 0; i < b; i++) {
Expand All @@ -45,5 +45,5 @@ NumericVector impl_rcbd_pmt(
}
}

return bar.close();
return statistic_container.close();
}
20 changes: 10 additions & 10 deletions inst/include/pmt/impl_table_pmt.hpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
template <typename T, typename U>
NumericVector impl_table_pmt(
template <bool progress, typename T>
RObject impl_table_pmt(
IntegerVector row,
IntegerVector col,
const U& statistic_func,
const T& statistic_func,
const double n_permu)
{
T bar;
Stat<progress> statistic_container;

R_len_t n = row.size();

Expand All @@ -20,31 +20,31 @@ NumericVector impl_table_pmt(
};

auto statistic_closure = statistic_func(data_filled());
auto table_update = [&data_filled, &statistic_closure, &bar]() {
return bar << statistic_closure(data_filled());
auto table_update = [&data_filled, &statistic_closure, &statistic_container]() {
return statistic_container << statistic_closure(data_filled());
};

bar.init_statistic(table_update);
statistic_container.init_statistic(table_update);

if (!std::isnan(n_permu)) {
if (n_permu == 0) {
std::sort(row.begin(), row.end());

IntegerVector col_ = (n_permutation(row) < n_permutation(col)) ? row : col;

bar.init_statistic_permu(n_permutation(col_));
statistic_container.init_statistic_permu(n_permutation(col_));

do {
table_update();
} while (next_permutation(col_));
} else {
bar.init_statistic_permu(n_permu);
statistic_container.init_statistic_permu(n_permu);

do {
random_shuffle(col);
} while (table_update());
}
}

return bar.close();
return statistic_container.close();
}
20 changes: 10 additions & 10 deletions inst/include/pmt/impl_twosample_pmt.hpp
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
template <typename T, typename U>
NumericVector impl_twosample_pmt(
template <bool progress, typename T>
RObject impl_twosample_pmt(
NumericVector x,
NumericVector y,
const U& statistic_func,
const T& statistic_func,
const double n_permu)
{
T bar;
Stat<progress> statistic_container;

auto statistic_closure = statistic_func(x, y);
auto twosample_update = [x, y, &statistic_closure, &bar]() {
return bar << statistic_closure(x, y);
auto twosample_update = [x, y, &statistic_closure, &statistic_container]() {
return statistic_container << statistic_closure(x, y);
};

bar.init_statistic(twosample_update);
statistic_container.init_statistic(twosample_update);

if (!std::isnan(n_permu)) {
NumericVector x_ = x.size() < y.size() ? x : y;
Expand All @@ -28,7 +28,7 @@ NumericVector impl_twosample_pmt(
for (i = m; i < n; i++) {
p[i] = 1;
}
bar.init_statistic_permu(n_permutation(p));
statistic_container.init_statistic_permu(n_permutation(p));

for (i = 0; i < n; i++) {
p[i] = i;
Expand Down Expand Up @@ -97,7 +97,7 @@ NumericVector impl_twosample_pmt(
}
}
} else {
bar.init_statistic_permu(n_permu);
statistic_container.init_statistic_permu(n_permu);

do {
for (i = 0; i < m; i++) {
Expand All @@ -110,5 +110,5 @@ NumericVector impl_twosample_pmt(
}
}

return bar.close();
return statistic_container.close();
}
Loading

0 comments on commit 7ade19c

Please sign in to comment.