diff --git a/docs/docs/usage.md b/docs/docs/usage.md index d6cd9e9e..c0ddf6f0 100644 --- a/docs/docs/usage.md +++ b/docs/docs/usage.md @@ -296,7 +296,7 @@ runWeightedT (WeightedT m) = runStateT m 1 `WeightedT m` is not an instance of `MonadDistribution`, but only as instance of `MonadFactor` (and that, only when `m` is an instance of `Monad`). However, since `StateT` is a monad transformer, there is a function `lift :: m Double -> WeightedT m Double`. -So if we take a `MonadDistribution` instance like `SamplerIO`, then `WeightedT SamplerIO` is an instance of both `MonadDistribution` and `MonadFactor`. Which means it is an instance of `MonadMeasure`. +So if we take a `MonadDistribution` instance like `SamplerIO`, then `Weighted SamplerIO` is an instance of both `MonadDistribution` and `MonadFactor`. Which means it is an instance of `MonadMeasure`. So we can successfully write `(sampler . runWeightedT) sprinkler` and get a program of type `IO (Bool, Log Double)`. When run, this will draw a sample from `sprinkler` along with an **unnormalized** density for that sample. @@ -328,18 +328,20 @@ Summary of key info on `PopulationT`: - `instance MonadFactor m => instance MonadFactor (PopulationT m)` ```haskell -newtype PopulationT m a = PopulationT (WeightedT (ListT m) a) +newtype PopulationT m a = PopulationT (WeightedT (FreeT [] m) a) ``` -So: +The `FreeT []` construction is for branching our probabilistic program into different branches, +corresponding to different choices of a random variable. + +It is interpreted, using `runPopulationT`, to: ```haskell -PopulationT m a ~ m [Log Double -> (a, Log Double)] +m [(a, Log Double)] ``` -Note that while `ListT` isn't in general a valid monad transformer, we're not requiring it to be one here. - -`PopulationT` is used to represent a collection of particles (in the statistical sense), along with their weights. +This shows that `Population` is used to compute a collection of particles (in the statistical sense), along with their weights. +Each `a` corresponds to one particle, and `Log Double` is the type of its weight. There are several useful functions associated with it: @@ -360,7 +362,7 @@ gives [([((),0.5),((),0.5)],1.0)] ``` -Observe how here we have interpreted `(spawn 2)` as of type `PopulationT Enumerator ()`. +Observe how here we have interpreted `(spawn 2)` as of type `Population Enumerator ()`. `resampleGeneric` takes a function to probabilistically select a set of indices from a vector, and makes a new population by selecting those indices. @@ -393,8 +395,8 @@ Summary of key info on `SequentialT`: ```haskell -newtype SequentialT m a = - SequentialT {runSequentialT :: Coroutine (Await ()) m a} +newtype Sequential m a = + Sequential {runSequential :: Coroutine (Await ()) m a} ``` This is a wrapper for the `Coroutine` type applied to the `Await` constructor from `Control.Monad.Coroutine`, which is defined thus: @@ -410,7 +412,7 @@ newtype Await x y = Await (x -> y) Unpacking that: ```haskell -SequentialT m a ~ m (Either (() -> SequentialT m a) a) +Sequential m a ~ m (Either (() -> Sequential m a) a) ``` As usual, `m` is going to be some other probability monad, so understand `SequentialT m a` as representing a program which, after making a random choice or doing conditioning, we either obtain an `a` value, or a paused computation, which when resumed gets us back to a new `SequentialT m a`. @@ -501,11 +503,11 @@ The latter is best understood if you're familiar with the standard use of a free ```haskell newtype SamF a = Random (Double -> a) -newtype DensityT m a = - DensityT {getDensityT :: FT SamF m a} +newtype Density m a = + Density {density :: FT SamF m a} -instance Monad m => MonadDistribution (DensityT m) where - random = DensityT $ liftF (Random id) +instance Monad m => MonadDistribution (Density m) where + random = Density $ liftF (Random id) ``` The monad-bayes implementation uses a more efficient implementation of `FreeT`, namely `FT` from the `free` package, known as the *Church transformed Free monad*. This is a technique explained in https://begriffs.com/posts/2016-02-04-difference-lists-and-codennsity.html. But that only changes the operational semantics - performance aside, it works just the same as the standard `FreeT` datatype. @@ -575,7 +577,7 @@ data Trace a = Trace } ``` -We also need a specification of the probabilistic program in question, free of any particular interpretation. That is precisely what `DensityT` is for. +We also need a specification of the probabilistic program in question, free of any particular interpretation. That is precisely what `Density` is for. The simplest version of `TracedT` is in `Control.Monad.Bayes.TracedT.Basic` @@ -635,13 +637,13 @@ example = do return x ``` -`(enumerator . runWeightedT) example` gives `[((False,0.0),0.5),((True,1.0),0.5)]`. This is quite edifying for understanding `(sampler . runWeightedT) example`. What it says is that there are precisely two ways the program will run, each with equal probability: either you get `False` with weight `0.0` or `True` with weight `1.0`. +`(enumerator . weighted) example` gives `[((False,0.0),0.5),((True,1.0),0.5)]`. This is quite edifying for understanding `(sampler . weighted) example`. What it says is that there are precisely two ways the program will run, each with equal probability: either you get `False` with weight `0.0` or `True` with weight `1.0`. ### Quadrature As described on the section on `Integrator`, we can interpret our probabilistic program of type `MonadDistribution m => m a` as having concrete type `Integrator a`. This views our program as an integrator, allowing us to calculate expectations, probabilities and so on via quadrature (i.e. numerical approximation of an integral). -This can also handle programs of type `MonadMeasure m => m a`, that is, programs with `factor` statements. For these cases, a function `normalize :: WeightedT Integrator a -> Integrator a` is employed. For example, +This can also handle programs of type `MonadMeasure m => m a`, that is, programs with `factor` statements. For these cases, a function `normalize :: Weighted Integrator a -> Integrator a` is employed. For example, ```haskell model :: MonadMeasure m => m Double @@ -652,7 +654,7 @@ model = do return var ``` -is really an unnormalized measure, rather than a probability distribution. `normalize` views it as of type `WeightedT Integrator Double`, which is isomorphic to `(Double -> (Double, Log Double) -> Double)`. This can be used to compute the normalization constant, and divide the integrator's output by it, all within `Integrator`. +is really an unnormalized measure, rather than a probability distribution. `normalize` views it as of type `Weighted Integrator Double`, which is isomorphic to `(Double -> (Double, Log Double) -> Double)`. This can be used to compute the normalization constant, and divide the integrator's output by it, all within `Integrator`. ### Independent forward sampling @@ -796,7 +798,7 @@ pmmh :: pmmh mcmcConf smcConf param model = (mcmc mcmcConf :: T m [(a, Log Double)] -> m [[(a, Log Double)]]) ((param :: T m b) >>= - (runPopulationT :: P (T m) a -> T m [(a, Log Double)]) + (population :: P (T m) a -> T m [(a, Log Double)]) . (pushEvidence :: P (T m) a -> P (T m) a) . Pop.hoist (lift :: forall x. m x -> T m x) . (smc smcConf :: S (P m) a -> P m a) diff --git a/monad-bayes.cabal b/monad-bayes.cabal index 8dc7b22b..ab68cf98 100644 --- a/monad-bayes.cabal +++ b/monad-bayes.cabal @@ -63,21 +63,21 @@ common deps , scientific ^>=0.3 , statistics >=0.14.0 && <0.17 , text >=1.2 && <2.1 + , transformers ^>=0.5.6 , vector >=0.12.0 && <0.14 , vty ^>=5.38 common test-deps build-depends: , abstract-par ^>=0.3 - , criterion >=1.5 && <1.7 + , criterion >=1.5 && <1.7 , directory ^>=1.3 , hspec ^>=2.11 , monad-bayes - , optparse-applicative >=0.17 && <0.19 + , optparse-applicative >=0.17 && <0.19 , process ^>=1.6 , QuickCheck ^>=2.14 - , time >=1.9 && <1.13 - , transformers ^>=0.5.6 + , time >=1.9 && <1.13 , typed-process ^>=0.2 autogen-modules: Paths_monad_bayes @@ -86,6 +86,7 @@ common test-deps library import: deps exposed-modules: + Control.Applicative.List Control.Monad.Bayes.Class Control.Monad.Bayes.Density.Free Control.Monad.Bayes.Density.State @@ -100,6 +101,7 @@ library Control.Monad.Bayes.Inference.TUI Control.Monad.Bayes.Integrator Control.Monad.Bayes.Population + Control.Monad.Bayes.Population.Applicative Control.Monad.Bayes.Sampler.Lazy Control.Monad.Bayes.Sampler.Strict Control.Monad.Bayes.Sequential.Coroutine @@ -114,8 +116,11 @@ library other-modules: Control.Monad.Bayes.Traced.Common default-language: Haskell2010 default-extensions: + ApplicativeDo BlockArguments + DerivingStrategies FlexibleContexts + GeneralizedNewtypeDeriving ImportQualifiedPost LambdaCase OverloadedStrings diff --git a/src/Control/Applicative/List.hs b/src/Control/Applicative/List.hs new file mode 100644 index 00000000..a6a0a99a --- /dev/null +++ b/src/Control/Applicative/List.hs @@ -0,0 +1,23 @@ +{-# LANGUAGE StandaloneDeriving #-} + +module Control.Applicative.List where + +-- base +import Control.Applicative +import Data.Functor.Compose + +-- * Applicative ListT + +-- | _Applicative_ transformer adding a list/nondeterminism/choice effect. +-- It is not a valid monad transformer, but it is a valid 'Applicative'. +newtype ListT m a = ListT {getListT :: Compose m [] a} + deriving newtype (Functor, Applicative, Alternative) + +listT :: m [a] -> ListT m a +listT = ListT . Compose + +lift :: (Functor m) => m a -> ListT m a +lift = ListT . Compose . fmap pure + +runListT :: ListT m a -> m [a] +runListT = getCompose . getListT diff --git a/src/Control/Monad/Bayes/Class.hs b/src/Control/Monad/Bayes/Class.hs index 6a8c1803..4d36fbb6 100644 --- a/src/Control/Monad/Bayes/Class.hs +++ b/src/Control/Monad/Bayes/Class.hs @@ -79,9 +79,9 @@ import Control.Monad (replicateM, when) import Control.Monad.Cont (ContT) import Control.Monad.Except (ExceptT, lift) import Control.Monad.Identity (IdentityT) -import Control.Monad.List (ListT) import Control.Monad.Reader (ReaderT) import Control.Monad.State (StateT) +import Control.Monad.Trans.Free (FreeT) import Control.Monad.Writer (WriterT) import Data.Histogram qualified as H import Data.Histogram.Fill qualified as H @@ -390,15 +390,15 @@ instance (MonadFactor m) => MonadFactor (StateT s m) where instance (MonadMeasure m) => MonadMeasure (StateT s m) -instance (MonadDistribution m) => MonadDistribution (ListT m) where +instance (Applicative f, (MonadDistribution m)) => MonadDistribution (FreeT f m) where random = lift random bernoulli = lift . bernoulli categorical = lift . categorical -instance (MonadFactor m) => MonadFactor (ListT m) where +instance (Applicative f, (MonadFactor m)) => MonadFactor (FreeT f m) where score = lift . score -instance (MonadMeasure m) => MonadMeasure (ListT m) +instance (Applicative f, (MonadMeasure m)) => MonadMeasure (FreeT f m) instance (MonadDistribution m) => MonadDistribution (ContT r m) where random = lift random diff --git a/src/Control/Monad/Bayes/Inference/RMSMC.hs b/src/Control/Monad/Bayes/Inference/RMSMC.hs index d9bc8a9b..f86a4c33 100644 --- a/src/Control/Monad/Bayes/Inference/RMSMC.hs +++ b/src/Control/Monad/Bayes/Inference/RMSMC.hs @@ -25,7 +25,8 @@ import Control.Monad.Bayes.Inference.MCMC (MCMCConfig (..)) import Control.Monad.Bayes.Inference.SMC import Control.Monad.Bayes.Population ( PopulationT, - spawn, + flatten, + single, withParticles, ) import Control.Monad.Bayes.Sequential.Coroutine as Seq @@ -50,8 +51,8 @@ rmsmc :: PopulationT m a rmsmc (MCMCConfig {..}) (SMCConfig {..}) = marginal - . S.sequentially (composeCopies numMCMCSteps mhStep . TrStat.hoist resampler) numSteps - . S.hoistFirst (TrStat.hoist (spawn numParticles >>)) + . S.sequentially (composeCopies numMCMCSteps (TrStat.hoistModel (single . flatten) . TrStat.hoist (single . flatten) . mhStep) . TrStat.hoist resampler) numSteps + . S.hoistFirst (TrStat.hoistModel (single . flatten) . TrStat.hoist (withParticles numParticles)) -- | Resample-move Sequential Monte Carlo with a more efficient -- tracing representation. @@ -64,7 +65,7 @@ rmsmcBasic :: PopulationT m a rmsmcBasic (MCMCConfig {..}) (SMCConfig {..}) = TrBas.marginal - . S.sequentially (composeCopies numMCMCSteps TrBas.mhStep . TrBas.hoist resampler) numSteps + . S.sequentially (TrBas.hoist (single . flatten) . composeCopies numMCMCSteps (TrBas.hoist (single . flatten) . TrBas.mhStep) . TrBas.hoist resampler) numSteps . S.hoistFirst (TrBas.hoist (withParticles numParticles)) -- | A variant of resample-move Sequential Monte Carlo @@ -79,7 +80,7 @@ rmsmcDynamic :: PopulationT m a rmsmcDynamic (MCMCConfig {..}) (SMCConfig {..}) = TrDyn.marginal - . S.sequentially (TrDyn.freeze . composeCopies numMCMCSteps TrDyn.mhStep . TrDyn.hoist resampler) numSteps + . S.sequentially (TrDyn.freeze . composeCopies numMCMCSteps (TrDyn.hoist (single . flatten) . TrDyn.mhStep) . TrDyn.hoist resampler) numSteps . S.hoistFirst (TrDyn.hoist (withParticles numParticles)) -- | Apply a function a given number of times. diff --git a/src/Control/Monad/Bayes/Inference/SMC.hs b/src/Control/Monad/Bayes/Inference/SMC.hs index 3f3a30b2..a729dc85 100644 --- a/src/Control/Monad/Bayes/Inference/SMC.hs +++ b/src/Control/Monad/Bayes/Inference/SMC.hs @@ -22,11 +22,18 @@ where import Control.Monad.Bayes.Class (MonadDistribution, MonadMeasure) import Control.Monad.Bayes.Population - ( PopulationT, + ( PopulationT (..), + flatten, pushEvidence, + single, withParticles, ) +import Control.Monad.Bayes.Population.Applicative qualified as Applicative import Control.Monad.Bayes.Sequential.Coroutine as Coroutine +import Control.Monad.Bayes.Sequential.Coroutine qualified as SequentialT +import Control.Monad.Bayes.Weighted (WeightedT (..), weightedT) +import Control.Monad.Coroutine +import Control.Monad.Trans.Free (FreeF (..), FreeT (..)) data SMCConfig m = SMCConfig { resampler :: forall x. PopulationT m x -> PopulationT m x, @@ -34,6 +41,19 @@ data SMCConfig m = SMCConfig numParticles :: Int } +sequentialToPopulation :: (Monad m) => Coroutine.SequentialT (Applicative.PopulationT m) a -> PopulationT m a +sequentialToPopulation = + PopulationT + . weightedT + . coroutineToFree + . Coroutine.runSequentialT + where + coroutineToFree = + FreeT + . fmap (Free . fmap (\(cont, p) -> either (coroutineToFree . extract) (pure . (,p)) cont)) + . Applicative.runPopulationT + . resume + -- | Sequential importance resampling. -- Basically an SMC template that takes a custom resampler. smc :: @@ -42,12 +62,15 @@ smc :: Coroutine.SequentialT (PopulationT m) a -> PopulationT m a smc SMCConfig {..} = - Coroutine.sequentially resampler numSteps + (single . flatten) + . Coroutine.sequentially resampler numSteps + . SequentialT.hoist (single . flatten) . Coroutine.hoistFirst (withParticles numParticles) + . SequentialT.hoist (single . flatten) -- | Sequential Monte Carlo with multinomial resampling at each timestep. -- Weights are normalized at each timestep and the total weight is pushed -- as a score into the transformed monad. smcPush :: (MonadMeasure m) => SMCConfig m -> Coroutine.SequentialT (PopulationT m) a -> PopulationT m a -smcPush config = smc config {resampler = (pushEvidence . resampler config)} +smcPush config = smc config {resampler = (single . flatten . pushEvidence . resampler config)} diff --git a/src/Control/Monad/Bayes/Inference/SMC2.hs b/src/Control/Monad/Bayes/Inference/SMC2.hs index 5570a2ba..530d8932 100644 --- a/src/Control/Monad/Bayes/Inference/SMC2.hs +++ b/src/Control/Monad/Bayes/Inference/SMC2.hs @@ -27,8 +27,10 @@ import Control.Monad.Bayes.Class import Control.Monad.Bayes.Inference.MCMC import Control.Monad.Bayes.Inference.RMSMC (rmsmc) import Control.Monad.Bayes.Inference.SMC (SMCConfig (SMCConfig, numParticles, numSteps, resampler), smcPush) -import Control.Monad.Bayes.Population as Pop (PopulationT, resampleMultinomial, runPopulationT) +import Control.Monad.Bayes.Population as Pop (PopulationT, flatten, resampleMultinomial, runPopulationT, single) +import Control.Monad.Bayes.Population qualified as PopulationT import Control.Monad.Bayes.Sequential.Coroutine (SequentialT) +import Control.Monad.Bayes.Sequential.Coroutine qualified as SequentialT import Control.Monad.Bayes.Traced import Control.Monad.Trans (MonadTrans (..)) import Numeric.Log (Log) @@ -71,4 +73,10 @@ smc2 k n p t param m = rmsmc MCMCConfig {numMCMCSteps = t, proposal = SingleSiteMH, numBurnIn = 0} SMCConfig {numParticles = p, numSteps = k, resampler = resampleMultinomial} - (param >>= setup . runPopulationT . smcPush (SMCConfig {numSteps = k, numParticles = n, resampler = resampleMultinomial}) . m) + (flattenSequentiallyTraced param >>= setup . runPopulationT . smcPush (SMCConfig {numSteps = k, numParticles = n, resampler = resampleMultinomial}) . flattenSMC2 . m) + +flattenSequentiallyTraced :: (Monad m) => SequentialT (TracedT (PopulationT m)) a -> SequentialT (TracedT (PopulationT m)) a +flattenSequentiallyTraced = SequentialT.hoist $ hoistModel (single . flatten) . hoist (single . flatten) + +flattenSMC2 :: (Monad m) => SequentialT (PopulationT (SMC2 m)) a -> SequentialT (PopulationT (SMC2 m)) a + flattenSMC2 = SequentialT.hoist $ single . flatten . PopulationT.hoist (SMC2 . flattenSequentiallyTraced . setup) diff --git a/src/Control/Monad/Bayes/Population.hs b/src/Control/Monad/Bayes/Population.hs index 2a384f77..bad1d01e 100644 --- a/src/Control/Monad/Bayes/Population.hs +++ b/src/Control/Monad/Bayes/Population.hs @@ -34,17 +34,21 @@ module Control.Monad.Bayes.Population collapse, popAvg, withParticles, + flatten, + single, ) where +import Control.Applicative (Alternative) import Control.Arrow (second) -import Control.Monad (replicateM) +import Control.Monad (MonadPlus, replicateM) import Control.Monad.Bayes.Class ( MonadDistribution (categorical, logCategorical, random, uniform), MonadFactor, MonadMeasure, factor, ) +import Control.Monad.Bayes.Population.Applicative qualified as Applicative import Control.Monad.Bayes.Weighted ( WeightedT, applyWeight, @@ -52,7 +56,10 @@ import Control.Monad.Bayes.Weighted runWeightedT, weightedT, ) -import Control.Monad.List (ListT (..), MonadIO, MonadTrans (..)) +import Control.Monad.Bayes.Weighted qualified as Weighted +import Control.Monad.IO.Class +import Control.Monad.Trans +import Control.Monad.Trans.Free import Data.List (unfoldr) import Data.List qualified import Data.Maybe (catMaybes) @@ -63,24 +70,29 @@ import Numeric.Log qualified as Log import Prelude hiding (all, sum) -- | A collection of weighted samples, or particles. -newtype PopulationT m a = PopulationT {getPopulationT :: WeightedT (ListT m) a} - deriving newtype (Functor, Applicative, Monad, MonadIO, MonadDistribution, MonadFactor, MonadMeasure) +-- +-- This monad transformer is internally represented as a free monad, +-- which means that each layer of its computation contains a collection of weighted samples. +-- These can be flattened with 'flatten', +-- but the result is not a monad anymore. +newtype PopulationT m a = PopulationT {getPopulationT :: WeightedT (FreeT [] m) a} + deriving newtype (Functor, Applicative, Alternative, Monad, MonadIO, MonadPlus, MonadDistribution, MonadFactor, MonadMeasure) instance MonadTrans PopulationT where lift = PopulationT . lift . lift -- | Explicit representation of the weighted sample with weights in the log -- domain. -runPopulationT :: PopulationT m a -> m [(a, Log Double)] -runPopulationT = runListT . runWeightedT . getPopulationT +runPopulationT :: (Monad m) => PopulationT m a -> m [(a, Log Double)] +runPopulationT = iterT (fmap concat . sequence) . fmap pure . runWeightedT . getPopulationT -- | Explicit representation of the weighted sample. -explicitPopulation :: (Functor m) => PopulationT m a -> m [(a, Double)] +explicitPopulation :: (Monad m) => PopulationT m a -> m [(a, Double)] explicitPopulation = fmap (map (second (exp . ln))) . runPopulationT -- | Initialize 'PopulationT' with a concrete weighted sample. fromWeightedList :: (Monad m) => m [(a, Log Double)] -> PopulationT m a -fromWeightedList = PopulationT . weightedT . ListT +fromWeightedList = PopulationT . weightedT . FreeT . fmap (Free . fmap pure) -- | Increase the sample size by a given factor. -- The weights are adjusted such that their sum is preserved. @@ -226,7 +238,7 @@ pushEvidence :: (MonadFactor m) => PopulationT m a -> PopulationT m a -pushEvidence = hoist applyWeight . extractEvidence +pushEvidence = single . flatten . hoist applyWeight . extractEvidence -- | A properly weighted single sample, that is one picked at random according -- to the weights, with the sum of all weights. @@ -265,8 +277,18 @@ popAvg f p = do -- | Applies a transformation to the inner monad. hoist :: - (Monad n) => + (Monad m, (Monad n)) => (forall x. m x -> n x) -> PopulationT m a -> PopulationT n a -hoist f = fromWeightedList . f . runPopulationT +hoist f = PopulationT . Weighted.hoist (hoistFreeT f) . getPopulationT + +-- | Flatten all layers of the free structure. +flatten :: (Monad m) => PopulationT m a -> Applicative.PopulationT m a +flatten = Applicative.fromWeightedList . runPopulationT + +-- | Create a population from a single layer of branching computations. +-- +-- Similar to 'fromWeightedListT'. +single :: (Monad m) => Applicative.PopulationT m a -> PopulationT m a +single = fromWeightedList . Applicative.runPopulationT diff --git a/src/Control/Monad/Bayes/Population/Applicative.hs b/src/Control/Monad/Bayes/Population/Applicative.hs new file mode 100644 index 00000000..7f13d484 --- /dev/null +++ b/src/Control/Monad/Bayes/Population/Applicative.hs @@ -0,0 +1,29 @@ +-- | 'PopulationT' turns a single sample into a collection of weighted samples. +-- +-- This module contains an _'Applicative'_ transformer corresponding to the Population monad transformer from the article. +-- It is based on the old-fashioned 'ListT', which is not a valid monad transformer, but a valid applicative transformer. +-- The corresponding monad transformer is contained in 'Control.Monad.Bayes.Population'. +-- One can convert from the monad transformer to the applicative transformer by 'flatten'ing. +module Control.Monad.Bayes.Population.Applicative where + +import Control.Applicative +import Control.Applicative.List +import Control.Monad.Trans.Writer.Strict +import Data.Functor.Compose +import Numeric.Log (Log) + +-- * Applicative Population transformer + +-- WriterT has to be used instead of WeightedT, +-- since WeightedT uses StateT under the hood, +-- which requires a Monad (ListT m) constraint. + +-- | A collection of weighted samples, or particles. +newtype PopulationT m a = PopulationT {getPopulationT :: WriterT (Log Double) (ListT m) a} + deriving newtype (Functor, Applicative, Alternative) + +runPopulationT :: PopulationT m a -> m [(a, Log Double)] +runPopulationT = runListT . runWriterT . getPopulationT + +fromWeightedList :: m [(a, Log Double)] -> PopulationT m a +fromWeightedList = PopulationT . WriterT . listT diff --git a/src/Control/Monad/Bayes/Sequential/Coroutine.hs b/src/Control/Monad/Bayes/Sequential/Coroutine.hs index 926c3db1..8b3b5fc1 100644 --- a/src/Control/Monad/Bayes/Sequential/Coroutine.hs +++ b/src/Control/Monad/Bayes/Sequential/Coroutine.hs @@ -22,6 +22,8 @@ module Control.Monad.Bayes.Sequential.Coroutine hoist, sequentially, sis, + runSequentialT, + extract, ) where diff --git a/src/Control/Monad/Bayes/Traced/Static.hs b/src/Control/Monad/Bayes/Traced/Static.hs index fc99b327..209b507d 100644 --- a/src/Control/Monad/Bayes/Traced/Static.hs +++ b/src/Control/Monad/Bayes/Traced/Static.hs @@ -12,6 +12,7 @@ module Control.Monad.Bayes.Traced.Static ( TracedT (..), hoist, + hoistModel, marginal, mhStep, mh, @@ -25,6 +26,7 @@ import Control.Monad.Bayes.Class MonadMeasure, ) import Control.Monad.Bayes.Density.Free (DensityT) +import Control.Monad.Bayes.Density.Free qualified as DensityT import Control.Monad.Bayes.Traced.Common ( Trace (..), bind, @@ -33,6 +35,7 @@ import Control.Monad.Bayes.Traced.Common singleton, ) import Control.Monad.Bayes.Weighted (WeightedT) +import Control.Monad.Bayes.Weighted qualified as WeightedT import Control.Monad.Trans (MonadTrans (..)) import Data.List.NonEmpty as NE (NonEmpty ((:|)), toList) @@ -72,6 +75,9 @@ instance (MonadMeasure m) => MonadMeasure (TracedT m) hoist :: (forall x. m x -> m x) -> TracedT m a -> TracedT m a hoist f (TracedT m d) = TracedT m (f d) +hoistModel :: (Monad m) => (forall x. m x -> m x) -> TracedT m a -> TracedT m a +hoistModel f (TracedT m d) = TracedT (WeightedT.hoist (DensityT.hoist f) m) d + -- | Discard the trace and supporting infrastructure. marginal :: (Monad m) => TracedT m a -> m a marginal (TracedT _ d) = fmap output d @@ -98,15 +104,15 @@ mhStep (TracedT m d) = TracedT m d' -- * What is the probability that it is the weekend? -- -- >>> :{ --- let --- bus = do x <- bernoulli (2/7) --- let rate = if x then 3 else 10 --- factor $ poissonPdf rate 4 --- return x --- mhRunBusSingleObs = do --- let nSamples = 2 --- sampleIOfixed $ unweighted $ mh nSamples bus --- in mhRunBusSingleObs +-- let +-- bus = do x <- bernoulli (2/7) +-- let rate = if x then 3 else 10 +-- factor $ poissonPdf rate 4 +-- return x +-- mhRunBusSingleObs = do +-- let nSamples = 2 +-- sampleIOfixed $ unweighted $ mh nSamples bus +-- in mhRunBusSingleObs -- :} -- [True,True,True] -- diff --git a/src/Control/Monad/Bayes/Weighted.hs b/src/Control/Monad/Bayes/Weighted.hs index a1bbe44a..843f430f 100644 --- a/src/Control/Monad/Bayes/Weighted.hs +++ b/src/Control/Monad/Bayes/Weighted.hs @@ -24,6 +24,8 @@ module Control.Monad.Bayes.Weighted ) where +import Control.Applicative (Alternative) +import Control.Monad (MonadPlus) import Control.Monad.Bayes.Class ( MonadDistribution, MonadFactor (..), @@ -36,7 +38,7 @@ import Numeric.Log (Log) -- | Execute the program using the prior distribution, while accumulating likelihood. newtype WeightedT m a = WeightedT (StateT (Log Double) m a) -- StateT is more efficient than WriterT - deriving newtype (Functor, Applicative, Monad, MonadIO, MonadTrans, MonadDistribution) + deriving newtype (Functor, Applicative, Alternative, Monad, MonadIO, MonadPlus, MonadTrans, MonadDistribution) instance (Monad m) => MonadFactor (WeightedT m) where score w = WeightedT (modify (* w))