|
1 | 1 | import json
|
| 2 | +import re |
2 | 3 |
|
3 | 4 | import structlog
|
4 | 5 | from litellm import ChatCompletionRequest
|
|
9 | 10 | PipelineResult,
|
10 | 11 | PipelineStep,
|
11 | 12 | )
|
| 13 | +from codegate.pipeline.extract_snippets.extract_snippets import extract_snippets |
12 | 14 | from codegate.storage.storage_engine import StorageEngine
|
| 15 | +from codegate.utils.package_extractor import PackageExtractor |
13 | 16 | from codegate.utils.utils import generate_vector_string
|
14 | 17 |
|
15 | 18 | logger = structlog.get_logger("codegate")
|
@@ -64,26 +67,45 @@ async def process(
|
64 | 67 | if len(user_messages) == 0:
|
65 | 68 | return PipelineResult(request=request)
|
66 | 69 |
|
67 |
| - context_str = "CodeGate did not find any malicious or archived packages." |
| 70 | + # Create storage engine object |
| 71 | + storage_engine = StorageEngine() |
| 72 | + |
| 73 | + # Extract any code snippets |
| 74 | + snippets = extract_snippets(user_messages) |
| 75 | + |
| 76 | + # Collect all packages referenced in the snippets |
| 77 | + snippet_packages = [] |
| 78 | + for snippet in snippets: |
| 79 | + snippet_packages.extend( |
| 80 | + PackageExtractor.extract_packages(snippet.code, snippet.language) |
| 81 | + ) |
| 82 | + logger.info(f"Found {len(snippet_packages)} packages in code snippets.") |
| 83 | + |
| 84 | + # Find bad packages in the snippets |
| 85 | + bad_snippet_packages = await storage_engine.search_by_property("name", snippet_packages) |
| 86 | + logger.info(f"Found {len(bad_snippet_packages)} bad packages in code snippets.") |
| 87 | + |
| 88 | + # Remove code snippets from the user messages and search for bad packages |
| 89 | + # in the rest of the user query/messsages |
| 90 | + user_messages = re.sub(r"```.*?```", "", user_messages, flags=re.DOTALL) |
68 | 91 |
|
69 | 92 | # Vector search to find bad packages
|
70 |
| - storage_engine = StorageEngine() |
71 |
| - searched_objects = await storage_engine.search(query=user_messages, distance=0.8, limit=100) |
| 93 | + bad_packages = await storage_engine.search(query=user_messages, distance=0.8, limit=100) |
72 | 94 |
|
73 |
| - logger.info( |
74 |
| - f"Found {len(searched_objects)} matches in the database", |
75 |
| - searched_objects=searched_objects, |
76 |
| - ) |
| 95 | + # All bad packages |
| 96 | + all_bad_packages = bad_snippet_packages + bad_packages |
| 97 | + |
| 98 | + logger.info(f"Adding {len(all_bad_packages)} bad packages to the context.") |
77 | 99 |
|
78 | 100 | # Generate context string using the searched objects
|
79 |
| - logger.info(f"Adding {len(searched_objects)} packages to the context") |
| 101 | + context_str = "CodeGate did not find any malicious or archived packages." |
80 | 102 |
|
81 | 103 | # Nothing to do if no bad packages are found
|
82 |
| - if len(searched_objects) == 0: |
| 104 | + if len(all_bad_packages) == 0: |
83 | 105 | return PipelineResult(request=request, context=context)
|
84 | 106 | else:
|
85 | 107 | # Add context for bad packages
|
86 |
| - context_str = self.generate_context_str(searched_objects, context) |
| 108 | + context_str = self.generate_context_str(all_bad_packages, context) |
87 | 109 | context.bad_packages_found = True
|
88 | 110 |
|
89 | 111 | last_user_idx = self.get_last_user_message_idx(request)
|
|
0 commit comments