Skip to content

Commit

Permalink
docs: Mention non-blocking sampling in readme
Browse files Browse the repository at this point in the history
  • Loading branch information
aseyboldt committed Mar 18, 2024
1 parent 94082dc commit cfceaf0
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 0 deletions.
20 changes: 20 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,26 @@ trace_pymc = nutpie.sample(compiled_model)
`trace_pymc` now contains an ArviZ `InferenceData` object, including sampling
statistics and the posterior of the variables defined above.

We can also control the sampler in a non-blocking way:

```python
# The sampler will now run the the background
sampler = nutpie.sample(compiled_model, blocking=False)

# Pause and resume the sampling
sampler.pause()
sampler.resume()

# Wait for the sampler to finish (up to timeout seconds)
# sampler.wait(timeout=0.1)

# or we can also abort the sampler (and return the incomplete trace)
incomplete_trace = sampler.abort()

# or cancel and discard all progress:
sampler.cancel()
```

## Usage with Stan

In order to sample from Stan model, `bridgestan` needs to be installed.
Expand Down
22 changes: 22 additions & 0 deletions tests/test_pymc.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,28 @@ def test_blocking():
trace.posterior.a # noqa: B018


@pytest.mark.timeout(2)
def test_wait_timeout():
with pm.Model() as model:
pm.Normal("a", shape=100_000)
compiled = nutpie.compile_pymc_model(model)
sampler = nutpie.sample(compiled, chains=1, blocking=False)
with pytest.raises(TimeoutError):
sampler.wait(timeout=0.1)
sampler.cancel()


@pytest.mark.timeout(2)
def test_pause():
with pm.Model() as model:
pm.Normal("a", shape=100_000)
compiled = nutpie.compile_pymc_model(model)
sampler = nutpie.sample(compiled, chains=1, blocking=False)
sampler.pause()
sampler.resume()
sampler.cancel()


def test_pymc_model_with_coordinate():
with pm.Model() as model:
model.add_coord("foo", length=5)
Expand Down

0 comments on commit cfceaf0

Please sign in to comment.