Skip to content

Commit

Permalink
minimum not maximum
Browse files Browse the repository at this point in the history
  • Loading branch information
Uri Granta committed Jan 3, 2025
1 parent 8ef3e46 commit e284c87
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions trieste/models/gpflow/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def sample(self, at: TensorType, *, jitter: float = DEFAULTS.JITTER) -> TensorTy
tf.debugging.assert_greater_equal(jitter, 0.0)

mean, var = self._model.predict(at[..., None, :, :]) # [..., 1, 1, L], [..., 1, 1, L]
var = var + tf.math.maximum(var, jitter)
var = var + tf.math.minimum(var, jitter)

def sample_eps() -> tf.Tensor:
self._initialized.assign(True)
Expand Down Expand Up @@ -276,7 +276,7 @@ def sample_eps() -> tf.Tensor:
)

identity = tf.eye(batch_size, dtype=cov.dtype) # [B, B]
cov = cov + tf.math.maximum(cov, jitter) * identity
cov = cov + tf.math.minimum(cov, jitter) * identity
cov_cholesky = tf.linalg.cholesky(cov) # [..., L, B, B]

variance_contribution = cov_cholesky @ self._eps # [..., L, B, S]
Expand Down
2 changes: 1 addition & 1 deletion trieste/models/gpflux/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def sample(self, at: TensorType, *, jitter: float = DEFAULTS.JITTER) -> TensorTy
continue

mean, var = layer.predict(samples, full_cov=False, full_output_cov=False)
var = var + tf.math.maximum(var, jitter)
var = var + tf.math.minimum(var, jitter)

if not self._initialized:
self._eps_list[i].assign(
Expand Down

0 comments on commit e284c87

Please sign in to comment.