Skip to content

Commit

Permalink
Merge pull request #25 from florencejt/refactor/add_mcvae_training_to…
Browse files Browse the repository at this point in the history
…_layer_mods

Fixed tests in test_pipeline to remove params arguments
  • Loading branch information
florencejt committed May 15, 2024
2 parents ba2ae01 + 6952440 commit ea4db6c
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 126 deletions.
81 changes: 48 additions & 33 deletions tests/test_data/test_get_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,22 @@ def test_get_data_module_custom(create_test_files):
graph_maker=None,
)

data_paths = {"tabular1": tabular1_csv,
"tabular2": tabular2_csv,
"image": image_torch_file_2d,
}
data_paths = {
"tabular1": tabular1_csv,
"tabular2": tabular2_csv,
"image": image_torch_file_2d,
}

# Call the prepare_fusion_data function with custom fusion type (non-graph)
dm = prepare_fusion_data(prediction_task="binary", fusion_model=fusion_model, data_paths=data_paths,
output_paths=None, test_size=0.3, batch_size=8, multiclass_dims=None, )
dm = prepare_fusion_data(
prediction_task="binary",
fusion_model=fusion_model,
data_paths=data_paths,
output_paths=None,
test_size=0.3,
batch_size=8,
multiclass_dimensions=None,
)

# Add assertions based on your expectations
assert isinstance(dm, TrainTestDataModule)
Expand All @@ -59,10 +67,11 @@ def test_get_k_fold_data_module_custom(create_test_files):
graph_maker=None,
)

data_paths = {"tabular1": tabular1_csv,
"tabular2": tabular2_csv,
"image": image_torch_file_2d,
}
data_paths = {
"tabular1": tabular1_csv,
"tabular2": tabular2_csv,
"image": image_torch_file_2d,
}

# Call the prepare_fusion_data function with custom fusion type (non-graph)
dm = prepare_fusion_data(
Expand All @@ -72,7 +81,8 @@ def test_get_k_fold_data_module_custom(create_test_files):
output_paths=None,
kfold=True,
num_folds=7,
batch_size=16)
batch_size=16,
)

# Add assertions based on your expectations
assert isinstance(dm, KFoldDataModule)
Expand All @@ -91,19 +101,21 @@ def test_get_graph_data_module(create_test_files):
graph_maker=MockGraphMakerModule,
)

data_paths = {"tabular1": tabular1_csv,
"tabular2": tabular2_csv,
"image": image_torch_file_2d,
}
data_paths = {
"tabular1": tabular1_csv,
"tabular2": tabular2_csv,
"image": image_torch_file_2d,
}

# Call the prepare_fusion_data function with custom fusion type (non-graph)
dm = prepare_fusion_data(prediction_task="regression",
fusion_model=fusion_model,
data_paths=data_paths,
output_paths=None,
multiclass_dims=None,
test_size=0.3,
)
dm = prepare_fusion_data(
prediction_task="regression",
fusion_model=fusion_model,
data_paths=data_paths,
output_paths=None,
multiclass_dimensions=None,
test_size=0.3,
)

# Add assertions based on your expectations
assert isinstance(dm, LightningNodeData)
Expand All @@ -121,19 +133,22 @@ def test_get_kfold_graph_data_module(create_test_files):
graph_maker=MockGraphMakerModule,
)

data_paths = {"tabular1": tabular1_csv,
"tabular2": tabular2_csv,
"image": image_torch_file_2d,
}
data_paths = {
"tabular1": tabular1_csv,
"tabular2": tabular2_csv,
"image": image_torch_file_2d,
}

# Call the prepare_fusion_data function with custom fusion type (non-graph)
dm = prepare_fusion_data(prediction_task="regression",
fusion_model=fusion_model,
data_paths=data_paths,
output_paths=None,
kfold=True,
num_folds=8,
batch_size=16)
dm = prepare_fusion_data(
prediction_task="regression",
fusion_model=fusion_model,
data_paths=data_paths,
output_paths=None,
kfold=True,
num_folds=8,
batch_size=16,
)

# Add assertions based on your expectations
for fold_dm in dm:
Expand Down
52 changes: 30 additions & 22 deletions tests/test_pipeline/test_kfold.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,17 @@
import matplotlib.pyplot as plt


