|
5 | 5 | import json |
6 | 6 | import os |
7 | 7 | import sys |
| 8 | +import time |
| 9 | +import zipfile |
8 | 10 | from typing import Any, Dict, List, Optional |
9 | 11 |
|
10 | 12 | import requests |
11 | 13 | from PIL import Image |
| 14 | +from requests.exceptions import HTTPError |
| 15 | +from tqdm import tqdm |
12 | 16 |
|
13 | 17 | from roboflow.adapters import rfapi |
14 | 18 | from roboflow.adapters.rfapi import AnnotationSaveError, ImageUploadError, RoboflowError |
@@ -662,6 +666,119 @@ def _upload_zip( |
662 | 666 | except Exception as e: |
663 | 667 | print(f"An error occured when uploading the model: {e}") |
664 | 668 |
|
| 669 | + def search_export( |
| 670 | + self, |
| 671 | + query: str, |
| 672 | + format: str = "coco", |
| 673 | + location: Optional[str] = None, |
| 674 | + dataset: Optional[str] = None, |
| 675 | + annotation_group: Optional[str] = None, |
| 676 | + name: Optional[str] = None, |
| 677 | + extract_zip: bool = True, |
| 678 | + ) -> str: |
| 679 | + """Export search results as a downloaded dataset. |
| 680 | +
|
| 681 | + Args: |
| 682 | + query: Search query string (e.g. ``"tag:annotate"`` or ``"*"``). |
| 683 | + format: Annotation format for the export (default ``"coco"``). |
| 684 | + location: Local directory to save the exported dataset. |
| 685 | + Defaults to ``./search-export-{format}``. |
| 686 | + dataset: Limit export to a specific dataset (project) slug. |
| 687 | + annotation_group: Limit export to a specific annotation group. |
| 688 | + name: Optional name for the export. |
| 689 | + extract_zip: If True (default), extract the zip and remove it. |
| 690 | + If False, keep the zip file as-is. |
| 691 | +
|
| 692 | + Returns: |
| 693 | + Absolute path to the extracted directory or the zip file. |
| 694 | +
|
| 695 | + Raises: |
| 696 | + ValueError: If both *dataset* and *annotation_group* are provided. |
| 697 | + RoboflowError: On API errors or export timeout. |
| 698 | + """ |
| 699 | + if dataset is not None and annotation_group is not None: |
| 700 | + raise ValueError("dataset and annotation_group are mutually exclusive; provide only one") |
| 701 | + |
| 702 | + if location is None: |
| 703 | + location = f"./search-export-{format}" |
| 704 | + location = os.path.abspath(location) |
| 705 | + |
| 706 | + # 1. Start the export |
| 707 | + export_id = rfapi.start_search_export( |
| 708 | + api_key=self.__api_key, |
| 709 | + workspace_url=self.url, |
| 710 | + query=query, |
| 711 | + format=format, |
| 712 | + dataset=dataset, |
| 713 | + annotation_group=annotation_group, |
| 714 | + name=name, |
| 715 | + ) |
| 716 | + print(f"Export started (id={export_id}). Polling for completion...") |
| 717 | + |
| 718 | + # 2. Poll until ready |
| 719 | + timeout = 600 |
| 720 | + poll_interval = 5 |
| 721 | + elapsed = 0 |
| 722 | + while elapsed < timeout: |
| 723 | + status = rfapi.get_search_export( |
| 724 | + api_key=self.__api_key, |
| 725 | + workspace_url=self.url, |
| 726 | + export_id=export_id, |
| 727 | + ) |
| 728 | + if status.get("ready"): |
| 729 | + break |
| 730 | + time.sleep(poll_interval) |
| 731 | + elapsed += poll_interval |
| 732 | + else: |
| 733 | + raise RoboflowError(f"Search export timed out after {timeout}s") |
| 734 | + |
| 735 | + download_url = status["link"] |
| 736 | + |
| 737 | + # 3. Download zip |
| 738 | + if not os.path.exists(location): |
| 739 | + os.makedirs(location) |
| 740 | + |
| 741 | + zip_path = os.path.join(location, "roboflow.zip") |
| 742 | + response = requests.get(download_url, stream=True) |
| 743 | + try: |
| 744 | + response.raise_for_status() |
| 745 | + except HTTPError as e: |
| 746 | + raise RoboflowError(f"Failed to download search export: {e}") |
| 747 | + |
| 748 | + total_length = response.headers.get("content-length") |
| 749 | + try: |
| 750 | + total_kib = int(total_length) // 1024 + 1 if total_length is not None else None |
| 751 | + except (TypeError, ValueError): |
| 752 | + total_kib = None |
| 753 | + with open(zip_path, "wb") as f: |
| 754 | + for chunk in tqdm( |
| 755 | + response.iter_content(chunk_size=1024), |
| 756 | + desc=f"Downloading search export to {location}", |
| 757 | + total=total_kib, |
| 758 | + ): |
| 759 | + if chunk: |
| 760 | + f.write(chunk) |
| 761 | + f.flush() |
| 762 | + |
| 763 | + if extract_zip: |
| 764 | + desc = f"Extracting search export to {location}" |
| 765 | + try: |
| 766 | + with zipfile.ZipFile(zip_path, "r") as zip_ref: |
| 767 | + for member in tqdm(zip_ref.infolist(), desc=desc): |
| 768 | + try: |
| 769 | + zip_ref.extract(member, location) |
| 770 | + except zipfile.error: |
| 771 | + raise RoboflowError("Error unzipping search export") |
| 772 | + except zipfile.BadZipFile: |
| 773 | + raise RoboflowError(f"Downloaded file is not a valid zip archive: {zip_path}") |
| 774 | + |
| 775 | + os.remove(zip_path) |
| 776 | + print(f"Search export extracted to {location}") |
| 777 | + return location |
| 778 | + else: |
| 779 | + print(f"Search export saved to {zip_path}") |
| 780 | + return zip_path |
| 781 | + |
665 | 782 | def __str__(self): |
666 | 783 | projects = self.projects() |
667 | 784 | json_value = {"name": self.name, "url": self.url, "projects": projects} |
|
0 commit comments