From 6a920cd52aee65b1e2b25c2eaa36269e4deee42a Mon Sep 17 00:00:00 2001 From: raulikak <43605561+raulikak@users.noreply.github.com> Date: Thu, 2 May 2024 13:28:52 +0300 Subject: [PATCH] Upload improvements (#34) * Dry-run and more compact printout of uploaded files and directories * Mark uploaded files as .file.up * Provide the statement path without 'statement/' for client tool command-line * Also in client tool login, do not expect statement/ in path * Statement reload with DB clearing * Subcommand reload into client tool * Updated client tool help * Flag --force-upload for client tool upload * Control exit delay in reset request * Drop /statement/ from login and api1 endpoints * Documentatoin and tiny fix for observing process exit * Launcher waits for process to get its web server up * File watchdog works as a charm * Pylint fix * Complain with client tool URL ending with / * Fixed that meta-file load order was ignored for subdirs when no data files to upload --- ClientTool.md | 17 +++-- requirements.txt | 1 + tcsfw/client_api.py | 11 ++++ tcsfw/client_tool.py | 130 ++++++++++++++++++++++++++++++++++----- tcsfw/entity_database.py | 3 + tcsfw/http_server.py | 28 ++++++++- tcsfw/launcher.py | 87 +++++++++++++++++++++++--- tcsfw/registry.py | 5 ++ tcsfw/sql_database.py | 14 +++++ 9 files changed, 263 insertions(+), 33 deletions(-) diff --git a/ClientTool.md b/ClientTool.md index 1d0a049..3ec6d9d 100644 --- a/ClientTool.md +++ b/ClientTool.md @@ -11,8 +11,8 @@ The tool main file is `tcsfw/client_tool.py`, which can be called instead. The following prompts for password of user `user1` and then fetches new API key for the ruuvi sample statement. - $ tcwfw get-key --user user1 \ - --url http://192.168.1.1/login/statement/samples/ruuvi/ruuvi + $ tcsfw get-key --user user1 \ + --url http://192.168.1.1/login/samples/ruuvi/ruuvi The API key is printed out, but with argument `--save` it is saved into known file `.tcsfw_api_key` which is read by other client subcommands. From now on, the API key assumed to be saved in this file. @@ -22,13 +22,20 @@ Alternatively, it can be given with `--api-key` command-line argument. Supported tool output files can be uploaded with subcommand `upload`. - $ tcwfw upload \ - --read \ - --url http://192.168.1.1/login/statement/samples/ruuvi/ruuvi + $ tcsfw upload --read \ + --url http://192.168.1.1/samples/ruuvi/ruuvi The uploaded directories and files must stick with the [supported formats](Tools.md). +## Reload + +Security statement can be reloaded, and stored data reapplied with reload subcommand. +JSON parameter can be provided to clear thre DB to avoid reapplying data, e.g.: + + $ tcsfw reset --param '{"clear_db": true}' + --url http://192.168.1.1/samples/ruuvi/ruuvi + ## Disabling certificate validation When dealing with development servers which have TLS enabled, but does not have appropriate certificates, one can use option `--insecure`. Beware, that using this exposes you to rogue servers and MITM attacks. diff --git a/requirements.txt b/requirements.txt index ab9af65..cce2c2b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,4 +3,5 @@ censys aiohttp aiofiles sqlalchemy +watchdog prompt_toolkit diff --git a/tcsfw/client_api.py b/tcsfw/client_api.py index fe7cfdc..e1a77e6 100644 --- a/tcsfw/client_api.py +++ b/tcsfw/client_api.py @@ -143,6 +143,9 @@ def api_post(self, request: APIRequest, data: Optional[BinaryIO]) -> Dict: self.system_reset(param.get("evidence", {}), include_all=param.get("include_all", False)) if param.get("dump_all", False): r = {"events": list(self.api_iterate_all(request.change_path(".")))} + elif path == "reload": + # reload is actually exit + self.api_exit(request, data) elif path.startswith("event/"): e_name = path[6:] e_type = EventMap.get_event_class(e_name) @@ -155,6 +158,14 @@ def api_post(self, request: APIRequest, data: Optional[BinaryIO]) -> Dict: raise FileNotFoundError("Unknown API endpoint") return r + def api_exit(self, _request: APIRequest, data: bytes) -> Dict: + """Reload model""" + param = json.loads(data) if data else {} + clear_db = bool(param.get("clear_db", False)) + if clear_db: + self.registry.clear_database() + return param + def api_post_file(self, request: APIRequest, data_file: pathlib.Path) -> Dict: """Post API data in ZIP file""" path = request.path diff --git a/tcsfw/client_tool.py b/tcsfw/client_tool.py index 252a3e8..62ac4cd 100644 --- a/tcsfw/client_tool.py +++ b/tcsfw/client_tool.py @@ -24,6 +24,8 @@ def __init__(self) -> None: self.auth_token = get_api_key() self.timeout = -1 self.secure = True + self.force_upload = False + self.dry_run = False def run(self): """Run the client tool""" @@ -50,6 +52,14 @@ def run(self): upload_parser.add_argument("--meta", "-m", help="Meta-data in JSON format") upload_parser.add_argument("--url", "-u", default="", help="Server URL") upload_parser.add_argument("--api-key", help="API key for server (avoiding providing by command line)") + upload_parser.add_argument("--force-upload", "-f", action="store_true", help="Force upload for all files") + upload_parser.add_argument("--dry-run", action="store_true", help="Only print files to be uploaded") + + # Subcommand: reload statement + reload_parser = subparsers.add_parser("reload", help="Reset statement") + reload_parser.add_argument("--url", "-u", help="Server URL") + reload_parser.add_argument("--api-key", help="API key for server (avoiding providing by command line)") + reload_parser.add_argument("--param", help="Specify reload parameters by JSON") args = parser.parse_args() logging.basicConfig(format='%(message)s', level=getattr( @@ -60,6 +70,8 @@ def run(self): if args.command == "get-key": self.run_get_key(args) + elif args.command == "reload": + self.run_reload(args) else: self.run_upload(args) @@ -79,7 +91,7 @@ def run_get_key(self, args: argparse.Namespace): base_url = f"{u.scheme}://{u.netloc}" path = urlunparse(('', '', u.path, u.params, u.query, u.fragment)) - login_url = f"{base_url}/login{path}" + login_url = f"{base_url}/login/{path}" self.logger.info("Getting API key from %s", login_url) headers = {"X-User": user_name} # Only in development, when missing authenticator from between resp = requests.get(login_url, timeout=self.timeout, auth=(user_name, user_password), headers=headers, @@ -95,17 +107,50 @@ def run_get_key(self, args: argparse.Namespace): else: print(f"{api_key}") + def run_reload(self, args: argparse.Namespace): + """Run reload subcommand""" + url = args.url.strip() + if not url: + raise ValueError("Missing server URL") + u = urlparse(url) + path = urlunparse(('', '', u.path, u.params, u.query, u.fragment)) + + api_url = self.resolve_api_url(url) + use_url = f"{api_url}/reload{path}" + + params = json.loads(args.param) if args.param else {} + + self.auth_token = args.api_key or self.auth_token + if not self.auth_token: + raise ValueError("Missing API key") + if not self.secure: + self.logger.warning("Disabling TLS verification for server connection") + urllib3.disable_warnings() + self.logger.info("Reset %s...", use_url) + headers = { + "Content-Type": "application/json", + "X-Authorization": self.auth_token, + } + resp = requests.post(use_url, headers=headers, json=params, + timeout=self.timeout, verify=self.secure) + resp.raise_for_status() + self.logger.info("statement reload") + def run_upload(self, args: argparse.Namespace): """Run upload subcommand""" meta_json = json.loads(args.meta) if args.meta else {} + self.dry_run = args.dry_run or False url = args.url.strip() if not url: raise ValueError("Missing server URL") + + self.force_upload = args.force_upload or False + if self.dry_run: + self.logger.warning("DRY RUN, no files will be uploaded") + self.auth_token = args.api_key or self.auth_token if not self.secure: self.logger.warning("Disabling TLS verification for server connection") urllib3.disable_warnings() - if args.api_key: - self.auth_token = args.api_key.strip() if args.read: read_file = pathlib.Path(args.read) if read_file.is_dir(): @@ -132,6 +177,8 @@ def run_upload(self, args: argparse.Namespace): def upload_file(self, url: str, file_data: BinaryIO, meta_json: Dict): """Upload a file""" # create a temporary zip file + if self.dry_run: + return with tempfile.NamedTemporaryFile(suffix='.zip') as temp_file: self.create_zipfile(file_data, meta_json, temp_file) self.upload_file_data(url, temp_file) @@ -146,15 +193,30 @@ def upload_directory(self, url: str, path: pathlib.Path): meta_json = json.load(f) file_load_order = meta_json.get("file_order", []) if file_load_order: - # sort subdirectories based on file_load_order + # sort files by file_load_order files = FileMetaInfo.sort_load_order(files, file_load_order) + else: + meta_json = None # not {} + + # files to upload + to_upload = self.filter_data_files(files) + + if meta_json is not None and to_upload: + to_upload.insert(0, meta_file) # upload also meta file, make it first + self.logger.info("%s", meta_file.as_posix()) # meta file exists -> upload files from here - self.logger.info("Uploading directory %s", path.as_posix()) + self.logger.info("%s/", path.as_posix()) with tempfile.NamedTemporaryFile(suffix='.zip') as temp_file: - self.copy_to_zipfile(files, temp_file) + self.copy_to_zipfile(to_upload, temp_file) self.upload_file_data(url, temp_file) + # mark files as uploaded + if not self.dry_run: + for f in to_upload: + if f != meta_file: + self.mark_uploaded(f) + # visit subdirectories for subdir in files: if subdir.is_dir(): @@ -188,7 +250,7 @@ def copy_to_zipfile(self, files: List[pathlib.Path], temp_file: BinaryIO) -> boo self.logger.warning("File too large: %s (%d > 1024 M)", file.as_posix(), file_size_mb) continue # write content - self.logger.info("Adding %s", file.as_posix()) + self.logger.info("%s", file.as_posix()) zip_info = zipfile.ZipInfo(file.name) with file.open("rb") as file_data: with zip_file.open(zip_info, "w") as of: @@ -200,7 +262,24 @@ def copy_to_zipfile(self, files: List[pathlib.Path], temp_file: BinaryIO) -> boo def upload_file_data(self, url: str, temp_file: BinaryIO): """Upload content zip file into the server""" + if self.dry_run: + return + api_url = self.resolve_api_url(url) + upload_url = f"{api_url}/batch" + headers = { + "Content-Type": "application/zip", + } + if self.auth_token: + headers["X-Authorization"] = self.auth_token + multipart = {"file": temp_file} + resp = requests.post(upload_url, files=multipart, headers=headers, timeout=self.timeout, verify=self.secure) + resp.raise_for_status() + + def resolve_api_url(self, url: str) -> str: + """Query server for API URL""" # split URL into host and statement + if url.endswith("/"): + raise ValueError("URL should not end with /") u = urlparse(url) base_url = f"{u.scheme}://{u.netloc}" path = urlunparse(('', '', u.path, u.params, u.query, u.fragment)) @@ -216,17 +295,34 @@ def upload_file_data(self, url: str, temp_file: BinaryIO): resp = requests.get(query_url, headers=headers, timeout=self.timeout, verify=self.secure) resp.raise_for_status() api_proxy = resp.json().get("api_proxy") + return f"{base_url}/api1" if not api_proxy else f"{base_url}/proxy/{api_proxy}/api1" - upload_url = f"{base_url}/api1/batch" if not api_proxy else f"{base_url}/proxy/{api_proxy}/api1/batch" - headers = { - "Content-Type": "application/zip", - } - if self.auth_token: - headers["X-Authorization"] = self.auth_token - multipart = {"file": temp_file} - resp = requests.post(upload_url, files=multipart, headers=headers, timeout=self.timeout, verify=self.secure) - resp.raise_for_status() - return resp + def filter_data_files(self, files: List[pathlib.Path]) -> List[pathlib.Path]: + """Filter data files""" + r = [] + for f in files: + if not f.is_file() or f.name == "00meta.json": + continue + if f.suffix.lower() in {".meta", ".bak", ".tmp", ".temp"}: + continue + if f.name.startswith(".") or f.name.endswith("~"): + continue + if not self.force_upload and self.is_uploaded(f): + continue + r.append(f) + return r + + @classmethod + def is_uploaded(cls, path: pathlib.Path) -> bool: + """Check if file has been uploaded""" + p = path.parent / f".{path.name}.up" + return p.exists() + + @classmethod + def mark_uploaded(cls, path: pathlib.Path): + """Mark file as uploaded""" + p = path.parent / f".{path.name}.up" + p.touch() def main(): """Main entry point""" diff --git a/tcsfw/entity_database.py b/tcsfw/entity_database.py index 20be5f8..3563e4d 100644 --- a/tcsfw/entity_database.py +++ b/tcsfw/entity_database.py @@ -37,6 +37,9 @@ def put_event(self, event: Event): """Store an event""" raise NotImplementedError() + def clear_database(self): + """Clear the database, from the disk""" + class InMemoryDatabase(EntityDatabase): """Store and retrieve events, later entities, etc.""" diff --git a/tcsfw/http_server.py b/tcsfw/http_server.py index 08553d0..d367d81 100644 --- a/tcsfw/http_server.py +++ b/tcsfw/http_server.py @@ -5,6 +5,7 @@ import json import logging import pathlib +import sys import tempfile import traceback from typing import BinaryIO, Dict, Optional, Tuple, List @@ -76,6 +77,7 @@ async def start_server(self): web.get('/api1/ping', self.handle_ping), # ping for health check web.get('/api1/proxy/{tail:.+}', self.handle_login), # query proxy configuration web.get('/api1/{tail:.+}', self.handle_http), + web.post('/api1/reload/{tail:.+}', self.handle_reload), # reload, kill the process web.post('/api1/{tail:.+}', self.handle_http), ]) rr = web.AppRunner(app) @@ -115,7 +117,7 @@ async def handle_ping(self, _request: web.Request): """Handle ping request""" return web.Response(text="{}") - async def handle_http(self, request): + async def handle_http(self, request: web.Request): """Handle normal HTTP GET or POST request""" try: self.check_permission(request) @@ -200,7 +202,7 @@ async def api_post_zip(self, api_request: APIRequest, request): res = self.api.api_post_file(api_request, pathlib.Path(temp_dir)) return res - async def handle_ws(self, request): + async def handle_ws(self, request: web.Request): """Handle websocket HTTP request""" assert request.path_qs.startswith("/api1/ws/") req = APIRequest.parse(request.path_qs[9:]) @@ -245,7 +247,7 @@ async def receive_loop(): self.channels.remove(channel) return ws - async def handle_login(self, request): + async def handle_login(self, request: web.Request): """Handle login or proxy query, which is launcher job. This should only be used in development.""" req = APIRequest.parse(request.path_qs) try: @@ -258,6 +260,26 @@ async def handle_login(self, request): traceback.print_exc() return web.Response(status=500) + async def handle_reload(self, request: web.Request): + """Handle reload request""" + self.check_permission(request) + req = APIRequest.parse(request.path_qs) + data = await request.content.read() if request.content else b"" + res = self.api.api_exit(req, data) + exit_delay = int(res.get("exit_delay", 1000)) + res = {} # do not return the parameters + + # reload means exiting this process, delay it for response to be sent + def do_exit(): + # return code 0 for successful exit + sys.exit(0) # pylint: disable=consider-using-sys-exit + + if exit_delay > 0: + self.loop.call_later(exit_delay / 1000, do_exit) + else: + do_exit() # no response will be sent + return web.Response(text=json.dumps(res)) + def dump_model(self, channel: WebsocketChannel): """Dump the whole model into channel""" if not channel.subscribed: diff --git a/tcsfw/launcher.py b/tcsfw/launcher.py index ec76ab1..76872c9 100644 --- a/tcsfw/launcher.py +++ b/tcsfw/launcher.py @@ -1,6 +1,7 @@ """Lauch model given from command-line""" import asyncio +from asyncio.subprocess import Process import logging import os import argparse @@ -10,10 +11,14 @@ import subprocess import sys import traceback -from typing import Dict, Set, Tuple +from typing import Dict, Optional, Set, Tuple import aiofiles from aiohttp import web +import aiohttp + +from watchdog.observers import Observer +from watchdog.events import FileSystemEvent, FileSystemEventHandler from tcsfw.client_api import APIRequest from tcsfw.command_basics import get_authorization @@ -29,7 +34,9 @@ def __init__(self): parser.add_argument("-l", "--log", dest="log_level", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help="Set the logging level", default=None) parser.add_argument("--no-db", action="store_true", - help="Do not use DB storage", default=None) + help="Do not use DB storage") + parser.add_argument("--watch", action="store_true", + help="Watch for statement file changes") args = parser.parse_args() self.logger = logging.getLogger("launcher") @@ -42,6 +49,10 @@ def __init__(self): self.api_keys: Dict[str, str] = {} self.api_key_reverse: Dict[str, str] = {} + self.change_observer: Optional[FileChangeObserver] = None + if args.watch: + self.change_observer = FileChangeObserver(self) + self.db_base_dir = None if args.no_db else pathlib.Path("app-dbs") # create sqlite DBs here self.host = None @@ -78,10 +89,10 @@ async def handle_login(self, request: web.Request): if request.method != "GET": raise NotImplementedError("Unexpected method") use_api_key = False - if request.path.startswith("/login/statement/"): - app = request.path[17:] - elif request.path.startswith("/api1/proxy/statement/"): - app = request.path[22:] + if request.path.startswith("/login/"): + app = request.path[7:] + elif request.path.startswith("/api1/proxy/"): + app = request.path[12:] use_api_key = True else: raise FileNotFoundError("Unexpected statement path") @@ -155,7 +166,8 @@ async def run_process(self, key: Tuple[str, str], app: str, api_key: str) -> int self.connected[key] = client_port python_app = f"{app}.py" - if not pathlib.Path(python_app).exists(): + app_file = pathlib.Path(python_app) + if not app_file.exists(): raise FileNotFoundError(f"App not found: {python_app}") args = [sys.executable, python_app, "--http-server", f"{client_port}"] @@ -180,9 +192,12 @@ async def wait_process(): await process.wait() await stdout_task await stderr_task + # free port, but leave keys in place self.clients.remove(client_port) self.connected.pop(key, None) - self.logger.info("Exit code %s from %s at port %s", process.returncode, key_str, client_port) + self.logger.info("Exit code %s from %s at port %d", process.returncode, key_str, client_port) + if self.change_observer: + self.change_observer.update_watch_list(process, remove=app_file.parent) # remove log files os.remove(stdout_file) os.remove(stderr_file) @@ -190,6 +205,25 @@ async def wait_process(): asyncio.create_task(wait_process()) self.logger.info("Launched %s at port %s", key_str, client_port) + + # wait for the process web server to start + ping_url = f"http://localhost:{client_port}/api1/ping" + self.logger.info("Pinging %s...", ping_url) + while True: + if client_port not in self.clients: + self.logger.info("Process failed/killed without starting") + raise FileNotFoundError("Process failed to start") + try: + async with aiohttp.ClientSession() as session: + async with session.get(ping_url) as resp: + if resp.status == 200: + break + except aiohttp.ClientConnectorError: + pass + await asyncio.sleep(0.1) + self.logger.info("...ping OK") + if self.change_observer: + self.change_observer.update_watch_list(process, add=app_file.parent) return client_port async def save_stream_to_file(self, stream, file_path): @@ -209,5 +243,42 @@ def generate_api_key(self, user_name: str) -> str: self.api_key_reverse[key] = user_name return key + +class FileChangeObserver(FileSystemEventHandler): + """Observe file changes""" + def __init__(self, laucher: Launcher): + self.launcher = laucher + self.watch_list: Dict[pathlib.Path, Process] = {} + self.observer = Observer() + self.observer.start() + + def update_watch_list(self, process: Process, + add: Optional[pathlib.Path] = None, remove: Optional[pathlib.Path] = None): + """Update watch list""" + if add: + path = add.as_posix() + self.launcher.logger.info("Adding watch for %s", path) + self.observer.schedule(self, path, recursive=True) + self.watch_list[path] = process + if remove: + path = remove.as_posix() + self.launcher.logger.info("Removing watch for %s", path) + self.watch_list.pop(path, None) + self.observer.unschedule_all() + for p in self.watch_list: + self.observer.schedule(self, p, recursive=True) + + def on_modified(self, event: FileSystemEvent) -> None: + """File modified, reload relevant process, if any""" + proc = self.watch_list.get(event.src_path) + if proc: + del self.watch_list[event.src_path] + self.launcher.logger.info("File modified: %s, reloading process", event.src_path) + try: + proc.kill() + except ProcessLookupError: + self.launcher.logger.info("Process kill failed") + + if __name__ == "__main__": Launcher() diff --git a/tcsfw/registry.py b/tcsfw/registry.py index b0d1a45..3817c40 100644 --- a/tcsfw/registry.py +++ b/tcsfw/registry.py @@ -65,6 +65,11 @@ def reset(self, evidence_filter: Dict[EvidenceSource, bool] = None, enable_all=F self.logging.reset() return self + def clear_database(self) -> Self: + """Clear the database, from the disk""" + self.database.clear_database() + return self + def apply_all_events(self) -> Self: """Apply all stored events, after reset""" while True: diff --git a/tcsfw/sql_database.py b/tcsfw/sql_database.py index bde978e..c98e10b 100644 --- a/tcsfw/sql_database.py +++ b/tcsfw/sql_database.py @@ -1,7 +1,10 @@ """SQL database by SQLAlchemy""" import json +import os +import pathlib from typing import Any, Iterator, List, Optional, Dict, Tuple, Set +from urllib.parse import urlparse from sqlalchemy import Boolean, Column, Integer, String, create_engine, delete, select from sqlalchemy.ext.declarative import declarative_base @@ -50,6 +53,7 @@ class SQLDatabase(EntityDatabase, ModelListener): """Use SQL database for storage""" def __init__(self, db_uri: str): super().__init__() + self.db_uri = db_uri self.engine = create_engine(db_uri) Base.metadata.create_all(self.engine) self.db_conn = self.engine.connect() @@ -70,6 +74,16 @@ def __init__(self, db_uri: str): self.pending_batch = [] self.pending_source_ids = set() + def clear_database(self): + # check if DB is a local file + self.engine.dispose() + u = urlparse(self.db_uri) + if u.scheme.startswith("sqlite") and u.path: + path = pathlib.Path(u.path[1:]) if u.path.startswith("/") else pathlib.Path(u.path) + self.logger.info("Deleting DB file %s if it exists", path) + if path.exists(): + os.remove(path) + def _fill_cache(self): """Fill entity cache from database""" with Session(self.engine) as ses: