Skip to content

Commit

Permalink
Merge pull request #86 from cesmix-mit/sw/quick_wls_fix
Browse files Browse the repository at this point in the history
Quick fix to use pinv only when necessary for WLS
  • Loading branch information
swyant authored Sep 10, 2024
2 parents 83b17e9 + 0cfc782 commit acb6d22
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 3 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "PotentialLearning"
uuid = "82b0a93c-c2e3-44bc-a418-f0f89b0ae5c2"
authors = ["CESMIX Team"]
version = "0.2.5"
version = "0.2.6"

[deps]
AtomsBase = "a963bdd2-2df7-4f54-a1ee-49d51e6be12a"
Expand Down
20 changes: 18 additions & 2 deletions src/Learning/linear-learn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,15 @@ function learn!(

# Calculate coefficients β.
Q = Diagonal(ws[1] * ones(length(e_train)))
βs = pinv(A'*Q*A)*(A'*Q*b)

βs = Vector{Float64}()
try
βs = (A'*Q*A) \ (A'*Q*b)
catch e
println(e)
println("Linear system will be solved using pinv.")
βs = pinv(A'*Q*A)*(A'*Q*b)
end

# Update lp.
if int
Expand Down Expand Up @@ -238,7 +246,15 @@ function learn!(
# Calculate coefficients βs.
Q = Diagonal([ws[1] * ones(length(e_train));
ws[2] * ones(length(f_train))])
βs = pinv(A'*Q*A)*(A'*Q*b)

βs = Vector{Float64}()
try
βs = (A'*Q*A) \ (A'*Q*b)
catch e
println(e)
println("Linear system will be solved using pinv.")
βs = pinv(A'*Q*A)*(A'*Q*b)
end

# Update lp.
if int
Expand Down

0 comments on commit acb6d22

Please sign in to comment.