Skip to content

Commit

Permalink
Relax nonzero dimension requirement on HMC (#39)
Browse files Browse the repository at this point in the history
Removes need for fixed_param
  • Loading branch information
WardBrian authored Aug 12, 2024
1 parent fdc2e1f commit 6524a22
Show file tree
Hide file tree
Showing 11 changed files with 14 additions and 21 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ on:
workflow_dispatch: {}

env:
CACHE_VERSION: 1
CACHE_VERSION: 2

# only run one copy per PR
concurrency:
Expand Down
2 changes: 1 addition & 1 deletion TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
- [x] Add nicer ability to build models from source in the languages
- [x] download source if needed, similar to bridgestan
- [x] Version checking
- [ ] Fixed param sampler for 0 dimension parameters?
- [-] ~Fixed param sampler for 0 dimension parameters?~
- [ ] Add wrapper around generate quantities method?
- [x] Add wraper around laplace sampling?
- [x] Pathfinder: expose the no lp/no PSIS version
Expand Down
3 changes: 0 additions & 3 deletions clients/R/R/tinystan.R
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,6 @@ sampler.tinystan_model = function(model, data = "", num_chains = 4, inits = NULL

with_model(model, data, seed, {
free_params <- get_free_params(model, model_ptr)
if (free_params == 0) {
stop("Model has no parameters to sample")
}

params <- c(HMC_SAMPLER_VARIABLES, get_parameter_names(model, model_ptr))
num_params <- length(params)
Expand Down
2 changes: 1 addition & 1 deletion clients/R/tests/testthat/test_sample.R
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ test_that("bad num_warmup handled properly", {

test_that("model with no params fails", {

expect_error(sampler(empty_model), "Model has no parameters")
expect_no_error(sampler(empty_model))

})

Expand Down
3 changes: 0 additions & 3 deletions clients/julia/src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -250,9 +250,6 @@ function sample(

with_model(model, data, seed) do model_ptr
free_params = num_free_params(model, model_ptr)
if free_params == 0
error("Model has no parameters to sample")
end

param_names = cat(HMC_SAMPLER_VARIABLES, get_names(model, model_ptr), dims = 1)
num_params = length(param_names)
Expand Down
4 changes: 3 additions & 1 deletion clients/julia/test/test_sample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,9 @@
end

@testset "Model without parameters" begin
@test_throws "Model has no parameters to sample" sample(empty_model)
(names, draws, metric) = sample(empty_model; save_metric=true)
@test length(names) == 7 # HMC parameters only
@test prod(size(metric)) == 0
end

@testset "Bad num_warmup" begin
Expand Down
6 changes: 3 additions & 3 deletions clients/python/tests/test_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,9 +310,9 @@ def test_bad_num_warmup(bernoulli_model):


def test_model_no_params(empty_model):
with pytest.raises(ValueError, match="Model has no parameters to sample"):
empty_model.sample()

fit = empty_model.sample(save_metric=True)
assert len(fit.parameters) == 7 # just HMC parameters
assert fit.metric.size == 0

@pytest.mark.parametrize(
"arg, value, match",
Expand Down
2 changes: 0 additions & 2 deletions clients/python/tinystan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,8 +557,6 @@ def sample(

with self._get_model(data, seed) as model:
model_params = self._num_free_params(model)
if model_params == 0:
raise ValueError("Model has no parameters to sample.")

param_names = HMC_SAMPLER_VARIABLES + self._get_parameter_names(model)

Expand Down
3 changes: 0 additions & 3 deletions clients/typescript/src/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,6 @@ export default class StanModel {
const n_params = paramNames.length;

const free_params = this.m._tinystan_model_num_free_params(model);
if (free_params === 0) {
throw new Error("Model has no parameters to sample.");
}

// TODO: allow init_inv_metric to be specified
const init_inv_metric_ptr = NULL;
Expand Down
6 changes: 4 additions & 2 deletions clients/typescript/test/model.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,11 @@ describe("test tinystan code with a mocked WASM module", () => {
const { mockedModule, model } = await getMockedModel({
numFreeParams: 0,
});
expect(() => model.sample({})).toThrow(/no parameters/);
const {metric} = model.sample({save_metric:true});

expect(mockedModule._tinystan_sample).toHaveBeenCalledTimes(0);
expect(metric?.[0]?.length).toEqual(0);

expect(mockedModule._tinystan_sample).toHaveBeenCalledTimes(1);
});

test("failure in model construction throws", async () => {
Expand Down

0 comments on commit 6524a22

Please sign in to comment.