Skip to content

Commit

Permalink
Optimise error handling in LTX support
Browse files Browse the repository at this point in the history
  • Loading branch information
acerv committed Jul 7, 2023
1 parent 511a79b commit 7fd7cb7
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 46 deletions.
38 changes: 18 additions & 20 deletions libkirk/ltx.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ def __init__(self, stdin_fd: int, stdout_fd: int) -> None:
self._lock = asyncio.Lock()
self._task = None
self._messages = []
self._exc = None
self._exception = None

async def __aenter__(self) -> None:
"""
Expand Down Expand Up @@ -505,11 +505,11 @@ async def connect(self) -> None:

self._logger.info("Connecting to LTX")

self._exc = None
self._exception = None
self._task = libkirk.create_task(self._polling())

if self.exception():
raise self.exception()
if self._exception:
raise self._exception

self._logger.info("Connected")

Expand All @@ -525,18 +525,14 @@ async def disconnect(self) -> None:

while self.connected:
await asyncio.sleep(0.005)
if self._exception:
raise self._exception

if self.exception():
raise self.exception()
if self._exception:
raise self._exception

self._logger.info("Disconnected")

def exception(self) -> Exception:
"""
Return an exception if error occurs during execution.
"""
return self._exc

async def send(self, requests: list) -> None:
"""
Send requests to LTX service. The order is preserved during
Expand Down Expand Up @@ -574,13 +570,15 @@ async def on_complete(req, *args):
for req in requests:
req.add_done_coro(on_complete)

try:
await self.send(requests)
await self.send(requests)

while len(replies) != req_len:
await asyncio.sleep(0.005)
except LTXError as err:
self._exc = err
while len(replies) != req_len:
await asyncio.sleep(0.005)
if self._exception:
raise self._exception

if self._exception:
raise self._exception

return replies

Expand Down Expand Up @@ -643,13 +641,13 @@ async def _polling(self) -> None:
raise LTXError("Message must be an array")

if msg[0] == Request.ERROR:
raise LTXError(data[1])
raise LTXError(msg[1])

await self._feed_requests(msg)
except msgpack.OutOfData:
break
except LTXError as err:
self._exc = err
self._exception = err
finally:
self._logger.info("Producer has stopped")

Expand Down
25 changes: 13 additions & 12 deletions libkirk/ltx_sut.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ async def stop(self, iobuffer: IOBuffer = None) -> None:
requests.append(Requests.kill(slot_id))

if requests:
await self._send_request(requests)
await self._send_requests(requests)

while self._slots:
await asyncio.sleep(1e-2)
Expand All @@ -103,16 +103,17 @@ async def stop(self, iobuffer: IOBuffer = None) -> None:
if err.errno == 9:
pass

async def _send_request(self, requests: list) -> dict:
async def _send_requests(self, requests: list) -> list:
"""
Wrapper around `ltx.gather` to catch `LTXError` exception.
Send requests and check for LTXError.
"""
replies = await self._ltx.gather(requests)

if self._ltx.exception():
raise SUTError(self._ltx.exception())
reply = None
try:
reply = await self._ltx.gather(requests)
except LTXError as err:
raise SUTError(err)

return replies
return reply

async def _reserve_slot(self) -> int:
"""
Expand Down Expand Up @@ -145,7 +146,7 @@ async def ping(self) -> float:

req = Requests.ping()
start_t = time.monotonic()
replies = await self._send_request([req])
replies = await self._send_requests([req])

return (replies[req][0] * 1e-9) - start_t

Expand All @@ -163,7 +164,7 @@ async def communicate(self, iobuffer: IOBuffer = None) -> None:
except LTXError as err:
raise SUTError(err)

await self._send_request([Requests.version()])
await self._send_requests([Requests.version()])

async def run_command(
self,
Expand Down Expand Up @@ -203,7 +204,7 @@ async def _stdout_coro(data):
stdout_coro=_stdout_coro)

requests.append(exec_req)
replies = await self._send_request(requests)
replies = await self._send_requests(requests)
reply = replies[exec_req]

ret = {
Expand All @@ -230,7 +231,7 @@ async def fetch_file(self, target_path: str) -> bytes:

async with self._fetch_lock:
req = Requests.get_file(target_path)
replies = await self._send_request([req])
replies = await self._send_requests([req])
reply = replies[req]

return reply[1]
14 changes: 0 additions & 14 deletions libkirk/tests/test_ltx.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ async def test_version(self, ltx):
"""
req = Requests.version()
replies = await ltx.gather([req])
assert not ltx.exception()
assert replies[req][0] == "0.1"

