Skip to content

Commit

Permalink
Merge branch 'svector-quaternion' into 'main'
Browse files Browse the repository at this point in the history
Use SVector to represent Quaternion coeffs

See merge request acubesat/adcs/adcs-simulation-julia!8
  • Loading branch information
xlxs4 committed Sep 25, 2023
2 parents 6b48ba0 + 1a21c9e commit 254e031
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 76 deletions.
14 changes: 10 additions & 4 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

julia_version = "1.10.0-beta2"
manifest_format = "2.0"
project_hash = "4903dd5d84faaf8c3eae7adbcef17ba971308241"
project_hash = "07e8af666232824fc87e7f7155a52487d33dbbf4"

[[deps.Accessors]]
deps = ["CompositionsBase", "ConstructionBase", "Dates", "InverseFunctions", "LinearAlgebra", "MacroTools", "Requires", "Test"]
Expand Down Expand Up @@ -246,6 +246,12 @@ deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"]
uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
version = "1.6.0"

[[deps.EpollShim_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl"]
git-tree-sha1 = "8e9441ee83492030ace98f9789a654a6d0b1f643"
uuid = "2702e6a9-849d-5ed8-8c21-79e8b8f9ee43"
version = "0.0.20230411+0"

[[deps.ExceptionUnwrapping]]
deps = ["Test"]
git-tree-sha1 = "e90caa41f5a86296e014e148ee061bd6c3edec96"
Expand Down Expand Up @@ -1001,9 +1007,9 @@ weakdeps = ["ChainRulesCore"]

[[deps.StaticArrays]]
deps = ["LinearAlgebra", "Random", "StaticArraysCore"]
git-tree-sha1 = "51621cca8651d9e334a659443a74ce50a3b6dfab"
git-tree-sha1 = "d5fb407ec3179063214bc6277712928ba78459e2"
uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "1.6.3"
version = "1.6.4"
weakdeps = ["Statistics"]

[deps.StaticArrays.extensions]
Expand Down Expand Up @@ -1118,7 +1124,7 @@ uuid = "41fe7b60-77ed-43a1-b4f0-825fd5a5650d"
version = "0.2.0"

[[deps.Wayland_jll]]
deps = ["Artifacts", "Expat_jll", "JLLWrappers", "Libdl", "Libffi_jll", "Pkg", "XML2_jll"]
deps = ["Artifacts", "EpollShim_jll", "Expat_jll", "JLLWrappers", "Libdl", "Libffi_jll", "Pkg", "XML2_jll"]
git-tree-sha1 = "ed8d92d9774b077c53e1da50fd81a36af3744c1c"
uuid = "a2964d1f-97da-50d4-b82a-358c7fce9d89"
version = "1.21.0+0"
Expand Down
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SatelliteDynamics = "0e7c1a32-1b9f-5532-88a4-e668712d6a4c"
SatelliteToolbox = "6ac157d9-b43d-51bb-8fab-48bf53814f4a"
SatelliteToolboxGeomagneticField = "9fc549ba-b5d7-49a2-b268-8171e5fb6e89"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
Expand Down
5 changes: 4 additions & 1 deletion src/ADCSSims.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
module ADCSSims

using ConcreteStructs

using StaticArrays

using SatelliteDynamics
using SatelliteToolboxGeomagneticField

Expand All @@ -20,7 +23,7 @@ include("mekf.jl")
include("simulation.jl")
include("plots.jl")

export Quaternion, scalar, vector, rotvec, toeuler
export Quaternion, QuaternionF16, QuaternionF32, QuaternionF64, rotvec
export ReactionWheel, stribeck, deadzone_compensation, saturation_compensation
export PDController, calculate_torque
export mse, qloss
Expand Down
2 changes: 1 addition & 1 deletion src/control.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ end
function calculate_torque(PD::PDController, qtarget, qestimated, w, wtarget)
qrel = qtarget * conj(qestimated) # TODO: should it be conj(qtarget)?
werr = w - wtarget
return -sign(scalar(qrel)) * PD.Kp * vector(qrel) - PD.Kd * werr
return -sign(real(qrel)) * PD.Kp * vec(qrel) - PD.Kd * werr
end
2 changes: 1 addition & 1 deletion src/losses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@ mse(ŷ, y; agg=mean) = agg(abs2.(ŷ .- y))

function qloss(q̂, q)
relq = q * conj(q̂)
return norm(vector(relq))
return norm(vec(relq))
end
178 changes: 109 additions & 69 deletions src/quaternion.jl
Original file line number Diff line number Diff line change
@@ -1,91 +1,131 @@
struct Quaternion{T}
q1::T
q2::T
q3::T
q4::T
# Inspired by Michael Boyle and Base.Complex

# Signify standard arithmetic operations should be implemented on it
struct Quaternion{T<:Real} <: Number
coeffs::SVector{4,T}
Quaternion{T}(v::SVector{4,T}) where {T<:Real} = new{T}(v)
# Always use SVector internally
Quaternion{T}(v::A) where {T<:Real, A<:AbstractVector} = new{T}(SVector{4,T}(v))
end

function Base.show(io::IO, Q::Quaternion)
print(io, "q1: ")
show(io, Q.q1)
print(io, ", q2: ")
show(io, Q.q2)
print(io, ", q3: ")
show(io, Q.q3)
print(io, ", q4: ")
show(io, Q.q4)
# Construct without having to specify T
Quaternion(v::SVector{4,T}) where {T<:Real} = Quaternion{T}(v)
Quaternion(v::AbstractVector{T}) where {T<:Real} = Quaternion{T}(v)

# Explicitly state type, use SVector promote_rules for mixed types
Quaternion{T}(w, x, y, z) where {T<:Real} = Quaternion(SVector{4,T}(w, x, y, z))
Quaternion{T}(x, y, z) where {T<:Real} = Quaternion(SVector{4,T}(zero(x), x, y, z))
Quaternion{T}(w::Real) where {T<:Real} = Quaternion(SVector{4,T}(w, zero(w), zero(w), zero(w)))
# Rely on type inference, use SVector promote_rules for mixed types
function Quaternion(w, x, y, z)
v = SVector{4}(w, x, y, z)
return Quaternion{eltype(v)}(v)
end

Quaternion(xs::Vector) = Quaternion(xs...)
function Quaternion(q1, q2, q3, q4)
promoted_type = promote_type(typeof(q1), typeof(q2), typeof(q3), typeof(q4))
return Quaternion{promoted_type}(promote(q1, q2, q3, q4)...)
function Quaternion(x, y, z)
v = SVector{4}(zero(x), x, y, z)
return Quaternion{eltype(v)}(v)
end

function Base.convert(::Type{Quaternion{T}}, x::T) where {T}
return Quaternion(x, x, x, x)
function Quaternion(w::Real)
v = SVector{4}(w, zero(w), zero(w), zero(w))
return Quaternion{eltype(v)}(v)
end

Base.convert(::Type{Quaternion}, Q::Quaternion) = Q
function Base.convert(::Type{T}, Q::Quaternion) where {T}
return Quaternion{T}(Q.q1, Q.q2, Q.q3, Q.q4)
# Type-preserving copy constructor
Quaternion{T}(Q::Quaternion{T}) where {T<:Real} = Quaternion(Q.coeffs...)
# Type conversion copy constructor
Quaternion{T}(Q::Quaternion{S}) where {T<:Real, S<:Real} = Quaternion(SVector{4, T}(Q.coeffs...))

const QuaternionF64 = Quaternion{Float64}
const QuaternionF32 = Quaternion{Float32}
const QuaternionF16 = Quaternion{Float16}

Base.zero(::Type{Quaternion{T}}) where {T<:Real} = Quaternion{T}(zero(T), zero(T), zero(T), zero(T))
Base.zero(Q::Quaternion{T}) where {T<:Real} = Base.zero(typeof(Q))

Base.one(::Type{Quaternion{T}}) where {T<:Real} = Quaternion{T}(one(T), zero(T), zero(T), zero(T))
Base.one(Q::Quaternion{T}) where {T<:Real} = Base.one(typeof(Q))

# Given a numeric type, return a Quaternion (not instance) specialized on T
(::Type{Quaternion})(::Type{T}) where {T<:Real} = Quaternion{T}
# Given a Quaternion specialized on T, return another Quaternion also specialized on T
(::Type{Quaternion})(::Type{Quaternion{T}}) where {T<:Real} = Quaternion{T}

Base.eltype(::Type{Quaternion{T}}) where {T<:Real} = T
Base.promote_rule(::Type{Quaternion{T}}, ::Type{S}) where {T<:Real, S<:Real} = Quaternion{promote_type(T, S)}
Base.promote_rule(::Type{Quaternion{T}}, ::Type{Quaternion{S}}) where {T<:Real, S<:Real} = Quaternion{promote_type(T, S)}

@inline function Base.getindex(Q::Quaternion, i::Integer)
@boundscheck checkbounds(Q.coeffs, i)
@inbounds return Q.coeffs[i]
end

Quaternion(x::T) where {T} = convert(Quaternion{T}, x)
Base.convert(::Type{Quaternion{T}}, x::Quaternion{T}) where {T} = x
Base.eltype(::Type{Quaternion{T}}) where {T} = T
Base.zero(Q::Quaternion{T}) where {T} = Quaternion(zero(eltype(Q)))
# Let SVector handle the underlying indexing
# TODO: Should this return a view?
Base.@propagate_inbounds Base.getindex(Q::Quaternion, I) = @view Q.coeffs[I]

Base.length(::Quaternion) = 4
function Base.getindex(Q::Quaternion, i::Int)
1 <= i <= 4 || throw(BoundsError(Q, i))
i == 1 && return Q.q1
i == 2 && return Q.q2
i == 3 && return Q.q3
i == 4 && return Q.q4
end

function Base.:*(Q1::Quaternion, Q2::Quaternion)
return Quaternion(Q1.q1 * Q2.q1 - Q1.q2 * Q2.q2 - Q1.q3 * Q2.q3 - Q1.q4 * Q2.q4,
Q1.q1 * Q2.q2 + Q1.q2 * Q2.q1 + Q1.q3 * Q2.q4 - Q1.q4 * Q2.q3,
Q1.q1 * Q2.q3 - Q1.q2 * Q2.q4 + Q1.q3 * Q2.q1 + Q1.q4 * Q2.q2,
Q1.q1 * Q2.q4 + Q1.q2 * Q2.q3 - Q1.q3 * Q2.q2 + Q1.q4 * Q2.q1)
end
Base.real(::Type{Quaternion}) = eltype(Quaternion)
Base.real(::Type{Quaternion{T}}) where {T} = T
Base.real(Q::Quaternion{T}) where {T<:Real} = Q[1]
Base.imag(Q::Quaternion{T}) where {T<:Real} = @view Q.coeffs[2:4]
Base.vec(Q::Quaternion{T}) where {T<:Real} = @view Q.coeffs[2:4]

Base.:*(Q::Quaternion, n::Number) = Quaternion(Q.q1 * n, Q.q2 * n, Q.q3 * n, Q.q4 * n)
Base.:*(n::Number, Q::Quaternion) = Q * n
Base.:/(Q::Quaternion, n::Number) = Quaternion(Q.q1 / n, Q.q2 / n, Q.q3 / n, Q.q4 / n)
function Base.:+(Q1::Quaternion, Q2::Quaternion)
return Quaternion(Q1.q1 + Q2.q1, Q1.q2 + Q2.q2, Q1.q3 + Q2.q3, Q1.q4 + Q2.q4)
end
Base.widen(::Type{Quaternion{T}}) where {T} = Quaternion{widen(T)}
Base.big(::Type{Quaternion{T}}) where {T<:Real} = Quaternion{big(T)}
Base.big(Q::Quaternion{T}) where {T<:Real} = Quaternion{big(T)}(Q)

Base.:-(Q::Quaternion) = Quaternion(-Q.q1, -Q.q2, -Q.q3, -Q.q4)
function Base.:-(Q1::Quaternion, Q2::Quaternion)
return Quaternion(Q1.q1 - Q2.q1, Q1.q2 - Q2.q2, Q1.q3 - Q2.q3, Q1.q4 - Q2.q4)
end
Base.conj(Q::Quaternion) = Quaternion(Q[1], -Q[2], -Q[3], -Q[4])
Base.abs2(Q::Quaternion) = sum(abs2, Q.coeffs)
# TODO: Should do isnan/isinf checks and scale with max of abs of coeffs?
Base.abs(Q::Quaternion) = sqrt(abs2(Q))
Base.inv(Q::Quaternion) = conj(Q) / abs2(Q)
LinearAlgebra.norm(Q::Quaternion) = abs(Q)
LinearAlgebra.normalize(Q::Quaternion) = Q / abs(Q)

function Base.:(==)(Q1::Quaternion, Q2::Quaternion)
return Q1.q1 == Q2.q1 && Q1.q2 == Q2.q2 && Q1.q3 == Q2.q3 && Q1.q4 == Q2.q4
end
Base.:-(Q::Quaternion) = Quaternion(-Q.coeffs)

Base.conj(Q::Quaternion) = Quaternion(Q.q1, -Q.q2, -Q.q3, -Q.q4)
LinearAlgebra.norm(Q::Quaternion) = sqrt(Q.q1^2 + Q.q2^2 + Q.q3^2 + Q.q4^2)
function LinearAlgebra.normalize(Q::Quaternion)
qnorm = norm(Q)
return Quaternion(Q.q1 / qnorm, Q.q2 / qnorm, Q.q3 / qnorm, Q.q4 / qnorm)
Base.:+(Q1::Quaternion, Q2::Quaternion) = Quaternion(Q1.coeffs + Q2.coeffs)
Base.:-(Q1::Quaternion, Q2::Quaternion) = Quaternion(Q1.coeffs - Q2.coeffs)
function Base.:*(Q1::Quaternion, Q2::Quaternion)
return Quaternion(
Q1[1]*Q2[1] - Q1[2]*Q2[2] - Q1[3]*Q2[3] - Q1[4]*Q2[4],
Q1[1]*Q2[2] + Q1[2]*Q2[1] + Q1[3]*Q2[4] - Q1[4]*Q2[3],
Q1[1]*Q2[3] - Q1[2]*Q2[4] + Q1[3]*Q2[1] + Q1[4]*Q2[2],
Q1[1]*Q2[4] + Q1[2]*Q2[3] - Q1[3]*Q2[2] + Q1[4]*Q2[1])
end

Base.:/(Q1::Quaternion, Q2::Quaternion) = Q1 * conj(Q2) / norm(Q2)^2
Base.inv(Q::Quaternion) = conj(Q) / norm(Q)^2
scalar(Q::Quaternion) = Q.q1
vector(Q::Quaternion) = [Q.q2, Q.q3, Q.q4]
function rotvec(v::Vector, Q::Quaternion)
Base.:/(Q1::Quaternion, Q2::Quaternion) = Q1 * inv(Q2)

Base.:*(Q::Quaternion, r::Real) = Quaternion(Q.coeffs * r)
Base.:*(r::Real, Q::Quaternion) = Quaternion(Q.coeffs * r)
Base.:/(Q::Quaternion, r::Real) = Quaternion(Q.coeffs / r)

Base.isreal(Q::Quaternion) = iszero(Q[2]) && iszero(Q[3]) && iszero(Q[4])
Base.isfinite(Q::Quaternion) = isfinite(Q[1]) && isfinite(Q[2]) && isfinite(Q[3]) && isfinite(Q[4])
Base.iszero(Q::Quaternion) = iszero(Q[1]) && iszero(Q[2]) && iszero(Q[3]) && iszero(Q[4])
Base.isone(Q::Quaternion) = isone(Q[1]) && iszero(Q[2]) && iszero(Q[3]) && iszero(Q[4])
Base.isnan(Q::Quaternion) = isnan(Q[1]) || isnan(Q[2]) || isnan(Q[3]) || isnan(Q[4])
Base.isinf(Q::Quaternion) = isinf(Q[1]) || isinf(Q[2]) || isinf(Q[3]) || isinf(Q[4])
Base.isinteger(Q::Quaternion) = isinteger(Q[1]) && isreal(Q)

function rotvec(v::A, Q::Quaternion) where {A<:AbstractVector}
Q = normalize(Q)
return vector(conj(Q) * Quaternion([0; v]) * Q)
return vec(Q * Quaternion(v[1], v[2], v[3]) * conj(Q))
end

function toeuler(Q::Quaternion)
roll = atan(2 * (Q.q1 * Q.q2 + Q.q3 * Q.q4), 1 - 2 * (Q.q2^2 + Q.q3^2))
pitch = asin(2 * (Q.q1 * Q.q3 - Q.q4 * Q.q2))
yaw = atan(2 * (Q.q1 * Q.q4 + Q.q2 * Q.q3), 1 - 2 * (Q.q3^2 + Q.q4^2))
return [(roll * 180 / π), (pitch * 180 / π), (yaw * 180 / π)]
function Base.show(io::IO, Q::Quaternion{T}) where T
print(io, "{")
show(io, T)
print(io, "}")
print(io, " q1: ")
show(io, Q[1])
print(io, ", q2: ")
show(io, Q[2])
print(io, ", q3: ")
show(io, Q[3])
print(io, ", q4: ")
show(io, Q[4])
end

0 comments on commit 254e031

Please sign in to comment.