Skip to content

Commit

Permalink
Update pyproject.toml for PyTorch versions (#107)
Browse files Browse the repository at this point in the history
  • Loading branch information
Neeratyoy committed Jun 19, 2024
2 parents 0553bd5 + ade69db commit 9a1079c
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 9 deletions.
15 changes: 10 additions & 5 deletions neps/optimizers/bayesian_optimization/kernels/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,16 @@ def extract_configs(configs: list[SearchSpace]) -> Tuple[list, list]:
"""
config_hps = [conf.get_normalized_hp_categories() for conf in configs]
graphs = [hps["graphs"] for hps in config_hps]

_nested_graphs = np.array(graphs, dtype=object)
if _nested_graphs.ndim == 3:
graphs = _nested_graphs[:, :, 0].reshape(-1).tolist()

# Don't call np.array on structured objects
# https://github.com/numpy/numpy/issues/24546#issuecomment-1693913119
# _nested_graphs = np.array(graphs, dtype=object)
# if _nested_graphs.ndim == 3
# graphs = _nested_graphs[:, :, 0].reshape(-1).tolist()
# Long hand way of doing the above
if (len(graphs) > 0 and isinstance(graphs[0], list)
and len(graphs[0]) > 0 and isinstance(graphs[0][0], list)):
res = [_list for list_of_list in graphs for _list in list_of_list]
graphs = res
return graphs, config_hps


Expand Down
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ pandas = "^2"
networkx = "^2.6.3"
nltk = "^3.6.4"
scipy = "^1"
torch = ">=1.7.0,<=2.1, !=2.0.1, !=2.1.0" # fix from: https://stackoverflow.com/a/76647180
# torch = ">=1.7.0,<=2.1, !=2.0.1, !=2.1.0" # fix from: https://stackoverflow.com/a/76647180
torch = ">1.7.0,!=2.0.1, !=2.1.0"
matplotlib = "^3"
more-itertools = "*"
portalocker = "^2"
Expand All @@ -70,7 +71,7 @@ pre-commit = "^3"
mypy = "^1"
pytest = "^7"
types-PyYAML = "^6"
torchvision = "<0.16.0" # Used in examples
torchvision = ">=0.8.0" # Used in examples
mkdocs-material = "*"
mkdocs-autorefs = "*"
mkdocs-gen-files = "*"
Expand Down
4 changes: 2 additions & 2 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ def no_logs_gte_error(caplog):
def test_core_examples(example):
if example.name == "analyse.py":
# Run hyperparameters example to have something to analyse
runpy.run_path(core_examples_scripts[0], run_name="__main__")
runpy.run_path(str(core_examples_scripts[0]), run_name="__main__")

runpy.run_path(example, run_name="__main__")
runpy.run_path(str(example), run_name="__main__")


@pytest.mark.ci_examples
Expand Down

0 comments on commit 9a1079c

Please sign in to comment.