Added group norm L0 and shifted group norm L0#117
Added group norm L0 and shifted group norm L0#117AHsu98 wants to merge 9 commits intoJuliaSmoothOptimizers:masterfrom
Conversation
|
Thank you @AHsu98, could you please rebase and update this PR ? |
+ to .+ and some minor syntax changes Co-authored-by: Dominique <dominique.orban@gmail.com>
…same changes that @dpo suggested on the checking for this groupNormL0 to make groupNormL2 match as well
…in groupNormL0 and groupNormL2 as it was causing tests to error with the way I was checking, and added the groupNormL0 and shiftedGroupNormL0 to the main file. Added groupNormL0 to the tests, and its not erroring, but haven't added the correctness check yet.
|
Alright, I rebased, made some updates (fixed some original mistakes), and incorporated the previous suggestions. It'd be nice to add a check that the indices are non-overlapping (the naive thing I tried worked for the basic case of vectors of vectors of integers as indices, but caused errors in the tests, maybe from shifts, or multiple indices). I also made the suggested changes to the checks on groupNormL2 as well. |
| ysum = R(0) | ||
| for (idx, λ) ∈ zip(f.idx, f.lambda) | ||
| yt = norm(x[idx])^2 | ||
| if yt !=0 |
There was a problem hiding this comment.
| if yt !=0 | |
| if yt > 0 |
| struct GroupNormL0{R <: Real, RR <: AbstractVector{R}, I} | ||
| lambda::RR | ||
| idx::I | ||
|
|
||
| function GroupNormL0{R, RR, I}(lambda::RR, idx::I) where {R <: Real, RR <: AbstractVector{R}, I} | ||
| any(lambda .< 0) && error("weights λ must be nonnegative") | ||
| length(lambda) != length(idx) && error("number of weights and groups must be the same") | ||
| new{R, RR, I}(lambda, idx) | ||
| end | ||
| end |
There was a problem hiding this comment.
| struct GroupNormL0{R <: Real, RR <: AbstractVector{R}, I} | |
| lambda::RR | |
| idx::I | |
| function GroupNormL0{R, RR, I}(lambda::RR, idx::I) where {R <: Real, RR <: AbstractVector{R}, I} | |
| any(lambda .< 0) && error("weights λ must be nonnegative") | |
| length(lambda) != length(idx) && error("number of weights and groups must be the same") | |
| new{R, RR, I}(lambda, idx) | |
| end | |
| end | |
| struct GroupNormL0{R <: Real, V <: AbstractVector{R}, I} | |
| lambda::V | |
| idx::I | |
| function GroupNormL0{R, V, I}(lambda::V, idx::I) where {R <: Real, V <: AbstractVector{R}, I} | |
| any(lambda .< 0) && error("weights λ must be nonnegative") | |
| length(lambda) != length(idx) && error("number of weights and groups must be the same") | |
| new{R, V, I}(lambda, idx) | |
| end | |
| end |
| function prox!( | ||
| y::AbstractArray{R}, | ||
| f::GroupNormL0{R, RR, I}, | ||
| x::AbstractArray{R}, | ||
| γ::R = R(1), | ||
| ) where {R <: Real, RR <: AbstractVector{R}, I} |
There was a problem hiding this comment.
| function prox!( | |
| y::AbstractArray{R}, | |
| f::GroupNormL0{R, RR, I}, | |
| x::AbstractArray{R}, | |
| γ::R = R(1), | |
| ) where {R <: Real, RR <: AbstractVector{R}, I} | |
| function prox!( | |
| y::AbstractArray{R}, | |
| f::GroupNormL0{R, V, I}, | |
| x::AbstractArray{R}, | |
| γ::R = R(1), | |
| ) where {R <: Real, V <: AbstractVector{R}, I} |
| @@ -17,13 +17,9 @@ struct GroupNormL2{R <: Real, RR <: AbstractVector{R}, I} | |||
| idx::I | |||
|
|
|||
| function GroupNormL2{R, RR, I}(lambda::RR, idx::I) where {R <: Real, RR <: AbstractVector{R}, I} | |||
There was a problem hiding this comment.
| function GroupNormL2{R, RR, I}(lambda::RR, idx::I) where {R <: Real, RR <: AbstractVector{R}, I} | |
| function GroupNormL2{R, V, I}(lambda::V, idx::I) where {R <: Real, V <: AbstractVector{R}, I} |
| end | ||
| any(lambda .< 0) && error("weights λ must be nonnegative") | ||
| length(lambda) != length(idx) && error("number of weights and groups must be the same") | ||
| new{R, RR, I}(lambda, idx) |
There was a problem hiding this comment.
| new{R, RR, I}(lambda, idx) | |
| new{R, V, I}(lambda, idx) |
There was a problem hiding this comment.
There may be others in this file. It’s just easier to remember that V stands for “vector”.
|
|
||
| mutable struct ShiftedGroupNormL0{ | ||
| R <: Real, | ||
| RR <: AbstractVector{R}, |
First commit towards adding a group L0 norm. Happy to change the name, perhaps L02 is better. This is defined as the L0 norm of the vector of L2 norms of the indices, weighted by the values of lambda. Still needs a bit of testing I think, but pretty simple, only a couple of lines to change from GroupL2.
Also, in both this and GroupL2, we should call out that we require it to be a separable sum (no indices overlap).