Skip to content

Commit

Permalink
Merge pull request #498 from c-bata/boto3-avoid-file-close
Browse files Browse the repository at this point in the history
Fix a bug that `boto3.upload_fileobj` may close the file
  • Loading branch information
c-bata committed Jun 13, 2023
2 parents ad7613d + ca5bf5f commit f33474f
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 3 deletions.
18 changes: 16 additions & 2 deletions optuna_dashboard/artifact/boto3.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import io
import shutil
from typing import TYPE_CHECKING

import boto3
Expand Down Expand Up @@ -33,9 +35,15 @@ def objective(trial: optuna.Trial) -> float:
return ...
"""

def __init__(self, bucket_name: str, client: Optional[S3Client] = None) -> None:
def __init__(
self, bucket_name: str, client: Optional[S3Client] = None, *, avoid_buf_copy: bool = False
) -> None:
self.bucket = bucket_name
self.client = client or boto3.client("s3")
# This flag is added to avoid that upload_fileobj() method of Boto3 client
# may close the source file object.
# See https://github.com/boto/boto3/issues/929
self._avoid_buf_copy = avoid_buf_copy

def open(self, artifact_id: str) -> BinaryIO:
try:
Expand All @@ -49,7 +57,13 @@ def open(self, artifact_id: str) -> BinaryIO:
return body # type: ignore

def write(self, artifact_id: str, content_body: BinaryIO) -> None:
self.client.upload_fileobj(content_body, self.bucket, artifact_id)
fsrc: BinaryIO = content_body
if not self._avoid_buf_copy:
buf = io.BytesIO()
shutil.copyfileobj(content_body, buf)
buf.seek(0)
fsrc = buf
self.client.upload_fileobj(fsrc, self.bucket, artifact_id)

def remove(self, artifact_id: str) -> None:
try:
Expand Down
4 changes: 3 additions & 1 deletion python_tests/artifact/test_boto3.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,18 @@ def tearDown(self) -> None:
def test_upload_download(self) -> None:
artifact_id = "dummy-uuid"
dummy_content = b"Hello World"
buf = io.BytesIO(dummy_content)

backend = Boto3Backend(self.bucket_name)
backend.write(artifact_id, io.BytesIO(dummy_content))
backend.write(artifact_id, buf)
assert len(self.s3_client.list_objects(Bucket=self.bucket_name)["Contents"]) == 1
obj = self.s3_client.get_object(Bucket=self.bucket_name, Key=artifact_id)
assert obj["Body"].read() == dummy_content

with backend.open(artifact_id) as f:
actual = f.read()
self.assertEqual(actual, dummy_content)
self.assertFalse(buf.closed)

def test_remove(self) -> None:
artifact_id = "dummy-uuid"
Expand Down

0 comments on commit f33474f

Please sign in to comment.