Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions docs/cache.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Cache Management

`struct` caches remote content under `~/.struct/cache` by default. Use the `--cache-policy` flag to control how cached data is used when fetching remote files:

- `always` (default): use cached content when available.
- `never`: bypass the cache and do not store fetched content.
- `refresh`: always refetch remote content and update the cache.

## Inspecting and Clearing Cache

The `cache` command lets you inspect or clear the cache:

```bash
struct cache inspect
struct cache clear
```

Use `--cache-dir` to operate on a different cache directory.
3 changes: 2 additions & 1 deletion struct_module/commands/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from struct_module.completers import log_level_completer
from struct_module.completers import log_level_completer, cache_policy_completer

# Base command class
class Command:
Expand All @@ -12,6 +12,7 @@ def add_common_arguments(self):
self.parser.add_argument('-l', '--log', type=str, default='INFO', help='Set the logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)').completer = log_level_completer
self.parser.add_argument('-c', '--config-file', type=str, help='Path to a configuration file')
self.parser.add_argument('-i', '--log-file', type=str, help='Path to a log file')
self.parser.add_argument('--cache-policy', type=str, choices=['always', 'never', 'refresh'], default='always', help='Cache policy for remote content fetching').completer = cache_policy_completer

def execute(self, args):
raise NotImplementedError("Subclasses should implement this!")
26 changes: 26 additions & 0 deletions struct_module/commands/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from struct_module.commands import Command
import os
from pathlib import Path
import shutil

class CacheCommand(Command):
def __init__(self, parser):
super().__init__(parser)
parser.add_argument('action', choices=['inspect', 'clear'], help='Inspect cache contents or clear cache')
parser.add_argument('--cache-dir', type=str, default=os.path.expanduser('~/.struct/cache'), help='Path to cache directory')
parser.set_defaults(func=self.execute)

def execute(self, args):
cache_dir = Path(args.cache_dir)
if args.action == 'inspect':
if not cache_dir.exists() or not any(cache_dir.iterdir()):
print('Cache is empty.')
return
for path in cache_dir.iterdir():
if path.is_file():
print(f"{path}: {path.stat().st_size} bytes")
elif args.action == 'clear':
if cache_dir.exists():
shutil.rmtree(cache_dir)
cache_dir.mkdir(parents=True, exist_ok=True)
print('Cache cleared.')
2 changes: 2 additions & 0 deletions struct_module/commands/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def _create_structure(self, args, mappings=None):
content["config_variables"] = config_variables
content["input_store"] = args.input_store
content["non_interactive"] = args.non_interactive
content["cache_policy"] = args.cache_policy
content["mappings"] = mappings or {}
file_item = FileItem(content)
file_item.fetch_content()
Expand All @@ -172,6 +173,7 @@ def _create_structure(self, args, mappings=None):
"input_store": args.input_store,
"non_interactive": args.non_interactive,
"mappings": mappings or {},
"cache_policy": args.cache_policy,
}
)

Expand Down
1 change: 1 addition & 0 deletions struct_module/completers.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,5 @@ def _get_available_structures(self, parsed_args):

log_level_completer = ChoicesCompleter(['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'])
file_strategy_completer = ChoicesCompleter(['overwrite', 'skip', 'append', 'rename', 'backup'])
cache_policy_completer = ChoicesCompleter(['always', 'never', 'refresh'])
structures_completer = StructuresCompleter()
75 changes: 67 additions & 8 deletions struct_module/content_fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pathlib import Path
import hashlib
import logging
import tempfile

try:
import boto3
Expand All @@ -22,10 +23,11 @@
gcs_available = False

class ContentFetcher:
def __init__(self, cache_dir=None):
def __init__(self, cache_dir=None, cache_policy="always"):
self.logger = logging.getLogger(__name__)
self.cache_dir = Path(cache_dir or os.path.expanduser("~/.struct/cache"))
self.cache_dir.mkdir(parents=True, exist_ok=True)
self.cache_policy = cache_policy

