Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pre/post-processing download requests #9383

Merged
merged 29 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/plenty-dragons-fold.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"gradio": minor
---

feat:Pre/post-processing download requests
2 changes: 1 addition & 1 deletion gradio/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ async def async_move_resource_to_block_cache(
url_or_file_path = str(url_or_file_path)

if client_utils.is_http_url_like(url_or_file_path):
temp_file_path = await processing_utils.async_save_url_to_cache(
temp_file_path = await processing_utils.async_ssrf_protected_download(
url_or_file_path, cache_dir=self.GRADIO_CACHE
)

Expand Down
311 changes: 226 additions & 85 deletions gradio/processing_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import asyncio
import base64
import hashlib
import ipaddress
Expand All @@ -8,14 +9,16 @@
import os
import shutil
import socket
import ssl
import subprocess
import tempfile
import warnings
from functools import lru_cache
from collections.abc import Awaitable, Callable, Coroutine
from functools import lru_cache, wraps
from io import BytesIO
from pathlib import Path
from typing import TYPE_CHECKING, Any
from urllib.parse import urlparse, urlunparse
from typing import TYPE_CHECKING, Any, TypeVar
from urllib.parse import urlparse

import aiofiles
import httpx
Expand Down Expand Up @@ -102,9 +105,6 @@ async def handle_async_request(
sync_transport = None
async_transport = None

sync_client = httpx.Client(transport=sync_transport)
async_client = httpx.AsyncClient(transport=async_transport)

log = logging.getLogger(__name__)

if TYPE_CHECKING:
Expand Down Expand Up @@ -273,123 +273,264 @@ def save_file_to_cache(file_path: str | Path, cache_dir: str) -> str:
return full_temp_file_path


@lru_cache(maxsize=256)
def resolve_with_google_dns(hostname: str) -> str | None:
url = f"https://dns.google/resolve?name={hostname}&type=A"
# Always return these URLs as is, without checking to see if they resolve
# to an internal IP address. This is because Hugging Face uses DNS splitting,
# which means that requests from HF Spaces to HF Datasets or HF Models
# may resolve to internal IP addresses even if they are publicly accessible.
PUBLIC_HOSTNAME_WHITELIST = ["hf.co", "huggingface.co"]

if wasm_utils.IS_WASM:
import pyodide.http

content = pyodide.http.open_url(url)
data = json.load(content)
else:
import urllib.request
def is_public_ip(ip: str) -> bool:
try:
ip_obj = ipaddress.ip_address(ip)
return not (
ip_obj.is_private
or ip_obj.is_loopback
or ip_obj.is_link_local
or ip_obj.is_multicast
or ip_obj.is_reserved
or (isinstance(ip_obj, ipaddress.IPv6Address) and ip_obj.is_site_local)
aliabid94 marked this conversation as resolved.
Show resolved Hide resolved
)
except ValueError:
return False

with urllib.request.urlopen(url) as response:
data = json.loads(response.read().decode())

if data.get("Status") == 0 and "Answer" in data:
for answer in data["Answer"]:
if answer["type"] == 1:
return answer["data"]
T = TypeVar("T")


# Always return these URLs as is, without checking to see if they resolve
# to an internal IP address. This is because Hugging Face uses DNS splitting,
# which means that requests from HF Spaces to HF Datasets or HF Models
# may resolve to internal IP addresses even if they are publicly accessible.
PUBLIC_URL_WHITELIST = ["hf.co", "huggingface.co"]
def lru_cache_async(maxsize: int = 128):
aliabid94 marked this conversation as resolved.
Show resolved Hide resolved
def decorator(
async_func: Callable[..., Coroutine[Any, Any, T]],
) -> Callable[..., Awaitable[T]]:
@lru_cache(maxsize=maxsize)
@wraps(async_func)
def wrapper(*args: Any, **kwargs: Any) -> Awaitable[T]:
return asyncio.create_task(async_func(*args, **kwargs))

return wrapper

return decorator

def get_public_url(url: str) -> str:
parsed_url = urlparse(url)
if parsed_url.scheme not in ["http", "https"]:
raise httpx.RequestError(f"Invalid scheme for URL: {url}")
hostname = parsed_url.hostname
if not hostname:
raise httpx.RequestError(f"Invalid URL: {url}, missing hostname")
if hostname.lower() in PUBLIC_URL_WHITELIST:
return url

@lru_cache(maxsize=256)
def resolve_hostname_google(hostname: str) -> list[str]:
with httpx.Client() as client:
try:
response_v4 = client.get(
f"https://dns.google/resolve?name={hostname}&type=A"
)
response_v6 = client.get(
f"https://dns.google/resolve?name={hostname}&type=AAAA"
)

ips = []
for response in [response_v4.json(), response_v6.json()]:
ips.extend([answer["data"] for answer in response.get("Answer", [])])
return ips
except Exception:
return []


@lru_cache_async(maxsize=256)
async def async_resolve_hostname_google(hostname: str) -> list[str]:
async with httpx.AsyncClient() as client:
try:
response_v4 = await client.get(
f"https://dns.google/resolve?name={hostname}&type=A"
)
response_v6 = await client.get(
f"https://dns.google/resolve?name={hostname}&type=AAAA"
)

ips = []
for response in [response_v4.json(), response_v6.json()]:
ips.extend([answer["data"] for answer in response.get("Answer", [])])
return ips
except Exception:
return []


class SecureTransport(httpx.HTTPTransport):
def __init__(self, verified_ip: str):
self.verified_ip = verified_ip
super().__init__()

def connect(
self,
hostname: str,
port: int,
timeout: float | None = None,
ssl_context: ssl.SSLContext | None = None,
):
sock = socket.create_connection((self.verified_ip, port), timeout=timeout)
if ssl_context:
sock = ssl_context.wrap_socket(sock, server_hostname=hostname)
return sock


class AsyncSecureTransport(httpx.AsyncHTTPTransport):
def __init__(self, verified_ip: str):
self.verified_ip = verified_ip
super().__init__()

async def connect(
self,
hostname: str,
port: int,
_timeout: float | None = None,
ssl_context: ssl.SSLContext | None = None,
**_kwargs: Any,
):
loop = asyncio.get_event_loop()
sock = await loop.getaddrinfo(self.verified_ip, port)
sock = socket.socket(sock[0][0], sock[0][1])
await loop.sock_connect(sock, (self.verified_ip, port))
if ssl_context:
sock = ssl_context.wrap_socket(sock, server_hostname=hostname)
return sock


def validate_url(url: str) -> str:
hostname = urlparse(url).hostname
if not hostname:
raise ValueError(f"URL {url} does not have a valid hostname")
try:
addrinfo = socket.getaddrinfo(hostname, None)
except socket.gaierror as e:
raise httpx.RequestError(
f"Cannot resolve URL with hostname: {hostname}, please download this file and use the path instead."
) from e
raise ValueError(f"Unable to resolve hostname {hostname}: {e}") from e

for family, _, _, _, sockaddr in addrinfo:
ip = sockaddr[0]
if family == socket.AF_INET6:
ip = ip.split("%")[0] # Remove scope ID if present

if ipaddress.ip_address(ip).is_global:
return url

google_resolved_ip = resolve_with_google_dns(hostname)
if google_resolved_ip and ipaddress.ip_address(google_resolved_ip).is_global:
if parsed_url.scheme == "https":
return url
new_parsed = parsed_url._replace(netloc=google_resolved_ip)
if parsed_url.port:
new_parsed = new_parsed._replace(
netloc=f"{google_resolved_ip}:{parsed_url.port}"
)
return urlunparse(new_parsed)
ip_address = sockaddr[0]
if family in (socket.AF_INET, socket.AF_INET6) and is_public_ip(ip_address):
return ip_address

raise httpx.RequestError(
f"No public IP address found for URL: {url}, please download this file and use the path instead."
raise ValueError(f"Hostname {hostname} failed validation")


async def async_validate_url(url: str) -> str:
hostname = urlparse(url).hostname
if not hostname:
raise ValueError(f"URL {url} does not have a valid hostname")
try:
loop = asyncio.get_event_loop()
addrinfo = await loop.getaddrinfo(hostname, None)
except socket.gaierror as e:
raise ValueError(f"Unable to resolve hostname {hostname}: {e}") from e

for family, _, _, _, sockaddr in addrinfo:
ip_address = sockaddr[0]
if family in (socket.AF_INET, socket.AF_INET6) and is_public_ip(ip_address):
return ip_address

for ip_address in await async_resolve_hostname_google(hostname):
if is_public_ip(ip_address):
return ip_address

raise ValueError(f"Hostname {hostname} failed validation")


def get_with_secure_transport(url: str, trust_hostname: bool = False) -> httpx.Response:
if trust_hostname:
transport = None
else:
verified_ip = validate_url(url)
transport = SecureTransport(verified_ip)
with httpx.Client(transport=transport) as client:
return client.get(url, follow_redirects=False)


async def async_get_with_secure_transport(
url: str, trust_hostname: bool = False
) -> httpx.Response:
if trust_hostname:
transport = None
else:
verified_ip = validate_url(url)
aliabid94 marked this conversation as resolved.
Show resolved Hide resolved
transport = AsyncSecureTransport(verified_ip)
async with httpx.AsyncClient(transport=transport) as client:
return await client.get(url, follow_redirects=False)


def ssrf_protected_download(url: str, cache_dir: str) -> str:
parsed_url = urlparse(url)
hostname = parsed_url.hostname

response = get_with_secure_transport(
url, trust_hostname=hostname in PUBLIC_HOSTNAME_WHITELIST
)

while response.is_redirect:
redirect_url = response.headers["Location"]
redirect_parsed = urlparse(redirect_url)

if not redirect_parsed.hostname:
redirect_url = f"{parsed_url.scheme}://{hostname}{redirect_url}"

def save_url_to_cache(url: str, cache_dir: str) -> str:
"""Downloads a file and makes a temporary file path for a copy if does not already
exist. Otherwise returns the path to the existing temp file."""
url = get_public_url(url)
response = get_with_secure_transport(redirect_url)

if response.status_code != 200:
raise Exception(f"Failed to download file. Status code: {response.status_code}")

content_disposition = response.headers.get("Content-Disposition")
if content_disposition and "filename=" in content_disposition:
filename = Path(content_disposition.split("filename=")[1].strip('"'))
else:
filename = Path(url).name

temp_dir = hash_url(url)
temp_dir = Path(cache_dir) / temp_dir
temp_dir.mkdir(exist_ok=True, parents=True)
name = client_utils.strip_invalid_filename_characters(Path(url).name)
full_temp_file_path = str(abspath(temp_dir / name))
full_temp_file_path = str(abspath(temp_dir / filename))

if not Path(full_temp_file_path).exists():
with (
sync_client.stream("GET", url, follow_redirects=True) as response,
open(full_temp_file_path, "wb") as f,
):
for redirect in response.history:
get_public_url(str(redirect.url))

for chunk in response.iter_raw():
f.write(chunk)
with open(full_temp_file_path, "wb") as f:
f.write(response.content)

return full_temp_file_path


async def async_save_url_to_cache(url: str, cache_dir: str) -> str:
"""Downloads a file and makes a temporary file path for a copy if does not already
exist. Otherwise returns the path to the existing temp file. Uses async httpx."""
url = get_public_url(url)
async def async_ssrf_protected_download(url: str, cache_dir: str) -> str:
parsed_url = urlparse(url)
hostname = parsed_url.hostname

response = await async_get_with_secure_transport(
url, trust_hostname=hostname in PUBLIC_HOSTNAME_WHITELIST
)

while response.is_redirect:
redirect_url = response.headers["Location"]
redirect_parsed = urlparse(redirect_url)

if not redirect_parsed.hostname:
redirect_url = f"{parsed_url.scheme}://{hostname}{redirect_url}"

response = await async_get_with_secure_transport(redirect_url)

if response.status_code != 200:
raise Exception(f"Failed to download file. Status code: {response.status_code}")

content_disposition = response.headers.get("Content-Disposition")
if content_disposition and "filename=" in content_disposition:
filename = Path(content_disposition.split("filename=")[1].strip('"')).name
else:
filename = client_utils.strip_invalid_filename_characters(Path(url).name)

temp_dir = hash_url(url)
temp_dir = Path(cache_dir) / temp_dir
temp_dir.mkdir(exist_ok=True, parents=True)
name = client_utils.strip_invalid_filename_characters(Path(url).name)
full_temp_file_path = str(abspath(temp_dir / name))
full_temp_file_path = str(abspath(temp_dir / filename))

if not Path(full_temp_file_path).exists():
async with async_client.stream("GET", url, follow_redirects=True) as response:
for redirect in response.history:
get_public_url(str(redirect.url))

async with aiofiles.open(full_temp_file_path, "wb") as f:
async for chunk in response.aiter_raw():
await f.write(chunk)
async with aiofiles.open(full_temp_file_path, "wb") as f:
async for chunk in response.aiter_bytes():
await f.write(chunk)

return full_temp_file_path


save_url_to_cache = ssrf_protected_download
aliabid94 marked this conversation as resolved.
Show resolved Hide resolved


def save_base64_to_cache(
base64_encoding: str, cache_dir: str, file_name: str | None = None
) -> str:
Expand Down
Loading
Loading