Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pre/post-processing download requests #9383

Merged
merged 29 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/plenty-dragons-fold.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"gradio": minor
---

feat:Pre/post-processing download requests
4 changes: 2 additions & 2 deletions gradio/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ async def async_move_resource_to_block_cache(
url_or_file_path = str(url_or_file_path)

if client_utils.is_http_url_like(url_or_file_path):
temp_file_path = await processing_utils.async_save_url_to_cache(
temp_file_path = await processing_utils.ssrf_protected_httpx_download(
url_or_file_path, cache_dir=self.GRADIO_CACHE
)

Expand Down Expand Up @@ -336,7 +336,7 @@ def move_resource_to_block_cache(
url_or_file_path = str(url_or_file_path)

if client_utils.is_http_url_like(url_or_file_path):
temp_file_path = processing_utils.save_url_to_cache(
temp_file_path = processing_utils.sync_ssrf_protected_httpx_download(
url_or_file_path, cache_dir=self.GRADIO_CACHE
)

Expand Down
2 changes: 1 addition & 1 deletion gradio/components/annotated_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def postprocess(
base_img = value[0]
if isinstance(base_img, str):
if client_utils.is_http_url_like(base_img):
base_img = processing_utils.save_url_to_cache(
base_img = processing_utils.sync_ssrf_protected_httpx_download(
base_img, cache_dir=self.GRADIO_CACHE
)
base_img_path = base_img
Expand Down
8 changes: 5 additions & 3 deletions gradio/components/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,15 +163,17 @@ def _download_files(self, value: str | list[str]) -> str | list[str]:
if isinstance(value, list):
for file in value:
if client_utils.is_http_url_like(file):
downloaded_file = processing_utils.save_url_to_cache(
file, self.GRADIO_CACHE
downloaded_file = (
processing_utils.sync_ssrf_protected_httpx_download(
file, self.GRADIO_CACHE
)
)
downloaded_files.append(downloaded_file)
else:
downloaded_files.append(file)
return downloaded_files
if client_utils.is_http_url_like(value):
downloaded_file = processing_utils.save_url_to_cache(
downloaded_file = processing_utils.sync_ssrf_protected_httpx_download(
value, self.GRADIO_CACHE
)
return downloaded_file
Expand Down
8 changes: 5 additions & 3 deletions gradio/components/upload_button.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,15 +185,17 @@ def _download_files(self, value: str | list[str]) -> str | list[str]:
if isinstance(value, list):
for file in value:
if client_utils.is_http_url_like(file):
downloaded_file = processing_utils.save_url_to_cache(
file, self.GRADIO_CACHE
downloaded_file = (
processing_utils.sync_ssrf_protected_httpx_download(
file, self.GRADIO_CACHE
)
)
downloaded_files.append(downloaded_file)
else:
downloaded_files.append(file)
return downloaded_files
if client_utils.is_http_url_like(value):
downloaded_file = processing_utils.save_url_to_cache(
downloaded_file = processing_utils.sync_ssrf_protected_httpx_download(
value, self.GRADIO_CACHE
)
return downloaded_file
Expand Down
2 changes: 1 addition & 1 deletion gradio/components/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def _format_video(self, video: str | Path | None) -> FileData | None:
# For cases where the video needs to be converted to another format
# or have a watermark added.
if is_url:
video = processing_utils.save_url_to_cache(
video = processing_utils.sync_ssrf_protected_httpx_download(
video, cache_dir=self.GRADIO_CACHE
)
if (
Expand Down
228 changes: 134 additions & 94 deletions gradio/processing_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import asyncio
import base64
import hashlib
import ipaddress
Expand All @@ -11,13 +12,13 @@
import subprocess
import tempfile
import warnings
from functools import lru_cache
from collections.abc import Awaitable, Callable, Coroutine
from functools import lru_cache, wraps
from io import BytesIO
from pathlib import Path
from typing import TYPE_CHECKING, Any
from urllib.parse import urlparse, urlunparse
from typing import TYPE_CHECKING, Any, TypeVar
from urllib.parse import urlparse

import aiofiles
import httpx
import numpy as np
from gradio_client import utils as client_utils
Expand Down Expand Up @@ -273,121 +274,160 @@ def save_file_to_cache(file_path: str | Path, cache_dir: str) -> str:
return full_temp_file_path


@lru_cache(maxsize=256)
def resolve_with_google_dns(hostname: str) -> str | None:
url = f"https://dns.google/resolve?name={hostname}&type=A"
# Always return these URLs as is, without checking to see if they resolve
# to an internal IP address. This is because Hugging Face uses DNS splitting,
# which means that requests from HF Spaces to HF Datasets or HF Models
# may resolve to internal IP addresses even if they are publicly accessible.
PUBLIC_URL_WHITELIST = ["hf.co", "huggingface.co"]

if wasm_utils.IS_WASM:
import pyodide.http

content = pyodide.http.open_url(url)
data = json.load(content)
else:
import urllib.request
def is_public_ip(ip: str | ipaddress.IPv4Address | ipaddress.IPv6Address) -> bool:
try:
ip_obj = ipaddress.ip_address(ip)
return not (
ip_obj.is_private
or ip_obj.is_loopback
or ip_obj.is_link_local
or ip_obj.is_multicast
or ip_obj.is_reserved
or (isinstance(ip_obj, ipaddress.IPv6Address) and ip_obj.is_site_local)
aliabid94 marked this conversation as resolved.
Show resolved Hide resolved
)
except ValueError:
return False


async def resolve_hostname_locally(hostname: str) -> list[str]:
try:
loop = asyncio.get_running_loop()
addrinfo = await loop.run_in_executor(
aliabid94 marked this conversation as resolved.
Show resolved Hide resolved
None, socket.getaddrinfo, hostname, None, socket.AF_UNSPEC
)
return [ip[4][0] for ip in addrinfo]
except socket.gaierror:
return []

with urllib.request.urlopen(url) as response:
data = json.loads(response.read().decode())

if data.get("Status") == 0 and "Answer" in data:
for answer in data["Answer"]:
if answer["type"] == 1:
return answer["data"]
T = TypeVar("T")


# Always return these URLs as is, without checking to see if they resolve
# to an internal IP address. This is because Hugging Face uses DNS splitting,
# which means that requests from HF Spaces to HF Datasets or HF Models
# may resolve to internal IP addresses even if they are publicly accessible.
PUBLIC_URL_WHITELIST = ["hf.co", "huggingface.co"]
def lru_cache_async(maxsize: int = 128):
aliabid94 marked this conversation as resolved.
Show resolved Hide resolved
def decorator(
async_func: Callable[..., Coroutine[Any, Any, T]],
) -> Callable[..., Awaitable[T]]:
@lru_cache(maxsize=maxsize)
@wraps(async_func)
def wrapper(*args: Any, **kwargs: Any) -> Awaitable[T]:
return asyncio.create_task(async_func(*args, **kwargs))

return wrapper

return decorator


def get_public_url(url: str) -> str:
@lru_cache_async(maxsize=256)
async def resolve_hostname_google(hostname: str) -> list[str]:
async with httpx.AsyncClient() as client:
try:
response_v4 = await client.get(
f"https://dns.google/resolve?name={hostname}&type=A"
)
response_v6 = await client.get(
f"https://dns.google/resolve?name={hostname}&type=AAAA"
)

ips = []
for response in [response_v4.json(), response_v6.json()]:
ips.extend([answer["data"] for answer in response.get("Answer", [])])
return ips
except Exception:
return []


async def validate_url(url: str):
parsed_url = urlparse(url)
if parsed_url.scheme not in ["http", "https"]:
raise httpx.RequestError(f"Invalid scheme for URL: {url}")
hostname = parsed_url.hostname
if not hostname:
raise httpx.RequestError(f"Invalid URL: {url}, missing hostname")
if hostname.lower() in PUBLIC_URL_WHITELIST:
return url

if hostname is None:
raise ValueError("Invalid URL")

if hostname in PUBLIC_URL_WHITELIST:
return

try:
addrinfo = socket.getaddrinfo(hostname, None)
except socket.gaierror as e:
raise httpx.RequestError(
f"Cannot resolve URL with hostname: {hostname}, please download this file and use the path instead."
) from e

for family, _, _, _, sockaddr in addrinfo:
ip = sockaddr[0]
if family == socket.AF_INET6:
ip = ip.split("%")[0] # Remove scope ID if present

if ipaddress.ip_address(ip).is_global:
return url

google_resolved_ip = resolve_with_google_dns(hostname)
if google_resolved_ip and ipaddress.ip_address(google_resolved_ip).is_global:
if parsed_url.scheme == "https":
return url
new_parsed = parsed_url._replace(netloc=google_resolved_ip)
if parsed_url.port:
new_parsed = new_parsed._replace(
netloc=f"{google_resolved_ip}:{parsed_url.port}"
)
return urlunparse(new_parsed)
ip = ipaddress.ip_address(hostname)
if is_public_ip(ip):
return
else:
raise ValueError(f"URL resolves to private IP: {ip}")
except ValueError:
# It's a hostname, not an IP
pass

raise httpx.RequestError(
f"No public IP address found for URL: {url}, please download this file and use the path instead."
)
local_ips = await resolve_hostname_locally(hostname)
for ip in local_ips:
if is_public_ip(ip):
return

google_ips = await resolve_hostname_google(hostname)
for ip in google_ips:
if is_public_ip(ip):
return

def save_url_to_cache(url: str, cache_dir: str) -> str:
"""Downloads a file and makes a temporary file path for a copy if does not already
exist. Otherwise returns the path to the existing temp file."""
url = get_public_url(url)
raise ValueError(f"Unable to resolve {hostname} to a public IP address")

temp_dir = hash_url(url)
temp_dir = Path(cache_dir) / temp_dir
temp_dir.mkdir(exist_ok=True, parents=True)
name = client_utils.strip_invalid_filename_characters(Path(url).name)
full_temp_file_path = str(abspath(temp_dir / name))

if not Path(full_temp_file_path).exists():
with (
sync_client.stream("GET", url, follow_redirects=True) as response,
open(full_temp_file_path, "wb") as f,
):
for redirect in response.history:
get_public_url(str(redirect.url))
async def ssrf_protected_httpx_download(url: str, cache_dir: str) -> str:
max_redirects = 10
redirect_count = 0
current_url = url

for chunk in response.iter_raw():
f.write(chunk)
async with httpx.AsyncClient() as client:
while redirect_count < max_redirects:
try:
await validate_url(current_url)

return full_temp_file_path
async with client.stream(
"GET", current_url, follow_redirects=False
) as response:
if response.is_redirect:
redirect_count += 1
current_url = response.headers["Location"]
continue

content_disposition = response.headers.get("Content-Disposition")
if content_disposition and "filename=" in content_disposition:
filename = Path(
content_disposition.split("filename=")[1].strip('"')
).name

async def async_save_url_to_cache(url: str, cache_dir: str) -> str:
"""Downloads a file and makes a temporary file path for a copy if does not already
exist. Otherwise returns the path to the existing temp file. Uses async httpx."""
url = get_public_url(url)
else:
filename = client_utils.strip_invalid_filename_characters(
Path(url).name
)

temp_dir = hash_url(url)
temp_dir = Path(cache_dir) / temp_dir
temp_dir.mkdir(exist_ok=True, parents=True)
name = client_utils.strip_invalid_filename_characters(Path(url).name)
full_temp_file_path = str(abspath(temp_dir / name))
filepath = os.path.join(cache_dir, filename)

if not Path(full_temp_file_path).exists():
async with async_client.stream("GET", url, follow_redirects=True) as response:
for redirect in response.history:
get_public_url(str(redirect.url))
async with asyncio.Lock():
aliabid94 marked this conversation as resolved.
Show resolved Hide resolved
with open(filepath, "wb") as f:
async for chunk in response.aiter_bytes():
f.write(chunk)

async with aiofiles.open(full_temp_file_path, "wb") as f:
async for chunk in response.aiter_raw():
await f.write(chunk)
return filepath

return full_temp_file_path
except httpx.HTTPStatusError as e:
raise ValueError(
f"HTTP error occurred requesting resource {url}: {e}"
) from e
except Exception as e:
raise ValueError(
f"An error occurred requesting resource {url}: {e}"
) from e

raise ValueError("Too many redirects")


def sync_ssrf_protected_httpx_download(url: str, cache_dir: str) -> str:
return asyncio.run(ssrf_protected_httpx_download(url, cache_dir))
aliabid94 marked this conversation as resolved.
Show resolved Hide resolved


def save_base64_to_cache(
Expand Down
Loading
Loading