Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions invenio_vcs/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# -*- coding: utf-8 -*-
# This file is part of Invenio.
# Copyright (C) 2025 CERN.
#
# Invenio is free software; you can redistribute it and/or modify it
# under the terms of the MIT License; see LICENSE file for more details.

"""Configuration for the VCS module."""

from typing import TYPE_CHECKING

from flask import current_app

if TYPE_CHECKING:
from invenio_vcs.providers import RepositoryServiceProviderFactory

VCS_PROVIDERS = []

VCS_RELEASE_CLASS = "invenio_vcs.service:VCSRelease"
"""VCSRelease class to be used for release handling."""

VCS_TEMPLATE_INDEX = "invenio_vcs/settings/index.html"
"""Repositories list template."""

VCS_TEMPLATE_VIEW = "invenio_vcs/settings/view.html"
"""Repository detail view template."""

VCS_ERROR_HANDLERS = None
"""Definition of the way specific exceptions are handled."""

VCS_MAX_CONTRIBUTORS_NUMBER = 30
"""Max number of contributors of a release to be retrieved from vcs."""

VCS_CITATION_FILE = None
"""Citation file name."""

VCS_CITATION_METADATA_SCHEMA = None
"""Citation metadata schema."""

VCS_ZIPBALL_TIMEOUT = 300
"""Timeout for the zipball download, in seconds."""


def get_provider_list(app=current_app) -> list["RepositoryServiceProviderFactory"]:
"""Get a list of configured VCS provider factories."""
return app.config["VCS_PROVIDERS"]


def get_provider_by_id(id: str) -> "RepositoryServiceProviderFactory":
"""Get a specific VCS provider by its registered ID."""
providers = get_provider_list()
for provider in providers:
if id == provider.id:
return provider
raise Exception(f"VCS provider with ID {id} not registered")
92 changes: 92 additions & 0 deletions invenio_vcs/oauth/handlers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# -*- coding: utf-8 -*-
# This file is part of Invenio.
# Copyright (C) 2025 CERN.
#
# Invenio is free software; you can redistribute it and/or modify it
# under the terms of the MIT License; see LICENSE file for more details.

"""Implement OAuth client handler."""

import typing

from flask import current_app, redirect, url_for
from flask_login import current_user
from invenio_db import db
from invenio_oauth2server.models import Token as ProviderToken
from invenio_oauthclient import oauth_unlink_external_id

from invenio_vcs.service import VCSService
from invenio_vcs.tasks import disconnect_provider

if typing.TYPE_CHECKING:
from invenio_vcs.providers import RepositoryServiceProviderFactory


class OAuthHandlers:
"""Provider-agnostic handler overrides to ensure VCS events are executed at certain points throughout the OAuth lifecyle."""

def __init__(self, provider_factory: "RepositoryServiceProviderFactory") -> None:
"""Instance are non-user-specific."""
self.provider_factory = provider_factory

def account_setup_handler(self, remote, token, resp):
"""Perform post initialization."""
try:
svc = VCSService(
self.provider_factory.for_user(token.remote_account.user_id)
)
svc.init_account()
svc.sync()
db.session.commit()
except Exception as e:
current_app.logger.warning(str(e), exc_info=True)

def disconnect_handler(self, remote):
"""Disconnect callback handler for the provider."""
# User must be authenticated
if not current_user.is_authenticated:
return current_app.login_manager.unauthorized()

external_method = self.provider_factory.id
external_ids = [
i.id
for i in current_user.external_identifiers
if i.method == external_method
]
if external_ids:
oauth_unlink_external_id(dict(id=external_ids[0], method=external_method))

svc = VCSService(self.provider_factory.for_user(current_user.id))
token = svc.provider.remote_token

if token:
extra_data = token.remote_account.extra_data

# Delete the token that we issued for vcs to deliver webhooks
webhook_token_id = extra_data.get("tokens", {}).get("webhook")
ProviderToken.query.filter_by(id=webhook_token_id).delete()

# Disable every vcs webhooks from our side
repos = svc.user_enabled_repositories.all()
repos_with_hooks = []
for repo in repos:
if repo.hook is not None:
repos_with_hooks.append((repo.provider_id, repo.hook))
svc.mark_repo_disabled(repo.provider_id)

# Commit any changes before running the ascynhronous task
db.session.commit()

# Send Celery task for webhooks removal and token revocation
disconnect_provider.delay(
self.provider_factory.id,
current_user.id,
token.access_token,
repos_with_hooks,
)

# Delete the RemoteAccount (along with the associated RemoteToken)
token.remote_account.delete()
db.session.commit()

return redirect(url_for("invenio_oauthclient_settings.index"))
104 changes: 104 additions & 0 deletions invenio_vcs/receivers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# -*- coding: utf-8 -*-
# This file is part of Invenio.
# Copyright (C) 2025 CERN.
#
# Invenio is free software; you can redistribute it and/or modify it
# under the terms of the MIT License; see LICENSE file for more details.

"""Task for managing vcs integration."""

from invenio_db import db
from invenio_webhooks.models import Receiver

from invenio_vcs.config import get_provider_by_id
from invenio_vcs.models import Release, ReleaseStatus, Repository
from invenio_vcs.tasks import process_release

from .errors import (
InvalidSenderError,
ReleaseAlreadyReceivedError,
RepositoryAccessError,
RepositoryDisabledError,
RepositoryNotFoundError,
)


class VCSReceiver(Receiver):
"""Handle incoming notification from vcs on a new release."""

def __init__(self, receiver_id):
"""Constructor."""
super().__init__(receiver_id)
self.provider_factory = get_provider_by_id(receiver_id)

def run(self, event):
"""Process an event.

.. note::

We should only do basic server side operation here, since we send
the rest of the processing to a Celery task which will be mainly
accessing the vcs API.
"""
self._handle_event(event)

def _handle_event(self, event):
"""Handles an incoming vcs event."""
is_create_release_event = self.provider_factory.webhook_is_create_release_event(
event.payload
)

if is_create_release_event:
self._handle_create_release(event)

def _handle_create_release(self, event):
"""Creates a release in invenio."""
try:
generic_release, generic_repo = (
self.provider_factory.webhook_event_to_generic(event.payload)
)

# Check if the release already exists
existing_release = Release.query.filter_by(
provider_id=generic_release.id,
).first()

if existing_release:
raise ReleaseAlreadyReceivedError(release=existing_release)

# Create the Release
repo = Repository.get(
self.provider_factory.id,
provider_id=generic_repo.id,
full_name=generic_repo.full_name,
)
if not repo:
raise RepositoryNotFoundError(generic_repo.full_name)

if repo.enabled:
release = Release(
provider_id=generic_release.id,
provider=self.provider_factory.id,
tag=generic_release.tag_name,
repository=repo,
event=event,
status=ReleaseStatus.RECEIVED,
)
db.session.add(release)
else:
raise RepositoryDisabledError(repo=repo)

# Process the release
# Since 'process_release' is executed asynchronously, we commit the current state of session
db.session.commit()
process_release.delay(self.provider_factory.id, release.provider_id)

except (ReleaseAlreadyReceivedError, RepositoryDisabledError) as e:
event.response_code = 409
event.response = dict(message=str(e), status=409)
except (RepositoryAccessError, InvalidSenderError) as e:
event.response_code = 403
event.response = dict(message=str(e), status=403)
except RepositoryNotFoundError as e:
event.response_code = 404
event.response = dict(message=str(e), status=404)
Loading
Loading