Skip to content

Commit

Permalink
STYLE: Format with black (#237)
Browse files Browse the repository at this point in the history
  • Loading branch information
NickEdwards7502 committed Sep 19, 2024
1 parent 8f11e62 commit d671f35
Showing 1 changed file with 30 additions and 16 deletions.
46 changes: 30 additions & 16 deletions python/varspark/test/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,15 @@ class VariantSparkPySparkTestCase(unittest.TestCase):

@classmethod
def setUpClass(self):
sconf = SparkConf(loadDefaults=False) \
.set("spark.driver.extraClassPath", find_variants_jar())
spark = SparkSession.builder.config(conf=sconf) \
.appName("test").master("local").getOrCreate()
sconf = SparkConf(loadDefaults=False).set(
"spark.driver.extraClassPath", find_variants_jar()
)
spark = (
SparkSession.builder.config(conf=sconf)
.appName("test")
.master("local")
.getOrCreate()
)
self.sc = spark.sparkContext

@classmethod
Expand All @@ -37,28 +42,37 @@ def setUp(self):
def test_variants_context_parameter_type(self):
with self.assertRaises(TypeError) as cm:
self.vc.load_label(label_file_path=123, col_name=456)
self.assertEqual('keyword argument label_file_path = 123 doesn\'t match signature str',
str(cm.exception))
self.assertEqual(
"keyword argument label_file_path = 123 doesn't match signature str",
str(cm.exception),
)

def test_rfmodel(self):
label_data_path = os.path.join(PROJECT_DIR, 'data/chr22-labels.csv')
label = self.vc.load_label(label_file_path=label_data_path, col_name='22_16050678')
feature_data_path = os.path.join(PROJECT_DIR, 'data/chr22_1000.vcf')
label_data_path = os.path.join(PROJECT_DIR, "data/chr22-labels.csv")
label = self.vc.load_label(
label_file_path=label_data_path, col_name="22_16050678"
)
feature_data_path = os.path.join(PROJECT_DIR, "data/chr22_1000.vcf")
features = self.vc.import_vcf(vcf_file_path=feature_data_path)
rf = RFModelContext(self.spark, mtry_fraction=None, oob=True, seed=17, var_ordinal_levels=3)
rf = RFModelContext(
self.spark, mtry_fraction=None, oob=True, seed=17, var_ordinal_levels=3
)
rf.fit_trees(features, label, n_trees=200, batch_size=50)
imp_analysis = rf.importance_analysis()
imp_vars = imp_analysis.important_variables(20)
most_imp_var = imp_vars['variable'][0]
self.assertEqual('22_16050678_C_T', most_imp_var)
most_imp_var = imp_vars["variable"][0]
self.assertEqual("22_16050678_C_T", most_imp_var)
df = imp_analysis.variable_importance(normalized=True)
self.assertEqual('22_16050678_C_T',
str(df.sort_values(by='importance', ascending=False)['variant_id'].iloc[0]))
self.assertEqual(
"22_16050678_C_T",
str(df.sort_values(by="importance", ascending=False)["variant_id"].iloc[0]),
)
oob_error = rf.oob_error()
self.assertEqual(0.004578754578754579, oob_error)
fdrCalc = rf.get_lfdr()
_, fdr = fdrCalc.compute_fdr(countThreshold = 2, local_fdr_cutoff = 0.05)
_, fdr = fdrCalc.compute_fdr(countThreshold=2, local_fdr_cutoff=0.05)
self.assertEqual(0.0002976892628282768, fdr)

if __name__ == '__main__':

if __name__ == "__main__":
unittest.main(verbosity=2)

0 comments on commit d671f35

Please sign in to comment.