Skip to content

Commit 00ca5d7

Browse files
authored
Merge pull request #1 from dahlem/feature/in-memory-index-backend
Replace SQLite with in-memory pickle-based index
2 parents 309f6ca + f8faba7 commit 00ca5d7

13 files changed

Lines changed: 1299 additions & 256 deletions

ARCHITECTURE.md

Lines changed: 142 additions & 106 deletions
Large diffs are not rendered by default.

README.md

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66
[![CI](https://github.com/dahlem/torchcachex/actions/workflows/ci.yml/badge.svg)](https://github.com/dahlem/torchcachex/actions)
77
[![codecov](https://codecov.io/gh/dahlem/torchcachex/branch/main/graph/badge.svg)](https://codecov.io/gh/dahlem/torchcachex)
88

9-
**Drop-in PyTorch module caching with Arrow IPC + SQLite backend**
9+
**Drop-in PyTorch module caching with Arrow IPC + in-memory index backend**
1010

1111
`torchcachex` provides transparent, per-sample caching for non-trainable PyTorch modules with:
1212
-**O(1) append-only writes** via incremental Arrow IPC segments
13-
-**O(1) batched lookups** via SQLite index + Arrow memory-mapping
13+
-**O(1) batched lookups** via in-memory index + Arrow memory-mapping
1414
-**Native tensor storage** with automatic dtype preservation
1515
-**LRU hot cache** for in-process hits
1616
-**Async writes** (non-blocking forward pass)
@@ -422,7 +422,7 @@ Wraps a PyTorch module to add transparent per-sample caching.
422422

423423
### `ArrowIPCCacheBackend`
424424

425-
Persistent cache using Arrow IPC segments with SQLite index for O(1) operations.
425+
Persistent cache using Arrow IPC segments with in-memory index for O(1) operations.
426426

427427
**Storage Format:**
428428
```
@@ -431,7 +431,7 @@ cache_dir/module_id/
431431
segment_000000.arrow # Incremental Arrow IPC files
432432
segment_000001.arrow
433433
...
434-
index.db # SQLite with WAL mode
434+
index.pkl # Pickled dict: key → (segment_id, row_offset)
435435
schema.json # Auto-inferred Arrow schema
436436
```
437437

@@ -446,22 +446,23 @@ cache_dir/module_id/
446446
- `current_rank` (Optional[int]): Current process rank (default: None)
447447

448448
**Methods:**
449-
- `get_batch(keys, map_location="cpu")`: O(1) batch lookup via SQLite index + memory-mapped Arrow
449+
- `get_batch(keys, map_location="cpu")`: O(1) batch lookup via in-memory index + memory-mapped Arrow
450450
- `put_batch(items)`: O(1) append-only write to pending buffer
451451
- `flush()`: Force flush pending writes to new Arrow segment
452452

453453
**Features:**
454454
- **O(1) writes**: New data appended to incremental segments, no rewrites
455-
- **O(1) reads**: SQLite index points directly to (segment_id, row_offset)
455+
- **O(1) reads**: In-memory dict index points directly to (segment_id, row_offset)
456456
- **Native tensors**: Automatic dtype preservation via Arrow's type system
457457
- **Schema inference**: Automatically detects structure on first write
458-
- **Crash safety**: Atomic commits via SQLite WAL + temp file approach
458+
- **Crash safety**: Automatic index rebuild from segments on corruption
459+
- **No database dependencies**: Simple pickle-based index persistence
459460

460461
## Architecture
461462

462463
### Storage Design
463464

464-
torchcachex uses a hybrid Arrow IPC + SQLite architecture optimized for billion-scale caching:
465+
torchcachex uses a hybrid Arrow IPC + in-memory index architecture optimized for billion-scale caching:
465466

466467
**Components:**
467468

@@ -471,11 +472,12 @@ torchcachex uses a hybrid Arrow IPC + SQLite architecture optimized for billion-
471472
- Memory-mapped for zero-copy reads
472473
- Each segment contains a batch of cached samples
473474

474-
2. **SQLite Index** (`index.db`)
475-
- WAL (Write-Ahead Logging) mode for concurrent reads
475+
2. **Pickle Index** (`index.pkl`)
476+
- In-memory Python dict backed by pickle persistence
476477
- Maps cache keys to (segment_id, row_offset)
477-
- O(1) lookups via primary key index
478-
- Tracks segment metadata (file paths, row counts)
478+
- O(1) lookups via dict access
479+
- Atomic persistence with temp file swap
480+
- Auto-rebuilds from segments on corruption
479481

480482
3. **Schema File** (`schema.json`)
481483
- Auto-inferred from first forward pass
@@ -488,8 +490,9 @@ torchcachex uses a hybrid Arrow IPC + SQLite architecture optimized for billion-
488490
put_batch() → pending buffer → flush() → {
489491
1. Create Arrow RecordBatch
490492
2. Write to temp segment file
491-
3. Update SQLite index (atomic transaction)
493+
3. Update in-memory index dict
492494
4. Atomic rename temp → final
495+
5. Persist index.pkl (atomic)
493496
}
494497
```
495498

@@ -498,7 +501,7 @@ put_batch() → pending buffer → flush() → {
498501
```
499502
get_batch() → {
500503
1. Check LRU cache (in-memory)
501-
2. Query SQLite for (segment_id, row_offset)
504+
2. Query in-memory index for (segment_id, row_offset)
502505
3. Memory-map Arrow segment
503506
4. Extract rows (zero-copy)
504507
5. Reconstruct tensors with correct dtype
@@ -508,10 +511,10 @@ get_batch() → {
508511
**Scalability Properties:**
509512

510513
- **Writes**: O(1) - append new segment, update index
511-
- **Reads**: O(1) - direct index lookup + memory-map
514+
- **Reads**: O(1) - direct dict lookup + memory-map
512515
- **Memory**: O(working set) - only LRU + current segment in memory
513516
- **Disk**: O(N) - one entry per sample across segments
514-
- **Crash Recovery**: Atomic - incomplete segments ignored, SQLite WAL ensures consistency
517+
- **Crash Recovery**: Atomic - incomplete segments ignored, index auto-rebuilds from segments if corrupted
515518

516519
### Schema Inference
517520

benchmark.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,9 @@
1212

1313
import argparse
1414
import os
15-
import shutil
1615
import tempfile
1716
import time
1817
from dataclasses import dataclass
19-
from typing import List
2018

2119
import torch
2220
import torch.nn as nn
@@ -81,7 +79,7 @@ def forward(self, x):
8179
return self.fc(x)
8280

8381

84-
def benchmark_write_scaling(tmpdir: str) -> List[BenchmarkResult]:
82+
def benchmark_write_scaling(tmpdir: str) -> list[BenchmarkResult]:
8583
"""Verify O(1) write scaling: flush time independent of cache size."""
8684
print("\n[Benchmark] Write Scaling (O(1) Verification)")
8785
print("=" * 60)
@@ -133,7 +131,7 @@ def benchmark_write_scaling(tmpdir: str) -> List[BenchmarkResult]:
133131
return results
134132

135133

136-
def benchmark_read_performance(tmpdir: str) -> List[BenchmarkResult]:
134+
def benchmark_read_performance(tmpdir: str) -> list[BenchmarkResult]:
137135
"""Measure read performance at different cache sizes."""
138136
print("\n[Benchmark] Read Performance")
139137
print("=" * 60)
@@ -156,7 +154,7 @@ def benchmark_read_performance(tmpdir: str) -> List[BenchmarkResult]:
156154
backend.flush()
157155

158156
# Benchmark random reads
159-
print(f" Benchmarking 1000 random reads...")
157+
print(" Benchmarking 1000 random reads...")
160158
import random
161159

162160
random.seed(42)
@@ -184,15 +182,13 @@ def benchmark_read_performance(tmpdir: str) -> List[BenchmarkResult]:
184182
return results
185183

186184

187-
def benchmark_memory_usage(tmpdir: str) -> List[BenchmarkResult]:
185+
def benchmark_memory_usage(tmpdir: str) -> list[BenchmarkResult]:
188186
"""Measure memory usage at different cache sizes."""
189187
print("\n[Benchmark] Memory Usage")
190188
print("=" * 60)
191189

192190
try:
193191
import psutil
194-
195-
HAS_PSUTIL = True
196192
except ImportError:
197193
print(" [Skip] psutil not installed")
198194
return []
@@ -237,7 +233,7 @@ def benchmark_memory_usage(tmpdir: str) -> List[BenchmarkResult]:
237233
return results
238234

239235

240-
def benchmark_cache_speedup(tmpdir: str) -> List[BenchmarkResult]:
236+
def benchmark_cache_speedup(tmpdir: str) -> list[BenchmarkResult]:
241237
"""Compare cached vs uncached performance."""
242238
print("\n[Benchmark] Cache Speedup")
243239
print("=" * 60)
@@ -250,7 +246,7 @@ def benchmark_cache_speedup(tmpdir: str) -> List[BenchmarkResult]:
250246
loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
251247

252248
# Benchmark WITHOUT caching
253-
print(f" Running WITHOUT cache...")
249+
print(" Running WITHOUT cache...")
254250
module_nocache = BenchmarkModule()
255251

256252
start = time.time()
@@ -261,7 +257,7 @@ def benchmark_cache_speedup(tmpdir: str) -> List[BenchmarkResult]:
261257
print(f" Time: {time_nocache:.3f}s, Calls: {module_nocache.call_count}")
262258

263259
# Benchmark WITH caching (first epoch - populate cache)
264-
print(f" Running WITH cache (epoch 1 - populate)...")
260+
print(" Running WITH cache (epoch 1 - populate)...")
265261
backend = ArrowIPCCacheBackend(
266262
cache_dir=tmpdir,
267263
module_id="speedup_test",
@@ -279,7 +275,7 @@ def benchmark_cache_speedup(tmpdir: str) -> List[BenchmarkResult]:
279275
print(f" Time: {time_epoch1:.3f}s, Module calls: {module_cached.call_count}")
280276

281277
# Benchmark WITH caching (second epoch - cache hits)
282-
print(f" Running WITH cache (epoch 2 - cache hits)...")
278+
print(" Running WITH cache (epoch 2 - cache hits)...")
283279
module_cached.call_count = 0
284280

285281
start = time.time()
@@ -321,7 +317,7 @@ def benchmark_cache_speedup(tmpdir: str) -> List[BenchmarkResult]:
321317
return results
322318

323319

324-
def benchmark_async_write(tmpdir: str) -> List[BenchmarkResult]:
320+
def benchmark_async_write(tmpdir: str) -> list[BenchmarkResult]:
325321
"""Compare async vs sync write performance."""
326322
print("\n[Benchmark] Async Write Performance")
327323
print("=" * 60)
@@ -334,7 +330,7 @@ def benchmark_async_write(tmpdir: str) -> List[BenchmarkResult]:
334330
loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
335331

336332
# Benchmark SYNC writes
337-
print(f" Running with SYNC writes...")
333+
print(" Running with SYNC writes...")
338334
backend_sync = ArrowIPCCacheBackend(
339335
cache_dir=tmpdir,
340336
module_id="async_test_sync",
@@ -355,7 +351,7 @@ def benchmark_async_write(tmpdir: str) -> List[BenchmarkResult]:
355351
print(f" Time: {time_sync:.3f}s")
356352

357353
# Benchmark ASYNC writes
358-
print(f" Running with ASYNC writes...")
354+
print(" Running with ASYNC writes...")
359355
backend_async = ArrowIPCCacheBackend(
360356
cache_dir=tmpdir,
361357
module_id="async_test_async",
@@ -398,7 +394,7 @@ def benchmark_async_write(tmpdir: str) -> List[BenchmarkResult]:
398394
return results
399395

400396

401-
def benchmark_dtype_preservation(tmpdir: str) -> List[BenchmarkResult]:
397+
def benchmark_dtype_preservation(tmpdir: str) -> list[BenchmarkResult]:
402398
"""Verify dtype preservation across different tensor types."""
403399
print("\n[Benchmark] Dtype Preservation")
404400
print("=" * 60)
@@ -450,7 +446,7 @@ def benchmark_dtype_preservation(tmpdir: str) -> List[BenchmarkResult]:
450446
return results
451447

452448

453-
def generate_markdown_report(all_results: List[BenchmarkResult], output_file: str):
449+
def generate_markdown_report(all_results: list[BenchmarkResult], output_file: str):
454450
"""Generate markdown report from benchmark results."""
455451
print(f"\n[Report] Generating markdown report: {output_file}")
456452

examples/advanced_usage.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,14 +104,14 @@ def example_kfold_cv():
104104

105105
# Train on fold (features cached progressively)
106106
for batch in train_loader:
107-
features = cached_extractor(batch["image"], cache_ids=batch["cache_ids"])
107+
_features = cached_extractor(batch["image"], cache_ids=batch["cache_ids"])
108108
# ... train classifier ...
109109

110110
backend.flush()
111111

112112
# Validate (reuses cached features from overlapping samples)
113113
for batch in val_loader:
114-
features = cached_extractor(batch["image"], cache_ids=batch["cache_ids"])
114+
_features = cached_extractor(batch["image"], cache_ids=batch["cache_ids"])
115115
# ... evaluate ...
116116

117117
print(f" Fold {fold + 1} complete\n")
@@ -147,7 +147,7 @@ def example_ddp_training():
147147
print("Training (all ranks compute, only rank 0 writes cache)...")
148148
for batch in loader:
149149
# All ranks compute features
150-
features = cached_extractor(batch["image"], cache_ids=batch["cache_ids"])
150+
_features = cached_extractor(batch["image"], cache_ids=batch["cache_ids"])
151151
# ... train on features ...
152152

153153
backend.flush()
@@ -201,13 +201,13 @@ def forward(self, x, cache_ids):
201201

202202
print("Training Model A (populates cache)...")
203203
for batch in loader:
204-
logits = model_a(batch["image"], cache_ids=batch["cache_ids"])
204+
_logits = model_a(batch["image"], cache_ids=batch["cache_ids"])
205205
# ... train model A ...
206206
backend.flush()
207207

208208
print("Training Model B (reuses Model A's cache)...")
209209
for batch in loader:
210-
logits = model_b(batch["image"], cache_ids=batch["cache_ids"])
210+
_logits = model_b(batch["image"], cache_ids=batch["cache_ids"])
211211
# ... train model B ...
212212

213213
print("Model B reused all features from Model A's cache!\n")

examples/cli_integration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from hydra.utils import instantiate
1414
from omegaconf import DictConfig
1515
from rich.console import Console
16+
from shade_io.feature_sets.filters import RemoveConstantFeaturesFilter
1617
from sklearn.decomposition import PCA
1718

1819
# Import shade-io components
@@ -23,7 +24,6 @@
2324
FilteredFeatureSet,
2425
SimpleFeatureSet,
2526
)
26-
from shade_io.feature_sets.filters import RemoveConstantFeaturesFilter
2727

2828
logger = logging.getLogger(__name__)
2929
console = Console()

examples/minimal_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
print("Training...")
4545
for epoch in range(3):
4646
print(f" Epoch {epoch + 1}/3")
47-
for batch_idx, (batch_images, batch_labels) in enumerate(loader):
47+
for batch_idx, (batch_images, _batch_labels) in enumerate(loader):
4848
# Get cache IDs for this batch
4949
start_idx = batch_idx * 10
5050
batch_cache_ids = cache_ids[start_idx : start_idx + 10]

src/torchcachex/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""torchcachex: Drop-in PyTorch module caching with Arrow IPC + SQLite backend.
1+
"""torchcachex: Drop-in PyTorch module caching with Arrow IPC + in-memory index backend.
22
33
This library provides transparent, per-sample caching for non-trainable PyTorch modules
44
with O(1) append-only writes, native tensor storage, batched lookups, LRU hot cache,

0 commit comments

Comments
 (0)