def fetch_content(self, content_location):
"""
Expand Down Expand Up @@ -72,15 +74,20 @@ def _fetch_http_url(self, url):
cache_key = hashlib.md5(url.encode()).hexdigest()
cache_file_path = self.cache_dir / cache_key

if cache_file_path.exists():
if self.cache_policy == "always" and cache_file_path.exists():
self.logger.debug(f"Loading content from cache: {cache_file_path}")
with cache_file_path.open('r') as file:
return file.read()

if self.cache_policy == "refresh" and cache_file_path.exists():
cache_file_path.unlink()

response = requests.get(url)
response.raise_for_status()
with cache_file_path.open('w') as file:
file.write(response.text)

if self.cache_policy in ["always", "refresh"]:
with cache_file_path.open('w') as file:
file.write(response.text)

return response.text

Expand Down Expand Up @@ -127,15 +134,22 @@ def _clone_or_fetch_github(self, owner, repo, branch, file_path, https=True):
repo_cache_path = self.cache_dir / f"{owner}_{repo}_{branch}"
clone_url = f"https://github.com/{owner}/{repo}.git" if https else f"[email protected]:{owner}/{repo}.git"

# Clone or fetch the repository
if self.cache_policy == "never":
with tempfile.TemporaryDirectory() as tmpdir:
subprocess.run(["git", "clone", "-b", branch, clone_url, tmpdir], check=True)
file_full_path = Path(tmpdir) / file_path
if not file_full_path.exists():
raise FileNotFoundError(f"File {file_path} not found in repository {owner}/{repo} on branch {branch}")
with file_full_path.open('r') as file:
return file.read()

if not repo_cache_path.exists():
self.logger.debug(f"Cloning repository: {owner}/{repo} (branch: {branch})")
subprocess.run(["git", "clone", "-b", branch, clone_url, str(repo_cache_path)], check=True)
else:
self.logger.debug(f"Repository already cloned. Pulling latest changes for: {repo_cache_path}")
elif self.cache_policy == "refresh":
self.logger.debug(f"Refreshing repository cache: {repo_cache_path}")
subprocess.run(["git", "-C", str(repo_cache_path), "pull"], check=True)

# Read the requested file
file_full_path = repo_cache_path / file_path
if not file_full_path.exists():
raise FileNotFoundError(f"File {file_path} not found in repository {owner}/{repo} on branch {branch}")
Expand All @@ -159,6 +173,31 @@ def _fetch_s3_file(self, s3_url):
bucket_name, key = match.groups()
local_file_path = self.cache_dir / Path(key).name

if self.cache_policy == "always" and local_file_path.exists():
with local_file_path.open('r') as file:
return file.read()

if self.cache_policy == "never":
with tempfile.TemporaryDirectory() as tmpdir:
temp_path = Path(tmpdir) / Path(key).name
session = boto3.Session()
s3_client = session.client("s3")
try:
s3_client.download_file(bucket_name, key, str(temp_path))
except NoCredentialsError:
raise RuntimeError("AWS credentials not found. Ensure that your credentials are configured properly.")
except ClientError as e:
error_code = e.response.get("Error", {}).get("Code")
if error_code == "404":
raise FileNotFoundError(f"The specified S3 key does not exist: {key}")
else:
raise RuntimeError(f"Failed to download S3 file: {e}")
with temp_path.open('r') as file:
return file.read()

if self.cache_policy == "refresh" and local_file_path.exists():
local_file_path.unlink()

try:
session = boto3.Session() # Create a new session
s3_client = session.client("s3")
Expand Down Expand Up @@ -192,6 +231,26 @@ def _fetch_gcs_file(self, gcs_url):
bucket_name, key = match.groups()
local_file_path = self.cache_dir / Path(key).name

if self.cache_policy == "always" and local_file_path.exists():
with local_file_path.open('r') as file:
return file.read()

if self.cache_policy == "never":
with tempfile.TemporaryDirectory() as tmpdir:
temp_path = Path(tmpdir) / Path(key).name
try:
gcs_client = storage.Client()
bucket = gcs_client.bucket(bucket_name)
blob = bucket.blob(key)
blob.download_to_filename(str(temp_path))
except GoogleAPIError as e:
raise RuntimeError(f"Failed to download GCS file: {e}")
with temp_path.open('r') as file:
return file.read()

if self.cache_policy == "refresh" and local_file_path.exists():
local_file_path.unlink()

try:
gcs_client = storage.Client()
bucket = gcs_client.bucket(bucket_name)
Expand Down
3 changes: 2 additions & 1 deletion struct_module/file_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ def __init__(self, properties):
self.skip = properties.get("skip", False)
self.skip_if_exists = properties.get("skip_if_exists", False)

self.content_fetcher = ContentFetcher()
self.cache_policy = properties.get("cache_policy", "always")
self.content_fetcher = ContentFetcher(cache_policy=self.cache_policy)

self.system_prompt = properties.get("system_prompt") or properties.get("global_system_prompt")
self.user_prompt = properties.get("user_prompt")
Expand Down
2 changes: 2 additions & 0 deletions struct_module/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from struct_module.commands.list import ListCommand
from struct_module.commands.generate_schema import GenerateSchemaCommand
from struct_module.commands.mcp import MCPCommand
from struct_module.commands.cache import CacheCommand
from struct_module.logging_config import configure_logging


Expand All @@ -30,6 +31,7 @@ def main():
ListCommand(subparsers.add_parser('list', help='List available structures'))
GenerateSchemaCommand(subparsers.add_parser('generate-schema', help='Generate JSON schema for available structures'))
MCPCommand(subparsers.add_parser('mcp', help='MCP (Model Context Protocol) support'))
CacheCommand(subparsers.add_parser('cache', help='Inspect or clear cache'))

argcomplete.autocomplete(parser)

Expand Down
62 changes: 62 additions & 0 deletions tests/test_cache_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import hashlib
from unittest.mock import patch, MagicMock
import tempfile
from pathlib import Path
from struct_module.content_fetcher import ContentFetcher
import argparse
from struct_module.commands.cache import CacheCommand


def _mock_response(text):
mock = MagicMock()
mock.text = text
mock.raise_for_status = MagicMock()
return mock


def test_cache_policy_always():
with tempfile.TemporaryDirectory() as tmpdir:
fetcher = ContentFetcher(cache_dir=tmpdir, cache_policy="always")
url = "https://example.com/data"
with patch("requests.get", return_value=_mock_response("first")) as mock_get:
assert fetcher.fetch_content(url) == "first"
assert mock_get.call_count == 1
with patch("requests.get", return_value=_mock_response("second")) as mock_get:
assert fetcher.fetch_content(url) == "first"
mock_get.assert_not_called()


def test_cache_policy_never(tmp_path):
fetcher = ContentFetcher(cache_dir=tmp_path, cache_policy="never")
url = "https://example.com/data"
with patch("requests.get", side_effect=[_mock_response("a"), _mock_response("b")]) as mock_get:
assert fetcher.fetch_content(url) == "a"
assert fetcher.fetch_content(url) == "b"
assert mock_get.call_count == 2
cache_key = hashlib.md5(url.encode()).hexdigest()
assert not (tmp_path / cache_key).exists()


def test_cache_policy_refresh(tmp_path):
fetcher = ContentFetcher(cache_dir=tmp_path, cache_policy="refresh")
url = "https://example.com/data"
with patch("requests.get", side_effect=[_mock_response("old"), _mock_response("new")]) as mock_get:
assert fetcher.fetch_content(url) == "old"
assert fetcher.fetch_content(url) == "new"
assert mock_get.call_count == 2
cache_key = hashlib.md5(url.encode()).hexdigest()
with open(tmp_path / cache_key, "r") as f:
assert f.read() == "new"


def test_cache_command_clear(tmp_path, capsys):
cache_dir = tmp_path / "cache"
cache_dir.mkdir()
(cache_dir / "file.txt").write_text("data")
parser = argparse.ArgumentParser()
cmd = CacheCommand(parser)
args = parser.parse_args(["clear", "--cache-dir", str(cache_dir)])
cmd.execute(args)
assert not any(cache_dir.iterdir())
captured = capsys.readouterr()
assert "Cache cleared." in captured.out
Loading