diff --git a/Makefile b/Makefile index 84cf255a..aeed1be5 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,6 @@ install: @echo "--- 🚀 Installing project dependencies ---" - pip install -e ./browsergym/core -e ./browsergym/miniwob -e ./browsergym/webarena -e ./browsergym/visualwebarena/ -e ./browsergym/experiments -e ./browsergym/assistantbench -e ./browsergym/ + pip install -e ./browsergym/core -e ./browsergym/miniwob -e ./browsergym/webarena -e ./browsergym/visualwebarena/ -e ./browsergym/experiments -e ./browsergym/assistantbench playwright install chromium install-demo: diff --git a/browsergym/visualwebarena/pyproject.toml b/browsergym/visualwebarena/pyproject.toml index 0e0669c2..035fe16e 100644 --- a/browsergym/visualwebarena/pyproject.toml +++ b/browsergym/visualwebarena/pyproject.toml @@ -25,6 +25,9 @@ dynamic = ["dependencies", "version"] [project.urls] homepage = "https://github.com/ServiceNow/BrowserGym" +[project.optional-dependencies] +torch = ["torch>=2.0.0"] + [tool.hatch.version] path = "../core/src/browsergym/core/__init__.py" diff --git a/browsergym/visualwebarena/requirements.txt b/browsergym/visualwebarena/requirements.txt index 6bdbb875..d6e608df 100644 --- a/browsergym/visualwebarena/requirements.txt +++ b/browsergym/visualwebarena/requirements.txt @@ -2,4 +2,3 @@ browsergym-core==0.14.2 browsergym-webarena libvisualwebarena==0.0.15 requests -torch diff --git a/browsergym/visualwebarena/src/browsergym/visualwebarena/__init__.py b/browsergym/visualwebarena/src/browsergym/visualwebarena/__init__.py index 618875c3..bf26a6d5 100644 --- a/browsergym/visualwebarena/src/browsergym/visualwebarena/__init__.py +++ b/browsergym/visualwebarena/src/browsergym/visualwebarena/__init__.py @@ -1,4 +1,5 @@ import nltk +import importlib.util from browsergym.core.registration import register_task @@ -11,6 +12,11 @@ except: nltk.download("punkt_tab", quiet=True, raise_on_error=True) +if importlib.util.find_spec("torch") is None: + raise ImportError( + "The 'torch' package is required for VisualWebArena tasks evaluation. Please install it with 'pip install torch'." + ) + ALL_VISUALWEBARENA_TASK_IDS = [] VISUALWEBARENA_TASK_IDS_WITH_RESET = [] VISUALWEBARENA_TASK_IDS_WITHOUT_RESET = [] diff --git a/browsergym/visualwebarena/src/browsergym/visualwebarena/task.py b/browsergym/visualwebarena/src/browsergym/visualwebarena/task.py index 77c0dd40..aa908c01 100644 --- a/browsergym/visualwebarena/src/browsergym/visualwebarena/task.py +++ b/browsergym/visualwebarena/src/browsergym/visualwebarena/task.py @@ -185,6 +185,7 @@ def setup(self, page: playwright.sync_api.Page) -> tuple[str, dict]: hide_progress_bar = is_progress_bar_enabled() if hide_progress_bar: disable_progress_bar() + captioning_fn = get_captioning_fn( device=self.eval_captioning_model_device, dtype=(