Skip to content

Commit

Permalink
feat(lab-3358): copy_project use a dedicated mutation
Browse files Browse the repository at this point in the history
  • Loading branch information
baptiste-olivier committed Feb 18, 2025
1 parent 966ab4f commit 46fa97c
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 174 deletions.
6 changes: 6 additions & 0 deletions src/kili/adapters/kili_api_gateway/project/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,9 @@ def get_update_properties_in_project_mutation(fragment: str) -> str:
}
}
"""

GQL_COPY_PROJECT = """
mutation CopyProject($data: CopyProjectData!) {
data: copyProject(data: $data)
}
"""
20 changes: 19 additions & 1 deletion src/kili/adapters/kili_api_gateway/project/operations_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@
from .common import get_project
from .mappers import project_data_mapper, project_where_mapper
from .operations import (
GQL_COPY_PROJECT,
GQL_COUNT_PROJECTS,
GQL_CREATE_PROJECT,
get_update_properties_in_project_mutation,
)
from .types import ProjectDataKiliAPIGatewayInput
from .types import CopyProjectInput, ProjectDataKiliAPIGatewayInput


class ProjectOperationMixin(BaseOperationMixin):
Expand Down Expand Up @@ -104,3 +105,20 @@ def update_properties_in_project(
variables = {"data": data, "where": {"id": project_id}}
result = self.graphql_client.execute(mutation, variables)
return load_project_json_fields(result["data"], fields)

def copy_project(
self,
project_id: ProjectId,
project_data: CopyProjectInput,
) -> ProjectId:
"""Copy a project."""
variables = {
"data": {
"projectId": project_id,
"shouldCopyAssets": project_data.should_copy_assets,
"shouldCopyUsers": project_data.should_copy_members,
}
}

result = self.graphql_client.execute(GQL_COPY_PROJECT, variables)
return ProjectId(result.get("data", ""))
8 changes: 8 additions & 0 deletions src/kili/adapters/kili_api_gateway/project/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,11 @@ class ProjectDataKiliAPIGatewayInput:
should_relaunch_kpi_computation: Optional[bool]
title: Optional[str]
use_honeypot: Optional[bool]


@dataclass
class CopyProjectInput:
"""Copy project input data for Kili API Gateway."""

should_copy_members: Optional[bool]
should_copy_assets: Optional[bool]
11 changes: 7 additions & 4 deletions src/kili/entrypoints/mutations/project/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,8 @@ def copy_project( # pylint: disable=too-many-arguments
title if `None` is provided.
description: Description for the new project. Defaults to empty string
if `None` is provided.
copy_json_interface: Include json interface in the copy.
copy_quality_settings: Include quality settings in the copy.
copy_json_interface: Deprecated. Always include json interface in the copy.
copy_quality_settings: Deprecated. Always include quality settings in the copy.
copy_members: Include members in the copy.
copy_assets: Include assets in the copy.
copy_labels: Include labels in the copy.
Expand All @@ -196,12 +196,15 @@ def copy_project( # pylint: disable=too-many-arguments
Examples:
>>> kili.copy_project(from_project_id="clbqn56b331234567890l41c0")
"""
if (not copy_json_interface) or (not copy_quality_settings):
raise ValueError(
"The 'copy_json_interface' and 'copy_quality_settings' arguments are deprecated."
)

