diff --git a/docs/cache.md b/docs/cache.md new file mode 100644 index 0000000..a02bfe2 --- /dev/null +++ b/docs/cache.md @@ -0,0 +1,18 @@ +# Cache Management + +`struct` caches remote content under `~/.struct/cache` by default. Use the `--cache-policy` flag to control how cached data is used when fetching remote files: + +- `always` (default): use cached content when available. +- `never`: bypass the cache and do not store fetched content. +- `refresh`: always refetch remote content and update the cache. + +## Inspecting and Clearing Cache + +The `cache` command lets you inspect or clear the cache: + +```bash +struct cache inspect +struct cache clear +``` + +Use `--cache-dir` to operate on a different cache directory. diff --git a/struct_module/commands/__init__.py b/struct_module/commands/__init__.py index eb74d17..1d4e484 100644 --- a/struct_module/commands/__init__.py +++ b/struct_module/commands/__init__.py @@ -1,5 +1,5 @@ import logging -from struct_module.completers import log_level_completer +from struct_module.completers import log_level_completer, cache_policy_completer # Base command class class Command: @@ -12,6 +12,7 @@ def add_common_arguments(self): self.parser.add_argument('-l', '--log', type=str, default='INFO', help='Set the logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)').completer = log_level_completer self.parser.add_argument('-c', '--config-file', type=str, help='Path to a configuration file') self.parser.add_argument('-i', '--log-file', type=str, help='Path to a log file') + self.parser.add_argument('--cache-policy', type=str, choices=['always', 'never', 'refresh'], default='always', help='Cache policy for remote content fetching').completer = cache_policy_completer def execute(self, args): raise NotImplementedError("Subclasses should implement this!") diff --git a/struct_module/commands/cache.py b/struct_module/commands/cache.py new file mode 100644 index 0000000..442ccc5 --- /dev/null +++ b/struct_module/commands/cache.py @@ -0,0 +1,26 @@ +from struct_module.commands import Command +import os +from pathlib import Path +import shutil + +class CacheCommand(Command): + def __init__(self, parser): + super().__init__(parser) + parser.add_argument('action', choices=['inspect', 'clear'], help='Inspect cache contents or clear cache') + parser.add_argument('--cache-dir', type=str, default=os.path.expanduser('~/.struct/cache'), help='Path to cache directory') + parser.set_defaults(func=self.execute) + + def execute(self, args): + cache_dir = Path(args.cache_dir) + if args.action == 'inspect': + if not cache_dir.exists() or not any(cache_dir.iterdir()): + print('Cache is empty.') + return + for path in cache_dir.iterdir(): + if path.is_file(): + print(f"{path}: {path.stat().st_size} bytes") + elif args.action == 'clear': + if cache_dir.exists(): + shutil.rmtree(cache_dir) + cache_dir.mkdir(parents=True, exist_ok=True) + print('Cache cleared.') diff --git a/struct_module/commands/generate.py b/struct_module/commands/generate.py index e78b83d..6394920 100644 --- a/struct_module/commands/generate.py +++ b/struct_module/commands/generate.py @@ -160,6 +160,7 @@ def _create_structure(self, args, mappings=None): content["config_variables"] = config_variables content["input_store"] = args.input_store content["non_interactive"] = args.non_interactive + content["cache_policy"] = args.cache_policy content["mappings"] = mappings or {} file_item = FileItem(content) file_item.fetch_content() @@ -172,6 +173,7 @@ def _create_structure(self, args, mappings=None): "input_store": args.input_store, "non_interactive": args.non_interactive, "mappings": mappings or {}, + "cache_policy": args.cache_policy, } ) diff --git a/struct_module/completers.py b/struct_module/completers.py index b880844..cb624fe 100644 --- a/struct_module/completers.py +++ b/struct_module/completers.py @@ -51,4 +51,5 @@ def _get_available_structures(self, parsed_args): log_level_completer = ChoicesCompleter(['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL']) file_strategy_completer = ChoicesCompleter(['overwrite', 'skip', 'append', 'rename', 'backup']) +cache_policy_completer = ChoicesCompleter(['always', 'never', 'refresh']) structures_completer = StructuresCompleter() diff --git a/struct_module/content_fetcher.py b/struct_module/content_fetcher.py index 8c7686b..40e7335 100644 --- a/struct_module/content_fetcher.py +++ b/struct_module/content_fetcher.py @@ -6,6 +6,7 @@ from pathlib import Path import hashlib import logging +import tempfile try: import boto3 @@ -22,10 +23,11 @@ gcs_available = False class ContentFetcher: - def __init__(self, cache_dir=None): + def __init__(self, cache_dir=None, cache_policy="always"): self.logger = logging.getLogger(__name__) self.cache_dir = Path(cache_dir or os.path.expanduser("~/.struct/cache")) self.cache_dir.mkdir(parents=True, exist_ok=True) + self.cache_policy = cache_policy def fetch_content(self, content_location): """ @@ -72,15 +74,20 @@ def _fetch_http_url(self, url): cache_key = hashlib.md5(url.encode()).hexdigest() cache_file_path = self.cache_dir / cache_key - if cache_file_path.exists(): + if self.cache_policy == "always" and cache_file_path.exists(): self.logger.debug(f"Loading content from cache: {cache_file_path}") with cache_file_path.open('r') as file: return file.read() + if self.cache_policy == "refresh" and cache_file_path.exists(): + cache_file_path.unlink() + response = requests.get(url) response.raise_for_status() - with cache_file_path.open('w') as file: - file.write(response.text) + + if self.cache_policy in ["always", "refresh"]: + with cache_file_path.open('w') as file: + file.write(response.text) return response.text @@ -127,15 +134,22 @@ def _clone_or_fetch_github(self, owner, repo, branch, file_path, https=True): repo_cache_path = self.cache_dir / f"{owner}_{repo}_{branch}" clone_url = f"https://github.com/{owner}/{repo}.git" if https else f"git@github.com:{owner}/{repo}.git" - # Clone or fetch the repository + if self.cache_policy == "never": + with tempfile.TemporaryDirectory() as tmpdir: + subprocess.run(["git", "clone", "-b", branch, clone_url, tmpdir], check=True) + file_full_path = Path(tmpdir) / file_path + if not file_full_path.exists(): + raise FileNotFoundError(f"File {file_path} not found in repository {owner}/{repo} on branch {branch}") + with file_full_path.open('r') as file: + return file.read() + if not repo_cache_path.exists(): self.logger.debug(f"Cloning repository: {owner}/{repo} (branch: {branch})") subprocess.run(["git", "clone", "-b", branch, clone_url, str(repo_cache_path)], check=True) - else: - self.logger.debug(f"Repository already cloned. Pulling latest changes for: {repo_cache_path}") + elif self.cache_policy == "refresh": + self.logger.debug(f"Refreshing repository cache: {repo_cache_path}") subprocess.run(["git", "-C", str(repo_cache_path), "pull"], check=True) - # Read the requested file file_full_path = repo_cache_path / file_path if not file_full_path.exists(): raise FileNotFoundError(f"File {file_path} not found in repository {owner}/{repo} on branch {branch}") @@ -159,6 +173,31 @@ def _fetch_s3_file(self, s3_url): bucket_name, key = match.groups() local_file_path = self.cache_dir / Path(key).name + if self.cache_policy == "always" and local_file_path.exists(): + with local_file_path.open('r') as file: + return file.read() + + if self.cache_policy == "never": + with tempfile.TemporaryDirectory() as tmpdir: + temp_path = Path(tmpdir) / Path(key).name + session = boto3.Session() + s3_client = session.client("s3") + try: + s3_client.download_file(bucket_name, key, str(temp_path)) + except NoCredentialsError: + raise RuntimeError("AWS credentials not found. Ensure that your credentials are configured properly.") + except ClientError as e: + error_code = e.response.get("Error", {}).get("Code") + if error_code == "404": + raise FileNotFoundError(f"The specified S3 key does not exist: {key}") + else: + raise RuntimeError(f"Failed to download S3 file: {e}") + with temp_path.open('r') as file: + return file.read() + + if self.cache_policy == "refresh" and local_file_path.exists(): + local_file_path.unlink() + try: session = boto3.Session() # Create a new session s3_client = session.client("s3") @@ -192,6 +231,26 @@ def _fetch_gcs_file(self, gcs_url): bucket_name, key = match.groups() local_file_path = self.cache_dir / Path(key).name + if self.cache_policy == "always" and local_file_path.exists(): + with local_file_path.open('r') as file: + return file.read() + + if self.cache_policy == "never": + with tempfile.TemporaryDirectory() as tmpdir: + temp_path = Path(tmpdir) / Path(key).name + try: + gcs_client = storage.Client() + bucket = gcs_client.bucket(bucket_name) + blob = bucket.blob(key) + blob.download_to_filename(str(temp_path)) + except GoogleAPIError as e: + raise RuntimeError(f"Failed to download GCS file: {e}") + with temp_path.open('r') as file: + return file.read() + + if self.cache_policy == "refresh" and local_file_path.exists(): + local_file_path.unlink() + try: gcs_client = storage.Client() bucket = gcs_client.bucket(bucket_name) diff --git a/struct_module/file_item.py b/struct_module/file_item.py index a664908..52d7369 100644 --- a/struct_module/file_item.py +++ b/struct_module/file_item.py @@ -25,7 +25,8 @@ def __init__(self, properties): self.skip = properties.get("skip", False) self.skip_if_exists = properties.get("skip_if_exists", False) - self.content_fetcher = ContentFetcher() + self.cache_policy = properties.get("cache_policy", "always") + self.content_fetcher = ContentFetcher(cache_policy=self.cache_policy) self.system_prompt = properties.get("system_prompt") or properties.get("global_system_prompt") self.user_prompt = properties.get("user_prompt") diff --git a/struct_module/main.py b/struct_module/main.py index a402c2e..3ff2013 100644 --- a/struct_module/main.py +++ b/struct_module/main.py @@ -8,6 +8,7 @@ from struct_module.commands.list import ListCommand from struct_module.commands.generate_schema import GenerateSchemaCommand from struct_module.commands.mcp import MCPCommand +from struct_module.commands.cache import CacheCommand from struct_module.logging_config import configure_logging @@ -30,6 +31,7 @@ def main(): ListCommand(subparsers.add_parser('list', help='List available structures')) GenerateSchemaCommand(subparsers.add_parser('generate-schema', help='Generate JSON schema for available structures')) MCPCommand(subparsers.add_parser('mcp', help='MCP (Model Context Protocol) support')) + CacheCommand(subparsers.add_parser('cache', help='Inspect or clear cache')) argcomplete.autocomplete(parser) diff --git a/tests/test_cache_policy.py b/tests/test_cache_policy.py new file mode 100644 index 0000000..8ff076f --- /dev/null +++ b/tests/test_cache_policy.py @@ -0,0 +1,62 @@ +import hashlib +from unittest.mock import patch, MagicMock +import tempfile +from pathlib import Path +from struct_module.content_fetcher import ContentFetcher +import argparse +from struct_module.commands.cache import CacheCommand + + +def _mock_response(text): + mock = MagicMock() + mock.text = text + mock.raise_for_status = MagicMock() + return mock + + +def test_cache_policy_always(): + with tempfile.TemporaryDirectory() as tmpdir: + fetcher = ContentFetcher(cache_dir=tmpdir, cache_policy="always") + url = "https://example.com/data" + with patch("requests.get", return_value=_mock_response("first")) as mock_get: + assert fetcher.fetch_content(url) == "first" + assert mock_get.call_count == 1 + with patch("requests.get", return_value=_mock_response("second")) as mock_get: + assert fetcher.fetch_content(url) == "first" + mock_get.assert_not_called() + + +def test_cache_policy_never(tmp_path): + fetcher = ContentFetcher(cache_dir=tmp_path, cache_policy="never") + url = "https://example.com/data" + with patch("requests.get", side_effect=[_mock_response("a"), _mock_response("b")]) as mock_get: + assert fetcher.fetch_content(url) == "a" + assert fetcher.fetch_content(url) == "b" + assert mock_get.call_count == 2 + cache_key = hashlib.md5(url.encode()).hexdigest() + assert not (tmp_path / cache_key).exists() + + +def test_cache_policy_refresh(tmp_path): + fetcher = ContentFetcher(cache_dir=tmp_path, cache_policy="refresh") + url = "https://example.com/data" + with patch("requests.get", side_effect=[_mock_response("old"), _mock_response("new")]) as mock_get: + assert fetcher.fetch_content(url) == "old" + assert fetcher.fetch_content(url) == "new" + assert mock_get.call_count == 2 + cache_key = hashlib.md5(url.encode()).hexdigest() + with open(tmp_path / cache_key, "r") as f: + assert f.read() == "new" + + +def test_cache_command_clear(tmp_path, capsys): + cache_dir = tmp_path / "cache" + cache_dir.mkdir() + (cache_dir / "file.txt").write_text("data") + parser = argparse.ArgumentParser() + cmd = CacheCommand(parser) + args = parser.parse_args(["clear", "--cache-dir", str(cache_dir)]) + cmd.execute(args) + assert not any(cache_dir.iterdir()) + captured = capsys.readouterr() + assert "Cache cleared." in captured.out