Skip to content

Commit

Permalink
nicer readme
Browse files Browse the repository at this point in the history
  • Loading branch information
ACEsuit committed May 19, 2024
1 parent b00cfd0 commit dc551d1
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 15 deletions.
45 changes: 35 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,41 @@

This is a small package, spun out of [`ACE.jl`](https://github.com/ACEsuit/ACE.jl). The original intended use case is managing (lists of) decorated particles, i.e., point clouds embedded in some vector space, where each point is decorated with additional features such as chemical species, charge, mass, etc.

Example usage:
### Example usage

```julia
using DecoratedParticles, StaticArrays
# a silicon atom
x = PState(𝐫 = randn(SVector{3, Float64}), z = 14)
# ⟨𝐫:[-0.91, -0.87, -0.42], z:14⟩
# extract its position
x.𝐫
using DecoratedParticles, StaticArrays, LinearAlgebra, Zygote
using DecoratedParticles: PState, VState
DP = DecoratedParticles

x1 = PState( 𝐫 = randn(SVector{3, Float64}), z = 14 )
# 〖𝐫:[-0.74, -2.27, -0.83], z:14〗
x2 = PState( 𝐫 = randn(SVector{3, Float64}), z = 14 )
# 〖𝐫:[-0.63, 0.67, -0.56], z:14〗
𝐫12 = VState(x2 - x1)
# ⦅𝐫:[0.11, 2.94, 0.27]⦆

# extract the position
x1.𝐫
# 3-element SVector{3, Float64} with indices SOneTo(3):
# -1.1469536186585183
# -0.1832512302259138
# 1.0216715637205427
# -0.7424735839283951
# -2.271376247109223
# -0.8265064008465374

# arithmetic on particle states
x1 + 𝐫12 x2
# true

f(X) = sum(DP.normsq(x.𝐫) for x in X)
f([x1, x2])
# 4.115...

# the gradient of a PState is a VState
g = Zygote.gradient(f, [x1, x2])[1]
# 2-element Vector{VState{@NamedTuple{𝐫::SVector{3, Float64}}}}:
#⦅𝐫:[-1.48, -4.54, -1.65]⦆
#⦅𝐫:[-1.26, 1.35, -1.12]⦆

g[1].𝐫 2 * x1.𝐫
# true
```
20 changes: 15 additions & 5 deletions src/states.jl
Original file line number Diff line number Diff line change
Expand Up @@ -378,10 +378,20 @@ normsq(x::SVector) = dot(x, x)

# TODO: check whether this is still needed
# this function makes sure that gradients w.r.t. a PState become a VState
function rrule(::typeof(getproperty), X::XState, sym::Symbol)
# The original iple,entation is a bit more complicated, but I don't understand
# it anymore. I'll need to revisit this a bit.
# function rrule(::typeof(getproperty), X::XState, sym::Symbol)
# val = getproperty(X, sym)
# return val, w -> ( NoTangent(),
# vstate_type(w[1], X)( NamedTuple{(sym,)}((w,)) ),
# NoTangent() )
# end


import ChainRulesCore
import ChainRulesCore: rrule

function rrule(::typeof(Base.getproperty), X::XState, sym::Symbol)
val = getproperty(X, sym)
return val, w -> ( NoTangent(),
vstate_type(w[1], X)( NamedTuple{(sym,)}((w,)) ),
NoTangent() )
return val, Δ -> (NoTangent(), VState(NamedTuple{(sym,)}((Δ,))), NoTangent())
end

34 changes: 34 additions & 0 deletions test/_readme.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
using DecoratedParticles, StaticArrays, LinearAlgebra, Zygote
using DecoratedParticles: PState, VState
DP = DecoratedParticles

x1 = PState( 𝐫 = randn(SVector{3, Float64}), z = 14 )
# 〖𝐫:[-0.74, -2.27, -0.83], z:14〗
x2 = PState( 𝐫 = randn(SVector{3, Float64}), z = 14 )
# 〖𝐫:[-0.63, 0.67, -0.56], z:14〗
𝐫12 = VState(x2 - x1)
# ⦅𝐫:[0.11, 2.94, 0.27]⦆

# extract the position
x1.𝐫
# 3-element SVector{3, Float64} with indices SOneTo(3):
# -0.7424735839283951
# -2.271376247109223
# -0.8265064008465374

# arithmetic on particle states
x1 + 𝐫12 x2
# true

f(X) = sum(DP.normsq(x.𝐫) for x in X)
f([x1, x2])
# 4.115...

# the gradient of a PState is a VState
g = Zygote.gradient(f, [x1, x2])[1]
# 2-element Vector{VState{@NamedTuple{𝐫::SVector{3, Float64}}}}:
#⦅𝐫:[-1.48, -4.54, -1.65]⦆
#⦅𝐫:[-1.26, 1.35, -1.12]⦆

g[1].𝐫 2 * x1.𝐫
# true

0 comments on commit dc551d1

Please sign in to comment.