diff --git a/invenio.cfg b/invenio.cfg index d211678a..4af9daa8 100644 --- a/invenio.cfg +++ b/invenio.cfg @@ -76,7 +76,7 @@ from zenodo_rdm.github.schemas import CitationMetadataSchema from zenodo_rdm.legacy.resources import record_serializers from zenodo_rdm.metrics.config import METRICS_CACHE_UPDATE_INTERVAL from zenodo_rdm.moderation.errors import UserBlockedException -from zenodo_rdm.moderation.handlers import CommunityScoreHandler, RecordScoreHandler +from zenodo_rdm.moderation.handlers import CommunityModerationHandler, RecordModerationHandler from zenodo_rdm.openaire.records.components import OpenAIREComponent from zenodo_rdm.permissions import ( ZenodoCommunityPermissionPolicy, @@ -817,11 +817,11 @@ RDM_RECORDS_SERVICE_COMPONENTS = DefaultRecordsComponents + [ """Addd OpenAIRE component to records service.""" RDM_CONTENT_MODERATION_HANDLERS = [ - RecordScoreHandler(), + RecordModerationHandler(), ] """Records content moderation handlers.""" RDM_COMMUNITY_CONTENT_MODERATION_HANDLERS = [ - CommunityScoreHandler(), + CommunityModerationHandler(), ] """Community content moderation handlers.""" @@ -1062,3 +1062,5 @@ COMMUNITIES_SHOW_BROWSE_MENU_ENTRY = True JOBS_ADMINISTRATION_ENABLED = True """Enable Jobs administration view.""" + +SPAM_DETECTOR_MODEL="spam-scikit:1.0.0" diff --git a/site/setup.cfg b/site/setup.cfg index 7bd0290c..17795307 100644 --- a/site/setup.cfg +++ b/site/setup.cfg @@ -48,12 +48,14 @@ invenio_base.apps = zenodo_rdm_moderation = zenodo_rdm.moderation.ext:ZenodoModeration invenio_openaire = zenodo_rdm.openaire.ext:OpenAIRE zenodo_rdm_stats = zenodo_rdm.stats.ext:ZenodoStats + zenodo_rdm_ml = zenodo_rdm.ml.ext:ZenodoML invenio_base.api_apps = zenodo_rdm_legacy = zenodo_rdm.legacy.ext:ZenodoLegacy profiler = zenodo_rdm.profiler:Profiler zenodo_rdm_metrics = zenodo_rdm.metrics.ext:ZenodoMetrics zenodo_rdm_moderation = zenodo_rdm.moderation.ext:ZenodoModeration invenio_openaire = zenodo_rdm.openaire.ext:OpenAIRE + zenodo_rdm_ml = zenodo_rdm.ml.ext:ZenodoML invenio_base.api_blueprints = zenodo_rdm_legacy = zenodo_rdm.legacy.views:blueprint zenodo_rdm_legacy_records = zenodo_rdm.legacy.views:create_legacy_records_bp diff --git a/site/zenodo_rdm/ml/__init__.py b/site/zenodo_rdm/ml/__init__.py new file mode 100644 index 00000000..4899db37 --- /dev/null +++ b/site/zenodo_rdm/ml/__init__.py @@ -0,0 +1,7 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2024 CERN. +# +# Zenodo-RDM is free software; you can redistribute it and/or modify +# it under the terms of the MIT License; see LICENSE file for more details. +"""Machine learning module.""" diff --git a/site/zenodo_rdm/ml/base.py b/site/zenodo_rdm/ml/base.py new file mode 100644 index 00000000..a85eb6a1 --- /dev/null +++ b/site/zenodo_rdm/ml/base.py @@ -0,0 +1,41 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2024 CERN. +# +# Zenodo-RDM is free software; you can redistribute it and/or modify +# it under the terms of the MIT License; see LICENSE file for more details. +"""Base class for ML models.""" + + +class MLModel: + """Base class for ML models.""" + + def __init__(self, version=None, **kwargs): + """Constructor.""" + self.version = version + + def process(self, data, preprocess=None, postprocess=None, raise_exc=True): + """Pipeline function to call pre/post process with predict.""" + try: + preprocessor = preprocess or self.preprocess + postprocessor = postprocess or self.postprocess + + preprocessed = preprocessor(data) + prediction = self.predict(preprocessed) + return postprocessor(prediction) + except Exception as e: + if raise_exc: + raise e + return None + + def predict(self, data): + """Predict method to be implemented by subclass.""" + raise NotImplementedError() + + def preprocess(self, data): + """Preprocess data.""" + return data + + def postprocess(self, data): + """Postprocess data.""" + return data diff --git a/site/zenodo_rdm/ml/config.py b/site/zenodo_rdm/ml/config.py new file mode 100644 index 00000000..e1cacf94 --- /dev/null +++ b/site/zenodo_rdm/ml/config.py @@ -0,0 +1,21 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2024 CERN. +# +# ZenodoRDM is free software; you can redistribute it and/or modify +# it under the terms of the MIT License; see LICENSE file for more details. + +"""Machine learning config.""" + +from .models import SpamDetectorScikit + +ML_MODELS = { + "spam_scikit": SpamDetectorScikit, +} +"""Machine learning models.""" + +# NOTE Model URL and model host need to be formattable strings for the model name. +ML_KUBEFLOW_MODEL_URL = "CHANGE-{0}-ME" +ML_KUBEFLOW_MODEL_HOST = "{0}-CHANGE" +ML_KUBEFLOW_TOKEN = "CHANGE SECRET" +"""Kubeflow connection config.""" diff --git a/site/zenodo_rdm/ml/ext.py b/site/zenodo_rdm/ml/ext.py new file mode 100644 index 00000000..18c72520 --- /dev/null +++ b/site/zenodo_rdm/ml/ext.py @@ -0,0 +1,49 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2024 CERN. +# +# ZenodoRDM is free software; you can redistribute it and/or modify +# it under the terms of the MIT License; see LICENSE file for more details. + +"""ZenodoRDM machine learning module.""" + +from flask import current_app + +from . import config + + +class ZenodoML: + """Zenodo machine learning extension.""" + + def __init__(self, app=None): + """Extension initialization.""" + if app: + self.init_app(app) + + @staticmethod + def init_config(app): + """Initialize configuration.""" + for k in dir(config): + if k.startswith("ML_"): + app.config.setdefault(k, getattr(config, k)) + + def init_app(self, app): + """Flask application initialization.""" + self.init_config(app) + app.extensions["zenodo-ml"] = self + + def _parse_model_name_version(self, model): + """Parse model name and version.""" + vals = model.rsplit(":") + version = vals[1] if len(vals) > 1 else None + return vals[0], version + + def models(self, model, **kwargs): + """Return model based on model name.""" + models = current_app.config.get("ML_MODELS", {}) + model_name, version = self._parse_model_name_version(model) + + if model_name not in models: + raise ValueError("Model not found/registered.") + + return models[model_name](version=version, **kwargs) diff --git a/site/zenodo_rdm/ml/models.py b/site/zenodo_rdm/ml/models.py new file mode 100644 index 00000000..bde9bdc5 --- /dev/null +++ b/site/zenodo_rdm/ml/models.py @@ -0,0 +1,85 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2024 CERN. +# +# ZenodoRDM is free software; you can redistribute it and/or modify +# it under the terms of the MIT License; see LICENSE file for more details. +"""Model definitions.""" + + +import json +import string + +import requests +from bs4 import BeautifulSoup +from flask import current_app + +from .base import MLModel + + +class SpamDetectorScikit(MLModel): + """Spam detection model based on Sklearn.""" + + MODEL_NAME = "sklearn-spam" + MAX_WORDS = 4000 + + def __init__(self, version, **kwargs): + """Constructor. Makes version required.""" + super().__init__(version, **kwargs) + + def preprocess(self, data): + """Preprocess data. + + Parse HTML, remove punctuation and truncate to max chars. + """ + text = BeautifulSoup(data, "html.parser").get_text() + trans_table = str.maketrans(string.punctuation, " " * len(string.punctuation)) + parts = text.translate(trans_table).lower().strip().split(" ") + if len(parts) >= self.MAX_WORDS: + parts = parts[: self.MAX_WORDS] + return " ".join(parts) + + def postprocess(self, data): + """Postprocess data. + + Gives spam and ham probability. + """ + result = { + "spam": data["outputs"][0]["data"][0], + "ham": data["outputs"][0]["data"][1], + } + return result + + def _send_request_kubeflow(self, data): + """Send predict request to Kubeflow.""" + payload = { + "inputs": [ + { + "name": "input-0", + "shape": [1], + "datatype": "BYTES", + "data": [f"{data}"], + } + ] + } + model_ref = self.MODEL_NAME + "-" + self.version + url = current_app.config.get("ML_KUBEFLOW_MODEL_URL").format(model_ref) + host = current_app.config.get("ML_KUBEFLOW_MODEL_HOST").format(model_ref) + access_token = current_app.config.get("ML_KUBEFLOW_TOKEN") + r = requests.post( + url, + headers={ + "Authorization": f"Bearer {access_token}", + "Content-Type": "application/json", + "Host": host, + }, + json=payload, + ) + if r.status_code != 200: + raise requests.RequestException("Prediction was not successful.", request=r) + return json.loads(r.text) + + def predict(self, data): + """Get prediction from model.""" + prediction = self._send_request_kubeflow(data) + return prediction diff --git a/site/zenodo_rdm/ml/proxies.py b/site/zenodo_rdm/ml/proxies.py new file mode 100644 index 00000000..41596c21 --- /dev/null +++ b/site/zenodo_rdm/ml/proxies.py @@ -0,0 +1,12 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2024 CERN. +# +# ZenodoRDM is free software; you can redistribute it and/or modify +# it under the terms of the MIT License; see LICENSE file for more details. +"""Proxy objects for easier access to application objects.""" + +from flask import current_app +from werkzeug.local import LocalProxy + +current_ml_models = LocalProxy(lambda: current_app.extensions["zenodo-ml"]) diff --git a/site/zenodo_rdm/moderation/config.py b/site/zenodo_rdm/moderation/config.py index 4c2af3ae..309e0008 100644 --- a/site/zenodo_rdm/moderation/config.py +++ b/site/zenodo_rdm/moderation/config.py @@ -45,21 +45,21 @@ MODERATION_EXEMPT_USERS = [] """List of users exempt from moderation.""" -MODERATION_RECORD_SCORE_RULES = [ - verified_user_rule, - links_rule, - files_rule, - text_sanitization_rule, - match_query_rule, -] +MODERATION_RECORD_SCORE_RULES = { + "verified_user_rule": verified_user_rule, + "links_rule": links_rule, + "files_rule": files_rule, + "text_sanitization_rule": text_sanitization_rule, + "match_query_rule": match_query_rule, +} """Scoring rules for record moderation.""" -MODERATION_COMMUNITY_SCORE_RULES = [ - links_rule, - text_sanitization_rule, - verified_user_rule, - match_query_rule, -] +MODERATION_COMMUNITY_SCORE_RULES = { + "links_rule": links_rule, + "text_sanitization_rule": text_sanitization_rule, + "verified_user_rule": verified_user_rule, + "match_query_rule": match_query_rule, +} """Scoring rules for communtiy moderation.""" MODERATION_PERCOLATOR_INDEX_PREFIX = "moderation-queries" diff --git a/site/zenodo_rdm/moderation/handlers.py b/site/zenodo_rdm/moderation/handlers.py index 1118aed4..b460baf3 100644 --- a/site/zenodo_rdm/moderation/handlers.py +++ b/site/zenodo_rdm/moderation/handlers.py @@ -44,7 +44,7 @@ from .uow import ExceptionOp -class BaseScoreHandler: +class BaseModerationHandler: """Base handler to calculate moderation scores based on rules.""" def __init__(self, rules=None): @@ -56,7 +56,11 @@ def rules(self): """Get scoring rules.""" if isinstance(self._rules, str): return current_app.config[self._rules] - return self._rules or [] + return self._rules or {} + + def evaluate_result(self, params): + """Evaluate aggregate result based on params.""" + return sum(params.values()) @property def should_apply_actions(self): @@ -77,17 +81,19 @@ def run(self, identity, draft=None, record=None, user=None, uow=None): ) return - score = 0 - for rule in self.rules: - score += rule(identity, draft=draft, record=record) + results = {} + for name, rule in self.rules.items(): + results[name] = rule(identity, draft=draft, record=record) action_ctx = { "user_id": user.id, "record_pid": record.pid.pid_value, - "score": score, + "results": results, } current_app.logger.debug("Moderation score calculated", extra=action_ctx) - if score > current_scores.spam_threshold: + + evaluation = self.evaluate_result(results) + if evaluation > current_scores.spam_threshold: action_ctx["action"] = "block" if self.should_apply_actions: # If user is verified, we need to (re)open the moderation @@ -102,9 +108,11 @@ def run(self, identity, draft=None, record=None, user=None, uow=None): "Block moderation action triggered", extra=action_ctx, ) - elif score < current_scores.ham_threshold: + + elif evaluation < current_scores.ham_threshold: action_ctx["action"] = "approve" + # If the user is already verified, we don't need to verify again if user.verified: current_app.logger.debug( "User is verified, skipping moderation actions", @@ -187,7 +195,7 @@ def _block(self, user, uow, action_ctx): raise UserBlockedException() -class RecordScoreHandler(BaseHandler, BaseScoreHandler): +class RecordModerationHandler(BaseHandler, BaseModerationHandler): """Handler for calculating scores for records.""" def __init__(self): @@ -222,7 +230,9 @@ def publish(self, identity, draft=None, record=None, uow=None, **kwargs): self.run(identity, record=record, user=user, uow=uow) -class CommunityScoreHandler(community_moderation.BaseHandler, BaseScoreHandler): +class CommunityModerationHandler( + community_moderation.BaseHandler, BaseModerationHandler +): """Handler for calculating scores for communities.""" def __init__(self):