From 89336327666692c4f68b1cfdea7e8585ecee9989 Mon Sep 17 00:00:00 2001 From: Jean Humann Date: Tue, 1 Oct 2024 23:47:02 +0200 Subject: [PATCH 1/2] feat: handle multiple files as submodules Signed-off-by: Jean Humann --- server/src/main/resources/shell_wrapper.py | 131 +++++++++++++++++---- 1 file changed, 109 insertions(+), 22 deletions(-) diff --git a/server/src/main/resources/shell_wrapper.py b/server/src/main/resources/shell_wrapper.py index 88e3e312..b94c57d2 100644 --- a/server/src/main/resources/shell_wrapper.py +++ b/server/src/main/resources/shell_wrapper.py @@ -5,6 +5,7 @@ import traceback import os import logging + sys_stdin = sys.stdin sys_stdout = sys.stdout @@ -13,12 +14,14 @@ logging.basicConfig(stream=sys.stdout, level=log_level) log = logging.getLogger("session") + def setup_output(): sys.stdout.flush() sys.stderr.flush() sys.stdout = io.StringIO() sys.stderr = io.StringIO() + def _do_with_retry(attempts, action): attempts_left = attempts last_exception = None @@ -48,13 +51,14 @@ def do_read(): str_line = sys_stdin.readline() line = json.loads(str_line) return [line] - return _do_with_retry(2, do_read) + return _do_with_retry(2, do_read) def write(self, id, result): def do_write(): print(json.dumps(result), file=sys_stdout) sys_stdout.flush() + _do_with_retry(2, do_write) @@ -62,29 +66,50 @@ class GatewayController(Controller): def __init__(self, session_id): super().__init__(session_id) from py4j.java_gateway import JavaGateway, GatewayParameters + port = int(os.environ.get("PY_GATEWAY_PORT")) host = os.environ.get("PY_GATEWAY_HOST") - self.gateway = JavaGateway(gateway_parameters=GatewayParameters( - address=host, port=port, auto_convert=True)) + self.gateway = JavaGateway( + gateway_parameters=GatewayParameters( + address=host, port=port, auto_convert=True + ) + ) self.endpoint = self.gateway.entry_point def read(self): return _do_with_retry( 3, lambda: [ - {"id": stmt.getId(), "code": stmt.getCode()} for stmt in self.endpoint.statementsToProcess(self.session_id) - ] + {"id": stmt.getId(), "code": stmt.getCode()} + for stmt in self.endpoint.statementsToProcess(self.session_id) + ], ) def write(self, id, result): _do_with_retry( - 3, - lambda: self.endpoint.handleResponse(self.session_id, id, result) + 3, lambda: self.endpoint.handleResponse(self.session_id, id, result) ) -class CommandHandler: +def is_url(words: str) -> bool: + import re + + log.info(f"Checking if {words} is a URL") + length = len(words.split(" ")) + if length != 1: + log.error(f"Not a single word: {words} ({length} words)") + return False + + match = re.match(r'^https?://\S+$', words) + if match: + log.info(f"Matched: {match.group()}") + return bool(match) + else: + log.error(f"Not matched: {words}") + return False + +class CommandHandler: def __init__(self, globals) -> None: self.globals = globals @@ -98,22 +123,87 @@ def _error_response(self, error): } def _exec_then_eval(self, code): - block = ast.parse(code, mode='exec') + block = ast.parse(code, mode="exec") # assumes last node is an expression last = ast.Interactive([block.body.pop()]) - exec(compile(block, '', 'exec'), self.globals) - exec(compile(last, '', 'single'), self.globals) + exec(compile(block, "", "exec"), self.globals) + exec(compile(last, "", "single"), self.globals) + def _download_then_exec(self, url): + temp_dir = self._download_and_extract(url) + try: + self._execute_main_file(temp_dir) + finally: + temp_dir.cleanup() + + @staticmethod + def _download_and_extract(url: str): + import tempfile + import zipfile + import requests + + temp_dir = tempfile.TemporaryDirectory() + temp_file_path = os.path.join(temp_dir.name, "export") + + response = requests.get(url) + with open(temp_file_path, "wb") as f: + f.write(response.content) + + with zipfile.ZipFile(temp_file_path, "r") as zip_ref: + zip_ref.extractall(temp_dir.name) + + return temp_dir + + @staticmethod + def _add_to_pythonpath(temp_dir): + log.info(f"Adding {temp_dir.name} to pythonpath") + sys.path.append(temp_dir.name) + + @staticmethod + def _remove_from_pythonpath(temp_dir): + log.info(f"Removing {temp_dir.name} from pythonpath") + sys.path.remove(temp_dir.name) + + @staticmethod + def _remove_module(temp_dir): + modules_to_remove = [] + for name, mod in sys.modules.items(): + try: + if hasattr(mod, '__spec__') and mod.__spec__ and mod.__spec__.origin and temp_dir.name in mod.__spec__.origin: + modules_to_remove.append(name) + except AttributeError: + continue + except TypeError: + continue + + for name in modules_to_remove: + log.info(f"Unloading {name}") + del sys.modules[name] + + def _execute_main_file(self, temp_dir): + self._add_to_pythonpath(temp_dir) + main_file_path = os.path.join(temp_dir.name, "main.py") + with open(main_file_path, "r") as f: + log.info(f"Executing {main_file_path}") + self._exec_then_eval(f.read()) + self._remove_from_pythonpath(temp_dir) + self._remove_module(temp_dir) + + + def _exec_code(self, code): + if not is_url(code): + self._exec_then_eval(code) + else: + self._download_then_exec(code) def exec(self, request): try: code = request["code"].rstrip() if code: - self._exec_then_eval(code) + self._exec_code(code) return {"content": {"text/plain": str(sys.stdout.getvalue()).rstrip()}} - return {"content": {"text/plain": ""}} except Exception as e: log.exception(e) @@ -126,10 +216,7 @@ def init_globals(name): from pyspark.sql import SparkSession - spark = SparkSession \ - .builder \ - .appName(name) \ - .getOrCreate() + spark = SparkSession.builder.appName(name).getOrCreate() return {"spark": spark} @@ -138,8 +225,9 @@ def main(): setup_output() session_id = os.environ.get("LIGHTER_SESSION_ID") log.info(f"Initiating session {session_id}") - controller = TestController( - session_id) if is_test else GatewayController(session_id) + controller = ( + TestController(session_id) if is_test else GatewayController(session_id) + ) handler = CommandHandler(init_globals(session_id)) log.info("Starting session loop") @@ -153,11 +241,10 @@ def main(): log.debug("Response sent") except: exc_type, exc_value, exc_tb = sys.exc_info() - log.error( - f"Error: {traceback.format_exception(exc_type, exc_value, exc_tb)}") + log.error(f"Error: {traceback.format_exception(exc_type, exc_value, exc_tb)}") log.info("Exiting") return 1 -if __name__ == '__main__': +if __name__ == "__main__": sys.exit(main()) From d552e50b002393f51ead08a460e117944109a0db Mon Sep 17 00:00:00 2001 From: Jean Humann Date: Thu, 3 Oct 2024 14:03:13 +0200 Subject: [PATCH 2/2] feat: handle single file from presigned URL Signed-off-by: Jean Humann --- server/src/main/resources/shell_wrapper.py | 214 ++++++++++----------- 1 file changed, 99 insertions(+), 115 deletions(-) diff --git a/server/src/main/resources/shell_wrapper.py b/server/src/main/resources/shell_wrapper.py index b94c57d2..ba8e9400 100644 --- a/server/src/main/resources/shell_wrapper.py +++ b/server/src/main/resources/shell_wrapper.py @@ -5,6 +5,13 @@ import traceback import os import logging +import re +import tempfile +import zipfile +import requests +from typing import Callable, Any, List, Dict +from pathlib import Path + sys_stdin = sys.stdin sys_stdout = sys.stdout @@ -22,53 +29,41 @@ def setup_output(): sys.stderr = io.StringIO() -def _do_with_retry(attempts, action): - attempts_left = attempts - last_exception = None - while attempts_left: +def retry(attempts: int, action: Callable[[], Any]) -> Any: + for _ in range(attempts): try: return action() except Exception as e: last_exception = e - attempts_left -= 1 raise last_exception class Controller: - def __init__(self, session_id): + def __init__(self, session_id: str): self.session_id = session_id - def read(self): + def read(self) -> List[Dict[str, Any]]: return [] - def write(self, _id, _result): + def write(self, _id: str, _result: Dict[str, Any]) -> None: pass class TestController(Controller): - def read(self): - def do_read(): - str_line = sys_stdin.readline() - line = json.loads(str_line) - return [line] - - return _do_with_retry(2, do_read) + def read(self) -> List[Dict[str, Any]]: + return retry(2, lambda: [json.loads(sys_stdin.readline())]) - def write(self, id, result): - def do_write(): - print(json.dumps(result), file=sys_stdout) - sys_stdout.flush() - - _do_with_retry(2, do_write) + def write(self, id: str, result: Dict[str, Any]) -> None: + retry(2, lambda: print(json.dumps(result), file=sys_stdout, flush=True)) class GatewayController(Controller): - def __init__(self, session_id): + def __init__(self, session_id: str): super().__init__(session_id) from py4j.java_gateway import JavaGateway, GatewayParameters - port = int(os.environ.get("PY_GATEWAY_PORT")) - host = os.environ.get("PY_GATEWAY_HOST") + port = int(os.environ.get("PY_GATEWAY_PORT", "0")) + host = os.environ.get("PY_GATEWAY_HOST", "") self.gateway = JavaGateway( gateway_parameters=GatewayParameters( address=host, port=port, auto_convert=True @@ -76,8 +71,8 @@ def __init__(self, session_id): ) self.endpoint = self.gateway.entry_point - def read(self): - return _do_with_retry( + def read(self) -> List[Dict[str, Any]]: + return retry( 3, lambda: [ {"id": stmt.getId(), "code": stmt.getCode()} @@ -85,145 +80,133 @@ def read(self): ], ) - def write(self, id, result): - _do_with_retry( - 3, lambda: self.endpoint.handleResponse(self.session_id, id, result) - ) + def write(self, id: str, result: Dict[str, Any]) -> None: + retry(3, lambda: self.endpoint.handleResponse(self.session_id, id, result)) def is_url(words: str) -> bool: - import re - - log.info(f"Checking if {words} is a URL") - length = len(words.split(" ")) - if length != 1: - log.error(f"Not a single word: {words} ({length} words)") + log.debug(f"Checking if {words} is a URL") + if len(words.split()) != 1: + log.debug(f"Not a single word: {words}") return False - - match = re.match(r'^https?://\S+$', words) + + match = re.match(r"^https?://\S+$", words) if match: - log.info(f"Matched: {match.group()}") - return bool(match) - else: - log.error(f"Not matched: {words}") - return False + log.debug(f"Matched: {match.group()}") + log.info(f"URL matched: {words}") + return True + log.debug(f"Not matched: {words}") + log.info("URL not matched") + return False class CommandHandler: - def __init__(self, globals) -> None: + def __init__(self, globals: Dict[str, Any]): self.globals = globals + self.code_file = "download" - def _error_response(self, error): + def _error_response(self, error: Exception) -> Dict[str, Any]: exc_type, exc_value, exc_tb = sys.exc_info() return { - "content": {"text/plain": str(sys.stdout.getvalue()).rstrip()}, + "content": {"text/plain": sys.stdout.getvalue().rstrip()}, "error": type(error).__name__, "message": str(error), "traceback": traceback.format_exception(exc_type, exc_value, exc_tb), } - def _exec_then_eval(self, code): + def _exec_then_eval(self, code: str) -> None: block = ast.parse(code, mode="exec") - - # assumes last node is an expression last = ast.Interactive([block.body.pop()]) - exec(compile(block, "", "exec"), self.globals) exec(compile(last, "", "single"), self.globals) - def _download_then_exec(self, url): - temp_dir = self._download_and_extract(url) - try: - self._execute_main_file(temp_dir) - finally: - temp_dir.cleanup() - - @staticmethod - def _download_and_extract(url: str): - import tempfile - import zipfile - import requests - - temp_dir = tempfile.TemporaryDirectory() - temp_file_path = os.path.join(temp_dir.name, "export") - + def _download_then_exec(self, url: str) -> None: + with tempfile.TemporaryDirectory() as temp_dir: + self._download_and_extract(url, temp_dir) + main_file = Path(temp_dir) / "main.py" + if main_file.exists(): + self._execute_main_file(temp_dir) + else: + code_file = Path(temp_dir) / self.code_file + with code_file.open("r") as f: + self._exec_then_eval(f.read()) + + def _download_and_extract(self, url: str, temp_dir: str) -> None: + temp_file_path = Path(temp_dir) / self.code_file + self._download(url, temp_file_path) + if self._is_zip(temp_file_path): + self._extract(temp_file_path) + + def _download(self, url: str, temp_file_path: Path) -> None: response = requests.get(url) - with open(temp_file_path, "wb") as f: - f.write(response.content) + response.raise_for_status() # Raise an exception for bad status codes + temp_file_path.write_bytes(response.content) - with zipfile.ZipFile(temp_file_path, "r") as zip_ref: - zip_ref.extractall(temp_dir.name) - - return temp_dir - @staticmethod - def _add_to_pythonpath(temp_dir): - log.info(f"Adding {temp_dir.name} to pythonpath") - sys.path.append(temp_dir.name) - + def _is_zip(file_path: Path) -> bool: + return zipfile.is_zipfile(file_path) + @staticmethod - def _remove_from_pythonpath(temp_dir): - log.info(f"Removing {temp_dir.name} from pythonpath") - sys.path.remove(temp_dir.name) + def _extract(file_path: Path) -> None: + with zipfile.ZipFile(file_path, "r") as zip_ref: + zip_ref.extractall(path=file_path.parent) + + def _execute_main_file(self, temp_dir: str) -> None: + sys.path.insert(0, temp_dir) + try: + main_file_path = Path(temp_dir) / "main.py" + with main_file_path.open("r") as f: + log.info(f"Executing {main_file_path}") + self._exec_then_eval(f.read()) + finally: + sys.path.remove(temp_dir) + self._remove_modules(temp_dir) @staticmethod - def _remove_module(temp_dir): - modules_to_remove = [] - for name, mod in sys.modules.items(): - try: - if hasattr(mod, '__spec__') and mod.__spec__ and mod.__spec__.origin and temp_dir.name in mod.__spec__.origin: - modules_to_remove.append(name) - except AttributeError: - continue - except TypeError: - continue - + def _remove_modules(temp_dir: str) -> None: + modules_to_remove = [ + name + for name, mod in sys.modules.items() + if hasattr(mod, "__spec__") + and mod.__spec__ + and mod.__spec__.origin + and temp_dir in mod.__spec__.origin + ] for name in modules_to_remove: log.info(f"Unloading {name}") del sys.modules[name] - def _execute_main_file(self, temp_dir): - self._add_to_pythonpath(temp_dir) - main_file_path = os.path.join(temp_dir.name, "main.py") - with open(main_file_path, "r") as f: - log.info(f"Executing {main_file_path}") - self._exec_then_eval(f.read()) - self._remove_from_pythonpath(temp_dir) - self._remove_module(temp_dir) - - - def _exec_code(self, code): - if not is_url(code): - self._exec_then_eval(code) - else: + def _exec_code(self, code: str) -> None: + if is_url(code): self._download_then_exec(code) + else: + self._exec_then_eval(code) - def exec(self, request): + def exec(self, request: Dict[str, str]) -> Dict[str, Any]: try: code = request["code"].rstrip() if code: self._exec_code(code) - return {"content": {"text/plain": str(sys.stdout.getvalue()).rstrip()}} + return {"content": {"text/plain": sys.stdout.getvalue().rstrip()}} return {"content": {"text/plain": ""}} except Exception as e: log.exception(e) return self._error_response(e) -def init_globals(name): +def init_globals(name: str) -> Dict[str, Any]: if is_test: return {} from pyspark.sql import SparkSession spark = SparkSession.builder.appName(name).getOrCreate() - return {"spark": spark} -def main(): +def main() -> int: setup_output() - session_id = os.environ.get("LIGHTER_SESSION_ID") + session_id = os.environ.get("LIGHTER_SESSION_ID", "") log.info(f"Initiating session {session_id}") controller = ( TestController(session_id) if is_test else GatewayController(session_id) @@ -239,11 +222,12 @@ def main(): result = handler.exec(command) controller.write(command["id"], result) log.debug("Response sent") - except: - exc_type, exc_value, exc_tb = sys.exc_info() - log.error(f"Error: {traceback.format_exception(exc_type, exc_value, exc_tb)}") - log.info("Exiting") + except Exception: + log.exception("Error in main loop") return 1 + finally: + log.info("Exiting") + return 0 if __name__ == "__main__":