@pytest.mark.filterwarnings("ignore:.*does not have many workers*.", )
@pytest.mark.filterwarnings(
"ignore:.*does not have many workers*.",
)
@pytest.mark.filterwarnings("ignore:.*The number of training batches*.")
@pytest.mark.filterwarnings("ignore:.*No positive samples in targets,*.")
@pytest.mark.filterwarnings("ignore:.*No negative samples in targets,*.")
@pytest.mark.filterwarnings("ignore:.*exists and is not empty*.")
def test_5fold_cv(create_test_files, tmp_path):
model_conditions = {"modality_type": "all", }
model_conditions = {
"modality_type": "all",
}

tabular1_csv = create_test_files["tabular1_csv"]
tabular2_csv = create_test_files["tabular2_csv"]
Expand All @@ -33,18 +37,7 @@ def test_5fold_cv(create_test_files, tmp_path):
checkpoint_dir = tmp_path / "checkpoint_dir"
checkpoint_dir.mkdir()

modifications = {
"AttentionAndSelfActivation": {"attention_reduction_ratio": 2}
}

params = {
# "test_size": 0.2,
"prediction_task": "binary",
"multiclass_dimensions": None,
"kfold": True,
"num_folds": 5,
"wandb_logging": False,
}
modifications = {"AttentionAndSelfActivation": {"attention_reduction_ratio": 2}}

data_paths = {
"tabular1": tabular1_csv,
Expand All @@ -58,18 +51,33 @@ def test_5fold_cv(create_test_files, tmp_path):
"losses": str(loss_log_dir),
}

new_metrics = ["accuracy", "precision", "recall", "f1", "auroc", "auprc", "balanced_accuracy"]
new_metrics = [
"accuracy",
"precision",
"recall",
"f1",
"auroc",
"auprc",
"balanced_accuracy",
]

fusion_models = import_chosen_fusion_models(model_conditions, skip_models=["MCVAE_tab"])
fusion_models = import_chosen_fusion_models(
model_conditions, skip_models=["MCVAE_tab"]
)

for model in fusion_models:
print("model", model)
dm = prepare_fusion_data(fusion_model=model,
data_paths=data_paths,
output_paths=output_paths,
layer_mods=modifications,
max_epochs=2,
**params)
dm = prepare_fusion_data(
fusion_model=model,
data_paths=data_paths,
output_paths=output_paths,
layer_mods=modifications,
max_epochs=2,
prediction_task="binary",
multiclass_dimensions=None,
kfold=True,
num_folds=5,
)

single_model_list = train_and_save_models(
data_module=dm,
Expand Down
96 changes: 50 additions & 46 deletions tests/test_pipeline/test_modify_and_new_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,28 +638,30 @@
"dropout_prob": 0.4,
},
"AttentionWeightedGraphMaker": {
"early_stop_callback": EarlyStopping(monitor="val_loss", ),
"early_stop_callback": EarlyStopping(
monitor="val_loss",
),
"edge_probability_threshold": 80,
"attention_MLP_test_size": 0.3,
"AttentionWeightingMLPInstance.weighting_layers": nn.ModuleDict({
"Layer 1": nn.Sequential(nn.Linear(25, 100),
nn.ReLU()),
"Layer 2": nn.Sequential(nn.Linear(100, 75),
nn.ReLU()),
"Layer 3": nn.Sequential(nn.Linear(75, 75),
nn.ReLU()),
"Layer 4": nn.Sequential(nn.Linear(75, 100),
nn.ReLU()),
"Layer 5": nn.Sequential(nn.Linear(100, 25),
nn.ReLU()),
})
"AttentionWeightingMLPInstance.weighting_layers": nn.ModuleDict(
{
"Layer 1": nn.Sequential(nn.Linear(25, 100), nn.ReLU()),
"Layer 2": nn.Sequential(nn.Linear(100, 75), nn.ReLU()),
"Layer 3": nn.Sequential(nn.Linear(75, 75), nn.ReLU()),
"Layer 4": nn.Sequential(nn.Linear(75, 100), nn.ReLU()),
"Layer 5": nn.Sequential(nn.Linear(100, 25), nn.ReLU()),
}
),
},
}


# Tests for adding modifications and evaluating on new data

@pytest.mark.filterwarnings("ignore:.*does not have many workers*.", )

@pytest.mark.filterwarnings(
"ignore:.*does not have many workers*.",
)
@pytest.mark.filterwarnings("ignore:.*The number of training batches*.")
@pytest.mark.filterwarnings("ignore:.*No positive samples in targets,*.")
@pytest.mark.filterwarnings("ignore:.*No negative samples in targets,*.")
Expand All @@ -668,7 +670,9 @@
@pytest.mark.filterwarnings("ignore:.*distutils Version classes are deprecated*.")
def test_train_and_test(create_test_files_more_features, tmp_path):
model_conditions = {"fusion_type": "all"}
fusion_models = import_chosen_fusion_models(model_conditions, skip_models=["MCVAE_tab"])
fusion_models = import_chosen_fusion_models(
model_conditions, skip_models=["MCVAE_tab"]
)