return ProjectCopier(self).copy_project( # pyright: ignore[reportGeneralTypeIssues]
from_project_id,
title,
description,
copy_json_interface,
copy_quality_settings,
copy_members,
copy_assets,
copy_labels,
Expand Down
187 changes: 18 additions & 169 deletions src/kili/services/copy_project/__init__.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,14 @@
"""Copy project implementation."""

import itertools
import json
import logging
from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Optional
from typing import TYPE_CHECKING, Optional

from kili.adapters.kili_api_gateway.helpers.queries import QueryOptions
from kili.core.constants import QUERY_BATCH_SIZE
from kili.core.utils.pagination import batcher
from kili.adapters.kili_api_gateway.project.types import CopyProjectInput
from kili.domain.asset import AssetFilters
from kili.domain.label import LabelFilters
from kili.domain.project import ProjectId
from kili.domain.types import ListOrTuple
from kili.use_cases.asset.media_downloader import get_download_assets_function
from kili.utils.tempfile import TemporaryDirectory
from kili.utils.tqdm import tqdm

if TYPE_CHECKING:
from kili.client import Kili
Expand All @@ -26,32 +19,20 @@ class ProjectCopier: # pylint: disable=too-few-public-methods

FIELDS_PROJECT = (
"title",
"inputType",
"description",
"id",
"dataConnections.id",
)
FIELDS_JSON_INTERFACE = ("jsonInterface",)
FIELDS_QUALITY_SETTINGS = (
"canSkipAsset",
"consensusTotCoverage",
"minConsensusSize",
"reviewCoverage",
"secondsToLabelBeforeAutoAssign",
"useHoneyPot",
)

def __init__(self, kili: "Kili") -> None:
self.disable_tqdm = False
self.kili = kili

def copy_project( # pylint: disable=too-many-arguments,too-many-locals
def copy_project( # pylint: disable=too-many-arguments
self,
from_project_id: str,
title: Optional[str],
description: Optional[str],
copy_json_interface: bool,
copy_quality_settings: bool,
copy_members: bool,
copy_assets: bool,
copy_labels: bool,
Expand All @@ -64,17 +45,7 @@ def copy_project( # pylint: disable=too-many-arguments,too-many-locals
logger = logging.getLogger("kili.services.copy_project")
logger.setLevel(logging.INFO)

if not any(
(copy_json_interface, copy_quality_settings, copy_members, copy_assets, copy_labels)
):
raise ValueError("At least one element has to be copied.")

if copy_labels:
if not copy_json_interface:
raise ValueError(
"`copy_json_interface` must be set to `True` for copying the source project"
" labels."
)
if not copy_assets:
raise ValueError(
"`copy_assets` must be set to `True` for copying the source project labels."
Expand All @@ -85,41 +56,31 @@ def copy_project( # pylint: disable=too-many-arguments,too-many-locals
)

fields = self.FIELDS_PROJECT
if copy_json_interface:
fields += self.FIELDS_JSON_INTERFACE
if copy_quality_settings:
fields += self.FIELDS_QUALITY_SETTINGS

src_project = self.kili.kili_api_gateway.get_project(ProjectId(from_project_id), fields)

if src_project["dataConnections"] and copy_assets:
raise NotImplementedError("Copying projects with cloud storage is not supported.")

new_project_title = title or self._generate_project_title(src_title=src_project["title"])

new_project_description = description or ""
logger.info("Copying new project...")

json_interface = src_project["jsonInterface"] if copy_json_interface else {"jobs": {}}

new_project_id = self.kili.create_project(
input_type=src_project["inputType"],
json_interface=json_interface,
title=new_project_title,
description=new_project_description,
)["id"]
logger.info(f"Creating new project with id: '{new_project_id}'")
new_project_id = self.kili.kili_api_gateway.copy_project(
ProjectId(from_project_id),
CopyProjectInput(
should_copy_members=copy_members,
should_copy_assets=copy_assets,
),
)

if copy_members:
logger.info("Copying members...")
self._copy_members(from_project_id, new_project_id)
logger.info(f"Created new project {new_project_id}")

if copy_quality_settings:
logger.info("Copying quality settings...")
self._copy_quality_settings(new_project_id, src_project)
self.kili.update_properties_in_project(
project_id=new_project_id,
title=title or self._generate_project_title(src_project["title"]),
description=description,
)

if copy_assets:
logger.info("Copying assets...")
self._copy_assets(from_project_id, new_project_id)
logger.info("Updated title/description")

if copy_labels:
logger.info("Copying labels...")
Expand All @@ -139,118 +100,6 @@ def _generate_project_title(self, src_title: str) -> str:
i += 1
return new_title

def _copy_members(self, from_project_id: str, new_project_id: str) -> None:
members = self.kili.project_users(
project_id=from_project_id,
fields=["activated", "role", "user.email", "status", "id"],
disable_tqdm=True,
)

members = [memb for memb in members if memb["status"] == "ACTIVATED" and memb["activated"]]

for member in tqdm(members, disable=self.disable_tqdm):
self.kili.append_to_roles(
project_id=new_project_id,
user_email=member["user"]["email"],
role=member["role"],
)

def _copy_quality_settings(self, new_project_id: str, src_project: Dict) -> None:
self.kili.update_properties_in_project(
project_id=new_project_id,
can_skip_asset=src_project["canSkipAsset"],
consensus_tot_coverage=src_project["consensusTotCoverage"],
min_consensus_size=src_project["minConsensusSize"],
use_honeypot=src_project["useHoneyPot"],
review_coverage=src_project["reviewCoverage"],
seconds_to_label_before_auto_assign=src_project["secondsToLabelBeforeAutoAssign"],
)

def _copy_assets(self, from_project_id: str, new_project_id: str) -> None:
"""Copy assets from a project to another.
Fetches assets by batch since `content` urls expire.
"""
filters = AssetFilters(project_id=ProjectId(from_project_id))
options = QueryOptions(disable_tqdm=False)
fields = (
"content",
"ocrMetadata",
"externalId",
"isHoneypot",
"jsonContent",
"jsonMetadata",
)

assets_gen = self.kili.kili_api_gateway.list_assets(filters, fields, options)

with TemporaryDirectory() as tmp_dir:
# TODO: modify download_media function so it can take a generator of assets
for assets_batch in batcher(assets_gen, QUERY_BATCH_SIZE):
downloaded_assets = self._download_assets(
from_project_id, fields, tmp_dir, assets_batch
)
self._upload_assets(new_project_id, downloaded_assets)

def _download_assets(
self, from_project_id: str, fields: ListOrTuple[str], tmp_dir: Path, assets: List[Dict]
) -> List[Dict]:
download_function, _ = get_download_assets_function(
self.kili.kili_api_gateway,
download_media=True,
fields=fields,
project_id=ProjectId(from_project_id),
local_media_dir=str(tmp_dir.resolve()),
)
assert download_function
return download_function(assets)

def _upload_assets(self, new_project_id: str, assets: List[Dict]) -> List[Dict]:
# ocrMetadata field of assets need to be merged with jsonMetadata field
for asset in assets:
if isinstance(asset["jsonMetadata"], str):
try:
asset["jsonMetadata"] = json.loads(asset["jsonMetadata"])
except json.JSONDecodeError:
asset["jsonMetadata"] = {}
if asset["ocrMetadata"]:
asset["jsonMetadata"] = {**asset["jsonMetadata"], **asset["ocrMetadata"]}

# we cannot send None values in the content_array or json_content_array fields of
# kili.append_many_to_assets. So we need to sort and group the assets by the presence of
# content and jsonContent.
assets = sorted(
assets,
key=lambda asset: (bool(asset["content"]), bool(asset["jsonContent"])),
)
assets_iterator = itertools.groupby(
assets,
key=lambda asset: (bool(asset["content"]), bool(asset["jsonContent"])),
)

for key, group in assets_iterator:
has_content, has_jsoncontent = key
group = list(group)

content_array = [asset["content"] for asset in group] if has_content else None
external_id_array = [asset["externalId"] for asset in group]
is_honeypot_array = [asset["isHoneypot"] for asset in group]
json_content_array = (
[asset["jsonContent"] for asset in group] if has_jsoncontent else None
)
json_metadata_array = [asset["jsonMetadata"] for asset in group]

self.kili.append_many_to_dataset(
project_id=new_project_id,
content_array=content_array,
external_id_array=external_id_array,
is_honeypot_array=is_honeypot_array,
json_content_array=json_content_array,
json_metadata_array=json_metadata_array,
disable_tqdm=True,
)
return assets

# pylint: disable=too-many-locals
def _copy_labels(self, from_project_id: str, new_project_id: str) -> None:
assets_new_project = self.kili.kili_api_gateway.list_assets(
Expand Down

0 comments on commit 46fa97c

Please sign in to comment.