Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 8328979

Browse files
authoredJan 6, 2025··
Merge pull request #493 from stacklok/extract-input-code-snippets
Extract and process code snippets in the user query
2 parents 1bb78e2 + b9d8fce commit 8328979

File tree

1 file changed

+32
-10
lines changed
  • src/codegate/pipeline/codegate_context_retriever

1 file changed

+32
-10
lines changed
 

‎src/codegate/pipeline/codegate_context_retriever/codegate.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
import re
23

34
import structlog
45
from litellm import ChatCompletionRequest
@@ -9,7 +10,9 @@
910
PipelineResult,
1011
PipelineStep,
1112
)
13+
from codegate.pipeline.extract_snippets.extract_snippets import extract_snippets
1214
from codegate.storage.storage_engine import StorageEngine
15+
from codegate.utils.package_extractor import PackageExtractor
1316
from codegate.utils.utils import generate_vector_string
1417

1518
logger = structlog.get_logger("codegate")
@@ -64,26 +67,45 @@ async def process(
6467
if len(user_messages) == 0:
6568
return PipelineResult(request=request)
6669

67-
context_str = "CodeGate did not find any malicious or archived packages."
70+
# Create storage engine object
71+
storage_engine = StorageEngine()
72+
73+
# Extract any code snippets
74+
snippets = extract_snippets(user_messages)
75+
76+
# Collect all packages referenced in the snippets
77+
snippet_packages = []
78+
for snippet in snippets:
79+
snippet_packages.extend(
80+
PackageExtractor.extract_packages(snippet.code, snippet.language)
81+
)
82+
logger.info(f"Found {len(snippet_packages)} packages in code snippets.")
83+
84+
# Find bad packages in the snippets
85+
bad_snippet_packages = await storage_engine.search_by_property("name", snippet_packages)
86+
logger.info(f"Found {len(bad_snippet_packages)} bad packages in code snippets.")
87+
88+
# Remove code snippets from the user messages and search for bad packages
89+
# in the rest of the user query/messsages
90+
user_messages = re.sub(r"```.*?```", "", user_messages, flags=re.DOTALL)
6891

6992
# Vector search to find bad packages
70-
storage_engine = StorageEngine()
71-
searched_objects = await storage_engine.search(query=user_messages, distance=0.8, limit=100)
93+
bad_packages = await storage_engine.search(query=user_messages, distance=0.8, limit=100)
7294

73-
logger.info(
74-
f"Found {len(searched_objects)} matches in the database",
75-
searched_objects=searched_objects,
76-
)
95+
# All bad packages
96+
all_bad_packages = bad_snippet_packages + bad_packages
97+
98+
logger.info(f"Adding {len(all_bad_packages)} bad packages to the context.")
7799

78100
# Generate context string using the searched objects
79-
logger.info(f"Adding {len(searched_objects)} packages to the context")
101+
context_str = "CodeGate did not find any malicious or archived packages."
80102

81103
# Nothing to do if no bad packages are found
82-
if len(searched_objects) == 0:
104+
if len(all_bad_packages) == 0:
83105
return PipelineResult(request=request, context=context)
84106
else:
85107
# Add context for bad packages
86-
context_str = self.generate_context_str(searched_objects, context)
108+
context_str = self.generate_context_str(all_bad_packages, context)
87109
context.bad_packages_found = True
88110

89111
last_user_idx = self.get_last_user_message_idx(request)

0 commit comments

Comments
 (0)
This repository has been archived.