diff --git a/.github/workflows/Documentation.yml b/.github/workflows/Documentation.yml index 6d84b0e..599bcad 100644 --- a/.github/workflows/Documentation.yml +++ b/.github/workflows/Documentation.yml @@ -28,6 +28,8 @@ jobs: pkg: - name: TuringBenchmarking dir: './TuringBenchmarking' + - name: TuringCallbacks + dir: './TuringCallbacks' steps: - name: Build and deploy diff --git a/TuringCallbacks/.github/workflows/CompatHelper.yml b/TuringCallbacks/.github/workflows/CompatHelper.yml new file mode 100644 index 0000000..cba9134 --- /dev/null +++ b/TuringCallbacks/.github/workflows/CompatHelper.yml @@ -0,0 +1,16 @@ +name: CompatHelper +on: + schedule: + - cron: 0 0 * * * + workflow_dispatch: +jobs: + CompatHelper: + runs-on: ubuntu-latest + steps: + - name: Pkg.add("CompatHelper") + run: julia -e 'using Pkg; Pkg.add("CompatHelper")' + - name: CompatHelper.main() + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + COMPATHELPER_PRIV: ${{ secrets.DOCUMENTER_KEY }} + run: julia -e 'using CompatHelper; CompatHelper.main()' diff --git a/TuringCallbacks/.github/workflows/Docs.yml b/TuringCallbacks/.github/workflows/Docs.yml new file mode 100644 index 0000000..915f26a --- /dev/null +++ b/TuringCallbacks/.github/workflows/Docs.yml @@ -0,0 +1,32 @@ +name: Documentation + +on: + push: + branches: + - main + tags: '*' + pull_request: + +concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: only if it is a pull request build. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} + +permissions: + contents: write + pull-requests: write + +jobs: + docs: + runs-on: ubuntu-latest + + steps: + - name: Build and deploy Documenter.jl docs + uses: TuringLang/actions/DocsDocumenter@main + + - run: | + julia --project=docs -e ' + using Documenter: doctest + using TuringCallbacks + doctest(TuringCallbacks)' diff --git a/TuringCallbacks/.github/workflows/DocsNav.yml b/TuringCallbacks/.github/workflows/DocsNav.yml new file mode 100644 index 0000000..7e86195 --- /dev/null +++ b/TuringCallbacks/.github/workflows/DocsNav.yml @@ -0,0 +1,39 @@ +name: Rebuild docs with newest navbar + +on: + # 3:25 AM UTC every Sunday -- choose an uncommon time to avoid + # periods of heavy GitHub Actions usage + schedule: + - cron: '25 3 * * 0' + # Whenever needed + workflow_dispatch: + +permissions: + contents: write + +jobs: + update-navbar: + runs-on: ubuntu-latest + + steps: + - name: Checkout gh-pages branch + uses: actions/checkout@v4 + with: + ref: gh-pages + + - name: Insert navbar + uses: TuringLang/actions/DocsNav@main + with: + doc-path: '.' + + - name: Commit and push changes + run: | + if [[ -n $(git status -s) ]]; then + git config user.name github-actions[bot] + git config user.email github-actions[bot]@users.noreply.github.com + git add -A + git commit -m "Update navbar (automated)" + git push + else + echo "No changes to commit" + fi diff --git a/TuringCallbacks/.github/workflows/TagBot.yml b/TuringCallbacks/.github/workflows/TagBot.yml new file mode 100644 index 0000000..f49313b --- /dev/null +++ b/TuringCallbacks/.github/workflows/TagBot.yml @@ -0,0 +1,15 @@ +name: TagBot +on: + issue_comment: + types: + - created + workflow_dispatch: +jobs: + TagBot: + if: github.event_name == 'workflow_dispatch' || github.actor == 'JuliaTagBot' + runs-on: ubuntu-latest + steps: + - uses: JuliaRegistries/TagBot@v1 + with: + token: ${{ secrets.GITHUB_TOKEN }} + ssh: ${{ secrets.DOCUMENTER_KEY }} diff --git a/TuringCallbacks/.github/workflows/ci.yml b/TuringCallbacks/.github/workflows/ci.yml new file mode 100644 index 0000000..ee2e251 --- /dev/null +++ b/TuringCallbacks/.github/workflows/ci.yml @@ -0,0 +1,42 @@ +name: CI +on: + push: + branches: + - main + pull_request: + +# needed to allow julia-actions/cache to delete old caches that it has created +permissions: + actions: write + contents: read + +# Cancel existing tests on the same PR if a new commit is added to a pull request +concurrency: + group: ${{ github.workflow }}-${{ github.ref || github.run_id }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} + +jobs: + test: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: + - ubuntu-latest + - macOS-latest + - windows-latest + + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 + with: + version: "1" + - uses: julia-actions/cache@v2 + - uses: julia-actions/julia-buildpkg@v1 + - uses: julia-actions/julia-runtest@v1 + - uses: julia-actions/julia-processcoverage@v1 + - uses: codecov/codecov-action@v5 + with: + files: lcov.info + token: ${{ secrets.CODECOV_TOKEN }} + fail_ci_if_error: true diff --git a/TuringCallbacks/.gitignore b/TuringCallbacks/.gitignore new file mode 100644 index 0000000..3804c22 --- /dev/null +++ b/TuringCallbacks/.gitignore @@ -0,0 +1,5 @@ +*.jl.*.cov +*.jl.cov +*.jl.mem +/docs/build/ +Manifest.toml diff --git a/TuringCallbacks/LICENSE b/TuringCallbacks/LICENSE new file mode 100644 index 0000000..0020dd3 --- /dev/null +++ b/TuringCallbacks/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2020 Tor + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/TuringCallbacks/Project.toml b/TuringCallbacks/Project.toml new file mode 100644 index 0000000..30dc900 --- /dev/null +++ b/TuringCallbacks/Project.toml @@ -0,0 +1,34 @@ +name = "TuringCallbacks" +uuid = "ea0860ee-d0ef-45ef-82e6-cc37d6be2f9c" +version = "0.4.4" +authors = ["Tor Erlend Fjelde and contributors"] + +[deps] +DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" +Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" +DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" +OnlineStats = "a15396b6-48d5-5d58-9928-6d29437db91e" +Reexport = "189a3867-3050-52da-a836-e630ba90ab69" +Requires = "ae029012-a4dd-5104-9daa-d747884805df" +TensorBoardLogger = "899adc3e-224a-11e9-021f-63837185c80f" + +[weakdeps] +Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" + +[extensions] +TuringCallbacksTuringExt = "Turing" + +[compat] +DataStructures = "0.18 - 0.19" +DocStringExtensions = "0.8, 0.9" +OnlineStats = "1.5" +Reexport = "0.2, 1.0" +Requires = "1" +TensorBoardLogger = "0.1.22" +Turing = "0.39, 0.40, 0.41, 0.42" +julia = "1" + +[extras] +Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" diff --git a/TuringCallbacks/README.md b/TuringCallbacks/README.md new file mode 100644 index 0000000..22ed85b --- /dev/null +++ b/TuringCallbacks/README.md @@ -0,0 +1,8 @@ +# TuringCallbacks + +[![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://turinglang.github.io/Deprecated/TuringCallbacks.jl/stable) +[![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://turinglang.github.io/Deprecated/TuringCallbacks.jl/dev) + +A package containing some convenient callbacks to use when you `sample` in [`Turing.jl`](https://github.com/TuringLang/Turing.jl). + +See the dev-docs for more information. diff --git a/TuringCallbacks/docs/Project.toml b/TuringCallbacks/docs/Project.toml new file mode 100644 index 0000000..fbef03d --- /dev/null +++ b/TuringCallbacks/docs/Project.toml @@ -0,0 +1,6 @@ +[deps] +Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +TuringCallbacks = "ea0860ee-d0ef-45ef-82e6-cc37d6be2f9c" + +[compat] +Documenter = "1" diff --git a/TuringCallbacks/docs/make.jl b/TuringCallbacks/docs/make.jl new file mode 100644 index 0000000..8324519 --- /dev/null +++ b/TuringCallbacks/docs/make.jl @@ -0,0 +1,17 @@ +using TuringCallbacks +using Documenter + +makedocs(; + modules=[TuringCallbacks], + authors="Tor", + repo="https://github.com/TuringLang/Deprecated/blob/{commit}{path}#L{line}", + sitename="TuringCallbacks.jl", + format=Documenter.HTML(; + prettyurls=get(ENV, "CI", "false") == "true", + canonical="https://turinglang.github.io/Deprecated/TuringCallbacks.jl", + assets=String[], + ), + pages=[ + "Home" => "index.md", + ], +) diff --git a/TuringCallbacks/docs/src/assets/tensorboard_demo_distributions_screen.png b/TuringCallbacks/docs/src/assets/tensorboard_demo_distributions_screen.png new file mode 100644 index 0000000..b91da26 Binary files /dev/null and b/TuringCallbacks/docs/src/assets/tensorboard_demo_distributions_screen.png differ diff --git a/TuringCallbacks/docs/src/assets/tensorboard_demo_histograms_screen.png b/TuringCallbacks/docs/src/assets/tensorboard_demo_histograms_screen.png new file mode 100644 index 0000000..b8ea094 Binary files /dev/null and b/TuringCallbacks/docs/src/assets/tensorboard_demo_histograms_screen.png differ diff --git a/TuringCallbacks/docs/src/assets/tensorboard_demo_initial_screen.png b/TuringCallbacks/docs/src/assets/tensorboard_demo_initial_screen.png new file mode 100644 index 0000000..1ee57e9 Binary files /dev/null and b/TuringCallbacks/docs/src/assets/tensorboard_demo_initial_screen.png differ diff --git a/TuringCallbacks/docs/src/index.md b/TuringCallbacks/docs/src/index.md new file mode 100644 index 0000000..ab29cd5 --- /dev/null +++ b/TuringCallbacks/docs/src/index.md @@ -0,0 +1,204 @@ +```@meta +CurrentModule = TuringCallbacks +DocTestSetup = quote + using TuringCallbacks +end +``` + +```@setup setup +using TuringCallbacks +``` + +# TuringCallbacks + +```@contents +``` + +## Getting started +As the package is not yet officially released, the package has to be added from the GitHub repository: +```julia +julia> ] +pkg> add TuringCallbacks.jl +``` + +## Visualizing sampling on-the-fly +`TensorBoardCallback` is a wrapper around `Base.CoreLogging.AbstractLogger` which can be used to create a `callback` compatible with `Turing.sample`. + +To actually visualize the results of the logging, you need to have installed `tensorboard` in Python. If you do not have `tensorboard` installed, +it should hopefully be sufficient to just run +```sh +pip3 install tensorboard +``` +Then you can start up the `TensorBoard`: +```sh +python3 -m tensorboard.main --logdir tensorboard_logs/run +``` +Now we're ready to actually write some Julia code. + +The following snippet demonstrates the usage of `TensorBoardCallback` on a simple model. +This will write a set of statistics at each iteration to an event-file compatible with Tensorboard: + +```julia +using Turing, TuringCallbacks + +@model function demo(x) + s ~ InverseGamma(2, 3) + m ~ Normal(0, √s) + for i in eachindex(x) + x[i] ~ Normal(m, √s) + end +end + +xs = randn(100) .+ 1; +model = demo(xs); + +# Number of MCMC samples/steps +num_samples = 10_000 +num_adapts = 100 + +# Sampling algorithm to use +alg = NUTS(num_adapts, 0.65) + +# Create the callback +callback = TensorBoardCallback("tensorboard_logs/run") + +# Sample +chain = sample(model, alg, num_samples; callback = callback) +``` + +While this is sampling, you can head right over to `localhost:6006` in your web browser and you should be seeing some plots! + +![TensorBoard dashboard](assets/tensorboard_demo_initial_screen.png) + +In particular, note the "Distributions" tab in the above picture. Clicking this, you should see something similar to: + +![TensorBoard dashboard](assets/tensorboard_demo_distributions_screen.png) + +And finally, the "Histogram" tab shows a slightly more visually pleasing version of the marginal distributions: + +![TensorBoard dashboard](assets/tensorboard_demo_histograms_screen.png) + +Note that the names of the stats following a naming `$variable_name/...` where `$variable_name` refers to name of the variable in the model. + +### Choosing what and how you log +#### Statistics +In the above example we didn't provide any statistics explicit and so it used the default statistics, e.g. `Mean` and `Variance`. But using other statistics is easy! Here's a much more interesting example: +```julia +# Create the stats (estimators are sub-types of `OnlineStats.OnlineStat`) +stats = Skip( + num_adapts, # Consider adaptation steps + Series( + # Estimators using the entire chain + Series(Mean(), Variance(), AutoCov(10), KHist(100)), + # Estimators using the entire chain but only every 10-th sample + Thin(10, Series(Mean(), Variance(), AutoCov(10), KHist(100))), + # Estimators using only the last 1000 samples + WindowStat(1000, Series(Mean(), Variance(), AutoCov(10), KHist(100))) + ) +) +# Create the callback +callback = TensorBoardCallback("tensorboard_logs/run", stats) + +# Sample +chain = sample(model, alg, num_samples; callback = callback) +``` + +Tada! Now you should be seeing waaaay more interesting statistics in your TensorBoard dashboard. See the [`OnlineStats.jl` documentation](https://joshday.github.io/OnlineStats.jl/latest/) for more on the different statistics, with the exception of [`Thin`](@ref), [`Skip`](@ref) and [`WindowStat`](@ref) which are implemented in this package. + +Note that these statistic estimators are stateful, and therefore the following is *bad*: + +```@repl setup +s = AutoCov(5) +stat = Series(s, s) +# => 10 samples but `n=20` since we've called `fit!` twice for each observation +fit!(stat, randn(10)) +``` +while the following is *good*: +```@repl setup +stat = Series(AutoCov(5), AutoCov(5)) +# => 10 samples AND `n=10`; great! +fit!(stat, randn(10)) +``` + +Since at the moment the only support statistics are sub-types of `OnlineStats.OnlineStat`. If you want to log some custom statistic, again, at the moment, you have to make a sub-type and implement `OnlineStats.fit!` and `OnlineStats.value`. By default, a `OnlineStat` is passed to `tensorboard` by simply calling `OnlineStat.value(stat)`. Therefore, if you also want to customize how a stat is passed to `tensorbord`, you need to overload `TensorBoardLogger.preprocess(name, stat, data)` accordingly. + +#### Filter variables to log +Maybe you want to only log stats for certain variables, e.g. in the above example we might want to exclude `m` *and* exclude the sampler statistics: +```julia +callback = TensorBoardCallback( + "tensorboard_logs/run", stats; + exclude = ["m", ], include_extras = false +) +``` +Or you can create the filter (a mapping `variable_name -> ::Bool`) yourself: +```julia +var_filter(varname, value) = varname != "m" +callback = TensorBoardCallback( + "tensorboard_logs/run", stats; + filter = var_filter +) +``` + +## Supporting `TensorBoardCallback` with your own sampler + +It's also possible to make your own sampler compatible with `TensorBoardCallback`. + +To do so, you need to implement the following method: + +```@docs +TuringCallbacks.params_and_values +``` + +If you don't have any particular names for your parameters, you're free to make use of the convenience method + +```@docs +TuringCallbacks.default_param_names_for_values +``` + + +!!! note + The `params_and_values(model, sampler, transition, state; kwargs...)` is not usually overloaded, but it can sometimes be useful for defining more complex behaviors. + +For example, if the `transition` for your `MySampler` is just a `Vector{Float64}`, a basic implementation of [`TuringCallbacks.params_and_values`](@ref) would just be + +```julia +function TuringCallbacks.params_and_values(transition::Vectorr{Float64}; kwargs...) + param_names = TuringCallbacks.default_param_names_for_values(transition) + return zip(param_names, transition) +end +``` + +Or sometimes the user might pass the parameter names in as a keyword argument, and so you might want to support that with something like + +```julia +function TuringCallbacks.params_and_values(transition::Vectorr{Float64}; param_names = nothing, kwargs...) + param_names = isnothing(param_names) ? TuringCallbacks.default_param_names_for_values(transition) : param_names + return zip(param_names, transition) +end +``` + +Finally, if you in addition want to log "extra" information, e.g. some sampler statistics you're keeping track of, you also need to implement + +```@docs +TuringCallbacks.extras +``` + +## Types & Functions + +```@autodocs +Modules = [TuringCallbacks] +Private = false +Order = [:type, :function] +``` + +## Internals +```@autodocs +Modules = [TuringCallbacks] +Private = true +Public = false +``` + +## Index + +```@index +``` diff --git a/TuringCallbacks/ext/TuringCallbacksTuringExt.jl b/TuringCallbacks/ext/TuringCallbacksTuringExt.jl new file mode 100644 index 0000000..64a8624 --- /dev/null +++ b/TuringCallbacks/ext/TuringCallbacksTuringExt.jl @@ -0,0 +1,69 @@ +module TuringCallbacksTuringExt + +if isdefined(Base, :get_extension) + using Turing: Turing, DynamicPPL + using TuringCallbacks: TuringCallbacks +else + # Requires compatible. + using ..Turing: Turing, DynamicPPL + using ..TuringCallbacks: TuringCallbacks +end + +const TuringTransition = Union{ + Turing.Inference.Transition, + Turing.Inference.SMCTransition, + Turing.Inference.PGTransition +} + +function TuringCallbacks.params_and_values( + model::DynamicPPL.Model, + transition::TuringTransition; + kwargs... +) + vns, vals = Turing.Inference._params_to_array(model, [transition]) + return zip(Iterators.map(string, vns), vals) +end + +function TuringCallbacks.extras( + model::DynamicPPL.Model, transition::TuringTransition; + kwargs... +) + names, vals = Turing.Inference.get_transition_extras([transition]) + return zip(string.(names), vec(vals)) +end + +default_hyperparams(sampler::DynamicPPL.Sampler) = default_hyperparams(sampler.alg) +default_hyperparams(alg::Turing.Inference.InferenceAlgorithm) = ( + string(f) => getfield(alg, f) for f in fieldnames(typeof(alg)) if f != :adtype +) + +const AlgsWithDefaultHyperparams = Union{ + Turing.Inference.HMC, + Turing.Inference.HMCDA, + Turing.Inference.NUTS, + Turing.Inference.SGHMC, + +} + +function TuringCallbacks.hyperparams( + model::DynamicPPL.Model, + sampler::DynamicPPL.Sampler{<:AlgsWithDefaultHyperparams}; + kwargs... +) + return default_hyperparams(sampler) +end + +function TuringCallbacks.hyperparam_metrics( + model, + sampler::DynamicPPL.Sampler{<:Turing.Inference.NUTS} +) + return [ + "extras/acceptance_rate/stat/Mean", + "extras/max_hamiltonian_energy_error/stat/Mean", + "extras/lp/stat/Mean", + "extras/n_steps/stat/Mean", + "extras/tree_depth/stat/Mean" + ] +end + +end diff --git a/TuringCallbacks/src/TuringCallbacks.jl b/TuringCallbacks/src/TuringCallbacks.jl new file mode 100644 index 0000000..1d183c3 --- /dev/null +++ b/TuringCallbacks/src/TuringCallbacks.jl @@ -0,0 +1,34 @@ +module TuringCallbacks + +using Reexport + +using LinearAlgebra +using Logging +using DocStringExtensions + +@reexport using OnlineStats # used to compute different statistics on-the-fly + +using TensorBoardLogger +const TBL = TensorBoardLogger + +using DataStructures: DefaultDict + +@static if !isdefined(Base, :get_extension) + using Requires +end + +export DefaultDict, WindowStat, Thin, Skip, TensorBoardCallback, MultiCallback + +include("utils.jl") +include("stats.jl") +include("tensorboardlogger.jl") +include("callbacks/tensorboard.jl") +include("callbacks/multicallback.jl") + +@static if !isdefined(Base, :get_extension) + function __init__() + @require Turing="fce5fe82-541a-59a6-adf8-730c64b5f9a0" include("../ext/TuringCallbacksTuringExt.jl") + end +end + +end diff --git a/TuringCallbacks/src/callbacks/multicallback.jl b/TuringCallbacks/src/callbacks/multicallback.jl new file mode 100644 index 0000000..2e270cd --- /dev/null +++ b/TuringCallbacks/src/callbacks/multicallback.jl @@ -0,0 +1,23 @@ +""" + MultiCallback + +A callback that combines multiple callbacks into one. + +Implements [`push!!`](@ref) to add callbacks to the list. +""" +struct MultiCallback{Cs} + callbacks::Cs +end + +MultiCallback() = MultiCallback(()) +MultiCallback(callbacks...) = MultiCallback(callbacks) + +(c::MultiCallback)(args...; kwargs...) = foreach(c -> c(args...; kwargs...), c.callbacks) + +""" + push!!(cb::MultiCallback, callback) + +Add a callback to the list of callbacks, mutating if possible. +""" +push!!(c::MultiCallback{<:Tuple}, callback) = MultiCallback((c.callbacks..., callback)) +push!!(c::MultiCallback{<:AbstractArray}, callback) = (push!(c.callbacks, callback); return c) diff --git a/TuringCallbacks/src/callbacks/tensorboard.jl b/TuringCallbacks/src/callbacks/tensorboard.jl new file mode 100644 index 0000000..0f1dc1d --- /dev/null +++ b/TuringCallbacks/src/callbacks/tensorboard.jl @@ -0,0 +1,319 @@ +using Dates + +""" + $(TYPEDEF) + +Wraps a `CoreLogging.AbstractLogger` to construct a callback to be +passed to `AbstractMCMC.step`. + +# Usage + + TensorBoardCallback(; kwargs...) + TensorBoardCallback(directory::string[, stats]; kwargs...) + TensorBoardCallback(lg::AbstractLogger[, stats]; kwargs...) + +Constructs an instance of a `TensorBoardCallback`, creating a `TBLogger` if `directory` is +provided instead of `lg`. + +## Arguments +- `lg`: an instance of an `AbstractLogger` which implements `TuringCallbacks.increment_step!`. +- `stats = nothing`: `OnlineStat` or lookup for variable name to statistic estimator. + If `stats isa OnlineStat`, we will create a `DefaultDict` which copies `stats` for unseen + variable names. + If `isnothing`, then a `DefaultDict` with a default constructor returning a + `OnlineStats.Series` estimator with `Mean()`, `Variance()`, and `KHist(num_bins)` + will be used. + +## Keyword arguments +- `num_bins::Int = 100`: Number of bins to use in the histograms. +- `filter = nothing`: Filter determining whether or not we should log stats for a + particular variable and value; expected signature is `filter(varname, value)`. + If `isnothing` a default-filter constructed from `exclude` and + `include` will be used. +- `exclude = String[]`: If non-empty, these variables will not be logged. +- `include = String[]`: If non-empty, only these variables will be logged. +- `include_extras::Bool = true`: Include extra statistics from transitions. +- `extras_include = String[]`: If non-empty, only these extra statistics will be logged. +- `extras_exclude = String[]`: If non-empty, these extra statistics will not be logged. +- `extras_filter = nothing`: Filter determining whether or not we should log + extra statistics; expected signature is `filter(extra, value)`. + If `isnothing` a default-filter constructed from `extras_exclude` and + `extras_include` will be used. +- `include_hyperparams::Bool = true`: Include hyperparameters. +- `hyperparam_include = String[]`: If non-empty, only these hyperparameters will be logged. +- `hyperparam_exclude = String[]`: If non-empty, these hyperparameters will not be logged. +- `hyperparam_filter = nothing`: Filter determining whether or not we should log + hyperparameters; expected signature is `filter(hyperparam, value)`. + If `isnothing` a default-filter constructed from `hyperparam_exclude` and + `hyperparam_include` will be used. +- `directory::String = nothing`: if specified, will together with `comment` be used to + define the logging directory. +- `comment::String = nothing`: if specified, will together with `directory` be used to + define the logging directory. + +# Fields +$(TYPEDFIELDS) +""" +struct TensorBoardCallback{L,F1,F2,F3} + "Underlying logger." + logger::AbstractLogger + "Lookup for variable name to statistic estimate." + stats::L + "Filter determining whether to include stats for a particular variable." + variable_filter::F1 + "Include extra statistics from transitions." + include_extras::Bool + "Filter determining whether to include a particular extra statistic." + extras_filter::F2 + "Include hyperparameters." + include_hyperparams::Bool + "Filter determining whether to include a particular hyperparameter." + hyperparam_filter::F3 + "Prefix used for logging realizations/parameters" + param_prefix::String + "Prefix used for logging extra statistics" + extras_prefix::String +end + +function TensorBoardCallback(directory::String, args...; kwargs...) + TensorBoardCallback(args...; directory = directory, kwargs...) +end + +function TensorBoardCallback(args...; comment = "", directory = nothing, kwargs...) + log_dir = if isnothing(directory) + "runs/$(Dates.format(now(), dateformat"Y-m-d_H-M-S"))-$(gethostname())$(comment)" + else + directory + end + + # Set up the logger + lg = TBLogger(log_dir, min_level=Logging.Info; step_increment=0) + + return TensorBoardCallback(lg, args...; kwargs...) +end + +maybe_filter(f; kwargs...) = f +maybe_filter(::Nothing; exclude=nothing, include=nothing) = NameFilter(; exclude, include) + +function TensorBoardCallback( + lg::AbstractLogger, + stats = nothing; + num_bins::Int = 100, + exclude = nothing, + include = nothing, + filter = nothing, + include_extras::Bool = true, + extras_include = nothing, + extras_exclude = nothing, + extras_filter = nothing, + include_hyperparams::Bool = false, + hyperparams_include = nothing, + hyperparams_exclude = nothing, + hyperparams_filter = nothing, + param_prefix::String = "", + extras_prefix::String = "extras/", + kwargs... +) + # Create the filters. + variable_filter_f = maybe_filter(filter; include=include, exclude=exclude) + extras_filter_f = maybe_filter( + extras_filter; include=extras_include, exclude=extras_exclude + ) + hyperparams_filter_f = maybe_filter( + hyperparams_filter; include=hyperparams_include, exclude=hyperparams_exclude + ) + + # Lookups: create default ones if not given + stats_lookup = if stats isa OnlineStat + # Warn the user if they've provided a non-empty `OnlineStat` + OnlineStats.nobs(stats) > 0 && @warn("using statistic with observations as a base: $(stats)") + let o = stats + DefaultDict{String, typeof(o)}(() -> deepcopy(o)) + end + elseif !isnothing(stats) + # If it's not an `OnlineStat` nor `nothing`, assume user knows what they're doing + stats + else + # This is default + let o = OnlineStats.Series(Mean(), Variance(), KHist(num_bins)) + DefaultDict{String, typeof(o)}(() -> deepcopy(o)) + end + end + + return TensorBoardCallback( + lg, + stats_lookup, + variable_filter_f, + include_extras, + extras_filter_f, + include_hyperparams, + hyperparams_filter_f, + param_prefix, + extras_prefix + ) +end + +""" + filter_param_and_value(cb::TensorBoardCallback, param_name, value) + +Filter parameters and values from a `transition` based on the `filter` of `cb`. +""" +function filter_param_and_value(cb::TensorBoardCallback, param, value) + return cb.variable_filter(param, value) +end +function filter_param_and_value(cb::TensorBoardCallback, param_and_value::Tuple) + filter_param_and_value(cb, param_and_value...) +end + +""" + default_param_names_for_values(x) + +Return an iterator of `θ[i]` for each element in `x`. +""" +default_param_names_for_values(x) = ("θ[$i]" for i = 1:length(x)) + + +""" + params_and_values(model, transition[, state]; kwargs...) + params_and_values(model, sampler, transition, state; kwargs...) + +Return an iterator over parameter names and values from a `transition`. +""" +function params_and_values(model, transition, state; kwargs...) + return params_and_values(model, transition; kwargs...) +end +function params_and_values(model, sampler, transition, state; kwargs...) + return params_and_values(model, transition, state; kwargs...) +end + +""" + extras(model, transition[, state]; kwargs...) + extras(model, sampler, transition, state; kwargs...) + +Return an iterator with elements of the form `(name, value)` for additional statistics in `transition`. + +Default implementation returns an empty iterator. +""" +extras(model, transition; kwargs...) = () +extras(model, transition, state; kwargs...) = extras(model, transition; kwargs...) +function extras(model, sampler, transition, state; kwargs...) + return extras(model, transition, state; kwargs...) +end + +""" + filter_extras_and_value(cb::TensorBoardCallback, name, value) + +Filter extras and values from a `transition` based on the `filter` of `cb`. +""" +function filter_extras_and_value(cb::TensorBoardCallback, name, value) + return cb.extras_filter(name, value) +end +function filter_extras_and_value(cb::TensorBoardCallback, name_and_value::Tuple) + return filter_extras_and_value(cb, name_and_value...) +end + +""" + hyperparams(model, sampler[, transition, state]; kwargs...) + +Return an iterator with elements of the form `(name, value)` for hyperparameters in `model`. +""" +function hyperparams(model, sampler; kwargs...) + @warn "`hyperparams(model, sampler; kwargs...)` is not implemented for $(typeof(model)) and $(typeof(sampler)). If you want to record hyperparameters, please implement this method." + return Pair{String, Any}[] +end +function hyperparams(model, sampler, transition, state; kwargs...) + return hyperparams(model, sampler; kwargs...) +end + +""" + filter_hyperparams_and_value(cb::TensorBoardCallback, name, value) + +Filter hyperparameters and values from a `transition` based on the `filter` of `cb`. +""" +function filter_hyperparams_and_value(cb::TensorBoardCallback, name, value) + return cb.hyperparam_filter(name, value) +end +function filter_hyperparams_and_value( + cb::TensorBoardCallback, + name_and_value::Union{Pair,Tuple} +) + return filter_hyperparams_and_value(cb, name_and_value...) +end + +""" + hyperparam_metrics(model, sampler[, transition, state]; kwargs...) + +Return a `Vector{String}` of metrics for hyperparameters in `model`. +""" +function hyperparam_metrics(model, sampler; kwargs...) + @warn "`hyperparam_metrics(model, sampler; kwargs...)` is not implemented for $(typeof(model)) and $(typeof(sampler)). If you want to use some of the other recorded values as hyperparameters metrics, please implement this method." + return String[] +end +function hyperparam_metrics(model, sampler, transition, state; kwargs...) + return hyperparam_metrics(model, sampler; kwargs...) +end + +increment_step!(lg::TensorBoardLogger.TBLogger, Δ_Step) = + TensorBoardLogger.increment_step!(lg, Δ_Step) + +function (cb::TensorBoardCallback)(rng, model, sampler, transition, state, iteration; kwargs...) + stats = cb.stats + lg = cb.logger + variable_filter = Base.Fix1(filter_param_and_value, cb) + extras_filter = Base.Fix1(filter_extras_and_value, cb) + hyperparams_filter = Base.Fix1(filter_hyperparams_and_value, cb) + + if iteration == 1 && cb.include_hyperparams + # If it's the first iteration, we write the hyperparameters. + hparams = Dict(Iterators.filter( + hyperparams_filter, + hyperparams(model, sampler, transition, state; kwargs...) + )) + if !isempty(hparams) + TensorBoardLogger.write_hparams!( + lg, + hparams, + hyperparam_metrics(model, sampler) + ) + end + end + + + # TODO: Should we use the explicit interface for TensorBoardLogger? + with_logger(lg) do + for (k, val) in Iterators.filter( + variable_filter, + params_and_values(model, sampler, transition, state; kwargs...) + ) + stat = stats[k] + + # Log the raw value + @info "$(cb.param_prefix)$k" val + + # Update statistic estimators + OnlineStats.fit!(stat, val) + + # Need some iterations before we start showing the stats + @info "$(cb.param_prefix)$k" stat + end + + # Transition statstics + if cb.include_extras + for (name, val) in Iterators.filter( + extras_filter, + extras(model, sampler, transition, state; kwargs...) + ) + @info "$(cb.extras_prefix)$(name)" val + + # TODO: Make this customizable. + if val isa Real + stat = stats["$(cb.extras_prefix)$(name)"] + fit!(stat, float(val)) + @info ("$(cb.extras_prefix)$(name)") stat + end + end + end + # Increment the step for the logger. + increment_step!(lg, 1) + end +end diff --git a/TuringCallbacks/src/stats.jl b/TuringCallbacks/src/stats.jl new file mode 100644 index 0000000..7115386 --- /dev/null +++ b/TuringCallbacks/src/stats.jl @@ -0,0 +1,107 @@ +################### +### OnlineStats ### +################### +""" +$(TYPEDEF) + +# Usage + + Skip(b::Int, stat::OnlineStat) + +Skips the first `b` observations before passing them on to `stat`. +""" +mutable struct Skip{T, O<:OnlineStat{T}} <: OnlineStat{T} + b::Int + current_index::Int + stat::O +end + +Skip(b::Int, stat) = Skip(b, 0, stat) + +OnlineStats.nobs(o::Skip) = OnlineStats.nobs(o.stat) +OnlineStats.value(o::Skip) = OnlineStats.value(o.stat) +function OnlineStats._fit!(o::Skip, x::Real) + if o.current_index > o.b + OnlineStats._fit!(o.stat, x) + end + o.current_index += length(x) + + return o +end + +Base.show(io::IO, o::Skip) = print( + io, + "Skip ($(o.b)): current_index=$(o.current_index) | stat=$(o.stat)`" +) + +""" +$(TYPEDEF) + +# Usage + + Thin(b::Int, stat::OnlineStat) + +Thins `stat` with an interval `b`, i.e. only passes every b-th observation to `stat`. +""" +mutable struct Thin{T, O<:OnlineStat{T}} <: OnlineStat{T} + b::Int + current_index::Int + stat::O +end + +Thin(b::Int, stat) = Thin(b, 0, stat) + +OnlineStats.nobs(o::Thin) = OnlineStats.nobs(o.stat) +OnlineStats.value(o::Thin) = OnlineStats.value(o.stat) +function OnlineStats._fit!(o::Thin, x::Real) + if (o.current_index % o.b) == 0 + OnlineStats._fit!(o.stat, x) + end + o.current_index += length(x) + + return o +end + +Base.show(io::IO, o::Thin) = print( + io, + "Thin ($(o.b)): current_index=$(o.current_index) | stat=$(o.stat)`" +) + +""" +$(TYPEDEF) + +# Usage + + WindowStat(b::Int, stat::O) where {O <: OnlineStat} + +"Wraps" `stat` in a `MovingWindow` of length `b`. + +`value(o::WindowStat)` will then return an `OnlineStat` of the same type as +`stat`, which is *only* fitted on the batched data contained in the `MovingWindow`. + +""" +struct WindowStat{T, O} <: OnlineStat{T} + window::MovingWindow{T} + stat::O +end + +WindowStat(b::Int, T::Type, o) = WindowStat{T, typeof(o)}(MovingWindow(b, T), o) +WindowStat(b::Int, o::OnlineStat{T}) where {T} = WindowStat{T, typeof(o)}( + MovingWindow(b, T), o +) + +# Proxy methods to the window +OnlineStats.nobs(o::WindowStat) = OnlineStats.nobs(o.window) +OnlineStats._fit!(o::WindowStat, x) = OnlineStats._fit!(o.window, x) + +function OnlineStats.value(o::WindowStat{<:Any, <:OnlineStat}) + stat_new = deepcopy(o.stat) + fit!(stat_new, OnlineStats.value(o.window)) + return stat_new +end + +function OnlineStats.value(o::WindowStat{<:Any, <:Function}) + stat_new = o.stat() + fit!(stat_new, OnlineStats.value(o.window)) + return stat_new +end diff --git a/TuringCallbacks/src/tensorboardlogger.jl b/TuringCallbacks/src/tensorboardlogger.jl new file mode 100644 index 0000000..16a74fe --- /dev/null +++ b/TuringCallbacks/src/tensorboardlogger.jl @@ -0,0 +1,81 @@ +######################################### +### Overloads for `TensorBoardLogger` ### +######################################### +# `tb_name` is used by `preprocess` to decide how a given `arg` should look +""" + tb_name(args...) + +Returns a `string` representing the name for `arg` or `args` in TensorBoard. + +If `length(args) > 1`, `args` are joined together by `"/"`. +""" +tb_name(arg) = string(arg) +tb_name(stat::OnlineStat) = string(nameof(typeof(stat))) +tb_name(o::Skip) = "Skip($(o.b))" +tb_name(o::Thin) = "Thin($(o.b))" +tb_name(o::WindowStat) = "WindowStat($(o.window.b))" +tb_name(o::AutoCov, b::Int) = "AutoCov(lag=$b)/corr" + +# Recursive impl +tb_name(s1::String, s2::String) = s1 * "/" * s2 +tb_name(arg1, arg2) = tb_name(arg1) * "/" * tb_name(arg2) +tb_name(arg, args...) = tb_name(arg) * "/" * tb_name(args...) + +function TBL.preprocess(name, stat::OnlineStat, data) + if OnlineStats.nobs(stat) > 0 + TBL.preprocess(tb_name(name, stat), value(stat), data) + end +end + +function TBL.preprocess(name, stat::Skip, data) + return TBL.preprocess(tb_name(name, stat), stat.stat, data) +end + +function TBL.preprocess(name, stat::Thin, data) + return TBL.preprocess(tb_name(name, stat), stat.stat, data) +end + +function TBL.preprocess(name, stat::WindowStat, data) + return TBL.preprocess(tb_name(name, stat), value(stat), data) +end + +function TBL.preprocess(name, stat::AutoCov, data) + autocor = OnlineStats.autocor(stat) + for b = 1:(stat.lag.b - 1) + # `autocor[i]` corresponds to the lag of size `i - 1` and `autocor[1] = 1.0` + bname = tb_name(stat, b) + TBL.preprocess(tb_name(name, bname), autocor[b + 1], data) + end +end + +function TBL.preprocess(name, stat::Series, data) + # Iterate through the stats and process those independently + for s in stat.stats + TBL.preprocess(name, s, data) + end +end + +function TBL.preprocess(name, hist::KHist, data) + if OnlineStats.nobs(hist) > 0 + # Creates a NORMALIZED histogram + edges = OnlineStats.edges(hist) + cnts = OnlineStats.counts(hist) + TBL.preprocess( + name, (edges, cnts ./ sum(cnts)), data + ) + end +end + +# Unlike the `preprocess` overload, this allows us to specify if we want to normalize +function TBL.log_histogram( + logger::AbstractLogger, name::AbstractString, hist::OnlineStats.HistogramStat; + step=nothing, normalize=false +) + edges = edges(hist) + cnts = Float64.(OnlineStats.counts(hist)) + if normalize + return TBL.log_histogram(logger, name, (edges, cnts ./ sum(cnts)); step=step) + else + return TBL.log_histogram(logger, name, (edges, cnts); step=step) + end +end diff --git a/TuringCallbacks/src/utils.jl b/TuringCallbacks/src/utils.jl new file mode 100644 index 0000000..2220659 --- /dev/null +++ b/TuringCallbacks/src/utils.jl @@ -0,0 +1,10 @@ +Base.@kwdef struct NameFilter{A,B} + include::A=nothing + exclude::B=nothing +end + +(f::NameFilter)(name, value) = f(name) +function (f::NameFilter)(name) + include, exclude = f.include, f.exclude + (exclude === nothing || name ∉ exclude) && (include === nothing || name ∈ include) +end diff --git a/TuringCallbacks/test/Project.toml b/TuringCallbacks/test/Project.toml new file mode 100644 index 0000000..466c699 --- /dev/null +++ b/TuringCallbacks/test/Project.toml @@ -0,0 +1,12 @@ +[deps] +DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" +TensorBoardLogger = "899adc3e-224a-11e9-021f-63837185c80f" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" +ValueHistories = "98cad3c8-aec3-5f06-8e41-884608649ab7" + +[compat] +TensorBoardLogger = "0.1.26" +Test = "1.11.0" +Turing = "0.39.10" +ValueHistories = "0.5.4" diff --git a/TuringCallbacks/test/multicallback.jl b/TuringCallbacks/test/multicallback.jl new file mode 100644 index 0000000..69ba91f --- /dev/null +++ b/TuringCallbacks/test/multicallback.jl @@ -0,0 +1,21 @@ +@testset "MultiCallback" begin + # Number of MCMC samples/steps + num_samples = 100 + num_adapts = 50 + + # Sampling algorithm to use + alg = NUTS(num_adapts, 0.65) + + callback = MultiCallback(CountingCallback(), CountingCallback()) + chain = sample(demo_model, alg, num_samples, callback=callback) + + # Both should have been trigger an equal number of times. + counts = map(c -> c.count[], callback.callbacks) + @test counts[1] == counts[2] + @test counts[1] == num_samples + + # Add a new one and make sure it's not like the others. + callback = TuringCallbacks.push!!(callback, CountingCallback()) + counts = map(c -> c.count[], callback.callbacks) + @test counts[1] == counts[2] != counts[3] +end diff --git a/TuringCallbacks/test/runtests.jl b/TuringCallbacks/test/runtests.jl new file mode 100644 index 0000000..e9692a6 --- /dev/null +++ b/TuringCallbacks/test/runtests.jl @@ -0,0 +1,30 @@ +using Test +using DynamicPPL +using Turing +using TuringCallbacks +using TensorBoardLogger, ValueHistories + +Base.@kwdef struct CountingCallback + count::Ref{Int}=Ref(0) +end + +(c::CountingCallback)(args...; kwargs...) = c.count[] += 1 + +@model function demo(x) + s ~ InverseGamma(2, 3) + m ~ Normal(0, √s) + for i in eachindex(x) + x[i] ~ Normal(m, √s) + end +end + +function DynamicPPL.TestUtils.varnames(::DynamicPPL.Model{typeof(demo)}) + return [@varname(s), @varname(m)] +end + +const demo_model = demo(randn(100) .+ 1) + +@testset "TuringCallbacks.jl" begin + include("multicallback.jl") + include("tensorboardcallback.jl") +end diff --git a/TuringCallbacks/test/tensorboardcallback.jl b/TuringCallbacks/test/tensorboardcallback.jl new file mode 100644 index 0000000..a8043d7 --- /dev/null +++ b/TuringCallbacks/test/tensorboardcallback.jl @@ -0,0 +1,163 @@ +@testset "TensorBoardCallback" begin + tmpdir = mktempdir() + mkpath(tmpdir) + + vns = DynamicPPL.TestUtils.varnames(demo_model) + + # Number of MCMC samples/steps + num_samples = 100 + num_adapts = 50 + + # Sampling algorithm to use + alg = NUTS(num_adapts, 0.65) + + @testset "Correctness of values" begin + # Create the callback + callback = TensorBoardCallback(joinpath(tmpdir, "runs")) + + # Sample + chain = sample(demo_model, alg, num_samples; callback=callback) + + # Extract the values. + hist = convert(MVHistory, callback.logger) + + # Compare the recorded values to the chain. + m_mean = last(last(hist["m/stat/Mean"])) + s_mean = last(last(hist["s/stat/Mean"])) + + @test m_mean ≈ mean(chain[:m]) + @test s_mean ≈ mean(chain[:s]) + end + + @testset "Default" begin + # Create the callback + callback = TensorBoardCallback( + joinpath(tmpdir, "runs"); + ) + + # Sample + chain = sample(demo_model, alg, num_samples; callback=callback) + + # Read the logging info. + hist = convert(MVHistory, callback.logger) + + # Check the variables. + @testset "$vn" for vn in vns + # Should have the `val` field. + @test haskey(hist, Symbol(vn, "/val")) + # Should have the `Mean` and `Variance` stat. + @test haskey(hist, Symbol(vn, "/stat/Mean")) + @test haskey(hist, Symbol(vn, "/stat/Variance")) + end + + # Check the extra statistics. + @testset "extras" begin + @test haskey(hist, Symbol("extras/lp/val")) + @test haskey(hist, Symbol("extras/acceptance_rate/val")) + end + end + + @testset "Exclude variable" begin + # Create the callback + callback = TensorBoardCallback( + joinpath(tmpdir, "runs"); + exclude=["s"] + ) + + # Sample + chain = sample(demo_model, alg, num_samples; callback=callback) + + # Read the logging info. + hist = convert(MVHistory, callback.logger) + + # Check the variables. + @testset "$vn" for vn in vns + if vn == @varname(s) + @test !haskey(hist, Symbol(vn, "/val")) + @test !haskey(hist, Symbol(vn, "/stat/Mean")) + @test !haskey(hist, Symbol(vn, "/stat/Variance")) + else + @test haskey(hist, Symbol(vn, "/val")) + @test haskey(hist, Symbol(vn, "/stat/Mean")) + @test haskey(hist, Symbol(vn, "/stat/Variance")) + end + end + + # Check the extra statistics. + @testset "extras" begin + @test haskey(hist, Symbol("extras/lp/val")) + @test haskey(hist, Symbol("extras/acceptance_rate/val")) + end + end + + @testset "Exclude extras" begin + # Create the callback + callback = TensorBoardCallback( + joinpath(tmpdir, "runs"); + include_extras=false + ) + + # Sample + chain = sample(demo_model, alg, num_samples; callback=callback) + + # Read the logging info. + hist = convert(MVHistory, callback.logger) + + # Check the variables. + @testset "$vn" for vn in vns + @test haskey(hist, Symbol(vn, "/val")) + @test haskey(hist, Symbol(vn, "/stat/Mean")) + @test haskey(hist, Symbol(vn, "/stat/Variance")) + end + + # Check the extra statistics. + @testset "extras" begin + @test !haskey(hist, Symbol("extras/lp/val")) + @test !haskey(hist, Symbol("extras/acceptance_rate/val")) + end + end + + @testset "With hyperparams" begin + @testset "$alg (has hyperparam: $hashyp)" for (alg, hashyp) in [ + (HMC(0.05, 10), true), + (HMCDA(num_adapts, 0.65, 1.0), true), + (NUTS(num_adapts, 0.65), true), + (MH(), false), + ] + + # Create the callback + callback = TensorBoardCallback( + joinpath(tmpdir, "runs"); + include_hyperparams=true, + ) + + # Sample + chain = sample(demo_model, alg, num_samples; callback=callback) + + # HACK: This touches internals so might just break at some point. + # If it some point does, let's just remove this test. + # Inspiration: https://github.com/JuliaLogging/TensorBoardLogger.jl/blob/3d9c1a554a08179785459ad7b83bce0177b90275/src/Deserialization/deserialization.jl#L244-L258 + iter = TensorBoardLogger.TBEventFileCollectionIterator( + callback.logger.logdir, purge=true + ) + + found_one = false + for event_file in iter + for event in event_file + event.what === nothing && continue + !(event.what.value isa TensorBoardLogger.Summary) && continue + + for (tag, _) in event.what.value + if tag == "_hparams_/experiment" + found_one = true + break + end + end + end + + found_one && break + end + @test (hashyp && found_one) || (!hashyp && !found_one) + end + end +end