Skip to content

Commit

Permalink
some tests
Browse files Browse the repository at this point in the history
  • Loading branch information
privefl committed Dec 8, 2023
1 parent b2c5c29 commit d16e7b2
Show file tree
Hide file tree
Showing 8 changed files with 630 additions and 0 deletions.
79 changes: 79 additions & 0 deletions tmp-tests/test-glasso2.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
library(bigsnpr)

chr22 <- snp_attach("../Dubois2010_data/celiac_chr22.rds")
G <- chr22$genotypes$copy(code = c(0, 1, 2, 0, rep(NA, 252)))
dim(G)
cov <- runonce::save_run(cov(G[]), file = "tmp-data/cov_chr22.rds")

i <- 3000
ind <- setdiff(which(cov[, i] ** 2 > 0.001), 1)


id_sub <- c(i, ind)
cov_sub <- as.matrix(cov[id_sub, id_sub])
rho <- 0.0001
glasso0 <- glassoFast::glassoFast(cov_sub, rho = rho)
str(glasso0)
W0 <- W0.2 <- glasso0$w; diag(W0) <- diag(W0) - rho
all.equal(W0, cov_sub) # Mean relative difference: 0.07337141
plot(W0, cov_sub); abline(0, 1, col = "red")
W0[1:5, 1:5]
cov_sub[1:5, 1:5]
all.equal(glasso0$wi %*% cov_sub, diag(ncol(cov_sub)))
inv <- solve(cov_sub + rho * diag(ncol(cov_sub)))
all.equal(glasso0$wi, inv)
glasso0$wi[1:5, 1:5]
inv[1:5, 1:5]
plot(diag(glasso0$wi), diag(inv)); abline(0, 1, col = "red")

inv2 <- solve(cov_sub + 10 * rho * diag(ncol(cov_sub)))
plot(diag(inv2), diag(inv)); abline(0, 1, col = "red")
glasso0.2 <- glassoFast::glassoFast(cov_sub, rho = 10 * rho)
plot(diag(glasso0.2$wi), diag(glasso0$wi)); abline(0, 1, col = "red")

diag(glassoFast::glassoFast(cov_sub + rho * diag(ncol(cov_sub)), rho = rho)$wi)
diag(glassoFast::glassoFast(cov_sub + 10 * rho * diag(ncol(cov_sub)), rho = 10 * rho)$wi)


all.equal(inv %*% cov_sub, diag(ncol(cov_sub)))
all.equal(inv2 %*% cov_sub, diag(ncol(cov_sub)))
glasso0 <- glassoFast::glassoFast(cov_sub, rho = rho)
inv_inv <- solve(glasso0$wi)
plot(inv_inv, cov_sub); abline(0, 1, col = "red")
hist(cov_sub - inv_inv)
hist(W0 - inv_inv)
hist(W0.2 - inv_inv)
inv_inv[1:5, 1:5]
W0.2[1:5, 1:5]
glasso_inv <- glassoFast::glassoFast(glasso0$wi, rho = 10 * rho)
all.equal(glasso_inv$wi, cov_sub)
sum(glasso_inv$wi == 0)
glasso_inv2 <- glassoFast::glassoFast(inv, rho = 10 * rho)
all.equal(glasso_inv2$wi, cov_sub)
sum(glasso_inv2$wi == 0)

Rcpp::sourceCpp("tmp-tests/test-glasso2.cpp")
glasso2 <- glasso(as.matrix(cov_sub), lambda = rho, 200, 200, tol = 1e-4, verbose = TRUE)
W <- glasso2[[1]]
all.equal(W, cov_sub) # 0.07335971
all.equal(W, W0) # 1e-4

X <- glasso2[[2]]

plot(X, glasso0$wi)
tmp <- 1 / (1 + rho - colSums(X * (W + rho * diag(ncol(W)))))
plot(tmp, diag(glasso0$wi))
all.equal(tmp, diag(glasso0$wi)) # 0.0002222643
X2 <- sweep(X, 2, -tmp, '*'); diag(X2) <- tmp
X3 <- (X2 + t(X2)) / 2
all.equal(X3, glasso0$wi) # 0.001191556
plot(X3, glasso0$wi); abline(0, 1, col = "red")
all.equal(X3 %*% cov_sub, diag(ncol(cov_sub))) # 0.841867

