diff --git a/alist_sync/common.py b/alist_sync/common.py index 377db3f..579c3be 100644 --- a/alist_sync/common.py +++ b/alist_sync/common.py @@ -46,7 +46,7 @@ def clear_cache(): clear_path(cache_dir) -def sha1(s): +def sha1(s) -> str: return hashlib.sha1(str(s).encode()).hexdigest() diff --git a/alist_sync/worker.py b/alist_sync/worker.py index 2007d0f..bd1bb54 100644 --- a/alist_sync/worker.py +++ b/alist_sync/worker.py @@ -1,7 +1,21 @@ -from typing import Literal, Any - -from pydantic import BaseModel -from alist_sdk.path_lib import AlistPath +import atexit +import datetime +import logging +from pathlib import Path +from typing import Literal, Any, Annotated + +from pydantic import ( + BaseModel, + computed_field, + Field, + AfterValidator, + PlainSerializer, + GetCoreSchemaHandler, +) +from pydantic_core import core_schema +from pymongo.collection import Collection +from pymongo.database import Database +from alist_sdk.path_lib import AlistPath as _AlistPath from alist_sync.config import cache_dir from alist_sync.common import sha1 @@ -15,25 +29,82 @@ "downloading", "uploading", "copied", + "done", +] + +logger = logging.getLogger("alist-sync.worker") + + +class AlistPath(_AlistPath): + @classmethod + def __get_pydantic_core_schema__( + cls, + source_type: Any, + handler: GetCoreSchemaHandler, + ): + return core_schema.no_info_after_validator_function(cls, handler(source_type)) + + +A_AlistPath = Annotated[ + AlistPath | str, + AfterValidator(lambda x: x if isinstance(x, AlistPath) else AlistPath(x)), + PlainSerializer(lambda x: x.as_uri(), return_type=str), ] class Worker(BaseModel): + owner: str + created_at: datetime.datetime = datetime.datetime.now() type: WorkerType need_backup: bool - source_path: AlistPath - target_path: AlistPath - backup_dir: AlistPath | None = None - status: WorkerStatus + backup_dir: A_AlistPath | None = None + + source_path: A_AlistPath + target_path: A_AlistPath | None = None + status: WorkerStatus = "init" + error_info: BaseException | None = None + + # 私有属性 + workers: "Workers | None" = Field(None, exclude=True) + collection: Collection | None = Field(None, exclude=True) + + model_config = { + "arbitrary_types_allowed": True, + "excludes": {"workers", "collection", "tmp_file"}, + } + + @computed_field(return_type=str) + @property + def _id(self) -> str: + return sha1(f"{self.type}{self.source_path}{self.created_at}") + + @property + def tmp_file(self) -> Path: + return cache_dir.joinpath(f"download_tmp_{sha1(self.source_path)}") + + def update(self, *field: Any): + if self.status == "done" and self.workers is not None: + return self.workers.del_worker(self._id) + return self.update_mongo(*field) + + def update_mongo(self, *field): + """""" + + if field == (): + data = self.model_dump(mode="json") + else: + data = {k: self.__getattr__(k) for k in field} + + logger.debug("更新Worker: %s", data) + return self.collection.update_one( + {"_id": self._id}, + {"$set": data}, + True if field == () else False, + ) def __del__(self): self.tmp_file.unlink(missing_ok=True) - def __init__(self, **data: Any): - super().__init__(**data) - - self.tmp_file = cache_dir.joinpath(f"download_tmp_{sha1(self.source_path)}") - def backup(self): """备份""" if self.backup_dir is None: @@ -56,9 +127,67 @@ def run(self): """启动Worker""" if self.need_backup: self.backup() + if self.type == "copy": self.copy_type() elif self.type == "delete": self.delete_type() +class Workers(BaseModel): + workers: list[Worker] = [] + mongodb: Database + + model_config = {"arbitrary_types_allowed": True} + + def __init__(self, **data: Any): + super().__init__(**data) + + atexit.register(self.__del__) + + def __del__(self): + for i in cache_dir.iterdir(): + if i.name.startswith("download_tmp_"): + i.unlink(missing_ok=True) + + def load_from_mongo(self): + """从MongoDB加载Worker""" + for i in self.mongodb.workers.find(): + self.workers.append(Worker(**i)) + + def add_worker(self, worker: Worker): + self.workers.append(worker) + + def del_worker(self, _id: str): + """删除Worker""" + pass + + def run(self): + for worker in self.workers: + worker.run() + + +if __name__ == "__main__": + import os + from pymongo import MongoClient + from pymongo.server_api import ServerApi + + logging.basicConfig(level=logging.DEBUG) + + uri = os.environ["MONGODB_URI"] + + client = MongoClient(uri, server_api=ServerApi("1")) + # w = Worker( + # owner="admin", + # type="delete", + # need_backup=True, + # backer_dir=AlistPath("http://localhost:5244/local/.history"), + # source_path="http://localhost:5244/local/test.txt", + # collection=client.get_default_database().get_collection("workers"), + # ) + # print(w.update()) + + ws = Workers(mongodb=client.get_default_database()) + ws.load_from_mongo() + for w in ws.workers: + print(w) diff --git a/tests/test_worker.py b/tests/test_worker.py new file mode 100644 index 0000000..af4db78 --- /dev/null +++ b/tests/test_worker.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +@File Name : test_worker.py +@Author : LeeCQ +@Date-Time : 2024/2/24 18:30 +""" + +import sys +import pytest + +from alist_sdk.path_lib import PureAlistPath, AlistPath, login_server + +from alist_sync.worker import Worker, Workers + +# 如果Python版本是3.12跳过模块 +if sys.version_info >= (3, 12): + pytest.skip("Skip this module on Python 3.12", allow_module_level=True) + + +def test_worker(): + from pymongo import MongoClient + from pymongo.server_api import ServerApi + + uri = ( + "mongodb+srv://alist-sync:alist-sync-p@a1.guggt7c.mongodb.net/alist_sync?" + "retryWrites=true&w=majority&appName=A1" + ) + + client = MongoClient(uri, server_api=ServerApi("1")) + w = Worker( + owner="admin", + type="delete", + need_backup=True, + backer_dir=AlistPath("http://localhost:5244/local/.history"), + source_path="http://localhost:5244/local/test.txt", + collection=client.get_default_database().get_collection("workers"), + ) + print(w.update()) + + +def test_workers(): + import os + from pymongo import MongoClient + from pymongo.server_api import ServerApi + + uri = os.environ["MONGODB_URI"] + + client = MongoClient(uri, server_api=ServerApi("1")) + ws = Workers() + ws.load_from_mongo()