diff --git a/tensorflow_datasets/core/download/downloader.py b/tensorflow_datasets/core/download/downloader.py index ff5fc60e5ed..c07bc526bc3 100644 --- a/tensorflow_datasets/core/download/downloader.py +++ b/tensorflow_datasets/core/download/downloader.py @@ -31,6 +31,7 @@ import urllib from etils import epath +from tensorflow_datasets.core import lazy_imports_lib from tensorflow_datasets.core import units from tensorflow_datasets.core import utils from tensorflow_datasets.core.download import checksums as checksums_lib @@ -130,6 +131,43 @@ def _get_filename(response: Response) -> str: return _basename_from_url(response.url) +def _process_gdrive_confirmation(original_url: str, contents: str) -> str: + """Process Google Drive confirmation page. + + Extracts the download link from a Google Drive confirmation page. + + Args: + original_url: The URL the confirmation page was originally retrieved from. + contents: The confirmation page's HTML. + + Returns: + download_url: The URL for downloading the file. + """ + bs4 = lazy_imports_lib.lazy_imports.bs4 + soup = bs4.BeautifulSoup(contents, 'html.parser') + form = soup.find('form') + if not form: + raise ValueError( + f'Failed to obtain confirmation link for GDrive URL {original_url}.' + ) + action = form.get('action', '') + if not action: + raise ValueError( + f'Failed to obtain confirmation link for GDrive URL {original_url}.' + ) + # Find the s named 'uuid', 'export', 'id' and 'confirm' + input_names = ['uuid', 'export', 'id', 'confirm'] + params = {} + for name in input_names: + input_tag = form.find('input', {'name': name}) + if input_tag: + params[name] = input_tag.get('value', '') + query_string = urllib.parse.urlencode(params) + download_url = f'{action}?{query_string}' if query_string else action + download_url = urllib.parse.urljoin(original_url, download_url) + return download_url + + class _Downloader: """Class providing async download API with checksum validation. @@ -318,11 +356,26 @@ def _open_with_requests( session.mount( 'https://', requests.adapters.HTTPAdapter(max_retries=retries) ) - if _DRIVE_URL.match(url): - url = _normalize_drive_url(url) with session.get(url, stream=True, **kwargs) as response: - _assert_status(response) - yield (response, response.iter_content(chunk_size=io.DEFAULT_BUFFER_SIZE)) + if ( + _DRIVE_URL.match(url) + and 'Content-Disposition' not in response.headers + ): + download_url = _process_gdrive_confirmation(url, response.text) + with session.get( + download_url, stream=True, **kwargs + ) as download_response: + _assert_status(download_response) + yield ( + download_response, + download_response.iter_content(chunk_size=io.DEFAULT_BUFFER_SIZE), + ) + else: + _assert_status(response) + yield ( + response, + response.iter_content(chunk_size=io.DEFAULT_BUFFER_SIZE), + ) @contextlib.contextmanager @@ -338,13 +391,6 @@ def _open_with_urllib( ) -def _normalize_drive_url(url: str) -> str: - """Returns Google Drive url with confirmation token.""" - # This bypasses the "Google Drive can't scan this file for viruses" warning - # when dowloading large files. - return url + '&confirm=t' - - def _assert_status(response: requests.Response) -> None: """Ensure the URL response is 200.""" if response.status_code != 200: diff --git a/tensorflow_datasets/core/download/downloader_test.py b/tensorflow_datasets/core/download/downloader_test.py index f1a503e800b..27bacad2b69 100644 --- a/tensorflow_datasets/core/download/downloader_test.py +++ b/tensorflow_datasets/core/download/downloader_test.py @@ -18,6 +18,7 @@ from typing import Optional from unittest import mock +import bs4 from etils import epath import pytest from tensorflow_datasets import testing @@ -36,6 +37,7 @@ def __init__(self, url, content, cookies=None, headers=None, status_code=200): self.status_code = status_code # For urllib codepath self.read = self.raw.read + self.text = '' def __enter__(self): return self @@ -78,6 +80,14 @@ def setUp(self): lambda *a, **kw: _FakeResponse(self.url, self.response, self.cookies), ).start() + bs_mock = mock.MagicMock(spec=bs4.BeautifulSoup) + form_mock = mock.MagicMock() + form_mock.get.return_value = 'x' + bs_mock.find.return_value = form_mock + mock.patch.object( + bs4, 'BeautifulSoup', autospec=True, return_value=bs_mock + ).start() + def test_ok(self): promise = self.downloader.download(self.url, self.tmp_dir) future = promise.get()