Skip to content

Commit

Permalink
Test against latest tensorflow (#819)
Browse files Browse the repository at this point in the history
  • Loading branch information
uri-granta authored Feb 8, 2024
1 parent 3a12b33 commit 248a49f
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 60 deletions.
116 changes: 59 additions & 57 deletions tests/latest/constraints.txt
Original file line number Diff line number Diff line change
@@ -1,80 +1,82 @@
about-time==4.2.1
absl-py==1.4.0
alive-progress==3.1.4
absl-py==2.1.0
alive-progress==3.1.5
astunparse==1.6.3
autograd==1.5
cachetools==5.3.1
certifi==2023.5.7
charset-normalizer==3.1.0
check-shapes==1.0.0
cloudpickle==2.2.1
autograd==1.6.2
cachetools==5.3.2
certifi==2024.2.2
charset-normalizer==3.3.2
check-shapes==1.1.1
clarabel==0.6.0
cloudpickle==3.0.0
cma==3.2.2
contourpy==1.0.7
cvxpy==1.3.1
cycler==0.11.0
contourpy==1.2.0
cvxpy==1.4.2
cycler==0.12.1
decorator==5.1.1
Deprecated==1.2.14
dill==0.3.5.1
dill==0.3.8
dm-tree==0.1.8
ecos==2.0.12
exceptiongroup==1.1.1
dropstackframe==0.1.0
ecos==2.0.13
exceptiongroup==1.2.0
flatbuffers==23.5.26
fonttools==4.40.0
fonttools==4.48.1
future==0.18.3
gast==0.4.0
google-auth==2.19.1
google-auth-oauthlib==1.0.0
gast==0.5.4
google-auth==2.27.0
google-auth-oauthlib==1.2.0
google-pasta==0.2.0
gpflow==2.8.1
gpflux==0.4.2
gpflow==2.9.1
gpflux==0.4.3
grapheme==0.6.0
greenlet==2.0.2
grpcio==1.54.2
h5py==3.8.0
idna==3.4
greenlet==3.0.3
grpcio==1.60.1
h5py==3.10.0
idna==3.6
iniconfig==2.0.0
jax==0.4.12
keras==2.12.0
kiwisolver==1.4.4
lark==1.1.5
libclang==16.0.0
Markdown==3.4.3
MarkupSafe==2.1.3
matplotlib==3.7.1
keras==2.15.0
kiwisolver==1.4.5
lark==1.1.9
libclang==16.0.6
Markdown==3.5.2
MarkupSafe==2.1.5
matplotlib==3.8.2
ml-dtypes==0.2.0
multipledispatch==0.6.0
numpy==1.23.5
multipledispatch==1.0.0
numpy==1.26.4
oauthlib==3.2.2
opt-einsum==3.3.0
osqp==0.6.3
packaging==23.1
Pillow==9.5.0
pluggy==1.0.0
protobuf==4.23.2
pyasn1==0.5.0
osqp==0.6.5
packaging==23.2
pillow==10.2.0
pluggy==1.4.0
protobuf==4.23.4
pyasn1==0.5.1
pyasn1-modules==0.3.0
pymoo==0.6.0.1
pyparsing==3.0.9
pytest==7.3.2
pybind11==2.11.1
pymoo==0.6.1.1
pyparsing==3.1.1
pytest==8.0.0
python-dateutil==2.8.2
PyYAML==6.0
qdldl==0.1.7
PyYAML==6.0.1
qdldl==0.1.7.post0
requests==2.31.0
requests-oauthlib==1.3.1
rsa==4.9
scipy==1.10.1
scs==3.2.3
scipy==1.11.4
scs==3.2.4.post1
six==1.16.0
tabulate==0.9.0
tensorboard==2.12.3
tensorboard-data-server==0.7.0
tensorflow==2.12.0
tensorflow-estimator==2.12.0
tensorflow-io-gcs-filesystem==0.32.0
tensorflow-probability==0.19.0
termcolor==2.3.0
tensorboard==2.15.1
tensorboard-data-server==0.7.2
tensorflow==2.15.0
tensorflow-estimator==2.15.0
tensorflow-io-gcs-filesystem==0.36.0
tensorflow-probability==0.23.0
termcolor==2.4.0
tomli==2.0.1
typing_extensions==4.6.3
urllib3==1.26.16
Werkzeug==2.3.6
typing_extensions==4.9.0
urllib3==2.2.0
Werkzeug==3.0.1
wrapt==1.14.1
11 changes: 11 additions & 0 deletions tests/unit/models/gpflow/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from time import time
from typing import Callable, Union, cast

import dill
import gpflow
import numpy as np
import numpy.testing as npt
Expand Down Expand Up @@ -2122,3 +2123,13 @@ def test_multifidelity_autoregressive_samples_are_varied(model_type: str) -> Non

hf_samples = model.sample(hf_test_locations, 2)
assert hf_samples[0] != hf_samples[1]


@random_seed
def test_gpflow_wrappers_dilling(
gpflow_interface_factory: ModelFactoryType,
) -> None:
data = mock_data()
model, _ = gpflow_interface_factory(*data)
reloaded_model = dill.loads(dill.dumps(model))
assert type(reloaded_model) is type(model)
6 changes: 6 additions & 0 deletions tests/unit/test_bayesian_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
import numpy.testing as npt
import pytest
import tensorflow as tf
import tensorflow_probability as tfp
from check_shapes import inherit_check_shapes
from packaging.version import Version

from tests.unit.test_ask_tell_optimization import DatasetChecker, LocalDatasetsFixedAcquisitionRule
from tests.util.misc import (
Expand Down Expand Up @@ -665,6 +667,10 @@ def optimize(self, dataset: Dataset) -> None:

@pytest.mark.parametrize("save_to_disk", [False, True])
def test_bayesian_optimizer_optimize_tracked_state(save_to_disk: bool) -> None:
if save_to_disk and Version(tfp.__version__) >= Version("0.23.0"):
# TODO: the latest tfp seems to have broken pickling QuadraticMeanAndRBFKernel
pytest.skip()

class _CountingRule(AcquisitionRule[State[Optional[int], TensorType], Box, ProbabilisticModel]):
def acquire(
self,
Expand Down
15 changes: 12 additions & 3 deletions trieste/models/keras/architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,19 @@

from __future__ import annotations

import contextlib
from abc import abstractmethod
from typing import Any, Callable, Sequence

import dill
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp

try:
from keras.src.saving.serialization_lib import SafeModeScope
except ImportError: # pragma: no cover (tested but not by coverage)
SafeModeScope = contextlib.nullcontext
from tensorflow_probability.python.layers.distribution_layer import DistributionLambda, _serialize

from trieste.types import TensorType
Expand Down Expand Up @@ -132,9 +138,12 @@ def __getstate__(self) -> dict[str, Any]:
def __setstate__(self, state: dict[str, Any]) -> None:
# When unpickling restore the model using model_from_json.
self.__dict__.update(state)
self._model = tf.keras.models.model_from_json(
state["_model"], custom_objects={"MultivariateNormalTriL": MultivariateNormalTriL}
)
# TF 2.15 disallows loading lambdas without "safe-mode" being disabled
# unfortunately, tfp.layers.DistributionLambda seems to use lambdas
with SafeModeScope(False):
self._model = tf.keras.models.model_from_json(
state["_model"], custom_objects={"MultivariateNormalTriL": MultivariateNormalTriL}
)
self._model.set_weights(state["_weights"])

# Restore the history (including any model it contains)
Expand Down

0 comments on commit 248a49f

Please sign in to comment.