diff --git a/python/varspark/core.py b/python/varspark/core.py index 8d85ebd4..cdab5cc9 100644 --- a/python/varspark/core.py +++ b/python/varspark/core.py @@ -7,14 +7,15 @@ from varspark import java from varspark.etc import find_jar from varspark.featuresource import FeatureSource +from varspark.covariatesource import CovariateSource + class VarsparkContext(object): - """The main entry point for VariantSpark functionality. - """ + """The main entry point for VariantSpark functionality.""" @classmethod def spark_conf(cls, conf=SparkConf()): - """ Adds the necessary option to the spark configuration. + """Adds the necessary option to the spark configuration. Note: In client mode these need to be setup up using --jars or --driver-class-path """ return conf.set("spark.jars", find_jar()) @@ -30,37 +31,63 @@ def __init__(self, ss, silent=False): self.sql = SQLContext.getOrCreate(self.sc) self._jsql = self.sql._jsqlContext self._jvm = self.sc._jvm - self._vs_api = getattr(self._jvm, 'au.csiro.variantspark.api') + self._vs_api = getattr(self._jvm, "au.csiro.variantspark.api") jss = ss._jsparkSession self._jvsc = self._vs_api.VSContext.apply(jss) setup_typecheck() if not self.silent: - sys.stderr.write('Running on Apache Spark version {}\n'.format(self.sc.version)) + sys.stderr.write( + "Running on Apache Spark version {}\n".format(self.sc.version) + ) if self.sc._jsc.sc().uiWebUrl().isDefined(): - sys.stderr.write('SparkUI available at {}\n'.format( - self.sc._jsc.sc().uiWebUrl().get())) + sys.stderr.write( + "SparkUI available at {}\n".format( + self.sc._jsc.sc().uiWebUrl().get() + ) + ) sys.stderr.write( - 'Welcome to\n' - ' _ __ _ __ _____ __ \n' - '| | / /___ ______(_)___ _____ / /_/ ___/____ ____ ______/ /__ \n' - '| | / / __ `/ ___/ / __ `/ __ \/ __/\__ \/ __ \/ __ `/ ___/ //_/ \n' - '| |/ / /_/ / / / / /_/ / / / / /_ ___/ / /_/ / /_/ / / / ,< \n' - '|___/\__,_/_/ /_/\__,_/_/ /_/\__//____/ .___/\__,_/_/ /_/|_| \n' - ' /_/ \n') + "Welcome to\n" + " _ __ _ __ _____ __ \n" + "| | / /___ ______(_)___ _____ / /_/ ___/____ ____ ______/ /__ \n" + "| | / / __ `/ ___/ / __ `/ __ \/ __/\__ \/ __ \/ __ `/ ___/ //_/ \n" + "| |/ / /_/ / / / / /_/ / / / / /_ ___/ / /_/ / /_/ / / / ,< \n" + "|___/\__,_/_/ /_/\__,_/_/ /_/\__//____/ .___/\__,_/_/ /_/|_| \n" + " /_/ \n" + ) @params(self=object, vcf_file_path=str, min_partitions=int) def import_vcf(self, vcf_file_path, min_partitions=0): - """ Import features from a VCF file. - """ - return FeatureSource(self._jvm, self._vs_api, - self._jsql, self.sql, self._jvsc.importVCF(vcf_file_path, - min_partitions)) + """Import features from a VCF file.""" + return FeatureSource( + self._jvm, + self._vs_api, + self._jsql, + self.sql, + self._jvsc.importVCF(vcf_file_path, min_partitions), + ) + + @params(self=object, cov_file_path=str, cov_types=(list, dict)) + def import_covariates(self, cov_file_path, cov_types): + """Import covariates from a CSV file.""" + if isinstance(cov_types, list): + types_rdd = self._jvm.SparkContext.parallelize(cov_types) + elif isinstance(cov_types, dict): + types_rdd = self._jvm.SparkContext.parallelize(cov_types.items()) + else: + types_rdd = None + return CovariateSource( + self._jvm, + self._vs_api, + self._jsql, + self.sql, + self._jvsc.importCSV(inputFile=cov_file_path, optVariableTypes=types_rdd), + ) @params(self=object, label_file_path=str, col_name=str) def load_label(self, label_file_path, col_name): - """ Loads the label source file + """Loads the label source file :param label_file_path: The file path for the label source file :param col_name: the name of the column containing labels @@ -68,8 +95,7 @@ def load_label(self, label_file_path, col_name): return self._jvsc.loadLabel(label_file_path, col_name) def stop(self): - """ Shut down the VariantsContext. - """ + """Shut down the VariantsContext.""" self.sc.stop() self.sc = None