Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: playing with basic rag scripts #135

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ docs/_build
.coverage*
coverage.xml
junit.xml
htmlcov
__pycache__
.*cache

Expand Down
6 changes: 5 additions & 1 deletion gptme/dirs.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
from pathlib import Path

from platformdirs import user_config_dir, user_data_dir
from platformdirs import user_cache_dir, user_config_dir, user_data_dir


def get_config_dir() -> Path:
Expand Down Expand Up @@ -39,5 +39,9 @@ def _init_paths():
path.mkdir(parents=True, exist_ok=True)


def get_cache_dir() -> Path:
return Path(user_cache_dir("gptme"))


# run once on init
_init_paths()
2 changes: 2 additions & 0 deletions gptme/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .subagent import tool as subagent_tool
from .tmux import tool as tmux_tool
from .youtube import tool as youtube_tool
from .rag import tool as rag_tool

logger = logging.getLogger(__name__)

Expand All @@ -39,6 +40,7 @@
gh_tool,
chats_tool,
youtube_tool,
rag_tool,
# python tool is loaded last to ensure all functions are registered
get_python_tool,
]
Expand Down
46 changes: 46 additions & 0 deletions gptme/tools/rag/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import importlib

import numpy as np
from gptme.dirs import get_cache_dir

from ..base import ToolSpec

available = importlib.util.find_spec("faiss") is not None

data_dir = get_cache_dir() / "rag"


def retrieve(query: str, top_k: int = 5):
import faiss # fmt: skip
from sentence_transformers import SentenceTransformer # fmt: skip

from .indexer import main as main_indexer # fmt: skip
from .retriever import retrieve_relevant_chunks # fmt: skip

# TODO: Add a check if the index exists
main_indexer()

# Load the model, index, and metadata only once
model = SentenceTransformer("all-MiniLM-L6-v2")
index = faiss.read_index(str(data_dir / "code_index.faiss"))
metadata = np.load(str(data_dir / "code_metadata.npy"), allow_pickle=True).tolist()

# Retrieve relevant chunks
relevant_chunks = retrieve_relevant_chunks(query, index, metadata, model, top_k)

# Format the results
formatted_results = []
for chunk in relevant_chunks:
file_path, code, start_line, end_line, distance = chunk
formatted_chunk = f"File: {file_path} (Lines {start_line}-{end_line}, Distance: {distance:.4f})\n{code}\n{'='*80}"
formatted_results.append(formatted_chunk)

return "\n\n".join(formatted_results)


tool = ToolSpec(
name="rag",
desc="A tool for retrieving relevant code snippets from the project.",
functions=[retrieve],
available=available,
)
224 changes: 224 additions & 0 deletions gptme/tools/rag/indexer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
import ast
import json
import logging
import os
import subprocess
import textwrap

import numpy as np
import pathspec
from gptme.dirs import get_cache_dir

import faiss
from sentence_transformers import SentenceTransformer

logging.basicConfig(level=logging.INFO)

data_dir = get_cache_dir() / "rag"
os.makedirs(data_dir, exist_ok=True)

metadata_file = data_dir / "index_metadata.json"


def get_git_root(path: str) -> str:
result = subprocess.run(
["git", "rev-parse", "--show-toplevel"],
cwd=path,
capture_output=True,
text=True,
)
if result.returncode != 0:
raise RuntimeError(f"Failed to find git root: {result.stderr.strip()}")
return result.stdout.strip()


def load_gitignore_patterns(repo_root: str) -> pathspec.PathSpec:
gitignore_path = os.path.join(repo_root, ".gitignore")
if os.path.exists(gitignore_path):
with open(gitignore_path, encoding="utf-8") as f:
return pathspec.PathSpec.from_lines("gitwildmatch", f)
return pathspec.PathSpec([])


def load_code_files(
directory: str, ignore_patterns: pathspec.PathSpec
) -> list[tuple[str, str]]:
code_files = []
ignored_files_count = 0
for root, _, files in os.walk(directory):
for file in files:
file_path = os.path.relpath(os.path.join(root, file), directory)
if ignore_patterns.match_file(file_path):
ignored_files_count += 1
continue
if file.endswith(
(".py", ".js", ".ts", ".html", ".css")
): # Add more file types as needed
logging.info(f"Processing file: {file_path}")
with open(os.path.join(root, file), encoding="utf-8") as f:
code_files.append((file_path, f.read()))
logging.info(f"Total files ignored: {ignored_files_count}")
return code_files


def chunk_code_syntactically(
code: str, file_path: str
) -> list[tuple[str, str, int, int]]:
chunks: list[tuple[str, str, int, int]] = []
try:
tree = ast.parse(code)
except SyntaxError:
# If parsing fails, fall back to simple line-based chunking
lines = code.split("\n")
for i in range(0, len(lines), 10):
chunk = "\n".join(lines[i : i + 10])
chunks.append((file_path, chunk, i + 1, min(i + 10, len(lines))))
return chunks

for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef | ast.ClassDef):
start_lineno = node.lineno - 1
end_lineno = (
node.end_lineno
if hasattr(node, "end_lineno") and node.end_lineno is not None
else len(code.split("\n"))
)

# Include decorators
if hasattr(node, "decorator_list") and node.decorator_list:
start_lineno = node.decorator_list[0].lineno - 1

chunk_lines = code.split("\n")[start_lineno:end_lineno]
chunk_code = textwrap.dedent("\n".join(chunk_lines))

chunks.append((file_path, chunk_code, start_lineno + 1, end_lineno))

# Handle small files or single-line statements
if not chunks:
lines = code.split("\n")
for i in range(0, len(lines), 10):
chunk = "\n".join(lines[i : i + 10])
chunks.append((file_path, chunk, i + 1, min(i + 10, len(lines))))

