Skip to content

Commit

Permalink
Automatically start nameserver and improve tests.
Browse files Browse the repository at this point in the history
Currently, tests are skipped for CMAESBenchmark since serialization does not work yet (#107)
  • Loading branch information
maximilianreimer committed Nov 19, 2021
1 parent 8ab4d16 commit e3caeb5
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 23 deletions.
31 changes: 30 additions & 1 deletion dacbench/container/container_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import gym
import numpy as np
import enum
import time
import socket

from typing import Any, Union, Tuple, List, Dict

Expand Down Expand Up @@ -139,4 +141,31 @@ def deserialize_random_state(random_state_dict: Dict) -> np.random.RandomState:
def serialize_random_state(random_state: np.random.RandomState) -> Tuple[int, List, int, int, int]:
(rnd0, rnd1, rnd2, rnd3, rnd4) = random_state.get_state()
rnd1 = rnd1.tolist()
return {'__type__': 'random_state', '__items__': [rnd0, rnd1, rnd2, rnd3, rnd4]}
return {'__type__': 'random_state', '__items__': [rnd0, rnd1, rnd2, rnd3, rnd4]}




def wait_for_port(port, host='localhost', timeout=5.0):
"""
Taken from https://gist.github.com/butla/2d9a4c0f35ea47b7452156c96a4e7b12
Wait until a port starts accepting TCP connections.
Args:
port (int): Port number.
host (str): Host address on which the port should exist.
timeout (float): In seconds. How long to wait before raising errors.
Raises:
TimeoutError: The port isn't accepting connection after time specified in `timeout`.
"""
start_time = time.perf_counter()
while True:
try:
with socket.create_connection((host, port), timeout=timeout):
break
except OSError as ex:
time.sleep(0.01)
if time.perf_counter() - start_time >= timeout:
raise TimeoutError('Waited too long for the port {} on host {} to start accepting '
'connections.'.format(port, host)) from ex


40 changes: 29 additions & 11 deletions dacbench/container/remote_runner.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,36 @@
import os
import sys
from typing import Tuple

from icecream import ic

os.environ["PYRO_LOGFILE"] = "pyro.log"
os.environ["PYRO_LOGLEVEL"] = "DEBUG"
import logging

from dacbench.abstract_agent import AbstractDACBenchAgent
from dacbench.abstract_benchmark import objdict, AbstractBenchmark
import Pyro4
from dacbench.abstract_benchmark import AbstractBenchmark
import Pyro4, Pyro4.naming

from dacbench.container.remote_env import RemoteEnvironmentServer, RemoteEnvironmentClient

# Needed in order to combine event loops of name_server and daemon
Pyro4.config.SERVERTYPE = "multiplex"

# Read in the verbosity level from the environment variable
log_level_str = os.environ.get('DACBENCH_DEBUG', 'false')
LOG_LEVEL = logging.DEBUG if log_level_str == 'true' else logging.INFO

root = logging.getLogger()
root.setLevel(level=LOG_LEVEL)

logger = logging.getLogger(__name__)
logger.setLevel(level=LOG_LEVEL)

# This option improves the quality of stacktraces if a container crashes
sys.excepthook = Pyro4.util.excepthook
#os.environ["PYRO_LOGFILE"] = "pyro.log"
#os.environ["PYRO_LOGLEVEL"] = "DEBUG"

# Number of tries to connect to server
MAX_TRIES = 5


@Pyro4.expose
class RemoteRunnerServer:
Expand All @@ -25,10 +44,9 @@ def start(self, config : str, benchmark : Tuple[str, str]):
self.benchmark = benchmark.from_json(config)

def get_environment(self) -> str:
ic(self.benchmark)

env = self.benchmark.get_environment()
ic(env)

# set up logger and stuff

self.env = RemoteEnvironmentServer(env)
Expand All @@ -48,7 +66,6 @@ def __init__(self, benchmark : AbstractBenchmark, factory_uri : str = "PYRONAME:

serialized_config = benchmark.to_json()
serialized_type = benchmark.class_to_str()
ic(serialized_config)
self.remote_runner.start(serialized_config, serialized_type)
self.env = None

Expand Down Expand Up @@ -94,10 +111,11 @@ def __call__(self):

PORT = 8888
HOST = "localhost" # add arguments
name_server = Pyro4.locateNS()
name_server_uir, name_server_daemon, _ = Pyro4.naming.startNS()
daemon = Pyro4.Daemon(HOST, PORT)
daemon.combine(name_server_daemon)
factory = RemoteRunnerServerFactory(daemon)
factory_uri = daemon.register(factory)
name_server.register("RemoteRunnerServerFactory", factory_uri)
name_server_daemon.nameserver.register("RemoteRunnerServerFactory", factory_uri)

daemon.requestLoop()
27 changes: 16 additions & 11 deletions tests/container/test_remote_runner.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,29 @@
import logging
import subprocess
import unittest
import signal

from icecream import ic


from dacbench.agents import StaticAgent, RandomAgent
import dacbench.benchmarks
from dacbench.container.remote_runner import RemoteRunner
from dacbench.run_baselines import DISCRETE_ACTIONS
from dacbench.container.container_utils import wait_for_port

# todo load from config if existent
PORT = 8888
HOST = 'localhost'

from time import sleep
class TestRemoteRunner(unittest.TestCase):
def setUp(self) -> None:
self.name_server_process = subprocess.Popen(
[
"pyro4-ns"
]
)
sleep(1)
self.daemon_process = subprocess.Popen(
[
"python",
"dacbench/container/remote_runner.py"
]
)
sleep(1)
wait_for_port(PORT, HOST)



Expand All @@ -34,9 +33,17 @@ def run_agent_on_benchmark_test(self, benchmark, agent_creation_function):
remote_runner.run(agent, 1)

def test_step(self):
skip_benchmarks = ['CMAESBenchmark']
benchmarks = dacbench.benchmarks .__all__[1:]

for benchmark in benchmarks:
if benchmark in skip_benchmarks:
continue
# todo Skipping since serialization is not done yet. https://github.com/automl/DACBench/issues/107
# self.skipTest(reason="Skipping since serialization is not done yet. https://github.com/automl/DACBench/issues/107")
if benchmark not in DISCRETE_ACTIONS:
logging.warning(f"Skipping test for {benchmark} since no discrete actions are available")
continue
benchmark_class = getattr(dacbench.benchmarks, benchmark)
benchmark_instance = benchmark_class()

Expand All @@ -47,12 +54,10 @@ def test_step(self):
]
for agent_creation_function, agent_info in agent_creation_functions:
with self.subTest(msg=f"[Benchmark]{benchmark}, [Agent]{agent_info}", agent_creation_function=agent_creation_function, benchmark=benchmark):
ic(benchmark, agent_info)
self.run_agent_on_benchmark_test(benchmark_instance, agent_creation_function)




def tearDown(self) -> None:
self.name_server_process.send_signal(signal.SIGTERM)
self.daemon_process.send_signal(signal.SIGTERM)

0 comments on commit e3caeb5

Please sign in to comment.