From b56d128ae787b4be871aefb042c91033630f352e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Manuel=20B=C3=A4renz?= Date: Wed, 8 Feb 2023 10:05:00 +0100 Subject: [PATCH 1/9] Fix whitespace --- docs/docs/usage.md | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/docs/docs/usage.md b/docs/docs/usage.md index d6cd9e9e..f280270e 100644 --- a/docs/docs/usage.md +++ b/docs/docs/usage.md @@ -296,7 +296,7 @@ runWeightedT (WeightedT m) = runStateT m 1 `WeightedT m` is not an instance of `MonadDistribution`, but only as instance of `MonadFactor` (and that, only when `m` is an instance of `Monad`). However, since `StateT` is a monad transformer, there is a function `lift :: m Double -> WeightedT m Double`. -So if we take a `MonadDistribution` instance like `SamplerIO`, then `WeightedT SamplerIO` is an instance of both `MonadDistribution` and `MonadFactor`. Which means it is an instance of `MonadMeasure`. +So if we take a `MonadDistribution` instance like `SamplerIO`, then `Weighted SamplerIO` is an instance of both `MonadDistribution` and `MonadFactor`. Which means it is an instance of `MonadMeasure`. So we can successfully write `(sampler . runWeightedT) sprinkler` and get a program of type `IO (Bool, Log Double)`. When run, this will draw a sample from `sprinkler` along with an **unnormalized** density for that sample. @@ -339,7 +339,7 @@ PopulationT m a ~ m [Log Double -> (a, Log Double)] Note that while `ListT` isn't in general a valid monad transformer, we're not requiring it to be one here. -`PopulationT` is used to represent a collection of particles (in the statistical sense), along with their weights. +`Population` is used to represent a collection of particles (in the statistical sense), along with their weights. There are several useful functions associated with it: @@ -360,7 +360,7 @@ gives [([((),0.5),((),0.5)],1.0)] ``` -Observe how here we have interpreted `(spawn 2)` as of type `PopulationT Enumerator ()`. +Observe how here we have interpreted `(spawn 2)` as of type `Population Enumerator ()`. `resampleGeneric` takes a function to probabilistically select a set of indices from a vector, and makes a new population by selecting those indices. @@ -393,8 +393,8 @@ Summary of key info on `SequentialT`: ```haskell -newtype SequentialT m a = - SequentialT {runSequentialT :: Coroutine (Await ()) m a} +newtype Sequential m a = + Sequential {runSequential :: Coroutine (Await ()) m a} ``` This is a wrapper for the `Coroutine` type applied to the `Await` constructor from `Control.Monad.Coroutine`, which is defined thus: @@ -410,7 +410,7 @@ newtype Await x y = Await (x -> y) Unpacking that: ```haskell -SequentialT m a ~ m (Either (() -> SequentialT m a) a) +Sequential m a ~ m (Either (() -> Sequential m a) a) ``` As usual, `m` is going to be some other probability monad, so understand `SequentialT m a` as representing a program which, after making a random choice or doing conditioning, we either obtain an `a` value, or a paused computation, which when resumed gets us back to a new `SequentialT m a`. @@ -501,11 +501,11 @@ The latter is best understood if you're familiar with the standard use of a free ```haskell newtype SamF a = Random (Double -> a) -newtype DensityT m a = - DensityT {getDensityT :: FT SamF m a} +newtype Density m a = + Density {density :: FT SamF m a} -instance Monad m => MonadDistribution (DensityT m) where - random = DensityT $ liftF (Random id) +instance Monad m => MonadDistribution (Density m) where + random = Density $ liftF (Random id) ``` The monad-bayes implementation uses a more efficient implementation of `FreeT`, namely `FT` from the `free` package, known as the *Church transformed Free monad*. This is a technique explained in https://begriffs.com/posts/2016-02-04-difference-lists-and-codennsity.html. But that only changes the operational semantics - performance aside, it works just the same as the standard `FreeT` datatype. @@ -575,7 +575,7 @@ data Trace a = Trace } ``` -We also need a specification of the probabilistic program in question, free of any particular interpretation. That is precisely what `DensityT` is for. +We also need a specification of the probabilistic program in question, free of any particular interpretation. That is precisely what `Density` is for. The simplest version of `TracedT` is in `Control.Monad.Bayes.TracedT.Basic` @@ -635,13 +635,13 @@ example = do return x ``` -`(enumerator . runWeightedT) example` gives `[((False,0.0),0.5),((True,1.0),0.5)]`. This is quite edifying for understanding `(sampler . runWeightedT) example`. What it says is that there are precisely two ways the program will run, each with equal probability: either you get `False` with weight `0.0` or `True` with weight `1.0`. +`(enumerator . weighted) example` gives `[((False,0.0),0.5),((True,1.0),0.5)]`. This is quite edifying for understanding `(sampler . weighted) example`. What it says is that there are precisely two ways the program will run, each with equal probability: either you get `False` with weight `0.0` or `True` with weight `1.0`. ### Quadrature As described on the section on `Integrator`, we can interpret our probabilistic program of type `MonadDistribution m => m a` as having concrete type `Integrator a`. This views our program as an integrator, allowing us to calculate expectations, probabilities and so on via quadrature (i.e. numerical approximation of an integral). -This can also handle programs of type `MonadMeasure m => m a`, that is, programs with `factor` statements. For these cases, a function `normalize :: WeightedT Integrator a -> Integrator a` is employed. For example, +This can also handle programs of type `MonadMeasure m => m a`, that is, programs with `factor` statements. For these cases, a function `normalize :: Weighted Integrator a -> Integrator a` is employed. For example, ```haskell model :: MonadMeasure m => m Double @@ -652,7 +652,7 @@ model = do return var ``` -is really an unnormalized measure, rather than a probability distribution. `normalize` views it as of type `WeightedT Integrator Double`, which is isomorphic to `(Double -> (Double, Log Double) -> Double)`. This can be used to compute the normalization constant, and divide the integrator's output by it, all within `Integrator`. +is really an unnormalized measure, rather than a probability distribution. `normalize` views it as of type `Weighted Integrator Double`, which is isomorphic to `(Double -> (Double, Log Double) -> Double)`. This can be used to compute the normalization constant, and divide the integrator's output by it, all within `Integrator`. ### Independent forward sampling @@ -796,7 +796,7 @@ pmmh :: pmmh mcmcConf smcConf param model = (mcmc mcmcConf :: T m [(a, Log Double)] -> m [[(a, Log Double)]]) ((param :: T m b) >>= - (runPopulationT :: P (T m) a -> T m [(a, Log Double)]) + (population :: P (T m) a -> T m [(a, Log Double)]) . (pushEvidence :: P (T m) a -> P (T m) a) . Pop.hoist (lift :: forall x. m x -> T m x) . (smc smcConf :: S (P m) a -> P m a) From ebc3b36536d45a6ee2f51c52f6e56f965ea6e78a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Manuel=20B=C3=A4renz?= Date: Fri, 3 Feb 2023 18:06:10 +0100 Subject: [PATCH 2/9] Replace ListT by free list transformer --- docs/docs/usage.md | 14 ++++++++------ src/Control/Monad/Bayes/Class.hs | 8 ++++---- src/Control/Monad/Bayes/Population.hs | 19 +++++++++++-------- 3 files changed, 23 insertions(+), 18 deletions(-) diff --git a/docs/docs/usage.md b/docs/docs/usage.md index f280270e..c0ddf6f0 100644 --- a/docs/docs/usage.md +++ b/docs/docs/usage.md @@ -328,18 +328,20 @@ Summary of key info on `PopulationT`: - `instance MonadFactor m => instance MonadFactor (PopulationT m)` ```haskell -newtype PopulationT m a = PopulationT (WeightedT (ListT m) a) +newtype PopulationT m a = PopulationT (WeightedT (FreeT [] m) a) ``` -So: +The `FreeT []` construction is for branching our probabilistic program into different branches, +corresponding to different choices of a random variable. + +It is interpreted, using `runPopulationT`, to: ```haskell -PopulationT m a ~ m [Log Double -> (a, Log Double)] +m [(a, Log Double)] ``` -Note that while `ListT` isn't in general a valid monad transformer, we're not requiring it to be one here. - -`Population` is used to represent a collection of particles (in the statistical sense), along with their weights. +This shows that `Population` is used to compute a collection of particles (in the statistical sense), along with their weights. +Each `a` corresponds to one particle, and `Log Double` is the type of its weight. There are several useful functions associated with it: diff --git a/src/Control/Monad/Bayes/Class.hs b/src/Control/Monad/Bayes/Class.hs index 6a8c1803..61642bfa 100644 --- a/src/Control/Monad/Bayes/Class.hs +++ b/src/Control/Monad/Bayes/Class.hs @@ -79,9 +79,9 @@ import Control.Monad (replicateM, when) import Control.Monad.Cont (ContT) import Control.Monad.Except (ExceptT, lift) import Control.Monad.Identity (IdentityT) -import Control.Monad.List (ListT) import Control.Monad.Reader (ReaderT) import Control.Monad.State (StateT) +import Control.Monad.Trans.Free.Ap (FreeT) import Control.Monad.Writer (WriterT) import Data.Histogram qualified as H import Data.Histogram.Fill qualified as H @@ -390,15 +390,15 @@ instance (MonadFactor m) => MonadFactor (StateT s m) where instance (MonadMeasure m) => MonadMeasure (StateT s m) -instance (MonadDistribution m) => MonadDistribution (ListT m) where +instance (Applicative f, (MonadDistribution m)) => MonadDistribution (FreeT f m) where random = lift random bernoulli = lift . bernoulli categorical = lift . categorical -instance (MonadFactor m) => MonadFactor (ListT m) where +instance (Applicative f, (MonadFactor m)) => MonadFactor (FreeT f m) where score = lift . score -instance (MonadMeasure m) => MonadMeasure (ListT m) +instance (Applicative f, (MonadMeasure m)) => MonadMeasure (FreeT f m) instance (MonadDistribution m) => MonadDistribution (ContT r m) where random = lift random diff --git a/src/Control/Monad/Bayes/Population.hs b/src/Control/Monad/Bayes/Population.hs index 2a384f77..9708dd51 100644 --- a/src/Control/Monad/Bayes/Population.hs +++ b/src/Control/Monad/Bayes/Population.hs @@ -52,7 +52,10 @@ import Control.Monad.Bayes.Weighted runWeightedT, weightedT, ) -import Control.Monad.List (ListT (..), MonadIO, MonadTrans (..)) +import Control.Monad.Bayes.Weighted qualified as Weighted +import Control.Monad.IO.Class +import Control.Monad.Trans +import Control.Monad.Trans.Free.Ap import Data.List (unfoldr) import Data.List qualified import Data.Maybe (catMaybes) @@ -63,7 +66,7 @@ import Numeric.Log qualified as Log import Prelude hiding (all, sum) -- | A collection of weighted samples, or particles. -newtype PopulationT m a = PopulationT {getPopulationT :: WeightedT (ListT m) a} +newtype PopulationT m a = PopulationT {getPopulationT :: WeightedT (FreeT [] m) a} deriving newtype (Functor, Applicative, Monad, MonadIO, MonadDistribution, MonadFactor, MonadMeasure) instance MonadTrans PopulationT where @@ -71,16 +74,16 @@ instance MonadTrans PopulationT where -- | Explicit representation of the weighted sample with weights in the log -- domain. -runPopulationT :: PopulationT m a -> m [(a, Log Double)] -runPopulationT = runListT . runWeightedT . getPopulationT +runPopulationT :: (Monad m) => PopulationT m a -> m [(a, Log Double)] +runPopulationT = iterT (fmap concat . sequence) . fmap pure . runWeightedT . getPopulationT -- | Explicit representation of the weighted sample. -explicitPopulation :: (Functor m) => PopulationT m a -> m [(a, Double)] +explicitPopulation :: (Monad m) => PopulationT m a -> m [(a, Double)] explicitPopulation = fmap (map (second (exp . ln))) . runPopulationT -- | Initialize 'PopulationT' with a concrete weighted sample. fromWeightedList :: (Monad m) => m [(a, Log Double)] -> PopulationT m a -fromWeightedList = PopulationT . weightedT . ListT +fromWeightedList = PopulationT . weightedT . FreeT . fmap (Free . fmap pure) -- | Increase the sample size by a given factor. -- The weights are adjusted such that their sum is preserved. @@ -265,8 +268,8 @@ popAvg f p = do -- | Applies a transformation to the inner monad. hoist :: - (Monad n) => + (Monad m, (Monad n)) => (forall x. m x -> n x) -> PopulationT m a -> PopulationT n a -hoist f = fromWeightedList . f . runPopulationT +hoist f = PopulationT . Weighted.hoist (hoistFreeT f) . getPopulationT From 0d05c10085723ae0d7f150a3c4264e3d395f494f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Manuel=20B=C3=A4renz?= Date: Wed, 8 Feb 2023 11:04:54 +0100 Subject: [PATCH 3/9] Derive Alternative and MonadPlus for Population --- src/Control/Monad/Bayes/Population.hs | 5 +++-- src/Control/Monad/Bayes/Weighted.hs | 4 +++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/Control/Monad/Bayes/Population.hs b/src/Control/Monad/Bayes/Population.hs index 9708dd51..4c06bdb6 100644 --- a/src/Control/Monad/Bayes/Population.hs +++ b/src/Control/Monad/Bayes/Population.hs @@ -37,8 +37,9 @@ module Control.Monad.Bayes.Population ) where +import Control.Applicative (Alternative) import Control.Arrow (second) -import Control.Monad (replicateM) +import Control.Monad (MonadPlus, replicateM) import Control.Monad.Bayes.Class ( MonadDistribution (categorical, logCategorical, random, uniform), MonadFactor, @@ -67,7 +68,7 @@ import Prelude hiding (all, sum) -- | A collection of weighted samples, or particles. newtype PopulationT m a = PopulationT {getPopulationT :: WeightedT (FreeT [] m) a} - deriving newtype (Functor, Applicative, Monad, MonadIO, MonadDistribution, MonadFactor, MonadMeasure) + deriving newtype (Functor, Applicative, Alternative, Monad, MonadIO, MonadPlus, MonadDistribution, MonadFactor, MonadMeasure) instance MonadTrans PopulationT where lift = PopulationT . lift . lift diff --git a/src/Control/Monad/Bayes/Weighted.hs b/src/Control/Monad/Bayes/Weighted.hs index a1bbe44a..843f430f 100644 --- a/src/Control/Monad/Bayes/Weighted.hs +++ b/src/Control/Monad/Bayes/Weighted.hs @@ -24,6 +24,8 @@ module Control.Monad.Bayes.Weighted ) where +import Control.Applicative (Alternative) +import Control.Monad (MonadPlus) import Control.Monad.Bayes.Class ( MonadDistribution, MonadFactor (..), @@ -36,7 +38,7 @@ import Numeric.Log (Log) -- | Execute the program using the prior distribution, while accumulating likelihood. newtype WeightedT m a = WeightedT (StateT (Log Double) m a) -- StateT is more efficient than WriterT - deriving newtype (Functor, Applicative, Monad, MonadIO, MonadTrans, MonadDistribution) + deriving newtype (Functor, Applicative, Alternative, Monad, MonadIO, MonadPlus, MonadTrans, MonadDistribution) instance (Monad m) => MonadFactor (WeightedT m) where score w = WeightedT (modify (* w)) From 593193529c82751225203b83b5daa58785d0ea75 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Manuel=20B=C3=A4renz?= Date: Wed, 8 Feb 2023 11:05:38 +0100 Subject: [PATCH 4/9] Use standard FreeT instead of applicative --- src/Control/Monad/Bayes/Class.hs | 2 +- src/Control/Monad/Bayes/Population.hs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Control/Monad/Bayes/Class.hs b/src/Control/Monad/Bayes/Class.hs index 61642bfa..4d36fbb6 100644 --- a/src/Control/Monad/Bayes/Class.hs +++ b/src/Control/Monad/Bayes/Class.hs @@ -81,7 +81,7 @@ import Control.Monad.Except (ExceptT, lift) import Control.Monad.Identity (IdentityT) import Control.Monad.Reader (ReaderT) import Control.Monad.State (StateT) -import Control.Monad.Trans.Free.Ap (FreeT) +import Control.Monad.Trans.Free (FreeT) import Control.Monad.Writer (WriterT) import Data.Histogram qualified as H import Data.Histogram.Fill qualified as H diff --git a/src/Control/Monad/Bayes/Population.hs b/src/Control/Monad/Bayes/Population.hs index 4c06bdb6..a62a726e 100644 --- a/src/Control/Monad/Bayes/Population.hs +++ b/src/Control/Monad/Bayes/Population.hs @@ -56,7 +56,7 @@ import Control.Monad.Bayes.Weighted import Control.Monad.Bayes.Weighted qualified as Weighted import Control.Monad.IO.Class import Control.Monad.Trans -import Control.Monad.Trans.Free.Ap +import Control.Monad.Trans.Free import Data.List (unfoldr) import Data.List qualified import Data.Maybe (catMaybes) From eeb4928fc6669719cfd0925e79b90b85fefd9d0c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Manuel=20B=C3=A4renz?= Date: Wed, 8 Feb 2023 16:58:10 +0100 Subject: [PATCH 5/9] Fix RMSMC algorithms --- src/Control/Monad/Bayes/Inference/RMSMC.hs | 10 +++++----- src/Control/Monad/Bayes/Population.hs | 5 +++++ 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/Control/Monad/Bayes/Inference/RMSMC.hs b/src/Control/Monad/Bayes/Inference/RMSMC.hs index d9bc8a9b..a3b7dfb8 100644 --- a/src/Control/Monad/Bayes/Inference/RMSMC.hs +++ b/src/Control/Monad/Bayes/Inference/RMSMC.hs @@ -25,7 +25,7 @@ import Control.Monad.Bayes.Inference.MCMC (MCMCConfig (..)) import Control.Monad.Bayes.Inference.SMC import Control.Monad.Bayes.Population ( PopulationT, - spawn, + flatten, withParticles, ) import Control.Monad.Bayes.Sequential.Coroutine as Seq @@ -50,8 +50,8 @@ rmsmc :: PopulationT m a rmsmc (MCMCConfig {..}) (SMCConfig {..}) = marginal - . S.sequentially (composeCopies numMCMCSteps mhStep . TrStat.hoist resampler) numSteps - . S.hoistFirst (TrStat.hoist (spawn numParticles >>)) + . S.sequentially (composeCopies numMCMCSteps (TrStat.hoist flatten . mhStep) . TrStat.hoist resampler) numSteps + . S.hoistFirst (TrStat.hoist (withParticles numParticles)) -- | Resample-move Sequential Monte Carlo with a more efficient -- tracing representation. @@ -64,7 +64,7 @@ rmsmcBasic :: PopulationT m a rmsmcBasic (MCMCConfig {..}) (SMCConfig {..}) = TrBas.marginal - . S.sequentially (composeCopies numMCMCSteps TrBas.mhStep . TrBas.hoist resampler) numSteps + . S.sequentially (TrBas.hoist flatten . composeCopies numMCMCSteps (TrBas.hoist flatten . TrBas.mhStep) . TrBas.hoist resampler) numSteps . S.hoistFirst (TrBas.hoist (withParticles numParticles)) -- | A variant of resample-move Sequential Monte Carlo @@ -79,7 +79,7 @@ rmsmcDynamic :: PopulationT m a rmsmcDynamic (MCMCConfig {..}) (SMCConfig {..}) = TrDyn.marginal - . S.sequentially (TrDyn.freeze . composeCopies numMCMCSteps TrDyn.mhStep . TrDyn.hoist resampler) numSteps + . S.sequentially (TrDyn.freeze . composeCopies numMCMCSteps (TrDyn.hoist flatten . TrDyn.mhStep) . TrDyn.hoist resampler) numSteps . S.hoistFirst (TrDyn.hoist (withParticles numParticles)) -- | Apply a function a given number of times. diff --git a/src/Control/Monad/Bayes/Population.hs b/src/Control/Monad/Bayes/Population.hs index a62a726e..c520bcc1 100644 --- a/src/Control/Monad/Bayes/Population.hs +++ b/src/Control/Monad/Bayes/Population.hs @@ -34,6 +34,7 @@ module Control.Monad.Bayes.Population collapse, popAvg, withParticles, + flatten, ) where @@ -274,3 +275,7 @@ hoist :: PopulationT m a -> PopulationT n a hoist f = PopulationT . Weighted.hoist (hoistFreeT f) . getPopulationT + +-- | Flatten all layers of the free structure +flatten :: (Monad m) => PopulationT m a -> PopulationT m a +flatten = fromWeightedList . runPopulationT From 0f50c50d6df7cabd226575523c30ec5fb4fbda54 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Manuel=20B=C3=A4renz?= Date: Sun, 31 Dec 2023 17:57:26 +0100 Subject: [PATCH 6/9] Add applicative ListT transformer --- monad-bayes.cabal | 4 ++++ src/Control/Applicative/List.hs | 17 +++++++++++++++++ 2 files changed, 21 insertions(+) create mode 100644 src/Control/Applicative/List.hs diff --git a/monad-bayes.cabal b/monad-bayes.cabal index 8dc7b22b..dcf2fac2 100644 --- a/monad-bayes.cabal +++ b/monad-bayes.cabal @@ -86,6 +86,7 @@ common test-deps library import: deps exposed-modules: + Control.Applicative.List Control.Monad.Bayes.Class Control.Monad.Bayes.Density.Free Control.Monad.Bayes.Density.State @@ -114,8 +115,11 @@ library other-modules: Control.Monad.Bayes.Traced.Common default-language: Haskell2010 default-extensions: + ApplicativeDo BlockArguments + DerivingStrategies FlexibleContexts + GeneralizedNewtypeDeriving ImportQualifiedPost LambdaCase OverloadedStrings diff --git a/src/Control/Applicative/List.hs b/src/Control/Applicative/List.hs new file mode 100644 index 00000000..1ebc618b --- /dev/null +++ b/src/Control/Applicative/List.hs @@ -0,0 +1,17 @@ +module Control.Applicative.List where + +-- base + +import Control.Applicative +import Data.Functor.Compose + +-- | _Applicative_ transformer adding a list/nondeterminism/choice effect. +-- It is not a valid monad transformer, but it is a valid 'Applicative'. +newtype ListT m a = ListT {getListT :: Compose [] m a} + deriving newtype (Functor, Applicative, Alternative) + +lift :: m a -> ListT m a +lift = ListT . Compose . pure + +runListT :: ListT m a -> [m a] +runListT = getCompose . getListT From 4035c9ad03bb9f56f58e425b253d87df1f9e9af8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Manuel=20B=C3=A4renz?= Date: Sun, 31 Dec 2023 18:45:38 +0100 Subject: [PATCH 7/9] Add and use applicative population transformer --- monad-bayes.cabal | 8 +++--- src/Control/Applicative/List.hs | 31 ++++++++++++++++++---- src/Control/Monad/Bayes/Inference/RMSMC.hs | 7 ++--- src/Control/Monad/Bayes/Population.hs | 12 +++++++-- 4 files changed, 44 insertions(+), 14 deletions(-) diff --git a/monad-bayes.cabal b/monad-bayes.cabal index dcf2fac2..f4c26287 100644 --- a/monad-bayes.cabal +++ b/monad-bayes.cabal @@ -63,21 +63,21 @@ common deps , scientific ^>=0.3 , statistics >=0.14.0 && <0.17 , text >=1.2 && <2.1 + , transformers ^>=0.5.6 , vector >=0.12.0 && <0.14 , vty ^>=5.38 common test-deps build-depends: , abstract-par ^>=0.3 - , criterion >=1.5 && <1.7 + , criterion >=1.5 && <1.7 , directory ^>=1.3 , hspec ^>=2.11 , monad-bayes - , optparse-applicative >=0.17 && <0.19 + , optparse-applicative >=0.17 && <0.19 , process ^>=1.6 , QuickCheck ^>=2.14 - , time >=1.9 && <1.13 - , transformers ^>=0.5.6 + , time >=1.9 && <1.13 , typed-process ^>=0.2 autogen-modules: Paths_monad_bayes diff --git a/src/Control/Applicative/List.hs b/src/Control/Applicative/List.hs index 1ebc618b..557c218a 100644 --- a/src/Control/Applicative/List.hs +++ b/src/Control/Applicative/List.hs @@ -1,17 +1,38 @@ +{-# LANGUAGE StandaloneDeriving #-} + module Control.Applicative.List where -- base - import Control.Applicative +-- transformers +import Control.Monad.Trans.Writer.Strict import Data.Functor.Compose +-- log-domain +import Numeric.Log (Log) + +-- * Applicative ListT -- | _Applicative_ transformer adding a list/nondeterminism/choice effect. -- It is not a valid monad transformer, but it is a valid 'Applicative'. -newtype ListT m a = ListT {getListT :: Compose [] m a} +newtype ListT m a = ListT {getListT :: Compose m [] a} deriving newtype (Functor, Applicative, Alternative) -lift :: m a -> ListT m a -lift = ListT . Compose . pure +lift :: (Functor m) => m a -> ListT m a +lift = ListT . Compose . fmap pure -runListT :: ListT m a -> [m a] +runListT :: ListT m a -> m [a] runListT = getCompose . getListT + +-- * Applicative Population transformer + +-- WriterT has to be used instead of WeightedT, +-- since WeightedT uses StateT under the hood, +-- which requires a Monad (ListT m) constraint. +newtype PopulationT m a = PopulationT {getPopulationT :: WriterT (Log Double) (ListT m) a} + deriving newtype (Functor, Applicative, Alternative) + +runPopulationT :: PopulationT m a -> m [(a, Log Double)] +runPopulationT = runListT . runWriterT . getPopulationT + +fromWeightedList :: m [(a, Log Double)] -> PopulationT m a +fromWeightedList = PopulationT . WriterT . ListT . Compose diff --git a/src/Control/Monad/Bayes/Inference/RMSMC.hs b/src/Control/Monad/Bayes/Inference/RMSMC.hs index a3b7dfb8..626eeae0 100644 --- a/src/Control/Monad/Bayes/Inference/RMSMC.hs +++ b/src/Control/Monad/Bayes/Inference/RMSMC.hs @@ -26,6 +26,7 @@ import Control.Monad.Bayes.Inference.SMC import Control.Monad.Bayes.Population ( PopulationT, flatten, + single, withParticles, ) import Control.Monad.Bayes.Sequential.Coroutine as Seq @@ -50,7 +51,7 @@ rmsmc :: PopulationT m a rmsmc (MCMCConfig {..}) (SMCConfig {..}) = marginal - . S.sequentially (composeCopies numMCMCSteps (TrStat.hoist flatten . mhStep) . TrStat.hoist resampler) numSteps + . S.sequentially (composeCopies numMCMCSteps (TrStat.hoist (single . flatten) . mhStep) . TrStat.hoist resampler) numSteps . S.hoistFirst (TrStat.hoist (withParticles numParticles)) -- | Resample-move Sequential Monte Carlo with a more efficient @@ -64,7 +65,7 @@ rmsmcBasic :: PopulationT m a rmsmcBasic (MCMCConfig {..}) (SMCConfig {..}) = TrBas.marginal - . S.sequentially (TrBas.hoist flatten . composeCopies numMCMCSteps (TrBas.hoist flatten . TrBas.mhStep) . TrBas.hoist resampler) numSteps + . S.sequentially (TrBas.hoist (single . flatten) . composeCopies numMCMCSteps (TrBas.hoist (single . flatten) . TrBas.mhStep) . TrBas.hoist resampler) numSteps . S.hoistFirst (TrBas.hoist (withParticles numParticles)) -- | A variant of resample-move Sequential Monte Carlo @@ -79,7 +80,7 @@ rmsmcDynamic :: PopulationT m a rmsmcDynamic (MCMCConfig {..}) (SMCConfig {..}) = TrDyn.marginal - . S.sequentially (TrDyn.freeze . composeCopies numMCMCSteps (TrDyn.hoist flatten . TrDyn.mhStep) . TrDyn.hoist resampler) numSteps + . S.sequentially (TrDyn.freeze . composeCopies numMCMCSteps (TrDyn.hoist (single . flatten) . TrDyn.mhStep) . TrDyn.hoist resampler) numSteps . S.hoistFirst (TrDyn.hoist (withParticles numParticles)) -- | Apply a function a given number of times. diff --git a/src/Control/Monad/Bayes/Population.hs b/src/Control/Monad/Bayes/Population.hs index c520bcc1..9b97dc31 100644 --- a/src/Control/Monad/Bayes/Population.hs +++ b/src/Control/Monad/Bayes/Population.hs @@ -35,10 +35,12 @@ module Control.Monad.Bayes.Population popAvg, withParticles, flatten, + single, ) where import Control.Applicative (Alternative) +import Control.Applicative.List qualified as ApplicativeListT import Control.Arrow (second) import Control.Monad (MonadPlus, replicateM) import Control.Monad.Bayes.Class @@ -277,5 +279,11 @@ hoist :: hoist f = PopulationT . Weighted.hoist (hoistFreeT f) . getPopulationT -- | Flatten all layers of the free structure -flatten :: (Monad m) => PopulationT m a -> PopulationT m a -flatten = fromWeightedList . runPopulationT +flatten :: (Monad m) => PopulationT m a -> ApplicativeListT.PopulationT m a +flatten = ApplicativeListT.fromWeightedList . runPopulationT + +-- | Create a population from a single layer of branching computations. +-- +-- Similar to 'fromWeightedListT'. +single :: (Monad m) => ApplicativeListT.PopulationT m a -> PopulationT m a +single = fromWeightedList . ApplicativeListT.runPopulationT From 084862f72310f93101cc294a553f06e3c09d5b10 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Manuel=20B=C3=A4renz?= Date: Wed, 3 Jan 2024 09:50:32 +0100 Subject: [PATCH 8/9] Applicative Population transformer --- monad-bayes.cabal | 1 + src/Control/Applicative/List.hs | 21 ++------------ src/Control/Monad/Bayes/Population.hs | 17 +++++++---- .../Monad/Bayes/Population/Applicative.hs | 29 +++++++++++++++++++ 4 files changed, 44 insertions(+), 24 deletions(-) create mode 100644 src/Control/Monad/Bayes/Population/Applicative.hs diff --git a/monad-bayes.cabal b/monad-bayes.cabal index f4c26287..ab68cf98 100644 --- a/monad-bayes.cabal +++ b/monad-bayes.cabal @@ -101,6 +101,7 @@ library Control.Monad.Bayes.Inference.TUI Control.Monad.Bayes.Integrator Control.Monad.Bayes.Population + Control.Monad.Bayes.Population.Applicative Control.Monad.Bayes.Sampler.Lazy Control.Monad.Bayes.Sampler.Strict Control.Monad.Bayes.Sequential.Coroutine diff --git a/src/Control/Applicative/List.hs b/src/Control/Applicative/List.hs index 557c218a..a6a0a99a 100644 --- a/src/Control/Applicative/List.hs +++ b/src/Control/Applicative/List.hs @@ -4,11 +4,7 @@ module Control.Applicative.List where -- base import Control.Applicative --- transformers -import Control.Monad.Trans.Writer.Strict import Data.Functor.Compose --- log-domain -import Numeric.Log (Log) -- * Applicative ListT @@ -17,22 +13,11 @@ import Numeric.Log (Log) newtype ListT m a = ListT {getListT :: Compose m [] a} deriving newtype (Functor, Applicative, Alternative) +listT :: m [a] -> ListT m a +listT = ListT . Compose + lift :: (Functor m) => m a -> ListT m a lift = ListT . Compose . fmap pure runListT :: ListT m a -> m [a] runListT = getCompose . getListT - --- * Applicative Population transformer - --- WriterT has to be used instead of WeightedT, --- since WeightedT uses StateT under the hood, --- which requires a Monad (ListT m) constraint. -newtype PopulationT m a = PopulationT {getPopulationT :: WriterT (Log Double) (ListT m) a} - deriving newtype (Functor, Applicative, Alternative) - -runPopulationT :: PopulationT m a -> m [(a, Log Double)] -runPopulationT = runListT . runWriterT . getPopulationT - -fromWeightedList :: m [(a, Log Double)] -> PopulationT m a -fromWeightedList = PopulationT . WriterT . ListT . Compose diff --git a/src/Control/Monad/Bayes/Population.hs b/src/Control/Monad/Bayes/Population.hs index 9b97dc31..aa177de1 100644 --- a/src/Control/Monad/Bayes/Population.hs +++ b/src/Control/Monad/Bayes/Population.hs @@ -40,7 +40,6 @@ module Control.Monad.Bayes.Population where import Control.Applicative (Alternative) -import Control.Applicative.List qualified as ApplicativeListT import Control.Arrow (second) import Control.Monad (MonadPlus, replicateM) import Control.Monad.Bayes.Class @@ -49,6 +48,7 @@ import Control.Monad.Bayes.Class MonadMeasure, factor, ) +import Control.Monad.Bayes.Population.Applicative qualified as Applicative import Control.Monad.Bayes.Weighted ( WeightedT, applyWeight, @@ -70,6 +70,11 @@ import Numeric.Log qualified as Log import Prelude hiding (all, sum) -- | A collection of weighted samples, or particles. +-- +-- This monad transformer is internally represented as a free monad, +-- which means that each layer of its computation contains a collection of weighted samples. +-- These can be flattened with 'flatten', +-- but the result is not a monad anymore. newtype PopulationT m a = PopulationT {getPopulationT :: WeightedT (FreeT [] m) a} deriving newtype (Functor, Applicative, Alternative, Monad, MonadIO, MonadPlus, MonadDistribution, MonadFactor, MonadMeasure) @@ -278,12 +283,12 @@ hoist :: PopulationT n a hoist f = PopulationT . Weighted.hoist (hoistFreeT f) . getPopulationT --- | Flatten all layers of the free structure -flatten :: (Monad m) => PopulationT m a -> ApplicativeListT.PopulationT m a -flatten = ApplicativeListT.fromWeightedList . runPopulationT +-- | Flatten all layers of the free structure. +flatten :: (Monad m) => PopulationT m a -> Applicative.PopulationT m a +flatten = Applicative.fromWeightedList . runPopulationT -- | Create a population from a single layer of branching computations. -- -- Similar to 'fromWeightedListT'. -single :: (Monad m) => ApplicativeListT.PopulationT m a -> PopulationT m a -single = fromWeightedList . ApplicativeListT.runPopulationT +single :: (Monad m) => Applicative.PopulationT m a -> PopulationT m a +single = fromWeightedList . Applicative.runPopulationT diff --git a/src/Control/Monad/Bayes/Population/Applicative.hs b/src/Control/Monad/Bayes/Population/Applicative.hs new file mode 100644 index 00000000..7f13d484 --- /dev/null +++ b/src/Control/Monad/Bayes/Population/Applicative.hs @@ -0,0 +1,29 @@ +-- | 'PopulationT' turns a single sample into a collection of weighted samples. +-- +-- This module contains an _'Applicative'_ transformer corresponding to the Population monad transformer from the article. +-- It is based on the old-fashioned 'ListT', which is not a valid monad transformer, but a valid applicative transformer. +-- The corresponding monad transformer is contained in 'Control.Monad.Bayes.Population'. +-- One can convert from the monad transformer to the applicative transformer by 'flatten'ing. +module Control.Monad.Bayes.Population.Applicative where + +import Control.Applicative +import Control.Applicative.List +import Control.Monad.Trans.Writer.Strict +import Data.Functor.Compose +import Numeric.Log (Log) + +-- * Applicative Population transformer + +-- WriterT has to be used instead of WeightedT, +-- since WeightedT uses StateT under the hood, +-- which requires a Monad (ListT m) constraint. + +-- | A collection of weighted samples, or particles. +newtype PopulationT m a = PopulationT {getPopulationT :: WriterT (Log Double) (ListT m) a} + deriving newtype (Functor, Applicative, Alternative) + +runPopulationT :: PopulationT m a -> m [(a, Log Double)] +runPopulationT = runListT . runWriterT . getPopulationT + +fromWeightedList :: m [(a, Log Double)] -> PopulationT m a +fromWeightedList = PopulationT . WriterT . listT From 4f6b2977f82b135ea0bbe0ca3d09585d5f3b345e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Manuel=20B=C3=A4renz?= Date: Wed, 3 Jan 2024 09:52:39 +0100 Subject: [PATCH 9/9] Unsuccessfull attempt at fixing SMC2 --- src/Control/Monad/Bayes/Inference/RMSMC.hs | 4 +-- src/Control/Monad/Bayes/Inference/SMC.hs | 29 +++++++++++++++++-- src/Control/Monad/Bayes/Inference/SMC2.hs | 12 ++++++-- src/Control/Monad/Bayes/Population.hs | 2 +- .../Monad/Bayes/Sequential/Coroutine.hs | 2 ++ src/Control/Monad/Bayes/Traced/Static.hs | 24 +++++++++------ 6 files changed, 56 insertions(+), 17 deletions(-) diff --git a/src/Control/Monad/Bayes/Inference/RMSMC.hs b/src/Control/Monad/Bayes/Inference/RMSMC.hs index 626eeae0..f86a4c33 100644 --- a/src/Control/Monad/Bayes/Inference/RMSMC.hs +++ b/src/Control/Monad/Bayes/Inference/RMSMC.hs @@ -51,8 +51,8 @@ rmsmc :: PopulationT m a rmsmc (MCMCConfig {..}) (SMCConfig {..}) = marginal - . S.sequentially (composeCopies numMCMCSteps (TrStat.hoist (single . flatten) . mhStep) . TrStat.hoist resampler) numSteps - . S.hoistFirst (TrStat.hoist (withParticles numParticles)) + . S.sequentially (composeCopies numMCMCSteps (TrStat.hoistModel (single . flatten) . TrStat.hoist (single . flatten) . mhStep) . TrStat.hoist resampler) numSteps + . S.hoistFirst (TrStat.hoistModel (single . flatten) . TrStat.hoist (withParticles numParticles)) -- | Resample-move Sequential Monte Carlo with a more efficient -- tracing representation. diff --git a/src/Control/Monad/Bayes/Inference/SMC.hs b/src/Control/Monad/Bayes/Inference/SMC.hs index 3f3a30b2..a729dc85 100644 --- a/src/Control/Monad/Bayes/Inference/SMC.hs +++ b/src/Control/Monad/Bayes/Inference/SMC.hs @@ -22,11 +22,18 @@ where import Control.Monad.Bayes.Class (MonadDistribution, MonadMeasure) import Control.Monad.Bayes.Population - ( PopulationT, + ( PopulationT (..), + flatten, pushEvidence, + single, withParticles, ) +import Control.Monad.Bayes.Population.Applicative qualified as Applicative import Control.Monad.Bayes.Sequential.Coroutine as Coroutine +import Control.Monad.Bayes.Sequential.Coroutine qualified as SequentialT +import Control.Monad.Bayes.Weighted (WeightedT (..), weightedT) +import Control.Monad.Coroutine +import Control.Monad.Trans.Free (FreeF (..), FreeT (..)) data SMCConfig m = SMCConfig { resampler :: forall x. PopulationT m x -> PopulationT m x, @@ -34,6 +41,19 @@ data SMCConfig m = SMCConfig numParticles :: Int } +sequentialToPopulation :: (Monad m) => Coroutine.SequentialT (Applicative.PopulationT m) a -> PopulationT m a +sequentialToPopulation = + PopulationT + . weightedT + . coroutineToFree + . Coroutine.runSequentialT + where + coroutineToFree = + FreeT + . fmap (Free . fmap (\(cont, p) -> either (coroutineToFree . extract) (pure . (,p)) cont)) + . Applicative.runPopulationT + . resume + -- | Sequential importance resampling. -- Basically an SMC template that takes a custom resampler. smc :: @@ -42,12 +62,15 @@ smc :: Coroutine.SequentialT (PopulationT m) a -> PopulationT m a smc SMCConfig {..} = - Coroutine.sequentially resampler numSteps + (single . flatten) + . Coroutine.sequentially resampler numSteps + . SequentialT.hoist (single . flatten) . Coroutine.hoistFirst (withParticles numParticles) + . SequentialT.hoist (single . flatten) -- | Sequential Monte Carlo with multinomial resampling at each timestep. -- Weights are normalized at each timestep and the total weight is pushed -- as a score into the transformed monad. smcPush :: (MonadMeasure m) => SMCConfig m -> Coroutine.SequentialT (PopulationT m) a -> PopulationT m a -smcPush config = smc config {resampler = (pushEvidence . resampler config)} +smcPush config = smc config {resampler = (single . flatten . pushEvidence . resampler config)} diff --git a/src/Control/Monad/Bayes/Inference/SMC2.hs b/src/Control/Monad/Bayes/Inference/SMC2.hs index 5570a2ba..530d8932 100644 --- a/src/Control/Monad/Bayes/Inference/SMC2.hs +++ b/src/Control/Monad/Bayes/Inference/SMC2.hs @@ -27,8 +27,10 @@ import Control.Monad.Bayes.Class import Control.Monad.Bayes.Inference.MCMC import Control.Monad.Bayes.Inference.RMSMC (rmsmc) import Control.Monad.Bayes.Inference.SMC (SMCConfig (SMCConfig, numParticles, numSteps, resampler), smcPush) -import Control.Monad.Bayes.Population as Pop (PopulationT, resampleMultinomial, runPopulationT) +import Control.Monad.Bayes.Population as Pop (PopulationT, flatten, resampleMultinomial, runPopulationT, single) +import Control.Monad.Bayes.Population qualified as PopulationT import Control.Monad.Bayes.Sequential.Coroutine (SequentialT) +import Control.Monad.Bayes.Sequential.Coroutine qualified as SequentialT import Control.Monad.Bayes.Traced import Control.Monad.Trans (MonadTrans (..)) import Numeric.Log (Log) @@ -71,4 +73,10 @@ smc2 k n p t param m = rmsmc MCMCConfig {numMCMCSteps = t, proposal = SingleSiteMH, numBurnIn = 0} SMCConfig {numParticles = p, numSteps = k, resampler = resampleMultinomial} - (param >>= setup . runPopulationT . smcPush (SMCConfig {numSteps = k, numParticles = n, resampler = resampleMultinomial}) . m) + (flattenSequentiallyTraced param >>= setup . runPopulationT . smcPush (SMCConfig {numSteps = k, numParticles = n, resampler = resampleMultinomial}) . flattenSMC2 . m) + +flattenSequentiallyTraced :: (Monad m) => SequentialT (TracedT (PopulationT m)) a -> SequentialT (TracedT (PopulationT m)) a +flattenSequentiallyTraced = SequentialT.hoist $ hoistModel (single . flatten) . hoist (single . flatten) + +flattenSMC2 :: (Monad m) => SequentialT (PopulationT (SMC2 m)) a -> SequentialT (PopulationT (SMC2 m)) a + flattenSMC2 = SequentialT.hoist $ single . flatten . PopulationT.hoist (SMC2 . flattenSequentiallyTraced . setup) diff --git a/src/Control/Monad/Bayes/Population.hs b/src/Control/Monad/Bayes/Population.hs index aa177de1..bad1d01e 100644 --- a/src/Control/Monad/Bayes/Population.hs +++ b/src/Control/Monad/Bayes/Population.hs @@ -238,7 +238,7 @@ pushEvidence :: (MonadFactor m) => PopulationT m a -> PopulationT m a -pushEvidence = hoist applyWeight . extractEvidence +pushEvidence = single . flatten . hoist applyWeight . extractEvidence -- | A properly weighted single sample, that is one picked at random according -- to the weights, with the sum of all weights. diff --git a/src/Control/Monad/Bayes/Sequential/Coroutine.hs b/src/Control/Monad/Bayes/Sequential/Coroutine.hs index 926c3db1..8b3b5fc1 100644 --- a/src/Control/Monad/Bayes/Sequential/Coroutine.hs +++ b/src/Control/Monad/Bayes/Sequential/Coroutine.hs @@ -22,6 +22,8 @@ module Control.Monad.Bayes.Sequential.Coroutine hoist, sequentially, sis, + runSequentialT, + extract, ) where diff --git a/src/Control/Monad/Bayes/Traced/Static.hs b/src/Control/Monad/Bayes/Traced/Static.hs index fc99b327..209b507d 100644 --- a/src/Control/Monad/Bayes/Traced/Static.hs +++ b/src/Control/Monad/Bayes/Traced/Static.hs @@ -12,6 +12,7 @@ module Control.Monad.Bayes.Traced.Static ( TracedT (..), hoist, + hoistModel, marginal, mhStep, mh, @@ -25,6 +26,7 @@ import Control.Monad.Bayes.Class MonadMeasure, ) import Control.Monad.Bayes.Density.Free (DensityT) +import Control.Monad.Bayes.Density.Free qualified as DensityT import Control.Monad.Bayes.Traced.Common ( Trace (..), bind, @@ -33,6 +35,7 @@ import Control.Monad.Bayes.Traced.Common singleton, ) import Control.Monad.Bayes.Weighted (WeightedT) +import Control.Monad.Bayes.Weighted qualified as WeightedT import Control.Monad.Trans (MonadTrans (..)) import Data.List.NonEmpty as NE (NonEmpty ((:|)), toList) @@ -72,6 +75,9 @@ instance (MonadMeasure m) => MonadMeasure (TracedT m) hoist :: (forall x. m x -> m x) -> TracedT m a -> TracedT m a hoist f (TracedT m d) = TracedT m (f d) +hoistModel :: (Monad m) => (forall x. m x -> m x) -> TracedT m a -> TracedT m a +hoistModel f (TracedT m d) = TracedT (WeightedT.hoist (DensityT.hoist f) m) d + -- | Discard the trace and supporting infrastructure. marginal :: (Monad m) => TracedT m a -> m a marginal (TracedT _ d) = fmap output d @@ -98,15 +104,15 @@ mhStep (TracedT m d) = TracedT m d' -- * What is the probability that it is the weekend? -- -- >>> :{ --- let --- bus = do x <- bernoulli (2/7) --- let rate = if x then 3 else 10 --- factor $ poissonPdf rate 4 --- return x --- mhRunBusSingleObs = do --- let nSamples = 2 --- sampleIOfixed $ unweighted $ mh nSamples bus --- in mhRunBusSingleObs +-- let +-- bus = do x <- bernoulli (2/7) +-- let rate = if x then 3 else 10 +-- factor $ poissonPdf rate 4 +-- return x +-- mhRunBusSingleObs = do +-- let nSamples = 2 +-- sampleIOfixed $ unweighted $ mh nSamples bus +-- in mhRunBusSingleObs -- :} -- [True,True,True] --