Skip to content

Commit

Permalink
model downloading and UI improved
Browse files Browse the repository at this point in the history
  • Loading branch information
bgorlick committed Jun 18, 2024
1 parent 1ad0703 commit 9c64e9b
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 14 deletions.
6 changes: 3 additions & 3 deletions getai/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from getai.cli.utils import CLIUtils

# Configure logging
logging.basicConfig(level=logging.INFO)
logging.basicConfig(level=logging.ERROR)
logger = logging.getLogger(__name__)

DEFAULT_OUTPUT_DIR = Path.home() / ".getai" / "models"
Expand Down Expand Up @@ -141,7 +141,7 @@ async def main():
)
args = parser.parse_args()

logger.info("Parsed arguments: %s", args)
logger.debug("Parsed arguments: %s", args)

if args.hf_login:
logger.info("Logging in to Hugging Face CLI")
Expand All @@ -159,7 +159,7 @@ async def main():
set_defaults(args)

# Log final values of the arguments
logger.info(
logger.debug(
"Final arguments: mode=%s, search_mode=%s, download_mode=%s, identifier=%s, branch=%s, output_dir=%s, max_connections=%s, hf_token=%s",
args.mode,
getattr(args, "search_mode", None),
Expand Down
2 changes: 1 addition & 1 deletion getai/core/model_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


BASE_URL = "https://huggingface.co"
logging.basicConfig(format="%(asctime)s - %(message)s", level=logging.INFO)
logging.basicConfig(format="%(asctime)s - %(message)s", level=logging.ERROR)

file_size_pattern = re.compile(
r'<a class="[^"]*" title="Download file"[^>]*>([\d.]+ [GMK]B)'
Expand Down
21 changes: 15 additions & 6 deletions getai/core/model_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from getai import api

BASE_URL = "https://huggingface.co"
logging.basicConfig(format="%(asctime)s - %(message)s", level=logging.INFO)
logging.basicConfig(format="%(asctime)s - %(message)s", level=logging.ERROR)

file_size_pattern = re.compile(
r'<a class="[^"]*" title="Download file"[^>]*>([\d.]+ [GMK]B)'
Expand Down Expand Up @@ -40,7 +40,7 @@ def __init__(
self.search_history: List[Tuple[List[Dict], int]] = []
self.prefetched_pages: Set[int] = set()
self.logger = logging.getLogger(__name__)
self.logger.setLevel(logging.INFO)
self.logger.setLevel(logging.NOTSET)
self.filter_flag = False

self.token = hf_token
Expand Down Expand Up @@ -173,6 +173,8 @@ async def display_search_results(self):
self.filtered_models = self.main_search_models
self.filter_flag = False
await self.display_search_results()
elif user_input.lower() == "q":
break
elif user_input.isdigit() and 1 <= int(user_input) <= len(
self.get_current_models(current_page)
):
Expand Down Expand Up @@ -224,14 +226,21 @@ async def display_current_page(self, current_page: int, total_pages: int):
if self.model_branch_info.get(model_id, {}).get("has_branches", False):
branches_count = len(self.model_branch_info[model_id]["branches"])
branches_info = f" | Branches: {branches_count}"
# ascii grey
ascii_grey_bold = "\033[90m\033[1m"
ascii_bold_green = "\033[92m\033[1m"
ascii_bold_magenta = "\033[95m\033[1m"
abwhite = "\033[97m\033[1m"
qrst = "\033[0m"

print(
f"{i}. \033[94m{model_name}\033[0m by \033[94m{author}\033[0m "
f"(\033[96m{model_id}\033[0m) (\033[93mSize: {size_str}\033[0m{branches_info}) "
f"(\033[97mLast updated: {last_modified}\033[0m)"
f"{i}. \033[96m{model_name}\033[0m by \033[94m{author}\033[0m | "
f"(\033[93mSize: {size_str}\033[0m{branches_info}) "
f"(\033[97m{last_modified}\033[0m)\n"
f"{ascii_grey_bold}{'-' * 100}\033[0m"
)
print(
"Enter 'n' for next page, 'p' for previous page, 'f' to filter, 's' to sort, 'r' for previous results, 'none' to show all results, or the model # to download."
f"{ascii_bold_green}getai search commands{ascii_bold_magenta}> {abwhite} #{qrst} download model, {abwhite}'n'{qrst} (next), {abwhite}'p'{qrst} (prev), {abwhite}'f'{qrst} (filter), {abwhite}'s'{qrst} (sort), {abwhite}'r'{qrst} (results), {abwhite}'none'{qrst} (all), {abwhite} 'q'{qrst} to quit.\033[0m"
)

async def get_model_size_str(self, model_id: str) -> str:
Expand Down
2 changes: 1 addition & 1 deletion getai/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import subprocess

logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
format="%(asctime)s - %(levelname)s - %(message)s", level=logging.ERROR
)


Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "getai"
version = "0.0.983"
version = "0.0.986"
description = "GetAI - An asynchronous AI search and download tool for AI models, datasets, and tools. Designed to streamline the process of downloading machine learning models, datasets, and more."
authors = ["Ben Gorlick <[email protected]>"]
license = "MIT - with attribution"
Expand All @@ -12,7 +12,7 @@ python = "^3.9"
aiohttp = "^3.9.3"
aiofiles = "^23.2.1"
prompt-toolkit = "^3.0.43"
rainbow-tqdm = "^0.1.3"
rainbow-tqdm = "^0.1.5"
types-aiofiles = "^0.1.0"
tenacity = "^8.0.1"

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name="getai",
version="0.0.983",
version="0.0.986",
author="Ben Gorlick",
author_email="[email protected]",
description="GetAI - Asynchronous AI Downloader for models, datasets and tools",
Expand Down

0 comments on commit 9c64e9b

Please sign in to comment.