diff --git a/python/cog/files.py b/python/cog/files.py index 2470cbc9f5..698ef57748 100644 --- a/python/cog/files.py +++ b/python/cog/files.py @@ -8,8 +8,9 @@ import requests -def upload_file(fh: io.IOBase, output_file_prefix: str = None) -> str: - fh.seek(0) +def upload_file(fh: io.IOBase, output_file_prefix: Optional[str] = None) -> str: + if fh.seekable(): + fh.seek(0) if output_file_prefix is not None: name = getattr(fh, "name", "output") @@ -42,7 +43,8 @@ def guess_filename(obj: io.IOBase) -> str: def put_file_to_signed_endpoint( fh: io.IOBase, endpoint: str, client: requests.Session, prediction_id: Optional[str] ) -> str: - fh.seek(0) + if fh.seekable(): + fh.seek(0) filename = guess_filename(fh) content_type, _ = mimetypes.guess_type(filename) diff --git a/python/cog/types.py b/python/cog/types.py index e781ca9c2a..f4110e68ca 100644 --- a/python/cog/types.py +++ b/python/cog/types.py @@ -104,6 +104,7 @@ def __get_pydantic_json_schema__( } ) return json_schema + else: @classmethod @@ -223,6 +224,7 @@ def __get_pydantic_json_schema__( json_schema = handler(core_schema) json_schema.update(type="string", format="uri") return json_schema + else: @classmethod @@ -286,6 +288,15 @@ class URLFile(io.IOBase): __slots__ = ("__target__", "__url__") def __init__(self, url: str) -> None: + parsed = urllib.parse.urlparse(url) + if parsed.scheme not in { + "http", + "https", + }: + raise ValueError( + "URLFile requires URL to conform to HTTP or HTTPS protocol" + ) + object.__setattr__(self, "name", os.path.basename(parsed.path)) object.__setattr__(self, "__url__", url) # We provide __getstate__ and __setstate__ explicitly to ensure that the @@ -413,6 +424,7 @@ def __get_pydantic_json_schema__( } ) return json_schema + else: @classmethod diff --git a/python/tests/test_json.py b/python/tests/test_json.py index 95243fbbf3..e79f1a1049 100644 --- a/python/tests/test_json.py +++ b/python/tests/test_json.py @@ -3,11 +3,12 @@ import numpy as np import pydantic +import responses import cog from cog.files import upload_file from cog.json import make_encodeable, upload_files -from cog.types import PYDANTIC_V2 +from cog.types import PYDANTIC_V2, URLFile def test_make_encodeable_recursively_encodes_tuples(): @@ -57,6 +58,20 @@ def test_upload_files(): } +@responses.activate +def test_upload_files_with_url(): + responses.get( + "https://example.com/some/url.txt", + body="file content", + status=200, + ) + + obj = {"path": URLFile("https://example.com/some/url.txt")} + assert upload_files(obj, upload_file) == { + "path": "data:text/plain;base64,ZmlsZSBjb250ZW50" + } + + def test_numpy(): class Model(pydantic.BaseModel): ndarray: np.ndarray diff --git a/python/tests/test_types.py b/python/tests/test_types.py index 6ed027db05..95f7b237f8 100644 --- a/python/tests/test_types.py +++ b/python/tests/test_types.py @@ -7,6 +7,14 @@ from cog.types import Secret, URLFile, get_filename +def test_urlfile_protocol_validation(): + with pytest.raises(ValueError): + URLFile("file:///etc/shadow") + + with pytest.raises(ValueError): + URLFile("data:text/plain,hello") + + @responses.activate def test_urlfile_acts_like_response(): responses.get(