Skip to content

Commit

Permalink
Support URLFile in the upload_file function (#1985)
Browse files Browse the repository at this point in the history
We would like `predict` functions to be able to return a remote URL rather than a local file on disk and have it behave like a file object. And when it is passed to the file uploader it will stream the file from the remote to the destination provided.

```py
class Predictor(BasePredictor):
    def predict(self, **kwargs) -> File:
        return URLFile("https://replicate.delivery/czjl/9MBNrffKcxoqY0iprW66NF8MZaNeH322a27yE0sjFGtKMXLnA/hello.webp")
```

This PR adds an additional check to the `upload_file()` handler to call `fh.seekable()` before attempting to seek. This allows instances of `io.IOBase` that do not support seek (like `URLFile`) to be uploaded.

We also add a `name` attribute to `URLFile`. This is used by the `upload_file` function to infer the file extension and mime type.
  • Loading branch information
aron authored Oct 16, 2024
1 parent 201dd5f commit 87787f4
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 4 deletions.
8 changes: 5 additions & 3 deletions python/cog/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
import requests


def upload_file(fh: io.IOBase, output_file_prefix: str = None) -> str:
fh.seek(0)
def upload_file(fh: io.IOBase, output_file_prefix: Optional[str] = None) -> str:
if fh.seekable():
fh.seek(0)

if output_file_prefix is not None:
name = getattr(fh, "name", "output")
Expand Down Expand Up @@ -42,7 +43,8 @@ def guess_filename(obj: io.IOBase) -> str:
def put_file_to_signed_endpoint(
fh: io.IOBase, endpoint: str, client: requests.Session, prediction_id: Optional[str]
) -> str:
fh.seek(0)
if fh.seekable():
fh.seek(0)

filename = guess_filename(fh)
content_type, _ = mimetypes.guess_type(filename)
Expand Down
12 changes: 12 additions & 0 deletions python/cog/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def __get_pydantic_json_schema__(
}
)
return json_schema

else:

@classmethod
Expand Down Expand Up @@ -223,6 +224,7 @@ def __get_pydantic_json_schema__(
json_schema = handler(core_schema)
json_schema.update(type="string", format="uri")
return json_schema

else:

@classmethod
Expand Down Expand Up @@ -286,6 +288,15 @@ class URLFile(io.IOBase):
__slots__ = ("__target__", "__url__")

def __init__(self, url: str) -> None:
parsed = urllib.parse.urlparse(url)
if parsed.scheme not in {
"http",
"https",
}:
raise ValueError(
"URLFile requires URL to conform to HTTP or HTTPS protocol"
)
object.__setattr__(self, "name", os.path.basename(parsed.path))
object.__setattr__(self, "__url__", url)

# We provide __getstate__ and __setstate__ explicitly to ensure that the
Expand Down Expand Up @@ -413,6 +424,7 @@ def __get_pydantic_json_schema__(
}
)
return json_schema

else:

@classmethod
Expand Down
17 changes: 16 additions & 1 deletion python/tests/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@

import numpy as np
import pydantic
import responses

import cog
from cog.files import upload_file
from cog.json import make_encodeable, upload_files
from cog.types import PYDANTIC_V2
from cog.types import PYDANTIC_V2, URLFile


def test_make_encodeable_recursively_encodes_tuples():
Expand Down Expand Up @@ -57,6 +58,20 @@ def test_upload_files():
}


@responses.activate
def test_upload_files_with_url():
responses.get(
"https://example.com/some/url.txt",
body="file content",
status=200,
)

obj = {"path": URLFile("https://example.com/some/url.txt")}
assert upload_files(obj, upload_file) == {
"path": "data:text/plain;base64,ZmlsZSBjb250ZW50"
}


def test_numpy():
class Model(pydantic.BaseModel):
ndarray: np.ndarray
Expand Down
8 changes: 8 additions & 0 deletions python/tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,14 @@
from cog.types import Secret, URLFile, get_filename


def test_urlfile_protocol_validation():
with pytest.raises(ValueError):
URLFile("file:///etc/shadow")

with pytest.raises(ValueError):
URLFile("data:text/plain,hello")


@responses.activate
def test_urlfile_acts_like_response():
responses.get(
Expand Down

0 comments on commit 87787f4

Please sign in to comment.