return chunks


def chunk_code_line_based(
code: str, file_path: str, chunk_size: int = 20, language: str = "generic"
) -> list[tuple[str, str, int, int]]:
chunks = []
lines = code.split("\n")
for i in range(0, len(lines), chunk_size):
chunk = "\n".join(lines[i : i + chunk_size])
chunks.append((file_path, chunk, i + 1, min(i + chunk_size, len(lines))))
return chunks


def should_reindex(
current_metadata: dict[str, float], previous_metadata: dict[str, float]
) -> bool:
return any(
file_path not in previous_metadata or previous_metadata[file_path] != mtime
for file_path, mtime in current_metadata.items()
)


def create_index(
code_files: list[tuple[str, str]], model: SentenceTransformer
) -> tuple[faiss.Index, list[tuple[str, str, int, int]]]:
chunks = []
for file_path, code in code_files:
logging.info(f"Processing file: {file_path}")
if file_path.endswith(".py"):
chunks.extend(chunk_code_syntactically(code, file_path))
elif file_path.endswith(".ts"):
logging.info(f"Processing TypeScript file: {file_path}")
chunks.extend(chunk_code_line_based(code, file_path, language="typescript"))
else:
chunks.extend(chunk_code_line_based(code, file_path))
logging.info(f"Total chunks created: {len(chunks)}")

texts = [chunk[1] for chunk in chunks]
batch_size = 64 # Adjust batch size as needed
embeddings_list = []
total_batches = (len(texts) + batch_size - 1) // batch_size
for i in range(total_batches):
batch_texts = texts[i * batch_size : (i + 1) * batch_size]
if i % 10 == 0 or i == total_batches - 1:
logging.info(f"Encoding batch {i + 1}/{total_batches}")
batch_embeddings = model.encode(batch_texts, show_progress_bar=False)
embeddings_list.append(batch_embeddings)
embeddings = np.vstack(embeddings_list)

dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(embeddings.astype(np.float32))

return index, chunks


def load_metadata() -> dict[str, float]:
if metadata_file.exists():
with open(metadata_file, encoding="utf-8") as f:
return json.load(f)
return {}


def save_metadata(metadata: dict[str, float]):
with open(metadata_file, "w", encoding="utf-8") as f:
json.dump(metadata, f)


def main():
logging.info("Loading model...")
model = SentenceTransformer("all-MiniLM-L6-v2")
logging.info("Model loaded.")

logging.info("Finding Git root...")
repo_root = get_git_root(".")
logging.info(f"Git root found: {repo_root}")

logging.info("Loading .gitignore patterns...")
ignore_patterns = load_gitignore_patterns(repo_root)
logging.info(".gitignore patterns loaded.")

logging.info("Loading code files...")
code_files = load_code_files(repo_root, ignore_patterns)
logging.info(f"Total code files loaded: {len(code_files)}")

logging.info("Loading previous metadata...")
previous_metadata = load_metadata()

logging.info("Checking for changes...")
changed_files = []
current_metadata = {}
for file_path, code in code_files:
current_metadata[file_path] = os.path.getmtime(file_path)
if (
file_path not in previous_metadata
or previous_metadata[file_path] != current_metadata[file_path]
):
changed_files.append((file_path, code))

if not changed_files:
logging.info("No changes detected. Exiting.")
return

logging.info(f"Files changed: {len(changed_files)}")

logging.info("Creating index...")
index, chunks = create_index(changed_files, model)
logging.info("Index created.")

logging.info("Saving index and metadata...")
faiss.write_index(index, str(data_dir / "code_index.faiss"))
logging.info("Index saved.")

logging.info("Saving metadata...")
np.save(str(data_dir / "code_metadata.npy"), chunks)
save_metadata(current_metadata)
logging.info("Metadata saved.")


if __name__ == "__main__":
main()
48 changes: 48 additions & 0 deletions gptme/tools/rag/retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import faiss
import numpy as np
import logging
from sentence_transformers import SentenceTransformer
from gptme.dirs import get_cache_dir

logging.basicConfig(level=logging.INFO)

data_dir = get_cache_dir() / "rag"


def load_index_and_metadata() -> tuple[faiss.Index, list[tuple[str, str, int, int]]]:
index = faiss.read_index(str(data_dir / "code_index.faiss"))
metadata = np.load(str(data_dir / "code_metadata.npy"), allow_pickle=True)
return index, metadata.tolist()


def retrieve_relevant_chunks(
query: str,
index: faiss.Index,
metadata: list[tuple[str, str, int, int]],
model: SentenceTransformer,
top_k: int = 5,
) -> list[tuple[str, str, int, int, float]]:
query_embedding = model.encode([query])
distances, indices = index.search(query_embedding.astype("float32"), top_k)
return [(*metadata[idx], distances[0][i]) for i, idx in enumerate(indices[0])]


def format_chunk(chunk: tuple[str, str, int, int, float]) -> str:
file_path, code, start_line, end_line, distance = chunk
return f"File: {file_path} (Lines {start_line}-{end_line}, Distance: {distance:.4f})\n{code}\n{'='*80}"


def retrieve(query: str, top_k: int = 5) -> str:
model = SentenceTransformer("all-MiniLM-L6-v2")
index, metadata = load_index_and_metadata()
relevant_chunks = retrieve_relevant_chunks(query, index, metadata, model, top_k)
return "\n\n".join(format_chunk(chunk) for chunk in relevant_chunks)


if __name__ == "__main__":
import sys

if len(sys.argv) < 2:
print("Please provide a query as a command-line argument.")
sys.exit(1)
print(retrieve(sys.argv[1]))
Loading
Loading