From 7ade19cb7490c3a6fb7d9654e04c8013d2ce0c40 Mon Sep 17 00:00:00 2001 From: qddyy Date: Thu, 12 Dec 2024 17:45:55 +0800 Subject: [PATCH] improved `progress.hpp` --- R/pmt.R | 6 +- inst/include/pmt/impl_association_pmt.hpp | 20 ++-- inst/include/pmt/impl_ksample_pmt.hpp | 20 ++-- inst/include/pmt/impl_multcomp_pmt.hpp | 20 ++-- inst/include/pmt/impl_paired_pmt.hpp | 20 ++-- inst/include/pmt/impl_rcbd_pmt.hpp | 20 ++-- inst/include/pmt/impl_table_pmt.hpp | 20 ++-- inst/include/pmt/impl_twosample_pmt.hpp | 20 ++-- inst/include/pmt/progress.hpp | 109 +++++++++++----------- src/pmt_interface.cpp | 28 +++--- 10 files changed, 141 insertions(+), 142 deletions(-) diff --git a/R/pmt.R b/R/pmt.R index 380502ec..fa5c8f12 100644 --- a/R/pmt.R +++ b/R/pmt.R @@ -216,7 +216,6 @@ define_pmt <- function( } else { impl <- paste0("impl_", inherit, "_pmt") cppFunction( - env = environment(super$.calculate_statistic), depends = c(depends, "LearnNonparam"), plugins = { cpp_standard_ver <- evalCpp("__cplusplus") @@ -226,7 +225,7 @@ define_pmt <- function( hpps <- c("progress", "reorder", impl) c(includes, paste0("#include")) }, - code = { + env = environment(super$.calculate_statistic), code = { args <- paste0( "arg", 1:(n <- if (inherit == "rcbd") 2 else 3) ) @@ -236,8 +235,7 @@ define_pmt <- function( ", double n_permu, bool progress){", "auto statistic = ", statistic, ";", "return progress ?", paste0( - impl, "(", - paste0( + impl, "<", c("true", "false"), ">(", paste( "clone(", args[-n], ")", collapse = "," ), ", statistic, n_permu )", collapse = ":" ), ";}" diff --git a/inst/include/pmt/impl_association_pmt.hpp b/inst/include/pmt/impl_association_pmt.hpp index c0c8148e..f22eb84b 100644 --- a/inst/include/pmt/impl_association_pmt.hpp +++ b/inst/include/pmt/impl_association_pmt.hpp @@ -1,18 +1,18 @@ -template -NumericVector impl_association_pmt( +template +RObject impl_association_pmt( NumericVector x, NumericVector y, - const U& statistic_func, + const T& statistic_func, const double n_permu) { - T bar; + Stat statistic_container; auto statistic_closure = statistic_func(x, y); - auto association_update = [x, y, &statistic_closure, &bar]() { - return bar << statistic_closure(x, y); + auto association_update = [x, y, &statistic_closure, &statistic_container]() { + return statistic_container << statistic_closure(x, y); }; - bar.init_statistic(association_update); + statistic_container.init_statistic(association_update); if (!std::isnan(n_permu)) { if (n_permu == 0) { @@ -21,13 +21,13 @@ NumericVector impl_association_pmt( NumericVector y_ = (n_permutation(x) < n_permutation(y)) ? x : y; - bar.init_statistic_permu(n_permutation(y_)); + statistic_container.init_statistic_permu(n_permutation(y_)); do { association_update(); } while (next_permutation(y_)); } else { - bar.init_statistic_permu(n_permu); + statistic_container.init_statistic_permu(n_permu); do { random_shuffle(y); @@ -35,5 +35,5 @@ NumericVector impl_association_pmt( } } - return bar.close(); + return statistic_container.close(); } \ No newline at end of file diff --git a/inst/include/pmt/impl_ksample_pmt.hpp b/inst/include/pmt/impl_ksample_pmt.hpp index 16f9a71b..a5f7cc06 100644 --- a/inst/include/pmt/impl_ksample_pmt.hpp +++ b/inst/include/pmt/impl_ksample_pmt.hpp @@ -1,28 +1,28 @@ -template -NumericVector impl_ksample_pmt( +template +RObject impl_ksample_pmt( const NumericVector data, IntegerVector group, - const U& statistic_func, + const T& statistic_func, const double n_permu) { - T bar; + Stat statistic_container; auto statistic_closure = statistic_func(data, group); - auto ksample_update = [data, group, &statistic_closure, &bar]() { - return bar << statistic_closure(data, group); + auto ksample_update = [data, group, &statistic_closure, &statistic_container]() { + return statistic_container << statistic_closure(data, group); }; - bar.init_statistic(ksample_update); + statistic_container.init_statistic(ksample_update); if (!std::isnan(n_permu)) { if (n_permu == 0) { - bar.init_statistic_permu(n_permutation(group)); + statistic_container.init_statistic_permu(n_permutation(group)); do { ksample_update(); } while (next_permutation(group)); } else { - bar.init_statistic_permu(n_permu); + statistic_container.init_statistic_permu(n_permu); do { random_shuffle(group); @@ -30,5 +30,5 @@ NumericVector impl_ksample_pmt( } } - return bar.close(); + return statistic_container.close(); } \ No newline at end of file diff --git a/inst/include/pmt/impl_multcomp_pmt.hpp b/inst/include/pmt/impl_multcomp_pmt.hpp index 9e1b679f..5bf5e5af 100644 --- a/inst/include/pmt/impl_multcomp_pmt.hpp +++ b/inst/include/pmt/impl_multcomp_pmt.hpp @@ -1,39 +1,39 @@ -template -NumericVector impl_multcomp_pmt( +template +RObject impl_multcomp_pmt( const IntegerVector group_i, const IntegerVector group_j, const NumericVector data, IntegerVector group, - const U& statistic_func, + const T& statistic_func, const double n_permu) { R_len_t n_group = group[group.size() - 1]; R_len_t n_pair = n_group * (n_group - 1) / 2; - T bar(n_pair); + Stat statistic_container(n_pair); - auto multcomp_update = [group_i, group_j, data, group, n_pair, &statistic_func, &bar]() { + auto multcomp_update = [group_i, group_j, data, group, n_pair, &statistic_func, &statistic_container]() { auto statistic_closure = statistic_func(data, group); bool flag = false; for (R_len_t k = 0; k < n_pair; k++) { - flag = bar << statistic_closure(group_i[k], group_j[k]); + flag = statistic_container << statistic_closure(group_i[k], group_j[k]); }; return flag; }; - bar.init_statistic(multcomp_update); + statistic_container.init_statistic(multcomp_update); if (!std::isnan(n_permu)) { if (n_permu == 0) { - bar.init_statistic_permu(n_permutation(group)); + statistic_container.init_statistic_permu(n_permutation(group)); do { multcomp_update(); } while (next_permutation(group)); } else { - bar.init_statistic_permu(n_permu); + statistic_container.init_statistic_permu(n_permu); do { random_shuffle(group); @@ -41,5 +41,5 @@ NumericVector impl_multcomp_pmt( } } - return bar.close(); + return statistic_container.close(); } \ No newline at end of file diff --git a/inst/include/pmt/impl_paired_pmt.hpp b/inst/include/pmt/impl_paired_pmt.hpp index 07daae81..049c079d 100644 --- a/inst/include/pmt/impl_paired_pmt.hpp +++ b/inst/include/pmt/impl_paired_pmt.hpp @@ -1,25 +1,25 @@ -template -NumericVector impl_paired_pmt( +template +RObject impl_paired_pmt( NumericVector x, NumericVector y, - const U& statistic_func, + const T& statistic_func, const double n_permu) { - T bar; + Stat statistic_container; auto statistic_closure = statistic_func(x, y); - auto paired_update = [x, y, &statistic_closure, &bar]() { - return bar << statistic_closure(x, y); + auto paired_update = [x, y, &statistic_closure, &statistic_container]() { + return statistic_container << statistic_closure(x, y); }; - bar.init_statistic(paired_update); + statistic_container.init_statistic(paired_update); if (!std::isnan(n_permu)) { R_len_t i = 0; R_len_t n = x.size(); if (n_permu == 0) { - bar.init_statistic_permu(1 << n); + statistic_container.init_statistic_permu(1 << n); IntegerVector swapped(n, 0); while (i < n) { @@ -37,7 +37,7 @@ NumericVector impl_paired_pmt( } } } else { - bar.init_statistic_permu(n_permu); + statistic_container.init_statistic_permu(n_permu); do { for (i = 0; i < n; i++) { @@ -49,5 +49,5 @@ NumericVector impl_paired_pmt( } } - return bar.close(); + return statistic_container.close(); } \ No newline at end of file diff --git a/inst/include/pmt/impl_rcbd_pmt.hpp b/inst/include/pmt/impl_rcbd_pmt.hpp index bfe981e8..258fcd4f 100644 --- a/inst/include/pmt/impl_rcbd_pmt.hpp +++ b/inst/include/pmt/impl_rcbd_pmt.hpp @@ -1,17 +1,17 @@ -template -NumericVector impl_rcbd_pmt( +template +RObject impl_rcbd_pmt( NumericMatrix data, - const U& statistic_func, + const T& statistic_func, const double n_permu) { - T bar; + Stat statistic_container; auto statistic_closure = statistic_func(data); - auto rcbd_update = [data, &statistic_closure, &bar]() { - return bar << statistic_closure(data); + auto rcbd_update = [data, &statistic_closure, &statistic_container]() { + return statistic_container << statistic_closure(data); }; - bar.init_statistic(rcbd_update); + statistic_container.init_statistic(rcbd_update); if (!std::isnan(n_permu)) { R_len_t i; @@ -24,7 +24,7 @@ NumericVector impl_rcbd_pmt( total *= n_permutation(data.column(i)); } - bar.init_statistic_permu(total); + statistic_container.init_statistic_permu(total); i = 0; while (i < b) { @@ -35,7 +35,7 @@ NumericVector impl_rcbd_pmt( i = next_permutation(data.column(i)) ? 0 : i + 1; } } else { - bar.init_statistic_permu(n_permu); + statistic_container.init_statistic_permu(n_permu); do { for (i = 0; i < b; i++) { @@ -45,5 +45,5 @@ NumericVector impl_rcbd_pmt( } } - return bar.close(); + return statistic_container.close(); } \ No newline at end of file diff --git a/inst/include/pmt/impl_table_pmt.hpp b/inst/include/pmt/impl_table_pmt.hpp index a04309f7..d58513f3 100644 --- a/inst/include/pmt/impl_table_pmt.hpp +++ b/inst/include/pmt/impl_table_pmt.hpp @@ -1,11 +1,11 @@ -template -NumericVector impl_table_pmt( +template +RObject impl_table_pmt( IntegerVector row, IntegerVector col, - const U& statistic_func, + const T& statistic_func, const double n_permu) { - T bar; + Stat statistic_container; R_len_t n = row.size(); @@ -20,11 +20,11 @@ NumericVector impl_table_pmt( }; auto statistic_closure = statistic_func(data_filled()); - auto table_update = [&data_filled, &statistic_closure, &bar]() { - return bar << statistic_closure(data_filled()); + auto table_update = [&data_filled, &statistic_closure, &statistic_container]() { + return statistic_container << statistic_closure(data_filled()); }; - bar.init_statistic(table_update); + statistic_container.init_statistic(table_update); if (!std::isnan(n_permu)) { if (n_permu == 0) { @@ -32,13 +32,13 @@ NumericVector impl_table_pmt( IntegerVector col_ = (n_permutation(row) < n_permutation(col)) ? row : col; - bar.init_statistic_permu(n_permutation(col_)); + statistic_container.init_statistic_permu(n_permutation(col_)); do { table_update(); } while (next_permutation(col_)); } else { - bar.init_statistic_permu(n_permu); + statistic_container.init_statistic_permu(n_permu); do { random_shuffle(col); @@ -46,5 +46,5 @@ NumericVector impl_table_pmt( } } - return bar.close(); + return statistic_container.close(); } \ No newline at end of file diff --git a/inst/include/pmt/impl_twosample_pmt.hpp b/inst/include/pmt/impl_twosample_pmt.hpp index 91642d33..f34146c7 100644 --- a/inst/include/pmt/impl_twosample_pmt.hpp +++ b/inst/include/pmt/impl_twosample_pmt.hpp @@ -1,18 +1,18 @@ -template -NumericVector impl_twosample_pmt( +template +RObject impl_twosample_pmt( NumericVector x, NumericVector y, - const U& statistic_func, + const T& statistic_func, const double n_permu) { - T bar; + Stat statistic_container; auto statistic_closure = statistic_func(x, y); - auto twosample_update = [x, y, &statistic_closure, &bar]() { - return bar << statistic_closure(x, y); + auto twosample_update = [x, y, &statistic_closure, &statistic_container]() { + return statistic_container << statistic_closure(x, y); }; - bar.init_statistic(twosample_update); + statistic_container.init_statistic(twosample_update); if (!std::isnan(n_permu)) { NumericVector x_ = x.size() < y.size() ? x : y; @@ -28,7 +28,7 @@ NumericVector impl_twosample_pmt( for (i = m; i < n; i++) { p[i] = 1; } - bar.init_statistic_permu(n_permutation(p)); + statistic_container.init_statistic_permu(n_permutation(p)); for (i = 0; i < n; i++) { p[i] = i; @@ -97,7 +97,7 @@ NumericVector impl_twosample_pmt( } } } else { - bar.init_statistic_permu(n_permu); + statistic_container.init_statistic_permu(n_permu); do { for (i = 0; i < m; i++) { @@ -110,5 +110,5 @@ NumericVector impl_twosample_pmt( } } - return bar.close(); + return statistic_container.close(); } \ No newline at end of file diff --git a/inst/include/pmt/progress.hpp b/inst/include/pmt/progress.hpp index a2d465fa..4ae3abb3 100644 --- a/inst/include/pmt/progress.hpp +++ b/inst/include/pmt/progress.hpp @@ -32,49 +32,64 @@ constexpr std::array generate_bars(std::integer_se constexpr auto generated_bars = generate_bars(std::make_integer_sequence()); -class PermuBarHide { +template +class Stat { public: - PermuBarHide(const R_len_t statistic_size = 1) : + Stat(R_len_t statistic_size = 1) : + _progress_i(0), + _progress_every(2), _statistic_size(statistic_size) { } template - void init_statistic(const T& update_bar) + void init_statistic(T& update) { _init_statistic_buffer(_statistic_size, 1); - update_bar(); + update(); _statistic = _statistic_buffer; _statistic_buffer = NumericVector(0); } - void init_statistic_permu(const double n_permu) + void init_statistic_permu(double n_permu) { + _init_progress(); + _init_statistic_buffer(n_permu, _statistic_size); } - bool operator<<(const double statistic) + bool operator<<(double statistic) { + _update_progress(); + _statistic_buffer[_buffer_i++] = statistic; return _buffer_i != _buffer_size; } - NumericVector close() + RObject close() { + _clear_progress(); + _statistic.attr("permu") = _statistic_buffer; return _statistic; - } + }; private: - const R_len_t _statistic_size; - - NumericVector _statistic; + RObject _statistic; NumericVector _statistic_buffer; - void _init_statistic_buffer(const double n, const R_len_t size) + R_xlen_t _buffer_i; + R_xlen_t _buffer_size; + + R_xlen_t _progress_i; + R_xlen_t _progress_every; + + R_len_t _statistic_size; + + void _init_statistic_buffer(double n, R_len_t size) { double total = n * size; if (total <= 0 || total > R_XLEN_T_MAX) { @@ -91,55 +106,41 @@ class PermuBarHide { } } -protected: - R_xlen_t _buffer_i; - R_xlen_t _buffer_size; + void _init_progress(); + void _update_progress(); + void _clear_progress(); }; -class PermuBarShow : public PermuBarHide { -public: - template - PermuBarShow(Args&&... args) : - PermuBarHide(std::forward(args)...), - _show_i(0), - _show_every(2) { } - - template - auto init_statistic_permu(Args&&... args) - { - PermuBarHide::init_statistic_permu(std::forward(args)...); - - _show_i = 0; - _show_every = (_buffer_size < 100) ? 1 : _buffer_size / 100; +template <> +void Stat::_init_progress() { } - _show(); - } +template <> +void Stat::_init_progress() +{ + _progress_i = 0; + _progress_every = (_buffer_size < 100) ? 1 : _buffer_size / 100; - template - auto operator<<(Args&&... args) - { - if (++_show_i == _show_every) { - _show_i = 0; - _show(); - } + Rcout << generated_bars[0].data(); +} - return PermuBarHide::operator<<(std::forward(args)...); - } +template <> +void Stat::_update_progress() { } - template - auto close(Args&&... args) - { - Rcout << "\015\033[K\033[0m"; +template <> +void Stat::_update_progress() +{ + if (++_progress_i == _progress_every) { + _progress_i = 0; - return PermuBarHide::close(std::forward(args)...); + Rcout << generated_bars[static_cast(100 * _buffer_i / _buffer_size)].data(); } +} -private: - R_xlen_t _show_i; - R_xlen_t _show_every; +template <> +void Stat::_clear_progress() { } - void _show() const - { - Rcout << generated_bars[static_cast(100 * _buffer_i / _buffer_size)].data(); - } -}; \ No newline at end of file +template <> +void Stat::_clear_progress() +{ + Rcout << "\015\033[K\033[0m"; +} \ No newline at end of file diff --git a/src/pmt_interface.cpp b/src/pmt_interface.cpp index 5786188a..e62b3d5b 100644 --- a/src/pmt_interface.cpp +++ b/src/pmt_interface.cpp @@ -47,8 +47,8 @@ SEXP twosample_pmt( const bool progress) { return progress ? - impl_twosample_pmt>(clone(x), clone(y), statistic_func, n_permu) : - impl_twosample_pmt>(clone(x), clone(y), statistic_func, n_permu); + impl_twosample_pmt>(clone(x), clone(y), statistic_func, n_permu) : + impl_twosample_pmt>(clone(x), clone(y), statistic_func, n_permu); } #include "pmt/impl_ksample_pmt.hpp" @@ -62,8 +62,8 @@ SEXP ksample_pmt( const bool progress) { return progress ? - impl_ksample_pmt>(data, clone(group), statistic_func, n_permu) : - impl_ksample_pmt>(data, clone(group), statistic_func, n_permu); + impl_ksample_pmt>(data, clone(group), statistic_func, n_permu) : + impl_ksample_pmt>(data, clone(group), statistic_func, n_permu); } #include "pmt/impl_multcomp_pmt.hpp" @@ -79,8 +79,8 @@ SEXP multcomp_pmt( const bool progress) { return progress ? - impl_multcomp_pmt>(group_i, group_j, data, clone(group), statistic_func, n_permu) : - impl_multcomp_pmt>(group_i, group_j, data, clone(group), statistic_func, n_permu); + impl_multcomp_pmt>(group_i, group_j, data, clone(group), statistic_func, n_permu) : + impl_multcomp_pmt>(group_i, group_j, data, clone(group), statistic_func, n_permu); } #include "pmt/impl_paired_pmt.hpp" @@ -94,8 +94,8 @@ SEXP paired_pmt( const bool progress) { return progress ? - impl_paired_pmt>(clone(x), clone(y), statistic_func, n_permu) : - impl_paired_pmt>(clone(x), clone(y), statistic_func, n_permu); + impl_paired_pmt>(clone(x), clone(y), statistic_func, n_permu) : + impl_paired_pmt>(clone(x), clone(y), statistic_func, n_permu); } #include "pmt/impl_rcbd_pmt.hpp" @@ -108,8 +108,8 @@ SEXP rcbd_pmt( const bool progress) { return progress ? - impl_rcbd_pmt>(clone(data), statistic_func, n_permu) : - impl_rcbd_pmt>(clone(data), statistic_func, n_permu); + impl_rcbd_pmt>(clone(data), statistic_func, n_permu) : + impl_rcbd_pmt>(clone(data), statistic_func, n_permu); } #include "pmt/impl_association_pmt.hpp" @@ -123,8 +123,8 @@ SEXP association_pmt( const bool progress) { return progress ? - impl_association_pmt>(clone(x), clone(y), statistic_func, n_permu) : - impl_association_pmt>(clone(x), clone(y), statistic_func, n_permu); + impl_association_pmt>(clone(x), clone(y), statistic_func, n_permu) : + impl_association_pmt>(clone(x), clone(y), statistic_func, n_permu); } #include "pmt/impl_table_pmt.hpp" @@ -138,6 +138,6 @@ SEXP table_pmt( const bool progress) { return progress ? - impl_table_pmt>(clone(row), clone(col), statistic_func, n_permu) : - impl_table_pmt>(clone(row), clone(col), statistic_func, n_permu); + impl_table_pmt>(clone(row), clone(col), statistic_func, n_permu) : + impl_table_pmt>(clone(row), clone(col), statistic_func, n_permu); } \ No newline at end of file