Skip to content

Commit

Permalink
update interface to make more uniform; add a vignette
Browse files Browse the repository at this point in the history
  • Loading branch information
stephens999 committed Sep 23, 2017
1 parent 2f721d6 commit 6887cc7
Show file tree
Hide file tree
Showing 38 changed files with 648 additions and 269 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
.Rhistory
.RData
.Ruserdata
*.bak
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Package: flashr2
Type: Package
Title: Empirical Bayes Shrinkage for Factor Analysis
Version: 0.1-12
Title: Factor Analysis with Empirical Bayes Adaptive Shrinkage
Version: 0.1-13
Author: Wei Wang, Matthew Stephens
Maintainer: The package maintainer <[email protected]>
Description: More about what it does (maybe more than one line)
Expand Down
15 changes: 11 additions & 4 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
# Generated by roxygen2: do not edit by hand

export(flash)
export(flash_add_factors_from_data)
export(flash_add_fixed_f)
export(flash_add_fixed_l)
export(flash_add_greedy)
export(flash_add_lf)
export(flash_backfit)
export(flash_get_F)
export(flash_get_f)
export(flash_get_k)
export(flash_get_l)
export(flash_get_lf)
export(flash_get_pve)
export(flash_get_sizes)
export(flash_get_udv)
export(flash_greedy)
export(flash_init_fn)
export(flash_r1)
export(flash_set_data)
export(flash_update_precision)
export(flash_zero_out_factor)
export(get_F)
export(set_flash_data)
export(udv_random)
export(udv_si)
export(udv_svd)
2 changes: 1 addition & 1 deletion R/F_objective.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#' @param data a flash data object
#' @param f a flash fit object
#' @export
get_F = function(data,f){
flash_get_F = function(data,f){
return(sum(unlist(f$KL_l))+sum(unlist(f$KL_f))+e_loglik(data,f))
}

Expand Down
1 change: 1 addition & 0 deletions R/ash_defaults.R
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
flash_default_ash_param=function(){
return(list(outputlevel=5,mixcompdist="normal",method="shrink"))
#use very small bias to null; helps makes things faster and stable
}
91 changes: 43 additions & 48 deletions R/flash.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#' @title Fit the rank1 FLASH model to data
#' @param data an n by p matrix or a flash data object created using \code{set_flash_data}
#' @param data an n by p matrix or a flash data object created using \code{flash_set_data}
#' @param var_type type of variance structure to assume for residuals.
#' @param tol specify how much objective can change in a single iteration to be considered not converged
#' @param init_fn function to be used to initialize the factor. This function should take parameters (Y,K)
Expand All @@ -19,27 +19,26 @@
#' f = flash_r1(Y)
#' flash_get_sizes(f)
#' @export
flash_r1 = function(data,var_type = c("by_column","constant"), init_fn = "udv_si",tol=1e-2,ash_param=list(),verbose = FALSE, nullcheck=TRUE){
if(is.matrix(data)){data = set_flash_data(data)}
flash_r1 = function(data,f_init=NULL,var_type = c("by_column","constant"), init_fn = "udv_si",tol=1e-2,ash_param=list(),verbose = FALSE, nullcheck=TRUE){
if(is.matrix(data)){data = flash_set_data(data)}
var_type=match.arg(var_type)
f = flash_init_fn(data,init_fn)
f = flash_optimize_single_fl(data,f,1,var_type,nullcheck,tol,ash_param,verbose)
f = flash_add_factors_from_data(data,f_init = f_init, init_fn=init_fn,K=1)
f = flash_optimize_single_fl(data,f,get_k(f),var_type,nullcheck,tol,ash_param,verbose)
return(f)
}


#' @title Fit the FLASH model to data by a greedy approach
#' @details Fits the model by adding a factor and then optimizing it.
#' @title Adds factors to a flash object by a greedy approach
#' @details Adds factors iteratively, at each time adding a new factor and then optimizing it.
#' It is "greedy" in that it does not return to re-optimize previous factors.
#' The function stops when an added factor contributes nothing, or Kmax is reached.
#' Each new factor is intialized by applying the function `init_fn` to the residuals
#' after removing previously-fitted factors.
#' @param data an n by p matrix or a flash data object created using \code{set_flash_data}
#' @param Kmax the maximum number of factors to be considered
#' @param var_type type of variance structure to assume for residuals.
#' @param f_init a flash fit object to start the greedy algorithm: the greedy algorithm iteratively adds factors
#' @param data an n by p matrix or a flash data object created using \code{flash_set_data}
#' @param Kmax the maximum number of factors to add to f_init
#' @param f_init a flash fit object to start the greedy algorithm: the greedy algorithm iteratively adds up to Kmax factors
#' to this initial fit. (If NULL then the greedy algorithm starts with 0 factors)
#' (Note: if f_init already contains at least Kmax factors then this function returns f_init)
#' @param var_type type of variance structure to assume for residuals.
#' @param init_fn function to be used to initialize each factor when added. This function should take as
#' input an n by p matrix of data (or a flash data object)
#' and output a list with elements (u,d,v) where u is an n-vector,
Expand All @@ -57,40 +56,30 @@ flash_r1 = function(data,var_type = c("by_column","constant"), init_fn = "udv_si
#' l = rnorm(100)
#' f = rnorm(10)
#' Y = outer(l,f) + matrix(rnorm(1000),nrow=100)
#' f = flash_greedy(Y,10)
#' f = flash_add_greedy(Y,10)
#' flash_get_sizes(f)
#' # example to show how to use a different initialization function
#' f2 = flash_greedy(Y,10,function(x,K=1){softImpute::softImpute(x,K,lambda=10)})
#' f2 = flash_add_greedy(Y,10,function(x,K=1){softImpute::softImpute(x,K,lambda=10)})
#' @export
flash_greedy = function(data,Kmax=1,var_type = c("by_column","constant"),f_init = NULL, init_fn="udv_si",tol=1e-2,ash_param=list(),verbose=FALSE,nullcheck=TRUE){
if(is.matrix(data)){data = set_flash_data(data)}
flash_add_greedy = function(data,Kmax=1,f_init = NULL,var_type = c("by_column","constant"), init_fn="udv_si",tol=1e-2,ash_param=list(),verbose=FALSE,nullcheck=TRUE){
if(is.matrix(data)){data = flash_set_data(data)}
var_type=match.arg(var_type)
f = f_init

if(is.null(f_init)){
message("fitting factor/loading ",1)
f = flash_r1(data,var_type,init_fn,tol,ash_param,verbose,nullcheck)
if(is_tiny_fl(f,1)){return(f)} #finish if not even rank 1
} else { #if initial value specified, set it
f = f_init
for(k in 1:Kmax){
message("fitting factor/loading ",k)
f = flash_r1(data,f,var_type,init_fn,tol,ash_param,verbose,nullcheck)
if(is_tiny_fl(f,get_k(f))) #test whether the factor/loading combination is effectively 0
break
}

k_init = get_k(f)
if(k_init<Kmax){ #if we still have factors to add
for(k in (k_init+1):Kmax){
f = flash_add_factors_from_residuals(data, f, init_fn)
message("fitting factor/loading ",k)
f = flash_optimize_single_fl(data,f,k,var_type,nullcheck,tol,ash_param,verbose)
if(is_tiny_fl(f,k)) #test whether the factor/loading combination is effectively 0
break
}
}
return(f)
}


#' @title Refines a fit of the FLASH model to data by "backfitting"
#' @details Iterates through the factors of a flash object, updating each until convergence
#' @param data an n by p matrix or a flash data object created using \code{set_flash_data}
#' @param data an n by p matrix or a flash data object created using \code{flash_set_data}
#' @param f a fitted flash object to be refined
#' @param kset the indices of factors to be optimized (NULL indicates all factors)
#' @param var_type type of variance structure to assume for residuals.
Expand All @@ -103,36 +92,42 @@ flash_greedy = function(data,Kmax=1,var_type = c("by_column","constant"),f_init
#' fg = flash_greedy(Y,10)
#' fb = flash_backfit(Y,fg) # refines fit from greedy by backfitting
#' flash_get_sizes(fb)
#' fsi = flash_init_fn(set_flash_data(Y),"udv_si",4)
#' fsi = flash_init_fn(flash_set_data(Y),"udv_si",4)
#' fb2 = flash_backfit(Y,fsi)
#' flash_get_sizes(fb2)
#' @export
flash_backfit = function(data,f,kset=NULL,var_type = c("by_column","constant"),tol=1e-2,ash_param=list(),verbose=FALSE){
if(is.matrix(data)){data = set_flash_data(data)}
flash_backfit = function(data,f,kset=NULL,var_type = c("by_column","constant"),tol=1e-2,ash_param=list(),verbose=FALSE,nullcheck=TRUE){
if(is.matrix(data)){data = flash_set_data(data)}
if(is.null(kset)){kset = 1:get_k(f)}
var_type=match.arg(var_type)
if(is.null(f$tau)){f=flash_update_precision(data,f,var_type)} # need to do this in case f hasn't been fit at all yet
c = get_conv_criteria(data, f)
c = flash_get_F(data, f)
diff = 1

while(diff > tol){
fit_got_worse = FALSE #flag used to check for occassional
#issues with fit getting slightly worse due to numerics. If so we will stop iterating
# to avoid potential infinite loop.
while(diff > tol & !fit_got_worse){
diff = 1
while(diff > tol){
for(k in kset){
f = flash_update_single_fl(data,f,k,var_type,ash_param)
}
cnew = get_conv_criteria(data, f)
diff = sqrt(mean((cnew-c)^2))
cnew = flash_get_F(data, f)
diff = cnew-c
c = cnew
if(verbose){
message("objective: ",c)
}
}

kset = 1:get_k(f) #now remove factors that actually hurt objective
f = perform_nullcheck(data,f,kset,var_type,verbose)
cnew = get_conv_criteria(data, f)
diff = sqrt(mean((cnew-c)^2))
if(diff<0){fit_got_worse=TRUE}

if(nullcheck){
kset = 1:get_k(f) #now remove factors that actually hurt objective
f = perform_nullcheck(data,f,kset,var_type,verbose)
}
cnew = flash_get_F(data, f)
diff = cnew-c
c = cnew
}

Expand All @@ -141,8 +136,8 @@ flash_backfit = function(data,f,kset=NULL,var_type = c("by_column","constant"),t

#' @title Main flash function
#' @details Performs Empirical Bayes factor analysis with adaptive shrinkage on both factors and loadings.
#' @param data an n by p matrix or a flash data object created using \code{set_flash_data}
#' @param Kmax the maximum total number of factors to use (including the r1+r2 covariates)
#' @param data an n by p matrix or a flash data object created using \code{flash_set_data}
#' @param Kmax the maximum total number of factors to use
#' @param column_covariates an n by r1 matrix of covariates (eg could be a column of all 1s to allow an intercept for each column)
#' @param row_covariates a p by r2 matrix of covariates (eg could be a vector of p 1s to allow an intercept for each row)
#' @param init_fn function used to initialize factors and loadings,
Expand All @@ -154,7 +149,7 @@ flash_backfit = function(data,f,kset=NULL,var_type = c("by_column","constant"),t
#' @return a fitted flash object
#' @export
flash = function(data,Kmax,column_covariates = NULL,row_covariates = NULL,f_init = NULL,init_fn=NULL, var_type = c("by_column","constant"),tol=1e-2,ash_param=list(),verbose=FALSE){
if(is.matrix(data)){data = set_flash_data(data)}
if(is.matrix(data)){data = flash_set_data(data)}
var_type=match.arg(var_type)
f=f_init
if(is.null(f)){f = flash_init_null()}
Expand Down
110 changes: 110 additions & 0 deletions R/flash_add.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
#' @title add factors or loadings to f
#' @details The precision parameter in f is updated after adding
#' @param data a flash data object
#' @param f a flash fit object
#' @param LL the loadings, an n by K matrix
#' @param FF the factors, a p by K matrix
#' @param fixl an n by K matrix of TRUE/FALSE values indicating which elements of LL should be considered fixed and not changed during updates.
#' Useful for including a mean factor for example.
#' @param fixf a p by K matrix of TRUE/FALSE values; same as fixl but for factors FF.
#' @return a flash fit object, with additional factors initialized using LL and FF
#' @export
flash_add_lf = function(data,LL,FF,f_init=NULL,fixl=NULL,fixf=NULL){
if(is.null(f_init)){f_init = flash_init_null()}
f2 = flash_init_lf(LL,FF,fixl,fixf)
f = flash_combine(f_init,f2)
return(flash_update_precision(data,f))
}

#' @title add factors to a flash fit object based on data
#' @param data a flash data object
#' @param K number of factors to add
#' @param f_init an existing flash fit object to add to
#' @param init_fn the function to use to initialize new factors (typically some kind of svd-like function)
#' @details Computes the current residuals from data and f_init and adds K new factors based
#' on init_fn applied to these residuals. (If f_init is NULL then the residuals are the data)
#' @export
flash_add_factors_from_data = function(data,K,f_init=NULL,init_fn="udv_si"){
if(is.null(f_init)){f_init = flash_init_null()}
R = get_R_withmissing(data,f_init)
f2 = flash_init_fn(flash_set_data(R),init_fn,K)
f = flash_combine(f_init,f2)
return(flash_update_precision(data,f))
}



#' @title Add a set of fixed loadings to a flash fit object
#' @param data a flash data object
#' @param LL the loadings, an n by K matrix. Missing values will be initialized by the mean of the relevant column (but will generally be
#' re-estimated when refitting the model).
#' @param f_init a flash fit object to which loadings are to be added (if NULL then a new fit object is created)
#' @param fixl an n by K matrix of TRUE/FALSE values indicating which elements of LL should be considered fixed and not changed during updates.
#' Default is to fix all non-missing values, so missing values will be updated when f is updated.
#' @return a flash fit object, with loadings initialized from LL, and corresponding factors initialized to 0.
#' @export
flash_add_fixed_l = function(data, LL, f_init=NULL, fixl = NULL){
if(is.null(f_init)){f_init = flash_init_null()}
if(is.null(fixl)){fixl = !is.na(LL)}
LL = fill_missing_with_column_mean(LL)
FF = matrix(0,nrow=ncol(data$Y),ncol=ncol(LL))

f_new = flash_init_lf(LL,FF,fixl=fixl)
f = flash_combine(f_init,f_new)

# maybe in future we want to give a fit option? But then would
# need to pass in var_type? possibly not needed.
# if(fit){
# k1 = get_k(f_init)
# k2 = get_k(f)
# f = flash_backfit(data,f,kset=((k1+1):k2),var_type=xx)
# }
return(f)
}

#' @title Add a set of fixed factors to a flash fit object
#' @param data a flash data object
#' @param FF the factors, a p by K matrix. Missing values will be initialized by the mean of the relevant column (but will generally be
#' re-estimated when refitting the model).
#' @param f_init a flash fit object to which factors are to be added (if NULL then a new fit object is created)
#' @param fixf a p by K matrix of TRUE/FALSE values indicating which elements of FF should be considered fixed and not changed during updates.
#' Default is to fix all non-missing values, so missing values will be updated when f is updated.
#' @return a flash fit object, with factors initialized from FF, and corresponding loadings initialized to 0.
#' @export
flash_add_fixed_f = function(data, FF, f_init=NULL, fixf = NULL){
if(is.null(f_init)){f_init = flash_init_null()}
if(is.null(fixf)){fixf = !is.na(FF)}
FF = fill_missing_with_column_mean(FF)
LL = matrix(0,nrow=nrow(data$Y),ncol=ncol(FF))

f_new = flash_init_lf(LL,FF,fixf=fixf)
f = flash_combine(f_init,f_new)

# maybe in future we want to give a fit option? But then would
# need to pass in var_type? possibly not needed.
# if(fit){
# k1 = get_k(f_init)
# k2 = get_k(f)
# f = flash_backfit(data,f,kset=((k1+1):k2),var_type=xx)
# }
return(f)
}

NA2mean <- function(x) replace(x, is.na(x), mean(x, na.rm = TRUE))

fill_missing_with_column_mean = function(X){
apply(X, 2, NA2mean)
}


#' @title Initialize a flash fit object by applying a function to data
#' @param data a flash data object
#' @param init_fn an initialization function, which takes as input an (n by p matrix, or flash data object)
#' and K, a number of factors, and and outputs a list with elements (u,d,v)
#' @return a flash fit object
flash_init_fn = function(data,init_fn,K=1){
s = do.call(init_fn,list(get_Yorig(data),K))
f = flash_init_udv(s,K)
f = flash_update_precision(data,f)
return(f)
}
2 changes: 1 addition & 1 deletion R/flash_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
#' ii) data$Y * data$missing is 0 if the original data were missing
#' @return a flash data object
#' @export
set_flash_data = function(Y, S = 0){
flash_set_data = function(Y, S = 0){
data = list(Yorig = Y, S=S, anyNA=anyNA(Y), missing = is.na(Y)) # initialize data

if(anyNA(Y)){ # replace missing data with 0s
Expand Down
Loading

0 comments on commit 6887cc7

Please sign in to comment.