microbenchmark::microbenchmark(
glassoFast::glassoFast(cov_sub, rho = 0.01),
glasso(as.matrix(cov_sub), lambda = 0.01, 200, 200, tol = 1e-4, verbose = FALSE)
)



106 changes: 106 additions & 0 deletions tmp-tests/test-glasso2.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
// Based on algo described in DOI: 10.1371/journal.pone.0014147

#include <RcppArmadillo.h>
using namespace Rcpp;

inline double soft_thres(double z, double l1, double l2) {
if (z > 0) {
double num = z - l1;
return (num > 0) ? num / l2 : 0;
} else {
double num = z + l1;
return (num < 0) ? num / l2 : 0;
}
}

void inner_lasso(const arma::mat& mat,
arma::mat& W,
arma::mat& beta,
arma::vec& dotprods,
double lambda,
int m,
int i,
int maxiter,
double tol) {

// arma::vec dotprods = W * beta.col(i); // use parallelism to fasten
dotprods.fill(0);
for (int j = 0; j < m; j++) {
if (beta(j, i) != 0) { // use sparsity to fasten
dotprods += W.col(j) * beta(j, i);
}
}

double gap0 = std::inner_product(mat.begin_col(i), mat.end_col(i),
mat.begin_col(i), 0.0);

for (int k = 0; k < maxiter; k++) {

bool conv_inner = true;
double gap = 0;

for (int j = 0; j < m; j++) {

if (j != i) {

double resid = mat(j, i) - dotprods[j];
gap += resid * resid;
double curr_beta = beta(j, i);
double new_beta = soft_thres(resid + curr_beta * W(j, j), lambda, W(j, j));

double shift = new_beta - curr_beta;
if (shift != 0) {
if (conv_inner && std::abs(shift) > tol) conv_inner = false;
beta(j, i) = new_beta;
dotprods += W.col(j) * shift;
}
}
}

if (gap > gap0) Rcpp::stop("Divergence!");
if (conv_inner) break;
}

for (int j = 0; j < m; j++)
if (j != i)
W(i, j) = W(j, i) = dotprods[j];
}

// [[Rcpp::export]]
ListOf<NumericMatrix> glasso(const arma::mat& mat,
double lambda,
int maxiter_outer,
int maxiter_lasso,
double tol,
bool verbose) {

int m = mat.n_cols;

arma::mat W = mat + 0;
W.diag() += lambda;

arma::mat beta(m, m, arma::fill::zeros);
arma::vec dotprods(m);

for (int k = 0; k < maxiter_outer; k++) {

Rcpp::checkUserInterrupt();
if (verbose) Rcpp::Rcout << k + 1 << std::endl;

arma::mat Wold = W + 0;

for (int i = 0; i < m; i++) {
inner_lasso(mat, W, beta, dotprods, lambda, m, i, maxiter_lasso, tol);
}

double max_diff = max(max(abs(W - Wold)));
double diff2 = mean(mean(square(W - Wold)));
if (verbose) Rcpp::Rcout << max_diff << " // " << diff2 << std::endl;

if (max_diff < tol) break;
}

W.diag() -= lambda;
return List::create(as<NumericMatrix>(wrap(W)),
as<NumericMatrix>(wrap(beta)));
}
114 changes: 114 additions & 0 deletions tmp-tests/test-linear-solver.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
corr <- runonce::save_run(
snp_cor(chr22$genotypes, infos.pos = POS2, size = 3 / 1000, ncores = 6),
file = "tmp-data/corr_chr22.rds"
)

ind <- setdiff(order(corr[, 1] ** 2, decreasing = TRUE), 1)

id_sub <- 1
ic <- 0

j <- ind[ic <- ic + 1]

id_sub <- c(id_sub, j)
corr_sub <- as.matrix(corr[id_sub, id_sub])
glasso <- glassoFast::glassoFast(corr_sub, rho = 0.01)
glasso$wi

