Skip to content

Commit

Permalink
feat: playing with basic rag scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
ErikBjare committed Sep 23, 2024
1 parent c686dab commit 4e129b1
Show file tree
Hide file tree
Showing 6 changed files with 1,067 additions and 58 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ docs/_build
.coverage*
coverage.xml
junit.xml
htmlcov
__pycache__
.*cache

Expand Down
6 changes: 5 additions & 1 deletion gptme/dirs.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
from pathlib import Path

from platformdirs import user_config_dir, user_data_dir
from platformdirs import user_cache_dir, user_config_dir, user_data_dir


def get_config_dir() -> Path:
Expand Down Expand Up @@ -39,5 +39,9 @@ def _init_paths():
path.mkdir(parents=True, exist_ok=True)


def get_cache_dir() -> Path:
return Path(user_cache_dir("gptme"))


# run once on init
_init_paths()
207 changes: 207 additions & 0 deletions gptme/tools/rag/indexer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
import ast
import json
import logging
import os
import subprocess
import textwrap

import faiss
import numpy as np
import pathspec
from gptme.dirs import get_cache_dir
from sentence_transformers import SentenceTransformer

logging.basicConfig(level=logging.INFO)

data_dir = get_cache_dir() / "rag"
os.makedirs(data_dir, exist_ok=True)

metadata_file = data_dir / "index_metadata.json"


def get_git_root(path: str) -> str:
result = subprocess.run(
["git", "rev-parse", "--show-toplevel"],
cwd=path,
capture_output=True,
text=True,
)
if result.returncode != 0:
raise RuntimeError(f"Failed to find git root: {result.stderr.strip()}")
return result.stdout.strip()


def load_gitignore_patterns(repo_root: str) -> pathspec.PathSpec:
gitignore_path = os.path.join(repo_root, ".gitignore")
if os.path.exists(gitignore_path):
with open(gitignore_path, encoding="utf-8") as f:
return pathspec.PathSpec.from_lines("gitwildmatch", f)
return pathspec.PathSpec([])


def load_code_files(
directory: str, ignore_patterns: pathspec.PathSpec
) -> list[tuple[str, str]]:
code_files = []
ignored_files_count = 0
for root, _, files in os.walk(directory):
for file in files:
file_path = os.path.relpath(os.path.join(root, file), directory)
if ignore_patterns.match_file(file_path):
ignored_files_count += 1
continue
if file.endswith(
(".py", ".js", ".ts", ".html", ".css")
): # Add more file types as needed
logging.info(f"Processing file: {file_path}")
with open(os.path.join(root, file), encoding="utf-8") as f:
code_files.append((file_path, f.read()))
logging.info(f"Total files ignored: {ignored_files_count}")
return code_files


def chunk_code_syntactically(
code: str, file_path: str
) -> list[tuple[str, str, int, int]]:
chunks = []
try:
tree = ast.parse(code)
except SyntaxError:
# If parsing fails, fall back to simple line-based chunking
lines = code.split("\n")
for i in range(0, len(lines), 10):
chunk = "\n".join(lines[i : i + 10])
chunks.append((file_path, chunk, i + 1, min(i + 10, len(lines))))
return chunks

for node in ast.walk(tree):
if isinstance(node, (ast.FunctionDef, ast.ClassDef)):
start_lineno = node.lineno - 1
end_lineno = (
node.end_lineno
if hasattr(node, "end_lineno")
else len(code.split("\n"))
)

# Include decorators
if hasattr(node, "decorator_list") and node.decorator_list:
start_lineno = node.decorator_list[0].lineno - 1

chunk = code.split("\n")[start_lineno:end_lineno]
chunk_code = textwrap.dedent("\n".join(chunk))

chunks.append((file_path, chunk_code, start_lineno + 1, end_lineno))

return chunks


def chunk_code_line_based(
code: str, file_path: str, chunk_size: int = 20, language: str = "generic"
) -> list[tuple[str, str, int, int]]:
chunks = []
lines = code.split("\n")
for i in range(0, len(lines), chunk_size):
chunk = "\n".join(lines[i : i + chunk_size])
chunks.append((file_path, chunk, i + 1, min(i + chunk_size, len(lines))))
return chunks


def create_index(
code_files: list[tuple[str, str]], model: SentenceTransformer
) -> tuple[faiss.Index, list[tuple[str, str, int, int]]]:
chunks = []
for file_path, code in code_files:
logging.info(f"Processing file: {file_path}")
if file_path.endswith(".py"):
chunks.extend(chunk_code_syntactically(code, file_path))
elif file_path.endswith(".ts"):
logging.info(f"Processing TypeScript file: {file_path}")
chunks.extend(chunk_code_line_based(code, file_path, language="typescript"))
else:
chunks.extend(chunk_code_line_based(code, file_path))
logging.info(f"Total chunks created: {len(chunks)}")

