Skip to content

Commit

Permalink
feat(l2gprediction): add model as attribute
Browse files Browse the repository at this point in the history
  • Loading branch information
ireneisdoomed committed Jan 28, 2025
1 parent eedc6ab commit 624602e
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
6 changes: 4 additions & 2 deletions src/gentropy/dataset/l2g_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import TYPE_CHECKING

import pyspark.sql.functions as f
Expand All @@ -29,6 +29,8 @@ class L2GPrediction(Dataset):
confidence of the prediction that a gene is causal to an association.
"""

model: LocusToGeneModel | None = field(default=None, repr=False)

@classmethod
def get_schema(cls: type[L2GPrediction]) -> StructType:
"""Provides the schema for the L2GPrediction dataset.
Expand Down Expand Up @@ -82,7 +84,6 @@ def from_credible_set(
.fill_na()
.select_features(l2g_model.features_list)
)

return l2g_model.predict(fm, session)

def to_disease_target_evidence(
Expand Down Expand Up @@ -172,4 +173,5 @@ def add_locus_to_gene_features(
aggregated_features, on=["studyLocusId", "geneId"], how="left"
),
_schema=self.get_schema(),
model=self.model,
)
1 change: 1 addition & 0 deletions src/gentropy/method/l2g/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ def predict(
return L2GPrediction(
_df=session.spark.createDataFrame(feature_matrix_pdf.filter(output_cols)),
_schema=L2GPrediction.get_schema(),
model=self,
)

def save(self: LocusToGeneModel, path: str) -> None:
Expand Down

0 comments on commit 624602e

Please sign in to comment.