Skip to content

Commit

Permalink
Update SPFinder.py
Browse files Browse the repository at this point in the history
  • Loading branch information
PSSUN committed Jan 27, 2025
1 parent 7bf7bfe commit 19a4f51
Showing 1 changed file with 51 additions and 35 deletions.
86 changes: 51 additions & 35 deletions STMiner/SPFinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from STMiner.Algorithm.algorithm import cluster
from STMiner.Algorithm.distance import *
from STMiner.Algorithm.distance import compare_gmm_distance
from STMiner.Algorithm.distribution import get_gmm
from STMiner.Algorithm.distribution import get_gmm, array_to_list
from STMiner.Algorithm.distribution import view_gmm, fit_gmms, get_gmm_from_image
from STMiner.IO.IOUtil import merge_bin_coordinate
from STMiner.IO.read_bmk import read_bmk
Expand All @@ -20,21 +20,6 @@


def scale_array(exp_matrix, total_count):
"""
Scale the expression matrix to a total count of 100.
This function calculates the total sum of the expression matrix,
determines a scale factor to achieve a total count of 100,
scales the expression matrix by this factor, and adds the scaled
values to the total count.
Parameters:
exp_matrix (numpy.ndarray): A 2D array containing expression values.
total_count (numpy.ndarray): A 1D array containing the total count.
Returns:
numpy.ndarray: The updated total count after scaling the expression matrix.
"""
total_sum = np.sum(exp_matrix)
scale_factor = 100 / total_sum
scaled_matrix = exp_matrix * scale_factor
Expand All @@ -54,6 +39,7 @@ def __init__(self, adata: Optional[AnnData] = None):
self.image_gmm = None
self.global_distance = None
self.all_labels = None
self.custom_pattern = None
self.csr_dict = {}
self.patterns_matrix_dict = {}
self.patterns_binary_matrix_dict = {}
Expand Down Expand Up @@ -321,31 +307,47 @@ def get_pattern_array(self, vote_rate: int = 0, mode: str = "vote"):
if mode == "vote":
label_list = set(self.genes_labels["labels"])
for label in label_list:
gene_list = list(
self.genes_labels[self.genes_labels["labels"] == label]["gene_id"]
)
total_count = np.zeros(get_exp_array(self.adata, gene_list[0]).shape)
total_coo_list = []
vote_array = np.zeros(get_exp_array(self.adata, gene_list[0]).shape)
for gene in gene_list:
exp_matrix = get_exp_array(self.adata, gene)
# calculate nonzero index
non_zero_coo_list = np.vstack((np.nonzero(exp_matrix))).T.tolist()
for coo in non_zero_coo_list:
total_coo_list.append(tuple(coo))
total_count = scale_array(exp_matrix, total_count)
count_dict = Counter(total_coo_list)
for ele, count in count_dict.items():
if int(count) / len(gene_list) >= vote_rate:
vote_array[ele] = 1
total_count = total_count * vote_array
binary_arr = np.where(total_count != 0, 1, total_count)
gene_list = list(self.genes_labels[self.genes_labels["labels"] == label]["gene_id"])
binary_arr, total_count = self._genes_to_pattern(gene_list, vote_rate)
self.patterns_matrix_dict[label] = total_count
self.patterns_binary_matrix_dict[label] = binary_arr
elif mode == "test":
p_value_threshold = 0.05
# TODO: rewrite test mode, improve run time.
pass
else:
raise ValueError("mode should be vote or test")

def _genes_to_pattern(self, gene_list, vote_rate):

total_count = np.zeros(get_exp_array(self.adata, gene_list[0]).shape)
total_coo_list = []
vote_array = np.zeros(get_exp_array(self.adata, gene_list[0]).shape)
for gene in gene_list:
exp_matrix = get_exp_array(self.adata, gene)
# calculate nonzero index
non_zero_coo_list = np.vstack((np.nonzero(exp_matrix))).T.tolist()
for coo in non_zero_coo_list:
total_coo_list.append(tuple(coo))
total_count = scale_array(exp_matrix, total_count)
count_dict = Counter(total_coo_list)
for ele, count in count_dict.items():
if int(count) / len(gene_list) >= vote_rate:
vote_array[ele] = 1
total_count = total_count * vote_array
binary_arr = np.where(total_count != 0, 1, total_count)
return binary_arr, total_count

def get_custom_pattern(self, gene_list, n_components=20, vote_rate: int = 0, mode: str = "vote"):
if mode == "vote":
_, total_count = self._genes_to_pattern(gene_list, vote_rate)
from sklearn import mixture
_gmm = mixture.GaussianMixture(n_components=n_components)
_gmm.fit(array_to_list(np.round(total_count).astype(np.int32)))
self.custom_pattern = _gmm
elif mode == "test":
p_value_threshold = 0.05
# TODO: rewrite test mode, improve run time.
pass
else:
raise ValueError("mode should be vote or test")
Expand Down Expand Up @@ -396,5 +398,19 @@ def get_all_labels(self):
]
self.all_labels = all_labels

def get_pattern_of_given_genes(self, gene_list, n_comp=20):
_genes = []
if self.adata is None:
raise ValueError("Please load ST data first.")
for i in gene_list:
if i in list(self.adata.var.index):
_genes.append(i)

# Get expression patterns of interested gene set
self.fit_pattern(n_comp=n_comp, gene_list=_genes)
self.cluster_gene(n_clusters=1, mds_components=2)
self.patterns_matrix_dict = None
self.get_custom_pattern(gene_list=gene_list, n_components=n_comp, vote_rate=0)

# def flush_app(self):
# self.app = App()

0 comments on commit 19a4f51

Please sign in to comment.