texts = [chunk[1] for chunk in chunks]
batch_size = 64 # Adjust batch size as needed
embeddings = []
total_batches = (len(texts) + batch_size - 1) // batch_size
for i in range(total_batches):
batch_texts = texts[i * batch_size : (i + 1) * batch_size]
if i % 10 == 0 or i == total_batches - 1:
logging.info(f"Encoding batch {i + 1}/{total_batches}")
batch_embeddings = model.encode(batch_texts, show_progress_bar=False)
embeddings.append(batch_embeddings)
embeddings = np.vstack(embeddings)

dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(embeddings.astype("float32"))

return index, chunks


def load_metadata() -> dict[str, float]:
if metadata_file.exists():
with open(metadata_file, encoding="utf-8") as f:
return json.load(f)
return {}


def save_metadata(metadata: dict[str, float]):
with open(metadata_file, "w", encoding="utf-8") as f:
json.dump(metadata, f)


def main():
logging.info("Loading model...")
model = SentenceTransformer("all-MiniLM-L6-v2")
logging.info("Model loaded.")

logging.info("Finding Git root...")
repo_root = get_git_root(".")
logging.info(f"Git root found: {repo_root}")

logging.info("Loading .gitignore patterns...")
ignore_patterns = load_gitignore_patterns(repo_root)
logging.info(".gitignore patterns loaded.")

logging.info("Loading code files...")
code_files = load_code_files(repo_root, ignore_patterns)
logging.info(f"Total code files loaded: {len(code_files)}")

logging.info("Loading previous metadata...")
previous_metadata = load_metadata()

logging.info("Checking for changes...")
changed_files = []
current_metadata = {}
for file_path, code in code_files:
current_metadata[file_path] = os.path.getmtime(file_path)
if (
file_path not in previous_metadata
or previous_metadata[file_path] != current_metadata[file_path]
):
changed_files.append((file_path, code))

if not changed_files:
logging.info("No changes detected. Exiting.")
return

logging.info(f"Files changed: {len(changed_files)}")

logging.info("Creating index...")
index, chunks = create_index(changed_files, model)
logging.info("Index created.")

logging.info("Saving index and metadata...")
faiss.write_index(index, str(data_dir / "code_index.faiss"))
logging.info("Index saved.")

logging.info("Saving metadata...")
np.save(str(data_dir / "code_metadata.npy"), chunks)
save_metadata(current_metadata)
logging.info("Metadata saved.")


if __name__ == "__main__":
main()
60 changes: 60 additions & 0 deletions gptme/tools/rag/retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import faiss
import numpy as np
import logging
import time
import sys
from typing import List, Tuple
from sentence_transformers import SentenceTransformer
from gptme.dirs import get_cache_dir

logging.basicConfig(level=logging.INFO)

data_dir = get_cache_dir() / "rag"

def load_index_and_metadata() -> Tuple[faiss.Index, List[Tuple[str, str, int, int]]]:
start_time = time.time()
index = faiss.read_index(str(data_dir / "code_index.faiss"))
metadata = np.load(str(data_dir / "code_metadata.npy"), allow_pickle=True)
logging.info(f"Index and metadata loaded in {time.time() - start_time:.2f} seconds.")
return index, metadata.tolist()

def retrieve_relevant_chunks(query: str, index: faiss.Index, metadata: List[Tuple[str, str, int, int]], model: SentenceTransformer, top_k: int = 5) -> List[Tuple[str, str, int, int, float]]:
start_time = time.time()
query_embedding = model.encode([query])
logging.info(f"Query embedding created in {time.time() - start_time:.2f} seconds.")

start_time = time.time()
distances, indices = index.search(query_embedding.astype('float32'), top_k)
logging.info(f"Index search completed in {time.time() - start_time:.2f} seconds.")

relevant_chunks = [(metadata[idx][0], metadata[idx][1], metadata[idx][2], metadata[idx][3], distances[0][i]) for i, idx in enumerate(indices[0])]
return relevant_chunks

def format_chunk(chunk: Tuple[str, str, int, int, float]) -> str:
file_path, code, start_line, end_line, distance = chunk
return f"File: {file_path} (Lines {start_line}-{end_line}, Distance: {distance:.4f})\n{code}\n{'='*80}"

def main():
if len(sys.argv) < 2:
logging.error("Please provide a query as a command-line argument.")
sys.exit(1)

query = sys.argv[1]
logging.info(f"Query received: {query}")

logging.info("Loading model...")
model = SentenceTransformer('all-MiniLM-L6-v2')
logging.info("Model loaded.")

logging.info("Loading index and metadata...")
index, metadata = load_index_and_metadata()

logging.info("Retrieving relevant chunks...")
relevant_chunks = retrieve_relevant_chunks(query, index, metadata, model)
logging.info("Relevant chunks retrieved.")

for chunk in relevant_chunks:
print(format_chunk(chunk))

if __name__ == '__main__':
main()
Loading

0 comments on commit 4e129b1

Please sign in to comment.