Skip to content

Commit 73be1f2

Browse files
authored
Merge pull request #9 from BiomedSciAI/julia-cumulants
Julia cumulants
2 parents cb4a278 + 4c59d22 commit 73be1f2

File tree

6 files changed

+426
-126
lines changed

6 files changed

+426
-126
lines changed

geno4sd/topology/CuNA/cumulants.jl

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
_author__ = "Myson Burch"
2+
__copyright__ = "Copyright 2023, IBM Research"
3+
__version__ = "0.1"
4+
__maintainer__ = "Myson Burch"
5+
__email__ = "[email protected]"
6+
__status__ = "Development"
7+
8+
import Base.Threads.@threads
9+
10+
using Cumulants, NPZ, LinearAlgebra, Random, Statistics
11+
12+
function upper_tri(A)
13+
"""
14+
Generate upper triangular entries of residuals
15+
"""
16+
17+
# Get the dimensions of the matrix
18+
n, m = size(A)
19+
20+
# Initialize an empty vector to store the upper triangular elements
21+
upper_triangular_vector = Vector{Float64}(undef, n*(n-1)÷2)
22+
23+
# Extract the upper triangular elements and arrange them in order by row
24+
index = 1
25+
for i in 1:n
26+
for j in i+1:n
27+
upper_triangular_vector[index] = A[i, j]
28+
index += 1
29+
end
30+
end
31+
32+
# upper triangular vector
33+
return upper_triangular_vector
34+
end
35+
36+
function parse_third_order(x, n)
37+
"""
38+
Parse third order cumulant labels
39+
"""
40+
41+
third = Vector{Float64}()
42+
for i in 1:n
43+
curr_sheet = x[i,:,:]
44+
sheet_without_dupes = curr_sheet[i+1:end,i+1:end]
45+
third_res = upper_tri(sheet_without_dupes) # sheet_without_dupes[triu(trues(size(sheet_without_dupes)), 1)]
46+
third_res[isnan.(third_res)] .= 0.0
47+
append!(third, third_res)
48+
end
49+
return third
50+
end
51+
52+
function parse_fourth_order(x, n)
53+
"""
54+
Parse fourth order cumulant labels
55+
"""
56+
57+
fourth = Vector{Float64}()
58+
for j in 1:n
59+
for i in 1:n
60+
if i <= j
61+
# pass
62+
else
63+
curr_sheet = x[j,i,:,:]
64+
sheet_without_dupes = curr_sheet[i+1:end,i+1:end]
65+
fourth_res = upper_tri(sheet_without_dupes) # sheet_without_dupes[triu(trues(size(sheet_without_dupes)), 1)]
66+
fourth_res[isnan.(fourth_res)] .= 0.0
67+
append!(fourth, fourth_res)
68+
end
69+
end
70+
end
71+
return fourth
72+
end
73+
74+
function permute_dat(x)
75+
"""
76+
Permute columns
77+
"""
78+
79+
# Get the number of rows and columns in the matrix
80+
num_rows, num_cols = size(x)
81+
82+
y = rand(num_rows, num_cols)
83+
84+
# Loop through each column and shuffle its elements
85+
for col in 1:num_cols
86+
y[:, col] = shuffle!(x[:, col])
87+
end
88+
89+
return y
90+
91+
end
92+
93+
function run_cumulants(x, order)
94+
"""
95+
Compute cumulants
96+
"""
97+
98+
# cumulants(data::Matrix{T}, order::Int = 4, block::Int = 2)
99+
c = cumulants(x, parse(Int,order), 4)
100+
101+
res = Vector{Float64}()
102+
103+
first = Array(c[1])
104+
first[isnan.(first)] .= 0.0
105+
append!(res, first)
106+
107+
second = Array(c[2])
108+
second = upper_tri(second) # second[triu(trues(size(second)), 1)]
109+
second[isnan.(second)] .= 0.0
110+
append!(res, second)
111+
112+
third = parse_third_order(Array(c[3]), size(first)[1])
113+
append!(res, third)
114+
115+
if order == "4"
116+
fourth = parse_fourth_order(Array(c[4]), size(first)[1])
117+
append!(res, fourth)
118+
end
119+
120+
return res
121+
122+
end
123+
124+
## LOAD IN DATA
125+
x = npzread(ARGS[1]*"julia_dat.npy")
126+
127+
order = ARGS[2]
128+
129+
## RUN CUMULANTS
130+
res = run_cumulants(x, order)
131+
132+
## RUN PERMUTATIONS
133+
n = size(res)[1]; m = 50;
134+
dat_perms = rand(n,m);
135+
136+
@threads for i in 1:m
137+
y = permute_dat(x)
138+
dat_perms[:,i] = run_cumulants(y, order)
139+
end
140+
141+
# Calculate the mean by rows (axis=2) and standard deviation by rows (axis=2)
142+
mean_by_rows = mean(dat_perms, dims=2)
143+
std_by_rows = std(dat_perms, dims=2)
144+
std_by_rows[std_by_rows .< 1e-12] .= 0
145+
z = res - mean_by_rows
146+
z = ifelse.(denominator .== 0.0, 0.0, z ./ std_by_rows)
147+
z = ifelse.(isinf.(z), 0.0, z)
148+
149+
## WRITE TO OUTPUT
150+
npzwrite(ARGS[1]*"julia_cumulants.npz", res, mean_by_rows, std_by_rows, z)

0 commit comments

Comments
 (0)