From ff89242229de9f23ca57e3e703e32429572d5c74 Mon Sep 17 00:00:00 2001 From: Alexander Hartl Date: Mon, 30 Dec 2024 16:25:47 +0100 Subject: [PATCH] Fix GDrive URLs --- .../core/download/downloader.py | 58 +++++++++++++++---- 1 file changed, 47 insertions(+), 11 deletions(-) diff --git a/tensorflow_datasets/core/download/downloader.py b/tensorflow_datasets/core/download/downloader.py index ff5fc60e5ed..94554dd2fc3 100644 --- a/tensorflow_datasets/core/download/downloader.py +++ b/tensorflow_datasets/core/download/downloader.py @@ -33,6 +33,7 @@ from etils import epath from tensorflow_datasets.core import units from tensorflow_datasets.core import utils +from tensorflow_datasets.core import lazy_imports_lib from tensorflow_datasets.core.download import checksums as checksums_lib from tensorflow_datasets.core.download import resource as resource_lib from tensorflow_datasets.core.download import util as download_utils_lib @@ -130,6 +131,44 @@ def _get_filename(response: Response) -> str: return _basename_from_url(response.url) +def _process_gdrive_confirmation(original_url: str, contents: str) -> str: + """Process Google Drive confirmation page. + + Extracts the download link from a Google Drive confirmation page. + + Args: + original_url: The URL the confirmation page was originally + retrieved from. + contents: The confirmation page's HTML. + + Returns: + download_url: The URL for downloading the file. + """ + bs4 = lazy_imports_lib.lazy_imports.bs4 + soup = bs4.BeautifulSoup(contents, 'html.parser') + form = soup.find('form') + if not form: + raise ValueError( + f'Failed to obtain confirmation link for GDrive URL {original_url}.' + ) + action = form.get('action', '') + if not action: + raise ValueError( + f'Failed to obtain confirmation link for GDrive URL {original_url}.' + ) + # Find the s named 'uuid', 'export', 'id' and 'confirm' + input_names = ['uuid', 'export', 'id', 'confirm'] + params = {} + for name in input_names: + input_tag = form.find('input', {'name': name}) + if input_tag: + params[name] = input_tag.get('value', '') + query_string = urllib.parse.urlencode(params) + download_url = f'{action}?{query_string}' if query_string else action + download_url = urllib.parse.urljoin(original_url, download_url) + return download_url + + class _Downloader: """Class providing async download API with checksum validation. @@ -318,11 +357,15 @@ def _open_with_requests( session.mount( 'https://', requests.adapters.HTTPAdapter(max_retries=retries) ) - if _DRIVE_URL.match(url): - url = _normalize_drive_url(url) with session.get(url, stream=True, **kwargs) as response: - _assert_status(response) - yield (response, response.iter_content(chunk_size=io.DEFAULT_BUFFER_SIZE)) + if _DRIVE_URL.match(url) and 'Content-Disposition' not in response.headers: + download_url = _process_gdrive_confirmation(url, response.text) + with session.get(download_url, stream=True, **kwargs) as download_response: + _assert_status(download_response) + yield (download_response, download_response.iter_content(chunk_size=io.DEFAULT_BUFFER_SIZE)) + else: + _assert_status(response) + yield (response, response.iter_content(chunk_size=io.DEFAULT_BUFFER_SIZE)) @contextlib.contextmanager @@ -338,13 +381,6 @@ def _open_with_urllib( ) -def _normalize_drive_url(url: str) -> str: - """Returns Google Drive url with confirmation token.""" - # This bypasses the "Google Drive can't scan this file for viruses" warning - # when dowloading large files. - return url + '&confirm=t' - - def _assert_status(response: requests.Response) -> None: """Ensure the URL response is 200.""" if response.status_code != 200: