Skip to content

Commit

Permalink
Pre/post-processing download requests (#9383)
Browse files Browse the repository at this point in the history
* changes

* add changeset

* changes

* change

* changes

* changes

* changes

* changes

* change

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* Update gradio/processing_utils.py

Co-authored-by: Abubakar Abid <[email protected]>

* changes

* changes

* changes

* changes

* changes

* changes

* Fix Lite's ASGI receiver to convert memoryview to bytes as the multipart parser called in https://github.com/gradio-app/gradio/blob/98cbcaef827de7267462ccba180c7b2ffb1e825d/gradio/route_utils.py#L650 calls bytes.find() and memoryview objects don't have the method

* add changeset

* Fix async_get_with_secure_transport to use the unsecure but Pyodide-compatible transport in the case of Wasm

---------

Co-authored-by: Ali Abid <[email protected]>
Co-authored-by: gradio-pr-bot <[email protected]>
Co-authored-by: Abubakar Abid <[email protected]>
Co-authored-by: Yuichiro Tachibana (Tsuchiya) <[email protected]>
  • Loading branch information
5 people authored Sep 30, 2024
1 parent 3ac5d9c commit 30d13ac
Show file tree
Hide file tree
Showing 5 changed files with 220 additions and 100 deletions.
6 changes: 6 additions & 0 deletions .changeset/plenty-dragons-fold.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
"@gradio/wasm": minor
"gradio": minor
---

feat:Pre/post-processing download requests
2 changes: 1 addition & 1 deletion gradio/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ async def async_move_resource_to_block_cache(
url_or_file_path = str(url_or_file_path)

if client_utils.is_http_url_like(url_or_file_path):
temp_file_path = await processing_utils.async_save_url_to_cache(
temp_file_path = await processing_utils.async_ssrf_protected_download(
url_or_file_path, cache_dir=self.GRADIO_CACHE
)

Expand Down
253 changes: 163 additions & 90 deletions gradio/processing_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import asyncio
import base64
import hashlib
import ipaddress
Expand All @@ -8,14 +9,16 @@
import os
import shutil
import socket
import ssl
import subprocess
import tempfile
import warnings
from functools import lru_cache
from collections.abc import Awaitable, Callable, Coroutine
from functools import lru_cache, wraps
from io import BytesIO
from pathlib import Path
from typing import TYPE_CHECKING, Any
from urllib.parse import urlparse, urlunparse
from typing import TYPE_CHECKING, Any, TypeVar
from urllib.parse import urlparse

import aiofiles
import httpx
Expand Down Expand Up @@ -109,7 +112,6 @@ async def handle_async_request(
async_transport = None

sync_client = httpx.Client(transport=sync_transport)
async_client = httpx.AsyncClient(transport=async_transport)

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -279,123 +281,194 @@ def save_file_to_cache(file_path: str | Path, cache_dir: str) -> str:
return full_temp_file_path


@lru_cache(maxsize=256)
def resolve_with_google_dns(hostname: str) -> str | None:
url = f"https://dns.google/resolve?name={hostname}&type=A"
# Always return these URLs as is, without checking to see if they resolve
# to an internal IP address. This is because Hugging Face uses DNS splitting,
# which means that requests from HF Spaces to HF Datasets or HF Models
# may resolve to internal IP addresses even if they are publicly accessible.
PUBLIC_HOSTNAME_WHITELIST = ["hf.co", "huggingface.co"]

if wasm_utils.IS_WASM:
import pyodide.http

content = pyodide.http.open_url(url)
data = json.load(content)
else:
import urllib.request
def is_public_ip(ip: str) -> bool:
try:
ip_obj = ipaddress.ip_address(ip)
return not (
ip_obj.is_private
or ip_obj.is_loopback
or ip_obj.is_link_local
or ip_obj.is_multicast
or ip_obj.is_reserved
)
except ValueError:
return False

with urllib.request.urlopen(url) as response:
data = json.loads(response.read().decode())

if data.get("Status") == 0 and "Answer" in data:
for answer in data["Answer"]:
if answer["type"] == 1:
return answer["data"]
T = TypeVar("T")


# Always return these URLs as is, without checking to see if they resolve
# to an internal IP address. This is because Hugging Face uses DNS splitting,
# which means that requests from HF Spaces to HF Datasets or HF Models
# may resolve to internal IP addresses even if they are publicly accessible.
PUBLIC_URL_WHITELIST = ["hf.co", "huggingface.co"]
def lru_cache_async(maxsize: int = 128):
def decorator(
async_func: Callable[..., Coroutine[Any, Any, T]],
) -> Callable[..., Awaitable[T]]:
@lru_cache(maxsize=maxsize)
@wraps(async_func)
def wrapper(*args: Any, **kwargs: Any) -> Awaitable[T]:
return asyncio.create_task(async_func(*args, **kwargs))

return wrapper

def get_public_url(url: str) -> str:
parsed_url = urlparse(url)
if parsed_url.scheme not in ["http", "https"]:
raise httpx.RequestError(f"Invalid scheme for URL: {url}")
hostname = parsed_url.hostname
if not hostname:
raise httpx.RequestError(f"Invalid URL: {url}, missing hostname")
if hostname.lower() in PUBLIC_URL_WHITELIST:
return url
return decorator


@lru_cache_async(maxsize=256)
async def async_resolve_hostname_google(hostname: str) -> list[str]:
async with httpx.AsyncClient() as client:
try:
response_v4 = await client.get(
f"https://dns.google/resolve?name={hostname}&type=A"
)
response_v6 = await client.get(
f"https://dns.google/resolve?name={hostname}&type=AAAA"
)

ips = []
for response in [response_v4.json(), response_v6.json()]:
ips.extend([answer["data"] for answer in response.get("Answer", [])])
return ips
except Exception:
return []


class AsyncSecureTransport(httpx.AsyncHTTPTransport):
def __init__(self, verified_ip: str):
self.verified_ip = verified_ip
super().__init__()

async def connect(
self,
hostname: str,
port: int,
_timeout: float | None = None,
ssl_context: ssl.SSLContext | None = None,
**_kwargs: Any,
):
loop = asyncio.get_event_loop()
sock = await loop.getaddrinfo(self.verified_ip, port)
sock = socket.socket(sock[0][0], sock[0][1])
await loop.sock_connect(sock, (self.verified_ip, port))
if ssl_context:
sock = ssl_context.wrap_socket(sock, server_hostname=hostname)
return sock


async def async_validate_url(url: str) -> str:
hostname = urlparse(url).hostname
if not hostname:
raise ValueError(f"URL {url} does not have a valid hostname")
try:
addrinfo = socket.getaddrinfo(hostname, None)
loop = asyncio.get_event_loop()
addrinfo = await loop.getaddrinfo(hostname, None)
except socket.gaierror as e:
raise httpx.RequestError(
f"Cannot resolve URL with hostname: {hostname}, please download this file and use the path instead."
) from e
raise ValueError(f"Unable to resolve hostname {hostname}: {e}") from e

for family, _, _, _, sockaddr in addrinfo:
ip = sockaddr[0]
if family == socket.AF_INET6:
ip = ip.split("%")[0] # Remove scope ID if present

if ipaddress.ip_address(ip).is_global:
return url

google_resolved_ip = resolve_with_google_dns(hostname)
if google_resolved_ip and ipaddress.ip_address(google_resolved_ip).is_global:
if parsed_url.scheme == "https":
return url
new_parsed = parsed_url._replace(netloc=google_resolved_ip)
if parsed_url.port:
new_parsed = new_parsed._replace(
netloc=f"{google_resolved_ip}:{parsed_url.port}"
)
return urlunparse(new_parsed)
ip_address = sockaddr[0]
if family in (socket.AF_INET, socket.AF_INET6) and is_public_ip(ip_address):
return ip_address

raise httpx.RequestError(
f"No public IP address found for URL: {url}, please download this file and use the path instead."
)
if not wasm_utils.IS_WASM:
for ip_address in await async_resolve_hostname_google(hostname):
if is_public_ip(ip_address):
return ip_address

raise ValueError(f"Hostname {hostname} failed validation")

def save_url_to_cache(url: str, cache_dir: str) -> str:
"""Downloads a file and makes a temporary file path for a copy if does not already
exist. Otherwise returns the path to the existing temp file."""
url = get_public_url(url)

temp_dir = hash_url(url)
temp_dir = Path(cache_dir) / temp_dir
async def async_get_with_secure_transport(
url: str, trust_hostname: bool = False
) -> httpx.Response:
if wasm_utils.IS_WASM:
transport = PyodideHttpTransport()
elif trust_hostname:
transport = None
else:
verified_ip = await async_validate_url(url)
transport = AsyncSecureTransport(verified_ip)
async with httpx.AsyncClient(transport=transport) as client:
return await client.get(url, follow_redirects=False)


async def async_ssrf_protected_download(url: str, cache_dir: str) -> str:
temp_dir = Path(cache_dir) / hash_url(url)
temp_dir.mkdir(exist_ok=True, parents=True)
name = client_utils.strip_invalid_filename_characters(Path(url).name)
full_temp_file_path = str(abspath(temp_dir / name))
filename = client_utils.strip_invalid_filename_characters(Path(url).name)
full_temp_file_path = str(abspath(temp_dir / filename))

if not Path(full_temp_file_path).exists():
with (
sync_client.stream("GET", url, follow_redirects=True) as response,
open(full_temp_file_path, "wb") as f,
):
for redirect in response.history:
get_public_url(str(redirect.url))
if Path(full_temp_file_path).exists():
return full_temp_file_path

parsed_url = urlparse(url)
hostname = parsed_url.hostname

for chunk in response.iter_raw():
f.write(chunk)
response = await async_get_with_secure_transport(
url, trust_hostname=hostname in PUBLIC_HOSTNAME_WHITELIST
)

return full_temp_file_path
while response.is_redirect:
redirect_url = response.headers["Location"]
redirect_parsed = urlparse(redirect_url)

if not redirect_parsed.hostname:
redirect_url = f"{parsed_url.scheme}://{hostname}{redirect_url}"

async def async_save_url_to_cache(url: str, cache_dir: str) -> str:
"""Downloads a file and makes a temporary file path for a copy if does not already
exist. Otherwise returns the path to the existing temp file. Uses async httpx."""
url = get_public_url(url)
response = await async_get_with_secure_transport(redirect_url)

temp_dir = hash_url(url)
temp_dir = Path(cache_dir) / temp_dir
if response.status_code != 200:
raise Exception(f"Failed to download file. Status code: {response.status_code}")

async with aiofiles.open(full_temp_file_path, "wb") as f:
async for chunk in response.aiter_bytes():
await f.write(chunk)

return full_temp_file_path


def unsafe_download(url: str, cache_dir: str) -> str:
temp_dir = Path(cache_dir) / hash_url(url)
temp_dir.mkdir(exist_ok=True, parents=True)
name = client_utils.strip_invalid_filename_characters(Path(url).name)
full_temp_file_path = str(abspath(temp_dir / name))
filename = client_utils.strip_invalid_filename_characters(Path(url).name)
full_temp_file_path = str(abspath(temp_dir / filename))

if not Path(full_temp_file_path).exists():
async with async_client.stream("GET", url, follow_redirects=True) as response:
for redirect in response.history:
get_public_url(str(redirect.url))
with (
sync_client.stream("GET", url, follow_redirects=True) as r,
open(full_temp_file_path, "wb") as f,
):
for chunk in r.iter_raw():
f.write(chunk)

async with aiofiles.open(full_temp_file_path, "wb") as f:
async for chunk in response.aiter_raw():
await f.write(chunk)
# print path and file size
print(
f"Downloaded {full_temp_file_path} ({os.path.getsize(full_temp_file_path)} bytes)"
)
log.info(
f"Downloaded {full_temp_file_path} ({os.path.getsize(full_temp_file_path)} bytes)"
)

return full_temp_file_path


def ssrf_protected_download(url: str, cache_dir: str) -> str:
if wasm_utils.IS_WASM:
return unsafe_download(url, cache_dir)
else:
return client_utils.synchronize_async(
async_ssrf_protected_download, url, cache_dir
)


# Custom components created with versions of gradio < 5.0 may be using the processing_utils.save_url_to_cache method, so we alias to ssrf_protected_download to preserve backwards-compatibility
save_url_to_cache = ssrf_protected_download


def save_base64_to_cache(
base64_encoding: str, cache_dir: str, file_name: str | None = None
) -> str:
Expand Down
6 changes: 5 additions & 1 deletion js/wasm/src/webworker/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,11 @@ async def _call_asgi_app_from_js(app_id, scope, receive, send):
async def rcv():
event = await receive()
return event.to_py()
py_event = event.to_py()
if "body" in py_event:
if isinstance(py_event["body"], memoryview):
py_event["body"] = py_event["body"].tobytes()
return py_event
async def snd(event):
await send(event)
Expand Down
Loading

0 comments on commit 30d13ac

Please sign in to comment.