ind2 <- which(glasso$wi[, 1] != 0)
id_sub2 <- id_sub[ind2]
length(id_sub2) / length(id_sub)
corr_sub2 <- as.matrix(corr[id_sub2, id_sub2])
glasso2 <- glassoFast::glassoFast(corr_sub2, rho = 0.01)
plot(glasso2$wi[, 1], glasso$wi[ind2, ind2][, 1])
all.equal(glasso2$wi[, 1], glasso$wi[ind2, ind2][, 1])

glasso3 <- bigutilsr::regul_glasso(corr_sub2, lambda = 0.01)
all.equal(glasso3, corr_sub2)
all.equal(glasso2$w, corr_sub2)

glasso4 <- glasso2$w; diag(glasso4) <- 1
all.equal(glasso3, glasso2$w)
all.equal(glasso3, glasso4)
glasso3[1:5, 1:5]
glasso2$w[1:5, 1:5]

microbenchmark::microbenchmark(
glasso2 <- glassoFast::glassoFast(corr_sub2, rho = 0.01),
glasso3 <- bigutilsr::regul_glasso(corr_sub2, lambda = 0.01)
)

corr_sub2_sfbm <- bigsnpr::as_SFBM(corr[id_sub2, id_sub2], compact = TRUE)
b <- rep(0, length(id_sub2)); b[1] <- 1
test_solve <- bigsparser::sp_solve_sym(corr_sub2_sfbm, b, add_to_diag = 0)
plot(glasso2$wi[, 1], test_solve); abline(0, 1, col = "red")

Rcpp::sourceCpp("tmp-tests/test-linear-solver2.cpp",
showOutput = FALSE, echo = FALSE)
test_solve3 <- test_solver(as(corr[id_sub2, id_sub2], "generalMatrix"), b)

plot(glasso2$wi[, 1], test_solve3); abline(0, 1, col = "red")
plot(test_solve3, test_solve)
all.equal(test_solve3, test_solve)

corr_sub3 <- as(corr[id_sub2, id_sub2], "generalMatrix")

microbenchmark::microbenchmark(
glasso2 <- glassoFast::glassoFast(corr_sub2, rho = 0.01),
glasso3 <- bigutilsr::regul_glasso(corr_sub2, lambda = 0.01),
test_solve <- bigsparser::sp_solve_sym(corr_sub2_sfbm, b),
test_solve2 <- bigsparser::sp_solve_sym(corr_sub2_sfbm, b, add_to_diag = 0.01),
# as(corr[id_sub2, id_sub2], "generalMatrix"),
test_solve3 <- test_solver(corr_sub3, b)
)


corr2 <- as(corr, "generalMatrix")
b2 <- rep(0, ncol(corr2)); b2[1] <- 1
system.time(
test_solve4 <- test_solver(corr2, b2)
)
plot(test_solve4[ind2], test_solve3)

system.time(
test_solve5 <- test_solver(corr2 + Matrix::Diagonal(n = ncol(corr2), x = 0.01), b2)
) # 39 sec
plot(test_solve4, test_solve5)

system.time(
test_solve6 <- test_solver(corr2 + Matrix::Diagonal(n = ncol(corr2), x = 0.1), b2)
) # 8 sec
plot(test_solve6, test_solve5); abline(0, 1, col = "red")

Rcpp::sourceCpp("tmp-tests/test-linear-solver3.cpp")
corr3 <- as.matrix(corr2)
system.time(
test_solve7 <- test_solver_dense(corr3, b2)
) # 87 sec
all.equal(test_solve7, test_solve4)

corr4 <- corr3; diag(corr4) <- diag(corr3) + 0.01
system.time(
test_solve8 <- test_solver_dense(corr4, b2)
) # 21 sec
all.equal(test_solve8, test_solve5)

system.time(
test_solve9 <- test_solver_dense_sparse(corr4)
) # same time
all.equal(test_solve8, test_solve9)

system.time(
test_solve10 <- test_ConjugateGradient(corr4)
) # 28 sec
all.equal(test_solve10, test_solve9)

system.time(
test_solve11 <- test_BiCGSTAB(corr4)
) # 225 sec
all.equal(test_solve11, test_solve9) # 1e-4 relative diff

system.time(
test_solve12 <- test_GMRES(corr4)
) # 119 sec
all.equal(test_solve12, test_solve9) # 1e-2 relative diff
Loading

0 comments on commit d16e7b2

Please sign in to comment.