-
Notifications
You must be signed in to change notification settings - Fork 9
/
predict.py
68 lines (68 loc) · 2.26 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch
import numpy as np
import scanpy as sc
import anndata as ad
from tqdm import tqdm
from dataset import ViT_HER2ST, ViT_SKIN
from scipy.stats import pearsonr,spearmanr
from sklearn.cluster import KMeans
from sklearn.metrics import adjusted_rand_score as ari_score
from sklearn.metrics.cluster import normalized_mutual_info_score as nmi_score
def pk_load(fold,mode='train',flatten=False,dataset='her2st',r=4,ori=True,adj=True,prune='Grid',neighs=4):
assert dataset in ['her2st','cscc']
if dataset=='her2st':
dataset = ViT_HER2ST(
train=(mode=='train'),fold=fold,flatten=flatten,
ori=ori,neighs=neighs,adj=adj,prune=prune,r=r
)
elif dataset=='cscc':
dataset = ViT_SKIN(
train=(mode=='train'),fold=fold,flatten=flatten,
ori=ori,neighs=neighs,adj=adj,prune=prune,r=r
)
return dataset
def test(model,test,device='cuda'):
model=model.to(device)
model.eval()
preds=None
ct=None
gt=None
loss=0
with torch.no_grad():
for patch, position, exp, adj, *_, center in tqdm(test):
patch, position, adj = patch.to(device), position.to(device), adj.to(device).squeeze(0)
pred = model(patch, position, adj)[0]
preds = pred.squeeze().cpu().numpy()
ct = center.squeeze().cpu().numpy()
gt = exp.squeeze().cpu().numpy()
adata = ad.AnnData(preds)
adata.obsm['spatial'] = ct
adata_gt = ad.AnnData(gt)
adata_gt.obsm['spatial'] = ct
return adata,adata_gt
def cluster(adata,label):
idx=label!='undetermined'
tmp=adata[idx]
l=label[idx]
sc.pp.pca(tmp)
sc.tl.tsne(tmp)
kmeans = KMeans(n_clusters=len(set(l)), init="k-means++", random_state=0).fit(tmp.obsm['X_pca'])
p=kmeans.labels_.astype(str)
lbl=np.full(len(adata),str(len(set(l))))
lbl[idx]=p
adata.obs['kmeans']=lbl
return p,round(ari_score(p,l),3)
def get_R(data1,data2,dim=1,func=pearsonr):
adata1=data1.X
adata2=data2.X
r1,p1=[],[]
for g in range(data1.shape[dim]):
if dim==1:
r,pv=func(adata1[:,g],adata2[:,g])
elif dim==0:
r,pv=func(adata1[g,:],adata2[g,:])
r1.append(r)
p1.append(pv)
r1=np.array(r1)
p1=np.array(p1)
return r1,p1