tabular1_csv = create_test_files_more_features["tabular1_csv"]
tabular2_csv = create_test_files_more_features["tabular2_csv"]
Expand All @@ -691,14 +695,6 @@ def test_train_and_test(create_test_files_more_features, tmp_path):
checkpoint_dir = tmp_path / f"checkpoint_dir_{timestamp}"
checkpoint_dir.mkdir()

params = {
"test_size": 0.2,
"prediction_task": "binary",
"multiclass_dimensions": None,
"kfold": False,
"wandb_logging": False,
}

data_paths = {
"tabular1": tabular1_csv,
"tabular2": tabular2_csv,
Expand All @@ -718,14 +714,16 @@ def test_train_and_test(create_test_files_more_features, tmp_path):
}

for model in fusion_models:
dm = prepare_fusion_data(fusion_model=model,
data_paths=data_paths,
output_paths=output_paths,
params=params,
layer_mods=layer_mods,
max_epochs=2,
**params
)
dm = prepare_fusion_data(
fusion_model=model,
data_paths=data_paths,
output_paths=output_paths,
layer_mods=layer_mods,
max_epochs=2,
prediction_task="binary",
multiclass_dimensions=None,
kfold=False,
)

single_model_list = train_and_save_models(
data_module=dm,
Expand All @@ -745,16 +743,20 @@ def test_train_and_test(create_test_files_more_features, tmp_path):
assert fig is not None

if trained_model.model.fusion_type != "graph":
fig_new_data = ConfusionMatrix.from_new_data([trained_model], output_paths, test_data_paths,
layer_mods=layer_mods)
fig_new_data = ConfusionMatrix.from_new_data(
[trained_model], output_paths, test_data_paths, layer_mods=layer_mods
)
assert fig_new_data is not None

plt.close("all")


# kfold version

@pytest.mark.filterwarnings("ignore:.*does not have many workers*.", )

@pytest.mark.filterwarnings(
"ignore:.*does not have many workers*.",
)
@pytest.mark.filterwarnings("ignore:.*The number of training batches*.")
@pytest.mark.filterwarnings("ignore:.*No positive samples in targets,*.")
@pytest.mark.filterwarnings("ignore:.*No negative samples in targets,*.")
Expand All @@ -763,7 +765,9 @@ def test_train_and_test(create_test_files_more_features, tmp_path):
@pytest.mark.filterwarnings("ignore:.*distutils Version classes are deprecated*.")
def test_kfold(create_test_files_more_features, tmp_path):
model_conditions = {"class_name": "all"}
fusion_models = import_chosen_fusion_models(model_conditions, skip_models=["MCVAE_tab"])
fusion_models = import_chosen_fusion_models(
model_conditions, skip_models=["MCVAE_tab"]
)

tabular1_csv = create_test_files_more_features["tabular1_csv"]
tabular2_csv = create_test_files_more_features["tabular2_csv"]
Expand Down Expand Up @@ -792,7 +796,6 @@ def test_kfold(create_test_files_more_features, tmp_path):
"multiclass_dimensions": None,
"kfold": True,
"num_folds": 3,
"wandb_logging": False,
}

data_paths = {
Expand All @@ -815,14 +818,14 @@ def test_kfold(create_test_files_more_features, tmp_path):

for model in fusion_models:

dm = prepare_fusion_data(fusion_model=model,
data_paths=data_paths,
output_paths=output_paths,
params=params,
layer_mods=layer_mods,
max_epochs=2,
**params
)
dm = prepare_fusion_data(
fusion_model=model,
data_paths=data_paths,
output_paths=output_paths,
layer_mods=layer_mods,
max_epochs=2,
**params,
)

single_model_list = train_and_save_models(
data_module=dm,
Expand All @@ -843,8 +846,9 @@ def test_kfold(create_test_files_more_features, tmp_path):
assert fig is not None

if single_model_list[0].model.fusion_type != "graph":
fig_new_data = ConfusionMatrix.from_new_data(single_model_list, output_paths, test_data_paths,
layer_mods=layer_mods)
fig_new_data = ConfusionMatrix.from_new_data(
single_model_list, output_paths, test_data_paths, layer_mods=layer_mods
)
assert fig_new_data is not None

plt.close("all")
Loading

0 comments on commit ea4db6c

Please sign in to comment.