diff --git a/README.md b/README.md index 1b0256a..40d672d 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,7 @@ pip install deepinv[benchmarks] and then run on python: ```python -from deepinv.benchmarks import run_benchmark +from deepinv_bench import run_benchmark my_solver = lambda y, physics: ... # your solver here results = run_benchmark(my_solver, "benchmark_name") ``` diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..4772991 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,12 @@ +[build-system] +requires = ["setuptools>=61.0.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "deepinv_bench" +version = "0.0.1" +dependencies = [ + "torch", + "deepinv", + "benchopt" +] diff --git a/run.py b/run.py new file mode 100644 index 0000000..1a886f7 --- /dev/null +++ b/run.py @@ -0,0 +1,23 @@ +import deepinv as dinv +import torch +import benchopt + +def run_benchmark(model : dinv.models.Reconstructor | torch.nn.Module, benchmark_name : str): + r""" + Run a benchmark on a given model. + + + :param dinv.models.Reconstructor | torch.nn.Module model: + :param str benchmark_name: Name of the benchmark to run. + :return: dict with benchmark results, including metrics and runtime. + """ + + # TODO: how can we do this with benchopt? + + return results + + +if __name__ == "__main__": + solver = dinv.models.RAM() + results = run_benchmark(solver, "div2k_gaussian_deblurring") + print(results) \ No newline at end of file