Skip to content

Commit

Permalink
feat(L2GFeatureMatrix): track missingness rate for each feature
Browse files Browse the repository at this point in the history
  • Loading branch information
ireneisdoomed committed Dec 13, 2023
1 parent be480cd commit e69c47e
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 3 deletions.
20 changes: 20 additions & 0 deletions src/otg/dataset/l2g_feature_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,26 @@ def get_schema(cls: type[L2GFeatureMatrix]) -> StructType:
"""
return parse_spark_schema("l2g_feature_matrix.json")

def calculate_feature_missingness_rate(
self: L2GFeatureMatrix,
) -> dict[str, float]:
"""Calculate the proportion of missing values in each feature.
Returns:
dict[str, float]: Dictionary of feature names and their missingness rate.
Raises:
ValueError: If no features are found.
"""
total_count = self._df.count()
if not self.features_list:
raise ValueError("No features found")

return {
feature: (self._df.filter(self._df[feature].isNull()).count() / total_count)
for feature in self.features_list
}

def fill_na(
self: L2GFeatureMatrix, value: float = 0.0, subset: list[str] | None = None
) -> L2GFeatureMatrix:
Expand Down
11 changes: 8 additions & 3 deletions src/otg/method/l2g/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,14 +125,19 @@ def log_to_wandb(
wandb_evaluator.evaluate(results)
## Track feature importance
wandb_run.log({"importances": self.get_feature_importance()})
## Track training set metadata
## Track training set
training_table = wandb.Table(dataframe=training_data.df.toPandas())
wandb_run.log({"trainingSet": training_table})
# Count number of positive and negative labels
gs_counts_dict = {
"goldStandard" + row["goldStandardSet"].capitalize(): row["count"]
for row in training_data.df.groupBy("goldStandardSet").count().collect()
}
wandb_run.log(gs_counts_dict)
training_table = wandb.Table(dataframe=training_data.df.toPandas())
wandb_run.log({"trainingSet": training_table})
# Missingness rates
wandb_run.log(
"missingnessRates", training_data.calculate_feature_missingness_rate()
)

@classmethod
def load_from_disk(
Expand Down

0 comments on commit e69c47e

Please sign in to comment.