async def test_ping(self, ltx):
Expand All @@ -67,7 +66,6 @@ async def test_ping(self, ltx):
start_t = time.monotonic()
req = Requests.ping()
replies = await ltx.gather([req])
assert not ltx.exception()
assert start_t < replies[req][0] * 1e-9 < time.monotonic()

async def test_execute(self, ltx):
Expand All @@ -82,7 +80,6 @@ async def _stdout_coro(data):
start_t = time.monotonic()
req = Requests.execute(0, "uname", stdout_coro=_stdout_coro)
replies = await ltx.gather([req])
assert not ltx.exception()
reply = replies[req]

assert ''.join(stdout) == "Linux\n"
Expand All @@ -104,7 +101,6 @@ async def _stdout_coro(data):
req = Requests.execute(
0, "echo -n ciao", stdout_coro=_stdout_coro)
replies = await ltx.gather([req])
assert not ltx.exception()
reply = replies[req]

assert ''.join(stdout) == "ciao"
Expand Down Expand Up @@ -134,8 +130,6 @@ async def _stdout_coro(data):
replies = await ltx.gather(req)
end_t = time.monotonic()

assert not ltx.exception()

for reply in replies.values():
assert start_t < reply[0] * 1e-9 < end_t
assert reply[1] == 1
Expand All @@ -154,7 +148,6 @@ async def test_set_file(self, ltx, tmp_path):

req = Requests.set_file(str(pfile), data)
await ltx.gather([req])
assert not ltx.exception()

assert pfile.read_bytes() == data

Expand All @@ -167,7 +160,6 @@ async def test_get_file(self, ltx, tmp_path):

req = Requests.get_file(str(pfile))
replies = await ltx.gather([req])
assert not ltx.exception()

assert replies[req][0] == str(pfile)
assert pfile.read_bytes() == replies[req][1]
Expand All @@ -180,7 +172,6 @@ async def test_kill(self, ltx):
exec_req = Requests.execute(0, "sleep 1")
kill_req = Requests.kill(0)
replies = await ltx.gather([exec_req, kill_req])
assert not ltx.exception()
reply = replies[exec_req]

assert start_t < reply[0] * 1e-9 < time.monotonic()
Expand All @@ -196,7 +187,6 @@ async def test_env(self, ltx):
env_req = Requests.env(0, "HELLO", "CIAO")
exec_req = Requests.execute(0, "echo -n $HELLO")
replies = await ltx.gather([env_req, exec_req])
assert not ltx.exception()
reply = replies[exec_req]

assert start_t < reply[0] * 1e-9 < time.monotonic()
Expand All @@ -212,7 +202,6 @@ async def test_env_multiple(self, ltx):
env_req = Requests.env(128, "HELLO", "CIAO")
exec_req = Requests.execute(0, "echo -n $HELLO")
replies = await ltx.gather([env_req, exec_req])
assert not ltx.exception()
reply = replies[exec_req]

assert start_t < reply[0] * 1e-9 < time.monotonic()
Expand All @@ -230,7 +219,6 @@ async def test_cwd(self, ltx, tmpdir):
env_req = Requests.cwd(0, path)
exec_req = Requests.execute(0, "echo -n $PWD")
replies = await ltx.gather([env_req, exec_req])
assert not ltx.exception()
reply = replies[exec_req]

assert start_t < reply[0] * 1e-9 < time.monotonic()
Expand All @@ -248,7 +236,6 @@ async def test_cwd_multiple(self, ltx, tmpdir):
env_req = Requests.cwd(128, path)
exec_req = Requests.execute(0, "echo -n $PWD")
replies = await ltx.gather([env_req, exec_req])
assert not ltx.exception()
reply = replies[exec_req]

assert start_t < reply[0] * 1e-9 < time.monotonic()
Expand All @@ -273,7 +260,6 @@ async def test_all_together(self, ltx, tmp_path):
requests.append(Requests.get_file(str(pfile)))

await ltx.gather(requests)
assert not ltx.exception()


@pytest.fixture
Expand Down

0 comments on commit 7fd7cb7

Please sign in to comment.