Skip to content

Commit

Permalink
Multipart upload support
Browse files Browse the repository at this point in the history
* Implementation of multipart upload, as described in RFC 0072

* See inveniosoftware/rfcs#91

Co-authored-by: Mirek Simek <[email protected]>
  • Loading branch information
2 people authored and OARepo Bot committed Jan 21, 2025
1 parent a18ae9e commit 5402d6b
Show file tree
Hide file tree
Showing 5 changed files with 304 additions and 0 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,5 @@ target/

# Vim swapfiles
.*.sw?

.vscode/*
9 changes: 9 additions & 0 deletions invenio_s3/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,12 @@
S3_DEFAULT_BLOCK_SIZE = 5 * 2**20
"""Default block size value used to send multi-part uploads to S3.
Typically 5Mb is minimum allowed by the API."""

S3_UPLOAD_URL_EXPIRATION = 3600 * 24 * 7
"""Number of seconds the file upload URL will be valid. The default here is 7 days
to allow large file uploads with large number of chunks to be completed. This is
currently the maximum allowed by the AWS.
See `Amazon Boto3 documentation on presigned URLs
<https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3.html#S3.Client.generate_presigned_url>`_
for more information.
"""
147 changes: 147 additions & 0 deletions invenio_s3/multipart_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# -*- coding: utf-8 -*-
#
# Copyright (C) 2024 Miroslav Simek
#
# Invenio-S3 is free software; you can redistribute it and/or modify it
# under the terms of the MIT License; see LICENSE file for more details.


"""Low level client for S3 multipart uploads."""

import datetime

# WARNING: low-level code. The underlying s3fs currently does not have support
# for multipart uploads without keeping the S3File instance in memory between requests.
# To overcome this limitation, we have to use the low-level API directly separated in the
# LowLevelS3File class.


class MultipartS3File:
"""Low level client for S3 multipart uploads."""

def __init__(self, fs, path, upload_id=None):
"""Initialize the low level client.
:param fs: S3FS instance
:param path: The path of the file (with bucket and version)
:param upload_id: The upload ID of the multipart upload, can be none to get a new upload.
"""
self.fs = fs
self.path = path
self.bucket, self.key, self.version_id = fs.split_path(path)
self.s3_client = fs.s3
self.acl = fs.s3_additional_kwargs.get("ACL", "")
self.upload_id = upload_id

def create_multipart_upload(self):
"""Create a new multipart upload.
:returns: The upload ID of the multipart upload.
"""
mpu = self.s3_client.create_multipart_upload(
Bucket=self.bucket, Key=self.key, ACL=self.acl
)
# TODO: error handling here
self.upload_id = mpu["UploadId"]
return self.upload_id

def get_parts(self, max_parts):
"""List the parts of the multipart upload.
:param max_parts: The maximum number of parts to list.
:returns: The list of parts, including checksums and etags.
"""
ret = self.s3_client.list_parts(
Bucket=self.bucket,
Key=self.key,
UploadId=self.upload_id,
MaxParts=max_parts,
PartNumberMarker=0,
)
return ret.get("Parts", [])

def upload_part(self, part_number, data):
"""Upload a part of the multipart upload. Will be used only in tests.
:param part_number: The part number.
:param data: The data to upload.
"""
part = self.s3_client.upload_part(
Bucket=self.bucket,
Key=self.key,
UploadId=self.upload_id,
PartNumber=part_number,
Body=data,
)
return part

def _complete_operation_part_parameters(self, part):
"""Filter parameters for the complete operation."""
ret = {}
for k in [
"PartNumber",
"ETag",
"ChecksumCRC32",
"ChecksumCRC32C",
"ChecksumSHA1",
"ChecksumSHA256",
]:
if k in part:
ret[k] = part[k]
return ret

def get_part_links(self, max_parts, url_expiration):
"""
Generate pre-signed URLs for the parts of the multipart upload.
:param max_parts: The maximum number of parts to list.
:param url_expiration: The expiration time of the URLs in seconds
:returns: The list of parts with pre-signed URLs and expiration times.
"""
expiration = datetime.datetime.utcnow() + datetime.timedelta(
seconds=url_expiration
)
expiration = expiration.replace(microsecond=0).isoformat() + "Z"

return {
"parts": [
{
"part": part + 1,
"url": self.s3_client.generate_presigned_url(
"upload_part",
Params={
"Bucket": self.bucket,
"Key": self.key,
"UploadId": self.upload_id,
"PartNumber": part + 1,
},
ExpiresIn=url_expiration,
),
"expiration": expiration,
}
for part in range(max_parts)
]
}

def complete_multipart_upload(self, parts):
"""Complete the multipart upload.
:param parts: The list of parts (as from self.get_parts), including checksums and etags.
"""
return self.s3_client.complete_multipart_upload(
Bucket=self.bucket,
Key=self.key,
UploadId=self.upload_id,
MultipartUpload={
"Parts": [
self._complete_operation_part_parameters(part) for part in parts
]
},
)

def abort_multipart_upload(self):
"""Abort the multipart upload."""
return self.s3_client.abort_multipart_upload(
Bucket=self.bucket, Key=self.key, UploadId=self.upload_id
)
75 changes: 75 additions & 0 deletions invenio_s3/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@

from functools import partial, wraps
from math import ceil
from typing import Any, Dict, Union

import s3fs
from flask import current_app
from invenio_files_rest.errors import StorageError
from invenio_files_rest.storage import PyFSFileStorage, pyfs_storage_factory

from .helpers import redirect_stream
from .multipart_client import MultipartS3File


def set_blocksize(f):
Expand Down Expand Up @@ -187,6 +189,79 @@ def save(self, *args, **kwargs):
"""
return super(S3FSFileStorage, self).save(*args, **kwargs)

def multipart_initialize_upload(
self, parts, size, part_size
) -> Union[None, Dict[str, str]]:
"""
Initialize a multipart upload.
:param parts: The number of parts that will be uploaded.
:param size: The total size of the file.
:param part_size: The size of each part except the last one.
:returns: a dictionary of additional metadata that should be stored between
the initialization and the commit of the upload.
"""
return {"uploadId": self.multipart_file().create_multipart_upload()}

def multipart_file(self, upload_id=None):
"""Get a low-level file object.
:param upload_id: The upload ID of the multipart upload, can be none to get a new upload.
:returns: an instance of LowLevelS3File.
"""
# WARNING: low-level code. The underlying s3fs currently does not have support
# for multipart uploads without keeping the S3File instance in memory between requests.
return MultipartS3File(*self._get_fs(), upload_id=upload_id)

def multipart_set_content(
self, part, stream, content_length, **multipart_metadata
) -> Union[None, Dict[str, str]]:
"""Set the content of a part of the multipart upload.
This method will never be called
by invenio as user will use direct pre-signed requests to S3 and will never
upload the files through Invenio.
"""
raise NotImplementedError(
"The multipart_set_content method is not implemented as it will never be called directly."
)

def multipart_commit_upload(self, **multipart_metadata):
"""Commit the multipart upload.
:param multipart_metadata: The metadata returned by the multipart_initialize_upload
and the metadata returned by the multipart_set_content for each part.
"""
f = self.multipart_file(multipart_metadata["uploadId"])
expected_parts = int(multipart_metadata["parts"])
parts = f.get_parts(max_parts=expected_parts)
if len(parts) != expected_parts:
raise ValueError(
f"Not all parts were uploaded, got {len(parts)} out of {expected_parts} parts."
)
f.complete_multipart_upload(parts)

def multipart_abort_upload(self, **multipart_metadata):
"""Abort the multipart upload.
:param multipart_metadata: The metadata returned by the multipart_initialize_upload
and the metadata returned by the multipart_set_content for each part.
"""
f = self.multipart_file(multipart_metadata["uploadId"])
f.abort_multipart_upload()

def multipart_links(self, **multipart_metadata) -> Dict[str, Any]:
"""Generate links for the parts of the multipart upload.
:param multipart_metadata: The metadata returned by the multipart_initialize_upload
and the metadata returned by the multipart_set_content for each part.
:returns: a dictionary of name of the link to invenio_records_resources.services.base.links.Link
"""
return self.multipart_file(multipart_metadata["uploadId"]).get_part_links(
int(multipart_metadata["parts"]),
current_app.config["S3_UPLOAD_URL_EXPIRATION"],
)


def s3fs_storage_factory(**kwargs):
"""File storage factory for S3."""
Expand Down
71 changes: 71 additions & 0 deletions tests/test_multipart.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import pytest

MB = 2**20


def test_multipart_flow(base_app, s3fs):
part_size = 7 * MB
last_part_size = 5 * MB

# initialize the upload
upload_metadata = dict(
parts=2, part_size=part_size, size=part_size + last_part_size
)
upload_metadata |= s3fs.multipart_initialize_upload(**upload_metadata) or {}

# can not commit just now because no parts were uploaded
with pytest.raises(ValueError):
s3fs.multipart_commit_upload(**upload_metadata)

# check that links are generated

links = s3fs.multipart_links(**upload_metadata)["parts"]
assert len(links) == 2
assert links[0]["part"] == 1
assert "url" in links[0]
assert links[1]["part"] == 2
assert "url" in links[1]

# upload the first part manually
multipart_file = s3fs.multipart_file(upload_metadata["uploadId"])
multipart_file.upload_part(1, b"0" * part_size)
assert len(multipart_file.get_parts(2)) == 1

# still can not commit because not all parts were uploaded
with pytest.raises(ValueError):
s3fs.multipart_commit_upload(**upload_metadata)

# upload the second part
multipart_file.upload_part(2, b"1" * last_part_size)
assert len(multipart_file.get_parts(2)) == 2

s3fs.multipart_commit_upload(**upload_metadata)

assert s3fs.open("rb").read() == b"0" * part_size + b"1" * last_part_size


def test_multipart_abort(base_app, s3fs):
part_size = 7 * MB
last_part_size = 5 * MB

# initialize the upload
upload_metadata = dict(
parts=2, part_size=part_size, size=part_size + last_part_size
)
upload_metadata |= s3fs.multipart_initialize_upload(**upload_metadata) or {}

s3fs.multipart_abort_upload(**upload_metadata)


def test_set_content_not_supported(base_app, s3fs):
part_size = 7 * MB
last_part_size = 5 * MB

# initialize the upload
upload_metadata = dict(
parts=2, part_size=part_size, size=part_size + last_part_size
)
upload_metadata |= s3fs.multipart_initialize_upload(**upload_metadata) or {}

with pytest.raises(NotImplementedError):
s3fs.multipart_set_content(1, b"0" * part_size, part_size, **upload_metadata)

0 comments on commit 5402d6b

Please sign in to comment.