diff --git a/benchmark/Speed.hs b/benchmark/Speed.hs index 0d72808c..cfb7eb28 100644 --- a/benchmark/Speed.hs +++ b/benchmark/Speed.hs @@ -4,12 +4,14 @@ module Main (main) where -import Control.Monad.Bayes.Class (MonadMeasure) +import Control.Monad (replicateM) +import Control.Monad.Bayes.Class (MonadMeasure, normal) import Control.Monad.Bayes.Inference.MCMC (MCMCConfig (MCMCConfig, numBurnIn, numMCMCSteps, proposal), Proposal (SingleSiteMH)) import Control.Monad.Bayes.Inference.RMSMC (rmsmcDynamic) import Control.Monad.Bayes.Inference.SMC (SMCConfig (SMCConfig, numParticles, numSteps, resampler), smc) import Control.Monad.Bayes.Population (resampleSystematic, runPopulationT) import Control.Monad.Bayes.Sampler.Strict (SamplerIO, sampleIOfixed) +import Control.Monad.Bayes.Sampler.StrictFu qualified as FU import Control.Monad.Bayes.Traced (mh) import Control.Monad.Bayes.Weighted (unweighted) import Criterion.Main @@ -136,8 +138,14 @@ samplesBenchmarks lrData hmmData ldaData = benchmarks m <- models return (s, m, a) -speedLengthCSV :: FilePath -speedLengthCSV = "speed-length.csv" +normalBenchmarks :: [Benchmark] +normalBenchmarks = [ bench "Normal single sample monad bayes" $ nfIO $ do + sampleIOfixed (do xs <- replicateM 1000 $ normal 0.0 1.0 + return $ sum xs) + , bench "Normal single sample monad bayes fu" $ nfIO $ do + FU.sampleIOfixed (do xs <- replicateM 1000 $ normal 0.0 1.0 + return $ sum xs) + ] speedSamplesCSV :: FilePath speedSamplesCSV = "speed-samples.csv" @@ -146,7 +154,7 @@ rawDAT :: FilePath rawDAT = "raw.dat" cleanupLastRun :: IO () -cleanupLastRun = mapM_ removeIfExists [speedLengthCSV, speedSamplesCSV, rawDAT] +cleanupLastRun = mapM_ removeIfExists [speedSamplesCSV, rawDAT] removeIfExists :: FilePath -> IO () removeIfExists file = do @@ -162,8 +170,10 @@ main = do lrData <- sampleIOfixed (LogReg.syntheticData 1000) hmmData <- sampleIOfixed (HMM.syntheticData 1000) ldaData <- sampleIOfixed (LDA.syntheticData 5 1000) - let configLength = defaultConfig {csvFile = Just speedLengthCSV, rawDataFile = Just rawDAT} - defaultMainWith configLength (lengthBenchmarks lrData hmmData ldaData) - let configSamples = defaultConfig {csvFile = Just speedSamplesCSV, rawDataFile = Just rawDAT} - defaultMainWith configSamples (samplesBenchmarks lrData hmmData ldaData) + defaultMainWith defaultConfig {csvFile = Just speedSamplesCSV, rawDataFile = Just rawDAT} + (concat [ lengthBenchmarks lrData hmmData ldaData + , samplesBenchmarks lrData hmmData ldaData + , normalBenchmarks + ] + ) void $ runProcess "python plots.py" diff --git a/monad-bayes.cabal b/monad-bayes.cabal index c576c005..068c1177 100644 --- a/monad-bayes.cabal +++ b/monad-bayes.cabal @@ -57,6 +57,7 @@ common deps , pretty-simple ^>=4.1 , primitive >=0.7 && <0.9 , random ^>=1.2 + , random-fu , safe ^>=0.3.17 , scientific ^>=0.3 , statistics >=0.14.0 && <0.17 @@ -97,6 +98,7 @@ library Control.Monad.Bayes.Population Control.Monad.Bayes.Sampler.Lazy Control.Monad.Bayes.Sampler.Strict + Control.Monad.Bayes.Sampler.StrictFu Control.Monad.Bayes.Sequential.Coroutine Control.Monad.Bayes.Traced Control.Monad.Bayes.Traced.Basic diff --git a/shell.nix b/shell.nix index e6d91731..bb826e7a 100644 --- a/shell.nix +++ b/shell.nix @@ -1,14 +1,39 @@ -( - import - ( - let - lock = builtins.fromJSON (builtins.readFile ./flake.lock); - in - fetchTarball { - url = "https://github.com/edolstra/flake-compat/archive/${lock.nodes.flake-compat.locked.rev}.tar.gz"; - sha256 = lock.nodes.flake-compat.locked.narHash; - } - ) - {src = ./.;} -) -.shellNix +let + +myHaskellPackageOverlay = self: super: { + myHaskellPackages = super.haskellPackages.override { + overrides = hself: hsuper: rec { + }; + }; +}; + +in + +{ nixpkgs ? import { overlays = [ myHaskellPackageOverlay ]; }, compiler ? "default", doBenchmark ? false }: + + +let + + pkgs = nixpkgs; + + haskellDeps = ps: with ps; [ + abstract-par base brick containers criterion directory foldl free + histogram-fill hspec ieee754 integration lens linear log-domain + math-functions matrix monad-coroutine monad-extras mtl mwc-random + optparse-applicative pipes pretty-simple primitive process + QuickCheck random random-fu safe scientific statistics text time transformers + typed-process vector vty + ]; + +in + +pkgs.stdenv.mkDerivation { + name = "whatever"; + + buildInputs = [ + pkgs.libintlOrEmpty + pkgs.cabal-install + (pkgs.myHaskellPackages.ghcWithPackages haskellDeps) + pkgs.darwin.apple_sdk.frameworks.Cocoa + ]; +} diff --git a/src/Control/Monad/Bayes/Sampler/StrictFu.hs b/src/Control/Monad/Bayes/Sampler/StrictFu.hs new file mode 100644 index 00000000..b060e683 --- /dev/null +++ b/src/Control/Monad/Bayes/Sampler/StrictFu.hs @@ -0,0 +1,106 @@ +{-# LANGUAGE ApplicativeDo #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE ImportQualifiedPost #-} + +-- | +-- Module : Control.Monad.Bayes.Sampler +-- Description : Pseudo-random sampling monads +-- Copyright : (c) Adam Scibior, 2015-2020 +-- License : MIT +-- Maintainer : leonhard.markert@tweag.io +-- Stability : experimental +-- Portability : GHC +-- +-- 'SamplerIO' and 'SamplerST' are instances of 'MonadDistribution'. Apply a 'MonadFactor' +-- transformer to obtain a 'MonadMeasure' that can execute probabilistic models. +module Control.Monad.Bayes.Sampler.StrictFu + ( SamplerT (..), + SamplerIO, + SamplerST, + sampleIO, + sampleIOfixed, + sampleWith, + sampleSTfixed, + sampleMean, + sampler, + ) +where + +import Control.Foldl qualified as F hiding (random) +import Control.Monad.Bayes.Class + ( MonadDistribution + ( bernoulli, + beta, + -- categorical, + gamma, + -- geometric, + normal, + random, + uniform + ), + ) +import Control.Monad.Reader (ReaderT (..)) +import Control.Monad.ST (ST) +import Control.Monad.State +import Numeric.Log (Log (ln)) +import Data.Random qualified as RF +import Data.Random.Distribution.Beta qualified as RF +import Data.Random.Distribution.Bernoulli qualified as RF +import Data.Random.Distribution.Uniform as RF +import System.Random.Stateful (IOGenM (..), STGenM, StatefulGen, StdGen, initStdGen, mkStdGen, newIOGenM, newSTGenM) + + +-- | The sampling interpretation of a probabilistic program +-- Here m is typically IO or ST +newtype SamplerT g m a = SamplerT {runSamplerT :: ReaderT g m a} deriving (Functor, Applicative, Monad, MonadIO) + +-- | convenient type synonym to show specializations of SamplerT +-- to particular pairs of monad and RNG +type SamplerIO = SamplerT (IOGenM StdGen) IO + +-- | convenient type synonym to show specializations of SamplerT +-- to particular pairs of monad and RNG +type SamplerST s = SamplerT (STGenM StdGen s) (ST s) + +instance StatefulGen g m => MonadDistribution (SamplerT g m) where + random = SamplerT (ReaderT $ RF.runRVar $ RF.stdUniform) + + uniform a b = SamplerT (ReaderT $ RF.runRVar $ RF.doubleUniform a b) + normal m s = SamplerT (ReaderT $ RF.runRVar $ RF.normal m s) + gamma shape scale = SamplerT (ReaderT $ RF.runRVar $ RF.gamma shape scale) + beta a b = SamplerT (ReaderT $ RF.runRVar $ RF.beta a b) + + bernoulli p = SamplerT (ReaderT $ RF.runRVar $ RF.bernoulli p) + -- categorical ps = error "categorical" + -- geometric p = error "geometric" + +-- | Sample with a random number generator of your choice e.g. the one +-- from `System.Random`. +-- +-- >>> import Control.Monad.Bayes.Class +-- >>> import System.Random.Stateful hiding (random) +-- >>> newIOGenM (mkStdGen 1729) >>= sampleWith random +-- 4.690861245089605e-2 +sampleWith :: SamplerT g m a -> g -> m a +sampleWith (SamplerT m) = runReaderT m + +-- | initialize random seed using system entropy, and sample +sampleIO, sampler :: SamplerIO a -> IO a +sampleIO x = initStdGen >>= newIOGenM >>= sampleWith x +sampler = sampleIO + +-- | Run the sampler with a fixed random seed +sampleIOfixed :: SamplerIO a -> IO a +sampleIOfixed x = newIOGenM (mkStdGen 1729) >>= sampleWith x + +-- | Run the sampler with a fixed random seed +sampleSTfixed :: SamplerST s b -> ST s b +sampleSTfixed x = newSTGenM (mkStdGen 1729) >>= sampleWith x + +sampleMean :: [(Double, Log Double)] -> Double +sampleMean samples = + let z = F.premap (ln . exp . snd) F.sum + w = (F.premap (\(x, y) -> x * ln (exp y)) F.sum) + s = (/) <$> w <*> z + in F.fold s samples